## **Load Libraries**

In [1]:
import torch
import copy
import torch.nn as nn
from torch.utils.data import DataLoader, random_split, Subset
from torchvision import datasets, transforms, models
from tqdm import tqdm

from sklearn.metrics import confusion_matrix, classification_report

import warnings
warnings.filterwarnings('ignore')

## **Parameter Initialization**

In [2]:
batch_size = 16
learning_rate = 0.001
epochs = 10
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## **Data Loading and Preparation**

In [3]:

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5])
])

def load_dataset():
  train_data = datasets.CIFAR10(root='/content/data', train=True, download=True, transform=transform)
  train_size = int(0.8 * len(train_data))
  val_size = len(train_data) - train_size
  train_data, val_data = random_split(train_data, [train_size, val_size])
  test_data = datasets.CIFAR10(root='/content/data', train=False, download=True, transform=transforms.ToTensor())

  train_subset_indices, test_subset_indices = list(range(batch_size*4 * 24)), list(range(batch_size*4 * 8))
  train_subset = Subset(train_data, train_subset_indices)
  val_data = Subset(val_data, test_subset_indices)
  test_data = Subset(test_data, test_subset_indices)

  return train_subset, val_data, test_data

def prepare_data(train_data, val_data, test_data=None):
  train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
  val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=False)
  if test_data:
    test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)
    return train_loader, val_loader, test_loader

  return train_loader, val_loader



## **Model Loading**

In [4]:
def load_model(pretrained=True, width_mult=1.0):
  model = models.mobilenet_v2(pretrained=pretrained, width_mult=width_mult)
  model.classifier[1] = nn.Linear(model.last_channel, 10)
  return model.to(device)


## **Evaluation**

In [None]:
def evaluate_model(model, test_loader):
    model.eval()
    all_labels = []
    all_predictions = []

    with torch.no_grad():
        for images, labels in tqdm(test_loader):
            images, labels = images.to(device), labels
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)

            all_labels.extend(labels.numpy())
            all_predictions.extend(predicted.cpu().numpy())

    # Confusion matrix
    cm = confusion_matrix(all_labels, all_predictions)
    print("Confusion Matrix:")
    print(cm)

    # Classification report (Precision, Recall, F1-score)
    report = classification_report(all_labels, all_predictions)
    print("Classification Report:")
    print(report)

## **Fine-tuning**

In [6]:
def train_model(model, train_loader, val_loader, epochs, gradient_acc_steps = 256):
  criterion = nn.CrossEntropyLoss()
  optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

  best_model = copy.deepcopy(model).to(device)
  best_accuracy = 0.0

  for epoch in range(epochs):
    print(f"*** Epoch: {epoch+1} ***")
    model.train()
    running_loss = 0.0
    n_steps = 0
    bar = tqdm(train_loader)
    for images, labels in bar:

      images, labels = images.to(device), labels.to(device)
      outputs = model(images)

      unit_loss = criterion(outputs, labels)*len(images)/gradient_acc_steps
      unit_loss.backward()
      running_loss += unit_loss.detach().item()

      n_steps += len(images)
      if n_steps % gradient_acc_steps == 0:
        nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        bar.set_description(f"Loss: {running_loss:.5f}")

        optimizer.zero_grad()
        running_loss = 0.0

    model.eval()
    with torch.no_grad():
      correct = 0
      total = 0
      for images, labels in val_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

      val_accuracy = 100 * correct / total

      print(f"Validation Accuracy: {val_accuracy:.2f}%")

      if val_accuracy > best_accuracy:
        best_accuracy = val_accuracy
        best_model = copy.deepcopy(model).to(device)

  return best_model

## **Model Distillation**

In [7]:
def distillation_loss(teacher_outputs, student_outputs, labels, alpha, temperature):
  teacher_probs = nn.functional.softmax(teacher_outputs / temperature, dim=1)
  student_probs = nn.functional.log_softmax(student_outputs / temperature, dim=1)
  kd_loss = nn.functional.kl_div(student_probs, teacher_probs, reduction='batchmean')*(temperature**2)
  ce_loss = nn.functional.cross_entropy(student_outputs, labels)
  return alpha * kd_loss + (1 - alpha) * ce_loss


def model_distillation(teacher_model, student_model, train_loader, val_loader,
                    epochs, alpha=0.5, temperature=3, gradient_acc_steps=256):

  optimizer = torch.optim.Adam(student_model.parameters(), lr=learning_rate)


  best_model = copy.deepcopy(student_model).to(device)
  best_accuracy = 0.0

  teacher_model.eval()

  for epoch in range(epochs):
    print(f"*** Epoch: {epoch+1} ***")
    student_model.train()
    running_loss = 0.0

    n_steps = 0
    bar = tqdm(train_loader)
    for images, labels in bar:
      images, labels = images.to(device), labels.to(device)
      with torch.no_grad():
        teacher_outputs = teacher_model(images)
      student_outputs = student_model(images)

      unit_loss = distillation_loss(
          teacher_outputs, student_outputs, labels, alpha, temperature
      )*len(images)/gradient_acc_steps
      unit_loss.backward()
      running_loss += unit_loss.detach().item()

      n_steps += len(images)
      if n_steps % gradient_acc_steps == 0:
        nn.utils.clip_grad_norm_(student_model.parameters(), 1.0)
        optimizer.step()

        bar.set_description(f"Loss: {running_loss:.5f}")

        optimizer.zero_grad()
        running_loss = 0.0

    student_model.eval()
    with torch.no_grad():
      teacher_correct = 0
      student_correct = 0
      total = 0
      for images, labels in val_loader:
        images, labels = images.to(device), labels.to(device)
        teacher_outputs = teacher_model(images)
        student_outputs = student_model(images)
        _, teacher_predicted = torch.max(teacher_outputs.data, 1)
        _, student_predicted = torch.max(student_outputs.data, 1)
        total += labels.size(0)
        teacher_correct += (teacher_predicted == labels).sum().item()
        student_correct += (student_predicted == labels).sum().item()


      val_accuracy = 100 * student_correct / total

      if val_accuracy > best_accuracy:
        best_accuracy = val_accuracy
        best_model = copy.deepcopy(student_model).to(device)

      print(f"Validation Accuracy:- Teacher: {100*teacher_correct/total:.2f}%; Student: {val_accuracy:.2f}%")

  return best_model

## **main**

In [8]:
if __name__ == '__main__':
  train_data, val_data, test_data = load_dataset()
  print("Train, validation, and test data size:")
  print(len(train_data), len(val_data), len(test_data))
  train_loader, val_loader, test_loader = prepare_data(train_data, val_data, test_data)
  print("Train, validation, and test data loader size:")
  print(len(train_loader), len(val_loader), len(test_loader))

  teacher_model = load_model()
  student_model = load_model(pretrained=False,width_mult=0.25)

  print("Initial teacher Model Performance:")
  evaluate_model(teacher_model, val_loader)
  print("\nInitial Student Model Perfomance:")
  evaluate_model(student_model, val_loader)

  print("\nTraining Teacher Model:")
  teacher_model = train_model(teacher_model, train_loader, val_loader, 10)
  print("\nTeacher Model Performance after training:")
  evaluate_model(teacher_model, val_loader)

  print("\nTraining Student Model:")
  student_model = train_model(student_model, train_loader, val_loader, 50)
  print("\nStudent Model Performance after training:")
  evaluate_model(student_model, val_loader)

  print("\nDistillation:")
  student_model_distilled = load_model(pretrained=False, width_mult=0.25)
  distilled_model = model_distillation(teacher_model, student_model_distilled, train_loader, val_loader, 50)
  print("\nStudent Model Performance after distillation:")
  evaluate_model(distilled_model, val_loader)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to /content/data/cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:03<00:00, 48.7MB/s]


Extracting /content/data/cifar-10-python.tar.gz to /content/data
Files already downloaded and verified
Train, validation, and test data size:
1536 512 512
Train, validation, and test data loader size:
96 32 32


Downloading: "https://download.pytorch.org/models/mobilenet_v2-b0353104.pth" to /root/.cache/torch/hub/checkpoints/mobilenet_v2-b0353104.pth
100%|██████████| 13.6M/13.6M [00:00<00:00, 62.1MB/s]


Initial teacher Model Performance:


100%|██████████| 32/32 [00:04<00:00,  7.39it/s]


Confusion Matrix:
[[ 8 32  0  4  8  0  0  0  0  0]
 [18 15  6  4 14  0  0  0  0  0]
 [ 8 31  1  2 13  0  0  1  0  0]
 [ 9 26  5  2 13  0  0  0  0  0]
 [ 7 20 10  1 10  0  0  0  0  0]
 [ 4 14  8  0 17  0  0  0  0  0]
 [ 7 24  1  1  7  0  0  0  0  0]
 [ 9 19  8  2 17  0  0  0  0  0]
 [ 3 36  2  4  5  0  0  1  0  0]
 [17 19 10  8  1  0  0  0  0  0]]
Classification Report:
              precision    recall  f1-score   support

           0       0.09      0.15      0.11        52
           1       0.06      0.26      0.10        57
           2       0.02      0.02      0.02        56
           3       0.07      0.04      0.05        55
           4       0.10      0.21      0.13        48
           5       0.00      0.00      0.00        43
           6       0.00      0.00      0.00        40
           7       0.00      0.00      0.00        55
           8       0.00      0.00      0.00        51
           9       0.00      0.00      0.00        55

    accuracy                    

100%|██████████| 32/32 [00:01<00:00, 26.76it/s]


Confusion Matrix:
[[ 0  0  0  0  0  0  0 52  0  0]
 [ 0  0  0  0  0  0  0 57  0  0]
 [ 0  0  0  0  0  0  0 56  0  0]
 [ 0  0  0  0  0  0  0 55  0  0]
 [ 0  0  0  0  0  0  0 48  0  0]
 [ 0  0  0  0  0  0  0 43  0  0]
 [ 0  0  0  0  0  0  0 40  0  0]
 [ 0  0  0  0  0  0  0 55  0  0]
 [ 0  0  0  0  0  0  0 51  0  0]
 [ 0  0  0  0  0  0  0 55  0  0]]
Classification Report:
              precision    recall  f1-score   support

           0       0.00      0.00      0.00        52
           1       0.00      0.00      0.00        57
           2       0.00      0.00      0.00        56
           3       0.00      0.00      0.00        55
           4       0.00      0.00      0.00        48
           5       0.00      0.00      0.00        43
           6       0.00      0.00      0.00        40
           7       0.11      1.00      0.19        55
           8       0.00      0.00      0.00        51
           9       0.00      0.00      0.00        55

    accuracy                    

Loss: 1.05923: 100%|██████████| 96/96 [00:08<00:00, 11.30it/s]


Validation Accuracy: 71.29%
*** Epoch: 2 ***


Loss: 0.48875: 100%|██████████| 96/96 [00:07<00:00, 12.38it/s]


Validation Accuracy: 80.08%
*** Epoch: 3 ***


Loss: 0.30681: 100%|██████████| 96/96 [00:07<00:00, 12.75it/s]


Validation Accuracy: 78.52%
*** Epoch: 4 ***


Loss: 0.20244: 100%|██████████| 96/96 [00:07<00:00, 12.36it/s]


Validation Accuracy: 80.47%
*** Epoch: 5 ***


Loss: 0.14286: 100%|██████████| 96/96 [00:07<00:00, 12.38it/s]


Validation Accuracy: 76.95%
*** Epoch: 6 ***


Loss: 0.14228: 100%|██████████| 96/96 [00:07<00:00, 12.80it/s]


Validation Accuracy: 79.30%
*** Epoch: 7 ***


Loss: 0.09280: 100%|██████████| 96/96 [00:08<00:00, 10.88it/s]


Validation Accuracy: 79.69%
*** Epoch: 8 ***


Loss: 0.10597: 100%|██████████| 96/96 [00:07<00:00, 12.20it/s]


Validation Accuracy: 80.47%
*** Epoch: 9 ***


Loss: 0.10677: 100%|██████████| 96/96 [00:08<00:00, 11.90it/s]


Validation Accuracy: 80.08%
*** Epoch: 10 ***


Loss: 0.08899: 100%|██████████| 96/96 [00:07<00:00, 12.02it/s]


Validation Accuracy: 79.88%

Teacher Model Performance after training:


100%|██████████| 32/32 [00:01<00:00, 20.29it/s]


Confusion Matrix:
[[46  1  2  0  0  0  0  0  3  0]
 [ 1 50  0  0  0  0  1  0  1  4]
 [ 3  0 43  2  1  3  1  3  0  0]
 [ 1  0  3 40  2  8  1  0  0  0]
 [ 2  0  3  1 35  2  1  4  0  0]
 [ 1  0  1  7  3 29  0  2  0  0]
 [ 1  0  2  1  2  0 34  0  0  0]
 [ 0  0  2  0  2  3  0 48  0  0]
 [ 7  2  0  1  0  0  0  1 40  0]
 [ 4  1  0  1  0  0  1  0  1 47]]
Classification Report:
              precision    recall  f1-score   support

           0       0.70      0.88      0.78        52
           1       0.93      0.88      0.90        57
           2       0.77      0.77      0.77        56
           3       0.75      0.73      0.74        55
           4       0.78      0.73      0.75        48
           5       0.64      0.67      0.66        43
           6       0.87      0.85      0.86        40
           7       0.83      0.87      0.85        55
           8       0.89      0.78      0.83        51
           9       0.92      0.85      0.89        55

    accuracy                    

Loss: 2.20450: 100%|██████████| 96/96 [00:05<00:00, 16.28it/s]


Validation Accuracy: 15.62%
*** Epoch: 2 ***


Loss: 2.10749: 100%|██████████| 96/96 [00:05<00:00, 18.55it/s]


Validation Accuracy: 17.97%
*** Epoch: 3 ***


Loss: 2.07493: 100%|██████████| 96/96 [00:04<00:00, 19.62it/s]


Validation Accuracy: 20.90%
*** Epoch: 4 ***


Loss: 2.03340: 100%|██████████| 96/96 [00:05<00:00, 19.12it/s]


Validation Accuracy: 22.46%
*** Epoch: 5 ***


Loss: 1.97457: 100%|██████████| 96/96 [00:05<00:00, 19.07it/s]


Validation Accuracy: 22.46%
*** Epoch: 6 ***


Loss: 1.83396: 100%|██████████| 96/96 [00:05<00:00, 18.99it/s]


Validation Accuracy: 25.39%
*** Epoch: 7 ***


Loss: 1.75444: 100%|██████████| 96/96 [00:04<00:00, 19.34it/s]


Validation Accuracy: 28.91%
*** Epoch: 8 ***


Loss: 1.71557: 100%|██████████| 96/96 [00:04<00:00, 19.77it/s]


Validation Accuracy: 29.30%
*** Epoch: 9 ***


Loss: 1.71089: 100%|██████████| 96/96 [00:05<00:00, 18.62it/s]


Validation Accuracy: 29.49%
*** Epoch: 10 ***


Loss: 1.77115: 100%|██████████| 96/96 [00:04<00:00, 19.92it/s]


Validation Accuracy: 32.42%
*** Epoch: 11 ***


Loss: 1.76063: 100%|██████████| 96/96 [00:05<00:00, 18.01it/s]


Validation Accuracy: 35.35%
*** Epoch: 12 ***


Loss: 1.57590: 100%|██████████| 96/96 [00:04<00:00, 19.88it/s]


Validation Accuracy: 33.59%
*** Epoch: 13 ***


Loss: 1.50320: 100%|██████████| 96/96 [00:05<00:00, 17.99it/s]


Validation Accuracy: 35.16%
*** Epoch: 14 ***


Loss: 1.51017: 100%|██████████| 96/96 [00:04<00:00, 20.15it/s]


Validation Accuracy: 35.35%
*** Epoch: 15 ***


Loss: 1.64003: 100%|██████████| 96/96 [00:05<00:00, 17.39it/s]


Validation Accuracy: 38.28%
*** Epoch: 16 ***


Loss: 1.45530: 100%|██████████| 96/96 [00:04<00:00, 20.00it/s]


Validation Accuracy: 39.65%
*** Epoch: 17 ***


Loss: 1.44463: 100%|██████████| 96/96 [00:05<00:00, 17.17it/s]


Validation Accuracy: 35.94%
*** Epoch: 18 ***


Loss: 1.42594: 100%|██████████| 96/96 [00:04<00:00, 19.85it/s]


Validation Accuracy: 35.94%
*** Epoch: 19 ***


Loss: 1.42978: 100%|██████████| 96/96 [00:05<00:00, 17.28it/s]


Validation Accuracy: 38.09%
*** Epoch: 20 ***


Loss: 1.41897: 100%|██████████| 96/96 [00:04<00:00, 20.17it/s]


Validation Accuracy: 36.91%
*** Epoch: 21 ***


Loss: 1.45703: 100%|██████████| 96/96 [00:05<00:00, 17.06it/s]


Validation Accuracy: 39.26%
*** Epoch: 22 ***


Loss: 1.47177: 100%|██████████| 96/96 [00:04<00:00, 19.75it/s]


Validation Accuracy: 39.06%
*** Epoch: 23 ***


Loss: 1.28698: 100%|██████████| 96/96 [00:05<00:00, 16.90it/s]


Validation Accuracy: 36.72%
*** Epoch: 24 ***


Loss: 1.25011: 100%|██████████| 96/96 [00:04<00:00, 19.93it/s]


Validation Accuracy: 39.65%
*** Epoch: 25 ***


Loss: 1.29690: 100%|██████████| 96/96 [00:05<00:00, 17.33it/s]


Validation Accuracy: 38.28%
*** Epoch: 26 ***


Loss: 1.32983: 100%|██████████| 96/96 [00:05<00:00, 17.32it/s]


Validation Accuracy: 39.65%
*** Epoch: 27 ***


Loss: 1.23477: 100%|██████████| 96/96 [00:05<00:00, 16.99it/s]


Validation Accuracy: 38.09%
*** Epoch: 28 ***


Loss: 1.19810: 100%|██████████| 96/96 [00:04<00:00, 19.91it/s]


Validation Accuracy: 38.28%
*** Epoch: 29 ***


Loss: 1.19443: 100%|██████████| 96/96 [00:05<00:00, 17.49it/s]


Validation Accuracy: 36.72%
*** Epoch: 30 ***


Loss: 1.04127: 100%|██████████| 96/96 [00:04<00:00, 19.79it/s]


Validation Accuracy: 38.09%
*** Epoch: 31 ***


Loss: 1.09074: 100%|██████████| 96/96 [00:05<00:00, 17.90it/s]


Validation Accuracy: 36.72%
*** Epoch: 32 ***


Loss: 1.09229: 100%|██████████| 96/96 [00:04<00:00, 19.69it/s]


Validation Accuracy: 38.67%
*** Epoch: 33 ***


Loss: 1.09169: 100%|██████████| 96/96 [00:05<00:00, 18.57it/s]


Validation Accuracy: 38.09%
*** Epoch: 34 ***


Loss: 0.96861: 100%|██████████| 96/96 [00:04<00:00, 19.57it/s]


Validation Accuracy: 38.87%
*** Epoch: 35 ***


Loss: 0.95252: 100%|██████████| 96/96 [00:04<00:00, 19.66it/s]


Validation Accuracy: 37.50%
*** Epoch: 36 ***


Loss: 0.94036: 100%|██████████| 96/96 [00:05<00:00, 18.56it/s]


Validation Accuracy: 38.09%
*** Epoch: 37 ***


Loss: 0.91202: 100%|██████████| 96/96 [00:04<00:00, 20.04it/s]


Validation Accuracy: 39.26%
*** Epoch: 38 ***


Loss: 0.87713: 100%|██████████| 96/96 [00:05<00:00, 17.87it/s]


Validation Accuracy: 38.67%
*** Epoch: 39 ***


Loss: 0.87994: 100%|██████████| 96/96 [00:04<00:00, 19.92it/s]


Validation Accuracy: 37.11%
*** Epoch: 40 ***


Loss: 0.74023: 100%|██████████| 96/96 [00:05<00:00, 17.32it/s]


Validation Accuracy: 38.67%
*** Epoch: 41 ***


Loss: 0.78921: 100%|██████████| 96/96 [00:04<00:00, 19.99it/s]


Validation Accuracy: 38.67%
*** Epoch: 42 ***


Loss: 0.69435: 100%|██████████| 96/96 [00:05<00:00, 17.01it/s]


Validation Accuracy: 37.11%
*** Epoch: 43 ***


Loss: 0.63540: 100%|██████████| 96/96 [00:04<00:00, 20.24it/s]


Validation Accuracy: 39.84%
*** Epoch: 44 ***


Loss: 0.64176: 100%|██████████| 96/96 [00:05<00:00, 17.21it/s]


Validation Accuracy: 39.26%
*** Epoch: 45 ***


Loss: 0.66434: 100%|██████████| 96/96 [00:04<00:00, 19.80it/s]


Validation Accuracy: 39.65%
*** Epoch: 46 ***


Loss: 0.56405: 100%|██████████| 96/96 [00:05<00:00, 17.14it/s]


Validation Accuracy: 37.70%
*** Epoch: 47 ***


Loss: 0.53949: 100%|██████████| 96/96 [00:04<00:00, 19.80it/s]


Validation Accuracy: 39.45%
*** Epoch: 48 ***


Loss: 0.62836: 100%|██████████| 96/96 [00:05<00:00, 16.91it/s]


Validation Accuracy: 39.06%
*** Epoch: 49 ***


Loss: 0.52210: 100%|██████████| 96/96 [00:04<00:00, 20.08it/s]


Validation Accuracy: 38.67%
*** Epoch: 50 ***


Loss: 0.53158: 100%|██████████| 96/96 [00:05<00:00, 17.17it/s]


Validation Accuracy: 37.30%

Student Model Performance after training:


100%|██████████| 32/32 [00:00<00:00, 34.75it/s]


Confusion Matrix:
[[21  3  7  5  2  0  1  0  8  5]
 [ 2 27  3  1  0  3  1  0  7 13]
 [ 5  0 22  8  4 10  4  2  0  1]
 [ 1  0  9 21  4  9  7  1  1  2]
 [ 5  2  7 10  8  4  8  2  2  0]
 [ 3  1  4 11  2 17  2  2  1  0]
 [ 1  0  7  6  1  4 15  3  3  0]
 [ 1  0  4  3  6 10  4 20  1  6]
 [11  4  2  2  1  3  0  0 28  0]
 [ 2 13  0  3  2  1  2  0  7 25]]
Classification Report:
              precision    recall  f1-score   support

           0       0.40      0.40      0.40        52
           1       0.54      0.47      0.50        57
           2       0.34      0.39      0.36        56
           3       0.30      0.38      0.34        55
           4       0.27      0.17      0.21        48
           5       0.28      0.40      0.33        43
           6       0.34      0.38      0.36        40
           7       0.67      0.36      0.47        55
           8       0.48      0.55      0.51        51
           9       0.48      0.45      0.47        55

    accuracy                    

Loss: 5.66864: 100%|██████████| 96/96 [00:06<00:00, 15.62it/s]


Validation Accuracy:- Teacher: 80.47%; Student: 15.43%
*** Epoch: 2 ***


Loss: 5.37534: 100%|██████████| 96/96 [00:06<00:00, 15.90it/s]


Validation Accuracy:- Teacher: 80.47%; Student: 22.07%
*** Epoch: 3 ***


Loss: 5.27867: 100%|██████████| 96/96 [00:06<00:00, 15.81it/s]


Validation Accuracy:- Teacher: 80.47%; Student: 22.27%
*** Epoch: 4 ***


Loss: 4.85715: 100%|██████████| 96/96 [00:05<00:00, 17.33it/s]


Validation Accuracy:- Teacher: 80.47%; Student: 24.61%
*** Epoch: 5 ***


Loss: 4.74464: 100%|██████████| 96/96 [00:06<00:00, 15.51it/s]


Validation Accuracy:- Teacher: 80.47%; Student: 24.61%
*** Epoch: 6 ***


Loss: 4.56091: 100%|██████████| 96/96 [00:05<00:00, 17.20it/s]


Validation Accuracy:- Teacher: 80.47%; Student: 26.95%
*** Epoch: 7 ***


Loss: 4.49948: 100%|██████████| 96/96 [00:06<00:00, 15.47it/s]


Validation Accuracy:- Teacher: 80.47%; Student: 27.93%
*** Epoch: 8 ***


Loss: 4.46009: 100%|██████████| 96/96 [00:05<00:00, 17.25it/s]


Validation Accuracy:- Teacher: 80.47%; Student: 29.88%
*** Epoch: 9 ***


Loss: 4.07720: 100%|██████████| 96/96 [00:05<00:00, 16.74it/s]


Validation Accuracy:- Teacher: 80.47%; Student: 27.73%
*** Epoch: 10 ***


Loss: 4.15141: 100%|██████████| 96/96 [00:06<00:00, 15.55it/s]


Validation Accuracy:- Teacher: 80.47%; Student: 27.54%
*** Epoch: 11 ***


Loss: 4.30052: 100%|██████████| 96/96 [00:05<00:00, 17.24it/s]


Validation Accuracy:- Teacher: 80.47%; Student: 31.45%
*** Epoch: 12 ***


Loss: 4.16131: 100%|██████████| 96/96 [00:06<00:00, 15.49it/s]


Validation Accuracy:- Teacher: 80.47%; Student: 32.42%
*** Epoch: 13 ***


Loss: 4.18549: 100%|██████████| 96/96 [00:05<00:00, 17.34it/s]


Validation Accuracy:- Teacher: 80.47%; Student: 34.38%
*** Epoch: 14 ***


Loss: 4.02691: 100%|██████████| 96/96 [00:06<00:00, 15.30it/s]


Validation Accuracy:- Teacher: 80.47%; Student: 35.94%
*** Epoch: 15 ***


Loss: 3.96658: 100%|██████████| 96/96 [00:05<00:00, 17.20it/s]


Validation Accuracy:- Teacher: 80.47%; Student: 34.96%
*** Epoch: 16 ***


Loss: 4.03601: 100%|██████████| 96/96 [00:05<00:00, 16.59it/s]


Validation Accuracy:- Teacher: 80.47%; Student: 34.77%
*** Epoch: 17 ***


Loss: 3.87167: 100%|██████████| 96/96 [00:06<00:00, 15.72it/s]


Validation Accuracy:- Teacher: 80.47%; Student: 33.79%
*** Epoch: 18 ***


Loss: 3.77991: 100%|██████████| 96/96 [00:05<00:00, 17.24it/s]


Validation Accuracy:- Teacher: 80.47%; Student: 35.16%
*** Epoch: 19 ***


Loss: 4.17713: 100%|██████████| 96/96 [00:06<00:00, 15.48it/s]


Validation Accuracy:- Teacher: 80.47%; Student: 35.35%
*** Epoch: 20 ***


Loss: 3.72587: 100%|██████████| 96/96 [00:05<00:00, 17.15it/s]


Validation Accuracy:- Teacher: 80.47%; Student: 37.89%
*** Epoch: 21 ***


Loss: 3.65346: 100%|██████████| 96/96 [00:06<00:00, 15.51it/s]


Validation Accuracy:- Teacher: 80.47%; Student: 35.74%
*** Epoch: 22 ***


Loss: 3.63766: 100%|██████████| 96/96 [00:05<00:00, 17.34it/s]


Validation Accuracy:- Teacher: 80.47%; Student: 35.35%
*** Epoch: 23 ***


Loss: 3.71225: 100%|██████████| 96/96 [00:06<00:00, 14.29it/s]


Validation Accuracy:- Teacher: 80.47%; Student: 39.26%
*** Epoch: 24 ***


Loss: 3.65176: 100%|██████████| 96/96 [00:06<00:00, 15.41it/s]


Validation Accuracy:- Teacher: 80.47%; Student: 38.67%
*** Epoch: 25 ***


Loss: 3.76204: 100%|██████████| 96/96 [00:05<00:00, 17.23it/s]


Validation Accuracy:- Teacher: 80.47%; Student: 36.52%
*** Epoch: 26 ***


Loss: 3.59846: 100%|██████████| 96/96 [00:06<00:00, 15.36it/s]


Validation Accuracy:- Teacher: 80.47%; Student: 38.87%
*** Epoch: 27 ***


Loss: 3.64759: 100%|██████████| 96/96 [00:05<00:00, 17.28it/s]


Validation Accuracy:- Teacher: 80.47%; Student: 37.11%
*** Epoch: 28 ***


Loss: 3.18299: 100%|██████████| 96/96 [00:06<00:00, 15.57it/s]


Validation Accuracy:- Teacher: 80.47%; Student: 37.89%
*** Epoch: 29 ***


Loss: 3.25882: 100%|██████████| 96/96 [00:05<00:00, 16.73it/s]


Validation Accuracy:- Teacher: 80.47%; Student: 38.09%
*** Epoch: 30 ***


Loss: 3.07264: 100%|██████████| 96/96 [00:05<00:00, 16.91it/s]


Validation Accuracy:- Teacher: 80.47%; Student: 38.28%
*** Epoch: 31 ***


Loss: 3.21905: 100%|██████████| 96/96 [00:06<00:00, 15.46it/s]


Validation Accuracy:- Teacher: 80.47%; Student: 38.67%
*** Epoch: 32 ***


Loss: 3.09636: 100%|██████████| 96/96 [00:05<00:00, 17.39it/s]


Validation Accuracy:- Teacher: 80.47%; Student: 39.84%
*** Epoch: 33 ***


Loss: 3.14585: 100%|██████████| 96/96 [00:06<00:00, 15.36it/s]


Validation Accuracy:- Teacher: 80.47%; Student: 41.41%
*** Epoch: 34 ***


Loss: 2.92465: 100%|██████████| 96/96 [00:05<00:00, 17.31it/s]


Validation Accuracy:- Teacher: 80.47%; Student: 39.26%
*** Epoch: 35 ***


Loss: 2.78828: 100%|██████████| 96/96 [00:06<00:00, 15.56it/s]


Validation Accuracy:- Teacher: 80.47%; Student: 37.30%
*** Epoch: 36 ***


Loss: 3.11424: 100%|██████████| 96/96 [00:05<00:00, 16.79it/s]


Validation Accuracy:- Teacher: 80.47%; Student: 40.23%
*** Epoch: 37 ***


Loss: 2.64586: 100%|██████████| 96/96 [00:05<00:00, 17.18it/s]


Validation Accuracy:- Teacher: 80.47%; Student: 39.26%
*** Epoch: 38 ***


Loss: 2.23991: 100%|██████████| 96/96 [00:06<00:00, 15.40it/s]


Validation Accuracy:- Teacher: 80.47%; Student: 39.06%
*** Epoch: 39 ***


Loss: 2.53571: 100%|██████████| 96/96 [00:05<00:00, 17.26it/s]


Validation Accuracy:- Teacher: 80.47%; Student: 40.62%
*** Epoch: 40 ***


Loss: 2.66651: 100%|██████████| 96/96 [00:06<00:00, 15.29it/s]


Validation Accuracy:- Teacher: 80.47%; Student: 41.41%
*** Epoch: 41 ***


Loss: 2.56639: 100%|██████████| 96/96 [00:05<00:00, 17.31it/s]


Validation Accuracy:- Teacher: 80.47%; Student: 41.02%
*** Epoch: 42 ***


Loss: 2.21237: 100%|██████████| 96/96 [00:06<00:00, 15.46it/s]


Validation Accuracy:- Teacher: 80.47%; Student: 41.60%
*** Epoch: 43 ***


Loss: 2.16795: 100%|██████████| 96/96 [00:05<00:00, 16.65it/s]


Validation Accuracy:- Teacher: 80.47%; Student: 39.65%
*** Epoch: 44 ***


Loss: 2.37256: 100%|██████████| 96/96 [00:06<00:00, 15.03it/s]


Validation Accuracy:- Teacher: 80.47%; Student: 40.23%
*** Epoch: 45 ***


Loss: 1.92802: 100%|██████████| 96/96 [00:06<00:00, 15.37it/s]


Validation Accuracy:- Teacher: 80.47%; Student: 39.06%
*** Epoch: 46 ***


Loss: 2.07715: 100%|██████████| 96/96 [00:05<00:00, 17.28it/s]


Validation Accuracy:- Teacher: 80.47%; Student: 41.02%
*** Epoch: 47 ***


Loss: 2.37765: 100%|██████████| 96/96 [00:06<00:00, 15.30it/s]


Validation Accuracy:- Teacher: 80.47%; Student: 40.04%
*** Epoch: 48 ***


Loss: 1.92442: 100%|██████████| 96/96 [00:05<00:00, 17.27it/s]


Validation Accuracy:- Teacher: 80.47%; Student: 42.38%
*** Epoch: 49 ***


Loss: 1.86587: 100%|██████████| 96/96 [00:05<00:00, 16.12it/s]


Validation Accuracy:- Teacher: 80.47%; Student: 41.02%
*** Epoch: 50 ***


Loss: 2.01111: 100%|██████████| 96/96 [00:06<00:00, 15.97it/s]


Validation Accuracy:- Teacher: 80.47%; Student: 40.04%

Student Model Performance after distillation:


100%|██████████| 32/32 [00:00<00:00, 33.70it/s]

Confusion Matrix:
[[32  3  4  1  1  2  0  1  4  4]
 [ 1 35  2  0  0  0  3  0  2 14]
 [ 6  0 16  9  7  8  5  2  2  1]
 [ 1  0  9 10  5 17  7  5  1  0]
 [ 4  0  7  0 17  1  6  6  3  4]
 [ 1  0  5  8  4 16  2  3  1  3]
 [ 0  2  8  3  5  2 17  2  0  1]
 [ 0  0  6  2 15  4  3 23  0  2]
 [13  1  1  0  0  2  0  0 29  5]
 [ 5 14  2  0  1  0  4  4  3 22]]
Classification Report:
              precision    recall  f1-score   support

           0       0.51      0.62      0.56        52
           1       0.64      0.61      0.62        57
           2       0.27      0.29      0.28        56
           3       0.30      0.18      0.23        55
           4       0.31      0.35      0.33        48
           5       0.31      0.37      0.34        43
           6       0.36      0.42      0.39        40
           7       0.50      0.42      0.46        55
           8       0.64      0.57      0.60        51
           9       0.39      0.40      0.40        55

    accuracy                    




In [None]:
distilled_model_2 = model_distillation(teacher_model, distilled_model, train_loader, val_loader, 50)
print("\nStudent Model Performance before 2nd distillation loop:")
evaluate_model(distilled_model, val_loader)

In [11]:
print("\nStudent Model Performance after 2nd distillation loop:")
evaluate_model(distilled_model_2, val_loader)


Student Model Performance after 2nd distillation loop:


100%|██████████| 32/32 [00:00<00:00, 32.99it/s]

Confusion Matrix:
[[25  6 12  1  0  1  0  0  5  2]
 [ 1 43  2  2  1  1  3  0  2  2]
 [ 4  1 22  7  4  7  7  2  1  1]
 [ 1  1 10 16  8 10  7  1  0  1]
 [ 2  0  8  4 16  7  1  6  3  1]
 [ 0  0  6  5  3 22  6  0  1  0]
 [ 0  1  5  2  2  3 24  2  1  0]
 [ 0  0  6  3 14  6  2 20  0  4]
 [11  2  1  5  0  0  1  0 30  1]
 [ 4 20  5  0  0  1  6  0  3 16]]
Classification Report:
              precision    recall  f1-score   support

           0       0.52      0.48      0.50        52
           1       0.58      0.75      0.66        57
           2       0.29      0.39      0.33        56
           3       0.36      0.29      0.32        55
           4       0.33      0.33      0.33        48
           5       0.38      0.51      0.44        43
           6       0.42      0.60      0.49        40
           7       0.65      0.36      0.47        55
           8       0.65      0.59      0.62        51
           9       0.57      0.29      0.39        55

    accuracy                    


