# Knowledge Distillation with Hint Training: Experiment 3

Multiple hint layers


In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets

# Check if GPU is available, and if not, use the CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [2]:
# Below we are preprocessing data for CIFAR-10. We use an arbitrary batch size of 128.
transforms_cifar = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Loading the CIFAR-10 dataset:
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transforms_cifar)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transforms_cifar)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:02<00:00, 71014285.28it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [3]:
#Dataloaders
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2)

DeepNN(Not actually Deep): Shallow Neural Network but very wide with 3 Convolutional Layers and over 3 million parameters.
LightNN: much deeper network with 9 convolutional layers but much thinner with over 140,000 parameters.

In [5]:
# Wide Shallow neural network class to be used as teacher:
# 3 different instantiations where the hint layer parameters are returned, either first, second, or third convolutional layer

class DeepNN(nn.Module):
    def __init__(self, num_classes=10):
        super(DeepNN, self).__init__()
        self.features1 = nn.Sequential(
            nn.Conv2d(3, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.features2 = nn.Sequential(
            nn.Conv2d(128, 356, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.features3 = nn.Sequential(
            nn.Conv2d(356, 356, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=4, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Linear(3204, 512),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        x = self.features1(x)
        conv_feature_map = x
        x = self.features2(x)
        conv_feature_map2 = x
        x = self.features3(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x, conv_feature_map, conv_feature_map2

class DeepNN2(nn.Module):
    def __init__(self, num_classes=10):
        super(DeepNN2, self).__init__()
        self.features1 = nn.Sequential(
            nn.Conv2d(3, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.features2 = nn.Sequential(
            nn.Conv2d(128, 356, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.features3 = nn.Sequential(
            nn.Conv2d(356, 356, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=4, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Linear(3204, 512),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(512, num_classes)
        )


    def forward(self, x):
        x = self.features1(x)
        x = self.features2(x)
        conv_feature_map = x
        x = self.features3(x)
        conv_feature_map2 = x
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x, conv_feature_map, conv_feature_map2

class DeepNN3(nn.Module):
    def __init__(self, num_classes=10):
        super(DeepNN3, self).__init__()
        self.features1 = nn.Sequential(
            nn.Conv2d(3, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.features2 = nn.Sequential(
            nn.Conv2d(128, 356, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.features3 = nn.Sequential(
            nn.Conv2d(356, 356, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=4, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Linear(3204, 512),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        x = self.features1(x)
        conv_feature_map = x
        x = self.features2(x)
        x = self.features3(x)
        conv_feature_map2 = x
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x, conv_feature_map, conv_feature_map2



# Lightweight neural network class to be used as student, much deeper but thinner network:
class LightNN(nn.Module):
    def __init__(self, num_classes=10):
        super(LightNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(32, 48, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(48, 48, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(48, 48, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=4, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Linear(432, 128),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        conv_feature_map = x
        conv_feature_map2 = x
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x, conv_feature_map, conv_feature_map2

In [6]:
#checking shape of data
for x, y in train_loader:
    print(x.shape)
    break

#checking shape of data
for x, y in test_loader:
    print(x.shape)
    break

torch.Size([128, 3, 32, 32])
torch.Size([128, 3, 32, 32])


In [9]:
def train(model, train_loader, epochs, learning_rate, device, is_tuple):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    model.train()

    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in train_loader:
            # inputs: A collection of batch_size images
            # labels: A vector of dimensionality batch_size with integers denoting class of each image
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            if is_tuple == True:
              outputs, map, map1 = model(inputs)
            if is_tuple == False:
              outputs = model(inputs)

            # outputs: Output of the network for the collection of images. A tensor of dimensionality batch_size x num_classes
            # labels: The actual labels of the images. Vector of dimensionality batch_size
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")

def test(model, test_loader, device, is_tuple):
    model.to(device)
    model.eval()

    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            if is_tuple == True:
              outputs, map, _ = model(inputs)
            if is_tuple == False:
              outputs = model(inputs)

            #outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)

            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f"Test Accuracy: {accuracy:.2f}%")
    return accuracy

# Train Teacher Network

In [14]:
torch.manual_seed(42)
nn_deep = DeepNN(num_classes=10).to(device)
train(nn_deep, train_loader, epochs=15, learning_rate=0.001, device=device, is_tuple=True)

# Instantiate the lightweight network:
torch.manual_seed(42)

Epoch 1/15, Loss: 1.4520263959989523
Epoch 2/15, Loss: 0.9534348714382143
Epoch 3/15, Loss: 0.7547467898987138
Epoch 4/15, Loss: 0.6344018812527132
Epoch 5/15, Loss: 0.5438377241344403
Epoch 6/15, Loss: 0.4661091280257915
Epoch 7/15, Loss: 0.40139796339032596
Epoch 8/15, Loss: 0.3346936466443874
Epoch 9/15, Loss: 0.2862583933888799
Epoch 10/15, Loss: 0.24660988885652074
Epoch 11/15, Loss: 0.21071514782622036
Epoch 12/15, Loss: 0.17448605391223107
Epoch 13/15, Loss: 0.16270225264532182
Epoch 14/15, Loss: 0.13989289174489963
Epoch 15/15, Loss: 0.12728366112846243


<torch._C.Generator at 0x7a98101d7e30>

In [15]:
test_accuracy_deep = test(nn_deep, test_loader, device, is_tuple=True)

Test Accuracy: 77.22%


# Train Student Network

In [16]:
nn_light = LightNN(num_classes=10).to(device)
total_params_deep = "{:,}".format(sum(p.numel() for p in nn_deep.parameters()))
print(f"DeepNN parameters: {total_params_deep}")
total_params_light = "{:,}".format(sum(p.numel() for p in nn_light.parameters()))
print(f"LightNN parameters: {total_params_light}")

DeepNN parameters: 3,201,122
LightNN parameters: 140,378


In [17]:
train(nn_light, train_loader, epochs=15, learning_rate=0.001, device=device, is_tuple=True)
test_accuracy_light_ce = test(nn_light, test_loader, device, is_tuple=True)

Epoch 1/15, Loss: 1.8485963030544388
Epoch 2/15, Loss: 1.4587232517769269
Epoch 3/15, Loss: 1.2620959483144227
Epoch 4/15, Loss: 1.140239791187179
Epoch 5/15, Loss: 1.0325023027332239
Epoch 6/15, Loss: 0.9571028027083258
Epoch 7/15, Loss: 0.8828101238936109
Epoch 8/15, Loss: 0.8320294876232781
Epoch 9/15, Loss: 0.7796697020530701
Epoch 10/15, Loss: 0.7351299082988973
Epoch 11/15, Loss: 0.697329757692259
Epoch 12/15, Loss: 0.6668614166624406
Epoch 13/15, Loss: 0.6359554065767762
Epoch 14/15, Loss: 0.6117550333595032
Epoch 15/15, Loss: 0.5793168609556945
Test Accuracy: 72.32%


In [22]:
torch.manual_seed(42)
new_nn_light = LightNN(num_classes=10).to(device)

# Train Student Network with Knowledge Distillation and Cross Entropy from Hinton et. al.

In [23]:
def train_knowledge_distillation(teacher, student, train_loader, epochs, learning_rate, T, soft_target_loss_weight, ce_loss_weight, device):
    ce_loss = nn.CrossEntropyLoss()
    optimizer = optim.Adam(student.parameters(), lr=learning_rate)

    teacher.eval()  # Teacher set to evaluation mode
    student.train() # Student to train mode

    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()

            # Forward pass with the teacher model - do not save gradients here as we do not change the teacher's weights
            with torch.no_grad():
                teacher_logits, _, _ = teacher(inputs)

            # Forward pass with the student model
            student_logits, _, _ = student(inputs)

            #Soften the student logits by applying softmax first and log() second
            soft_targets = nn.functional.softmax(teacher_logits / T, dim=-1)
            soft_prob = nn.functional.log_softmax(student_logits / T, dim=-1)

            # Calculate the soft targets loss. Scaled by T**2 as suggested by the authors of the paper "Distilling the knowledge in a neural network"
            soft_targets_loss = -torch.sum(soft_targets * soft_prob) / soft_prob.size()[0] * (T**2)

            # Calculate the true label loss
            label_loss = ce_loss(student_logits, labels)

            # Weighted sum of the two losses
            loss = soft_target_loss_weight * soft_targets_loss + ce_loss_weight * label_loss

            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")

# Apply ``train_knowledge_distillation`` with a temperature of 2. Arbitrarily set the weights to 0.75 for CE and 0.25 for distillation loss.
train_knowledge_distillation(teacher=nn_deep, student=new_nn_light, train_loader=train_loader, epochs=15, learning_rate=0.001, T=2, soft_target_loss_weight=0.25, ce_loss_weight=0.75, device=device)

Epoch 1/15, Loss: 3.3155798631555893
Epoch 2/15, Loss: 2.658221813723864
Epoch 3/15, Loss: 2.303970820763532
Epoch 4/15, Loss: 2.04968343915232
Epoch 5/15, Loss: 1.856287304092856
Epoch 6/15, Loss: 1.728067234349068
Epoch 7/15, Loss: 1.6185380202120223
Epoch 8/15, Loss: 1.527414279215781
Epoch 9/15, Loss: 1.4510936541935366
Epoch 10/15, Loss: 1.3917772651023572
Epoch 11/15, Loss: 1.3282699612400415
Epoch 12/15, Loss: 1.2720586558437104
Epoch 13/15, Loss: 1.2197882305935521
Epoch 14/15, Loss: 1.186876413157529
Epoch 15/15, Loss: 1.1392659639458522


NameError: ignored

In [24]:
test_accuracy_light_ce_and_kd = test(new_nn_light, test_loader, device, is_tuple=True)

# Compare the student test accuracy with and without the teacher, after distillation
print(f"Teacher accuracy: {test_accuracy_deep:.2f}%")
print(f"Student accuracy without teacher: {test_accuracy_light_ce:.2f}%")
print(f"Student accuracy with CE + KD: {test_accuracy_light_ce_and_kd:.2f}%")

Test Accuracy: 74.85%
Teacher accuracy: 77.22%
Student accuracy without teacher: 72.32%
Student accuracy with CE + KD: 74.85%


In [25]:
sample_input = torch.randn(128, 3, 32, 32).to(device)

In [26]:
convolutional_fe_output_student = nn_light.features(sample_input)
convolutional_fe_output_teacher = nn_deep.features1(sample_input)
# Print their shapes
print("Student's feature extractor output shape: ", convolutional_fe_output_student.shape)
print("Teacher's feature extractor output shape: ", convolutional_fe_output_teacher.shape)

Student's feature extractor output shape:  torch.Size([128, 48, 3, 3])
Teacher's feature extractor output shape:  torch.Size([128, 128, 16, 16])


# Modified Networks: Teacher stays the same. Student requires convolutional regressor on top of the Guided Layer. This increases the number of parameters to over 294,000 parameters.

In [27]:
class ModifiedLightNNRegressor(nn.Module):
    def __init__(self, num_classes=10):
        super(ModifiedLightNNRegressor, self).__init__()
        self.features1 = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )

        self.features2 = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.features3= nn.Sequential(
            nn.Conv2d(32, 48, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(48, 48, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(48, 48, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=4, stride=2),
        )

        self.regressor = nn.Sequential(
            nn.Conv2d(16, 128, kernel_size=3, padding=1)
        )

        self.regressor1 = nn.Sequential(
            nn.Conv2d(32, 356, kernel_size=3, padding=1)
        )

        self.classifier = nn.Sequential(
            nn.Linear(432, 128),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        x = self.features1(x)
        regressor_output = self.regressor(x)
        x = self.features2(x)
        regressor_output1 = self.regressor1(x)
        x = self.features3(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x, regressor_output, regressor_output1



class ModifiedLightNNRegressor2(nn.Module):
    def __init__(self, num_classes=10):
        super(ModifiedLightNNRegressor2, self).__init__()
        self.features1 = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )

        self.features2 = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.features3= nn.Sequential(
            nn.Conv2d(32, 48, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(48, 48, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(48, 48, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=4, stride=2),
        )

        self.regressor = nn.Sequential(
            nn.Conv2d(32, 356, kernel_size=3, padding=1)
        )

        self.regressor1 = nn.Sequential(
            nn.Conv2d(48, 356, kernel_size=3, padding=1)
        )

        self.classifier = nn.Sequential(
            nn.Linear(432, 128),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        x = self.features1(x)
        x = self.features2(x)
        regressor_output = self.regressor(x)
        x = self.features3(x)
        regressor_output1 = self.regressor1(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x, regressor_output, regressor_output1




class ModifiedLightNNRegressor3(nn.Module):
    def __init__(self, num_classes=10):
        super(ModifiedLightNNRegressor3, self).__init__()
        self.features1 = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )

        self.features2 = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.features3= nn.Sequential(
            nn.Conv2d(32, 48, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(48, 48, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(48, 48, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=4, stride=2),
        )

        self.regressor = nn.Sequential(
            nn.Conv2d(16, 128, kernel_size=3, padding=1)
        )

        self.regressor1 = nn.Sequential(
            nn.Conv2d(48, 356, kernel_size=3, padding=1)
        )

        self.classifier = nn.Sequential(
            nn.Linear(432, 128),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        x = self.features1(x)
        regressor_output = self.regressor(x)
        x = self.features2(x)
        x = self.features3(x)
        regressor_output1 = self.regressor1(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x, regressor_output, regressor_output1

In [28]:
mod_nn_light = ModifiedLightNNRegressor(num_classes=10).to(device)
convolutional_fe_output_student = mod_nn_light.features1(sample_input)
convolutional_fe_output_teacher = nn_deep.features1(sample_input)
# Print their shapes
print("Student's feature extractor output shape: ", convolutional_fe_output_student.shape)
print("Teacher's feature extractor output shape: ", convolutional_fe_output_teacher.shape)

Student's feature extractor output shape:  torch.Size([128, 16, 16, 16])
Teacher's feature extractor output shape:  torch.Size([128, 128, 16, 16])


In [29]:
total_params_deep = "{:,}".format(sum(p.numel() for p in nn_deep.parameters()))
print(f"DeepNN parameters: {total_params_deep}")
total_params_light = "{:,}".format(sum(p.numel() for p in mod_nn_light.parameters()))
print(f"LightNN parameters: {total_params_light}")

DeepNN parameters: 3,201,122
LightNN parameters: 261,822


In [36]:
def test_multiple_outputs(model, test_loader, device):
    model.to(device)
    model.eval()

    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            outputs, _, _ = model(inputs) # Disregard the second tensor of the tuple
            _, predicted = torch.max(outputs.data, 1)

            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f"Test Accuracy: {accuracy:.2f}%")
    return accuracy

In [None]:
import torch.nn as nn
import torch

class CustomLoss(nn.Module):
    def __init__(self):
        super(CustomLoss, self).__init__()

    def forward(self, inputs, targets):
        loss = (torch.norm(inputs - targets)**2)*0.5
        #loss = -1 * (targets * torch.log(inputs) + (1 - targets) * torch.log(1 - inputs))
        return loss.mean()

# Start 2-Stage Training:
1. MSE Loss between Hint and Guided Layers minimized for 10 epochs.
2. Knowledge distillation run where CE loss in combination with KD loss function are weighted and minimized.  

In [31]:
def train_mse_loss(teacher, student, train_loader, epochs, learning_rate, feature_map_weight, ce_loss_weight, device):
    ce_loss = nn.CrossEntropyLoss()
    #mse_loss = CustomLoss()
    mse_loss = nn.MSELoss()
    optimizer = optim.Adam(student.parameters(), lr=learning_rate)

    teacher.to(device)
    student.to(device)
    teacher.eval()  # Teacher set to evaluation mode
    student.train() # Student to train mode

    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()

            # Again ignore teacher logits
            with torch.no_grad():
                _, teacher_feature_map, teacher_feature_map1 = teacher(inputs)

            # Forward pass with the student model
            student_logits, regressor_feature_map, regressor_feature_map1 = student(inputs)

            # Calculate the loss
            hidden_rep_loss = mse_loss(regressor_feature_map, teacher_feature_map)
            hidden_rep_loss1 = mse_loss(regressor_feature_map1, teacher_feature_map1)

            # Calculate the true label loss
            label_loss = ce_loss(student_logits, labels)

            # Only minimize mse loss of teacher and regressor feature maps
            loss = (feature_map_weight*.5) * hidden_rep_loss + (feature_map_weight*.5) * hidden_rep_loss1 + ce_loss_weight * label_loss

            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")

# Notice how our test function remains the same here with the one we used in our previous case. We only care about the actual outputs because we measure accuracy.

# Initialize a ModifiedLightNNRegressor
torch.manual_seed(42)
modified_nn_light_reg = ModifiedLightNNRegressor(num_classes=10).to(device)
modified_nn_light_reg2 = ModifiedLightNNRegressor2(num_classes=10).to(device)
modified_nn_light_reg3 = ModifiedLightNNRegressor3(num_classes=10).to(device)

# We do not have to train the modified deep network from scratch of course, we just load its weights from the trained instance
modified_nn_deep_reg = DeepNN(num_classes=10).to(device)
modified_nn_deep_reg.load_state_dict(nn_deep.state_dict())
modified_nn_deep_reg2 = DeepNN2(num_classes=10).to(device)
modified_nn_deep_reg2.load_state_dict(nn_deep.state_dict())
modified_nn_deep_reg3 = DeepNN3(num_classes=10).to(device)
modified_nn_deep_reg3.load_state_dict(nn_deep.state_dict())



#train(modified_nn_deep_reg, train_loader, epochs=10, learning_rate=0.001, device=device)
#test_accuracy_deep = test(modified_nn_deep_reg, test_loader, device)


#Freeze params after convolutional layers, train up until classifier
for param in modified_nn_light_reg.classifier.parameters():
  param.requires_grad = False
for param in modified_nn_light_reg2.classifier.parameters():
  param.requires_grad = False
for param in modified_nn_light_reg3.classifier.parameters():
  param.requires_grad = False

# Train and test once again
train_mse_loss(teacher=modified_nn_deep_reg, student=modified_nn_light_reg, train_loader=train_loader, epochs=10, learning_rate=0.001, feature_map_weight=0.25, ce_loss_weight=0.75, device=device)
train_mse_loss(teacher=modified_nn_deep_reg2, student=modified_nn_light_reg2, train_loader=train_loader, epochs=10, learning_rate=0.001, feature_map_weight=0.25, ce_loss_weight=0.75, device=device)
train_mse_loss(teacher=modified_nn_deep_reg3, student=modified_nn_light_reg3, train_loader=train_loader, epochs=10, learning_rate=0.001, feature_map_weight=0.25, ce_loss_weight=0.75, device=device)

Epoch 1/10, Loss: 1.4984829212393602
Epoch 2/10, Loss: 1.249401466468411
Epoch 3/10, Loss: 1.1174374798984479
Epoch 4/10, Loss: 1.0247597636469186
Epoch 5/10, Loss: 0.955969703471874
Epoch 6/10, Loss: 0.8872274105506175
Epoch 7/10, Loss: 0.8298273592653787
Epoch 8/10, Loss: 0.7902315162941623
Epoch 9/10, Loss: 0.7494779393800994
Epoch 10/10, Loss: 0.7178159351544002
Epoch 1/10, Loss: 1.5687385715182176
Epoch 2/10, Loss: 1.3227761654597718
Epoch 3/10, Loss: 1.1910338270694703
Epoch 4/10, Loss: 1.0946518448002809
Epoch 5/10, Loss: 1.0182523457595454
Epoch 6/10, Loss: 0.961217195329154
Epoch 7/10, Loss: 0.9122144968613334
Epoch 8/10, Loss: 0.8700367410469543
Epoch 9/10, Loss: 0.831103858892875
Epoch 10/10, Loss: 0.7895143320188498
Epoch 1/10, Loss: 1.5489999019276455
Epoch 2/10, Loss: 1.268882025233315
Epoch 3/10, Loss: 1.137671595034392
Epoch 4/10, Loss: 1.0455766849200745
Epoch 5/10, Loss: 0.9690823065655311
Epoch 6/10, Loss: 0.9038617126167278
Epoch 7/10, Loss: 0.8537990528604259
Epoch

In [34]:
#unfreeze parameters and train finally with knowledge distillation for stage 2

for param in modified_nn_light_reg.classifier.parameters():
  param.requires_grad = True
for param in modified_nn_light_reg2.classifier.parameters():
  param.requires_grad = True
for param in modified_nn_light_reg3.classifier.parameters():
  param.requires_grad = True

def train_knowledge_distillation1(teacher, student, train_loader, epochs, learning_rate, T, soft_target_loss_weight, ce_loss_weight, device):
    ce_loss = nn.CrossEntropyLoss()
    optimizer = optim.Adam(student.parameters(), lr=learning_rate)

    teacher.eval()  # Teacher set to evaluation mode
    student.train() # Student to train mode

    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()

            # Forward pass with the teacher model - do not save gradients here as we do not change the teacher's weights
            with torch.no_grad():
                teacher_logits, _, _ = teacher(inputs)

            # Forward pass with the student model
            student_logits, _, _  = student(inputs)

            #Soften the student logits by applying softmax first and log() second
            soft_targets = nn.functional.softmax(teacher_logits / T, dim=-1)
            soft_prob = nn.functional.log_softmax(student_logits / T, dim=-1)

            # Calculate the soft targets loss. Scaled by T**2 as suggested by the authors of the paper "Distilling the knowledge in a neural network"
            soft_targets_loss = -torch.sum(soft_targets * soft_prob) / soft_prob.size()[0] * (T**2)

            # Calculate the true label loss
            label_loss = ce_loss(student_logits, labels)

            # Weighted sum of the two losses
            loss = soft_target_loss_weight * soft_targets_loss + ce_loss_weight * label_loss

            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")



In [35]:
train_knowledge_distillation1(teacher=modified_nn_deep_reg, student=modified_nn_light_reg, train_loader=train_loader, epochs=10, learning_rate=0.001, T=2, soft_target_loss_weight=0.25, ce_loss_weight=0.75, device=device)
train_knowledge_distillation1(teacher=modified_nn_deep_reg2, student=modified_nn_light_reg2, train_loader=train_loader, epochs=10, learning_rate=0.001, T=2, soft_target_loss_weight=0.25, ce_loss_weight=0.75, device=device)
train_knowledge_distillation1(teacher=modified_nn_deep_reg3, student=modified_nn_light_reg3, train_loader=train_loader, epochs=10, learning_rate=0.001, T=2, soft_target_loss_weight=0.25, ce_loss_weight=0.75, device=device)

Epoch 1/10, Loss: 1.7286391529585698
Epoch 2/10, Loss: 1.5791609884832827
Epoch 3/10, Loss: 1.4970128106339204
Epoch 4/10, Loss: 1.4377644885226588
Epoch 5/10, Loss: 1.3722594235559253
Epoch 6/10, Loss: 1.3299927423372293
Epoch 7/10, Loss: 1.2806868484562925
Epoch 8/10, Loss: 1.2346410397678385
Epoch 9/10, Loss: 1.2128590767645775
Epoch 10/10, Loss: 1.1707141903965064
Epoch 1/10, Loss: 1.757028256840718
Epoch 2/10, Loss: 1.6169207227199585
Epoch 3/10, Loss: 1.5284026120324878
Epoch 4/10, Loss: 1.4749924974978124
Epoch 5/10, Loss: 1.404397677277665
Epoch 6/10, Loss: 1.3571709455431575
Epoch 7/10, Loss: 1.3039591742293608
Epoch 8/10, Loss: 1.2618243357409602
Epoch 9/10, Loss: 1.2214463194617835
Epoch 10/10, Loss: 1.1825790301613186
Epoch 1/10, Loss: 1.7280333490322923
Epoch 2/10, Loss: 1.5871878648962816
Epoch 3/10, Loss: 1.5129780281535194
Epoch 4/10, Loss: 1.4528244907593788
Epoch 5/10, Loss: 1.4099713743800093
Epoch 6/10, Loss: 1.3571288241144945
Epoch 7/10, Loss: 1.3109316275552716
E

ValueError: ignored

In [37]:
print(f"Teacher Accuracy:")
test_accuracy_deep = test(modified_nn_deep_reg, test_loader, device, is_tuple= True)
print(f"Final Student Accuracy (Hint = 1,2):")
test_accuracy_light_ce_and_mse_loss1 = test_multiple_outputs(modified_nn_light_reg, test_loader, device)
print(f"Final Student Accuracy (Hint = 2,3):")
test_accuracy_light_ce_and_mse_loss2 = test_multiple_outputs(modified_nn_light_reg2, test_loader, device)
print(f"Final Student Accuracy (Hint = 1,3):")
test_accuracy_light_ce_and_mse_loss3 = test_multiple_outputs(modified_nn_light_reg3, test_loader, device)

Teacher Accuracy:
Test Accuracy: 77.22%
Final Student Accuracy (Hint = 1,2):
Test Accuracy: 75.07%
Final Student Accuracy (Hint = 2,3):
Test Accuracy: 72.99%
Final Student Accuracy (Hint = 1,3):
Test Accuracy: 74.08%


# Final Model Accuracies
As we can see, the student's accuracy is much improved with the 2-stage training following similar methodology in Romero et. al.

In [38]:
print(f"Teacher accuracy: {test_accuracy_deep:.2f}%")
print(f"Student accuracy without teacher: {test_accuracy_light_ce:.2f}%")
print(f"Student accuracy with CE + KD: {test_accuracy_light_ce_and_kd:.2f}%")
print(f"Student accuracy (hint=1,2 and guided = Beginning, middle) and Stage 1(RegressorMSE) + Stage 2(CE + KD): {test_accuracy_light_ce_and_mse_loss1:.2f}%")
print(f"Student accuracy (hint=2,3 and guided = Middle, end) and Stage 1(RegressorMSE) + Stage 2(CE + KD): {test_accuracy_light_ce_and_mse_loss2:.2f}%")
print(f"Student accuracy (hint=1,3 and guided = beginning, End) and Stage 1(RegressorMSE) + Stage 2(CE + KD): {test_accuracy_light_ce_and_mse_loss3:.2f}%")

Teacher accuracy: 77.22%
Student accuracy without teacher: 72.32%
Student accuracy with CE + KD: 74.85%
Student accuracy (hint=1,2 and guided = Beginning, middle) and Stage 1(RegressorMSE) + Stage 2(CE + KD): 75.07%
Student accuracy (hint=2,3 and guided = Middle, end) and Stage 1(RegressorMSE) + Stage 2(CE + KD): 72.99%
Student accuracy (hint=1,3 and guided = beginning, End) and Stage 1(RegressorMSE) + Stage 2(CE + KD): 74.08%


In [39]:
nn_light = LightNN(num_classes=100).to(device)
total_params_deep = "{:,}".format(sum(p.numel() for p in nn_deep.parameters()))
print(f"DeepNN parameters: {total_params_deep}")
total_params_light = "{:,}".format(sum(p.numel() for p in nn_light.parameters()))
print(f"LightNN parameters: {total_params_light}")
total_params_light = "{:,}".format(sum(p.numel() for p in modified_nn_light_reg.parameters()))
print(f"LightNN parameters: {total_params_light}")
total_params_light = "{:,}".format(sum(p.numel() for p in modified_nn_light_reg2.parameters()))
print(f"LightNN parameters: {total_params_light}")
total_params_light = "{:,}".format(sum(p.numel() for p in modified_nn_light_reg3.parameters()))
print(f"LightNN parameters: {total_params_light}")

DeepNN parameters: 3,201,122
LightNN parameters: 151,988
LightNN parameters: 261,822
LightNN parameters: 397,410
LightNN parameters: 313,086
