# Knowledge Distillation with Hint Training: Experiment 3

Multiple hint layers


In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
batch_size = 128

mean = torch.tensor([0.4914, 0.4822, 0.4465])
std = torch.tensor([0.2009, 0.2009, 0.2009])
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean = mean, std = std)])
cifar_train_data = datasets.CIFAR100('./data', train=True, download=True, transform=transform)
cifar_test_data = datasets.CIFAR100('./data', train=False, download=True, transform=transform)

train_loader = torch.utils.data.DataLoader(cifar_train_data, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(cifar_test_data, batch_size=batch_size)

Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to ./data/cifar-100-python.tar.gz


100%|██████████| 169001437/169001437 [00:05<00:00, 29371384.10it/s]


Extracting ./data/cifar-100-python.tar.gz to ./data
Files already downloaded and verified


In [3]:
for x, y in train_loader:
    print(x.shape)
    break

torch.Size([128, 3, 32, 32])


Network initializations

In [23]:
# Wide Shallow neural network class to be used as teacher:
# 3 different instantiations where the hint layer parameters are returned, either first, second, or third convolutional layer

class DeepNN(nn.Module):
    def __init__(self, num_classes=100):
        super(DeepNN, self).__init__()
        self.features1 = nn.Sequential(
            nn.Conv2d(3, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.features2 = nn.Sequential(
            nn.Conv2d(128, 356, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.features3 = nn.Sequential(
            nn.Conv2d(356, 356, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=4, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Linear(3204, 512),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        x = self.features1(x)
        conv_feature_map = x
        x = self.features2(x)
        conv_feature_map2 = x
        x = self.features3(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x, conv_feature_map, conv_feature_map2

class DeepNN2(nn.Module):
    def __init__(self, num_classes=100):
        super(DeepNN2, self).__init__()
        self.features1 = nn.Sequential(
            nn.Conv2d(3, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.features2 = nn.Sequential(
            nn.Conv2d(128, 356, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.features3 = nn.Sequential(
            nn.Conv2d(356, 356, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=4, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Linear(3204, 512),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(512, num_classes)
        )


    def forward(self, x):
        x = self.features1(x)
        x = self.features2(x)
        conv_feature_map = x
        x = self.features3(x)
        conv_feature_map2 = x
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x, conv_feature_map, conv_feature_map2

class DeepNN3(nn.Module):
    def __init__(self, num_classes=100):
        super(DeepNN3, self).__init__()
        self.features1 = nn.Sequential(
            nn.Conv2d(3, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.features2 = nn.Sequential(
            nn.Conv2d(128, 356, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.features3 = nn.Sequential(
            nn.Conv2d(356, 356, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=4, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Linear(3204, 512),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        x = self.features1(x)
        conv_feature_map = x
        x = self.features2(x)
        x = self.features3(x)
        conv_feature_map2 = x
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x, conv_feature_map, conv_feature_map2



# Lightweight neural network class to be used as student, much deeper but thinner network:
class LightNN(nn.Module):
    def __init__(self, num_classes=100):
        super(LightNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(32, 48, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(48, 48, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(48, 48, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=4, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Linear(432, 128),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        conv_feature_map = x
        conv_feature_map1 = x
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x, conv_feature_map, conv_feature_map1

In [8]:
#Training and testing loop normal
def test(model, test_dl):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()
    correct = 0
    num = 0
    with torch.no_grad():
        for (X, y) in test_dl:
            X, y = X.to(device), y.to(device)
            y_pred , _, _ = model(X)
            y_pred = torch.argmax(y_pred, dim=1)
            correct += torch.eq(y, y_pred).sum()
            num += X.shape[0]

    accuracy = correct / num
    return accuracy.item()

def train(model, lr, num_epochs, train_dl, test_dl):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    opt = torch.optim.Adam(model.parameters(), lr=lr)
    loss_fn = nn.CrossEntropyLoss()

    test_accuracy_list = []
    for epoch in range(1, num_epochs + 1):
        model.train()
        for X, y in train_dl:
            X, y = X.to(device), y.to(device)
            opt.zero_grad()
            y_pred, _, _ = model(X)
            loss = loss_fn(y_pred, y)
            loss.backward()
            opt.step()

        test_accuracy = test(model, test_dl)
        print(f"Test accuracy at epoch {epoch}: {test_accuracy:.4f}")
        test_accuracy_list.append(test_accuracy)
    return test_accuracy_list

In [9]:
#training teacher for 10 epochs
lr =  0.001
num_epochs = 15
teacher = DeepNN()

x = train(teacher, lr, num_epochs, train_loader, test_loader)
test_accuracy = test(teacher, test_loader)
print(f"Final teacher test accuracy on CIFAR-100: {test_accuracy:.4f}")

Test accuracy at epoch 1: 0.2317
Test accuracy at epoch 2: 0.3249
Test accuracy at epoch 3: 0.3723
Test accuracy at epoch 4: 0.4174
Test accuracy at epoch 5: 0.4372
Test accuracy at epoch 6: 0.4449
Test accuracy at epoch 7: 0.4490
Test accuracy at epoch 8: 0.4523
Test accuracy at epoch 9: 0.4561
Test accuracy at epoch 10: 0.4596
Test accuracy at epoch 11: 0.4596
Test accuracy at epoch 12: 0.4526
Test accuracy at epoch 13: 0.4586
Test accuracy at epoch 14: 0.4502
Test accuracy at epoch 15: 0.4520
Final teacher test accuracy on CIFAR-100: 0.4520


In [24]:
#student model training
lr =  0.001
num_epochs = 15
student = LightNN()

x = train(student, lr, num_epochs, train_loader, test_loader)
test_accuracy = test(student, test_loader)
print(f"Final student test accuracy on CIFAR-100: {test_accuracy:.4f}")

Test accuracy at epoch 1: 0.0889
Test accuracy at epoch 2: 0.1592
Test accuracy at epoch 3: 0.2067
Test accuracy at epoch 4: 0.2331
Test accuracy at epoch 5: 0.2586
Test accuracy at epoch 6: 0.2758
Test accuracy at epoch 7: 0.3001
Test accuracy at epoch 8: 0.3049
Test accuracy at epoch 9: 0.3176
Test accuracy at epoch 10: 0.3294
Test accuracy at epoch 11: 0.3422
Test accuracy at epoch 12: 0.3403
Test accuracy at epoch 13: 0.3431
Test accuracy at epoch 14: 0.3565
Test accuracy at epoch 15: 0.3384
Final student test accuracy on CIFAR-100: 0.3384


In [33]:
#Distillation loss
def test_d(model, test_dl):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()
    correct = 0
    num = 0
    with torch.no_grad():
        for (X, y) in test_dl:
            X, y = X.to(device), y.to(device)
            y_pred, _, _ = model(X)
            y_pred = torch.argmax(y_pred, dim=1)
            correct += torch.eq(y, y_pred).sum()
            num += X.shape[0]

    accuracy = correct / num
    return accuracy.item()

def train_d(teacher, student, lr, num_epochs, temperature, alpha, train_dl, test_dl):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    teacher.to(device)
    student.to(device)
    opt = torch.optim.Adam(student.parameters(), lr=lr)
    loss_fn = nn.CrossEntropyLoss()

    test_accuracy_list = []
    for epoch in range(1, num_epochs + 1):
        student.train()
        running_loss = 0.0
        for X, y in train_dl:
            X, y = X.to(device), y.to(device)
            opt.zero_grad()
            student_logits, _, _ = student(X)
            with torch.no_grad():
                teacher_logits, _, _ = teacher(X)

            student_probs = torch.softmax(student_logits / temperature, dim=1)


            soft_targets = nn.functional.softmax(teacher_logits / temperature, dim=-1)
            soft_prob = nn.functional.log_softmax(student_logits / temperature, dim=-1)
            soft_targets_loss = -torch.sum(soft_targets * soft_prob) / soft_prob.size()[0] * (temperature**2)
            label_loss = loss_fn(student_logits, y)
            loss = 0.25 * soft_targets_loss + 0.75* label_loss


            loss.backward()
            opt.step()

            running_loss += loss.item()

        #print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")

        test_accuracy = test(student, test_dl)
        print(f"Test accuracy at epoch {epoch}: {test_accuracy:.4f}")
        test_accuracy_list.append(test_accuracy)
    return test_accuracy_list

In [34]:
lr =  0.001
num_epochs = 15
student = LightNN()

train_d(teacher, student, lr, num_epochs, 2, .75, train_loader, test_loader)
test_accuracy = test_d(student, test_loader)
print(f"Final student test accuracy on CIFAR-10: {test_accuracy:.4f}")

Test accuracy at epoch 1: 0.0702
Test accuracy at epoch 2: 0.1453
Test accuracy at epoch 3: 0.1937
Test accuracy at epoch 4: 0.2305
Test accuracy at epoch 5: 0.2454
Test accuracy at epoch 6: 0.2739
Test accuracy at epoch 7: 0.2870
Test accuracy at epoch 8: 0.3071
Test accuracy at epoch 9: 0.3226
Test accuracy at epoch 10: 0.3200
Test accuracy at epoch 11: 0.3457
Test accuracy at epoch 12: 0.3508
Test accuracy at epoch 13: 0.3566
Test accuracy at epoch 14: 0.3607
Test accuracy at epoch 15: 0.3683
Final student test accuracy on CIFAR-10: 0.3683


In [11]:
class ModifiedLightNNRegressor(nn.Module):
    def __init__(self, num_classes=100):
        super(ModifiedLightNNRegressor, self).__init__()
        self.features1 = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )

        self.features2 = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.features3= nn.Sequential(
            nn.Conv2d(32, 48, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(48, 48, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(48, 48, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=4, stride=2),
        )

        self.regressor = nn.Sequential(
            nn.Conv2d(16, 128, kernel_size=3, padding=1)
        )

        self.regressor1 = nn.Sequential(
            nn.Conv2d(32, 356, kernel_size=3, padding=1)
        )

        self.classifier = nn.Sequential(
            nn.Linear(432, 128),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        x = self.features1(x)
        regressor_output = self.regressor(x)
        x = self.features2(x)
        regressor_output1 = self.regressor1(x)
        x = self.features3(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x, regressor_output, regressor_output1



class ModifiedLightNNRegressor2(nn.Module):
    def __init__(self, num_classes=100):
        super(ModifiedLightNNRegressor2, self).__init__()
        self.features1 = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )

        self.features2 = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.features3= nn.Sequential(
            nn.Conv2d(32, 48, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(48, 48, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(48, 48, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=4, stride=2),
        )

        self.regressor = nn.Sequential(
            nn.Conv2d(32, 356, kernel_size=3, padding=1)
        )

        self.regressor1 = nn.Sequential(
            nn.Conv2d(48, 356, kernel_size=3, padding=1)
        )

        self.classifier = nn.Sequential(
            nn.Linear(432, 128),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        x = self.features1(x)
        x = self.features2(x)
        regressor_output = self.regressor(x)
        x = self.features3(x)
        regressor_output1 = self.regressor1(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x, regressor_output, regressor_output1




class ModifiedLightNNRegressor3(nn.Module):
    def __init__(self, num_classes=100):
        super(ModifiedLightNNRegressor3, self).__init__()
        self.features1 = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )

        self.features2 = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.features3= nn.Sequential(
            nn.Conv2d(32, 48, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(48, 48, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(48, 48, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=4, stride=2),
        )

        self.regressor = nn.Sequential(
            nn.Conv2d(16, 128, kernel_size=3, padding=1)
        )

        self.regressor1 = nn.Sequential(
            nn.Conv2d(48, 356, kernel_size=3, padding=1)
        )

        self.classifier = nn.Sequential(
            nn.Linear(432, 128),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        x = self.features1(x)
        regressor_output = self.regressor(x)
        x = self.features2(x)
        x = self.features3(x)
        regressor_output1 = self.regressor1(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x, regressor_output, regressor_output1

In [12]:
#Distillation loss
def test_d(model, test_dl):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()
    correct = 0
    num = 0
    with torch.no_grad():
        for (X, y) in test_dl:
            X, y = X.to(device), y.to(device)
            y_pred, _, _ = model(X)
            y_pred = torch.argmax(y_pred, dim=1)
            correct += torch.eq(y, y_pred).sum()
            num += X.shape[0]

    accuracy = correct / num
    return accuracy.item()

def train_d1(teacher, student, lr, num_epochs, temperature, alpha, train_dl, test_dl):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    teacher.to(device)
    student.to(device)
    opt = torch.optim.Adam(student.parameters(), lr=lr)
    loss_fn = nn.CrossEntropyLoss()
    mse_loss = nn.MSELoss()

    test_accuracy_list = []
    for epoch in range(1, num_epochs + 1):
        student.train()
        running_loss = 0.0
        for X, y in train_dl:
            X, y = X.to(device), y.to(device)
            opt.zero_grad()
            student_logits, regressor_feature_map, regressor_feature_map1 = student(X)
            with torch.no_grad():
                teacher_logits, teacher_feature_map, teacher_feature_map1 = teacher(X)

            hidden_rep_loss = mse_loss(regressor_feature_map, regressor_feature_map)
            hidden_rep_loss1 = mse_loss(regressor_feature_map1, regressor_feature_map1)
            label_loss = loss_fn(student_logits, y)
            loss = .75*label_loss + .125*hidden_rep_loss + .125*hidden_rep_loss1


            loss.backward()
            opt.step()

            running_loss += loss.item()

        #print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")

        test_accuracy = test(student, test_dl)
        print(f"Test accuracy at epoch {epoch}: {test_accuracy:.4f}")
        test_accuracy_list.append(test_accuracy)
    return test_accuracy_list


# Initialize a ModifiedLightNNRegressor
torch.manual_seed(42)
modified_nn_light_reg = ModifiedLightNNRegressor(num_classes=100).to(device)
modified_nn_light_reg2 = ModifiedLightNNRegressor2(num_classes=100).to(device)
modified_nn_light_reg3 = ModifiedLightNNRegressor3(num_classes=100).to(device)

# We do not have to train the modified deep network from scratch of course, we just load its weights from the trained instance
modified_nn_deep_reg = DeepNN(num_classes=100).to(device)
modified_nn_deep_reg.load_state_dict(teacher.state_dict())
modified_nn_deep_reg2 = DeepNN2(num_classes=100).to(device)
modified_nn_deep_reg2.load_state_dict(teacher.state_dict())
modified_nn_deep_reg3 = DeepNN3(num_classes=100).to(device)
modified_nn_deep_reg3.load_state_dict(teacher.state_dict())

<All keys matched successfully>

In [13]:
for param in modified_nn_light_reg.classifier.parameters():
  param.requires_grad = False
for param in modified_nn_light_reg2.classifier.parameters():
  param.requires_grad = False
for param in modified_nn_light_reg3.classifier.parameters():
  param.requires_grad = False

# Train and test once again
train_d1(teacher=modified_nn_deep_reg, student=modified_nn_light_reg, lr=0.001, num_epochs=10, temperature=0, alpha = 0, train_dl=train_loader, test_dl=test_loader)
train_d1(teacher=modified_nn_deep_reg2, student=modified_nn_light_reg2, lr=0.001, num_epochs=10, temperature=0, alpha = 0, train_dl=train_loader, test_dl=test_loader)
train_d1(teacher=modified_nn_deep_reg3, student=modified_nn_light_reg3, lr=0.001, num_epochs=10, temperature=0, alpha = 0, train_dl=train_loader, test_dl=test_loader)

Test accuracy at epoch 1: 0.0289
Test accuracy at epoch 2: 0.0562
Test accuracy at epoch 3: 0.0810
Test accuracy at epoch 4: 0.1040
Test accuracy at epoch 5: 0.1138
Test accuracy at epoch 6: 0.1260
Test accuracy at epoch 7: 0.1417
Test accuracy at epoch 8: 0.1480
Test accuracy at epoch 9: 0.1617
Test accuracy at epoch 10: 0.1611
Test accuracy at epoch 1: 0.0341
Test accuracy at epoch 2: 0.0469
Test accuracy at epoch 3: 0.0646
Test accuracy at epoch 4: 0.0762
Test accuracy at epoch 5: 0.0912
Test accuracy at epoch 6: 0.1058
Test accuracy at epoch 7: 0.1228
Test accuracy at epoch 8: 0.1409
Test accuracy at epoch 9: 0.1389
Test accuracy at epoch 10: 0.1565
Test accuracy at epoch 1: 0.0247
Test accuracy at epoch 2: 0.0549
Test accuracy at epoch 3: 0.0715
Test accuracy at epoch 4: 0.0906
Test accuracy at epoch 5: 0.1023
Test accuracy at epoch 6: 0.1307
Test accuracy at epoch 7: 0.1279
Test accuracy at epoch 8: 0.1498
Test accuracy at epoch 9: 0.1551
Test accuracy at epoch 10: 0.1657


[0.024699999019503593,
 0.054899998009204865,
 0.07149999588727951,
 0.09059999883174896,
 0.1022999957203865,
 0.1306999921798706,
 0.12789998948574066,
 0.14980000257492065,
 0.1551000028848648,
 0.16569998860359192]

In [14]:
#unfreeze parameters and train finally with knowledge distillation for stage 2

for param in modified_nn_light_reg.classifier.parameters():
  param.requires_grad = True
for param in modified_nn_light_reg2.classifier.parameters():
  param.requires_grad = True
for param in modified_nn_light_reg3.classifier.parameters():
  param.requires_grad = True


def test_d(model, test_dl):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()
    correct = 0
    num = 0
    with torch.no_grad():
        for (X, y) in test_dl:
            X, y = X.to(device), y.to(device)
            y_pred, _, _ = model(X)
            y_pred = torch.argmax(y_pred, dim=1)
            correct += torch.eq(y, y_pred).sum()
            num += X.shape[0]

    accuracy = correct / num
    return accuracy.item()

def train_d(teacher, student, lr, num_epochs, temperature, alpha, train_dl, test_dl):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    teacher.to(device)
    student.to(device)
    opt = torch.optim.Adam(student.parameters(), lr=lr)
    loss_fn = nn.CrossEntropyLoss()

    test_accuracy_list = []
    for epoch in range(1, num_epochs + 1):
        student.train()
        running_loss = 0.0
        for X, y in train_dl:
            X, y = X.to(device), y.to(device)
            opt.zero_grad()
            student_logits, _, _ = student(X)
            with torch.no_grad():
                teacher_logits, _, _ = teacher(X)

            student_probs = torch.softmax(student_logits / temperature, dim=1)


            soft_targets = nn.functional.softmax(teacher_logits / temperature, dim=-1)
            soft_prob = nn.functional.log_softmax(student_logits / temperature, dim=-1)
            soft_targets_loss = -torch.sum(soft_targets * soft_prob) / soft_prob.size()[0] * (temperature**2)
            label_loss = loss_fn(student_logits, y)
            loss = 0.25 * soft_targets_loss + 0.75* label_loss


            loss.backward()
            opt.step()

            running_loss += loss.item()

        #print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")

        test_accuracy = test(student, test_dl)
        print(f"Test accuracy at epoch {epoch}: {test_accuracy:.4f}")
        test_accuracy_list.append(test_accuracy)
    return test_accuracy_list

In [15]:
train_d(teacher=modified_nn_deep_reg, student=modified_nn_light_reg, lr=0.001, num_epochs=10, temperature=2, alpha = 0.75, train_dl=train_loader, test_dl=test_loader)
train_d(teacher=modified_nn_deep_reg2, student=modified_nn_light_reg2,  lr=0.001, num_epochs=10, temperature=2, alpha = 0.75, train_dl=train_loader, test_dl=test_loader)
train_d(teacher=modified_nn_deep_reg3, student=modified_nn_light_reg3,  lr=0.001, num_epochs=10, temperature=2, alpha = 0.75, train_dl=train_loader, test_dl=test_loader)

Test accuracy at epoch 1: 0.2329
Test accuracy at epoch 2: 0.2528
Test accuracy at epoch 3: 0.2784
Test accuracy at epoch 4: 0.2986
Test accuracy at epoch 5: 0.3014
Test accuracy at epoch 6: 0.3100
Test accuracy at epoch 7: 0.3273
Test accuracy at epoch 8: 0.3254
Test accuracy at epoch 9: 0.3341
Test accuracy at epoch 10: 0.3348
Test accuracy at epoch 1: 0.2242
Test accuracy at epoch 2: 0.2625
Test accuracy at epoch 3: 0.2830
Test accuracy at epoch 4: 0.3049
Test accuracy at epoch 5: 0.3115
Test accuracy at epoch 6: 0.3349
Test accuracy at epoch 7: 0.3414
Test accuracy at epoch 8: 0.3470
Test accuracy at epoch 9: 0.3500
Test accuracy at epoch 10: 0.3609
Test accuracy at epoch 1: 0.2416
Test accuracy at epoch 2: 0.2684
Test accuracy at epoch 3: 0.2932
Test accuracy at epoch 4: 0.2908
Test accuracy at epoch 5: 0.3092
Test accuracy at epoch 6: 0.3317
Test accuracy at epoch 7: 0.3396
Test accuracy at epoch 8: 0.3437
Test accuracy at epoch 9: 0.3442
Test accuracy at epoch 10: 0.3552


[0.24159999191761017,
 0.2683999836444855,
 0.29319998621940613,
 0.290800005197525,
 0.3091999888420105,
 0.33169999718666077,
 0.33959999680519104,
 0.34369999170303345,
 0.3441999852657318,
 0.35519999265670776]

In [18]:
def test_multiple_outputs(model, test_loader, device):
    model.to(device)
    model.eval()

    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            outputs, _, _ = model(inputs) # Disregard the second tensor of the tuple
            _, predicted = torch.max(outputs.data, 1)

            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f"Test Accuracy: {accuracy:.2f}%")
    return accuracy

In [31]:
print(f"Teacher Accuracy:")
test_accuracy_deep = test(modified_nn_deep_reg, test_loader)
print(f"{(100*test_accuracy_deep):.2f}%")

test_accuracy = test_d(student, test_loader)
print(f"Student accuracy regular training: {(100*test_accuracy):.2f}%")

print(f"Final Student Accuracy (Hint = 1,2):")
test_accuracy_light_ce_and_mse_loss1 = test_multiple_outputs(modified_nn_light_reg, test_loader, device)
print(f"Final Student Accuracy (Hint = 2,3):")
test_accuracy_light_ce_and_mse_loss2 = test_multiple_outputs(modified_nn_light_reg2, test_loader, device)
print(f"Final Student Accuracy (Hint = 1,3):")
test_accuracy_light_ce_and_mse_loss3 = test_multiple_outputs(modified_nn_light_reg3, test_loader, device)


Teacher Accuracy:
45.20%
Student accuracy regular training: 33.84%
Final Student Accuracy (Hint = 1,2):
Test Accuracy: 33.48%
Final Student Accuracy (Hint = 2,3):
Test Accuracy: 36.09%
Final Student Accuracy (Hint = 1,3):
Test Accuracy: 35.52%


In [21]:
nn_light = LightNN(num_classes=100).to(device)
total_params_deep = "{:,}".format(sum(p.numel() for p in teacher.parameters()))
print(f"DeepNN parameters: {total_params_deep}")
total_params_light = "{:,}".format(sum(p.numel() for p in student.parameters()))
print(f"LightNN parameters: {total_params_light}")
total_params_light = "{:,}".format(sum(p.numel() for p in modified_nn_light_reg.parameters()))
print(f"LightNN parameters: {total_params_light}")
total_params_light = "{:,}".format(sum(p.numel() for p in modified_nn_light_reg2.parameters()))
print(f"LightNN parameters: {total_params_light}")
total_params_light = "{:,}".format(sum(p.numel() for p in modified_nn_light_reg3.parameters()))
print(f"LightNN parameters: {total_params_light}")

DeepNN parameters: 3,247,292
LightNN parameters: 151,988
LightNN parameters: 273,432
LightNN parameters: 409,020
LightNN parameters: 324,696
