In [7]:
import torch
from model import Resnet
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
import numpy as np
from torchvision.transforms import Compose, RandomHorizontalFlip, RandomGrayscale, ToTensor, Normalize
from torchsummary import summary
from torch.optim import Adam
from torch.nn import MSELoss, KLDivLoss
from tqdm import tqdm

In [8]:
TrainBS = 64
TestBS = 64
Learning_Rate = 0.001
Random_Seed = np.random.uniform()
torch.manual_seed(Random_Seed)

<torch._C.Generator at 0x1e77c9242b0>

In [9]:
Train_Data = DataLoader(dataset = CIFAR10(train = True,
                                          root = '/data/',
                                          download = True,
                                          transform = Compose([RandomHorizontalFlip(),
                                                               RandomGrayscale(),
                                                               ToTensor(),
                                                               Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])),
                       batch_size = TrainBS,
                       shuffle = True)
Test_Data = DataLoader(dataset = CIFAR10(train = False,
                                         root = '/data/',
                                         download = True,
                                         transform = Compose([ToTensor(),
                                                              Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])),
                       batch_size = TestBS,
                       shuffle = True)

Files already downloaded and verified
Files already downloaded and verified


In [10]:
Network = Resnet(34)
summary(Network, (3, 224, 224))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Resize-1          [-1, 3, 224, 224]               0
            Conv2d-2         [-1, 64, 112, 112]           9,472
       BatchNorm2d-3         [-1, 64, 112, 112]             128
              ReLU-4         [-1, 64, 112, 112]               0
         MaxPool2d-5           [-1, 64, 56, 56]               0
            Conv2d-6           [-1, 64, 56, 56]          36,928
       BatchNorm2d-7           [-1, 64, 56, 56]             128
            Conv2d-8           [-1, 64, 56, 56]          36,928
       BatchNorm2d-9           [-1, 64, 56, 56]             128
           Conv2d-10           [-1, 64, 56, 56]          36,928
      BatchNorm2d-11           [-1, 64, 56, 56]             128
      Basic_Block-12           [-1, 64, 56, 56]               0
           Conv2d-13           [-1, 64, 56, 56]          36,928
      BatchNorm2d-14           [-1, 64,

In [None]:
Teacher_Network = resnet(34)
Network.load_state_dict(torch.load('KD-resnet.pth'))

Student_Network = Resnet(18)
optimizer = Adam(params = Student_Network.parameters(), lr = Learning_Rate)  
Loss_Function = MSELoss()

In [None]:
KD_temp = 1
Soft_target_Loss = MSELoss()
alpha = 0.5
Hard_target_Loss = KLDivLoss(reduction="batchmean")
Beta = 0.5

In [None]:
epochs = 10
for epoch in range(1,epochs + 1):
    for data, target in tqdm(enumerate(Train_Data)):
        Teacher_Pred = Teacher_Network(data)
        Student_Pred = Student_Network(data)
        Distill_Loss = Hard_target_Loss(Student_Pred / KD_temp, Teacher_Pred / KD_temp)
        Student_Loss = Soft_target_Loss(Student_Pred, target)
        Loss = alpha * Distill_Loss + Beta * Student_Loss
        Loss.backward()
        optimizer.step()
        avg_loss += Loss
        correct += target.eq(Student_Network.data).sum()
    print('Epoch:{}\t Loss:{:.6f}\t Acc:{:.1f}'.format(
        epoch, avg_loss/len(Train_Data), 100.*correct/len(Train_Data.dataset)))