In [1]:
import torch                                        # root package
from torch.utils.data import Dataset, DataLoader    # dataset representation and loading
import torch.nn as nn                     # neural networks
import torch.nn.functional as F           # layers, activations and more
import torch.optim as optim 
from torchvision import datasets
from torchvision.transforms import ToTensor

torch.manual_seed(0)

<torch._C.Generator at 0x2210a6657b0>

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:
train = datasets.MNIST(root='./data/',train=True, download=True, transform=ToTensor())
test = datasets.MNIST(root='./data/',train=False, download=True, transform=ToTensor())

In [4]:
bs = 64
train_loader = DataLoader(dataset=train, batch_size=bs, shuffle=True, num_workers=1, pin_memory=True)
test_loader = DataLoader(dataset=test, batch_size=bs, shuffle=True, num_workers=1, pin_memory=True)

In [5]:
class Net(nn.Module):
    def __init__(self, d):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(784, d),
            nn.Dropout(0.3),
            nn.ReLU(),
            nn.Linear(d, d),
            nn.Dropout(0.3),
            nn.ReLU(),
            nn.Linear(d, 10)
        )
    def forward(self, x):
        return self.layers(x)

In [6]:
teacher_model = Net(1200)
teacher_model = teacher_model.to(device)
print(teacher_model)

Net(
  (layers): Sequential(
    (0): Linear(in_features=784, out_features=1200, bias=True)
    (1): Dropout(p=0.3, inplace=False)
    (2): ReLU()
    (3): Linear(in_features=1200, out_features=1200, bias=True)
    (4): Dropout(p=0.3, inplace=False)
    (5): ReLU()
    (6): Linear(in_features=1200, out_features=10, bias=True)
  )
)


In [7]:
optimizer = optim.Adam(params=teacher_model.parameters(), lr=0.001)
loss = torch.nn.CrossEntropyLoss()

In [8]:
def train(model, epochs, optimizer, loss):
    for epoch in range(epochs):
        model.train()
        total_loss = []
        accurate = 0
        total = 0
        for x,y in train_loader:
            x = x.to(device)
#             x = x.view(-1, 784)
            y = y.to(device)
            total += x.shape[0]
            y_hat = model(x.view(-1,784))
            _, pred_label = torch.max(y_hat.data, 1)
#             print(y.shape)
            accurate += torch.sum(pred_label==y)
            train_loss = loss(y_hat, y)
            total_loss.append(train_loss)
            optimizer.zero_grad()
            train_loss.backward()
            optimizer.step()
        if epoch%10 == 0:
            print("Epoch : {}".format(epoch))
            print("Loss on train set : {} and Accuracy : {}".format((sum(total_loss)/len(train_loader)).item(), (accurate/total).item()))
            with torch.no_grad():
                model.eval()
                accurate = 0
                total = 0
                total_loss = []
                for x,y in test_loader:
                    x = x.to(device)
                    x = x.view(-1,784)                                               
                    y = y.to(device)
                    total += x.shape[0]
                    y_test = model(x)
                    _, pred_label = torch.max(y_test.data, 1)
#                     print(pred_label)
                    accurate += torch.sum(pred_label==y)
                    total_loss.append(loss(y_test,y))
                print("Loss on test set : {} and Accuracy : {}".format((sum(total_loss)/len(test_loader)).item(), (accurate/total).item()))
                print("-----------------------------------------------------------------------------------")

In [9]:
train(teacher_model,41, optimizer, loss)

Epoch : 0
Loss on train set : 0.22282832860946655 and Accuracy : 0.9310833215713501
Loss on test set : 0.11105051636695862 and Accuracy : 0.9652999639511108
-----------------------------------------------------------------------------------
Epoch : 10
Loss on train set : 0.03577291592955589 and Accuracy : 0.9890000224113464
Loss on test set : 0.08173023164272308 and Accuracy : 0.9824000000953674
-----------------------------------------------------------------------------------
Epoch : 20
Loss on train set : 0.03412889316678047 and Accuracy : 0.9917166829109192
Loss on test set : 0.10353590548038483 and Accuracy : 0.9802999496459961
-----------------------------------------------------------------------------------
Epoch : 30
Loss on train set : 0.023396972566843033 and Accuracy : 0.9942833185195923
Loss on test set : 0.10780889540910721 and Accuracy : 0.9861999750137329
-----------------------------------------------------------------------------------
Epoch : 40
Loss on train set : 0

In [10]:
softmax = nn.Softmax(dim=1)
mseLoss = nn.MSELoss()
def distillation_loss(yhat, y, temperature):
    smooth_smax_pred = softmax(yhat / temperature)
    smooth_target = softmax(y / temperature)
    #Original loss in paper
    # loss = 0.5 * torch.pow(smooth_smax_pred - smooth_target, 2)
    return mseLoss(smooth_smax_pred, smooth_target)

In [11]:
def train_student(teacher, model, epochs, temp, optimizer, loss):
    """
    temp: temperature for loss calculation
    teacher: teacher model
    epoch: total iterations over dataset
    model: the student model to be trained
    """
    for epoch in range(epochs):
        teacher.eval()
        model.train()
        total_loss = []
        accurate = 0
        total = 0
        for x,y in train_loader:
            x = x.to(device)
            y = y.to(device)
            target = teacher(x.view(-1,784)) #Get the target from the cumbersome model
            total += x.shape[0]
            y_hat = model(x.view(-1,784))
            _, pred_label = torch.max(y_hat.data, 1)
            accurate += torch.sum(pred_label==y)
            train_loss = distillation_loss(y_hat, target, temp) 
            total_loss.append(train_loss)
            optimizer.zero_grad()
            train_loss.backward()
            optimizer.step()
        if epoch%5 == 0:
            print("Loss on train set : {} and Accuracy : {}".format((sum(total_loss)/len(train_loader)).item(), (accurate/total).item()))
            with torch.no_grad():
                model.eval()
                accurate = 0
                total = 0
                for x,y in test_loader:
                    x = x.to(device)
                    x = x.view(-1,784)                                               
                    y = y.to(device)
                    total += x.shape[0]
                    y_test = model(x)
                    _, pred_label = torch.max(y_test.data, 1)
                    accurate += torch.sum(pred_label==y)
                print("Accuracy on Test set : {}".format( (accurate/total).item()))

In [36]:
#training benchmark
benchmark_student = Net(100)
benchmark_student = benchmark_student.to(device)
optimizer_bench = optim.Adam(params=benchmark_student.parameters(), lr=0.001)
loss = torch.nn.CrossEntropyLoss()
train(benchmark_student, 11, optimizer_bench, loss)

Epoch : 0
Loss on train set : 0.4543033242225647 and Accuracy : 0.8642666935920715
Loss on test set : 0.17455440759658813 and Accuracy : 0.9465999603271484
-----------------------------------------------------------------------------------
Epoch : 10
Loss on train set : 0.09133455157279968 and Accuracy : 0.9708666801452637
Loss on test set : 0.08549182116985321 and Accuracy : 0.9753999710083008
-----------------------------------------------------------------------------------


In [40]:
student_model = Net(100)
student_model = student_model.to(device)
optimizer_student = optim.Adam(params=student_model.parameters(), lr=0.001)
loss = distillation_loss
train_student(teacher=teacher_model, model=student_model, epochs=11, temp=2.5, optimizer=optimizer_student, loss=loss)

Loss on train set : 0.02274789847433567 and Accuracy : 0.8457333445549011
Accuracy on Test set : 0.934499979019165
Loss on train set : 0.006044209934771061 and Accuracy : 0.9581166505813599
Accuracy on Test set : 0.9686999917030334
Loss on train set : 0.004588209558278322 and Accuracy : 0.9681333303451538
Accuracy on Test set : 0.9714999794960022


In [None]:
# 20 layer net can 