In [1]:
#import
import torch
from torch import nn,optim
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision import datasets

In [2]:
# Construct Distiller class

class Distiller(nn.Module):
    def __init__(self,student,teacher):
        super(Distiller,self).__init__()
        self.student = student
        self.teacher = teacher
        self.teacher.eval()
        for p in self.teacher.parameters():
            p.requires_grad = False
        self.KL = nn.KLDivLoss(reduction='batchmean')

    def forward(self,x):
        student_logits = self.student(x)
        with torch.no_grad():
            teacher_logits = self.teacher(x)
        return student_logits,teacher_logits
    
    def loss(self,student_logits,teacher_logits,T=20,alpha=0.5):
        student_softmax = nn.functional.softmax(student_logits/T,dim=1)
        
        teacher_softmax = nn.functional.softmax(teacher_logits/T,dim=1)
        loss = alpha*T*T*self.KL(student_softmax,teacher_softmax) + nn.functional.cross_entropy(student_logits,torch.argmax(teacher_softmax,dim=1))*(1. - alpha)
        return loss

    def train(self,trainloader,optimizer,epochs=10,T=20,alpha=0.5):
        self.student.train()
        for epoch in range(epochs):
            running_loss = 0.0
            for i,data in enumerate(tqdm(trainloader)):
                inputs,labels = data
                optimizer.zero_grad()
                student_logits,teacher_logits = self(inputs)
                loss = self.loss(student_logits,teacher_logits,T,alpha)
                loss.backward()
                optimizer.step()
                running_loss += loss.item()
            print("Student loss: %.3f" % (running_loss/len(trainloader)))
            print('Teacher loss: %.3f' % (nn.functional.cross_entropy(teacher_logits,labels)))
            print('Distillation loss: %.3f' % (self.KL(nn.functional.softmax(student_logits/T,dim=1),nn.functional.softmax(teacher_logits/T,dim=1))))
            print('Epoch: %d, Loss: %.3f' % (epoch + 1, running_loss / len(trainloader)))
        print('Finished Training')

    def test(self,testloader):
        self.student.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for data in tqdm(testloader):
                images, labels = data
                outputs = self.student(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum()
        print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total))

    def save(self,path):
        torch.save(self.student.state_dict(),path)

    def load(self,path):
        self.student.load_state_dict(torch.load(path))



In [3]:

# Creat Student and Teacher
class Student(nn.Module):
    def __init__(self, infeatures=784, outfeatures=10):
        super().__init__()
        self.fc1 = nn.Linear(infeatures, 256)
        self.LeakyReLU = nn.LeakyReLU()
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, outfeatures)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        x = x.view(-1, 784)
        x = self.fc1(x)
        x = self.LeakyReLU(x)
        x = self.fc2(x)
        x = self.LeakyReLU(x)
        x = self.fc3(x)
        x = self.softmax(x)
        return x

class Teacher(nn.Module):
    def __init__(self, infeatures=784, outfeatures=10):
        super().__init__()
        self.fc1 = nn.Linear(infeatures, 256)
        self.LeakyReLU = nn.LeakyReLU()
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, outfeatures)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        x = x.view(-1, 784)
        x = self.fc1(x)
        x = self.LeakyReLU(x)
        x = self.fc2(x)
        x = self.LeakyReLU(x)
        x = self.fc3(x)
        x = self.softmax(x)
        return x

In [4]:
#load data
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,),(0.5,))])

trainset = datasets.MNIST(root='./data',train=True,download=True,transform=transform)
trainloader = DataLoader(trainset,batch_size=64,shuffle=True)

testset = datasets.MNIST(root='./data',train=False,download=True,transform=transform)
testloader = DataLoader(testset,batch_size=64,shuffle=False)



Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:04<00:00, 2406425.02it/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 40177676.23it/s]

Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz





Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 4454379.19it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 12802774.71it/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw






In [5]:
# Train teacher
teacher = Teacher()
teacher_optimizer = optim.Adam(teacher.parameters(),lr=0.001)
teacher.train()

for epoch in range(5):
    running_loss = 0.0
    for i,data in enumerate(tqdm(trainloader)):
        inputs,labels = data
        teacher_optimizer.zero_grad()
        outputs = teacher(inputs)
        loss = nn.functional.cross_entropy(outputs,labels)
        loss.backward()
        teacher_optimizer.step()
        running_loss += loss.item()
    print("Loss: %.3f" % (running_loss/len(trainloader)))
print('Finished Training')

100%|██████████| 938/938 [00:09<00:00, 96.76it/s] 


Loss: 1.663


100%|██████████| 938/938 [00:11<00:00, 81.57it/s]


Loss: 1.540


100%|██████████| 938/938 [00:10<00:00, 93.37it/s] 


Loss: 1.521


 97%|█████████▋| 907/938 [00:11<00:00, 83.10it/s]

In [None]:
# Distill teacher to student
student = Student()
distiller = Distiller(student,teacher)
optimizer = optim.Adam(student.parameters(),lr=0.001)
distiller.train(trainloader,optimizer,epochs=2,T=20,alpha=0.5)

#test
distiller.test(testloader)

#save
distiller.save('student.pth')

#load
student = Student()
distiller = Distiller(student,teacher)
distiller.load('student.pth')

#test
distiller.test(testloader)

# Path: test.py.ipynb

100%|██████████| 938/938 [00:11<00:00, 84.96it/s]


Student loss: -479.680
Teacher loss: 1.495
Distillation loss: -2.402
Epoch: 1, Loss: -479.680


100%|██████████| 938/938 [00:11<00:00, 83.99it/s]


Student loss: -479.737
Teacher loss: 1.492
Distillation loss: -2.402
Epoch: 2, Loss: -479.737
Finished Training


100%|██████████| 157/157 [00:01<00:00, 98.04it/s] 


Accuracy of the network on the 10000 test images: 93 %


100%|██████████| 157/157 [00:01<00:00, 92.62it/s]

Accuracy of the network on the 10000 test images: 93 %





In [None]:
# Train student from scratch for comparison
student = Student()
optimizer = optim.Adam(student.parameters(),lr=0.001)
student.train()

for epoch in range(5):
    running_loss = 0.0
    for i,data in enumerate(tqdm(trainloader)):
        inputs,labels = data
        optimizer.zero_grad()
        outputs = student(inputs)
        loss = nn.functional.cross_entropy(outputs,labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print("Loss: %.3f" % (running_loss/len(trainloader)))

100%|██████████| 938/938 [00:10<00:00, 88.06it/s] 


Loss: 1.614


100%|██████████| 938/938 [00:12<00:00, 78.03it/s] 


Loss: 1.535


100%|██████████| 938/938 [00:11<00:00, 80.70it/s]


Loss: 1.519


 17%|█▋        | 160/938 [00:01<00:09, 85.00it/s]