In [1]:
# Torch Tutorial for knowledge distillation: https://pytorch.org/tutorials/beginner/knowledge_distillation_tutorial.html

In [39]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

import torchvision.transforms as transforms
import torchvision.datasets as datasets

In [4]:
# check if GPU is available, and if not, use the cpu
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## data_prepare

In [6]:
# Below we are preprocessing data for CIFAR-10. We use an arbitrary batch size of 128.
transforms_cifar = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Loading the CIFAR-10 dataset
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transforms_cifar)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transforms_cifar)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data\cifar-10-python.tar.gz


100%|██████████████████████████████████████████████████████████████| 170498071/170498071 [00:15<00:00, 11072661.09it/s]


Extracting ./data\cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [8]:
# Dataloaders
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2)

## Define Models

### DeepNN

In [12]:
# Deeper neural network class to be used as teacher
class DeepNN(nn.Module):
    def __init__(self, num_classes=10):
        super(DeepNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Linear(2048, 512),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(512, num_classes)
        )
    
    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

### LightNN

In [31]:
# Lightweight neural network class to be used as student
class LightNN(nn.Module):
    def __init__(self, num_classes=10):
        super(LightNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(16, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Linear(1024, 256),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(256, num_classes),
        )
    
    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

### training and testing funcitons

In [17]:
def train(model, train_loader, epochs, learning_rate, device):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    
    model.train()
    
    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in train_loader:
            # inputs: A collection of batch_size images
            # labels: A vector of dimensionality batch_size with integers denoting class of each image
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            
            # outputs: Output of the network for the collection of images. A tensor of dimensionality batch_size x num_classes
            # labels: The actual labels of the images. Vector of dimensionality batch_size
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            
        print(f"Epoch {epoch+1}/{epochs}, Train_Loss: {running_loss / len(train_loader)}")

In [18]:
def test(model, test_loader, device):
    model.to(device)
    model.eval()
    
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    print(f"Test Accuracy: {accuracy: .2f}%")
    return accuracy

### cross-entropy runs

#### running DeepNN

In [21]:
torch.manual_seed(42)
nn_deep = DeepNN(num_classes=10).to(device)
train(nn_deep, train_loader, epochs=10, learning_rate=0.001, device=device)
test_accuracy_deep = test(nn_deep, test_loader, device)

Epoch 1/10, Train_Loss: 1.3621877626994687
Epoch 2/10, Train_Loss: 0.9039473713511397
Epoch 3/10, Train_Loss: 0.7056784529972564
Epoch 4/10, Train_Loss: 0.5657938920018618
Epoch 5/10, Train_Loss: 0.437897097805272
Epoch 6/10, Train_Loss: 0.3359904433683971
Epoch 7/10, Train_Loss: 0.2495878839012607
Epoch 8/10, Train_Loss: 0.18941869723903554
Epoch 9/10, Train_Loss: 0.15507488672995506
Epoch 10/10, Train_Loss: 0.12575501867610475
Test Accuracy:  75.04%


#### running LightNN

In [32]:
# Initialize the lightweight network
torch.manual_seed(42)
nn_light = LightNN(num_classes=10).to(device)

# we instantiate one more lightweight model for knowledge distillation to compare their performances. 
# Back propagation is sensitive to weight initialilzation, so we need to make sure these two networks have the exact same initialization.
torch.manual_seed(42)
new_nn_light = LightNN(num_classes=10).to(device)

In [33]:
# print the norm of the first layer of the light weight models
print("Norm of 1st layer of nn_light:", torch.norm(nn_light.features[0].weight).item())
print("Norm of 1st layer of new_nn_light:", torch.norm(new_nn_light.features[0].weight).item())

Norm of 1st layer of nn_light: 2.327361822128296
Norm of 1st layer of new_nn_light: 2.327361822128296


In [35]:
# print the total number of parameters in each model
total_params_deep = "{:,}".format(sum(p.numel() for p in nn_deep.parameters()))
print(f"DeepNN parameters: {total_params_deep}")
total_params_light = "{:,}".format(sum(p.numel() for p in nn_light.parameters()))
print(f"LightNN parameters: {total_params_light}")

DeepNN parameters: 1,186,986
LightNN parameters: 267,738


In [36]:
# train and test lightweight network with cross entropy loss
train(nn_light, train_loader, epochs=10, learning_rate=0.001, device=device)
test_accuracy_light_ce = test(nn_light, test_loader, device)

Epoch 1/10, Train_Loss: 1.4584239163362156
Epoch 2/10, Train_Loss: 1.1506041994180216
Epoch 3/10, Train_Loss: 1.0183946664071144
Epoch 4/10, Train_Loss: 0.9245466835358563
Epoch 5/10, Train_Loss: 0.8482528150538959
Epoch 6/10, Train_Loss: 0.7804726214360094
Epoch 7/10, Train_Loss: 0.7131863570274295
Epoch 8/10, Train_Loss: 0.6577138483829206
Epoch 9/10, Train_Loss: 0.6036398121920388
Epoch 10/10, Train_Loss: 0.5515439357141705
Test Accuracy:  69.73%


In [38]:
print(f"Teacher accuracy: {test_accuracy_deep:.2f}%")
print(f"Student accuracy: {test_accuracy_light_ce:.2f}%")

Teacher accuracy: 75.04%
Student accuracy: 69.73%


### training knowledge distillation

In [45]:
def train_knowledge_distillation(teacher, student, train_loader, epochs, learning_rate, T, soft_target_loss_weight, 
                                ce_loss_weight, device):
    ce_loss = nn.CrossEntropyLoss()
    optimizer = optim.Adam(student.parameters(), lr=learning_rate)
    
    teacher.eval()  # Teacher set to evaluation mode
    student.train()  # Student to train mode
    
    for epoch in range(epochs): 
        running_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            # Forward pass with the teacher model - do not save gradients here as we do not change the teacher's weights
            with torch.no_grad():
                teacher_logits = teacher(inputs)
            
            # Forward pass with the student model
            student_logits = student(inputs)
            
            # Soften the student logits by applying softmax first and log() second
            soft_targets = F.softmax(teacher_logits / T, dim=-1)
            soft_prob = F.log_softmax(student_logits / T, dim=-1)
            
            # Calculate the soft targets loss. Scaled by T**2 as suggested by the authors of the paper "Distilling the knowledge in a neural network"
            soft_targets_loss = torch.sum(soft_targets * (soft_targets.log() - soft_prob)) / soft_prob.size()[0] * (T**2)
            
            # Calculate the true label loss
            label_loss = ce_loss(student_logits, labels)
            
            # Weighted sum of the two losses
            loss = soft_target_loss_weight * soft_targets_loss + ce_loss_weight * label_loss
            
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")

In [46]:
# apply "train knowledge distillation" with a temperature of 2.
train_knowledge_distillation(teacher=nn_deep, student=new_nn_light, train_loader=train_loader, epochs=10, learning_rate=0.001,
                            T=2, soft_target_loss_weight=0.25, ce_loss_weight=0.75, device=device)
test_accuracy_light_ce_and_kd = test(new_nn_light, test_loader, device)

# Compare the student test accuracy with and without the teacher, after distillation
print(f"Teacher accuracy: {test_accuracy_deep:.2f}%")
print(f"Student accuracy without teacher: {test_accuracy_light_ce:.2f}%")
print(f"Student accuracy with CE + KD: {test_accuracy_light_ce_and_kd:.2f}%")

Epoch 1/10, Loss: 2.3944996811849686
Epoch 2/10, Loss: 1.8428954312868435
Epoch 3/10, Loss: 1.6162752291125715
Epoch 4/10, Loss: 1.4622494927452654
Epoch 5/10, Loss: 1.3339652150793149
Epoch 6/10, Loss: 1.216781522916711
Epoch 7/10, Loss: 1.1258911095616762
Epoch 8/10, Loss: 1.0300040181030703
Epoch 9/10, Loss: 0.9425400297355164
Epoch 10/10, Loss: 0.8669911873005235
Test Accuracy:  71.04%
Teacher accuracy: 75.04%
Student accuracy without teacher: 69.73%
Student accuracy with CE + KD: 71.04%
