<a href="https://colab.research.google.com/github/shreyash53/SMAI-Knowledge-Distilation/blob/main/KD_ResNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import time
import copy
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torchvision
from torchvision import datasets, transforms
from torchsummary import summary
from torch.optim import lr_scheduler
import torch.nn.functional as F
import torch.nn as nn
import torchvision.models as models
from torch import nn, optim

In [None]:
transform = transforms.Compose([transforms.Resize((224,224)),
                                transforms.ToTensor(),
                                transforms.Normalize([0.485,0.456,  
                                0.406], [0.229, 0.224, 0.225])])
trainset = datasets.CIFAR10('/content/train/', download=True, train=True, transform=transform)
valset = datasets.CIFAR10('/content/val/', download=True, train=False, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
valloader = torch.utils.data.DataLoader(valset, batch_size=64, shuffle=True)
len_trainset = len(trainset)
len_valset = len(valset)
classes = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to /content/train/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting /content/train/cifar-10-python.tar.gz to /content/train/
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to /content/val/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting /content/val/cifar-10-python.tar.gz to /content/val/


In [None]:
dataiter = iter(trainloader)
images, labels = dataiter.next()
print(images.shape)
print(labels.shape)

torch.Size([64, 3, 224, 224])
torch.Size([64])


In [None]:
resnet = models.resnet50(pretrained=True)
for param in resnet.parameters():
   param.requires_grad = False
num_ftrs = resnet.fc.in_features
resnet.fc = nn.Linear(num_ftrs, 10)
resnet = resnet.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(resnet.fc.parameters())

Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth


  0%|          | 0.00/97.8M [00:00<?, ?B/s]

In [None]:
def train_and_evaluate(model, trainloader, valloader, criterion, optimizer, len_trainset, len_valset, num_epochs=25):
   model.train()
   best_model_wts = copy.deepcopy(model.state_dict())
   best_acc = 0.0
   for epoch in range(num_epochs):
      model.train()
      print('Epoch {}/{}'.format(epoch, num_epochs - 1))
      print('-' * 10)
      running_loss = 0.0
      running_corrects = 0
      for inputs, labels in trainloader:
         inputs = inputs.to(device)
         labels = labels.to(device)
         optimizer.zero_grad()
         outputs = model(inputs)
         loss = criterion(outputs, labels)
         _, preds = torch.max(outputs, 1)
         loss.backward() 
         optimizer.step()  
         running_loss += loss.item() * inputs.size(0)
         running_corrects += torch.sum(preds == labels.data)
      epoch_loss = running_loss / len_trainset
      epoch_acc = running_corrects.double() / len_trainset
      print(' Train Loss: {:.4f} Acc: {:.4f}'.format(epoch_loss,
             epoch_acc)) 
         
      model.eval()
      running_loss_val = 0.0 
      running_corrects_val = 0
      for inputs, labels in valloader:
         inputs = inputs.to(device)
         labels = labels.to(device)
         outputs = model(inputs) 
         loss = criterion(outputs,labels)
         _, preds = torch.max(outputs, 1)
         running_loss_val += loss.item() * inputs.size(0)
         running_corrects_val += torch.sum(preds == labels.data)
      
      epoch_loss_val = running_loss_val / len_valset
      epoch_acc_val = running_corrects_val.double() / len_valset
      
      if epoch_acc_val > best_acc:
         best_acc = epoch_acc_val
         best_model_wts = copy.deepcopy(model.state_dict())
      
      print(' Val Loss: {:.4f} Acc: {:.4f}'.format(epoch_loss_val,
             epoch_acc_val))
      
      print()
      print('Best val Acc: {:4f}'.format(best_acc))
      model.load_state_dict(best_model_wts)
   return model

In [None]:
resnet_teacher = train_and_evaluate(resnet,trainloader,
                                   valloader,criterion,optimizer,
                                   len_trainset,len_valset,10)

Epoch 0/9
----------
 Train Loss: 0.7646 Acc: 0.7479
 Val Loss: 0.5980 Acc: 0.7940

Best val Acc: 0.794000
Epoch 1/9
----------
 Train Loss: 0.5936 Acc: 0.7943
 Val Loss: 0.5693 Acc: 0.8013

Best val Acc: 0.801300
Epoch 2/9
----------
 Train Loss: 0.5670 Acc: 0.8029
 Val Loss: 0.5807 Acc: 0.7968

Best val Acc: 0.801300
Epoch 3/9
----------
 Train Loss: 0.5690 Acc: 0.8038
 Val Loss: 0.5765 Acc: 0.7984

Best val Acc: 0.801300
Epoch 4/9
----------
 Train Loss: 0.5672 Acc: 0.8028
 Val Loss: 0.5787 Acc: 0.7994

Best val Acc: 0.801300
Epoch 5/9
----------
 Train Loss: 0.5618 Acc: 0.8028
 Val Loss: 0.5809 Acc: 0.7978

Best val Acc: 0.801300
Epoch 6/9
----------
 Train Loss: 0.5661 Acc: 0.8038
 Val Loss: 0.5900 Acc: 0.7965

Best val Acc: 0.801300
Epoch 7/9
----------
 Train Loss: 0.5644 Acc: 0.8027
 Val Loss: 0.5961 Acc: 0.7979

Best val Acc: 0.801300
Epoch 8/9
----------
 Train Loss: 0.5651 Acc: 0.8048
 Val Loss: 0.5349 Acc: 0.8182

Best val Acc: 0.818200
Epoch 9/9
----------
 Train Loss: 0.5

In [None]:
PATH = "/content/drive/MyDrive/ml_models"
torch.save(resnet_teacher.state_dict(), PATH + "/teacher")

In [None]:
class Net(nn.Module):
   """
   This will be your student network that will learn from the 
   teacher network in our case resnet50.
   """
   def __init__(self):
      super(Net, self).__init__()
      self.layer1 = nn.Sequential(
         nn.Conv2d(3, 64, kernel_size = (3,3), stride = (1,1), 
         padding = (1,1)),
         nn.ReLU(inplace=True),
         nn.Conv2d(64, 64, kernel_size = (3,3), stride = (1,1), 
         padding = (1,1)),
         nn.ReLU(inplace=True),
         nn.MaxPool2d(kernel_size=2, stride=2, padding=0, 
         dilation=1, ceil_mode=False)
      )
      self.layer2 = nn.Sequential(
         nn.Conv2d(64, 128, kernel_size = (3,3), stride = (1,1), 
         padding = (1,1)),
         nn.ReLU(inplace=True),
         nn.Conv2d(128, 128, kernel_size = (3,3), stride = (1,1), 
         padding = (1,1)),
         nn.ReLU(inplace=True),
         nn.MaxPool2d(kernel_size=2, stride=2, padding=0, 
         dilation=1, ceil_mode=False)
      )
      self.pool1 = nn.AdaptiveAvgPool2d(output_size=(1,1))
      self.fc1 = nn.Linear(128, 32)
      self.fc2 = nn.Linear(32, 10)
      self.dropout_rate = 0.5
   
   def forward(self, x):
      x = self.layer1(x)
      x = self.layer2(x)
      x = self.pool1(x)
      x = x.view(x.size(0), -1)
      x = self.fc1(x)
      x = self.fc2(x)
#    return  xnet = Net().to(device)
      x.net = Net().to(device)
      return x


In [None]:
dataiter = iter(trainloader)
images, labels = dataiter.next()
out = resnet(images.cuda())
print(out.shape)

torch.Size([64, 10])


In [None]:
def loss_kd(outputs, labels, teacher_outputs, temparature, alpha):
   KD_loss = nn.KLDivLoss()(
       F.log_softmax(outputs/temparature, dim=1),
       F.softmax(teacher_outputs/temparature,dim=1)) * (alpha * temparature * temparature) + F.cross_entropy(outputs, labels) * (1. - alpha)
   return KD_loss

def get_outputs(model, dataloader):
   '''
   Used to get the output of the teacher network
   '''
   outputs = []
   for inputs, labels in dataloader:
      inputs_batch, labels_batch = inputs.cuda(), labels.cuda()
      output_batch = model(inputs_batch).data.cpu().numpy()
      outputs.append(output_batch)
   return outputs

In [None]:
def train_kd(model,teacher_out, optimizer, loss_kd, dataloader, temparature, alpha):
   model.train()
   running_loss = 0.0
   running_corrects = 0
   for i,(images, labels) in enumerate(dataloader):
      inputs = images.to(device)
      labels = labels.to(device)
      optimizer.zero_grad()
      outputs = model(inputs)
      outputs_teacher = torch.from_numpy(teacher_out[i]).to(device)
      loss = loss_kd(outputs,labels,outputs_teacher,temparature, 
                     alpha)
      _, preds = torch.max(outputs, 1)
      loss.backward()
      optimizer.step()
      running_loss += loss.item() * inputs.size(0)
      running_corrects += torch.sum(preds == labels.data)
   
   epoch_loss = running_loss / len(trainset)
   epoch_acc = running_corrects.double() / len(trainset)
   print(' Train Loss: {:.4f} Acc: {:.4f}'.format(epoch_loss, 
          epoch_acc))

def eval_kd(model,teacher_out, optimizer, loss_kd, dataloader, temparature, alpha):
   model.eval()
   running_loss = 0.0
   running_corrects = 0
   for i,(images, labels) in enumerate(dataloader):
      inputs = images.to(device)
      labels = labels.to(device)
      outputs = model(inputs)
      outputs_teacher = torch.from_numpy(teacher_out[i]).cuda()
      loss = loss_kd(outputs,labels,outputs_teacher,temparature, 
                     alpha)
      _, preds = torch.max(outputs, 1)
      running_loss += loss.item() * inputs.size(0)
      running_corrects += torch.sum(preds == labels.data)
   epoch_loss = running_loss / len(valset)
   epoch_acc = running_corrects.double() / len(valset)
   print(' Val Loss: {:.4f} Acc: {:.4f}'.format(epoch_loss,
          epoch_acc))
   return epoch_acc

def train_and_evaluate_kd(model, teacher_model, optimizer, loss_kd, trainloader, valloader, temparature, alpha, num_epochs=25):
   teacher_model.eval()
   best_model_wts = copy.deepcopy(model.state_dict())
   outputs_teacher_train = get_outputs(teacher_model, trainloader)
   outputs_teacher_val = get_outputs(teacher_model, valloader)
   print("Teacher’s outputs are computed now starting the training process-")
   best_acc = 0.0
   for epoch in range(num_epochs):
      print('Epoch {}/{}'.format(epoch, num_epochs - 1))
      print('-' * 10)
      
      # Training the student with the soft labes as the outputs from the teacher and using the loss_kd function
      
      train_kd(model, outputs_teacher_train, 
               optim.Adam(resnet.parameters()),loss_kd,trainloader, 
               temparature, alpha)
     
      # Evaluating the student network
      epoch_acc_val = eval_kd(model, outputs_teacher_val,     
                          optim.Adam(resnet.parameters()), loss_kd, 
                          valloader, temparature, alpha)
      if epoch_acc_val > best_acc:
         best_acc = epoch_acc_val
         best_model_wts = copy.deepcopy(model.state_dict())
         print('Best val Acc: {:4f}'.format(best_acc))
         model.load_state_dict(best_model_wts)
   return model

In [None]:
stud=train_and_evaluate_kd(resnet,resnet_teacher,
optim.Adam(resnet.parameters()),loss_kd,trainloader,valloader,1,0.5,20)

Teacher’s outputs are computed now starting the training process-
Epoch 0/19
----------


  "reduction: 'mean' divides the total loss by both the batch size and the support size."


 Train Loss: 0.5170 Acc: 0.8149
 Val Loss: 0.4927 Acc: 0.8252
Best val Acc: 0.825200
Epoch 1/19
----------
 Train Loss: 0.4940 Acc: 0.8205
 Val Loss: 0.5027 Acc: 0.8133
Epoch 2/19
----------
 Train Loss: 0.4876 Acc: 0.8233
 Val Loss: 0.5094 Acc: 0.8046
Epoch 3/19
----------
 Train Loss: 0.4841 Acc: 0.8279
 Val Loss: 0.4930 Acc: 0.8206
Epoch 4/19
----------
 Train Loss: 0.4815 Acc: 0.8274
 Val Loss: 0.4849 Acc: 0.8250
Epoch 5/19
----------
 Train Loss: 0.4764 Acc: 0.8329
 Val Loss: 0.4865 Acc: 0.8253
Best val Acc: 0.825300
Epoch 6/19
----------
 Train Loss: 0.4771 Acc: 0.8319
 Val Loss: 0.4761 Acc: 0.8321
Best val Acc: 0.832100
Epoch 7/19
----------
 Train Loss: 0.4729 Acc: 0.8345
 Val Loss: 0.4968 Acc: 0.8191
Epoch 8/19
----------
 Train Loss: 0.4726 Acc: 0.8378
 Val Loss: 0.4799 Acc: 0.8296
Epoch 9/19
----------
 Train Loss: 0.4695 Acc: 0.8390
 Val Loss: 0.4754 Acc: 0.8339
Best val Acc: 0.833900
Epoch 10/19
----------
 Train Loss: 0.4681 Acc: 0.8398
 Val Loss: 0.4765 Acc: 0.8331
Epoch

In [None]:
PATH = "/content/drive/MyDrive/ml_models"
torch.save(stud.state_dict(), PATH + "/student")