### Imports

In [2]:
import medmnist
from medmnist import INFO, Evaluator
import torch
import torchvision
from torchvision import transforms
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
import torch.nn as nn
from torch import utils, optim, device, inference_mode
import tqdm
from timeit import default_timer as timer
from tqdm.auto import tqdm
from torchmetrics import ConfusionMatrix
import mlxtend
from mlxtend.plotting import plot_confusion_matrix
import numpy

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")





In [3]:
import medmnist
from medmnist import INFO, Evaluator

### Dataset

In [4]:
data_flag = 'pathmnist'
# data_flag = 'dermamnist'
info = INFO[data_flag]
DataClass = getattr(medmnist, info['python_class'])

info = INFO[data_flag]
task = info['task']
n_channels = info['n_channels']
n_classes = len(info['label'])

data_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[.5], std=[.5])
])

In [5]:
print('task:', task)
print('channels: ', n_channels)
print('classes:', n_classes)

task: multi-class
channels:  3
classes: 9


In [6]:
train_data = DataClass(root='pathmnist_data', split='train', transform=data_transform, size=224, mmap_mode='r' ,download=True)
val_data = DataClass(root='pathmnist_data', split='val', transform=data_transform, size=224, mmap_mode='r', download=True)
test_data = DataClass(root='pathmnist_data', split='test', transform=data_transform, size=224, mmap_mode='r', download=True)

Using downloaded and verified file: pathmnist_data\pathmnist_224.npz
Using downloaded and verified file: pathmnist_data\pathmnist_224.npz
Using downloaded and verified file: pathmnist_data\pathmnist_224.npz


In [7]:
train_sample = train_data[0]
print(train_sample[0].shape)

torch.Size([3, 224, 224])


In [None]:
# generate random data loaders for demonstration

In [8]:
# change data into dataloader form
BATCH_SIZE = 256
train_dataloader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
val_dataloader = DataLoader(dataset=val_data, batch_size=BATCH_SIZE, shuffle=False)
test_dataloader = DataLoader(dataset=test_data, batch_size=BATCH_SIZE, shuffle=False)

### ResnetModel

In [13]:
resnet = torchvision.models.resnet18(pretrained=True)

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to C:\Users\preet/.cache\torch\hub\checkpoints\resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:01<00:00, 27.6MB/s]


In [None]:
resnet

### Teacher Models

In [51]:
# class TeacherResnet50(torch.nn.Module):
#     def __init__(self):
#         super(TeacherResnet50, self).__init__()
#         self.teacher_resnet50 = torchvision.models.resnet50(pretrained=True)
#         self.layers = list(self.teacher_resnet50.children())[:-1]
        
#         for param in self.teacher_resnet50.parameters():
#             param.requires_grad = False
        

#     def forward(self, x):
#         print(self.layers)
#         for layer in self.layers:
#             x = layer(x)
#         return x

In [9]:
class TeacherResnet50(torch.nn.Module):
    def __init__(self):
        super(TeacherResnet50, self).__init__()
        self.resnet50 = torchvision.models.resnet50(pretrained=True)
        self.resnet50.fc = torch.nn.Identity()  # Remove the last layer

    def forward(self, x):
        return self.resnet50(x)

In [10]:
# test the teacher model with a sample tensor
test_tensor = torch.randn(1, 3, 224, 224)
teacher_model = TeacherResnet50()
output = teacher_model(test_tensor)
print(output.shape)




torch.Size([1, 2048])


### Student Models

In [None]:
mobilenet = torchvision.models.mobilenet_v3_small(pretrained=True)
mobilenet

In [11]:
class StudentMobileNetV3(torch.nn.Module):
    def __init__(self, num_classes):
        super(StudentMobileNetV3, self).__init__()
        self.mobilenet_v3 = torchvision.models.mobilenet_v3_small(pretrained=False)
        self.mobilenet_v3.classifier[3] = torch.nn.Identity()  # Remove the last layer
        self.linear = torch.nn.Linear(1024, 2048)
        self.fc = torch.nn.Linear(1024, num_classes)  # Add a new fully connected layer for classification

    def forward(self, x):
        x = self.mobilenet_v3(x)
        return self.linear(x), self.fc(x)

In [12]:
# test the student model with a sample tensor
student_model = StudentMobileNetV3(num_classes=n_classes)
embedding, output = student_model(test_tensor)
print(embedding.shape)

torch.Size([1, 2048])




### Training With Cosine Loss and Student Softmax

In [14]:
def test_multiple_outputs(student_model, test_loader, device, validation = False):
    student_model.to(device)
    student_model.eval()
    
    correct = 0
    total = 0
    
    with torch.no_grad():
        for inputs, labels in test_loader:
            labels = labels.squeeze(1)
            inputs, labels = inputs.to(device), labels.to(device)
            
            _, outputs = student_model(inputs) # Disregard the first tensor of the tuple
            _, predicted = torch.max(outputs.data, 1)
            # print("predicted.shape =", predicted.shape)
            # print("label.shape =", labels.shape)
            
            # print("total number of samples in this batch:", labels.size(0))
            # print("correctly classified samples in this batch:", (predicted == labels).sum().item())
            
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    # print("correct = ", correct)
    # print("total = ", total)
    
    accuracy = 100 * correct / total
    if validation:
        print(f"Validation Accuracy: {accuracy:.2f}%")
    else:
        print(f"Test Accuracy: {accuracy:.2f}%")
    return accuracy

In [15]:
def train_cosine_loss(teacher, student, train_loader, val_loader, epochs, learning_rate, hidden_rep_loss_weight, ce_loss_weight, device):
    ce_loss = nn.CrossEntropyLoss()
    cosine_loss = nn.CosineEmbeddingLoss()
    optimizer = optim.Adam(student.parameters(), lr=learning_rate)
    
    teacher.to(device)
    student.to(device)
    teacher.eval()  # Teacher set to evaluation mode
    student.train() # Student to train mode
    
    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in train_loader:
            labels = labels.type(torch.LongTensor)
            inputs, labels = inputs.to(device), labels.to(device)
            
            optimizer.zero_grad()
            
            # Forward pass with the teacher model and keep only the hidden representation
            with torch.no_grad():
                teacher_hidden_representation = teacher(inputs)
                
            # Forward pass with the student model
            student_hidden_representation, student_logits = student(inputs)

            # Calculate the cosine loss. Target is a vector of ones. From the loss formula above we can see that is the case where loss minimization leads to cosine similarity increase.
            hidden_rep_loss = cosine_loss(student_hidden_representation, teacher_hidden_representation, target=torch.ones(inputs.size(0)).to(device))

            # Calculate the true label loss
            # print("student logits shape:", student_logits.shape)
            # print("labels shape:", labels.shape)
            labels = labels.squeeze(1)
            # print("labels shape:", labels.shape)
            label_loss = ce_loss(student_logits, labels)

            # Weighted sum of the two losses
            loss = hidden_rep_loss_weight * hidden_rep_loss + ce_loss_weight * label_loss

            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")
        
        # perform validation
        accuracy = test_multiple_outputs(student, val_loader, device, validation=True)

In [16]:
# Train and test the lightweight network with cross entropy loss
train_cosine_loss(teacher=teacher_model, student=student_model, train_loader=train_dataloader, val_loader=val_dataloader, epochs=100, learning_rate=0.001, hidden_rep_loss_weight=0.25, ce_loss_weight=0.75, device=device)
test_accuracy_light_ce_and_cosine_loss = test_multiple_outputs(student_model, test_dataloader, device)

Epoch 1/100, Loss: 0.3937030675905672
Validation Accuracy: 8.92%
Epoch 2/100, Loss: 0.4327007128687745
Validation Accuracy: 92.56%
Epoch 3/100, Loss: 0.17708203645253723
Validation Accuracy: 95.56%
Epoch 4/100, Loss: 0.13737372774630785
Validation Accuracy: 95.89%
Epoch 5/100, Loss: 0.11192736558785493
Validation Accuracy: 96.21%
Epoch 6/100, Loss: 0.10288054970177737
Validation Accuracy: 96.58%
Epoch 7/100, Loss: 0.09220552170352841
Validation Accuracy: 96.94%
Epoch 8/100, Loss: 0.08249164290133525
Validation Accuracy: 97.55%
Epoch 9/100, Loss: 0.08175810299475085
Validation Accuracy: 97.45%
Epoch 10/100, Loss: 0.07326256921938197
Validation Accuracy: 98.12%
Epoch 11/100, Loss: 0.07345142145641148
Validation Accuracy: 96.51%
Epoch 12/100, Loss: 0.07380022267303006
Validation Accuracy: 97.94%
Epoch 13/100, Loss: 0.0667750159574842
Validation Accuracy: 97.99%
Epoch 14/100, Loss: 0.06378577736904845
Validation Accuracy: 98.61%
Epoch 15/100, Loss: 0.06278720675882968
Validation Accuracy: 