In [1]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.autograd import Variable
from tqdm import tqdm
import numpy as np
import torch.nn.functional as F

In [2]:
import torch.utils.data as Data
    
BATCH_SIZE = 100

train_dataset = datasets.MNIST(root='./data/',train=True,transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ]))
test_dataset = datasets.MNIST(root='./data/',train=False,transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ]))


train_loader = Data.DataLoader(dataset=train_dataset,batch_size=BATCH_SIZE,shuffle=True)
test_loader = Data.DataLoader(dataset=test_dataset,batch_size=BATCH_SIZE,shuffle=False)

In [3]:
class teacherNet(nn.Module):
    def __init__(self):
        super(teacherNet, self).__init__()
        self.layers=nn.Sequential(
            
            nn.Linear(784, 1200),
            nn.BatchNorm1d(1200),
            nn.Hardtanh(),
            
            nn.Linear(1200,1200),
            nn.BatchNorm1d(1200),
            nn.Hardtanh(),
            
            
            nn.Linear(1200,10),
            nn.BatchNorm1d(10),
            nn.Hardtanh(),
        )
    def forward(self, x):
        x = x.view(-1, 784)
        x = self.layers(x)
        return x

teacher = teacherNet()
teacher.cuda()

teacherNet(
  (layers): Sequential(
    (0): Linear(in_features=784, out_features=1200, bias=True)
    (1): BatchNorm1d(1200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): Hardtanh(min_val=-1.0, max_val=1.0)
    (3): Linear(in_features=1200, out_features=1200, bias=True)
    (4): BatchNorm1d(1200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): Hardtanh(min_val=-1.0, max_val=1.0)
    (6): Linear(in_features=1200, out_features=10, bias=True)
    (7): BatchNorm1d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): Hardtanh(min_val=-1.0, max_val=1.0)
  )
)

In [4]:
class studentNet(nn.Module):
    def __init__(self):
        super(studentNet, self).__init__()
        self.layers=nn.Sequential(
            
            nn.Linear(784, 100),
            nn.BatchNorm1d(100),
            nn.Hardtanh(),
            
            
            nn.Linear(100,10),
            nn.BatchNorm1d(10),
            nn.Hardtanh(),
        )
    def forward(self, x):
        x = x.view(-1, 784)
        x = self.layers(x)
        return x

student = studentNet()
student.cuda()

studentNet(
  (layers): Sequential(
    (0): Linear(in_features=784, out_features=100, bias=True)
    (1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): Hardtanh(min_val=-1.0, max_val=1.0)
    (3): Linear(in_features=100, out_features=10, bias=True)
    (4): BatchNorm1d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): Hardtanh(min_val=-1.0, max_val=1.0)
  )
)

In [6]:
def test(model):
    model.eval()
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.cuda(), target.cuda()
            data, target = Variable(data), Variable(target)
            output =F.softmax(model(data),dim=-1)
            pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
            correct += pred.eq(target.data.view_as(pred)).cpu().sum()
            
    print('Test :  Accuracy: {}/{} ({:.2f}%)\n'.format(
        correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

In [14]:
import torch.optim as optim
# train teacher net
criterion1 = nn.CrossEntropyLoss()
optimizer1 = optim.Adam(teacher.parameters(), lr=1e-4)

teacher.train()
for epoch in range(1, 20):
    for data, target in train_loader:
        data, target = data.cuda(), target.cuda()
        data, target = Variable(data), Variable(target)
        optimizer1.zero_grad()
        output =F.softmax( teacher(data),dim=-1)
        loss = criterion1(output, target)


        optimizer1.zero_grad()
        loss.backward()
        optimizer1.step()

    test(teacher)

Test :  Accuracy: 9769/10000 (97.69%)

Test :  Accuracy: 9820/10000 (98.20%)

Test :  Accuracy: 9822/10000 (98.22%)

Test :  Accuracy: 9818/10000 (98.18%)



KeyboardInterrupt: 

In [8]:
optimizer2 = optim.Adam(student.parameters(), lr=1e-3)

criterion2 = nn.KLDivLoss()

In [None]:
# train student net without teacher


student.train()
for epoch in range(1, 5):
    prbar = tqdm(total=len(train_loader))
    prbar.set_description("student training epoch"+str(epoch))
    for data, target in train_loader:
        data, target = data.cuda(), target.cuda()
        data, target = Variable(data), Variable(target)
        optimizer2.zero_grad()
        output = F.softmax(student(data),dim=-1)
        loss = criterion1(output, target)


        optimizer2.zero_grad()
        loss.backward()
        optimizer2.step()

        prbar.update(1)
        prbar.set_postfix(loss=loss.item())
    test(student)
    prbar.close()

In [11]:
# train student net with the help of teacher

alpha = 0.8
T=3


student.train()
teacher.eval()
for epoch in range(1, 10):
    for data, target in train_loader:
        data, target = data.cuda(), target.cuda()
        data, target = Variable(data), Variable(target)
        optimizer2.zero_grad()
        student_output = student(data)
        student_output = F.softmax(student_output,dim=-1)
        loss1 = criterion1(student_output, target)
        with torch.no_grad():
            teacher_output = teacher(data)
        teacher_output = F.softmax(teacher_output/T,dim=-1)
        loss2 = criterion2(student_output,teacher_output)*T*T
        
        loss = loss1*(1-alpha) + loss2*alpha
        loss.backward()
        optimizer2.step()

    test(student)

Test :  Accuracy: 9649/10000 (96.49%)



KeyboardInterrupt: 