In [1]:
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import warnings
import os
import time
import torch
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader, random_split
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
warnings.filterwarnings('ignore')

def show_image(image, label):
    image = image.permute(1, 2, 0)
    plt.imshow(image.squeeze())
    plt.title(f'Label: {label}')
    plt.show()

device = ("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using {device} device")
            
class Params:
    def __init__(self):
        self.batch_size = 96
        self.name = "resnet_152_sgd1"
        self.workers = 4
        self.momentum = 0.9
        self.weight_decay = 4e-5
        self.lr_step_size = 30
        self.lr_gamma = 0.1
        self.alpha = 0.5
        self.temperature = 3.0
        self.num_epochs = 20
        self.learning_rate = 0.0045

    def __repr__(self):
        return str(self.__dict__)
    
    def __eq__(self, other):
        return self.__dict__ == other.__dict__

params = Params()
training_folder_name = '/Users/revanth/Documents/Assignments/CV/TPA/ILSVRC/Data/CLS-LOC/train'
val_folder_name = '/Users/revanth/Documents/Assignments/CV/TPA/ILSVRC/Data/CLS-LOC/val'

train_transformation = transforms.Compose([
    transforms.ToTensor(),
    transforms.RandomResizedCrop(224, interpolation=transforms.InterpolationMode.BILINEAR, antialias=True),
    transforms.RandomHorizontalFlip(0.5),
    # Normalize the pixel values (in R, G, and B channels)
    transforms.Normalize(mean=[0.485, 0.485, 0.406], std=[0.229, 0.224, 0.225])
])

# Define dataset with ImageFolder
train_dataset = torchvision.datasets.ImageFolder(
    root=training_folder_name,
    transform=train_transformation
)
classes = [ 'n04507155', 'n04254680', 'n03642806', 'n02782093', 'n03792782', 'n03393912', 'n03895866', 'n04204347', 'n03791053', 'n02701002' ] #, 'n02930766', 'n03594945', 'n03770679', 'n04037443', 'n03345487', 'n03417042', 'n04461696', 'n04467665', 'n03796401', 'n03977966', 'n04335435', 'n04380533', 'n03337140', 'n03179701', 'n04550184', 'n07747607', 'n07749582', 'n07753113', 'n07753275', 'n07753592', 'n07754684', 'n07760859', 'n07768694', 'n12620546', 'n13133613', 'n12144580', 'n03017168', 'n03249569', 'n03447721', 'n03954731', 'n03481172', 'n03109150', 'n02951585', 'n03970156', 'n04154565', 'n04208210', 'n03000684', 'n03876231', 'n03691459', 'n03759954', 'n04152593', 'n03793489', 'n03271574', 'n04118776', 'n03196217', 'n04548280', 'n03197337', 'n04376876', 'n03706229', 'n04356056', 'n03085013', 'n04505470', 'n03666591', 'n03180011', 'n03485407', 'n03832673', 'n04004767', 'n04355933', 'n04074963', 'n02948072', 'n04456115', 'n03485794', 'n07579787', 'n03814906', 'n02795169', 'n04553703', 'n02783161', 'n02802426', 'n02808304', 'n04548362', 'n06794110', 'n03388183', 'n04540053', 'n04026417', 'n04404412', 'n04204238', 'n04597913', 'n09835506', 'n04557648', 'n03958227', 'n02786058', 'n04409515', 'n03223299', 'n07614500', 'n03729826', 'n04254777', 'n02988304', 'n03063599', 'n04116512', 'n03255030' ]
allowed_class_to_idx = {cls: train_dataset.class_to_idx[cls] for cls in classes if cls in train_dataset.class_to_idx}
filtered_samples = [
    (path, label) for path, label in train_dataset.samples if label in allowed_class_to_idx.values()
]

train_dataset.samples = filtered_samples
train_dataset.targets = [s[1] for s in filtered_samples]
train_dataset.class_to_idx = allowed_class_to_idx
train_dataset.classes = list(allowed_class_to_idx.keys())

print(f"Filtered dataset contains {len(train_dataset.samples)} samples across {len(train_dataset.classes)} classes.")

# Define the split ratios
train_size = int(0.1 * len(train_dataset))
val_size = int(0.05 * len(train_dataset)) #len(train_dataset) - train_size
rem_size = len(train_dataset) - train_size - val_size

# Split the dataset into training and validation sets
train_data, val_data, _ = random_split(train_dataset, [train_size, val_size, rem_size])

# Define the sampler and DataLoader for the training set
train_sampler = torch.utils.data.RandomSampler(train_data)
train_loader = DataLoader(
    train_data,
    batch_size=params.batch_size,
    sampler=train_sampler,
    num_workers=params.workers,
    pin_memory=True,
)

# Define the DataLoader for the validation set (no sampler needed, as we'll just shuffle it)
val_loader = DataLoader(
    val_data,
    batch_size=params.batch_size,
    shuffle=True,  # Shuffling validation data is optional
    num_workers=params.workers,
    pin_memory=True,
)

Using mps device
Filtered dataset contains 13000 samples across 10 classes.


In [2]:
class DistillationLoss(nn.Module):
    def __init__(self, alpha, temperature):
        super(DistillationLoss, self).__init__()
        self.alpha = alpha
        self.temperature = temperature
        self.criterion = nn.CrossEntropyLoss()

    def forward(self, student_logits, teacher_logits, targets):
        # CrossEntropy Loss for student's hard predictions
        student_loss = self.criterion(student_logits, targets)
        
        # KL Divergence Loss for soft predictions
        distillation_loss = F.kl_div(
            F.log_softmax(student_logits / self.temperature, dim=1),
            F.softmax(teacher_logits / self.temperature, dim=1),
            reduction='batchmean'
        ) * (self.temperature ** 2)

        # Combine losses with alpha weighting
        loss = self.alpha * student_loss + (1 - self.alpha) * distillation_loss
        return loss

In [3]:
def distill(dataloader, teacher, model, epoch):
    size = len(dataloader.dataset)
    model.train()
    distillation_criterion = DistillationLoss(params.alpha, params.temperature)
    optimizer = optim.RMSprop(student_model.parameters(), lr=params.learning_rate, alpha=0.9, momentum=params.momentum, weight_decay=params.weight_decay)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=params.lr_step_size, gamma=params.lr_gamma)

    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)
        with torch.no_grad():
            teacher_logits = teacher(X)
        student_logits = model(X)
        loss = distillation_criterion(student_logits, teacher_logits, y)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(student_model.parameters(), max_norm=5.0)
        optimizer.step()
        optimizer.zero_grad()

        if batch % 100 == 0:
            print(f"loss: {loss.item():.6f} [{batch * len(X)}/{size}]")

    scheduler.step()

def train(dataloader, teacher, model):
    start_epoch = 0
    # total_epochs = 20
    for epoch in range(start_epoch, params.num_epochs):
        distill(dataloader, teacher, model, epoch)
        test(val_loader, model, nn.CrossEntropyLoss(), epoch)
        

def test(dataloader, model, loss_fn, epoch):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    step = epoch * len(dataloader.dataset)
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} , Step: {step}\n")

In [4]:
teacher_model = torchvision.models.swin_t(weights="IMAGENET1K_V1").to(device)
teacher_model.eval()
student_model = torchvision.models.mobilenet_v2(weights=None).to(device) # weights=torchvision.models.MobileNet_V2_Weights.IMAGENET1K_V1
student_model.train()

train(train_loader, teacher_model, student_model)

torch.save(student_model.state_dict(), 'student_mobilenet_v3_large.pth')
print("Student model saved as 'student_mobilenet_v3_large.pth'")

loss: 3.843838 [0/1300]


python(74278) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(74281) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(74282) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(74283) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


Test Error: 
 Accuracy: 11.7%, Avg loss: 4.784631 , Step: 0



python(74296) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(74300) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(74301) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(74302) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


loss: 1.833974 [0/1300]


python(74441) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(74444) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(74445) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(74446) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


Test Error: 
 Accuracy: 10.8%, Avg loss: 4.681284 , Step: 650



python(74464) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(74467) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(74468) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(74470) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


loss: 1.759503 [0/1300]


python(74592) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(74595) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(74596) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(74597) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


Test Error: 
 Accuracy: 11.8%, Avg loss: 2.732299 , Step: 1300



python(74618) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(74621) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(74622) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(74624) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


loss: 1.851325 [0/1300]


python(74754) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(74757) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(74760) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(74761) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


Test Error: 
 Accuracy: 11.2%, Avg loss: 3.037619 , Step: 1950



python(74775) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(74779) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(74781) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(74782) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


loss: 1.890518 [0/1300]


python(74862) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(74865) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(74866) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(74867) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


Test Error: 
 Accuracy: 10.8%, Avg loss: 2.573600 , Step: 2600



python(74891) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(74896) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(74897) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(74898) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


loss: 1.750341 [0/1300]


python(75037) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(75040) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(75041) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(75043) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


Test Error: 
 Accuracy: 10.8%, Avg loss: 2.598214 , Step: 3250



python(75071) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(75074) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(75076) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(75077) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


loss: 1.767687 [0/1300]


python(75178) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(75181) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(75182) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(75184) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


Test Error: 
 Accuracy: 11.2%, Avg loss: 2.503245 , Step: 3900



python(75197) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(75200) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(75201) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(75202) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


loss: 1.739387 [0/1300]


python(75319) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(75322) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(75323) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(75324) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


Test Error: 
 Accuracy: 11.4%, Avg loss: 2.800154 , Step: 4550



python(75333) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(75336) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(75337) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(75338) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


loss: 1.930172 [0/1300]


python(75453) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(75456) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(75457) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(75459) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


Test Error: 
 Accuracy: 6.8%, Avg loss: 4.634768 , Step: 5200



python(75477) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(75480) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(75481) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(75484) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


loss: 1.856771 [0/1300]


python(75586) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(75589) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(75590) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(75592) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


Test Error: 
 Accuracy: 11.1%, Avg loss: 3.951894 , Step: 5850



python(75616) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(75619) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(75620) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(75621) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


loss: 1.782046 [0/1300]


python(75702) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(75725) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(75726) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(75729) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


Test Error: 
 Accuracy: 9.8%, Avg loss: 3.134273 , Step: 6500



python(75753) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(75756) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(75757) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(75758) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


loss: 1.757365 [0/1300]


python(75839) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(75842) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(75843) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(75847) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


Test Error: 
 Accuracy: 11.1%, Avg loss: 3.310654 , Step: 7150



python(75922) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(75925) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(75926) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(75927) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


loss: 1.755320 [0/1300]


python(76018) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(76021) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(76022) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(76024) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


Test Error: 
 Accuracy: 6.8%, Avg loss: 22.484444 , Step: 7800



python(76039) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(76042) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(76043) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(76044) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


loss: 1.748047 [0/1300]


python(76117) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(76120) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(76121) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(76123) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


Test Error: 
 Accuracy: 6.5%, Avg loss: 7.935783 , Step: 8450



python(76170) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(76173) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(76174) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(76177) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


loss: 1.838407 [0/1300]


python(76293) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(76296) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(76297) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(76299) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


Test Error: 
 Accuracy: 9.4%, Avg loss: 3.749347 , Step: 9100



python(76311) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(76314) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(76315) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(76316) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


loss: 2.028296 [0/1300]


python(76439) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(76443) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(76445) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(76447) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


Test Error: 
 Accuracy: 11.5%, Avg loss: 8.196679 , Step: 9750



python(76483) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(76486) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(76487) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(76488) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


loss: 1.831571 [0/1300]


python(76549) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(76552) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(76553) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(76554) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


Test Error: 
 Accuracy: 18.2%, Avg loss: 2.878485 , Step: 10400



python(76574) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(76577) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(76579) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(76580) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


loss: 1.639027 [0/1300]


python(76716) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(76719) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(76720) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(76722) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


Test Error: 
 Accuracy: 13.2%, Avg loss: 8.006303 , Step: 11050



python(76736) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(76739) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(76742) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(76743) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


loss: 1.764007 [0/1300]


python(76841) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(76844) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(76845) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(76846) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


Test Error: 
 Accuracy: 17.4%, Avg loss: 2.646545 , Step: 11700



python(76866) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(76869) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(76870) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(76871) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


loss: 1.879700 [0/1300]


python(76978) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(76981) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(76982) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(76983) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


Test Error: 
 Accuracy: 10.9%, Avg loss: 6.872127 , Step: 12350

Student model saved as 'student_mobilenet_v3_large.pth'
