# Knowledge Distillation with Hint Training: Experiment 1

Hint and Guided layer in beginning, middle and end.


In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
batch_size = 128

mean = torch.tensor([0.4914, 0.4822, 0.4465])
std = torch.tensor([0.2009, 0.2009, 0.2009])
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean = mean, std = std)])
cifar_train_data = datasets.CIFAR100('./data', train=True, download=True, transform=transform)
cifar_test_data = datasets.CIFAR100('./data', train=False, download=True, transform=transform)

train_loader = torch.utils.data.DataLoader(cifar_train_data, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(cifar_test_data, batch_size=batch_size)

Files already downloaded and verified
Files already downloaded and verified


In [4]:
for x, y in train_loader:
    print(x.shape)
    break

torch.Size([128, 3, 32, 32])


Network initializations

In [5]:
# Wide Shallow neural network class to be used as teacher:
# 3 different instantiations where the hint layer parameters are returned, either first, second, or third convolutional layer

class DeepNN(nn.Module):
    def __init__(self, num_classes=100):
        super(DeepNN, self).__init__()
        self.features1 = nn.Sequential(
            nn.Conv2d(3, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.features2 = nn.Sequential(
            nn.Conv2d(128, 356, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.features3 = nn.Sequential(
            nn.Conv2d(356, 356, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=4, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Linear(3204, 512),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        x = self.features1(x)
        conv_feature_map = x
        x = self.features2(x)
        x = self.features3(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x, conv_feature_map

class DeepNN2(nn.Module):
    def __init__(self, num_classes=100):
        super(DeepNN2, self).__init__()
        self.features1 = nn.Sequential(
            nn.Conv2d(3, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.features2 = nn.Sequential(
            nn.Conv2d(128, 356, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.features3 = nn.Sequential(
            nn.Conv2d(356, 356, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=4, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Linear(3204, 512),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(512, num_classes)
        )


    def forward(self, x):
        x = self.features1(x)
        x = self.features2(x)
        conv_feature_map = x
        x = self.features3(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x, conv_feature_map

class DeepNN3(nn.Module):
    def __init__(self, num_classes=100):
        super(DeepNN3, self).__init__()
        self.features1 = nn.Sequential(
            nn.Conv2d(3, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.features2 = nn.Sequential(
            nn.Conv2d(128, 356, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.features3 = nn.Sequential(
            nn.Conv2d(356, 356, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=4, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Linear(3204, 512),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        x = self.features1(x)
        x = self.features2(x)
        x = self.features3(x)
        conv_feature_map = x
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x, conv_feature_map



# Lightweight neural network class to be used as student, much deeper but thinner network:
class LightNN(nn.Module):
    def __init__(self, num_classes=100):
        super(LightNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(32, 48, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(48, 48, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(48, 48, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=4, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Linear(432, 128),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        conv_feature_map = x
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x, conv_feature_map

In [6]:
#Training and testing loop normal
def test(model, test_dl):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()
    correct = 0
    num = 0
    with torch.no_grad():
        for (X, y) in test_dl:
            X, y = X.to(device), y.to(device)
            y_pred , _ = model(X)
            y_pred = torch.argmax(y_pred, dim=1)
            correct += torch.eq(y, y_pred).sum()
            num += X.shape[0]

    accuracy = correct / num
    return accuracy.item()

def train(model, lr, num_epochs, train_dl, test_dl):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    opt = torch.optim.Adam(model.parameters(), lr=lr)
    loss_fn = nn.CrossEntropyLoss()

    test_accuracy_list = []
    for epoch in range(1, num_epochs + 1):
        model.train()
        for X, y in train_dl:
            X, y = X.to(device), y.to(device)
            opt.zero_grad()
            y_pred, _ = model(X)
            loss = loss_fn(y_pred, y)
            loss.backward()
            opt.step()

        test_accuracy = test(model, test_dl)
        print(f"Test accuracy at epoch {epoch}: {test_accuracy:.4f}")
        test_accuracy_list.append(test_accuracy)
    return test_accuracy_list

In [7]:
#training teacher for 10 epochs
lr =  0.001
num_epochs = 15
teacher = DeepNN()

x = train(teacher, lr, num_epochs, train_loader, test_loader)
test_accuracy = test(teacher, test_loader)
print(f"Final teacher test accuracy on CIFAR-100: {test_accuracy:.4f}")

Test accuracy at epoch 1: 0.2438
Test accuracy at epoch 2: 0.3323
Test accuracy at epoch 3: 0.3779
Test accuracy at epoch 4: 0.4120
Test accuracy at epoch 5: 0.4261
Test accuracy at epoch 6: 0.4391
Test accuracy at epoch 7: 0.4518
Test accuracy at epoch 8: 0.4514
Test accuracy at epoch 9: 0.4642
Test accuracy at epoch 10: 0.4582
Test accuracy at epoch 11: 0.4607
Test accuracy at epoch 12: 0.4554
Test accuracy at epoch 13: 0.4497
Test accuracy at epoch 14: 0.4471
Test accuracy at epoch 15: 0.4475
Final teacher test accuracy on CIFAR-100: 0.4475


In [8]:
#student model training
lr =  0.001
num_epochs = 15
student = LightNN()

x = train(student, lr, num_epochs, train_loader, test_loader)
test_accuracy = test(student, test_loader)
print(f"Final student test accuracy on CIFAR-100: {test_accuracy:.4f}")

Test accuracy at epoch 1: 0.0907
Test accuracy at epoch 2: 0.1576
Test accuracy at epoch 3: 0.1913
Test accuracy at epoch 4: 0.2227
Test accuracy at epoch 5: 0.2353
Test accuracy at epoch 6: 0.2529
Test accuracy at epoch 7: 0.2602
Test accuracy at epoch 8: 0.2897
Test accuracy at epoch 9: 0.3015
Test accuracy at epoch 10: 0.3096
Test accuracy at epoch 11: 0.3214
Test accuracy at epoch 12: 0.3169
Test accuracy at epoch 13: 0.3262
Test accuracy at epoch 14: 0.3260
Test accuracy at epoch 15: 0.3379
Final student test accuracy on CIFAR-100: 0.3379


In [9]:
#Distillation loss
def test_d(model, test_dl):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()
    correct = 0
    num = 0
    with torch.no_grad():
        for (X, y) in test_dl:
            X, y = X.to(device), y.to(device)
            y_pred, _ = model(X)
            y_pred = torch.argmax(y_pred, dim=1)
            correct += torch.eq(y, y_pred).sum()
            num += X.shape[0]

    accuracy = correct / num
    return accuracy.item()

def train_d(teacher, student, lr, num_epochs, temperature, alpha, train_dl, test_dl):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    teacher.to(device)
    student.to(device)
    opt = torch.optim.Adam(student.parameters(), lr=lr)
    loss_fn = nn.CrossEntropyLoss()

    test_accuracy_list = []
    for epoch in range(1, num_epochs + 1):
        student.train()
        running_loss = 0.0
        for X, y in train_dl:
            X, y = X.to(device), y.to(device)
            opt.zero_grad()
            student_logits, _ = student(X)
            with torch.no_grad():
                teacher_logits, _ = teacher(X)

            student_probs = torch.softmax(student_logits / temperature, dim=1)


            soft_targets = nn.functional.softmax(teacher_logits / temperature, dim=-1)
            soft_prob = nn.functional.log_softmax(student_logits / temperature, dim=-1)
            soft_targets_loss = -torch.sum(soft_targets * soft_prob) / soft_prob.size()[0] * (temperature**2)
            label_loss = loss_fn(student_logits, y)
            loss = 0.25 * soft_targets_loss + 0.75* label_loss


            loss.backward()
            opt.step()

            running_loss += loss.item()

        #print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")

        test_accuracy = test(student, test_dl)
        print(f"Test accuracy at epoch {epoch}: {test_accuracy:.4f}")
        test_accuracy_list.append(test_accuracy)
    return test_accuracy_list

In [10]:
lr =  0.001
num_epochs = 15
student = LightNN()

train_d(teacher, student, lr, num_epochs, 2, .75, train_loader, test_loader)
test_accuracy = test_d(student, test_loader)
print(f"Final student test accuracy on CIFAR-10: {test_accuracy:.4f}")

Test accuracy at epoch 1: 0.1024
Test accuracy at epoch 2: 0.1761
Test accuracy at epoch 3: 0.2164
Test accuracy at epoch 4: 0.2459
Test accuracy at epoch 5: 0.2665
Test accuracy at epoch 6: 0.2826
Test accuracy at epoch 7: 0.2928
Test accuracy at epoch 8: 0.3093
Test accuracy at epoch 9: 0.3190
Test accuracy at epoch 10: 0.3180
Test accuracy at epoch 11: 0.3420
Test accuracy at epoch 12: 0.3456
Test accuracy at epoch 13: 0.3456
Test accuracy at epoch 14: 0.3484
Test accuracy at epoch 15: 0.3574
Final student test accuracy on CIFAR-10: 0.3574


In [11]:
class ModifiedLightNNRegressor(nn.Module):
    def __init__(self, num_classes=100):
        super(ModifiedLightNNRegressor, self).__init__()
        self.features1 = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )

        self.features2 = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.features3= nn.Sequential(
            nn.Conv2d(32, 48, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(48, 48, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(48, 48, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=4, stride=2),
        )

        self.regressor = nn.Sequential(
            nn.Conv2d(16, 128, kernel_size=3, padding=1)
        )

        self.classifier = nn.Sequential(
            nn.Linear(432, 128),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        x = self.features1(x)
        regressor_output = self.regressor(x)
        x = self.features2(x)
        x = self.features3(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x, regressor_output



class ModifiedLightNNRegressor2(nn.Module):
    def __init__(self, num_classes=100):
        super(ModifiedLightNNRegressor2, self).__init__()
        self.features1 = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )

        self.features2 = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.features3= nn.Sequential(
            nn.Conv2d(32, 48, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(48, 48, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(48, 48, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=4, stride=2),
        )

        self.regressor = nn.Sequential(
            nn.Conv2d(32, 356, kernel_size=3, padding=1)
        )

        self.classifier = nn.Sequential(
            nn.Linear(432, 128),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        x = self.features1(x)
        x = self.features2(x)
        regressor_output = self.regressor(x)
        x = self.features3(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x, regressor_output




class ModifiedLightNNRegressor3(nn.Module):
    def __init__(self, num_classes=100):
        super(ModifiedLightNNRegressor3, self).__init__()
        self.features1 = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )

        self.features2 = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.features3= nn.Sequential(
            nn.Conv2d(32, 48, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(48, 48, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(48, 48, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=4, stride=2),
        )

        self.regressor = nn.Sequential(
            nn.Conv2d(48, 356, kernel_size=3, padding=1)
        )

        self.classifier = nn.Sequential(
            nn.Linear(432, 128),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        x = self.features1(x)
        x = self.features2(x)
        x = self.features3(x)
        regressor_output = self.regressor(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x, regressor_output

In [12]:
#Distillation loss
def test_d(model, test_dl):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()
    correct = 0
    num = 0
    with torch.no_grad():
        for (X, y) in test_dl:
            X, y = X.to(device), y.to(device)
            y_pred, _ = model(X)
            y_pred = torch.argmax(y_pred, dim=1)
            correct += torch.eq(y, y_pred).sum()
            num += X.shape[0]

    accuracy = correct / num
    return accuracy.item()

def train_d1(teacher, student, lr, num_epochs, temperature, alpha, train_dl, test_dl):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    teacher.to(device)
    student.to(device)
    opt = torch.optim.Adam(student.parameters(), lr=lr)
    loss_fn = nn.CrossEntropyLoss()
    mse_loss = nn.MSELoss()

    test_accuracy_list = []
    for epoch in range(1, num_epochs + 1):
        student.train()
        running_loss = 0.0
        for X, y in train_dl:
            X, y = X.to(device), y.to(device)
            opt.zero_grad()
            student_logits, regressor_feature_map = student(X)
            with torch.no_grad():
                teacher_logits, teacher_feature_map = teacher(X)

            hidden_rep_loss = mse_loss(regressor_feature_map, teacher_feature_map)
            label_loss = loss_fn(student_logits, y)
            loss = 0.25 * hidden_rep_loss + 0.75 * label_loss



            loss.backward()
            opt.step()

            running_loss += loss.item()

        #print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")

        test_accuracy = test(student, test_dl)
        print(f"Test accuracy at epoch {epoch}: {test_accuracy:.4f}")
        test_accuracy_list.append(test_accuracy)
    return test_accuracy_list


# Initialize a ModifiedLightNNRegressor
torch.manual_seed(42)
modified_nn_light_reg = ModifiedLightNNRegressor(num_classes=100).to(device)
modified_nn_light_reg2 = ModifiedLightNNRegressor2(num_classes=100).to(device)
modified_nn_light_reg3 = ModifiedLightNNRegressor3(num_classes=100).to(device)

# We do not have to train the modified deep network from scratch of course, we just load its weights from the trained instance
modified_nn_deep_reg = DeepNN(num_classes=100).to(device)
modified_nn_deep_reg.load_state_dict(teacher.state_dict())
modified_nn_deep_reg2 = DeepNN2(num_classes=100).to(device)
modified_nn_deep_reg2.load_state_dict(teacher.state_dict())
modified_nn_deep_reg3 = DeepNN3(num_classes=100).to(device)
modified_nn_deep_reg3.load_state_dict(teacher.state_dict())

<All keys matched successfully>

In [13]:
for param in modified_nn_light_reg.classifier.parameters():
  param.requires_grad = False
for param in modified_nn_light_reg2.classifier.parameters():
  param.requires_grad = False
for param in modified_nn_light_reg3.classifier.parameters():
  param.requires_grad = False

# Train and test once again
train_d1(teacher=modified_nn_deep_reg, student=modified_nn_light_reg, lr=0.001, num_epochs=10, temperature=0, alpha = 0, train_dl=train_loader, test_dl=test_loader)
train_d1(teacher=modified_nn_deep_reg2, student=modified_nn_light_reg2, lr=0.001, num_epochs=10, temperature=0, alpha = 0, train_dl=train_loader, test_dl=test_loader)
train_d1(teacher=modified_nn_deep_reg3, student=modified_nn_light_reg3, lr=0.001, num_epochs=10, temperature=0, alpha = 0, train_dl=train_loader, test_dl=test_loader)

Test accuracy at epoch 1: 0.0245
Test accuracy at epoch 2: 0.0688
Test accuracy at epoch 3: 0.1067
Test accuracy at epoch 4: 0.1214
Test accuracy at epoch 5: 0.1375
Test accuracy at epoch 6: 0.1550
Test accuracy at epoch 7: 0.1680
Test accuracy at epoch 8: 0.1735
Test accuracy at epoch 9: 0.1833
Test accuracy at epoch 10: 0.1938
Test accuracy at epoch 1: 0.0566
Test accuracy at epoch 2: 0.0883
Test accuracy at epoch 3: 0.1189
Test accuracy at epoch 4: 0.1353
Test accuracy at epoch 5: 0.1471
Test accuracy at epoch 6: 0.1613
Test accuracy at epoch 7: 0.1729
Test accuracy at epoch 8: 0.1829
Test accuracy at epoch 9: 0.1945
Test accuracy at epoch 10: 0.1939
Test accuracy at epoch 1: 0.0337
Test accuracy at epoch 2: 0.0721
Test accuracy at epoch 3: 0.1027
Test accuracy at epoch 4: 0.1173
Test accuracy at epoch 5: 0.1286
Test accuracy at epoch 6: 0.1431
Test accuracy at epoch 7: 0.1490
Test accuracy at epoch 8: 0.1602
Test accuracy at epoch 9: 0.1703
Test accuracy at epoch 10: 0.1893


[0.03370000049471855,
 0.07209999859333038,
 0.10269999504089355,
 0.11729999631643295,
 0.12860000133514404,
 0.14309999346733093,
 0.14899998903274536,
 0.16019999980926514,
 0.17029999196529388,
 0.18930000066757202]

In [14]:
#unfreeze parameters and train finally with knowledge distillation for stage 2

for param in modified_nn_light_reg.classifier.parameters():
  param.requires_grad = True
for param in modified_nn_light_reg2.classifier.parameters():
  param.requires_grad = True
for param in modified_nn_light_reg3.classifier.parameters():
  param.requires_grad = True


def test_d(model, test_dl):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()
    correct = 0
    num = 0
    with torch.no_grad():
        for (X, y) in test_dl:
            X, y = X.to(device), y.to(device)
            y_pred, _ = model(X)
            y_pred = torch.argmax(y_pred, dim=1)
            correct += torch.eq(y, y_pred).sum()
            num += X.shape[0]

    accuracy = correct / num
    return accuracy.item()

def train_d(teacher, student, lr, num_epochs, temperature, alpha, train_dl, test_dl):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    teacher.to(device)
    student.to(device)
    opt = torch.optim.Adam(student.parameters(), lr=lr)
    loss_fn = nn.CrossEntropyLoss()

    test_accuracy_list = []
    for epoch in range(1, num_epochs + 1):
        student.train()
        running_loss = 0.0
        for X, y in train_dl:
            X, y = X.to(device), y.to(device)
            opt.zero_grad()
            student_logits, _ = student(X)
            with torch.no_grad():
                teacher_logits, _ = teacher(X)

            student_probs = torch.softmax(student_logits / temperature, dim=1)


            soft_targets = nn.functional.softmax(teacher_logits / temperature, dim=-1)
            soft_prob = nn.functional.log_softmax(student_logits / temperature, dim=-1)
            soft_targets_loss = -torch.sum(soft_targets * soft_prob) / soft_prob.size()[0] * (temperature**2)
            label_loss = loss_fn(student_logits, y)
            loss = 0.25 * soft_targets_loss + 0.75* label_loss


            loss.backward()
            opt.step()

            running_loss += loss.item()

        #print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")

        test_accuracy = test(student, test_dl)
        print(f"Test accuracy at epoch {epoch}: {test_accuracy:.4f}")
        test_accuracy_list.append(test_accuracy)
    return test_accuracy_list

In [15]:
train_d(teacher=modified_nn_deep_reg, student=modified_nn_light_reg, lr=0.001, num_epochs=10, temperature=2, alpha = 0.75, train_dl=train_loader, test_dl=test_loader)
train_d(teacher=modified_nn_deep_reg2, student=modified_nn_light_reg2,  lr=0.001, num_epochs=10, temperature=2, alpha = 0.75, train_dl=train_loader, test_dl=test_loader)
train_d(teacher=modified_nn_deep_reg3, student=modified_nn_light_reg3,  lr=0.001, num_epochs=10, temperature=2, alpha = 0.75, train_dl=train_loader, test_dl=test_loader)

Test accuracy at epoch 1: 0.2553
Test accuracy at epoch 2: 0.2847
Test accuracy at epoch 3: 0.3084
Test accuracy at epoch 4: 0.3200
Test accuracy at epoch 5: 0.3347
Test accuracy at epoch 6: 0.3383
Test accuracy at epoch 7: 0.3509
Test accuracy at epoch 8: 0.3597
Test accuracy at epoch 9: 0.3662
Test accuracy at epoch 10: 0.3661
Test accuracy at epoch 1: 0.2707
Test accuracy at epoch 2: 0.3010
Test accuracy at epoch 3: 0.3183
Test accuracy at epoch 4: 0.3350
Test accuracy at epoch 5: 0.3442
Test accuracy at epoch 6: 0.3623
Test accuracy at epoch 7: 0.3638
Test accuracy at epoch 8: 0.3751
Test accuracy at epoch 9: 0.3806
Test accuracy at epoch 10: 0.3723
Test accuracy at epoch 1: 0.2432
Test accuracy at epoch 2: 0.2807
Test accuracy at epoch 3: 0.2922
Test accuracy at epoch 4: 0.3093
Test accuracy at epoch 5: 0.3261
Test accuracy at epoch 6: 0.3388
Test accuracy at epoch 7: 0.3457
Test accuracy at epoch 8: 0.3568
Test accuracy at epoch 9: 0.3568
Test accuracy at epoch 10: 0.3602


[0.24319998919963837,
 0.2806999981403351,
 0.2921999990940094,
 0.3093000054359436,
 0.3260999917984009,
 0.33879998326301575,
 0.3456999957561493,
 0.35679998993873596,
 0.35679998993873596,
 0.3601999878883362]

In [16]:
def test_multiple_outputs(model, test_loader, device):
    model.to(device)
    model.eval()

    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            outputs, _ = model(inputs) # Disregard the second tensor of the tuple
            _, predicted = torch.max(outputs.data, 1)

            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f"Test Accuracy: {accuracy:.2f}%")
    return accuracy

In [17]:
print(f"Teacher Accuracy:")
test_accuracy_deep = test(modified_nn_deep_reg, test_loader)
print(f"{(100*test_accuracy_deep):.2f}%")

test_accuracy = test_d(student, test_loader)
print(f"Final student accuracy KD: {test_accuracy:.4f}")

print(f"Final Student Accuracy (Hint = 1):")
test_accuracy_light_ce_and_mse_loss1 = test_multiple_outputs(modified_nn_light_reg, test_loader, device)
print(f"Final Student Accuracy (Hint = 2):")
test_accuracy_light_ce_and_mse_loss2 = test_multiple_outputs(modified_nn_light_reg2, test_loader, device)
print(f"Final Student Accuracy (Hint = 3):")
test_accuracy_light_ce_and_mse_loss3 = test_multiple_outputs(modified_nn_light_reg3, test_loader, device)


Teacher Accuracy:
44.75%
Final student accuracy KD: 0.3574
Final Student Accuracy (Hint = 1):
Test Accuracy: 36.61%
Final Student Accuracy (Hint = 2):
Test Accuracy: 37.23%
Final Student Accuracy (Hint = 3):
Test Accuracy: 36.02%


In [18]:
nn_light = LightNN(num_classes=100).to(device)
total_params_deep = "{:,}".format(sum(p.numel() for p in teacher.parameters()))
print(f"DeepNN parameters: {total_params_deep}")
total_params_light = "{:,}".format(sum(p.numel() for p in student.parameters()))
print(f"LightNN parameters: {total_params_light}")
total_params_light = "{:,}".format(sum(p.numel() for p in modified_nn_light_reg.parameters()))
print(f"LightNN parameters: {total_params_light}")
total_params_light = "{:,}".format(sum(p.numel() for p in modified_nn_light_reg2.parameters()))
print(f"LightNN parameters: {total_params_light}")
total_params_light = "{:,}".format(sum(p.numel() for p in modified_nn_light_reg3.parameters()))
print(f"LightNN parameters: {total_params_light}")

DeepNN parameters: 3,247,292
LightNN parameters: 151,988
LightNN parameters: 170,548
LightNN parameters: 254,872
LightNN parameters: 306,136
