#### Knowledge Distillation

Knowledge is a training technique where small models are trained based on knowledge transfer from larger and computationally more expensive models without losing validity. This allows deployment on smaller and less powerful hardware leading to faster inference and more efficient evaluation.

In this notebook we cover experiments focused on comparing accuracy of lightweight neural network with a more powerful network using distillation.

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets

device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else 'cpu'
print(f"Using {device} device")

Using cuda device


Load dataset (CIFAR-10)

In [3]:
# Preprocessing input Normalization
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.456, 0.456, 0.406), std=(0.229, 0.224, 0.225))
])

# load cifar-10
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

train_loaders = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loaders = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)

100%|██████████| 170M/170M [00:04<00:00, 41.3MB/s]


Define the model

In [14]:
# Teacher Model
class TeacherModel(nn.Module):
    def __init__(self, num_classes=10):
        super(TeacherModel, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Linear(32 * 8 * 8, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(512, num_classes),
        )
    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        # x = x.view(x.size(0), -1)
        # x = self.classifier(x)
        return x


class StudentModel(nn.Module):
    def __init__(self, num_classes=10):
        super(StudentModel, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Linear(16 * 8 * 8, 256),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(256, num_classes),
        )
    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

Create training loop

In [5]:
def train(model, train_loader, epochs, learning_rate, device):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    model.train()

    for epoch in range(epochs):
        running_loss = 0.0
        for i, data in enumerate(train_loader, 0):
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            # zero the gradients
            optimizer.zero_grad()
            logits = model(inputs)
            # calculate loss
            loss = criterion(logits, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
        print(f"Epoch {epoch + 1}/{epochs}, Loss: {running_loss / len(train_loader)}")

def test(model, test_loader, device):
    model.to(device)
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for data in test_loader:
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            logits = model(inputs)
            _, predicted = torch.max(logits.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f"Test Accuracy: {accuracy:2f}%")
    return accuracy

torch.manual_seed(42)
# Instantiate, train and evaluate the teacher model
teacher = TeacherModel().to(device)
train(teacher, train_loaders, 10, 0.001, device)
test_accuracy_teacher = test(teacher, test_loaders, device)

# Instantiate the student model
torch.manual_seed(42)
student_1 = StudentModel().to(device)

Epoch 1/10, Loss: 1.3964465220870874
Epoch 2/10, Loss: 0.9504235522521426
Epoch 3/10, Loss: 0.792588456207529
Epoch 4/10, Loss: 0.6896231275842623
Epoch 5/10, Loss: 0.6144788938638804
Epoch 6/10, Loss: 0.5494618303787983
Epoch 7/10, Loss: 0.49458054317842665
Epoch 8/10, Loss: 0.4420514222606064
Epoch 9/10, Loss: 0.40226563325394754
Epoch 10/10, Loss: 0.36865541019746106
Test Accuracy: 76.730000%


In [16]:
# instantiate one more student model to compare performance
torch.manual_seed(42)
student_2 = StudentModel().to(device)

In [8]:
# To ensure we have two different student models
print("Norm of 1st layer of student-1 model:", torch.norm(student_1.features[0].weight))
print("Norm of 1st layer of student-2 model:", torch.norm(student_2.features[0].weight))

Norm of 1st layer of student-1 model: tensor(2.3274, device='cuda:0', grad_fn=<LinalgVectorNormBackward0>)
Norm of 1st layer of student-2 model: tensor(2.3274, device='cuda:0', grad_fn=<LinalgVectorNormBackward0>)


In [17]:
# Compare parameters of the teacher and student models
total_params_teacher = sum(p.numel() for p in teacher.parameters())
print(f"Total parameters in teacher model: {total_params_teacher:,}")

total_params_student = sum(p.numel() for p in student_1.parameters())
print(f"Total parameters in student model: {total_params_student:,}")

Total parameters in teacher model: 1,186,986
Total parameters in student model: 267,738


In [18]:
# Train the student network
train(student_1, train_loaders, 10, 0.001, device)
test_accuracy_student_1 = test(student_1, test_loaders, device)

Epoch 1/10, Loss: 1.5056005944986173
Epoch 2/10, Loss: 1.2171901727423948
Epoch 3/10, Loss: 1.1019190862355634
Epoch 4/10, Loss: 1.0210292486431043
Epoch 5/10, Loss: 0.966056467474574
Epoch 6/10, Loss: 0.9174672881965442
Epoch 7/10, Loss: 0.8839037964868424
Epoch 8/10, Loss: 0.8416863677218137
Epoch 9/10, Loss: 0.8139965129859003
Epoch 10/10, Loss: 0.7895370634163127
Test Accuracy: 70.150000%


In [19]:
# Compare teacher accuracy with student accuracy
print(f"Teacher accuracy: {test_accuracy_teacher:.2f}%")
print(f"Student accuracy: {test_accuracy_student_1:.2f}%")

Teacher accuracy: 76.73%
Student accuracy: 70.15%


#### Knowledge Distillation

Knowledge distillation incorporates additional loss into the traditional crossentropy loss which based on the softmax output of the teacher network. We assume that the output activate of the teacher network carries additional information that can be leverage by the student network.

In [21]:
def train_knowledge_distillation(teacher_model, student_model, train_loader, epochs, learning_rate, device, alpha=0.5, temperature=3, soft_target_loss_weight=0.5, ce_loss_weight=0.5):
    teacher_model.eval()
    student_model.train()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(student_model.parameters(), lr=learning_rate)

    for epoch in range(epochs):
        running_loss = 0.0
        for i, data in enumerate(train_loader, 0):
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            with torch.no_grad():
                teacher_logits = teacher_model(inputs)

            student_logits = student_model(inputs)

            # apply softmax to teacher logits
            teacher_softmax = nn.functional.softmax(teacher_logits / temperature, dim=1)
            # apply log on the softmax output from the student logits
            student_prob = nn.functional.log_softmax(student_logits / temperature, dim=1)

            # calculate the soft target loss
            soft_targets_loss = torch.sum(teacher_softmax * (teacher_softmax.log() - student_prob)) / student_prob.size()[0] * (temperature ** 2)

            # true label loss
            true_labels_loss = criterion(student_logits, labels)

            # Weighted sum of two losses
            loss = soft_target_loss_weight * soft_targets_loss + ce_loss_weight * true_labels_loss

            loss.backward()
            optimizer.step()

            running_loss += loss.item()
        print(f"Epoch {epoch + 1}/{epochs}, Loss: {running_loss / len(train_loader)}")

# Train with Knowledge distillation
train_knowledge_distillation(teacher, student_2, train_loaders, epochs=10, learning_rate=0.001, temperature=2, soft_target_loss_weight=0.25, ce_loss_weight=0.75, device=device)
test_student_accuracy_ce_and_kd = test(student_2, test_loaders, device)

# Compare the student test accuracy with and without the teacher, after distillation
print(f"Teacher accuracy: {test_accuracy_teacher:.2f}%")
print(f"Student accuracy without teacher: {test_accuracy_student_1:.2f}%")
print(f"Student accuracy with CE + KD: {test_student_accuracy_ce_and_kd:.2f}%")

Epoch 1/10, Loss: 1.4863279846013355
Epoch 2/10, Loss: 1.3549058321491836
Epoch 3/10, Loss: 1.2576321841353346
Epoch 4/10, Loss: 1.1868097400268935
Epoch 5/10, Loss: 1.1198706733601174
Epoch 6/10, Loss: 1.0753767711427205
Epoch 7/10, Loss: 1.0390705127088005
Epoch 8/10, Loss: 1.0074640433196826
Epoch 9/10, Loss: 0.9701840968235679
Epoch 10/10, Loss: 0.9427998823582974
Test Accuracy: 70.580000%
Teacher accuracy: 76.73%
Student accuracy without teacher: 70.15%
Student accuracy with CE + KD: 70.58%
