#  Comparison of CNN to Vision Transformer model trained on Mars Satellite Images in the HiRISE dataset     (Feel free to change this)
### by Aniruddha Prasad and Andrew Hartnett

The following notebook will compare the accuracies of a Convolutional Neural Network (CNN) and Vision Transformer (ViT) trained on satellite images taken of Mars from the HiRISE dataset. The goal of this work is to determine whether or not a pre-trained ViT model, which has been seen used as the state-of-the-art for image classification in certain circumstances, will prove better when pre-trained on a significant size dataset and fine-tuned to this data. Then, we will train 3 version of each model with larger and larger subsets of the data to determine the trend in accuracy for each model. This will tell us which model will be best as more images are accumulated over the years.

### Table of Contents

1. Prepare the Training Data - **WIP**
2. Define and Train the CNN - **WIP**
3. Define and Traing the Vision Transformer (ViT) - **WIP**
4. Evaluate CNN vs ViT - **WIP**
5. Retrain CNN and ViT on small, medium, and full HiRISE - **WIP**
6. Compare three CNNs vs three ViTs - **WIP**

## 1. Prepare the training data

In [None]:
# Import all required libraries and functions:
import numpy as np
from PIL import Image
import tensorflow as tf
import os

import cv2
import csv
from tensorflow import keras
from keras import utils, layers
from keras.models import Sequential
from keras.layers import Input, Conv2D, MaxPool2D, Dense, Flatten, Dropout, BatchNormalization
import matplotlib.pyplot as plt
%matplotlib inline

import pathlib
import numpy as np
from sklearn.model_selection import train_test_split
from keras.preprocessing import image
from keras.preprocessing.image import ImageDataGenerator, img_to_array, array_to_img, load_img
from keras.utils import to_categorical

import pandas as pd


import matplotlib
matplotlib.use('PS') #prevent import error due to venv
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from PIL import Image

# Imports for dataset separation
from keras.preprocessing.image import ImageDataGenerator

# Improve progress bar display
import tqdm
from tqdm import auto
tqdm.tqdm = tqdm.auto.tqdm

## The way Neihusst Preprocesses their data:

In [None]:
data_images = []
data_labels = []
rel_img_path = 'map-proj/' # add path of folder to image name for later loading

# open up the labeled data file
with open('labels-map-proj.txt') as labels:
  for line in labels:
    file_name, label = line.split(' ')
    data_images.append(rel_img_path + file_name)
    data_labels.append(int(label))

# divide data into testing and training (total len 3820)
train_images, test_images, train_labels, test_labels = train_test_split(
    data_images, data_labels, test_size=0.15, random_state=666)
test_len = len(test_images)   # 573
train_len = len(train_images) # 3247

# label translations
class_labels = ['other','crater','dark_dune','streak',
                'bright_dune','impact','edge']


### Data Preprocessing

In [None]:
#convert image paths into numpy matrices
def parse_image(filename):
  img_obj = Image.open(filename)
  img = np.asarray(img_obj).astype(np.float32)
  #normalize image to 0-1 range
  img /= 255.0
  return img

train_images = np.array(list(map(parse_image, train_images)))
test_images = np.array(list(map(parse_image, test_images)))

### Convert labels to one-hot encoding

In [None]:
def to_one_hot(label):
  encoding = [0 for _ in range(len(class_labels))]
  encoding[label] = 1
  return np.array(encoding).astype(np.float32)

train_labels = np.array(list(map(to_one_hot, train_labels)))
test_labels = np.array(list(map(to_one_hot, test_labels)))

## 2. Define and Train the CNN - WIP

### An example of Performing image classifaction with the keras library:

*https://www.tensorflow.org/tutorials/images/classification*

The way the data is arranged in this example is that in a directory of all images, each class gets its own folder. This is how they are effectively labeled. This could be a way we could do it as well.

In [None]:
# Define the neural network, would be nice to show the use of GridSearchCV and a param_grid for optimization
num_classes = len(class_names)

model = Sequential([
  layers.Rescaling(1./255, input_shape=(img_height, img_width, 3)),
  layers.Conv2D(16, 3, padding='same', activation='relu'),
  layers.MaxPooling2D(),
  layers.Conv2D(32, 3, padding='same', activation='relu'),
  layers.MaxPooling2D(),
  layers.Conv2D(64, 3, padding='same', activation='relu'),
  layers.MaxPooling2D(),
  layers.Flatten(),
  layers.Dense(128, activation='relu'),
  layers.Dense(num_classes)
])

# Compile the model:

model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])


# Train the model:

epochs = 10
history = model.fit(train_ds, validation_data = val_ds, epochs=epochs)

#Train_ds and val_ds is how the image data is stored. We need to store the data in a similar fashion

In [None]:
# Print the best_params_ for GridSearchCV


## 3. Define and Train the Vision Transformer (ViT) - WIP
Website used as a source: https://theaisummer.com/hugging-face-vit/

In [None]:
# TBR
from datasets import load_dataset
train_ds, test_ds = load_dataset('cifar10', split=['train[:5000]', 'test[:2000]'])

# TBR
splits - train_ds.train_test_split(test_size=0.1)
train_ds = splits['train']
val_ds = splits['test']

In [None]:
# Set the training metric to minimize
from datasets import load_metric
metric = load_metric("accuracy")

# Instantiate ViT model
from transformers import ViTForImageClassification
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
model.eval()

In [None]:
# Feature extraction
from transformers import ViTFeatureExtractor

feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')

def preprocess_images(examples):
    
    images = examples['img']
    images = [np.array(image, dtype=np.uint8) for image in images]
    images = [np.moveaxis(image, source=-1, destination=0) for image in images]
    examples['pixel_values'] = inputs['pixel_values']
    
    return examples

from datasets import Features, ClassLabel, Array3D

features = Features({
    'label': ClassLabels(names=['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']),
    'img': Array3D(dtype="int64", shape=(3,32,32)),
    'pixel_values': Array3D(dtype="float32", shape=(3, 224, 244)),
})

preprocessed_train_ds = train_ds.map(preprocess_images, batched=True, features=features)
preprocessed_val_ds = val_ds.map(preprocess_images, batched=True, features=features)
preprocessed_test_ds = test_ds.map(preprocess_images, batched=True, features=features)

In [None]:
# Data collator - Used for forming batches from the dataset when training the model
from transformers import default_data_collator

data_collator = default_data_collator

In [None]:
# Defining the model - Part 1
from transformers import ViTForImageClassification

model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-ink21k')

model.train()

In [None]:
# Defining the model - Part 2
from transformers import ViTModel

class ViTForImageClassification2(nn.Module):
    
    def __init__(self, num_labels=10):
        
        super(ViTForImageClassification2, self).__init__()
        self.vit = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
        self.classifier = nn.Linear(self.vit.config.hidden_size, num_labels)
        self.num_labels = num_labels
        
    def forward(self, pixel_values, labels):
        
        outputs = self.vit(pixel_values=pixel_values)
        logits = self.classifier(output)
        loss = None
        
        if labels is not None:
            
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
            
        return SequenceClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

In [None]:
# Calculate the metrics during evaluation (CUSTOM - May need to change)
def compute_metrics(eval_pred):
    
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return metric.compute(predicitons=predictions, references=labels)

In [None]:
# Training the model
trainer = Trainer(
    model,
    args,
    train_dataset = preprocessed_train_ds,
    eval_dataset = preprocessed_val_ds,
    data_collator = data_collator,
    compute_metrics = compute_metrics,
)

In [None]:
# Training arguments
args = TrainingArguments(
    "test-cifar-10",
    evaluation_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=10,
    per_device_eval_batch_size=4,
    num_train_epochs=3,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model=metric_name,
    logging_dir='logs',
)

trainer.train()

In [None]:
# Callbacks - This cell is not complete
from transformers import WandbCallback
callbacks = [WandbCallback(...)]

## 4. Evaluate CNN vs ViT - WIP

In [None]:
# Evaluating the CNN



In [None]:
# Evaluating the ViT
outputs = trainer.predict(preprocessed_test_ds)
y_pred = outputs.predictions.argmax(1)

In [None]:
# Make a plot for the clout


#### Short response to our findings:
Was the output expected? what did we do for optimizations? is it overfit/underfit?

## 5. Retrain CNN and ViT on small, medium, and full HiRISE - WIP

## 6. Compare three CNNs vs three ViTs - WIP

In [None]:
# Plot for the clout
