# Task 1: Import Libraries

In [None]:
# prompt: connect only to a folder in google drive and mount it in the current directory

from google.colab import drive
import os

# Mount the specific folder in Google Drive
folder_name = "MyDrive/swin-transformer"
mount_point = "/content/drive"

drive.mount(mount_point)

# Change the current working directory to the mounted folder
os.chdir(os.path.join(mount_point, folder_name))

# Verify the current working directory
print(f"Current working directory: {os.getcwd()}")


In [None]:
!pip install transformers datasets
!pip install accelerate -U

In [None]:
import os

def create_folder(folder_name):
  if not os.path.exists(folder_name):
    os.makedirs(folder_name)
    print(f"Created folder: {folder_name}")
  else:
    print(f"Folder already exists: {folder_name}")

# Example usage
create_folder("gtsrb_data")
create_folder("logs")
create_folder("trained_model")


In [None]:
from PIL import Image
from datasets import load_dataset, ClassLabel
from collections import Counter
import matplotlib.pyplot as plt
import seaborn as sns
from transformers import AutoImageProcessor, Swinv2ForImageClassification, TrainingArguments, Trainer
from torchvision import transforms
import sklearn
import numpy as np
import torch
from datetime import datetime
import os
import io

# Task 2: Load the Dataset

In [None]:
from datasets import load_dataset, concatenate_datasets
# load stanford dogs dataset
path = 'gtsrb_data'
ds = load_dataset('bazyl/GTSRB', cache_dir=path)

# read labels.txt that has "class_id: class_name," mapping in each line
id2label = {}
with open(os.path.join(path, 'labels.txt'), 'r') as f:
    for line in f:
        class_id, class_name = line.split(': ')
        # remove ' and , from class_name
        class_name = class_name.replace("\n","").replace("'", "").replace(",", "")
        id2label[int(class_id)] = class_name
print(id2label)

temp_ds = ds['train']

num_train_samples_per_class = 100
label_samples = {}

# Get unique labels
labels = temp_ds.unique('ClassId')

# Select samples for each label
for label in labels:
    samples = temp_ds.filter(lambda x: x['ClassId'] == label).shuffle(seed=42).select(range(int(num_train_samples_per_class*1.1)))
    label_samples[label] = samples

# Initialize an empty dataset with the same structure as train_ds
train_val_ds = temp_ds.select([])

# Concatenate each label's samples into the new_train_ds
for label in labels:
    train_val_ds = concatenate_datasets([train_val_ds, label_samples[label]])

splits = train_val_ds.train_test_split(test_size=0.1,shuffle=True)
train_ds = splits["train"]
validation_ds = splits["test"]

# Print or use the new train dataset
print(train_ds)
test_ds = ds['test']

# map the image path to the image bytes
train_ds = train_ds.map(lambda x: {'image': Image.open(io.BytesIO(x['Path']['bytes'])).convert('RGB'), 'label': x['ClassId']})
test_ds = test_ds.map(lambda x: {'image': Image.open(io.BytesIO(x['Path']['bytes'])).convert('RGB'), 'label': x['ClassId']})
validation_ds = validation_ds.map(lambda x: {'image': Image.open(io.BytesIO(x['Path']['bytes'])).convert('RGB'), 'label': x['ClassId']})


# print the number of samples in each dataset and the number of classes
print(f"Number of samples in train dataset: {len(train_ds)}")
print(f"Number of samples in validation dataset: {len(validation_ds)}")
print(f"Number of samples in test dataset: {len(test_ds)}")
print(f"Number of classes: {len(train_ds.unique('label'))}")

# Task 3: Visualize the Dataset

In [None]:
# visualize gtsrb few samples from train_ds
def show_samples(dataset, num_samples=10):
    plt.figure(figsize=(10, 10))
    for i in range(num_samples):
        image = dataset[i]['image']
        label = dataset[i]['label']
        plt.subplot(1, num_samples, i + 1)
        plt.imshow(image)
        plt.title(f"Label: {label}")
        plt.axis("off")
    plt.show()
show_samples(train_ds)

In [None]:
train_labels = Counter(train_ds['label'])

# plot the frequency of each label in the training and validation datasets as histograms
def plot_frequency(dataset, title):
    plt.figure(figsize=(10, 5))
    sns.histplot(dataset, discrete=True)
    plt.title(title)
    plt.show()
plot_frequency(train_ds['label'], 'Train dataset')
plot_frequency(validation_ds['label'], 'Validation dataset')

# Task 4: Create a Mapping of Class Names to Index

In [None]:
label2id = {t[1] : t[0] for t in id2label.items()}
num_labels = len(id2label)

print(id2label)

# Task 5: Load the Preprocessor for the Dataset

In [None]:
PRETRAINED_MODEL_NAME = 'microsoft/swinv2-tiny-patch4-window16-256'
processor = AutoImageProcessor.from_pretrained(PRETRAINED_MODEL_NAME)

# Task 6: Define Data Augmentations

In [None]:
image_mean, image_std = processor.image_mean, processor.image_std
height, width = processor.size['height'], processor.size['width']

normalize = transforms.Normalize(mean=image_mean,std=image_std)

_train_transforms = transforms.Compose([
                                    transforms.RandomAffine(degrees=10,translate=(0.1,0.1),scale=(0.9,1.1)),
                                    transforms.Resize((height,width)),
                                    transforms.ToTensor(),
                                    normalize])

_validation_transforms = transforms.Compose([
                                    transforms.Resize((height,width)),
                                    transforms.ToTensor(),
                                    normalize])

# Task 7:  Implement Data Transformation

In [None]:
def train_transforms(batch):
    rgb_images = [img.convert('RGB') for img in batch['image']]
    batch['pixel_values'] = [_train_transforms(img) for img in rgb_images]
    return batch

def validation_transforms(batch):
    rgb_images = [img.convert('RGB') for img in batch['image']]
    batch['pixel_values'] = [_validation_transforms(img) for img in rgb_images]
    return batch

train_ds.set_transform(train_transforms)
validation_ds.set_transform(validation_transforms)
test_ds.set_transform(validation_transforms)

# Task 8: Collate the Function for DataLoader

In [None]:
def collate_fn(batch):
    pixel_values = torch.stack([sample['pixel_values'] for sample in batch])
    labels = torch.tensor([sample['label'] for sample in batch])
    return {'pixel_values' : pixel_values, 'labels' : labels}

# Task 9: Create a Model

In [None]:
model = Swinv2ForImageClassification.from_pretrained(pretrained_model_name_or_path=PRETRAINED_MODEL_NAME,
                                                num_labels = num_labels,
                                                id2label = id2label,
                                                label2id = label2id,
                                                ignore_mismatched_sizes=True)
# print the number of parameters in the model
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Number of parameters: {num_params}")
# print the model config
print(model.config)

# Task 10: Define a Metric for the Model

In [None]:
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    f1 = sklearn.metrics.f1_score(labels, predictions, average='macro') # just average the metrics
    return {'f1' : f1}

# Task 11: Set Up Trainer Arguments

In [None]:
# !pip show transformers
# !pip show accelerate
# !pip show torch
output_dir ='./trained_model'
# create unique directory name for tensorboard logs
tensorboard_dir = f'./logs/{datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}'
training_args = TrainingArguments(output_dir=output_dir,
                                                 num_train_epochs=20,
                                                 learning_rate=0.00001,
                                                 per_device_train_batch_size=32,
                                                 per_device_eval_batch_size=64,
                                                 logging_dir=tensorboard_dir,
                                                 logging_strategy='epoch',
                                                 evaluation_strategy='epoch',
                                                 remove_unused_columns=False,
                                                 load_best_model_at_end=True,
                                                 save_strategy='epoch',
                                                 metric_for_best_model="eval_loss",
                                                greater_is_better=False,
                                                warmup_ratio=0.1
                                                 )

# Task 12: Create a Trainer Object

In [None]:
class MNISTTrainer(Trainer):
    # write a custom collate function
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.data_collator = collate_fn

trainer = MNISTTrainer(model=model,
                  args=training_args,
                  compute_metrics=compute_metrics,
                  train_dataset=train_ds,
                  eval_dataset=validation_ds,
                  )

# Task 13: Evaluate the Model Before Training

In [None]:
trainer.evaluate()

# Task 14: Train the Model

In [None]:
# # freeze all the layers except the last layer
# for name, param in model.named_parameters():
#     if 'classifier' not in name:
#         param.requires_grad = False
trainer.train()

# Task 15: Visualize the Performance in TensorBoard

In [None]:
# visualize the training loss, learning rate and f1 score using tensorboard in vscode
#%load_ext tensorboard
#!tensorboard --logdir './logs'

# Task 16: Evaluate the Model

In [None]:
# clear the cache in ./trained_model
!rm -r ./trained_model/*
trainer.evaluate(test_ds)

# Task 17: Set Up the Confusion Matrix

In [None]:
# load the model in trainer with the best
trainer.model = Swinv2ForImageClassification.from_pretrained('gtsrb_data/st')

# setup the confusion matrix
from sklearn.metrics import confusion_matrix
import pandas as pd

def plot_confusion_matrix(y_true, y_pred, class_names):
    cm = confusion_matrix(y_true, y_pred)
    cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
    df_cm = pd.DataFrame(cm, index=class_names, columns=class_names)
    plt.figure(figsize=(10,7))
    sns.heatmap(df_cm, annot=True, cmap='Blues')
    plt.xlabel('Predicted labels')
    plt.ylabel('True labels')
    plt.show()

# get the predictions
predictions = trainer.predict(test_ds)
y_true = predictions.label_ids
y_pred = predictions.predictions.argmax(-1)
plot_confusion_matrix(y_true, y_pred, id2label.values())

# find the top 5 classes that gets confused the most
cm = confusion_matrix(y_true, y_pred)
np.fill_diagonal(cm, 0)
# find the 5 maximum values in the confusion matrix and get true and predicted labels
max_indices = np.argpartition(cm.flatten(), -5)[-5:] # get the indices of the 5 maximum values
max_indices = np.unravel_index(max_indices, cm.shape) # convert the indices to 2D
for i in range(5):
    true_label = max_indices[0][i]
    pred_label = max_indices[1][i]
    print(f"True label: {id2label[true_label]}, Predicted label: {id2label[pred_label]}, Confusion value: {cm[true_label][pred_label]}")

# Task 18: Save the Model and Metrics

In [None]:
# save the best model and metrics
trainer.save_model(output_dir)
trainer.save_metrics(output_dir, metrics=trainer.state.log_history[-1])
trainer.state.save_to_json(f'{output_dir}/trainer_state.json')

# Task 19: Push the model to huggingface

In [None]:
#!huggingface-cli login
#!huggingface-cli repo create swinv2-tiny-patch4-window16-256-gtsrb-ft --type model