# Imports

In [None]:

from google.colab import drive
drive.mount('/content/drive')

import os
import json

import matplotlib
import matplotlib.pyplot as plt
import matplotlib_inline.backend_inline
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data
import torchvision
from torchvision import transforms
from torchvision.datasets import CIFAR10
import math
from prettytable import PrettyTable



Mounted at /content/drive


# Our ViT modules

In [None]:
import sys
sys.path.append('/content/drive/MyDrive/ViT-codes')
from Modified_vit import ViTForClassification
from Modified_utils import visualize_attention, prepare_data,save_checkpoint,save_experiment,load_experiment, Modified_Trainer

In [None]:
config = {
    "patch_size": 7,
    "embed_dim": 32,
    "num_hidden_layers": 4,
    "num_attention_heads": 3,
    "hidden_dim": 64,  # Adjusted as 2 * embed_dim
    "dropout_val": 0.1,
    "attention_probs_dropout_prob": 0.1,
    "initializer_range": 0.02,
    "image_size": 28,
    "num_classes": 10,
    "num_channels": 1,
    "qkv_bias": True,
    "use_faster_attention": True,
    "attention_block_index": 1  # Specify the block index you want to observe
}

model = ViTForClassification(config)

save_model_every_n_epochs = 5
exp_name = "ViT_CIFAR100"
batch_size = 64
epochs = 5
lr = 0.001
device = "cuda" if torch.cuda.is_available() else "cpu"
save_model_every_n_epochs = 5

trainloader, testloader, _ = prepare_data(batch_size)

optimizer = optim.Adam(model.parameters(), lr=0.001)
loss_fn = nn.CrossEntropyLoss()
trainer = Modified_Trainer(model, optimizer, loss_fn, "ViT_Experiment", device,config)
attention_matrices = trainer.train(trainloader, testloader, epochs,save_model_every_n_epochs, output_attentions=True)


NameError: name 'ViTForClassification' is not defined

# Examining Attention Matrices:

Note that the input batch is of 64 and MNIST has a total of (4 x 4) patches in which we add + 1 CLS token. The dimension of each patch is (7 x 7) in our configuration. Then after flattening, we would have a dimension of 49 which we would embed in a dimension of 32 using E. Our input then is the following to the attention block: [64, 16, 32].The Query, Key would be of size [64, 16, 32] . The attention matrix should have a shape of [64, 17, 17], representing the attention score of each patch with respect to every other patch. Now, each attention head computes its own set of Key (K), Query (Q), and Value (V) matrices. The dimensions of each of these matrices per head are the same as described earlier: [batch size, number of patches, embedding dimension per head]. For n heads, we therefore obtain the following as attention matrix:
- [64, n, 17, 17]  

where n is the number of attention-heads.

Now, the first index of `attention-matrix` highlights the number of epochs. For example, if the epochs are number 5, then the index of `attention-matrix` would run from 0 to 4. The second index highlights the number of **tensors** collected in an epoch. Note that MNSIT contains 60000 images and our batch size is 64. If you divide 60000/64, you obtain 938 tensors. These are the exact number of values stored as you can verify if you print `attention_matrices[0][937]` you obtain a tensor but for `attention_matrices[0][938]` you get an index error.


In [None]:
attention_matrix = attention_matrices[0][0]
print(attention_matrix[0].shape) # this is where the attention is [64,3,17,17]

torch.Size([64, 3, 17, 17])


In [None]:
attention_matrix = attention_matrices[2][937]
print(attention_matrix[0].shape) # this is where the attention is [64,3,17,17]

torch.Size([32, 3, 17, 17])


# Examining Attention matrix in a particular head

Lets examine the attention matrix for epoch 3, batch 32, image 3, head 2:


In [None]:
attention_matrices_index = attention_matrices[4][55] # epoch 2 and batch 32 tensor
attention_matrix_at_indx = attention_matrices_index[0] # the actual index
print(attention_matrix_at_indx.shape)

torch.Size([64, 3, 17, 17])


In [None]:
attention_matrices_img = attention_matrix_at_indx[31] # for image 32

In [None]:
attention_matrix_head = attention_matrices_img[2] # attention matrix for head 2
print(attention_matrix_head.shape)

torch.Size([17, 17])


In [None]:
print(attention_matrix_head)


tensor([[0.0144, 0.0152, 0.0117, 0.0095, 0.0059, 0.0249, 0.0443, 0.0182, 0.0896,
         0.0177, 0.4442, 0.1135, 0.0040, 0.0358, 0.0361, 0.0996, 0.0154],
        [0.0324, 0.0601, 0.0570, 0.0642, 0.0581, 0.0711, 0.0451, 0.0577, 0.0773,
         0.0566, 0.0710, 0.0633, 0.0458, 0.0418, 0.0574, 0.0815, 0.0597],
        [0.0590, 0.0577, 0.0563, 0.0610, 0.0711, 0.0623, 0.0491, 0.0773, 0.0640,
         0.0523, 0.0443, 0.0475, 0.0535, 0.0449, 0.0726, 0.0718, 0.0552],
        [0.1068, 0.0620, 0.0705, 0.0705, 0.1108, 0.0466, 0.0424, 0.0611, 0.0264,
         0.0580, 0.0094, 0.0187, 0.1290, 0.0492, 0.0564, 0.0227, 0.0597],
        [0.0360, 0.0361, 0.0303, 0.0350, 0.0305, 0.0511, 0.0604, 0.0678, 0.0867,
         0.0333, 0.1370, 0.0891, 0.0184, 0.0394, 0.0862, 0.1277, 0.0350],
        [0.0264, 0.0518, 0.0465, 0.0582, 0.0457, 0.0667, 0.0536, 0.0603, 0.0755,
         0.0476, 0.0970, 0.0817, 0.0350, 0.0388, 0.0641, 0.0985, 0.0525],
        [0.0657, 0.0745, 0.0932, 0.0765, 0.1357, 0.0507, 0.0218, 0.038

In [None]:
row_sums = torch.sum(attention_matrix_head, dim=-1)
print("column sums are", row_sums) # AFTER REMOVING DROPOUT. WE ARE GETTING THE PROPER VALUES with each row summing to 1.

column sums are tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
       device='cuda:0', grad_fn=<SumBackward1>)


# Implementing the Teacher-Student Trainer

We setup a trainer where we first train a model for 100 epochs and then use it to distill knowledge to a student model.

In [None]:
class TS_Trainer:
    def __init__(self, teacher_model, student_model, optimizer_teacher, optimizer_student, loss_fn, exp_name, device, base_dir="experiments"):
        self.teacher_model = teacher_model.to(device)
        self.student_model = student_model.to(device)

        self.optimizer_teacher = optimizer_teacher
        self.optimizer_student = optimizer_student
        self.loss_fn = loss_fn
        self.exp_name = exp_name
        self.device = device
        self.base_dir = base_dir


    def train_epoch_teacher(self, trainloader):
        self.teacher_model.train()
        total_loss = 0
        for _, (images, labels) in enumerate(trainloader):
            images, labels = images.to(self.device), labels.to(self.device)
            self.optimizer_teacher.zero_grad()
            logits = self.teacher_model(images)[0]
            loss = self.loss_fn(logits, labels)
            loss.backward()
            self.optimizer_teacher.step()
            total_loss += loss.item() * len(images)
        return total_loss / len(trainloader.dataset)

    def train_epoch_student(self,trainloader):
        self.student_model.train()
        total_loss = 0
        for _, (images, labels) in enumerate(trainloader):
          images, labels = images.to(self.device), labels.to(self.device)
          self.optimizer_student.zero_grad()
          # Forward passes
          student_logits, student_attention = self.student_model(images, output_attentions=True)
          with torch.no_grad(): # ENSURE THAT WE DON'T TRAIN THE TEACHER.
            _, teacher_attention = self.teacher_model(images, output_attentions=True)

          # Compute losses
          classification_loss = self.loss_fn(student_logits, labels)

          # Here we have to be careful. The actual student attention and
          # teacher attention are found in the first index [0].
          #They should be of dimensions [64 x 4 x 17 x 17]
          #print(student_attention[0].shape)
          #print(teacher_attention[0].shape)

          student_attention_block = student_attention[0]
          teacher_attention_block = teacher_attention[0]
          attention_loss = nn.CrossEntropyLoss()(student_attention_block, teacher_attention_block)
          #print("attention loss is ",attention_loss)
          #print("classification loss is ",classification_loss )
          loss = classification_loss + attention_loss

          # Backward and optimize
          loss.backward()
          self.optimizer_student.step()
          total_loss += loss.item() * images.size(0)

        total_loss = total_loss / len(trainloader.dataset)
        classification_loss = classification_loss / len(trainloader.dataset)
        # Normalizing attention is reducing the value by too much.
        # However, that might just be because for now I am testing on a teacher model
        # that has not been trained properly.

        #attention_loss = attention_loss / len(trainloader.dataset)
        print("attention loss for one batch ",attention_loss)
        print("classification loss for one batch ",classification_loss )

        return total_loss, classification_loss, attention_loss


    def evaluate_teacher(self, testloader):
        self.teacher_model.eval()
        total_loss, correct = 0, 0
        for _, (images, labels) in enumerate(testloader):
            images, labels = images.to(self.device), labels.to(self.device)
            logits, _ = self.teacher_model(images)
            loss = self.loss_fn(logits, labels)
            total_loss += loss.item() * len(images)
            predictions = torch.argmax(logits, dim=1)
            correct += torch.sum(predictions == labels).item()
        accuracy = correct / len(testloader.dataset)
        avg_loss = total_loss / len(testloader.dataset)
        return accuracy, avg_loss

    def evaluate_student(self, testloader):
        self.student_model.eval()
        total_loss, correct = 0, 0
        for _, (images, labels) in enumerate(testloader):
            images, labels = images.to(self.device), labels.to(self.device)
            logits, _ = self.student_model(images)
            # WE COMPUTE ONLY CLASSIFICATION LOSS WHEN EVALUATING STUDENT.
            # NO DISTILLATION LOSS COMPUTED HERE.
            loss = self.loss_fn(logits, labels)
            total_loss += loss.item() * len(images)
            predictions = torch.argmax(logits, dim=1)
            correct += torch.sum(predictions == labels).item()
        accuracy = correct / len(testloader.dataset)
        avg_loss = total_loss / len(testloader.dataset)
        return accuracy, avg_loss




    def train_teacher(self, trainloader,testloader, epochs,save_model_every_n_epochs=0):

        train_losses, test_losses, accuracies = [], [], []
        self.teacher_model.train()
        for i in range(epochs):
          train_loss = self.train_epoch_teacher(trainloader)
          accuracy, test_loss = self.evaluate_teacher(testloader)
          train_losses.append(train_loss)
          test_losses.append(test_loss)
          accuracies.append(accuracy)
          print(f"Epoch: {i+1}, Teacher Train loss: {train_loss:.4f}, Test loss: {test_loss:.4f}, Accuracy: {accuracy:.4f}")

          # Save checkpoint if required
          if save_model_every_n_epochs > 0 and (i+1) % save_model_every_n_epochs == 0:
              save_checkpoint(self.exp_name, self.teacher_model, f"epoch_{i+1}", base_dir=self.base_dir)
              print(f'\tSaving checkpoint at epoch {i+1}')

        # Save the final model and experiment details at the end of training
        save_experiment(self.exp_name, config, self.teacher_model, train_losses, test_losses, accuracies, base_dir=self.base_dir)
        print(f'teacher_final and experiment details saved under {self.exp_name}')

# THIS IS THE ACTUAL NOVELTY. WE EMPLEMENT THIS FROM SCRATCH.
    def train_student(self, trainloader,testloader, epoch_student,save_model_every_n_epochs):
        train_losses, test_losses, attention_losses, accuracies = [], [], [],[]
        self.student_model.train()
        for i in range(epochs):
          train_loss, classification_loss, attention_loss = self.train_epoch_student(trainloader)
          accuracy, test_loss = self.evaluate_student(testloader)
          train_losses.append(train_loss)
          test_losses.append(test_loss)
          attention_losses.append(attention_loss)
          accuracies.append(accuracy)
          print(f"Epoch: {i+1}, Student Train loss: {train_loss:.4f}, Test loss: {test_loss:.4f},Attention Loss: {attention_loss:.4f}, Accuracy: {accuracy:.4f}")

          # Save checkpoint if required
          if save_model_every_n_epochs > 0 and (i+1) % save_model_every_n_epochs == 0:
              save_checkpoint(self.exp_name, self.student_model, f"epoch_{i+1}", base_dir=self.base_dir)
              print(f'\tSaving checkpoint at epoch {i+1}')

        # Save the final model and experiment details at the end of training
        save_experiment(self.exp_name, config, self.student_model, train_losses, test_losses, accuracies, base_dir=self.base_dir)
        print(f'student_final and experiment details saved under {self.exp_name}')

In [None]:
# FIRST TRAINING A SMALLER MODEL NORMALLY:

config = {
    "patch_size": 7,
    "embed_dim": 32,
    "num_hidden_layers": 2,
    "num_attention_heads": 4,
    "hidden_dim": 64,  # Adjusted as 2 * embed_dim
    "dropout_val": 0.1,
    "attention_probs_dropout_prob": 0.1,
    "initializer_range": 0.02,
    "image_size": 28,
    "num_classes": 10,
    "num_channels": 1,
    "qkv_bias": True,
    "use_faster_attention": True,
    "attention_block_index": 1  # Specify the block index you want to observe
}

model = ViTForClassification(config)

save_model_every_n_epochs = 5
exp_name = "ViT_CIFAR100"
batch_size = 64
epochs = 20
lr = 0.001
device = "cuda" if torch.cuda.is_available() else "cpu"
save_model_every_n_epochs = 5

trainloader, testloader, _ = prepare_data(batch_size)

optimizer = optim.Adam(model.parameters(), lr=0.001)
loss_fn = nn.CrossEntropyLoss()
trainer = Modified_Trainer(model, optimizer, loss_fn, "ViT_Experiment", device,config)
trainer.train(trainloader, testloader, epochs,save_model_every_n_epochs, output_attentions=False)

Epoch: 1, Train loss: 1.5082, Test loss: 0.9648, Accuracy: 0.6519
Epoch: 2, Train loss: 0.6869, Test loss: 0.3586, Accuracy: 0.8916
Epoch: 3, Train loss: 0.3789, Test loss: 0.2567, Accuracy: 0.9193
Epoch: 4, Train loss: 0.2987, Test loss: 0.1924, Accuracy: 0.9384
Epoch: 5, Train loss: 0.2583, Test loss: 0.1706, Accuracy: 0.9465
	Saving checkpoint at epoch 5
Epoch: 6, Train loss: 0.2233, Test loss: 0.1483, Accuracy: 0.9523
Epoch: 7, Train loss: 0.2030, Test loss: 0.1357, Accuracy: 0.9542
Epoch: 8, Train loss: 0.1848, Test loss: 0.1246, Accuracy: 0.9603
Epoch: 9, Train loss: 0.1684, Test loss: 0.1074, Accuracy: 0.9643
Epoch: 10, Train loss: 0.1584, Test loss: 0.1018, Accuracy: 0.9654
	Saving checkpoint at epoch 10
Epoch: 11, Train loss: 0.1474, Test loss: 0.1060, Accuracy: 0.9672
Epoch: 12, Train loss: 0.1373, Test loss: 0.1000, Accuracy: 0.9688
Epoch: 13, Train loss: 0.1317, Test loss: 0.1025, Accuracy: 0.9690
Epoch: 14, Train loss: 0.1258, Test loss: 0.0870, Accuracy: 0.9718
Epoch: 15,

In [None]:
config_teacher = {
    "patch_size": 7,
    "embed_dim": 32,
    "num_hidden_layers": 4,
    "num_attention_heads": 4,
    "hidden_dim": 64,  # Adjusted as 2 * embed_dim
    "dropout_val": 0.1,
    "attention_probs_dropout_prob": 0.1,
    "initializer_range": 0.02,
    "image_size": 28,
    "num_classes": 10,
    "num_channels": 1,
    "qkv_bias": True,
    "use_faster_attention": True,
    "attention_block_index": 1  # Specify the block index you want to observe
}

config_student = {
    "patch_size": 7,
    "embed_dim": 32,
    "num_hidden_layers": 2,
    "num_attention_heads": 4,
    "hidden_dim": 64,  # Adjusted as 2 * embed_dim
    "dropout_val": 0.1,
    "attention_probs_dropout_prob": 0.1,
    "initializer_range": 0.02,
    "image_size": 28,
    "num_classes": 10,
    "num_channels": 1,
    "qkv_bias": True,
    "use_faster_attention": True,
    "attention_block_index": 1  # We extract attention from first block
}


exp_name = "ViT_Student_teacher"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


teacher_model = ViTForClassification(config_teacher)
student_model = ViTForClassification(config_student)

optimizer_teacher = optim.Adam(teacher_model.parameters(), lr=0.001)
optimizer_student = optim.Adam(student_model.parameters(), lr=0.001)
loss_fn = nn.CrossEntropyLoss()

save_model_every_n_epochs = 5
epoch_teacher = 20  # Adjust the number of epochs as needed
epoch_student = 20  # Adjust the number of epochs as needed
batch_size = 64  # Adjust the batch size as needed
trainloader, testloader, _ = prepare_data(batch_size)



trainer = TS_Trainer(teacher_model, student_model, optimizer_teacher, optimizer_student, loss_fn,exp_name, device,base_dir="experiments")
trainer.train_teacher(trainloader,testloader, epoch_teacher,save_model_every_n_epochs)  # Train teacher for 100 epochs
trainer.train_student(trainloader,testloader, epoch_student,save_model_every_n_epochs)  # Train student for 100 epochs

Epoch: 1, Teacher Train loss: 1.4118, Test loss: 0.8152, Accuracy: 0.7263
Epoch: 2, Teacher Train loss: 0.6444, Test loss: 0.3684, Accuracy: 0.8889
Epoch: 3, Teacher Train loss: 0.3665, Test loss: 0.2782, Accuracy: 0.9130
Epoch: 4, Teacher Train loss: 0.2769, Test loss: 0.1715, Accuracy: 0.9473
Epoch: 5, Teacher Train loss: 0.2311, Test loss: 0.1651, Accuracy: 0.9484
	Saving checkpoint at epoch 5
Epoch: 6, Teacher Train loss: 0.2005, Test loss: 0.1458, Accuracy: 0.9565
Epoch: 7, Teacher Train loss: 0.1743, Test loss: 0.1176, Accuracy: 0.9632
Epoch: 8, Teacher Train loss: 0.1606, Test loss: 0.1223, Accuracy: 0.9644
Epoch: 9, Teacher Train loss: 0.1492, Test loss: 0.1033, Accuracy: 0.9665
Epoch: 10, Teacher Train loss: 0.1366, Test loss: 0.1070, Accuracy: 0.9663
	Saving checkpoint at epoch 10
Epoch: 11, Teacher Train loss: 0.1318, Test loss: 0.1001, Accuracy: 0.9686
Epoch: 12, Teacher Train loss: 0.1209, Test loss: 0.0939, Accuracy: 0.9718
Epoch: 13, Teacher Train loss: 0.1156, Test loss