<a href="https://colab.research.google.com/github/rileytyh/NNDL-TransferLearning/blob/main/W4995_TransferLearning_Project_ResNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!unzip Released_Data_NNDL_2025-20250510T140910Z-001.zip
!unzip Released_Data_NNDL_2025/train_images.zip
!unzip Released_Data_NNDL_2025/test_images.zip

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  inflating: test_images/7677.jpg    
  inflating: __MACOSX/test_images/._7677.jpg  
  inflating: test_images/8544.jpg    
  inflating: __MACOSX/test_images/._8544.jpg  
  inflating: test_images/9882.jpg    
  inflating: __MACOSX/test_images/._9882.jpg  
  inflating: test_images/792.jpg     
  inflating: __MACOSX/test_images/._792.jpg  
  inflating: test_images/1206.jpg    
  inflating: __MACOSX/test_images/._1206.jpg  
  inflating: test_images/6569.jpg    
  inflating: __MACOSX/test_images/._6569.jpg  
  inflating: test_images/5060.jpg    
  inflating: __MACOSX/test_images/._5060.jpg  
  inflating: test_images/10248.jpg   
  inflating: __MACOSX/test_images/._10248.jpg  
  inflating: test_images/3411.jpg    
  inflating: __MACOSX/test_images/._3411.jpg  
  inflating: test_images/11156.jpg   
  inflating: __MACOSX/test_images/._11156.jpg  
  inflating: test_images/4418.jpg    
  inflating: __MACOSX/test_images/._4418.jpg  

In [None]:
import os
import random
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from datetime import datetime
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision
import torchvision.models as models

from torch.utils.data import Dataset, DataLoader, BatchSampler, random_split
from torchvision import transforms
from PIL import Image

import timm

In [None]:
# Create Dataset class for multilabel classification
class MultiClassImageDataset(Dataset):
    def __init__(self, ann_df, super_map_df, sub_map_df, img_dir, transform=None):
        self.ann_df = ann_df
        self.super_map_df = super_map_df
        self.sub_map_df = sub_map_df
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return len(self.ann_df)

    def __getitem__(self, idx):
        img_name = self.ann_df['image'][idx]
        img_path = os.path.join(self.img_dir, img_name)
        image = Image.open(img_path).convert('RGB')

        super_idx = self.ann_df['superclass_index'][idx]
        super_label = self.super_map_df['class'][super_idx]

        sub_idx = self.ann_df['subclass_index'][idx]
        sub_label = self.sub_map_df['class'][sub_idx]

        if self.transform:
            image = self.transform(image)

        return image, super_idx, super_label, sub_idx, sub_label

class MultiClassImageTestDataset(Dataset):
    def __init__(self, super_map_df, sub_map_df, img_dir, transform=None):
        self.super_map_df = super_map_df
        self.sub_map_df = sub_map_df
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self): # Count files in img_dir
        return len([fname for fname in os.listdir(self.img_dir)])

    def __getitem__(self, idx):
        img_name = str(idx) + '.jpg'
        img_path = os.path.join(self.img_dir, img_name)
        image = Image.open(img_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        return image, img_name

In [None]:
train_ann_df = pd.read_csv('Released_Data_NNDL_2025/train_data.csv')
super_map_df = pd.read_csv('Released_Data_NNDL_2025/superclass_mapping.csv')
sub_map_df = pd.read_csv('Released_Data_NNDL_2025/subclass_mapping.csv')

train_img_dir = 'train_images'
test_img_dir = 'test_images'

timm_model_name = 'resnetv2_101x3_bit.goog_in21k'

resnet_model = timm.create_model(timm_model_name, pretrained=True, num_classes=0)
resnet_data_config = timm.data.resolve_model_data_config(resnet_model)
resnet_transforms = timm.data.create_transform(**resnet_data_config, is_training=False)

last_layer_nodes = resnet_model.num_features
print(last_layer_nodes)

del resnet_model

# image_preprocessing = transforms.Compose([
#     transforms.ToTensor(),
#     transforms.Normalize(mean=(0), std=(1)),
# ])

# Create train and val split
train_dataset = MultiClassImageDataset(train_ann_df, super_map_df, sub_map_df, train_img_dir, transform=resnet_transforms)
train_dataset, val_dataset = random_split(train_dataset, [0.9, 0.1])

# Create test dataset
test_dataset = MultiClassImageTestDataset(super_map_df, sub_map_df, test_img_dir, transform=resnet_transforms)

# Create dataloaders
batch_size = 16
train_loader = DataLoader(train_dataset,
                          batch_size=batch_size,
                          shuffle=True)

val_loader = DataLoader(val_dataset,
                        batch_size=batch_size,
                        shuffle=True)

test_loader = DataLoader(test_dataset,
                         batch_size=1,
                         shuffle=False)

6144


In [None]:
class GeneralModel(nn.Module):
    def __init__(self, freeze_main=True):
        super().__init__()

        self.main_model = timm.create_model(timm_model_name, pretrained=True, num_classes=0)

        self.freeze_main = freeze_main
        if freeze_main:
            self.main_model.eval()

        self.super_fc = nn.Linear(last_layer_nodes, 4)
        self.sub_fc = nn.Linear(last_layer_nodes, 88)

    def forward(self, x):
        if self.freeze_main == True:
            with torch.no_grad():
                x = self.main_model(x)
        else:
            x = self.main_model(x)

        super_out = self.super_fc(x)
        sub_out = self.sub_fc(x)

        return super_out, sub_out

class Trainer():
    def __init__(self, model, criterion, optimizer, train_loader, val_loader, test_loader=None, device='cuda'):
        self.model = model
        self.criterion = criterion
        self.optimizer = optimizer
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.test_loader = test_loader

    def train_epoch(self):
        if not self.model.freeze_main:
            self.model.main_model.train()

        running_loss = 0.0
        for i, data in tqdm(enumerate(self.train_loader), total=len(self.train_loader)):
            inputs, super_labels, sub_labels = data[0].to(device), data[1].to(device), data[3].to(device)

            self.optimizer.zero_grad()
            super_outputs, sub_outputs = self.model(inputs)
            loss = self.criterion(super_outputs, super_labels) + self.criterion(sub_outputs, sub_labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f'Training loss: {running_loss/i:.3f}')

    def validate_epoch(self):
        if not self.model.freeze_main:
            self.model.main_model.eval()

        super_correct = 0
        sub_correct = 0
        total = 0
        running_loss = 0.0
        with torch.no_grad():
            for i, data in tqdm(enumerate(self.val_loader), total=len(self.val_loader)):
                inputs, super_labels, sub_labels = data[0].to(device), data[1].to(device), data[3].to(device)

                super_outputs, sub_outputs = self.model(inputs)
                loss = self.criterion(super_outputs, super_labels) + self.criterion(sub_outputs, sub_labels)
                _, super_predicted = torch.max(super_outputs.data, 1)
                _, sub_predicted = torch.max(sub_outputs.data, 1)

                total += super_labels.size(0)
                super_correct += (super_predicted == super_labels).sum().item()
                sub_correct += (sub_predicted == sub_labels).sum().item()
                running_loss += loss.item()

        print(f'Validation loss: {running_loss/i:.3f}')
        print(f'Validation superclass acc: {100 * super_correct / total:.2f} %')
        print(f'Validation subclass acc: {100 * sub_correct / total:.2f} %')

    def test(self, save_to_csv=False, return_predictions=False):
        if not self.model.freeze_main:
            self.model.main_model.eval()

        if not self.test_loader:
            raise NotImplementedError('test_loader not specified')

        # Evaluate on test set, in this simple demo no special care is taken for novel/unseen classes
        test_predictions = {'image': [], 'superclass_index': [], 'subclass_index': []}
        with torch.no_grad():
            for i, data in tqdm(enumerate(self.test_loader), total=len(self.test_loader)):
                inputs, img_name = data[0].to(device), data[1]

                super_outputs, sub_outputs = self.model(inputs)
                _, super_predicted = torch.max(super_outputs.data, 1)
                _, sub_predicted = torch.max(sub_outputs.data, 1)

                test_predictions['image'].append(img_name[0])
                test_predictions['superclass_index'].append(super_predicted.item())
                test_predictions['subclass_index'].append(sub_predicted.item())

        test_predictions = pd.DataFrame(data=test_predictions)

        if save_to_csv:
            current_time = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
            if not os.path.exists(current_time):
                os.mkdir(current_time)
            test_predictions.to_csv(current_time + '/test_predictions.csv', index=False)

        if return_predictions:
            return test_predictions

In [None]:
# Init model and trainer
device = 'cuda'
model = GeneralModel(freeze_main=True)
# model.load_state_dict(torch.load('resnet_model_final.pt', weights_only=True))
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3)
trainer = Trainer(model, criterion, optimizer, train_loader, val_loader, test_loader)

In [None]:
# Training loop
for epoch in range(6):
    print(f'Epoch {epoch+1}')
    trainer.train_epoch()
    trainer.validate_epoch()
    print('')

print('Finished Training')

Epoch 1


100%|██████████| 354/354 [01:16<00:00,  4.63it/s]


Training loss: 1.308


100%|██████████| 40/40 [00:08<00:00,  4.75it/s]


Validation loss: 0.473
Validation superclass acc: 100.00 %
Validation subclass acc: 89.81 %

Epoch 2


100%|██████████| 354/354 [01:16<00:00,  4.62it/s]


Training loss: 0.440


100%|██████████| 40/40 [00:08<00:00,  4.70it/s]


Validation loss: 0.562
Validation superclass acc: 100.00 %
Validation subclass acc: 92.52 %

Epoch 3


100%|██████████| 354/354 [01:16<00:00,  4.61it/s]


Training loss: 0.347


100%|██████████| 40/40 [00:08<00:00,  4.67it/s]


Validation loss: 0.698
Validation superclass acc: 100.00 %
Validation subclass acc: 92.52 %

Epoch 4


100%|██████████| 354/354 [01:16<00:00,  4.62it/s]


Training loss: 0.224


100%|██████████| 40/40 [00:08<00:00,  4.74it/s]


Validation loss: 1.164
Validation superclass acc: 100.00 %
Validation subclass acc: 89.81 %

Epoch 5


100%|██████████| 354/354 [01:17<00:00,  4.59it/s]


Training loss: 0.303


100%|██████████| 40/40 [00:08<00:00,  4.67it/s]


Validation loss: 0.962
Validation superclass acc: 100.00 %
Validation subclass acc: 92.04 %

Epoch 6


100%|██████████| 354/354 [01:17<00:00,  4.59it/s]


Training loss: 0.284


100%|██████████| 40/40 [00:08<00:00,  4.74it/s]

Validation loss: 0.761
Validation superclass acc: 100.00 %
Validation subclass acc: 93.47 %

Finished Training





In [None]:
torch.save(model.state_dict(), 'resnet_model_final.pt')

In [None]:
test_predictions = trainer.test(save_to_csv=True, return_predictions=True)

100%|██████████| 11180/11180 [06:52<00:00, 27.11it/s]


In [None]:
# validation with collecting of prob scores to help with deciding threshold
super_correct = 0
sub_correct = 0
total = 0
running_loss = 0.

super_all_scores = []
sub_all_scores = []

super_all_energies = []
sub_all_energies = []

with torch.no_grad():
    for i, data in tqdm(enumerate(trainer.val_loader), total=len(trainer.val_loader)):
        inputs, super_labels, sub_labels = data[0].to(device), data[1].to(device), data[3].to(device)

        super_outputs, sub_outputs = trainer.model(inputs)
        loss = trainer.criterion(super_outputs, super_labels) + trainer.criterion(sub_outputs, sub_labels)
        _, super_predicted = torch.max(super_outputs.data, 1)
        _, sub_predicted = torch.max(sub_outputs.data, 1)

        super_probs = torch.nn.functional.softmax(super_outputs.data, 1)
        sub_probs = torch.nn.functional.softmax(sub_outputs.data, 1)

        super_scores, _ = torch.max(super_probs, 1)
        sub_scores, _ = torch.max(sub_probs, 1)

        super_energies = -torch.logsumexp(super_outputs.data, dim=1)
        sub_energies = -torch.logsumexp(sub_outputs.data, dim=1)

        super_all_scores.extend(super_scores.tolist())
        sub_all_scores.extend(sub_scores.tolist())

        super_all_energies.extend(super_energies.tolist())
        sub_all_energies.extend(sub_energies.tolist())

        total += super_labels.size(0)
        super_correct += (super_predicted == super_labels).sum().item()
        sub_correct += (sub_predicted == sub_labels).sum().item()
        running_loss += loss.item()

print(f'Validation loss: {running_loss/i:.3f}')
print(f'Validation superclass acc: {100 * super_correct / total:.2f} %')
print(f'Validation subclass acc: {100 * sub_correct / total:.2f} %')

100%|██████████| 40/40 [00:08<00:00,  4.58it/s]

Validation loss: 0.761
Validation superclass acc: 100.00 %
Validation subclass acc: 93.47 %





In [None]:
display(pd.DataFrame(super_all_scores).describe())
display(pd.DataFrame(sub_all_scores).describe())

Unnamed: 0,0
count,628.0
mean,0.999642
std,0.008483
min,0.787417
25%,0.999996
50%,0.999999
75%,1.0
max,1.0


Unnamed: 0,0
count,628.0
mean,0.99243
std,0.041941
min,0.544701
25%,1.0
50%,1.0
75%,1.0
max,1.0


In [None]:
display(pd.DataFrame(super_all_energies).describe())
display(pd.DataFrame(sub_all_energies).describe())

Unnamed: 0,0
count,628.0
mean,-15.717142
std,2.368735
min,-22.388897
25%,-17.250791
50%,-15.313695
75%,-14.122219
max,-7.025426


Unnamed: 0,0
count,628.0
mean,-26.118684
std,12.992579
min,-91.561447
25%,-34.150433
50%,-25.585178
75%,-17.473681
max,7.543972


In [None]:
super_score_threshold = float(np.mean(super_all_scores) - 2 * np.std(super_all_scores))
sub_score_threshold = float(np.mean(sub_all_scores) - 2 * np.std(sub_all_scores))

print(super_score_threshold)
print(sub_score_threshold)

0.9826892410836034
0.9086145793010311


In [None]:
super_energy_threshold = float(np.mean(super_all_energies) - 2 * np.std(super_all_energies))
sub_energy_threshold = float(np.mean(sub_all_energies) - 2 * np.std(sub_all_energies))

print(super_energy_threshold)
print(sub_energy_threshold)

-20.450838235920944
-52.08314545406972


In [None]:
# with chosen thresholds, make novel OOD decisions as well using softmax scores
test_predictions = {'image': [], 'superclass_index': [], 'subclass_index': []}
with torch.no_grad():
    for i, data in tqdm(enumerate(trainer.test_loader), total=len(trainer.test_loader)):
        inputs, img_name = data[0].to(device), data[1]

        super_outputs, sub_outputs = trainer.model(inputs)
        _, super_predicted = torch.max(super_outputs.data, 1)
        _, sub_predicted = torch.max(sub_outputs.data, 1)

        super_probs = torch.nn.functional.softmax(super_outputs.data, 1)
        sub_probs = torch.nn.functional.softmax(sub_outputs.data, 1)

        super_scores, _ = torch.max(super_probs, 1)
        sub_scores, _ = torch.max(sub_probs, 1)

        super_energies = -torch.logsumexp(super_outputs.data, dim=1)
        sub_energies = -torch.logsumexp(sub_outputs.data, dim=1)

        super_predicted = torch.where(super_scores < super_score_threshold,
                                      torch.tensor(3, device=super_predicted.device),
                                      super_predicted)
        sub_predicted = torch.where(sub_scores < sub_score_threshold,
                                    torch.tensor(87, device=sub_predicted.device),
                                    sub_predicted)

        test_predictions['image'].append(img_name[0])
        test_predictions['superclass_index'].append(super_predicted.item())
        test_predictions['subclass_index'].append(sub_predicted.item())

test_predictions = pd.DataFrame(data=test_predictions)

current_time = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
if not os.path.exists(current_time):
    os.mkdir(current_time)
test_predictions.to_csv(current_time + '/test_predictions.csv', index=False)

100%|██████████| 11180/11180 [06:54<00:00, 26.98it/s]


In [None]:
# with chosen thresholds, make novel OOD decisions as well using energies
test_predictions = {'image': [], 'superclass_index': [], 'subclass_index': []}
with torch.no_grad():
    for i, data in tqdm(enumerate(trainer.test_loader), total=len(trainer.test_loader)):
        inputs, img_name = data[0].to(device), data[1]

        super_outputs, sub_outputs = trainer.model(inputs)
        _, super_predicted = torch.max(super_outputs.data, 1)
        _, sub_predicted = torch.max(sub_outputs.data, 1)

        super_probs = torch.nn.functional.softmax(super_outputs.data, 1)
        sub_probs = torch.nn.functional.softmax(sub_outputs.data, 1)

        super_scores, _ = torch.max(super_probs, 1)
        sub_scores, _ = torch.max(sub_probs, 1)

        super_energies = -torch.logsumexp(super_outputs.data, dim=1)
        sub_energies = -torch.logsumexp(sub_outputs.data, dim=1)

        super_predicted = torch.where(super_energies < super_energy_threshold,
                                      torch.tensor(3, device=super_predicted.device),
                                      super_predicted)
        sub_predicted = torch.where(sub_energies < sub_energy_threshold,
                                    torch.tensor(87, device=sub_predicted.device),
                                    sub_predicted)

        test_predictions['image'].append(img_name[0])
        test_predictions['superclass_index'].append(super_predicted.item())
        test_predictions['subclass_index'].append(sub_predicted.item())

test_predictions = pd.DataFrame(data=test_predictions)

current_time = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
if not os.path.exists(current_time):
    os.mkdir(current_time)
test_predictions.to_csv(current_time + '/test_predictions.csv', index=False)

100%|██████████| 11180/11180 [06:54<00:00, 26.97it/s]
