<a href="https://colab.research.google.com/github/sfarrukhm/making_models_small/blob/main/knowledge_distillation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms
import torchvision


In [4]:
# Below we are preprocessing data for CIFAR-10. We use an arbitrary batch size of 128.
transform = torchvision.transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.50,0.5,0.5], std=[0.50,0.5,0.5]),
])

# Loading the CIFAR-10 dataset:

train_dataset =datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset =  datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
# train_dataset = torch.utils.data.Subset((datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)),range(10000)) # Changed transforms to transform
# test_dataset =  torch.utils.data.Subset(datasets.CIFAR10(root='./data', train=False, download=True, transform=transform),range(2000)) # Changed transforms to transform
#Dataloaders
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=2)


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:08<00:00, 20.4MB/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [5]:
# construct the teacher model
class TeacherModel(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.features=nn.Sequential(
            nn.Conv2d(3, 128, kernel_size=3,
                      padding=1),
            nn.ReLU(),
            nn.Conv2d(128,64,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(64,64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64,32,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.classifier=nn.Sequential(
            nn.Linear(2048,512),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(512, num_classes)
        )
    def forward(self,x):
        x=self.features(x)
        x=torch.flatten(x,1)
        x=self.classifier(x)
        return x


## Student Model (way lighter than the teacher model)
class StudentModel(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.features=nn.Sequential(
            nn.Conv2d(3,10,kernel_size=3, padding=1),
            nn.ReLU(),

            nn.MaxPool2d(kernel_size=2,stride=2),
            nn.Conv2d(10,10,kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.classifier=nn.Sequential(
            nn.Linear(640,256),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(256,num_classes)
        )
    def forward(self,x):
        x=self.features(x)
        x=torch.flatten(x,1)
        x=self.classifier(x)
        return x


In [6]:
# training the teacher and student model before the distiallation with cross entropy
# optimizer=torch.optim.Adam(m)
device="cuda" if torch.cuda.is_available() else 'cpu'
from collections import defaultdict
log_dict=defaultdict(list)

def train(model, train_loader, num_epochs, learning_rate,device, save_model_path=None):
    loss_fn = torch.nn.CrossEntropyLoss()
    optimizer=torch.optim.Adam(model.parameters(),lr=learning_rate)

    model.to(device)
    model.train()
    for epoch in range(num_epochs):
        running_loss=0
        for images, labels in train_loader:
            images,labels=images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs=model(images)
            loss=loss_fn(outputs, labels)
            loss.backward()

            optimizer.step()
            running_loss+=loss.item()

        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss / len(train_loader)}")
    if save_model_path is not None:
        torch.save(model.state_dict(),save_model_path)

def test(model, test_loader, device):
    model.to(device)
    model.eval()

    correct_predictions=0
    total=0
    with torch.no_grad():
        for images, labels in test_loader:
            images,labels=images.to(device), labels.to(device)
            outputs=model(images)

            _, predicted = torch.max(outputs, 1)

            total+=labels.size(0)
            correct_predictions+=(labels==predicted).sum()

    accuracy=100*correct_predictions/total
    print(f"Total correct predictions: {correct_predictions}")
    print(f"Total labels: {total}")
    print(f"Test Accuracy: {accuracy}")

    return correct_predictions, total, accuracy


In [None]:
## Cross-entropy run
# teacher traing
torch.manual_seed(2342)
save_path="/content/drive/MyDrive/deep_generative_models/trained_models/knowledge_distillation"
# teacher_model=TeacherModel(num_classes=10).to(device)
# train(teacher_model, train_loader,
#       10, 0.001,device=device, save_model_path=save_path+"/cifar_teacher_v1.pt")
# torch.save(teacher_model.state_dict(),save_path+"/cifar_teacher_v1.pt")

In [None]:
# testing the teacher
test_teacher = test(teacher_model, test_loader, device)

Total correct predictions: 7156
Total labels: 10000
Test Accuracy: 71.55999755859375


In [40]:
# student traing but without the support fo teacher
torch.manual_seed(2342)
save_path="/content/drive/MyDrive/deep_generative_models/knowledge_distillation"
student_model=StudentModel(num_classes=10).to(device)
train(student_model, train_loader,
      10, 0.001,device=device, save_model_path=save_path+"/cifar_student_wo_teacher_v1.pt")

Epoch 1/10, Loss: 1.5366986973206405
Epoch 2/10, Loss: 1.2394564447500516
Epoch 3/10, Loss: 1.1167083738557517
Epoch 4/10, Loss: 1.0352113188989938
Epoch 5/10, Loss: 0.9662601889094429
Epoch 6/10, Loss: 0.9080696434849669
Epoch 7/10, Loss: 0.8518675702916997
Epoch 8/10, Loss: 0.8079483616535011
Epoch 9/10, Loss: 0.7643249590912133
Epoch 10/10, Loss: 0.7205251159570406


In [41]:
test_teacher = test(student_model, test_loader, device)

Total correct predictions: 6766
Total labels: 10000
Test Accuracy: 67.65999603271484


### Distillation Run

In [None]:
# load the trained teacher model
teacher_model=TeacherModel()
teacher_state_dict=torch.load("/content/drive/MyDrive/deep_generative_models/knowledge_distillation/cifar_teacher_v1.pt",map_location=device)
teacher_model.load_state_dict(teacher_state_dict)
teacher_model

In [38]:
def train_knowledge_distillation(teacher,student, train_loader, num_epochs,
                                 learning_rate, temperature, soft_training_loss_weight,
                                 ce_loss_weight, device):
    teacher.eval()
    teacher.to(device)
    student.train()
    student.to(device)
    optimizer = torch.optim.Adam(student.parameters(), lr=learning_rate)

    loss_fn=torch.nn.CrossEntropyLoss()

    for epoch in range(num_epochs):
        running_loss=0
        for images, labels in train_loader:
            images,labels=images.to(device), labels.to(device)

            optimizer.zero_grad()
            with torch.no_grad():
                teacher_logits=teacher(images)

            student_logits=student(images)

            soft_targets=torch.softmax(teacher_logits/temperature,
                                                dim=-1)
            soft_probs=torch.softmax(student_logits/temperature,dim=-1)

            # porbability distribution loss
            kl_div_loss= torch.sum(soft_targets*(soft_targets.log()-soft_probs.log()))/soft_probs.size(0)*temperature**2  #Kullback-Leibler (KL) divergence between two probabilit distributions modeling the same random variable

            # classification loss which is cross-entropy loss
            ce_loss=loss_fn(student_logits, labels)

            # weighted sum of the two losses
            loss=soft_training_loss_weight*kl_div_loss + ce_loss_weight*ce_loss

            loss.backward()

            optimizer.step()

            running_loss+=loss.item()
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss / len(train_loader)}")


# starting teaching the student
student_model=StudentModel()
torch.manual_seed(2342)
train_knowledge_distillation(teacher=teacher_model, student=student_model, train_loader=train_loader, num_epochs=10, learning_rate=0.001, temperature=7,
                             soft_training_loss_weight=0.1, ce_loss_weight=0.9, device=device)


torch.save(student_model.state_dict(),"/content/drive/MyDrive/deep_generative_models/knowledge_distillation/cifar_student_trained_with_teacher.pt")




Epoch 1/10, Loss: 3.6267738403261776
Epoch 2/10, Loss: 2.92132333599393
Epoch 3/10, Loss: 2.6230877584508616
Epoch 4/10, Loss: 2.4164173636595003
Epoch 5/10, Loss: 2.2566777353396503
Epoch 6/10, Loss: 2.102309353668671
Epoch 7/10, Loss: 1.9779516960044041
Epoch 8/10, Loss: 1.8689032418038838
Epoch 9/10, Loss: 1.7709855746735088
Epoch 10/10, Loss: 1.687798717747564


In [39]:
# test student accuracy, temp=5
test_student = test(student_model, test_loader, device)

Total correct predictions: 6535
Total labels: 10000
Test Accuracy: 65.3499984741211


In [37]:
# test student accuracy, temp=7
test_student = test(student_model, test_loader, device)

Total correct predictions: 6749
Total labels: 10000
Test Accuracy: 67.48999786376953


In [35]:
# test student accuracy, temp=13
test_student = test(student_model, test_loader, device)

Total correct predictions: 6639
Total labels: 10000
Test Accuracy: 66.38999938964844


In [33]:
# test student accuracy, temp=10
test_student = test(student_model, test_loader, device)

Total correct predictions: 6721
Total labels: 10000
Test Accuracy: 67.20999908447266


In [31]:
# test student accuracy
test_student = test(student_model, test_loader, device)

Total correct predictions: 6631
Total labels: 10000
Test Accuracy: 66.30999755859375


In [29]:
# test student accuracy
test_student = test(student_model, test_loader, device)

Total correct predictions: 6573
Total labels: 10000
Test Accuracy: 65.72999572753906
