Importing libraries

In [0]:
# Import pytorch basic functions/classes
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

# Import torchvision functions/classes for MNIST import and data loaders
import torchvision
import torchvision.transforms as transforms

# Set device on which code is run
device = 'cuda'

Defining support functions

In [0]:
# Define support function used to convert label to one-hot encoded tensor
def convert_labels(labels):
    target = torch.zeros([len(labels), 10], dtype=torch.float32)
    for i, l in enumerate(labels):
      target[i][l] = 1.0
    return target

Define our network model (the hidden layers size is specified through the constructor)

In [0]:
# Define MLP model and its layers
class Model(nn.Module):

    def __init__(self, hidden_size=1200, dropout=0.0, hidden_dropout=0.0):
        super(Model, self).__init__()
        self.dropout = nn.Dropout(dropout)
        self.hidden1 = nn.Linear(784, hidden_size, bias=True)
        self.hidden1_dropout = nn.Dropout(hidden_dropout)
        self.hidden2 = nn.Linear(hidden_size, hidden_size, bias=True)
        self.hidden2_dropout = nn.Dropout(hidden_dropout)
        self.hidden3 = nn.Linear(hidden_size, 10, bias=True)

    def forward(self, x):

        x = self.dropout(x)
        x = F.relu(self.hidden1(x))
        x = self.hidden1_dropout(x)
        x = F.relu(self.hidden2(x))
        x = self.hidden2_dropout(x)
        x = self.hidden3(x)
        return x#, F.softmax(x)

Downloading MNIST dataset

In [27]:
# Define transform from PIL image to tensor and normalize to 1x768 pixels
train_transform = transforms.Compose([
  transforms.RandomAffine(0, (1/14, 1/14)),
  transforms.ToTensor(),
  transforms.Normalize((0.5,), (0.5,))
])

test_transform = transforms.Compose([
  transforms.ToTensor(),
  transforms.Normalize((0.5,), (0.5,))
])

# Set batch size for data loaders
batch_size = 128

# (Down)load training set
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=train_transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)

# (Down)load test set
testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=test_transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)

'''examples = enumerate(testloader)
batch_idx, (example_data, example_targets) = next(examples)

import matplotlib.pyplot as plt

fig = plt.figure()
for i in range(6):
  plt.subplot(2,3,i+1)
  plt.tight_layout()
  plt.imshow(example_data[i][0], cmap='gray', interpolation='none')
  plt.title("Ground Truth: {}".format(example_targets[i]))
  plt.xticks([])
  plt.yticks([])
#fig'''

'examples = enumerate(testloader)\nbatch_idx, (example_data, example_targets) = next(examples)\n\nimport matplotlib.pyplot as plt\n\nfig = plt.figure()\nfor i in range(6):\n  plt.subplot(2,3,i+1)\n  plt.tight_layout()\n  plt.imshow(example_data[i][0], cmap=\'gray\', interpolation=\'none\')\n  plt.title("Ground Truth: {}".format(example_targets[i]))\n  plt.xticks([])\n  plt.yticks([])\n#fig'

Training the Deep Teacher Neural Network

In [26]:
# Setup model and move it to the GPU
net = Model(dropout=0.2, hidden_dropout=0.5)
net.to(device)

# Set up loss function and optimizer: 
#     using cross entropy loss because it's better for classification task

learning_rate = 0.01
criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr= learning_rate, momentum=0.9)
#optimizer = torch.optim.Adam(net.parameters(), lr=0.0001, weight_decay=0.00001)

# Run over 100 epochs (1 epoch = visited all items in dataset)
for epoch in range(2000):

    running_loss = 0.0
    total = 0

    if(epoch%200 == 0 and epoch != 0):

      learning_rate = learning_rate - (0.001) # or maybe decrease by (learning_rate * 0.1)
      optimizer = optim.SGD(net.parameters(), lr= learning_rate, momentum=0.9)

    for i, data in enumerate(trainloader, 0):
        

        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data
        inputs = torch.flatten(inputs, start_dim=1).to(device)

        # This for not cross entropy
        #target = convert_labels(labels).to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs)
        target = labels.to(device).long()
        loss = criterion(outputs, target)
        loss.backward()
        optimizer.step()

        total += len(data)

        # print statistics
        running_loss += loss.item()
    # print every epoch
    print('[%d] loss: %.3f' % (epoch + 1, running_loss / total))

print('Finished Training')

# Save model after having finished training
PATH = './mnist_dropout_100_epoch.pth'
torch.save(net.state_dict(), PATH)

print('Saved Model')

[1] loss: 0.515
[2] loss: 0.265
[3] loss: 0.202
[4] loss: 0.174
[5] loss: 0.153
[6] loss: 0.140
[7] loss: 0.131
[8] loss: 0.122
[9] loss: 0.116
[10] loss: 0.110
[11] loss: 0.107
[12] loss: 0.103
[13] loss: 0.099
[14] loss: 0.094
[15] loss: 0.093
[16] loss: 0.091
[17] loss: 0.086
[18] loss: 0.087
[19] loss: 0.083
[20] loss: 0.080
[21] loss: 0.082
[22] loss: 0.078
[23] loss: 0.076
[24] loss: 0.076
[25] loss: 0.075
[26] loss: 0.072
[27] loss: 0.075
[28] loss: 0.070
[29] loss: 0.069
[30] loss: 0.067
[31] loss: 0.068
[32] loss: 0.068
[33] loss: 0.065
[34] loss: 0.065
[35] loss: 0.064
[36] loss: 0.067
[37] loss: 0.064
[38] loss: 0.062
[39] loss: 0.062
[40] loss: 0.062
[41] loss: 0.061
[42] loss: 0.062
[43] loss: 0.060
[44] loss: 0.060
[45] loss: 0.058
[46] loss: 0.058
[47] loss: 0.058
[48] loss: 0.057
[49] loss: 0.057
[50] loss: 0.055
[51] loss: 0.055
[52] loss: 0.056
[53] loss: 0.054
[54] loss: 0.056
[55] loss: 0.055
[56] loss: 0.053
[57] loss: 0.054
[58] loss: 0.053
[59] loss: 0.052
[60] l

Run Deep Teacher model on test set

In [28]:
# Instantiate model and load saved network parameters
net = Model().to(device)
net.load_state_dict(torch.load(PATH))

# Run model on test set and determine accuracy
correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        inputs, labels = data
        inputs = torch.flatten(inputs, start_dim=1).to(device)
        target = convert_labels(labels).to(device)
        outputs = net(inputs)
        _, predicted = torch.max(outputs.data, 1)
        _, target = torch.max(target.data, 1)
        total += target.size(0)
        correct += (predicted == target).sum().item()

# Output model accuracy to user
print('Accuracy of the network on the 10000 test images: %d %% (%d wrong out of %d)' % (
    100 * correct / total, total - correct, total))

Accuracy of the network on the 10000 test images: 99 % (61 wrong out of 10000)


Train student model to mimic the teacher

In [29]:
import numpy as np
# Custom student loss: linear combination of 2 cross-entropy losses
#     The first one between student output and hard labels
#     The second one between student output and soft labels from teacher

def student_loss(outputs, labels, teacher_outputs, alpha, temperature):

    #loss = torch.mean((weight*(outputA - targetA)**2) + (1-weight)*(outputB-targetB)**2)

    #loss = weight*(F.cross_entropy(outputA, targetA)) + (1-weight)*(F.cross_entropy(outputB, targetB))

    KD_loss = nn.KLDivLoss()(F.log_softmax(outputs/T, dim=1),
                             F.softmax(teacher_outputs/T, dim=1)) * (alpha * T * T) + \
              F.cross_entropy(outputs, labels) * (1. - alpha)

    return KD_loss

# Setup student model and move it to the GPU
student_net = Model(hidden_size = 800)
student_net.to(device)

# Set up loss function and optimizer

optimizer = optim.SGD(student_net.parameters(), lr=0.001, momentum=0.9)
#optimizer = torch.optim.Adam(student_net.parameters(), lr=0.0001)

# Run over 100 epochs (1 epoch = visited all items in dataset)
for epoch in range(100):
    running_loss = 0.0
    total = 0
    for i, data in enumerate(trainloader, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data
        inputs = torch.flatten(inputs, start_dim=1).to(device)
        target = labels.to(device).long() #convert_labels(labels).to(device)
        
        # zero the parameter gradients
        optimizer.zero_grad()

        # Set temperature and the weights for losses linear combination
        w = 0.3
        T = 20

        # Compute soft labels using deep teacher model previously trained
        outputs_teacher = net(inputs)
        #soft_labels = F.softmax(outputs_teacher/T, dim = 1)

        # Abomination to obtain hard_labels for custom cross entropy loss
        #teacher_hard_labels = torch.from_numpy(np.array([np.argmax(l.cpu().detach().numpy()) for l in soft_labels])).to(device).long()

        # Student forward + backward + optimize
        outputs_stud = student_net(inputs)
        #outputs_stud = F.softmax(output_stud/T, dim = 1)
        
        loss = student_loss(outputs_stud, target, outputs_teacher, w, T)
        loss.backward()
        optimizer.step()

        total += len(data)

        # print statistics
        running_loss += loss.item()
    # print every epoch
    print('[%d] loss: %.3f' % (epoch + 1, running_loss / total))

print('Finished Training')

# Save model after having finished training
STUD_PATH = './mnist_student_100_epoch.pth'
torch.save(student_net.state_dict(), STUD_PATH)

print('Saved Model')



[1] loss: 1.408
[2] loss: 0.948
[3] loss: 0.811
[4] loss: 0.748
[5] loss: 0.694
[6] loss: 0.643
[7] loss: 0.590
[8] loss: 0.530
[9] loss: 0.475
[10] loss: 0.420
[11] loss: 0.375
[12] loss: 0.338
[13] loss: 0.308
[14] loss: 0.285
[15] loss: 0.265
[16] loss: 0.248
[17] loss: 0.234
[18] loss: 0.220
[19] loss: 0.209
[20] loss: 0.197
[21] loss: 0.187
[22] loss: 0.180
[23] loss: 0.172
[24] loss: 0.166
[25] loss: 0.159
[26] loss: 0.153
[27] loss: 0.149
[28] loss: 0.143
[29] loss: 0.139
[30] loss: 0.135
[31] loss: 0.130
[32] loss: 0.127
[33] loss: 0.124
[34] loss: 0.120
[35] loss: 0.117
[36] loss: 0.115
[37] loss: 0.112
[38] loss: 0.111
[39] loss: 0.107
[40] loss: 0.105
[41] loss: 0.103
[42] loss: 0.102
[43] loss: 0.099
[44] loss: 0.096
[45] loss: 0.095
[46] loss: 0.094
[47] loss: 0.091
[48] loss: 0.089
[49] loss: 0.089
[50] loss: 0.089
[51] loss: 0.087
[52] loss: 0.085
[53] loss: 0.085
[54] loss: 0.082
[55] loss: 0.081
[56] loss: 0.080
[57] loss: 0.080
[58] loss: 0.079
[59] loss: 0.077
[60] l

Running student model on test set

In [30]:
stud_net = Model(hidden_size = 800).to(device)
stud_net.load_state_dict(torch.load(STUD_PATH))

# Run model on test set and determine accuracy
correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        inputs, labels = data
        inputs = torch.flatten(inputs, start_dim=1).to(device)
        target = convert_labels(labels).to(device)
        outputs = stud_net(inputs)
        _, predicted = torch.max(outputs.data, 1)
        _, target = torch.max(target.data, 1)
        total += target.size(0)
        correct += (predicted == target).sum().item()

# Output model accuracy to user
print('Accuracy of the network on the 10000 test images: %d %% (%d wrong out of %d)' % (
    100 * correct / total, total - correct, total))

Accuracy of the network on the 10000 test images: 98 % (124 wrong out of 10000)
