In [None]:
import os
import torch
import torchvision
from torch.utils.data import random_split
import torch.nn as nn
import torch.nn.functional as F

random_seed = 123
torch.manual_seed(random_seed)

data_dir  = '../data/Images'

classes = os.listdir(data_dir)
print(classes)
print(f"length: {len(classes)}")

In [None]:
from matplotlib import pyplot as plt
from torchvision.datasets import ImageFolder
import torchvision.transforms as transforms
from torch.utils.data.dataloader import DataLoader
import numpy as np
from torch.utils.data.dataloader import DataLoader
from torchvision.models import ResNet152_Weights, EfficientNet_B0_Weights, Inception_V3_Weights


def dataset_setup(model_name='resnet18'):
    if model_name == 'resnet18':
        transformations = transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(degrees=10),
            transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
            transforms.ToTensor()])
    elif model_name == 'efficientnet_b0' or model_name == 'vit_b_16' or model_name == 'resnet152':
        transformations = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(degrees=10),
            transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
            transforms.ToTensor()])
    elif model_name == 'inception_v3':
        transformations = transforms.Compose([
            transforms.Resize((299, 299)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(degrees=10),
            transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
            transforms.ToTensor()])
    return transformations


model_name = 'vit_b_16'
dataset = ImageFolder(data_dir, transform = dataset_setup(model_name=model_name))
print(f"dataset size: {len(dataset)}")
batch_size = 64
train_ds, val_ds, test_ds = random_split(dataset, [2800, 500, 383])
train_loader = DataLoader(train_ds, batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size*2, num_workers=4, pin_memory=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
from ConvClassifier import ConvClassifier
from Distillation import DistillationLoss

def train_distillation(teacher_model, student_model, num_epochs=5, train_loader=None, val_loader=None, 
                       temperature=4.0, alpha=0.5):
    teacher_model.eval()  # Freeze the teacher model
    distillation_loss = DistillationLoss(temperature=temperature, alpha=alpha)
    optimizer = torch.optim.Adam(student_model.parameters(), lr=3e-5)
    history = []

    for epoch in range(num_epochs):
        # Training phase
        student_model.train()
        train_losses = []
        for batch in train_loader:
            images, labels = batch[0].to(device), batch[1].to(device)

            with torch.no_grad():
                teacher_outputs = teacher_model(images)

            student_outputs = student_model(images)
            loss = distillation_loss(student_outputs, teacher_outputs, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_losses.append(loss.item())

        # Validation phase
        student_model.eval()
        outputs = [student_model.valid_step(batch) for batch in val_loader]
        result = {
            'val_loss': np.mean([out['val_loss'] for out in outputs]),
            'val_acc': np.mean([out['val_acc'] for out in outputs]),
            'train_loss': np.mean(train_losses),
        }
        print(
            f"[Epoch {epoch+1}] train_loss: {result['train_loss']:.4f}, "
            f"val_loss: {result['val_loss']:.4f}, val_acc: {result['val_acc']:.4f}"
        )
        history.append(result)

    return history, student_model


teacher_model = torchvision.models.inception_v3(weights = Inception_V3_Weights.IMAGENET1K_V1).to(device)
teacher_model.eval()  # Freeze the teacher model

student_model = ConvClassifier(model_name=model_name, dataset=dataset).to(device)

history, student_model = train_distillation(teacher_model=teacher_model, student_model=student_model,
                                            train_loader=train_loader, val_loader=val_loader,
                                            temperature=4.0, alpha=0.5)
