# Task 1: Import Libraries

In [None]:
from PIL import Image
from datasets import load_dataset
from collections import Counter
import matplotlib.pyplot as plt 
import seaborn as sns
from transformers import ViTImageProcessor, ViTForImageClassification, TrainingArguments, Trainer
from torchvision import transforms
import sklearn
import numpy as np 
import torch
from datetime import datetime

# Task 2: Load the Dataset

In [None]:
# download the dataset
load_dataset('mnist',cache_dir='mnist_dataset')

path = 'mnist_dataset/'
train_validation_ds = load_dataset(path=path,split='train[:1000]')
test_ds = load_dataset(path=path,split='test[:100]')

splits = train_validation_ds.train_test_split(test_size=0.1)
train_ds = splits['train']
validation_ds = splits['test']

# Task 3: Visualize the Dataset

In [None]:
train_ds['image'][449]

In [None]:
train_labels = Counter(train_ds['label'])
plt.clf()
plt.bar(train_labels.keys(),train_labels.values())
plt.title('Training dataset')
plt.xlabel('Label')
plt.ylabel('Frequency')

val_labels = Counter(validation_ds['label'])
plt.clf()
plt.bar(val_labels.keys(),val_labels.values())
plt.title('Validation dataset')
plt.xlabel('Label')
plt.ylabel('Frequency')

# Task 4: Create a Mapping of Class Names to Index

In [None]:
id2label = {id : label for id,label in enumerate(train_ds.features['label'].names)}
label2id = {t[1] : t[0] for t in id2label.items()}
num_labels = len(id2label)

# Task 5: Load the Preprocessor for the Dataset

In [None]:
PRETRAINED_MODEL_NAME = 'google/vit-base-patch16-224'
processor = ViTImageProcessor.from_pretrained(PRETRAINED_MODEL_NAME)

# Task 6: Define Data Augmentations

In [None]:
image_mean, image_std = processor.image_mean, processor.image_std
height, width = processor.size['height'], processor.size['width']

normalize = transforms.Normalize(mean=image_mean,std=image_std)

_train_transforms = transforms.Compose([
                                    # transforms.RandomAffine(degrees=10,translate=(0.1,0.1),scale=(0.9,1.1)),
                                    transforms.Resize((height,width)),
                                    transforms.ToTensor(),
                                    normalize])

_validation_transforms = transforms.Compose([
                                    transforms.Resize((height,width)),
                                    transforms.ToTensor(),
                                    normalize])

# Task 7:  Implement Data Transformation

In [None]:
def train_transforms(batch):
    rgb_images = [img.convert('RGB') for img in batch['image']]
    batch['pixel_values'] = [_train_transforms(img) for img in rgb_images]
    return batch

def validation_transforms(batch):
    rgb_images = [img.convert('RGB') for img in batch['image']]
    batch['pixel_values'] = [_validation_transforms(img) for img in rgb_images]
    return batch

train_ds.set_transform(train_transforms)
validation_ds.set_transform(validation_transforms)
test_ds.set_transform(validation_transforms)

# Task 8: Collate the Function for DataLoader

In [None]:
def collate_fn(batch):
    pixel_values = torch.stack([sample['pixel_values'] for sample in batch])
    labels = torch.tensor([sample['label'] for sample in batch])
    return {'pixel_values' : pixel_values, 'labels' : labels}

# Task 9: Create a Model

In [None]:
model = ViTForImageClassification.from_pretrained(pretrained_model_name_or_path=PRETRAINED_MODEL_NAME, 
                                                num_labels = num_labels,
                                                id2label = id2label,
                                                label2id = label2id,
                                                ignore_mismatched_sizes=True)

# Task 10: Define a Metric for the Model

In [None]:
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    f1 = sklearn.metrics.f1_score(labels, predictions, average='macro') # just average the metrics
    return {'f1' : f1}

# Task 11: Set Up Trainer Arguments

In [None]:
output_dir ='./trained_model'
# create unique directory name for tensorboard logs
tensorboard_dir = f'./logs/{datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}'
training_args = TrainingArguments(output_dir=output_dir,
                                                 num_train_epochs=20,
                                                 learning_rate=0.00001,
                                                 per_device_train_batch_size=16,
                                                 per_device_eval_batch_size=64,
                                                 logging_dir=tensorboard_dir,
                                                 logging_strategy='epoch',
                                                 evaluation_strategy='epoch',
                                                 remove_unused_columns=False,
                                                 load_best_model_at_end=True,
                                                 save_strategy='epoch',
                                                 metric_for_best_model="eval_loss",
                                                greater_is_better=False,
                                                warmup_ratio=0.1
                                                 )

# Task 12: Create a Trainer Object

In [None]:
class MNISTTrainer(Trainer):
    # write a custom collate function
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.data_collator = collate_fn

trainer = MNISTTrainer(model=model,
                  args=training_args,
                  compute_metrics=compute_metrics,
                  train_dataset=train_ds,
                  eval_dataset=validation_ds,
                  )   

# Task 13: Evaluate the Model Before Training

In [None]:
trainer.evaluate()

# Task 14: Train the Model

In [None]:
# # freeze all the layers except the last layer
# for name, param in model.named_parameters():
#     if 'classifier' not in name:
#         param.requires_grad = False

trainer.train()

# Task 15: Visualize the Performance in TensorBoard

In [None]:
# visualize the training loss, learning rate and f1 score using tensorboard in vscode
%load_ext tensorboard
#!tensorboard --logdir './logs'

# Task 16: Evaluate the Model

In [None]:
# clear the cache in ./trained_model
!rm -r ./trained_model/*
trainer.evaluate(validation_ds)

# Task 17: Set Up the Confusion Matrix

# Task 18: Save the Model and Metrics

# Task 19: Set Up an Inference for the Model