In [50]:
import torch
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
import numpy as np
from torchvision.transforms import Compose, RandomHorizontalFlip, RandomGrayscale, ToTensor, Normalize
from torchsummary import summary
from torch.optim import Adam
from torch.nn import MSELoss,  Sequential, CrossEntropyLoss
from tqdm import tqdm

In [51]:
TrainBS = 64
TestBS = 32
Learning_Rate = 1e-4
KD_temp = 10
alpha = 0.5
beta = 0.5
Random_Seed = np.random.uniform()
torch.manual_seed(Random_Seed)

<torch._C.Generator at 0x1ac08752310>

In [52]:
Train_Data = DataLoader(dataset = CIFAR10(train = True,
                                          root = './data/',
                                          download = True,
                                          transform = Compose([RandomHorizontalFlip(),
                                                               RandomGrayscale(),
                                                               ToTensor(),
                                                               Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])),
                       batch_size = TrainBS,
                       shuffle = True)
Test_Data = DataLoader(dataset = CIFAR10(train = False,
                                         root = './data/',
                                         download = True,
                                         transform = Compose([ToTensor(),
                                                              Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])),
                       batch_size = TestBS,
                       shuffle = True)

Files already downloaded and verified
Files already downloaded and verified


In [53]:
import torch.nn as nn
from torch.nn import Module, Sequential, Conv2d, BatchNorm2d, ReLU, MaxPool2d, AdaptiveAvgPool2d, Linear, Softmax
from torchvision.transforms import Resize, InterpolationMode
import torch.nn.functional as F

class Basic_Block(Module):
    def __init__(self, Proc_Channel, DownSample):
        super(Basic_Block, self).__init__()
        Proc_Channel= int(Proc_Channel)
        stride = 1
        in_channels = int(Proc_Channel)
        self.shortcut = Sequential()
        if DownSample == 1:
            if Proc_Channel != 64:
                in_channels = int(Proc_Channel/2)
                stride = 2
            self.shortcut = Sequential(Conv2d(in_channels = int(in_channels), 
                                   out_channels = int(Proc_Channel), 
                                   kernel_size = 3, 
                                   stride = int(stride), 
                                   padding = 1),
                               BatchNorm2d(Proc_Channel)) 
        self.ConvLayer1 = Sequential(Conv2d(in_channels, Proc_Channel, 3, stride, 1),
                            BatchNorm2d(Proc_Channel))
        self.ConvLayer2 = Sequential(Conv2d(Proc_Channel, Proc_Channel, 3, 1, 1),
                            BatchNorm2d(Proc_Channel)) 
    def forward(self, x):
        Residual = self.shortcut(x)
        x = self.ConvLayer1(x)
        x = F.relu(x)
        x = self.ConvLayer2(x)
        x = F.relu(x)
        x = x + Residual
        x = F.relu(x)
        return x
    
class Bottle_neck(nn.Module):
    def __init__(self, Proc_Channel, DownSample):
        super(Bottle_neck,self).__init__()
        Proc_Channel = int(Proc_Channel)
        stride = 1
        in_channels = int(Proc_Channel * 4)
        if DownSample == 1:
            if Proc_Channel == 64:
                in_channels = Proc_Channel
            else:
                stride = 2
                in_channels = int(Proc_Channel * 2)
        self.ConvLayer1 = Sequential(Conv2d(in_channels, Proc_Channel, 1, stride, 0),
                            BatchNorm2d(Proc_Channel), 
                            ReLU())
        self.ConvLayer2 = Sequential(Conv2d(Proc_Channel, Proc_Channel, 3, stride, 1),
                            BatchNorm2d(Proc_Channel), 
                            ReLU())
        self.ConvLayer3 = Sequential(Conv2d(Proc_Channel, Proc_Channel * 4, 1, stride, 0),
                            BatchNorm2d(Proc_Channel * 4), 
                            ReLU())
        self.shortcut = Sequential(Conv2d(in_channels, Proc_Channel * 4, 3, stride, 1),
                          BatchNorm2d(Proc_Channel * 4))
    def forward(self,x):
        Residual = self.shortcut(x)
        x = self.ConvLayer1(x)
        x = self.ConvLayer2(x)
        x = self.ConvLayer3(x)
        x = x + Residual
        return x
class Resnet(Module):
    arch = {18 : [Basic_Block, [2, 2, 2, 2], 512], 
             34 : [Basic_Block, [3, 4, 6, 3], 512], 
            50 : [Bottle_neck, [3, 4, 6, 3], 2048], 
            101 : [Bottle_neck, [3, 4, 23, 3], 2048], 
            152 : [Bottle_neck, [3, 8, 36, 3], 2048]}
    def __init__(self,typ):
        super(Resnet,self).__init__()
        [block, layer_arch, final_channel] = Resnet.arch[typ]
        self.Resize = Resize((224,224), interpolation = InterpolationMode.BILINEAR)
        self.final_channel = final_channel
        self.stem = Sequential(Conv2d(in_channels = 3, out_channels = 64, kernel_size = 7, stride = 2, padding = 3),
                        BatchNorm2d(64),
                        ReLU(),
                        MaxPool2d(kernel_size = 3, stride = 2, padding = 1))
        self.stage1 = self.make_layer(block, 64, layer_arch[0], stride = 1)
        self.stage2 = self.make_layer(block, 128, layer_arch[1], stride = 2)
        self.stage3 = self.make_layer(block, 256, layer_arch[2], stride = 2)
        self.stage4 = self.make_layer(block, 512, layer_arch[3], stride = 2)
        self.Aver_pool = AdaptiveAvgPool2d((1,1))
        self.fc = Linear(self.final_channel, 10)
        self.softmax = Softmax(dim = 1)
    def make_layer(self, block, Proc_Channel, layer_arch, stride):
        layer = []
        for i in range(layer_arch):
            if i ==0 :
                layer.append(block(Proc_Channel,True))
            else:
                layer.append(block(Proc_Channel,False))
        return Sequential(*layer)
    def forward(self,x):
        x = self.Resize(x)
        x = self.stem(x)
        x = self.stage1(x)
        x = self.stage2(x)
        x = self.stage3(x)
        x = self.stage4(x)
        x = self.Aver_pool(x)
        x = x.view(-1, self.final_channel)
        x = self.fc(x)
        x = F.relu(x)
        x = self.softmax(x)
        return x

In [54]:
"""
Teacher_Network = Resnet(34)
Teacher_Network = Teacher_Network.to('cuda')
optimizer = Adam(params = Teacher_Network.parameters(), lr = Learning_Rate)  
Loss_Function = CrossEntropyLoss()
"""

"\nTeacher_Network = Resnet(34)\nTeacher_Network = Teacher_Network.to('cuda')\noptimizer = Adam(params = Teacher_Network.parameters(), lr = Learning_Rate)  \nLoss_Function = CrossEntropyLoss()\n"

In [55]:
"""
epochs =  100
for epoch in range(1,epochs + 1):
    Teacher_Network.train()
    correct = 0
    avg_loss = 0
    for batch_idx, (data, target) in enumerate(tqdm(Train_Data)):
        data = data.to('cuda')
        target = target.to('cuda')
        label = np.zeros((len(target),  10))
        for idx,i in enumerate(target):
            label[idx][i] = 1
        label = torch.tensor(label, dtype = torch.float32)
        label = label.to('cuda')
        optimizer.zero_grad()
        output = Teacher_Network(data)
        output = output.view(-1,10)
        loss = Loss_Function (output, label)
        loss.backward()
        optimizer.step()
        avg_loss += loss
        correct += target.eq(output.data.max(1).indices).sum()
    print('Train_Epoch:{}\t Loss:{:.6f}\t Acc:{:.1f}'.format(
        epoch, avg_loss / len(Train_Data), 100.*correct / len(Train_Data.dataset)))
    avg_loss /= len(Train_Data) 
    acc = correct / len(Train_Data.dataset)
    with open('result.txt', 'a') as f:
        f.write('{\'Epoch\': ')
        f.write(str(epoch))
        f.write(', \'Avg_Loss\': ')
        f.write(str(avg_loss.item()))
        f.write(', \'Acc\': ')
        f.write(str(acc.item()))
        f.write('}\n')
    if epoch%10 ==0:
        torch.save(Teacher_Network.state_dict(), f'Res34_{epoch}.pth')
torch.save(Teacher_Network.state_dict(), f'Res34_Last.pth')
"""

"\nepochs =  100\nfor epoch in range(1,epochs + 1):\n    Teacher_Network.train()\n    correct = 0\n    avg_loss = 0\n    for batch_idx, (data, target) in enumerate(tqdm(Train_Data)):\n        data = data.to('cuda')\n        target = target.to('cuda')\n        label = np.zeros((len(target),  10))\n        for idx,i in enumerate(target):\n            label[idx][i] = 1\n        label = torch.tensor(label, dtype = torch.float32)\n        label = label.to('cuda')\n        optimizer.zero_grad()\n        output = Teacher_Network(data)\n        output = output.view(-1,10)\n        loss = Loss_Function (output, label)\n        loss.backward()\n        optimizer.step()\n        avg_loss += loss\n        correct += target.eq(output.data.max(1).indices).sum()\n    print('Train_Epoch:{}\t Loss:{:.6f}\t Acc:{:.1f}'.format(\n        epoch, avg_loss / len(Train_Data), 100.*correct / len(Train_Data.dataset)))\n    avg_loss /= len(Train_Data) \n    acc = correct / len(Train_Data.dataset)\n    with open(

In [56]:
"""
Student_Network = Resnet(18)
Student_Network = Student_Network.to('cuda')
optimizer = Adam(params = Student_Network.parameters(), lr = Learning_Rate)  
Loss_Function = CrossEntropyLoss()
"""

"\nStudent_Network = Resnet(18)\nStudent_Network = Student_Network.to('cuda')\noptimizer = Adam(params = Student_Network.parameters(), lr = Learning_Rate)  \nLoss_Function = CrossEntropyLoss()\n"

In [57]:
"""
epochs = 20
for epoch in range(1, epochs + 1):
    Student_Network.train()
    avg_loss = 0
    acc = 0
    for batch_idx, (data, target) in enumerate(tqdm(Train_Data)):
        #data = data.to('cuda')
        #target = target.to('cuda')
        label = np.zeros((len(target),10))
        for label_idx, i in enumerate(target):
            label[label_idx][i] = 1
        label = torch.tensor(label, dtype = torch.float32)
        #label = label.to('cuda')
        output = Student_Network(data)
        loss = Loss_Function(output, label)
        loss.backward()
        optimizer.step()
        avg_loss += loss
        correct = target.eq(data.max(1).indices).sum()
    print('Train_Epoch: {}\t Loss: {:.6f}\t Acc: {:.1f}'.format
          (epoch, avg_loss /= len(Train_Date), 100.*correct/len(Train_Data.dataset)))
    acc = correct / len(Train_Data.dataset)
    avg_loss /= len(Train_Data)
    with open('ResNet18_result.txt', 'a') as f:
        f.write('{\'Epoch\': ')
        f.write(str(epoch))
        f.write(', \'Avg_Loss\': ')
        f.write(str(avg_loss.item()))
        f.write(', \'Acc\': ')
        f.write(str(acc.item()))
        f.write('}\n')
    Student_Network.eval()
    correct = 0
    for batch_idx, (data, target) in enumerate(tqdm(Test_Data)):
        #data = data.to('cuda')
        #target = target.to('cuda')
        with torch.no_grad():
            output = Student_Network(data)
            correct += target.eq(output.max(1).indices).sum()
    print( format( correct / len( Test_Data.dataset ) * 100, '.2f') )
"""

"\nepochs = 20\nfor epoch in range(1, epochs + 1):\n    Student_Network.train()\n    avg_loss = 0\n    acc = 0\n    for batch_idx, (data, target) in enumerate(tqdm(Train_Data)):\n        #data = data.to('cuda')\n        #target = target.to('cuda')\n        label = np.zeros((len(target),10))\n        for label_idx, i in enumerate(target):\n            label[label_idx][i] = 1\n        label = torch.tensor(label, dtype = torch.float32)\n        #label = label.to('cuda')\n        output = Student_Network(data)\n        loss = Loss_Function(output, label)\n        loss.backward()\n        optimizer.step()\n        avg_loss += loss\n        correct = target.eq(data.max(1).indices).sum()\n    print('Train_Epoch: {}\t Loss: {:.6f}\t Acc: {:.1f}'.format\n          (epoch, avg_loss /= len(Train_Date), 100.*correct/len(Train_Data.dataset)))\n    acc = correct / len(Train_Data.dataset)\n    avg_loss /= len(Train_Data)\n    with open('ResNet18_result.txt', 'a') as f:\n        f.write('{'Epoch': ')\

In [58]:
Teacher_Network = Resnet(34)
Student_Network = Resnet(18)
Teacher_Network.load_state_dict(torch.load('Res34_Last.pth'))
Teacher_Network = Teacher_Network.to('cuda')
Student_Network = Student_Network.to('cuda')
optimizer = Adam(params = Student_Network.parameters(), lr = Learning_Rate)
Loss_Function = CrossEntropyLoss()

In [59]:
epochs = 1
Teacher_Network.eval()
for epochs in range(1, epochs + 1):
    Student_Network.train()
    correct = 0
    kd_avg_loss = 0
    gt_avg_loss = 0
    for batch_idx, (data,target) in enumerate(tqdm(Train_Data)):
        data = data.to('cuda')
        target = target.to('cuda')
        Pred = np.zeros((len(target), 10))
        for idx, i in enumerate(target):
            Pred[idx][i] = 1
        Pred = torch.tensor(Pred, dtype = torch.float32)
        Pred = label.to('cuda')
        with torch.no_grad():
            Teacher_Pred = Teacher_Network(data)
        Student_Pred = Student_Network(data)
        Distill_Loss = Loss_Function(Student_Pred / KD_temp, Teacher_Pred / KD_temp)
        Normal_Loss = Loss_Function(Student_Pred, Pred)
        loss = Distill_Loss * alpha + Normal_Loss * beta
        loss.backward()
        optimizer.step()
        kd_avg_loss += Distill_Loss
        gt_avg_loss += Normal_Loss
        teacher_correct = target.eq(Teacher_Pred.max(1).indices).sum()
        student_correct = target.eq(Student_Pred.max(1).indices).sum()
    print('Epoch:{}\t KDLoss:{:.6f}\t GTLoss:{:.6f}\t TAcc:{:.1f}\t Acc:{:.1f}\t'.format(
         epoch, kd_avg_loss / len(Train_Data), gt_avg_loss / len(Train_Data), 
          teacher_correct / len(Train_Data.dataset), student_correct / len(Train_Data.dataset)))
    Acc = student_correct / len(Train_Data.dataset)
    Tacc = teacher_correct / len(Train_Data.dataset)
    gt_avg_loss /= len(Train_Data)
    kd_avg_loss /= len(Train_Data)
    with open('Resnet18KD.txt', 'a') as f:
        f.write('{\"Epoch\": ')
        f.write(str(epoch))
        f.write(', \"KDLoss\": ')
        f.write(str(kd_avg_loss.item()))
        f.write(', \"GTLoss\": ')
        f.write(str(gt_avg_loss.item()))
        f.write(', \"TAcc\": ')
        f.write(str(Tacc.item()))
        f.write(', \"Acc\": ')
        f.write(str(Acc.item()))
        f.write('}\n')
    Student_Network.eval
    correct = 0
    for batch_idx, (data, target) in enumerate(tqdm(Test_Data)):
        data = data.to('cuda')
        target = target.to('cuda')
        with torch.no_grad:
            output = Student_Network(data)
            correct += target.eq(output.max(1).indice).sum()
    print( format( correct / len( Test_Data.dataset ) * 100, '.2f') )

  0%|▏                                                                               | 2/782 [00:36<3:57:06, 18.24s/it]


KeyboardInterrupt: 