In [None]:
import torch
import torchvision
import torchvision.transforms as T
from torch.utils.data import DataLoader
from pycocotools.coco import COCO
from torchvision.models.detection import fasterrcnn_resnet50_fpn, ssdlite320_mobilenet_v3_large
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
import os
from coco_dataset import coco_dataset_download as cocod
from torchvision import transforms

dataset_dir = "./coco2017"
annotations_path='/Users/revanth/Documents/Assignments/CV/TPA/coco2017/annotations/instances_train2017.json'

transform = transforms.Compose([
    transforms.Resize((360, 360)),
    transforms.ToTensor()
])

# TODO: TRY CREATING TARGET TRANSFORM AND SEE IF ANNOTATIONS NEED TO BE CONVERTED TO TENSOR
coco_train = torchvision.datasets.CocoDetection(root=f"{dataset_dir}/train2017", annFile=annotations_path, transform=transform)
coco_subset = torch.utils.data.Subset(coco_train, range(50))

In [None]:
train_loader = DataLoader(coco_subset, batch_size=2, shuffle=True)
# train_loader = DataLoader(coco_subset, batch_size=2, shuffle=True, num_workers=4)
print(coco_train.__getitem__(0))

gpu_available = torch.cuda.is_available()
mps_available = torch.backends.mps.is_built()
# device = "mps" if mps_available else "cuda" if gpu_available else "cpu"
device = "cpu"

#TEACHER
teacher_model = fasterrcnn_resnet50_fpn(weights=torchvision.models.detection.FasterRCNN_ResNet50_FPN_Weights.COCO_V1).to(device)
teacher_model.eval()

# STUDENT
student_model = ssdlite320_mobilenet_v3_large(weights=torchvision.models.detection.SSDLite320_MobileNet_V3_Large_Weights.COCO_V1).to(device)
student_model.train()

optimizer = optim.SGD(student_model.parameters(), lr=0.001, momentum=0.9)

# Loss Function
def distillation_loss(teacher_outputs, student_outputs, alpha=0.5, temperature=3.0):
    soft_teacher_probs = torch.softmax(teacher_outputs / temperature, dim=1)
    soft_student_probs = torch.log_softmax(student_outputs / temperature, dim=1)

    distillation_loss = torch.nn.functional.kl_div(soft_student_probs, soft_teacher_probs, reduction="batchmean") * (temperature ** 2)
    return distillation_loss

In [None]:
num_epochs = 1
for epoch in range(num_epochs):
    student_model.train()
    running_loss = 0.0
    
    for images, targets in train_loader:
        images = list(image.to(device) for image in images)
        
        with torch.no_grad():
            teacher_outputs = teacher_model(images)
        
        student_outputs = student_model(images)

        loss = 0.0
        for teacher_out, student_out in zip(teacher_outputs, student_outputs):
            teacher_logits = teacher_out['boxes']
            student_logits = student_out['boxes']
            # TODO is predicted class and scores taken for loss calculation? SHOULD WE ADD IT?
            
            loss += distillation_loss(teacher_logits, student_logits)

        print('Calculated loss ', loss)
        optimizer.zero_grad()
        loss.backward()
        print('Loss backward')
        optimizer.step()
        
        running_loss += loss.item()

    print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / len(train_loader)}")

print("Training completed.")