In [1]:
# This notebook is based on PyTorch's Knowledge Distillation Tutorial, found here:
# https://pytorch.org/tutorials/beginner/knowledge_distillation_tutorial.html

%matplotlib inline

In [3]:
!pip install wandb -qU

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m20.1/20.1 MB[0m [31m71.8 MB/s[0m eta [36m0:00:00[0m
[?25h

In [4]:
# Log in to W&B account
import wandb

In [5]:
wandb.login()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

 ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
from torch.utils.data import DataLoader, TensorDataset

# Check if GPU is available, and if not, use the CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Load Data

In [12]:
import pandas as pd
import time

In [7]:
X = pd.read_csv("initial_train_dataframe.csv")
y = pd.read_csv("initial_label_dataframe.csv")

In [8]:
numeric_features = [ "arrival_0", "arrival_15", "arrival_30", "arrival_45", "depart_0", "depart_15", "depart_30", "depart_45", "quarter_0", "quarter_15", "quarter_30", "quarter_45", "month_0", "day_of_week"]
labels= ["arrival_0_1H", "arrival_15_1H", "arrival_30_1H", "arrival_45_1H", "arrival_0_2H",  "arrival_15_2H", "arrival_30_2H", "arrival_45_2H", "arrival_0_3H", "arrival_15_3H", "arrival_30_3H", "arrival_45_3H"]
indices = list(range(0, X.shape[0]))
trainIndex = int(0.8*len(indices))
trainTensor = torch.tensor((X[numeric_features]).values[indices[0:trainIndex]], device="cpu").float()
trainLabels = torch.tensor((y[labels]).values[indices[0:trainIndex]], device="cpu").float()
valTensor = torch.tensor((X[numeric_features]).values[indices[trainIndex:-14]], device="cpu").float()
valLabels = torch.tensor((y[labels]).values[indices[trainIndex:-14]], device="cpu").float()

train_dataset = TensorDataset(trainTensor, trainLabels)
test_dataset = TensorDataset(valTensor, valLabels)

train_loader = DataLoader(train_dataset, batch_size=56, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=56, shuffle=True, num_workers=2)

Defining model classes and utility functions
============================================

Next, we need to define our model classes. Several user-defined
parameters need to be set here. We use two different architectures,
keeping the number of filters fixed across our experiments to ensure
fair comparisons. Both architectures are Convolutional Neural Networks
(CNNs) with a different number of convolutional layers that serve as
feature extractors, followed by a classifier with 12 classes. The number
of filters and neurons is smaller for the students.


In [9]:
# Deeper neural network class to be used as teacher:
class DeepNN(nn.Module):
    def __init__(self, num_classes=12):
        super(DeepNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv1d(56, 14, kernel_size=1, padding=0),
            nn.ReLU(),
            nn.Conv1d(14, 64, kernel_size=1, padding=0),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=1, stride=2),
            nn.Conv1d(64, 128, kernel_size=1, padding=0),
            nn.ReLU(),
            nn.Conv1d(128, 64, kernel_size=1, padding=0),
            nn.ReLU(),

            nn.Conv1d(64, 56, kernel_size=1, padding=0),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=1, stride=2),

        )
        self.classifier = nn.Sequential(
            nn.Linear(4, 56),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(56, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

# Lightweight neural network class to be used as student:
class LightNN(nn.Module):
    def __init__(self, num_classes=10):
        super(LightNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv1d(56, 14, kernel_size=1, padding=0),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=1, stride=2),
            nn.Conv1d(14, 56, kernel_size=1, padding=0),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=1, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Linear(4, 56),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(56, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

In [13]:
loss_fn = lambda x,y: torch.exp(-torch.sqrt(((x-y)**2).mean())/10)

def train(model, train_loader, epochs, learning_rate, device):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    model.train()

    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in train_loader:

            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)

            loss = -loss_fn(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        # log loss using W and B
        wandb.log({"loss": running_loss / len(train_loader)})

        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")

def test(model, test_loader, device):

    start = time.time() # record inference time

    model.to(device)
    model.eval()

    total = 0
    squared_error = 0.0

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)

            total += 1
            squared_error += loss_fn(outputs, labels)

    accuracy = squared_error / total
    print(f"Test acc: {accuracy:.4f}")

    end = time.time()
    print(f"Elapsed time: {end-start:.2f}")
    wandb.log({"time": end-start}) # log inference time
    return accuracy

Initial training runs
==================

For reproducibility, we set the torch manual seed. We train
networks using different methods, so to compare them fairly, it makes
sense to initialize the networks with the same weights. Start by
training the teacher network:


In [27]:
torch.manual_seed(42)
nn_deep = DeepNN(num_classes=12).to(device)
wandb.init(
    # Set the project
    project="HPML_KD",
    name=f"deep_nn_train",
    # Track hyperparameters and run metadata
    config={
    "learning_rate": 0.001,
    "architecture": "CNN",
    "dataset": "NASA Airport Throughput Challenge - FUSER Dataset",
    "epochs": 15,
    })
train(nn_deep, train_loader, epochs=15, learning_rate=0.001, device=device)

# Mark the run as finished
wandb.finish()

Epoch 1/15, Loss: -0.4648365443126081
Epoch 2/15, Loss: -0.5102230784611199
Epoch 3/15, Loss: -0.5132005197552446
Epoch 4/15, Loss: -0.5132960943749156
Epoch 5/15, Loss: -0.513472133646377
Epoch 6/15, Loss: -0.5132478924985892
Epoch 7/15, Loss: -0.5131508321426927
Epoch 8/15, Loss: -0.5136609368811781
Epoch 9/15, Loss: -0.513717073030746
Epoch 10/15, Loss: -0.514194377980674
Epoch 11/15, Loss: -0.5146414720402739
Epoch 12/15, Loss: -0.5143257230044173
Epoch 13/15, Loss: -0.514994304115399
Epoch 14/15, Loss: -0.5152962920003044
Epoch 15/15, Loss: -0.5155271582138805


0,1
loss,█▂▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
loss,-0.51553


In [28]:
wandb.init(
    project="HPML_KD",
    name=f"deep_nn_test",
    config={
    "learning_rate": 0.001,
    "architecture": "CNN",
    "dataset": "NASA Airport Throughput Challenge - FUSER Dataset",
    # "epochs": 15,
    })
test_accuracy_deep = test(nn_deep, test_loader, device)

wandb.finish()

Test acc: 0.6357
Elapsed time: 0.34


In [29]:
# Instantiate the lightweight network:
torch.manual_seed(42)
nn_light = LightNN(num_classes=12).to(device)

We instantiate one more lightweight network model to compare their
performances. Back propagation is sensitive to weight initialization, so
we need to make sure these two networks have the exact same
initialization.


In [30]:
torch.manual_seed(42)
new_nn_light = LightNN(num_classes=12).to(device)

To ensure we have created a copy of the first network, we inspect the
norm of its first layer. If it matches, then we are safe to conclude
that the networks are indeed the same.


In [31]:
# Print the norm of the first layer of the initial lightweight model
print("Norm of 1st layer of nn_light:", torch.norm(nn_light.features[0].weight).item())
# Print the norm of the first layer of the new lightweight model
print("Norm of 1st layer of new_nn_light:", torch.norm(new_nn_light.features[0].weight).item())

Norm of 1st layer of nn_light: 2.190237045288086
Norm of 1st layer of new_nn_light: 2.190237045288086


Print the total number of parameters in each model:


In [32]:
total_params_deep = "{:,}".format(sum(p.numel() for p in nn_deep.parameters()))
print(f"DeepNN parameters: {total_params_deep}")
total_params_light = "{:,}".format(sum(p.numel() for p in nn_light.parameters()))
print(f"LightNN parameters: {total_params_light}")

DeepNN parameters: 22,938
LightNN parameters: 2,602


Train and test the lightweight network with cross entropy loss:


In [33]:
wandb.init(
    project="HPML_KD",
    name=f"light_nn_train",
    config={
    "learning_rate": 0.001,
    "architecture": "CNN",
    "dataset": "NASA Airport Throughput Challenge - FUSER Dataset",
    "epochs": 15,
    })
train(nn_light, train_loader, epochs=15, learning_rate=0.001, device=device)

wandb.finish()

Epoch 1/15, Loss: -0.4673674910213239
Epoch 2/15, Loss: -0.5052133432973307
Epoch 3/15, Loss: -0.5111270848744974
Epoch 4/15, Loss: -0.5123393821259277
Epoch 5/15, Loss: -0.5134803881279577
Epoch 6/15, Loss: -0.5141326758427361
Epoch 7/15, Loss: -0.5152957408001628
Epoch 8/15, Loss: -0.5158408610775067
Epoch 9/15, Loss: -0.5168480135190981
Epoch 10/15, Loss: -0.5182226090766371
Epoch 11/15, Loss: -0.5189879449030843
Epoch 12/15, Loss: -0.5189019357815338
Epoch 13/15, Loss: -0.5200413916819393
Epoch 14/15, Loss: -0.5204711543104519
Epoch 15/15, Loss: -0.5201884551931875


0,1
loss,█▃▂▂▂▂▂▂▁▁▁▁▁▁▁

0,1
loss,-0.52019


As we can see, based on test accuracy, we can now compare the deeper
network that is to be used as a teacher with the lightweight network
that is our supposed student. So far, our student has not intervened
with the teacher, therefore this performance is achieved by the student
itself. The metrics so far can be seen with the following lines:


In [34]:
wandb.init(
    project="HPML_KD",
    name=f"light_nn_test",
    config={
    "learning_rate": 0.001,
    "architecture": "CNN",
    "dataset": "NASA Airport Throughput Challenge - FUSER Dataset",
    # "epochs": 15,
    })
test_accuracy_light = test(nn_light, test_loader, device)

wandb.finish()

Test acc: 0.6372
Elapsed time: 0.44


In [35]:
print(f"Teacher accuracy: {test_accuracy_deep:.4f}%")
print(f"Student accuracy: {test_accuracy_light:.4f}%")

Teacher accuracy: 0.6357%
Student accuracy: 0.6372%


Knowledge distillation run: response-based knowledge
====================================================

In [36]:
def train_knowledge_distillation(teacher, student, train_loader, epochs, learning_rate, target_loss_weight, loss_weight, device):

    optimizer = optim.Adam(student.parameters(), lr=learning_rate)

    teacher.eval()  # Teacher set to evaluation mode
    student.train() # Student to train mode

    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()

            # Forward pass with the teacher model - do not save gradients here as we do not change the teacher's weights
            with torch.no_grad():
                teacher_out = teacher(inputs)

            # Forward pass with the student model
            student_out = student(inputs)

            # student-teacher loss
            student_teacher_loss = loss_fn(student_out, teacher_out)

            # Calculate the true label loss
            label_loss = loss_fn(student_out, labels)

            # Weighted sum of the two losses
            loss = -1 * (target_loss_weight * student_teacher_loss + loss_weight * label_loss)

            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        wandb.log({"loss": running_loss / len(train_loader)})

        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")

# Apply ``train_knowledge_distillation``
wandb.init(
    project="HPML_KD",
    name=f"light_nn_train_KD",
    config={
    "learning_rate": 0.001,
    "architecture": "CNN",
    "dataset": "NASA Airport Throughput Challenge - FUSER Dataset",
    "epochs": 15,
    })

train_knowledge_distillation(teacher=nn_deep, student=new_nn_light, train_loader=train_loader, epochs=15, learning_rate=0.001, target_loss_weight=0.25, loss_weight=0.75, device=device)
print(f"Teacher accuracy: {test_accuracy_deep:.4f}%")
wandb.finish()

wandb.init(
    project="HPML_KD",
    name=f"light_nn_test_2",
    config={
    "learning_rate": 0.001,
    "architecture": "CNN",
    "dataset": "NASA Airport Throughput Challenge - FUSER Dataset",
    # "epochs": 15,
    })
test_accuracy_light = test(nn_light, test_loader, device)
print(f"Student accuracy without teacher: {test_accuracy_light:.4f}%")
wandb.finish()

wandb.init(
    project="HPML_KD",
    name=f"light_nn_test_KD",
    config={
    "learning_rate": 0.001,
    "architecture": "CNN",
    "dataset": "NASA Airport Throughput Challenge - FUSER Dataset",
    # "epochs": 15,
    })
test_accuracy_light_kd = test(new_nn_light, test_loader, device)
print(f"Student accuracy with KD: {test_accuracy_light_kd:.4f}%")
wandb.finish()

# Compare the student test accuracy with and without the teacher, after distillation
print(f"Teacher accuracy: {test_accuracy_deep:.4f}%")
print(f"Student accuracy without teacher: {test_accuracy_light:.4f}%")
print(f"Student accuracy with CE + KD: {test_accuracy_light_kd:.4f}%")

Epoch 1/15, Loss: -0.5356747969842186
Epoch 2/15, Loss: -0.6031586674455637
Epoch 3/15, Loss: -0.6107667301790402
Epoch 4/15, Loss: -0.6130794317196733
Epoch 5/15, Loss: -0.6133716441571903
Epoch 6/15, Loss: -0.6141026618000799
Epoch 7/15, Loss: -0.6156381351498369
Epoch 8/15, Loss: -0.6158110628874538
Epoch 9/15, Loss: -0.6163829123250212
Epoch 10/15, Loss: -0.6178820851130988
Epoch 11/15, Loss: -0.6193491559439954
Epoch 12/15, Loss: -0.6202494393522366
Epoch 13/15, Loss: -0.6215826855680813
Epoch 14/15, Loss: -0.6224061338284526
Epoch 15/15, Loss: -0.6224503560949819
Teacher accuracy: 0.6357%


0,1
loss,█▃▂▂▂▂▂▂▁▁▁▁▁▁▁

0,1
loss,-0.62245


Test acc: 0.6378
Elapsed time: 0.32
Student accuracy without teacher: 0.6378%


Test acc: 0.6399
Elapsed time: 0.32
Student accuracy with KD: 0.6399%


Teacher accuracy: 0.6357%
Student accuracy without teacher: 0.6378%
Student accuracy with CE + KD: 0.6399%


Cosine loss minimization run: KD with feature-based knowledge
============================

Now we add a new loss function to convey the information in the teacher's hidden layers. Our goal is to minimize this naive cosine loss function, thereby
making the student's flattened hidden layer more similar to that of the teacher.

We also include an average pooling layer to ensure that the vectors are the same
size so the cosine loss function is well defined.

The `CosineEmbeddingLoss`
is given by the following formula:

![Formula for
CosineEmbeddingLoss](https://pytorch.org/tutorials//../_static/img/knowledge_distillation/cosine_embedding_loss.png){.align-center
width="450px"}


In [37]:
class ModifiedDeepNNCosine(nn.Module):
    def __init__(self, num_classes=10):
        super(ModifiedDeepNNCosine, self).__init__()
        self.features = nn.Sequential(
            nn.Conv1d(56, 14, kernel_size=1, padding=0),
            nn.ReLU(),
            nn.Conv1d(14, 64, kernel_size=1, padding=0),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=1, stride=2),
            nn.Conv1d(64, 128, kernel_size=1, padding=0),
            nn.ReLU(),
            nn.Conv1d(128, 64, kernel_size=1, padding=0),
            nn.ReLU(),

            nn.Conv1d(64, 56, kernel_size=1, padding=0),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=1, stride=2),

        )
        self.classifier = nn.Sequential(
            nn.Linear(4, 56),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(56, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        flattened_conv_output = torch.flatten(x, 1)
        x = self.classifier(flattened_conv_output)
        flattened_conv_output_after_pooling = torch.nn.functional.avg_pool1d(flattened_conv_output, 1)
        return x, flattened_conv_output_after_pooling

# Create a similar student class where we return a tuple. We do not apply pooling after flattening.
class ModifiedLightNNCosine(nn.Module):
    def __init__(self, num_classes=10):
        super(ModifiedLightNNCosine, self).__init__()
        self.features = nn.Sequential(
            nn.Conv1d(56, 14, kernel_size=1, padding=0),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=1, stride=2),
            nn.Conv1d(14, 56, kernel_size=1, padding=0),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=1, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Linear(4, 56),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(56, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        flattened_conv_output = torch.flatten(x, 1)
        x = self.classifier(flattened_conv_output)
        return x, flattened_conv_output

# We do not have to train the modified deep network from scratch of course, we just load its weights from the trained instance
modified_nn_deep = ModifiedDeepNNCosine(num_classes=12).to(device)
modified_nn_deep.load_state_dict(nn_deep.state_dict())

# Once again ensure the norm of the first layer is the same for both networks
print("Norm of 1st layer for deep_nn:", torch.norm(nn_deep.features[0].weight).item())
print("Norm of 1st layer for modified_deep_nn:", torch.norm(modified_nn_deep.features[0].weight).item())

# Initialize a modified lightweight network with the same seed as our other lightweight instances. This will be trained from scratch to examine the effectiveness of cosine loss minimization.
torch.manual_seed(42)
modified_nn_light = ModifiedLightNNCosine(num_classes=12).to(device)
print("Norm of 1st layer:", torch.norm(modified_nn_light.features[0].weight).item())

Norm of 1st layer for deep_nn: 2.5286672115325928
Norm of 1st layer for modified_deep_nn: 2.5286672115325928
Norm of 1st layer: 2.190237045288086


Next, we modify the training loop to account for the tuple that the forward pass returns which includes the output and the hidden representation.


In [119]:
# for inputs, labels in train_loader:
#   print(inputs.shape, labels.shape)
#   break

torch.Size([56, 14]) torch.Size([56, 12])


In [38]:
# Create a sample input tensor
sample_input = torch.randn(56, 14).to(device)

# Pass the input through the student
light_outputs, hidden_representation = modified_nn_light(sample_input)

# Print the shapes of the tensors
print("Student logits shape:", light_outputs.shape) # batch_size x total_classes
print("Student hidden representation shape:", hidden_representation.shape) # batch_size x hidden_representation_size

# Pass the input through the teacher
light_outputs, hidden_representation = modified_nn_deep(sample_input)

# Print the shapes of the tensors
print("Teacher logits shape:", light_outputs.shape) # batch_size x total_classes
print("Teacher hidden representation shape:", hidden_representation.shape) # batch_size x hidden_representation_size

Student logits shape: torch.Size([56, 12])
Student hidden representation shape: torch.Size([56, 4])
Teacher logits shape: torch.Size([56, 12])
Teacher hidden representation shape: torch.Size([56, 4])


In [39]:
def train_cosine_loss(teacher, student, train_loader, epochs, learning_rate, hidden_rep_loss_weight, loss_weight, device):

    cosine_loss = nn.CosineEmbeddingLoss()
    optimizer = optim.Adam(student.parameters(), lr=learning_rate)

    teacher.to(device)
    student.to(device)
    teacher.eval()  # Teacher set to evaluation mode
    student.train() # Student to train mode

    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()

            # Forward pass with the teacher model and keep only the hidden representation
            with torch.no_grad():
                _, teacher_hidden_representation = teacher(inputs)

            # Forward pass with the student model
            student_outs, student_hidden_representation = student(inputs)

            # Calculate the cosine loss. Target is a vector of ones. From the loss formula above we can see that is the case where loss minimization leads to cosine similarity increase.
            hidden_rep_loss = cosine_loss(student_hidden_representation, teacher_hidden_representation, target=torch.ones(inputs.size(0)).to(device))

            # Calculate the true label loss
            label_loss = -loss_fn(student_outs, labels)

            # Weighted sum of the two losses
            loss = hidden_rep_loss_weight * hidden_rep_loss + loss_weight * label_loss

            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        wandb.log({"loss": running_loss / len(train_loader)})

        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")

We need to modify our test function for the same reason. Here we ignore
the hidden representation returned by the model.


In [40]:
def test_multiple_outputs(model, test_loader, device):

    start = time.time() # record inference time

    model.to(device)
    model.eval()

    total = 0
    squared_error = 0.0

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            outputs, _ = model(inputs)

            total += 1
            squared_error += loss_fn(outputs, labels)

    accuracy = squared_error / total
    print(f"Test acc: {accuracy:.4f}")

    end = time.time()

    print(f"Elapsed time: {end-start:.4f}")
    wandb.log({"time": end-start}) # log inference time
    return accuracy

In this case, we could easily include both knowledge distillation and
cosine loss minimization in the same function. It is common to combine
methods to achieve better performance in teacher-student paradigms. For
now, we can run a simple train-test session.


In [41]:
# Train and test the lightweight network
wandb.init(
    project="HPML_KD",
    name=f"cosine_loss_train",
    config={
    "learning_rate": 0.001,
    "architecture": "CNN",
    "dataset": "NASA Airport Throughput Challenge - FUSER Dataset",
    "epochs": 15,
    })
train_cosine_loss(teacher=modified_nn_deep, student=modified_nn_light, train_loader=train_loader, epochs=15, learning_rate=0.001, hidden_rep_loss_weight=0.25, loss_weight=0.75, device=device)

wandb.finish()

wandb.init(
    project="HPML_KD",
    name=f"cosine_loss_test",
    config={
    "learning_rate": 0.001,
    "architecture": "CNN",
    "dataset": "NASA Airport Throughput Challenge - FUSER Dataset",
    # "epochs": 15,
    })
test_accuracy_light_cosine_loss = test_multiple_outputs(modified_nn_light, test_loader, device)
wandb.finish()

Epoch 1/15, Loss: -0.20772462216809914
Epoch 2/15, Loss: -0.2606479236112235
Epoch 3/15, Loss: -0.2638599706914859
Epoch 4/15, Loss: -0.2656288054137946
Epoch 5/15, Loss: -0.2667099564505842
Epoch 6/15, Loss: -0.26694401070332757
Epoch 7/15, Loss: -0.26906462556447464
Epoch 8/15, Loss: -0.2700655847883072
Epoch 9/15, Loss: -0.2712587409030896
Epoch 10/15, Loss: -0.2733841678395439
Epoch 11/15, Loss: -0.27719020857788124
Epoch 12/15, Loss: -0.2829544541363518
Epoch 13/15, Loss: -0.28555448205707173
Epoch 14/15, Loss: -0.28762641806191147
Epoch 15/15, Loss: -0.2879176004626119


0,1
loss,█▃▃▃▃▃▃▃▂▂▂▁▁▁▁

0,1
loss,-0.28792


Test acc: 0.6408
Elapsed time: 0.2957


0,1
time,▁

0,1
time,0.29571


Intermediate regressor run: KD with relation-based knowledge
==========================

Our naive cosine loss minimization does not guarantee better results, so we will try relation-based knowledge, implemented through an intermediate regressor which minimizes MSE between the feature maps of the student and teacher.


In [42]:
# Pass the sample input only from the convolutional feature extractor
convolutional_fe_output_student = nn_light.features(sample_input)
convolutional_fe_output_teacher = nn_deep.features(sample_input)

# Print their shapes
print("Student's feature extractor output shape: ", convolutional_fe_output_student.shape)
print("Teacher's feature extractor output shape: ", convolutional_fe_output_teacher.shape)

Student's feature extractor output shape:  torch.Size([56, 4])
Teacher's feature extractor output shape:  torch.Size([56, 4])


In [43]:
class ModifiedDeepNNRegressor(nn.Module):
    def __init__(self, num_classes=10):
        super(ModifiedDeepNNRegressor, self).__init__()
        self.features = nn.Sequential(
            nn.Conv1d(56, 14, kernel_size=1, padding=0),
            nn.ReLU(),
            nn.Conv1d(14, 64, kernel_size=1, padding=0),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=1, stride=2),
            nn.Conv1d(64, 128, kernel_size=1, padding=0),
            nn.ReLU(),
            nn.Conv1d(128, 64, kernel_size=1, padding=0),
            nn.ReLU(),

            nn.Conv1d(64, 56, kernel_size=1, padding=0),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=1, stride=2),

        )
        self.classifier = nn.Sequential(
            nn.Linear(4, 56),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(56, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        conv_feature_map = x
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x, conv_feature_map

class ModifiedLightNNRegressor(nn.Module):
    def __init__(self, num_classes=10):
        super(ModifiedLightNNRegressor, self).__init__()
        self.features = nn.Sequential(
            nn.Conv1d(56, 14, kernel_size=1, padding=0),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=1, stride=2),
            nn.Conv1d(14, 56, kernel_size=1, padding=0),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=1, stride=2),
        )
        # Include an extra regressor (in our case linear)
        self.regressor = nn.Sequential(
            nn.Conv1d(56, 56, kernel_size=1, padding=0)
        )
        self.classifier = nn.Sequential(
            nn.Linear(4, 56),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(56, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        regressor_output = self.regressor(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x, regressor_output

After that, we have to update our train loop again. This time, we
extract the regressor output of the student, the feature map of the
teacher, we calculate the `MSE` on these tensors (they have the exact
same shape so it\'s properly defined) and we back propagate gradients
based on that loss, in addition to our loss function.


In [44]:
def train_mse_loss(teacher, student, train_loader, epochs, learning_rate, feature_map_weight, loss_weight, device):

    mse_loss = nn.MSELoss()
    optimizer = optim.Adam(student.parameters(), lr=learning_rate)

    teacher.to(device)
    student.to(device)
    teacher.eval()  # Teacher set to evaluation mode
    student.train() # Student to train mode

    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()

            # Again ignore teacher outputs
            with torch.no_grad():
                _, teacher_feature_map = teacher(inputs)

            # Forward pass with the student model
            student_logits, regressor_feature_map = student(inputs)

            # Calculate the loss
            hidden_rep_loss = mse_loss(regressor_feature_map, teacher_feature_map)

            # Calculate the true label loss
            label_loss = -loss_fn(student_logits, labels)

            # Weighted sum of the two losses
            loss = feature_map_weight * hidden_rep_loss + loss_weight * label_loss

            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        wandb.log({"loss": running_loss / len(train_loader)})
        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")


# Initialize a ModifiedLightNNRegressor
torch.manual_seed(42)
modified_nn_light_reg = ModifiedLightNNRegressor(num_classes=12).to(device)

# We do not have to train the modified deep network from scratch of course, we just load its weights from the trained instance
modified_nn_deep_reg = ModifiedDeepNNRegressor(num_classes=12).to(device)
modified_nn_deep_reg.load_state_dict(nn_deep.state_dict())

# Train and test once again
wandb.init(
    project="HPML_KD",
    name=f"relation_KD_train",
    config={
    "learning_rate": 0.001,
    "architecture": "CNN",
    "dataset": "NASA Airport Throughput Challenge - FUSER Dataset",
    "epochs": 15,
    })
train_mse_loss(teacher=modified_nn_deep_reg, student=modified_nn_light_reg, train_loader=train_loader, epochs=15, learning_rate=0.001, feature_map_weight=0.25, loss_weight=0.75, device=device)
wandb.finish()

wandb.init(
    project="HPML_KD",
    name=f"relation_KD_test",
    config={
    "learning_rate": 0.001,
    "architecture": "CNN",
    "dataset": "NASA Airport Throughput Challenge - FUSER Dataset",
    # "epochs": 15,
    })
test_accuracy_light_mse_loss = test_multiple_outputs(modified_nn_light_reg, test_loader, device)
wandb.finish()

Epoch 1/15, Loss: -0.11737132677064537
Epoch 2/15, Loss: -0.2624121413539393
Epoch 3/15, Loss: -0.2819834760488413
Epoch 4/15, Loss: -0.29507492882565567
Epoch 5/15, Loss: -0.31700853812999236
Epoch 6/15, Loss: -0.3277685209965934
Epoch 7/15, Loss: -0.33239490260331395
Epoch 8/15, Loss: -0.3370368047453725
Epoch 9/15, Loss: -0.3376909791947173
Epoch 10/15, Loss: -0.34245735811539735
Epoch 11/15, Loss: -0.34240963331426677
Epoch 12/15, Loss: -0.3402695272105951
Epoch 13/15, Loss: -0.3390910074162407
Epoch 14/15, Loss: -0.3363130316852381
Epoch 15/15, Loss: -0.3389919742989464


0,1
loss,█▃▃▂▂▁▁▁▁▁▁▁▁▁▁

0,1
loss,-0.33899


Test acc: 0.6320
Elapsed time: 0.7479


0,1
time,▁

0,1
time,0.74786


In [46]:
print(f"Teacher accuracy: {test_accuracy_deep:.4f}%")
print(f"Student accuracy without teacher: {test_accuracy_light:.4f}%")
print(f"Student accuracy with KD: {test_accuracy_light_kd:.4f}%")
print(f"Student accuracy with CosineLoss: {test_accuracy_light_cosine_loss:.4f}%")
print(f"Student accuracy with RegressorMSE: {test_accuracy_light_mse_loss:.4f}%")

Teacher accuracy: 0.6357%
Student accuracy without teacher: 0.6378%
Student accuracy with KD: 0.6399%
Student accuracy with CosineLoss: 0.6408%
Student accuracy with RegressorMSE: 0.6320%
