<a href="https://colab.research.google.com/github/sfarrukhm/making_models_small/blob/main/knowledge_distillation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms
import torchvision


In [2]:
# Below we are preprocessing data for CIFAR-10. We use an arbitrary batch size of 128.
transform = torchvision.transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.50,0.5,0.5], std=[0.50,0.5,0.5]),
])

# Loading the CIFAR-10 dataset:

train_dataset =datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset =  datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
# train_dataset = torch.utils.data.Subset((datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)),range(10000)) # Changed transforms to transform
# test_dataset =  torch.utils.data.Subset(datasets.CIFAR10(root='./data', train=False, download=True, transform=transform),range(2000)) # Changed transforms to transform
#Dataloaders
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=2)


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:02<00:00, 77.8MB/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [3]:
# construct the teacher model
class TeacherModel(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.features=nn.Sequential(
            nn.Conv2d(3, 128, kernel_size=3,
                      padding=1),
            nn.ReLU(),
            nn.Conv2d(128,64,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(64,64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64,32,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.classifier=nn.Sequential(
            nn.Linear(2048,512),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(512, num_classes)
        )
    def forward(self,x):
        x=self.features(x)
        x=torch.flatten(x,1)
        x=self.classifier(x)
        return x


## Student Model (way lighter than the teacher model)
class StudentModel(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.features=nn.Sequential(
            nn.Conv2d(3,10,kernel_size=3, padding=1),
            nn.ReLU(),

            nn.MaxPool2d(kernel_size=2,stride=2),
            nn.Conv2d(10,10,kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.classifier=nn.Sequential(
            nn.Linear(640,256),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(256,num_classes)
        )
    def forward(self,x):
        x=self.features(x)
        x=torch.flatten(x,1)
        x=self.classifier(x)
        return x


In [4]:
# training the teacher and student model before the distiallation with cross entropy
# optimizer=torch.optim.Adam(m)
device="cuda" if torch.cuda.is_available() else 'cpu'
from collections import defaultdict
log_dict=defaultdict(list)

def train(model, train_loader, num_epochs, learning_rate,device, save_model_path=None):
    loss_fn = torch.nn.CrossEntropyLoss()
    optimizer=torch.optim.Adam(model.parameters(),lr=learning_rate)

    model.to(device)
    model.train()
    for epoch in range(num_epochs):
        running_loss=0
        for images, labels in train_loader:
            images,labels=images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs=model(images)
            loss=loss_fn(outputs, labels)
            loss.backward()

            optimizer.step()
            running_loss+=loss.item()

        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss / len(train_loader)}")
    if save_model_path is not None:
        torch.save(model.state_dict(),save_model_path)

def test(model, test_loader, device):
    model.to(device)
    model.eval()

    correct_predictions=0
    total=0
    with torch.no_grad():
        for images, labels in test_loader:
            images,labels=images.to(device), labels.to(device)
            outputs=model(images)

            _, predicted = torch.max(outputs, 1)

            total+=labels.size(0)
            correct_predictions+=(labels==predicted).sum()

    accuracy=100*correct_predictions/total
    print(f"Total correct predictions: {correct_predictions}")
    print(f"Total labels: {total}")
    print(f"Test Accuracy: {accuracy}")

    return correct_predictions, total, accuracy


In [15]:
## Cross-entropy run
# teacher traing
torch.manual_seed(2342)
save_path="/content/drive/MyDrive/deep_generative_models/trained_models/knowledge_distillation"
# teacher_model=TeacherModel(num_classes=10).to(device)
# train(teacher_model, train_loader,
#       10, 0.001,device=device, save_model_path=save_path+"/cifar_teacher_v1.pt")
# torch.save(teacher_model.state_dict(),save_path+"/cifar_teacher_v1.pt")

In [17]:
# testing the teacher
test_teacher = test(teacher_model, test_loader, device)

Total correct predictions: 7156
Total labels: 10000
Test Accuracy: 71.55999755859375


In [18]:
# student traing but without the support fo teacher
torch.manual_seed(2342)
save_path="/content/drive/MyDrive/deep_generative_models/trained_models/knowledge_distillation"
# student_model=StudentModel(num_classes=10).to(device)
# train(student_model, train_loader,
#       10, 0.001,device=device, save_model_path=save_path+"/cifar_student_wo_teacher_v1.pt")

Epoch 1/10, Loss: 1.5369352017674605
Epoch 2/10, Loss: 1.2399079882732742
Epoch 3/10, Loss: 1.1173738487388776
Epoch 4/10, Loss: 1.0368778001316978
Epoch 5/10, Loss: 0.9659609710011641
Epoch 6/10, Loss: 0.9065325933191782
Epoch 7/10, Loss: 0.8526395525011565
Epoch 8/10, Loss: 0.8091558261448161
Epoch 9/10, Loss: 0.7642902602320132
Epoch 10/10, Loss: 0.7214994926358123


In [20]:
# test_teacher = test(student_model, test_loader, device)

Total correct predictions: 6786
Total labels: 10000
Test Accuracy: 67.86000061035156


### Distillation Run

In [28]:
# load the trained teacher model
teacher_model=TeacherModel()
teacher_state_dict=torch.load("/content/drive/MyDrive/deep_generative_models/trained_models/knowledge_distillation/cifar_teacher_v1.pt",map_location=device)
teacher_model.load_state_dict(teacher_state_dict)
teacher_model

  teacher_state_dict=torch.load("/content/drive/MyDrive/deep_generative_models/trained_models/knowledge_distillation/cifar_teacher_v1.pt",map_location=device)


TeacherModel(
  (features): Sequential(
    (0): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU()
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU()
    (7): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU()
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (classifier): Sequential(
    (0): Linear(in_features=2048, out_features=512, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.1, inplace=False)
    (3): Linear(in_features=512, out_features=10, bias=True)
  )
)

In [32]:
def train_knowledge_distillation(teacher,student, train_loader, num_epochs,
                                 learning_rate, temperature, soft_training_loss_weight,
                                 ce_loss_weight, device):
    teacher.eval()
    teacher.to(device)
    student.train()
    student.to(device)
    optimizer = torch.optim.Adam(student.parameters(), lr=learning_rate)

    loss_fn=torch.nn.CrossEntropyLoss()

    for epoch in range(num_epochs):
        running_loss=0
        for images, labels in train_loader:

            optimizer.zero_grad()
            with torch.no_grad():
                teacher_logits=teacher(images)

            student_logits=student(images)

            soft_targets=torch.softmax(teacher_logits/temperature,
                                                dim=-1)
            soft_probs=torch.softmax(student_logits/temperature,dim=-1)

            # porbability distribution loss
            kl_div_loss= torch.sum(soft_targets*(soft_targets.log()-soft_probs.log()))/soft_probs.size(0)*temperature**2  #Kullback-Leibler (KL) divergence between two probabilit distributions modeling the same random variable

            # classification loss which is cross-entropy loss
            ce_loss=loss_fn(student_logits, labels)

            # weighted sum of the two losses
            loss=soft_training_loss_weight*kl_div_loss + ce_loss_weight*ce_loss

            loss.backward()

            optimizer.step()

            running_loss+=loss.item()
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss / len(train_loader)}")


# starting teaching the student
student_model=StudentModel()
train_knowledge_distillation(teacher=teacher_model, student=student_model, train_loader=train_loader, num_epochs=10, learning_rate=0.001, temperature=2,
                             soft_training_loss_weight=0.25, ce_loss_weight=0.75, device=device)


torch.save(student_model.state_dict(),"/content/drive/MyDrive/deep_generative_models/trained_models/knowledge_distillation/cifar_student_trained_with_teacher.pt")




ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.



Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/IPython/core/interactiveshell.py", line 3553, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-32-b5e6b73b9d07>", line 45, in <cell line: 0>
    train_knowledge_distillation(teacher=teacher_model, student=student_model, train_loader=train_loader, num_epochs=10, learning_rate=0.001, temperature=2,
  File "<ipython-input-32-b5e6b73b9d07>", line 18, in train_knowledge_distillation
    teacher_logits=teacher(images)
                   ^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<ipython-input-3-0487002f3d1e>", lin

TypeError: object of type 'NoneType' has no len()

In [81]:
teacher_model.eval()
with torch.no_grad():
    logits=teacher_model.forward(image_tensor.unsqueeze(0))

T=50
prob=torch.softmax(logits/T,dim=-1)
print(prob)
torch.argsort(prob[0],)

tensor([[0.0823, 0.1567, 0.0929, 0.1122, 0.0814, 0.0863, 0.0900, 0.0771, 0.0929,
         0.1282]])


tensor([7, 4, 0, 5, 6, 8, 2, 3, 9, 1])

In [77]:
train_dataset.class_to_idx

{'airplane': 0,
 'automobile': 1,
 'bird': 2,
 'cat': 3,
 'deer': 4,
 'dog': 5,
 'frog': 6,
 'horse': 7,
 'ship': 8,
 'truck': 9}

In [26]:
torch.manual_seed(234)
t=torch.rand(3,5,)
t=t/t.sum(dim=1, keepdim=True)
z=torch.rand(3,5)
z=z/z.sum(dim=1, keepdim=True)

kl=-1*(t*z.log()).sum()


In [25]:
-1*(t*z.log()).sum()/temperature**2

tensor(5.7194)

In [8]:
z

tensor([[ 0.6272,  0.1104,  0.3533, -0.2535,  0.0875],
        [-0.3873,  0.5706,  0.5746,  0.2521,  1.2906],
        [ 1.0185, -0.6426,  0.2849,  0.6353,  2.8411]])

In [58]:
torch.sort(a)

torch.return_types.sort(
values=tensor([  1,   2,   3,   4,   8,  10, 100]),
indices=tensor([2, 5, 0, 1, 3, 6, 4]))

In [55]:
a.sort()

torch.return_types.sort(
values=tensor([ 1,  2,  3,  4,  8,  9, 10]),
indices=tensor([2, 5, 0, 1, 3, 4, 6]))