<a href="https://colab.research.google.com/github/rileytyh/NNDL-TransferLearning/blob/main/W4995_TransferLearning_Project_ConvNeXt.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!unzip Released_Data_NNDL_2025-20250510T140910Z-001.zip
!unzip Released_Data_NNDL_2025/train_images.zip
!unzip Released_Data_NNDL_2025/test_images.zip

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  inflating: test_images/7677.jpg    
  inflating: __MACOSX/test_images/._7677.jpg  
  inflating: test_images/8544.jpg    
  inflating: __MACOSX/test_images/._8544.jpg  
  inflating: test_images/9882.jpg    
  inflating: __MACOSX/test_images/._9882.jpg  
  inflating: test_images/792.jpg     
  inflating: __MACOSX/test_images/._792.jpg  
  inflating: test_images/1206.jpg    
  inflating: __MACOSX/test_images/._1206.jpg  
  inflating: test_images/6569.jpg    
  inflating: __MACOSX/test_images/._6569.jpg  
  inflating: test_images/5060.jpg    
  inflating: __MACOSX/test_images/._5060.jpg  
  inflating: test_images/10248.jpg   
  inflating: __MACOSX/test_images/._10248.jpg  
  inflating: test_images/3411.jpg    
  inflating: __MACOSX/test_images/._3411.jpg  
  inflating: test_images/11156.jpg   
  inflating: __MACOSX/test_images/._11156.jpg  
  inflating: test_images/4418.jpg    
  inflating: __MACOSX/test_images/._4418.jpg  

In [None]:
import os
import random
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from datetime import datetime
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision
import torchvision.models as models

from torch.utils.data import Dataset, DataLoader, BatchSampler, random_split
from torchvision import transforms
from PIL import Image

import timm

In [None]:
for m in timm.list_models(pretrained=True):
  print(m)

aimv2_1b_patch14_224.apple_pt
aimv2_1b_patch14_336.apple_pt
aimv2_1b_patch14_448.apple_pt
aimv2_3b_patch14_224.apple_pt
aimv2_3b_patch14_336.apple_pt
aimv2_3b_patch14_448.apple_pt
aimv2_huge_patch14_224.apple_pt
aimv2_huge_patch14_336.apple_pt
aimv2_huge_patch14_448.apple_pt
aimv2_large_patch14_224.apple_pt
aimv2_large_patch14_224.apple_pt_dist
aimv2_large_patch14_336.apple_pt
aimv2_large_patch14_336.apple_pt_dist
aimv2_large_patch14_448.apple_pt
bat_resnext26ts.ch_in1k
beit_base_patch16_224.in22k_ft_in22k
beit_base_patch16_224.in22k_ft_in22k_in1k
beit_base_patch16_384.in22k_ft_in22k_in1k
beit_large_patch16_224.in22k_ft_in22k
beit_large_patch16_224.in22k_ft_in22k_in1k
beit_large_patch16_384.in22k_ft_in22k_in1k
beit_large_patch16_512.in22k_ft_in22k_in1k
beitv2_base_patch16_224.in1k_ft_in1k
beitv2_base_patch16_224.in1k_ft_in22k
beitv2_base_patch16_224.in1k_ft_in22k_in1k
beitv2_large_patch16_224.in1k_ft_in1k
beitv2_large_patch16_224.in1k_ft_in22k
beitv2_large_patch16_224.in1k_ft_in22k_in1

In [None]:
# Create Dataset class for multilabel classification
class MultiClassImageDataset(Dataset):
    def __init__(self, ann_df, super_map_df, sub_map_df, img_dir, transform=None):
        self.ann_df = ann_df
        self.super_map_df = super_map_df
        self.sub_map_df = sub_map_df
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return len(self.ann_df)

    def __getitem__(self, idx):
        img_name = self.ann_df['image'][idx]
        img_path = os.path.join(self.img_dir, img_name)
        image = Image.open(img_path).convert('RGB')

        super_idx = self.ann_df['superclass_index'][idx]
        super_label = self.super_map_df['class'][super_idx]

        sub_idx = self.ann_df['subclass_index'][idx]
        sub_label = self.sub_map_df['class'][sub_idx]

        if self.transform:
            image = self.transform(image)

        return image, super_idx, super_label, sub_idx, sub_label

class MultiClassImageTestDataset(Dataset):
    def __init__(self, super_map_df, sub_map_df, img_dir, transform=None):
        self.super_map_df = super_map_df
        self.sub_map_df = sub_map_df
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self): # Count files in img_dir
        return len([fname for fname in os.listdir(self.img_dir)])

    def __getitem__(self, idx):
        img_name = str(idx) + '.jpg'
        img_path = os.path.join(self.img_dir, img_name)
        image = Image.open(img_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        return image, img_name

In [None]:
train_ann_df = pd.read_csv('Released_Data_NNDL_2025/train_data.csv')
super_map_df = pd.read_csv('Released_Data_NNDL_2025/superclass_mapping.csv')
sub_map_df = pd.read_csv('Released_Data_NNDL_2025/subclass_mapping.csv')

train_img_dir = 'train_images'
test_img_dir = 'test_images'

timm_model_name = 'convnextv2_huge.fcmae_ft_in22k_in1k_384'

convnext_model = timm.create_model(timm_model_name, pretrained=True, num_classes=0)
convnext_data_config = timm.data.resolve_model_data_config(convnext_model)
convnext_transforms = timm.data.create_transform(**convnext_data_config, is_training=False)

last_layer_nodes = convnext_model.num_features
print(last_layer_nodes)

del convnext_model

# image_preprocessing = transforms.Compose([
#     transforms.ToTensor(),
#     transforms.Normalize(mean=(0), std=(1)),
# ])

# Create train and val split
train_dataset = MultiClassImageDataset(train_ann_df, super_map_df, sub_map_df, train_img_dir, transform=convnext_transforms)
train_dataset, val_dataset = random_split(train_dataset, [0.9, 0.1])

# Create test dataset
test_dataset = MultiClassImageTestDataset(super_map_df, sub_map_df, test_img_dir, transform=convnext_transforms)

# Create dataloaders
batch_size = 16
train_loader = DataLoader(train_dataset,
                          batch_size=batch_size,
                          shuffle=True)

val_loader = DataLoader(val_dataset,
                        batch_size=batch_size,
                        shuffle=True)

test_loader = DataLoader(test_dataset,
                         batch_size=1,
                         shuffle=False)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


model.safetensors:   0%|          | 0.00/2.64G [00:00<?, ?B/s]

2816


In [None]:
class GeneralModel(nn.Module):
    def __init__(self, freeze_main=True):
        super().__init__()

        self.freeze_main = freeze_main

        self.main_model = timm.create_model(timm_model_name, pretrained=True, num_classes=0)
        self.main_model.eval()

        self.super_fc = nn.Linear(last_layer_nodes, 4)
        self.sub_fc = nn.Linear(last_layer_nodes, 88)

    def forward(self, x):
        if self.freeze_main == True:
            with torch.no_grad():
                x = self.main_model(x)
        else:
                x = self.main_model(x)

        super_out = self.super_fc(x)
        sub_out = self.sub_fc(x)

        return super_out, sub_out

class Trainer():
    def __init__(self, model, criterion, optimizer, train_loader, val_loader, test_loader=None, device='cuda'):
        self.model = model
        self.criterion = criterion
        self.optimizer = optimizer
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.test_loader = test_loader

    def train_epoch(self):
        running_loss = 0.0
        for i, data in tqdm(enumerate(self.train_loader), total=len(self.train_loader)):
            inputs, super_labels, sub_labels = data[0].to(device), data[1].to(device), data[3].to(device)

            self.optimizer.zero_grad()
            super_outputs, sub_outputs = self.model(inputs)
            loss = self.criterion(super_outputs, super_labels) + self.criterion(sub_outputs, sub_labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f'Training loss: {running_loss/i:.3f}')

    def validate_epoch(self):
        super_correct = 0
        sub_correct = 0
        total = 0
        running_loss = 0.0
        with torch.no_grad():
            for i, data in tqdm(enumerate(self.val_loader), total=len(self.val_loader)):
                inputs, super_labels, sub_labels = data[0].to(device), data[1].to(device), data[3].to(device)

                super_outputs, sub_outputs = self.model(inputs)
                loss = self.criterion(super_outputs, super_labels) + self.criterion(sub_outputs, sub_labels)
                _, super_predicted = torch.max(super_outputs.data, 1)
                _, sub_predicted = torch.max(sub_outputs.data, 1)

                total += super_labels.size(0)
                super_correct += (super_predicted == super_labels).sum().item()
                sub_correct += (sub_predicted == sub_labels).sum().item()
                running_loss += loss.item()

        print(f'Validation loss: {running_loss/i:.3f}')
        print(f'Validation superclass acc: {100 * super_correct / total:.2f} %')
        print(f'Validation subclass acc: {100 * sub_correct / total:.2f} %')

    def test(self, save_to_csv=False, return_predictions=False):
        if not self.test_loader:
            raise NotImplementedError('test_loader not specified')

        # Evaluate on test set, in this simple demo no special care is taken for novel/unseen classes
        test_predictions = {'image': [], 'superclass_index': [], 'subclass_index': []}
        with torch.no_grad():
            for i, data in tqdm(enumerate(self.test_loader), total=len(self.test_loader)):
                inputs, img_name = data[0].to(device), data[1]

                super_outputs, sub_outputs = self.model(inputs)
                _, super_predicted = torch.max(super_outputs.data, 1)
                _, sub_predicted = torch.max(sub_outputs.data, 1)

                test_predictions['image'].append(img_name[0])
                test_predictions['superclass_index'].append(super_predicted.item())
                test_predictions['subclass_index'].append(sub_predicted.item())

        test_predictions = pd.DataFrame(data=test_predictions)

        if save_to_csv:
            current_time = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
            if not os.path.exists(current_time):
                os.mkdir(current_time)
            test_predictions.to_csv(current_time + '/test_predictions.csv', index=False)

        if return_predictions:
            return test_predictions

In [None]:
# Init model and trainer
device = 'cuda'
model = GeneralModel()
# model.load_state_dict(torch.load('convnext_model_final.pt', weights_only=True))
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
trainer = Trainer(model, criterion, optimizer, train_loader, val_loader, test_loader)

In [None]:
# Training loop
for epoch in range(5):
    print(f'Epoch {epoch+1}')
    trainer.train_epoch()
    trainer.validate_epoch()
    print('')

print('Finished Training')

Epoch 1


100%|██████████| 354/354 [04:20<00:00,  1.36it/s]


Training loss: 0.436


100%|██████████| 40/40 [00:28<00:00,  1.39it/s]


Validation loss: 0.197
Validation superclass acc: 100.00 %
Validation subclass acc: 94.11 %

Epoch 2


100%|██████████| 354/354 [04:19<00:00,  1.36it/s]


Training loss: 0.088


100%|██████████| 40/40 [00:28<00:00,  1.39it/s]


Validation loss: 0.170
Validation superclass acc: 100.00 %
Validation subclass acc: 94.90 %

Epoch 3


100%|██████████| 354/354 [04:19<00:00,  1.36it/s]


Training loss: 0.053


100%|██████████| 40/40 [00:28<00:00,  1.39it/s]


Validation loss: 0.100
Validation superclass acc: 100.00 %
Validation subclass acc: 96.97 %

Epoch 4


100%|██████████| 354/354 [04:19<00:00,  1.36it/s]


Training loss: 0.034


100%|██████████| 40/40 [00:28<00:00,  1.39it/s]


Validation loss: 0.123
Validation superclass acc: 100.00 %
Validation subclass acc: 96.50 %

Epoch 5


100%|██████████| 354/354 [04:19<00:00,  1.36it/s]


Training loss: 0.038


100%|██████████| 40/40 [00:28<00:00,  1.39it/s]

Validation loss: 0.139
Validation superclass acc: 100.00 %
Validation subclass acc: 95.86 %

Finished Training





In [None]:
torch.save(model.state_dict(), 'convnext_model_final.pt')

In [None]:
test_predictions = trainer.test(save_to_csv=True, return_predictions=True)

100%|██████████| 11180/11180 [10:23<00:00, 17.92it/s]


In [None]:
# validation with collecting of prob scores to help with deciding threshold
super_correct = 0
sub_correct = 0
total = 0
running_loss = 0.

super_all_scores = []
sub_all_scores = []

super_all_energies = []
sub_all_energies = []

with torch.no_grad():
    for i, data in tqdm(enumerate(trainer.val_loader), total=len(trainer.val_loader)):
        inputs, super_labels, sub_labels = data[0].to(device), data[1].to(device), data[3].to(device)

        super_outputs, sub_outputs = trainer.model(inputs)
        loss = trainer.criterion(super_outputs, super_labels) + trainer.criterion(sub_outputs, sub_labels)
        _, super_predicted = torch.max(super_outputs.data, 1)
        _, sub_predicted = torch.max(sub_outputs.data, 1)

        super_probs = torch.nn.functional.softmax(super_outputs.data, 1)
        sub_probs = torch.nn.functional.softmax(sub_outputs.data, 1)

        super_scores, _ = torch.max(super_probs, 1)
        sub_scores, _ = torch.max(sub_probs, 1)

        super_energies = -torch.logsumexp(super_outputs.data, dim=1)
        sub_energies = -torch.logsumexp(sub_outputs.data, dim=1)

        super_all_scores.extend(super_scores.tolist())
        sub_all_scores.extend(sub_scores.tolist())

        super_all_energies.extend(super_energies.tolist())
        sub_all_energies.extend(sub_energies.tolist())

        total += super_labels.size(0)
        super_correct += (super_predicted == super_labels).sum().item()
        sub_correct += (sub_predicted == sub_labels).sum().item()
        running_loss += loss.item()

print(f'Validation loss: {running_loss/i:.3f}')
print(f'Validation superclass acc: {100 * super_correct / total:.2f} %')
print(f'Validation subclass acc: {100 * sub_correct / total:.2f} %')

100%|██████████| 40/40 [00:29<00:00,  1.38it/s]

Validation loss: 0.139
Validation superclass acc: 100.00 %
Validation subclass acc: 95.86 %





In [None]:
display(pd.DataFrame(super_all_scores).describe())
display(pd.DataFrame(sub_all_scores).describe())

Unnamed: 0,0
count,628.0
mean,0.99994
std,0.000296
min,0.993611
25%,0.999968
50%,0.999994
75%,0.999999
max,1.0


Unnamed: 0,0
count,628.0
mean,0.987351
std,0.056163
min,0.429258
25%,0.999165
50%,0.999879
75%,0.999975
max,1.0


In [None]:
display(pd.DataFrame(super_all_energies).describe())
display(pd.DataFrame(sub_all_energies).describe())

Unnamed: 0,0
count,628.0
mean,-9.146938
std,1.785976
min,-14.68965
25%,-10.299514
50%,-9.022823
75%,-7.847642
max,-4.207986


Unnamed: 0,0
count,628.0
mean,-12.245291
std,3.007796
min,-23.391033
25%,-13.70545
50%,-11.888737
75%,-10.219782
max,-4.362776


In [None]:
super_score_threshold = float(np.mean(super_all_scores) - 3 * np.std(super_all_scores))
sub_score_threshold = float(np.mean(sub_all_scores) - 3 * np.std(sub_all_scores))

print(super_score_threshold)
print(sub_score_threshold)

0.9990530562800088
0.8189966143347298


In [None]:
super_energy_threshold = float(np.mean(super_all_energies) - 3 * np.std(super_all_energies))
sub_energy_threshold = float(np.mean(sub_all_energies) - 3 * np.std(sub_all_energies))

print(super_energy_threshold)
print(sub_energy_threshold)

-14.50059670341956
-21.261492462608906


In [None]:
# with chosen thresholds, make novel OOD decisions as well using softmax scores
test_predictions = {'image': [], 'superclass_index': [], 'subclass_index': []}
with torch.no_grad():
    for i, data in tqdm(enumerate(trainer.test_loader), total=len(trainer.test_loader)):
        inputs, img_name = data[0].to(device), data[1]

        super_outputs, sub_outputs = trainer.model(inputs)
        _, super_predicted = torch.max(super_outputs.data, 1)
        _, sub_predicted = torch.max(sub_outputs.data, 1)

        super_probs = torch.nn.functional.softmax(super_outputs.data, 1)
        sub_probs = torch.nn.functional.softmax(sub_outputs.data, 1)

        super_scores, _ = torch.max(super_probs, 1)
        sub_scores, _ = torch.max(sub_probs, 1)

        super_energies = -torch.logsumexp(super_outputs.data, dim=1)
        sub_energies = -torch.logsumexp(sub_outputs.data, dim=1)

        super_predicted = torch.where(super_scores < super_score_threshold,
                                      torch.tensor(3, device=super_predicted.device),
                                      super_predicted)
        sub_predicted = torch.where(sub_scores < sub_score_threshold,
                                    torch.tensor(87, device=sub_predicted.device),
                                    sub_predicted)

        test_predictions['image'].append(img_name[0])
        test_predictions['superclass_index'].append(super_predicted.item())
        test_predictions['subclass_index'].append(sub_predicted.item())

test_predictions = pd.DataFrame(data=test_predictions)

current_time = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
if not os.path.exists(current_time):
    os.mkdir(current_time)
test_predictions.to_csv(current_time + '/test_predictions.csv', index=False)

100%|██████████| 11180/11180 [10:25<00:00, 17.88it/s]


In [None]:
# with chosen thresholds, make novel OOD decisions as well using energies
test_predictions = {'image': [], 'superclass_index': [], 'subclass_index': []}
with torch.no_grad():
    for i, data in tqdm(enumerate(trainer.test_loader), total=len(trainer.test_loader)):
        inputs, img_name = data[0].to(device), data[1]

        super_outputs, sub_outputs = trainer.model(inputs)
        _, super_predicted = torch.max(super_outputs.data, 1)
        _, sub_predicted = torch.max(sub_outputs.data, 1)

        super_probs = torch.nn.functional.softmax(super_outputs.data, 1)
        sub_probs = torch.nn.functional.softmax(sub_outputs.data, 1)

        super_scores, _ = torch.max(super_probs, 1)
        sub_scores, _ = torch.max(sub_probs, 1)

        super_energies = -torch.logsumexp(super_outputs.data, dim=1)
        sub_energies = -torch.logsumexp(sub_outputs.data, dim=1)

        super_predicted = torch.where(super_energies < super_energy_threshold,
                                      torch.tensor(3, device=super_predicted.device),
                                      super_predicted)
        sub_predicted = torch.where(sub_energies < sub_energy_threshold,
                                    torch.tensor(87, device=sub_predicted.device),
                                    sub_predicted)

        test_predictions['image'].append(img_name[0])
        test_predictions['superclass_index'].append(super_predicted.item())
        test_predictions['subclass_index'].append(sub_predicted.item())

test_predictions = pd.DataFrame(data=test_predictions)

current_time = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
if not os.path.exists(current_time):
    os.mkdir(current_time)
test_predictions.to_csv(current_time + '/test_predictions.csv', index=False)

100%|██████████| 11180/11180 [10:25<00:00, 17.86it/s]


In [None]:
# with handcoded manual thresholds, make novel OOD decisions as well using softmax scores
test_predictions = {'image': [], 'superclass_index': [], 'subclass_index': []}
with torch.no_grad():
    for i, data in tqdm(enumerate(trainer.test_loader), total=len(trainer.test_loader)):
        inputs, img_name = data[0].to(device), data[1]

        super_outputs, sub_outputs = trainer.model(inputs)
        _, super_predicted = torch.max(super_outputs.data, 1)
        _, sub_predicted = torch.max(sub_outputs.data, 1)

        super_probs = torch.nn.functional.softmax(super_outputs.data, 1)
        sub_probs = torch.nn.functional.softmax(sub_outputs.data, 1)

        super_scores, _ = torch.max(super_probs, 1)
        sub_scores, _ = torch.max(sub_probs, 1)

        super_energies = -torch.logsumexp(super_outputs.data, dim=1)
        sub_energies = -torch.logsumexp(sub_outputs.data, dim=1)

        super_predicted = torch.where(super_scores < 0.95,
                                      torch.tensor(3, device=super_predicted.device),
                                      super_predicted)
        sub_predicted = torch.where(sub_scores < 0.75,
                                    torch.tensor(87, device=sub_predicted.device),
                                    sub_predicted)

        test_predictions['image'].append(img_name[0])
        test_predictions['superclass_index'].append(super_predicted.item())
        test_predictions['subclass_index'].append(sub_predicted.item())

test_predictions = pd.DataFrame(data=test_predictions)

current_time = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
if not os.path.exists(current_time):
    os.mkdir(current_time)
test_predictions.to_csv(current_time + '/test_predictions.csv', index=False)

100%|██████████| 11180/11180 [10:25<00:00, 17.88it/s]
