In [None]:
import os
import random
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torchvision.models import efficientnet_v2_s, EfficientNet_V2_S_Weights
from torchvision.transforms.autoaugment import AutoAugmentPolicy

from torch.utils.data import Dataset, DataLoader, BatchSampler, random_split
from torchvision import transforms
from PIL import Image

In [None]:
!unzip data.zip

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  inflating: data/train_images/283.jpg  
  inflating: __MACOSX/data/train_images/._283.jpg  
  inflating: data/train_images/4647.jpg  
  inflating: __MACOSX/data/train_images/._4647.jpg  
  inflating: data/train_images/3128.jpg  
  inflating: __MACOSX/data/train_images/._3128.jpg  
  inflating: data/train_images/2236.jpg  
  inflating: __MACOSX/data/train_images/._2236.jpg  
  inflating: data/train_images/5559.jpg  
  inflating: __MACOSX/data/train_images/._5559.jpg  
  inflating: data/train_images/6050.jpg  
  inflating: __MACOSX/data/train_images/._6050.jpg  
  inflating: data/train_images/1059.jpg  
  inflating: __MACOSX/data/train_images/._1059.jpg  
  inflating: data/train_images/3896.jpg  
  inflating: __MACOSX/data/train_images/._3896.jpg  
  inflating: data/train_images/2550.jpg  
  inflating: __MACOSX/data/train_images/._2550.jpg  
  inflating: data/train_images/4121.jpg  
  inflating: __MACOSX/data/train_images/

In [None]:
# Create Dataset class for multilabel classification
class MultiClassImageDataset(Dataset):
    def __init__(self, ann_df, super_map_df, sub_map_df, img_dir, transform=None):
        self.ann_df = ann_df
        self.super_map_df = super_map_df
        self.sub_map_df = sub_map_df
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return len(self.ann_df)

    def __getitem__(self, idx):
        img_name = self.ann_df['image'][idx]
        img_path = os.path.join(self.img_dir, img_name)
        image = Image.open(img_path).convert('RGB')

        super_idx = self.ann_df['superclass_index'][idx]
        super_label = self.super_map_df['class'][super_idx]

        sub_idx = self.ann_df['subclass_index'][idx]
        sub_label = self.sub_map_df['class'][sub_idx]

        if self.transform:
            image = self.transform(image)

        return image, super_idx, super_label, sub_idx, sub_label

class MultiClassImageTestDataset(Dataset):
    def __init__(self, super_map_df, sub_map_df, img_dir, transform=None):
        self.super_map_df = super_map_df
        self.sub_map_df = sub_map_df
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self): # Count files in img_dir
        return len([fname for fname in os.listdir(self.img_dir)])

    def __getitem__(self, idx):
        img_name = str(idx) + '.jpg'
        img_path = os.path.join(self.img_dir, img_name)
        image = Image.open(img_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        return image, img_name

In [None]:
def calculate_weights(train_ann_df):
    # Assuming you have the counts for both superclass and subclass
    superclass_counts = train_ann_df['superclass_index'].value_counts().sort_index()
    subclass_counts = train_ann_df['subclass_index'].value_counts().sort_index()

    # Calculate weights for both
    superclass_weights = 1.0 / superclass_counts
    superclass_weights = superclass_weights / superclass_weights.sum()
    superclass_weights = torch.tensor(superclass_weights.values, dtype=torch.float32)
    superclass_weights = torch.cat([superclass_weights, torch.tensor([0.0])])

    subclass_weights = 1.0 / subclass_counts
    subclass_weights = subclass_weights / subclass_weights.sum()
    subclass_weights = torch.tensor(subclass_weights.values, dtype=torch.float32)
    subclass_weights = torch.cat([subclass_weights, torch.tensor([0.0])])

    print(superclass_weights)
    print(subclass_weights)
    return superclass_weights, subclass_weights

In [None]:
train_ann_df = pd.read_csv('data/train_data.csv')
# test_ann_df = pd.read_csv('data/test_data.csv')
super_map_df = pd.read_csv('data/superclass_mapping.csv')
sub_map_df = pd.read_csv('data/subclass_mapping.csv')

train_img_dir = 'data/train_images'
test_img_dir = 'data/test_images'

augmentation_setups = {
    "baseline": transforms.Compose([
        transforms.Resize((384, 384)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ]),
    "autoaugment": transforms.Compose([
        transforms.Resize((384, 384)),
        transforms.AutoAugment(policy=AutoAugmentPolicy.IMAGENET),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ]),
    "manual_combo": transforms.Compose([
        transforms.Resize((384, 384)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomRotation(degrees=15),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
        transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ])
}

# image_preprocessing = transforms.Compose([
#     transforms.Resize((384, 384)),
#     transforms.ToTensor(),
#     transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
# ])

superclass_weights, subclass_weights = calculate_weights(train_ann_df)

# Create train and val split
train_dataset = MultiClassImageDataset(train_ann_df, super_map_df, sub_map_df, train_img_dir, transform=augmentation_setups["manual_combo"])
train_dataset, val_dataset = random_split(train_dataset, [0.9, 0.1])

# Create test dataset
test_dataset = MultiClassImageTestDataset(super_map_df, sub_map_df, test_img_dir, transform=augmentation_setups["baseline"])

# Create dataloaders
batch_size = 32
train_loader = DataLoader(train_dataset,
                          batch_size=batch_size,
                          shuffle=True)

val_loader = DataLoader(val_dataset,
                        batch_size=batch_size,
                        shuffle=True)

test_loader = DataLoader(test_dataset,
                         batch_size=1,
                         shuffle=False)

tensor([0.3740, 0.3320, 0.2939, 0.0000])
tensor([0.0150, 0.0074, 0.0150, 0.0147, 0.0074, 0.0147, 0.0074, 0.0074, 0.0147,
        0.0150, 0.0150, 0.0147, 0.0150, 0.0150, 0.0147, 0.0150, 0.0147, 0.0150,
        0.0074, 0.0147, 0.0147, 0.0074, 0.0074, 0.0150, 0.0074, 0.0150, 0.0147,
        0.0074, 0.0074, 0.0073, 0.0074, 0.0074, 0.0150, 0.0144, 0.0150, 0.0072,
        0.0074, 0.0074, 0.0150, 0.0150, 0.0147, 0.0147, 0.0147, 0.0074, 0.0074,
        0.0150, 0.0074, 0.0074, 0.0150, 0.0074, 0.0074, 0.0147, 0.0074, 0.0150,
        0.0150, 0.0147, 0.0147, 0.0073, 0.0074, 0.0147, 0.0147, 0.0073, 0.0074,
        0.0072, 0.0074, 0.0074, 0.0072, 0.0150, 0.0150, 0.0073, 0.0074, 0.0074,
        0.0074, 0.0147, 0.0150, 0.0074, 0.0074, 0.0150, 0.0147, 0.0150, 0.0147,
        0.0073, 0.0147, 0.0147, 0.0074, 0.0150, 0.0147, 0.0000])


In [None]:
from re import sub
class EfficientNetV2MultiHead(nn.Module):
    def __init__(self, num_super=4, num_sub=88, pretrained=True):
        super().__init__()
        weights = EfficientNet_V2_S_Weights.DEFAULT if pretrained else None
        pre_trained = efficientnet_v2_s(weights=weights)
        self.backbone = pre_trained.features
        self.avgpool = pre_trained.avgpool
        self.fc = nn.Linear(1280, 256)
        self.dropout = nn.Dropout(0.3)
        self.classifer_super = nn.Linear(256, num_super)
        self.classifer_sub = nn.Linear(256 + num_super, num_sub)
        self.threshold = 0.6

    def forward(self, x):
        x = self.backbone(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = F.relu(self.fc(x))
        x = self.dropout(x)
        # print(x.size())
        super_out = self.classifer_super(x)
        # print(super_out)

        super_probs = F.softmax(super_out, dim=1)
        x_with_super = torch.cat([x, super_probs], dim=1)

        sub_out = self.classifer_sub(x_with_super)
        sub_probs = F.softmax(sub_out, dim=1)
        return super_out, sub_out

    # def forward(self, x):
    #     x = self.backbone(x)
    #     x = self.avgpool(x)
    #     x = torch.flatten(x, 1)
    #     x = F.relu(self.fc(x))
    #     x = self.dropout(x)
    #     super_out = self.classifer_super(x)
    #     super_probs = F.softmax(super_out, dim=1)
    #     super_max_probs, super_preds = torch.max(super_probs, dim=1)

    #     x_with_super = torch.cat([x, super_probs], dim=1)
    #     sub_out = self.classifer_sub(x_with_super)
    #     sub_probs = F.softmax(sub_out, dim=1)
    #     sub_max_probs, sub_preds = torch.max(sub_probs, dim=1)

    #     # Mark predictions as novel if confidence is below threshold
    #     super_novel = torch.where(super_max_probs < self.threshold,
    #                             torch.tensor(3, device=super_preds.device),
    #                             super_preds)
    #     sub_novel = torch.where(sub_max_probs < self.threshold,
    #                           torch.tensor(87, device=sub_preds.device),
    #                           sub_preds)

    #     return super_novel, sub_novel

class Trainer():
    def __init__(self, model, super_criterion, sub_criterion, optimizer, train_loader, val_loader, test_loader=None, device='cuda'):
        self.model = model
        self.super_criterion = super_criterion
        self.sub_criterion = sub_criterion
        self.optimizer = optimizer
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.test_loader = test_loader

    def train_epoch(self):
        running_loss = 0.0
        for i, data in enumerate(self.train_loader):
            inputs, super_labels, sub_labels = data[0].to(device), data[1].to(device), data[3].to(device)

            self.optimizer.zero_grad()
            super_outputs, sub_outputs = self.model(inputs)
            # print(super_outputs)
            # print(sub_outputs)
            # print(super_outputs.dtype)
            # print(super_labels.dtype)
            # print(sub_outputs.dtype)
            # print(sub_labels.dtype)
            loss = self.super_criterion(super_outputs, super_labels) + self.sub_criterion(sub_outputs, sub_labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f'Training loss: {running_loss/i:.3f}')

    def validate_epoch(self):
        super_correct = 0
        sub_correct = 0
        total = 0
        running_loss = 0.0
        with torch.no_grad():
            for i, data in enumerate(self.val_loader):
                inputs, super_labels, sub_labels = data[0].to(device), data[1].to(device), data[3].to(device)

                super_outputs, sub_outputs = self.model(inputs)
                loss = self.super_criterion(super_outputs, super_labels) + self.sub_criterion(sub_outputs, sub_labels)
                _, super_predicted = torch.max(super_outputs.data, 1)
                _, sub_predicted = torch.max(sub_outputs.data, 1)

                total += super_labels.size(0)
                super_correct += (super_predicted == super_labels).sum().item()
                sub_correct += (sub_predicted == sub_labels).sum().item()
                running_loss += loss.item()

        print(f'Validation loss: {running_loss/i:.3f}')
        print(f'Validation superclass acc: {100 * super_correct / total:.2f} %')
        print(f'Validation subclass acc: {100 * sub_correct / total:.2f} %')

    def test(self, save_to_csv=False, return_predictions=False):
        # threshold = 0.4
        if not self.test_loader:
            raise NotImplementedError('test_loader not specified')

        # Evaluate on test set, in this simple demo no special care is taken for novel/unseen classes
        test_predictions = {'image': [], 'superclass_index': [], 'subclass_index': []}
        with torch.no_grad():
            for i, data in enumerate(self.test_loader):
                inputs, img_name = data[0].to(device), data[1]

                super_outputs, sub_outputs = self.model(inputs)
                super_probs = F.softmax(super_outputs, dim=1)
                sub_probs = F.softmax(sub_outputs, dim=1)

                super_max_probs, super_predicted = torch.max(super_probs, 1)
                sub_max_probs, sub_predicted = torch.max(sub_probs, 1)

                super_predicted = torch.where(super_max_probs < 0.6,
                                torch.tensor(3, device=super_predicted.device),
                                super_predicted)
                # print(sub_probs)
                print(super_max_probs)
                sub_predicted = torch.where(sub_max_probs < 0.08,
                                torch.tensor(87, device=sub_predicted.device),
                                sub_predicted)

                test_predictions['image'].append(img_name[0])
                test_predictions['superclass_index'].append(super_predicted.item())
                test_predictions['subclass_index'].append(sub_predicted.item())

        test_predictions = pd.DataFrame(data=test_predictions)

        if save_to_csv:
            test_predictions.to_csv('example_test_predictions.csv', index=False)

        if return_predictions:
            return test_predictions

In [None]:
# Init model and trainer
device = 'cuda'
num_super_classes = 4
num_sub_classes = 88
model = EfficientNetV2MultiHead(num_super=num_super_classes, num_sub=num_sub_classes, pretrained=True).to(device)
# print(model)
for param in model.backbone.parameters():
    param.requires_grad = False
last_blocks = list(model.backbone)[-3:]  # Unfreeze last 3 blocks
for block in last_blocks:
    for param in block.parameters():
        param.requires_grad = True

superclass_weights = superclass_weights.to(device)
subclass_weights = subclass_weights.to(device)
super_criterion = nn.CrossEntropyLoss(weight=superclass_weights)
sub_criterion = nn.CrossEntropyLoss(weight=subclass_weights)

optimizer = optim.Adam(model.parameters(), lr=1e-4)
trainer = Trainer(model, super_criterion, sub_criterion, optimizer, train_loader, val_loader, test_loader)

In [None]:
torch.cuda.empty_cache()

In [None]:
# Training loop
for epoch in range(20):
    print(f'Epoch {epoch+1}')
    trainer.train_epoch()
    trainer.validate_epoch()
    print('')

print('Finished Training')

Epoch 1
Training loss: 3.694
Validation loss: 2.231
Validation superclass acc: 98.89 %
Validation subclass acc: 53.50 %

Epoch 2
Training loss: 1.132
Validation loss: 0.778
Validation superclass acc: 99.84 %
Validation subclass acc: 83.12 %

Epoch 3
Training loss: 0.448
Validation loss: 0.433
Validation superclass acc: 100.00 %
Validation subclass acc: 88.54 %

Epoch 4
Training loss: 0.263
Validation loss: 0.309
Validation superclass acc: 100.00 %
Validation subclass acc: 92.04 %

Epoch 5
Training loss: 0.199
Validation loss: 0.303
Validation superclass acc: 100.00 %
Validation subclass acc: 92.52 %

Epoch 6
Training loss: 0.143
Validation loss: 0.252
Validation superclass acc: 99.84 %
Validation subclass acc: 94.43 %

Epoch 7
Training loss: 0.120
Validation loss: 0.227
Validation superclass acc: 100.00 %
Validation subclass acc: 95.06 %

Epoch 8
Training loss: 0.097
Validation loss: 0.210
Validation superclass acc: 100.00 %
Validation subclass acc: 94.75 %

Epoch 9
Training loss: 0.08

In [None]:
test_predictions = trainer.test(save_to_csv=True, return_predictions=True)

In [None]:
# Quick script for evaluating generated csv files with ground truth

super_correct = 0
sub_correct = 0
seen_super_correct = 0
seen_sub_correct = 0
unseen_super_correct = 0
unseen_sub_correct = 0

total = 0
seen_super_total = 0
unseen_super_total = 0
seen_sub_total = 0
unseen_sub_total = 0

for i in range(len(test_predictions)):
    super_pred = test_predictions['superclass_index'][i]
    sub_pred = test_predictions['subclass_index'][i]

    super_gt = test_ann_df['superclass_index'][i]
    sub_gt = test_ann_df['subclass_index'][i]

    # Total setting
    if super_pred == super_gt:
        super_correct += 1
    if sub_pred == sub_gt:
        sub_correct += 1
    total += 1

    # Unseen superclass setting
    if super_gt == 3:
        if super_pred == super_gt:
            unseen_super_correct += 1
        if sub_pred == sub_gt:
            unseen_sub_correct += 1
        unseen_super_total += 1
        unseen_sub_total += 1

    # Seen superclass, unseen subclass setting
    if super_gt != 3 and sub_gt == 87:
        if super_pred == super_gt:
            seen_super_correct += 1
        if sub_pred == sub_gt:
            unseen_sub_correct += 1
        seen_super_total += 1
        unseen_sub_total += 1

    # Seen superclass and subclass setting
    if super_gt != 3 and sub_gt != 87:
        if super_pred == super_gt:
            seen_super_correct += 1
        if sub_pred == sub_gt:
            seen_sub_correct += 1
        seen_super_total += 1
        seen_sub_total += 1

print('Superclass Accuracy')
print(f'Overall: {100*super_correct/total:.2f} %')
print(f'Seen: {100*seen_super_correct/seen_super_total:.2f} %')
print(f'Unseen: {100*unseen_super_correct/unseen_super_total:.2f} %')

print('\nSubclass Accuracy')
print(f'Overall: {100*sub_correct/total:.2f} %')
print(f'Seen: {100*seen_sub_correct/seen_sub_total:.2f} %')
print(f'Unseen: {100*unseen_sub_correct/unseen_sub_total:.2f} %')

Superclass Accuracy
Overall: 43.83 %
Seen: 61.11 %
Unseen: 0.00 %

Subclass Accuracy
Overall: 2.03 %
Seen: 9.56 %
Unseen: 0.00 %
