# Knowledge Distillation with Hint Training: Experiment 1

Hint and Guided layer in beginning, middle and end.


In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets

# Check if GPU is available, and if not, use the CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [2]:
# Below we are preprocessing data for CIFAR-10. We use an arbitrary batch size of 128.
transforms_cifar = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Loading the CIFAR-10 dataset:
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transforms_cifar)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transforms_cifar)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:05<00:00, 30284428.27it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [3]:
#Dataloaders
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2)

DeepNN(Not actually Deep): Shallow Neural Network but very wide with 3 Convolutional Layers and over 3 million parameters.
LightNN: much deeper network with 9 convolutional layers but much thinner with over 140,000 parameters.

In [4]:
# Wide Shallow neural network class to be used as teacher:
# 3 different instantiations where the hint layer parameters are returned, either first, second, or third convolutional layer

class DeepNN(nn.Module):
    def __init__(self, num_classes=10):
        super(DeepNN, self).__init__()
        self.features1 = nn.Sequential(
            nn.Conv2d(3, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.features2 = nn.Sequential(
            nn.Conv2d(128, 356, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.features3 = nn.Sequential(
            nn.Conv2d(356, 356, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=4, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Linear(3204, 512),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        x = self.features1(x)
        conv_feature_map = x
        x = self.features2(x)
        x = self.features3(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x, conv_feature_map

class DeepNN2(nn.Module):
    def __init__(self, num_classes=10):
        super(DeepNN2, self).__init__()
        self.features1 = nn.Sequential(
            nn.Conv2d(3, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.features2 = nn.Sequential(
            nn.Conv2d(128, 356, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.features3 = nn.Sequential(
            nn.Conv2d(356, 356, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=4, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Linear(3204, 512),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(512, num_classes)
        )


    def forward(self, x):
        x = self.features1(x)
        x = self.features2(x)
        conv_feature_map = x
        x = self.features3(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x, conv_feature_map

class DeepNN3(nn.Module):
    def __init__(self, num_classes=10):
        super(DeepNN3, self).__init__()
        self.features1 = nn.Sequential(
            nn.Conv2d(3, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.features2 = nn.Sequential(
            nn.Conv2d(128, 356, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.features3 = nn.Sequential(
            nn.Conv2d(356, 356, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=4, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Linear(3204, 512),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        x = self.features1(x)
        x = self.features2(x)
        x = self.features3(x)
        conv_feature_map = x
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x, conv_feature_map



# Lightweight neural network class to be used as student, much deeper but thinner network:
class LightNN(nn.Module):
    def __init__(self, num_classes=10):
        super(LightNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(32, 48, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(48, 48, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(48, 48, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=4, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Linear(432, 128),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

In [5]:
#checking shape of data
for x, y in train_loader:
    print(x.shape)
    break

#checking shape of data
for x, y in test_loader:
    print(x.shape)
    break

torch.Size([128, 3, 32, 32])
torch.Size([128, 3, 32, 32])


In [6]:
def train(model, train_loader, epochs, learning_rate, device, is_tuple):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    model.train()

    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in train_loader:
            # inputs: A collection of batch_size images
            # labels: A vector of dimensionality batch_size with integers denoting class of each image
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            if is_tuple == True:
              outputs, map = model(inputs)
            if is_tuple == False:
              outputs = model(inputs)

            # outputs: Output of the network for the collection of images. A tensor of dimensionality batch_size x num_classes
            # labels: The actual labels of the images. Vector of dimensionality batch_size
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")

def test(model, test_loader, device, is_tuple):
    model.to(device)
    model.eval()

    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            if is_tuple == True:
              outputs, map = model(inputs)
            if is_tuple == False:
              outputs = model(inputs)

            #outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)

            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f"Test Accuracy: {accuracy:.2f}%")
    return accuracy

# Train Teacher Network

In [7]:
torch.manual_seed(42)
nn_deep = DeepNN(num_classes=10).to(device)
train(nn_deep, train_loader, epochs=15, learning_rate=0.001, device=device, is_tuple=True)
test_accuracy_deep = test(nn_deep, test_loader, device, is_tuple=True)

# Instantiate the lightweight network:
torch.manual_seed(42)

Epoch 1/15, Loss: 1.4520263959989523
Epoch 2/15, Loss: 0.9534348714382143
Epoch 3/15, Loss: 0.7547467898987138
Epoch 4/15, Loss: 0.6344018812527132
Epoch 5/15, Loss: 0.5438377241344403
Epoch 6/15, Loss: 0.4661091280257915
Epoch 7/15, Loss: 0.40139796339032596
Epoch 8/15, Loss: 0.3346936466443874
Epoch 9/15, Loss: 0.2862583933888799
Epoch 10/15, Loss: 0.24660988885652074
Epoch 11/15, Loss: 0.21071514782622036
Epoch 12/15, Loss: 0.17448605391223107
Epoch 13/15, Loss: 0.16270225264532182
Epoch 14/15, Loss: 0.13989289174489963
Epoch 15/15, Loss: 0.12728366112846243
Test Accuracy: 77.22%


<torch._C.Generator at 0x7e97b4bebcd0>

# Train Student Network

In [8]:
nn_light = LightNN(num_classes=10).to(device)
total_params_deep = "{:,}".format(sum(p.numel() for p in nn_deep.parameters()))
print(f"DeepNN parameters: {total_params_deep}")
total_params_light = "{:,}".format(sum(p.numel() for p in nn_light.parameters()))
print(f"LightNN parameters: {total_params_light}")

DeepNN parameters: 3,201,122
LightNN parameters: 140,378


In [9]:
train(nn_light, train_loader, epochs=15, learning_rate=0.001, device=device, is_tuple=False)
test_accuracy_light_ce = test(nn_light, test_loader, device, is_tuple=False)

Epoch 1/15, Loss: 1.8517431895751173
Epoch 2/15, Loss: 1.4583035832475824
Epoch 3/15, Loss: 1.234693345816239
Epoch 4/15, Loss: 1.0854455113715833
Epoch 5/15, Loss: 0.977134457329655
Epoch 6/15, Loss: 0.8846762924243117
Epoch 7/15, Loss: 0.8166822044898177
Epoch 8/15, Loss: 0.751284743666344
Epoch 9/15, Loss: 0.7037253174025689
Epoch 10/15, Loss: 0.6664565334387142
Epoch 11/15, Loss: 0.6220652912278919
Epoch 12/15, Loss: 0.5970804447408222
Epoch 13/15, Loss: 0.5594652122091455
Epoch 14/15, Loss: 0.5318591693020842
Epoch 15/15, Loss: 0.4972910936683645
Test Accuracy: 74.77%


In [10]:
torch.manual_seed(42)
new_nn_light = LightNN(num_classes=10).to(device)

# Train Student Network with Knowledge Distillation and Cross Entropy from Hinton et. al.

In [11]:
def train_knowledge_distillation(teacher, student, train_loader, epochs, learning_rate, T, soft_target_loss_weight, ce_loss_weight, device):
    ce_loss = nn.CrossEntropyLoss()
    optimizer = optim.Adam(student.parameters(), lr=learning_rate)

    teacher.eval()  # Teacher set to evaluation mode
    student.train() # Student to train mode

    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()

            # Forward pass with the teacher model - do not save gradients here as we do not change the teacher's weights
            with torch.no_grad():
                teacher_logits, _ = teacher(inputs)

            # Forward pass with the student model
            student_logits = student(inputs)

            #Soften the student logits by applying softmax first and log() second
            soft_targets = nn.functional.softmax(teacher_logits / T, dim=-1)
            soft_prob = nn.functional.log_softmax(student_logits / T, dim=-1)

            # Calculate the soft targets loss. Scaled by T**2 as suggested by the authors of the paper "Distilling the knowledge in a neural network"
            soft_targets_loss = -torch.sum(soft_targets * soft_prob) / soft_prob.size()[0] * (T**2)

            # Calculate the true label loss
            label_loss = ce_loss(student_logits, labels)

            # Weighted sum of the two losses
            loss = soft_target_loss_weight * soft_targets_loss + ce_loss_weight * label_loss

            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")

# Apply ``train_knowledge_distillation`` with a temperature of 2. Arbitrarily set the weights to 0.75 for CE and 0.25 for distillation loss.
train_knowledge_distillation(teacher=nn_deep, student=new_nn_light, train_loader=train_loader, epochs=15, learning_rate=0.001, T=2, soft_target_loss_weight=0.25, ce_loss_weight=0.75, device=device)
test_accuracy_light_ce_and_kd = test(new_nn_light, test_loader, device, is_tuple=False)

# Compare the student test accuracy with and without the teacher, after distillation
print(f"Teacher accuracy: {test_accuracy_deep:.2f}%")
print(f"Student accuracy without teacher: {test_accuracy_light_ce:.2f}%")
print(f"Student accuracy with CE + KD: {test_accuracy_light_ce_and_kd:.2f}%")

Epoch 1/15, Loss: 3.346776022020813
Epoch 2/15, Loss: 2.7153116559128625
Epoch 3/15, Loss: 2.349225106446639
Epoch 4/15, Loss: 2.117342648298844
Epoch 5/15, Loss: 1.9277966824333992
Epoch 6/15, Loss: 1.8013699536433305
Epoch 7/15, Loss: 1.6828551268028786
Epoch 8/15, Loss: 1.5906910091409903
Epoch 9/15, Loss: 1.5117233006850532
Epoch 10/15, Loss: 1.458070687015953
Epoch 11/15, Loss: 1.3901310415219164
Epoch 12/15, Loss: 1.3491188313650049
Epoch 13/15, Loss: 1.2871591282622588
Epoch 14/15, Loss: 1.2514180635552272
Epoch 15/15, Loss: 1.2079072280613052
Test Accuracy: 73.89%
Teacher accuracy: 77.22%
Student accuracy without teacher: 74.77%
Student accuracy with CE + KD: 73.89%


In [12]:
sample_input = torch.randn(128, 3, 32, 32).to(device)

In [13]:
convolutional_fe_output_student = nn_light.features(sample_input)
convolutional_fe_output_teacher = nn_deep.features1(sample_input)
# Print their shapes
print("Student's feature extractor output shape: ", convolutional_fe_output_student.shape)
print("Teacher's feature extractor output shape: ", convolutional_fe_output_teacher.shape)

Student's feature extractor output shape:  torch.Size([128, 48, 3, 3])
Teacher's feature extractor output shape:  torch.Size([128, 128, 16, 16])


# Modified Networks: Teacher stays the same. Student requires convolutional regressor on top of the Guided Layer. This increases the number of parameters to over 294,000 parameters.

In [14]:
class ModifiedLightNNRegressor(nn.Module):
    def __init__(self, num_classes=10):
        super(ModifiedLightNNRegressor, self).__init__()
        self.features1 = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )

        self.features2 = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.features3= nn.Sequential(
            nn.Conv2d(32, 48, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(48, 48, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(48, 48, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=4, stride=2),
        )

        self.regressor = nn.Sequential(
            nn.Conv2d(16, 128, kernel_size=3, padding=1)
        )

        self.classifier = nn.Sequential(
            nn.Linear(432, 128),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        x = self.features1(x)
        regressor_output = self.regressor(x)
        x = self.features2(x)
        x = self.features3(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x, regressor_output



class ModifiedLightNNRegressor2(nn.Module):
    def __init__(self, num_classes=10):
        super(ModifiedLightNNRegressor2, self).__init__()
        self.features1 = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )

        self.features2 = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.features3= nn.Sequential(
            nn.Conv2d(32, 48, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(48, 48, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(48, 48, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=4, stride=2),
        )

        self.regressor = nn.Sequential(
            nn.Conv2d(32, 356, kernel_size=3, padding=1)
        )

        self.classifier = nn.Sequential(
            nn.Linear(432, 128),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        x = self.features1(x)
        x = self.features2(x)
        regressor_output = self.regressor(x)
        x = self.features3(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x, regressor_output




class ModifiedLightNNRegressor3(nn.Module):
    def __init__(self, num_classes=10):
        super(ModifiedLightNNRegressor3, self).__init__()
        self.features1 = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )

        self.features2 = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.features3= nn.Sequential(
            nn.Conv2d(32, 48, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(48, 48, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(48, 48, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=4, stride=2),
        )

        self.regressor = nn.Sequential(
            nn.Conv2d(48, 356, kernel_size=3, padding=1)
        )

        self.classifier = nn.Sequential(
            nn.Linear(432, 128),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        x = self.features1(x)
        x = self.features2(x)
        x = self.features3(x)
        regressor_output = self.regressor(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x, regressor_output

In [15]:
mod_nn_light = ModifiedLightNNRegressor(num_classes=10).to(device)
convolutional_fe_output_student = mod_nn_light.features1(sample_input)
convolutional_fe_output_teacher = nn_deep.features1(sample_input)
# Print their shapes
print("Student's feature extractor output shape: ", convolutional_fe_output_student.shape)
print("Teacher's feature extractor output shape: ", convolutional_fe_output_teacher.shape)

Student's feature extractor output shape:  torch.Size([128, 16, 16, 16])
Teacher's feature extractor output shape:  torch.Size([128, 128, 16, 16])


In [16]:
total_params_deep = "{:,}".format(sum(p.numel() for p in nn_deep.parameters()))
print(f"DeepNN parameters: {total_params_deep}")
total_params_light = "{:,}".format(sum(p.numel() for p in mod_nn_light.parameters()))
print(f"LightNN parameters: {total_params_light}")

DeepNN parameters: 3,201,122
LightNN parameters: 158,938


In [17]:
def test_multiple_outputs(model, test_loader, device):
    model.to(device)
    model.eval()

    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            outputs, _ = model(inputs) # Disregard the second tensor of the tuple
            _, predicted = torch.max(outputs.data, 1)

            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f"Test Accuracy: {accuracy:.2f}%")
    return accuracy

In [18]:
import torch.nn as nn
import torch

class CustomLoss(nn.Module):
    def __init__(self):
        super(CustomLoss, self).__init__()

    def forward(self, inputs, targets):
        loss = (torch.norm(inputs - targets)**2)*0.5
        #loss = -1 * (targets * torch.log(inputs) + (1 - targets) * torch.log(1 - inputs))
        return loss.mean()

# Start 2-Stage Training:
1. MSE Loss between Hint and Guided Layers minimized for 10 epochs.
2. Knowledge distillation run where CE loss in combination with KD loss function are weighted and minimized.  

In [19]:
def train_mse_loss(teacher, student, train_loader, epochs, learning_rate, feature_map_weight, ce_loss_weight, device):
    ce_loss = nn.CrossEntropyLoss()
    #mse_loss = CustomLoss()
    mse_loss = nn.MSELoss()
    optimizer = optim.Adam(student.parameters(), lr=learning_rate)

    teacher.to(device)
    student.to(device)
    teacher.eval()  # Teacher set to evaluation mode
    student.train() # Student to train mode

    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()

            # Again ignore teacher logits
            with torch.no_grad():
                _, teacher_feature_map = teacher(inputs)

            # Forward pass with the student model
            student_logits, regressor_feature_map = student(inputs)

            # Calculate the loss
            hidden_rep_loss = mse_loss(regressor_feature_map, teacher_feature_map)

            # Calculate the true label loss
            label_loss = ce_loss(student_logits, labels)

            # Only minimize mse loss of teacher and regressor feature maps
            loss = label_loss#feature_map_weight * hidden_rep_loss + ce_loss_weight * label_loss

            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")

# Notice how our test function remains the same here with the one we used in our previous case. We only care about the actual outputs because we measure accuracy.

# Initialize a ModifiedLightNNRegressor
torch.manual_seed(42)
modified_nn_light_reg = ModifiedLightNNRegressor(num_classes=10).to(device)
modified_nn_light_reg2 = ModifiedLightNNRegressor2(num_classes=10).to(device)
modified_nn_light_reg3 = ModifiedLightNNRegressor3(num_classes=10).to(device)

# We do not have to train the modified deep network from scratch of course, we just load its weights from the trained instance
modified_nn_deep_reg = DeepNN(num_classes=10).to(device)
modified_nn_deep_reg.load_state_dict(nn_deep.state_dict())
modified_nn_deep_reg2 = DeepNN2(num_classes=10).to(device)
modified_nn_deep_reg2.load_state_dict(nn_deep.state_dict())
modified_nn_deep_reg3 = DeepNN3(num_classes=10).to(device)
modified_nn_deep_reg3.load_state_dict(nn_deep.state_dict())



#train(modified_nn_deep_reg, train_loader, epochs=10, learning_rate=0.001, device=device)
#test_accuracy_deep = test(modified_nn_deep_reg, test_loader, device)


#Freeze params after convolutional layers, train up until classifier
for param in modified_nn_light_reg.classifier.parameters():
  param.requires_grad = False
for param in modified_nn_light_reg2.classifier.parameters():
  param.requires_grad = False
for param in modified_nn_light_reg3.classifier.parameters():
  param.requires_grad = False

# Train and test once again
train_mse_loss(teacher=modified_nn_deep_reg, student=modified_nn_light_reg, train_loader=train_loader, epochs=10, learning_rate=0.001, feature_map_weight=0.25, ce_loss_weight=0.75, device=device)
train_mse_loss(teacher=modified_nn_deep_reg2, student=modified_nn_light_reg2, train_loader=train_loader, epochs=10, learning_rate=0.001, feature_map_weight=0.25, ce_loss_weight=0.75, device=device)
train_mse_loss(teacher=modified_nn_deep_reg3, student=modified_nn_light_reg3, train_loader=train_loader, epochs=10, learning_rate=0.001, feature_map_weight=0.25, ce_loss_weight=0.75, device=device)

Epoch 1/10, Loss: 1.892559883234751
Epoch 2/10, Loss: 1.5386585500234229
Epoch 3/10, Loss: 1.3878256393515545
Epoch 4/10, Loss: 1.281129860817014
Epoch 5/10, Loss: 1.2000640391388817
Epoch 6/10, Loss: 1.128105265710055
Epoch 7/10, Loss: 1.063829592274278
Epoch 8/10, Loss: 1.0044389043927497
Epoch 9/10, Loss: 0.9495981183198406
Epoch 10/10, Loss: 0.9041930549894758
Epoch 1/10, Loss: 1.94764286935177
Epoch 2/10, Loss: 1.6376763310883662
Epoch 3/10, Loss: 1.4775052744409312
Epoch 4/10, Loss: 1.3662629185430228
Epoch 5/10, Loss: 1.2686319573760947
Epoch 6/10, Loss: 1.1931388086979957
Epoch 7/10, Loss: 1.122466505945796
Epoch 8/10, Loss: 1.0584449883921982
Epoch 9/10, Loss: 1.0070317500387616
Epoch 10/10, Loss: 0.969907849798422
Epoch 1/10, Loss: 1.9602105946796935
Epoch 2/10, Loss: 1.6109210810697903
Epoch 3/10, Loss: 1.4027366940017856
Epoch 4/10, Loss: 1.264879050157259
Epoch 5/10, Loss: 1.1622212442290751
Epoch 6/10, Loss: 1.0828604310979624
Epoch 7/10, Loss: 1.0339541429144037
Epoch 8/

In [20]:
#unfreeze parameters and train finally with knowledge distillation for stage 2

for param in modified_nn_light_reg.classifier.parameters():
  param.requires_grad = True
for param in modified_nn_light_reg2.classifier.parameters():
  param.requires_grad = True
for param in modified_nn_light_reg3.classifier.parameters():
  param.requires_grad = True

def train_knowledge_distillation1(teacher, student, train_loader, epochs, learning_rate, T, soft_target_loss_weight, ce_loss_weight, device):
    ce_loss = nn.CrossEntropyLoss()
    optimizer = optim.Adam(student.parameters(), lr=learning_rate)

    teacher.eval()  # Teacher set to evaluation mode
    student.train() # Student to train mode

    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()

            # Forward pass with the teacher model - do not save gradients here as we do not change the teacher's weights
            with torch.no_grad():
                teacher_logits, _ = teacher(inputs)

            # Forward pass with the student model
            student_logits, _  = student(inputs)

            #Soften the student logits by applying softmax first and log() second
            soft_targets = nn.functional.softmax(teacher_logits / T, dim=-1)
            soft_prob = nn.functional.log_softmax(student_logits / T, dim=-1)

            # Calculate the soft targets loss. Scaled by T**2 as suggested by the authors of the paper "Distilling the knowledge in a neural network"
            soft_targets_loss = -torch.sum(soft_targets * soft_prob) / soft_prob.size()[0] * (T**2)

            # Calculate the true label loss
            label_loss = ce_loss(student_logits, labels)

            # Weighted sum of the two losses
            loss = soft_target_loss_weight * soft_targets_loss + ce_loss_weight * label_loss

            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")



In [21]:
train_knowledge_distillation1(teacher=modified_nn_deep_reg, student=modified_nn_light_reg, train_loader=train_loader, epochs=10, learning_rate=0.001, T=2, soft_target_loss_weight=0.25, ce_loss_weight=0.75, device=device)
train_knowledge_distillation1(teacher=modified_nn_deep_reg2, student=modified_nn_light_reg2, train_loader=train_loader, epochs=10, learning_rate=0.001, T=2, soft_target_loss_weight=0.25, ce_loss_weight=0.75, device=device)
train_knowledge_distillation1(teacher=modified_nn_deep_reg3, student=modified_nn_light_reg3, train_loader=train_loader, epochs=10, learning_rate=0.001, T=2, soft_target_loss_weight=0.25, ce_loss_weight=0.75, device=device)
print(f"Teacher Accuracy:")
test_accuracy_deep = test(modified_nn_deep_reg, test_loader, device, is_tuple= True)
print(f"Final Student Accuracy (Hint = 1):")
test_accuracy_light_ce_and_mse_loss1 = test_multiple_outputs(modified_nn_light_reg, test_loader, device)
print(f"Final Student Accuracy (Hint = 2):")
test_accuracy_light_ce_and_mse_loss2 = test_multiple_outputs(modified_nn_light_reg2, test_loader, device)
print(f"Final Student Accuracy (Hint = 3):")
test_accuracy_light_ce_and_mse_loss3 = test_multiple_outputs(modified_nn_light_reg3, test_loader, device)

Epoch 1/10, Loss: 1.6636662882612185
Epoch 2/10, Loss: 1.5187497803622194
Epoch 3/10, Loss: 1.4281714770495129
Epoch 4/10, Loss: 1.351959081866857
Epoch 5/10, Loss: 1.2860061791546815
Epoch 6/10, Loss: 1.2245834632907682
Epoch 7/10, Loss: 1.1704592335864406
Epoch 8/10, Loss: 1.135043044376861
Epoch 9/10, Loss: 1.0854161517394474
Epoch 10/10, Loss: 1.068695957246034
Epoch 1/10, Loss: 1.7640529642324618
Epoch 2/10, Loss: 1.626938057067754
Epoch 3/10, Loss: 1.5341035414229878
Epoch 4/10, Loss: 1.4719279841388888
Epoch 5/10, Loss: 1.4043076294462393
Epoch 6/10, Loss: 1.3539477021188078
Epoch 7/10, Loss: 1.3061789658368397
Epoch 8/10, Loss: 1.2564769089984162
Epoch 9/10, Loss: 1.2021003133805512
Epoch 10/10, Loss: 1.1746260920144103
Epoch 1/10, Loss: 1.6988924763086812
Epoch 2/10, Loss: 1.5742417817835308
Epoch 3/10, Loss: 1.4906383219277461
Epoch 4/10, Loss: 1.4271146677948934
Epoch 5/10, Loss: 1.3723106914773926
Epoch 6/10, Loss: 1.312894817203512
Epoch 7/10, Loss: 1.2773472601190552
Epoc

# Final Model Accuracies
As we can see, the student's accuracy is much improved with the 2-stage training following similar methodology in Romero et. al.

In [22]:
print(f"Teacher accuracy: {test_accuracy_deep:.2f}%")
print(f"Student accuracy without teacher: {test_accuracy_light_ce:.2f}%")
print(f"Student accuracy with CE + KD: {test_accuracy_light_ce_and_kd:.2f}%")
print(f"Student accuracy (hint=1 and guided = Beginning) and Stage 1(RegressorMSE) + Stage 2(CE + KD): {test_accuracy_light_ce_and_mse_loss1:.2f}%")
print(f"Student accuracy (hint=2 and guided = Middle) and Stage 1(RegressorMSE) + Stage 2(CE + KD): {test_accuracy_light_ce_and_mse_loss2:.2f}%")
print(f"Student accuracy (hint=3 and guided = End) and Stage 1(RegressorMSE) + Stage 2(CE + KD): {test_accuracy_light_ce_and_mse_loss3:.2f}%")

Teacher accuracy: 77.22%
Student accuracy without teacher: 74.77%
Student accuracy with CE + KD: 73.89%
Student accuracy (hint=1 and guided = Beginning) and Stage 1(RegressorMSE) + Stage 2(CE + KD): 74.45%
Student accuracy (hint=2 and guided = Middle) and Stage 1(RegressorMSE) + Stage 2(CE + KD): 72.47%
Student accuracy (hint=3 and guided = End) and Stage 1(RegressorMSE) + Stage 2(CE + KD): 74.31%


In [23]:
nn_light = LightNN(num_classes=100).to(device)
total_params_deep = "{:,}".format(sum(p.numel() for p in nn_deep.parameters()))
print(f"DeepNN parameters: {total_params_deep}")
total_params_light = "{:,}".format(sum(p.numel() for p in nn_light.parameters()))
print(f"LightNN parameters: {total_params_light}")
total_params_light = "{:,}".format(sum(p.numel() for p in modified_nn_light_reg.parameters()))
print(f"LightNN parameters: {total_params_light}")
total_params_light = "{:,}".format(sum(p.numel() for p in modified_nn_light_reg2.parameters()))
print(f"LightNN parameters: {total_params_light}")
total_params_light = "{:,}".format(sum(p.numel() for p in modified_nn_light_reg3.parameters()))
print(f"LightNN parameters: {total_params_light}")

DeepNN parameters: 3,201,122
LightNN parameters: 151,988
LightNN parameters: 158,938
LightNN parameters: 243,262
LightNN parameters: 294,526
