In [18]:
%matplotlib inline

In [19]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets



In [35]:
import torchvision.transforms as transforms
import torchvision.datasets as datasets

transforms_cifar = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(
    mean=[0.4914, 0.4822, 0.4465],
    std=[0.2470, 0.2435, 0.2616]
)
])
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transforms_cifar)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transforms_cifar)


100%|██████████| 170M/170M [00:04<00:00, 41.3MB/s]


In [36]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2)

In [44]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets


class LightNN(nn.Module):
    def __init__(self, num_classes=10): # Changed num_classes to 100 for CIFAR10
        super(LightNN, self).__init__()
        self.features = nn.Sequential(
            # ... (your existing feature layers)
        )
        self.classifier = nn.Sequential(
            nn.Linear(3072, 256), # Changed input size to 4096
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(256, num_classes),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x



In [39]:
import random
import numpy as np
import torch


def train(model, train_loader, epochs, learning_rate, device):
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(nn_deep.parameters(), lr=learning_rate, momentum=0.9, weight_decay=5e-4)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

    model.train()

    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in train_loader:

            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)


            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")

def test(model, test_loader, device):
    random.seed(42)
    np.random.seed(42)
    torch.manual_seed(42)
    torch.cuda.manual_seed_all(42)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()

    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)

            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f"Test Accuracy: {accuracy:.2f}%")
    return accuracy

In [40]:
!git clone --recursive https://github.com/chenyaofo/image-classification-codebase
%cd image-classification-codebase

%pip install -qr requirements.txt

import torch
from IPython.display import clear_output

clear_output()
print('Setup complete. Using torch %s %s' % (torch.__version__, torch.cuda.get_device_properties(0) if torch.cuda.is_available() else 'CPU'))

device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
print(f"Using {device} device")
torch.manual_seed(42)
nn_deep =torch.hub.load("chenyaofo/pytorch-cifar-models", "cifar10_resnet20", pretrained=True)
test_accuracy_deep = test(nn_deep, test_loader, device)

torch.manual_seed(42)
nn_light = LightNN(num_classes=10).to(device)

Setup complete. Using torch 2.6.0+cu124 _CudaDeviceProperties(name='Tesla T4', major=7, minor=5, total_memory=15095MB, multi_processor_count=40, uuid=d303970b-8fd4-18b4-1938-c9e4c60eef17, L2_cache_size=4MB)
Using cuda device


Using cache found in /root/.cache/torch/hub/chenyaofo_pytorch-cifar-models_master


Test Accuracy: 92.12%


In [41]:
torch.manual_seed(42)
new_nn_light = LightNN(num_classes=10).to(device)

In [42]:
 total_params_deep = "{:,}".format(sum(p.numel() for p in nn_deep.parameters()))
print(f"DeepNN parameters: {total_params_deep}")
total_params_light = "{:,}".format(sum(p.numel() for p in nn_light.parameters()))
print(f"LightNN parameters: {total_params_light}")

DeepNN parameters: 272,474
LightNN parameters: 789,258


In [45]:
train(nn_light, train_loader, epochs=30, learning_rate=0.01, device=device)
test_accuracy_light_ce = test(nn_light, test_loader, device)

Epoch 1/30, Loss: 2.3077011291328295
Epoch 2/30, Loss: 2.307016994032409
Epoch 3/30, Loss: 2.307546752188212
Epoch 4/30, Loss: 2.307515230934943
Epoch 5/30, Loss: 2.307847381552772
Epoch 6/30, Loss: 2.307508619849944
Epoch 7/30, Loss: 2.30776570154273
Epoch 8/30, Loss: 2.3071448272451414
Epoch 9/30, Loss: 2.3078355685524317
Epoch 10/30, Loss: 2.307326365614791
Epoch 11/30, Loss: 2.3071981488591264
Epoch 12/30, Loss: 2.307311330000153
Epoch 13/30, Loss: 2.3071203664745514
Epoch 14/30, Loss: 2.3082754611968994
Epoch 15/30, Loss: 2.3077036700285305
Epoch 16/30, Loss: 2.306735718646623
Epoch 17/30, Loss: 2.307705093832577
Epoch 18/30, Loss: 2.3071521749276944
Epoch 19/30, Loss: 2.3072343183600386
Epoch 20/30, Loss: 2.3073886611577494
Epoch 21/30, Loss: 2.307452281112866
Epoch 22/30, Loss: 2.3077442481389743
Epoch 23/30, Loss: 2.307466950867792
Epoch 24/30, Loss: 2.3072600950060598
Epoch 25/30, Loss: 2.3076211741513304
Epoch 26/30, Loss: 2.3073247729055106
Epoch 27/30, Loss: 2.3075498748008

In [46]:
print(f"Teacher accuracy: {test_accuracy_deep:.2f}%")
print(f"Student accuracy: {test_accuracy_light_ce:.2f}%")

Teacher accuracy: 92.12%
Student accuracy: 11.07%


In [47]:
def train_knowledge_distillation(teacher, student, train_loader, epochs, learning_rate, T, soft_target_loss_weight, ce_loss_weight, device):
    ce_loss = nn.CrossEntropyLoss()
    optimizer = optim.Adam(student.parameters(), lr=learning_rate)

    teacher.eval()
    student.train()

    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()


            with torch.no_grad():
                teacher_logits = teacher(inputs)


            student_logits = student(inputs)
            teacher_logits = teacher_logits[:, :student_logits.shape[1]]
            soft_targets = nn.functional.softmax(teacher_logits / T, dim=-1)
            soft_prob = nn.functional.log_softmax(student_logits / T, dim=-1)

            soft_targets_loss = nn.functional.kl_div(soft_prob, soft_targets, reduction='batchmean') * (T**2)


            #soft_targets_loss = torch.sum(soft_targets * (soft_targets.log() - soft_prob)) / soft_prob.size()[10] * (T**2)


            label_loss = ce_loss(student_logits, labels)


            loss = soft_target_loss_weight * soft_targets_loss + ce_loss_weight * label_loss

            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")

train_knowledge_distillation(teacher=nn_deep, student=new_nn_light, train_loader=train_loader, epochs=25, learning_rate=0.001, T=2, soft_target_loss_weight=0.25, ce_loss_weight=0.75, device=device)
test_accuracy_light_ce_and_kd = test(new_nn_light, test_loader, device)

print(f"Teacher accuracy: {test_accuracy_deep:.2f}%")
print(f"Student accuracy without teacher: {test_accuracy_light_ce:.2f}%")
print(f"Student accuracy with CE + KD: {test_accuracy_light_ce_and_kd:.2f}%")

Epoch 1/25, Loss: 2.814057553210832
Epoch 2/25, Loss: 2.501636771594777
Epoch 3/25, Loss: 2.3669612535735225
Epoch 4/25, Loss: 2.283060855877674
Epoch 5/25, Loss: 2.2059308778294517
Epoch 6/25, Loss: 2.136520601599418
Epoch 7/25, Loss: 2.0705720521604922
Epoch 8/25, Loss: 2.0038733140891773
Epoch 9/25, Loss: 1.960922350054202
Epoch 10/25, Loss: 1.9123723650222544
Epoch 11/25, Loss: 1.8649200201034546
Epoch 12/25, Loss: 1.8276193205962705
Epoch 13/25, Loss: 1.7769482022965961
Epoch 14/25, Loss: 1.7348025564647391
Epoch 15/25, Loss: 1.709314928335302
Epoch 16/25, Loss: 1.6597290231138848
Epoch 17/25, Loss: 1.6365282069081846
Epoch 18/25, Loss: 1.5961654750282501
Epoch 19/25, Loss: 1.5611174740754734
Epoch 20/25, Loss: 1.5328156804794546
Epoch 21/25, Loss: 1.512585202141491
Epoch 22/25, Loss: 1.4759798510300228
Epoch 23/25, Loss: 1.4561062615240932
Epoch 24/25, Loss: 1.4343591644940779
Epoch 25/25, Loss: 1.402238358926895
Test Accuracy: 53.87%
Teacher accuracy: 92.12%
Student accuracy wit

In [48]:
import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F

In [49]:

print(list(new_nn_light.named_parameters()))

[('classifier.0.weight', Parameter containing:
tensor([[-0.0213,  0.0464,  0.0252,  ..., -0.0625, -0.0520, -0.0807],
        [-0.0159,  0.0659,  0.0684,  ..., -0.0203, -0.0667, -0.0614],
        [ 0.0298,  0.0016, -0.0663,  ...,  0.1611,  0.1309,  0.0628],
        ...,
        [ 0.0647, -0.0283, -0.0034,  ..., -0.0483, -0.1729, -0.1530],
        [-0.1118, -0.0651, -0.0721,  ...,  0.1015,  0.0606,  0.1160],
        [ 0.0406,  0.0104, -0.0519,  ..., -0.0704, -0.0019,  0.0761]],
       device='cuda:0', requires_grad=True)), ('classifier.0.bias', Parameter containing:
tensor([-7.4757e-01, -1.1043e+00, -1.4975e+00, -6.0935e-01, -9.3079e-01,
         1.0980e+00, -1.6646e+00, -1.4790e+00, -1.0734e+00, -1.1514e-01,
        -2.7319e-01, -8.1674e-01, -1.6549e+00, -6.7103e-01, -1.0799e+00,
        -6.9657e-01, -1.5094e+00, -4.8473e-01, -3.5016e-01, -3.2103e-01,
        -1.5437e+00, -7.3409e-01, -1.0916e+00, -8.3180e-01, -1.2105e+00,
        -1.6960e+00, -9.2564e-01, -2.3219e+00, -4.6447e-01, -1.7

In [50]:
import torch.nn.utils.prune as prune
def prune_by_magnitude(model, amount):

    for name, module in model.named_modules():
        if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d)):
            # Get the weight tensor
            weight = module.weight
            # Find the indices of the top k largest weights
            num_prune = int(weight.size(0) * amount)
            _, indices = torch.topk(torch.abs(weight.view(-1)), num_prune)  # Flatten weights
            # Create a pruning mask
            mask = torch.ones_like(weight.view(-1))
            mask[indices] = 0
            mask = mask.view(weight.shape)
            # Apply the mask
            prune.custom_from_mask(module, name='weight', mask=mask)

prune_by_magnitude(new_nn_light, amount=0.5)  # Adjust the amount as needed

test_accuracy_pruned_light_with_kd = test(new_nn_light, test_loader, device)
print(f"pruned student accuracy with KD: {test_accuracy_pruned_light_with_kd:.2f}%")
print(f"Teacher accuracy: {test_accuracy_deep:.2f}%")
print(f"Student accuracy without teacher: {test_accuracy_light_ce:.2f}%")
print(f"Student accuracy with CE + KD: {test_accuracy_light_ce_and_kd:.2f}%")

Test Accuracy: 53.73%
pruned student accuracy with KD: 53.73%
Teacher accuracy: 92.12%
Student accuracy without teacher: 11.07%
Student accuracy with CE + KD: 53.87%


*italicized text*

---

Increase of 380% in accuracy