In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models, transforms
from torch.utils.data import DataLoader, Dataset
from collections import defaultdict
import random
import os
from PIL import Image
import time
import datetime
from torch.utils.data import ConcatDataset
import copy

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [3]:
# ---- STEP 1: Define Dataset Wrapper ----
class SimpleImageDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        label = self.labels[idx]

        image = Image.open(img_path).convert("RGB")

        if self.transform:
            image = self.transform(image)

        return image, label

# ---- STEP 2: Define the Incremental Model ----
class IncrementalResNet(nn.Module):
    def __init__(self, num_classes):
        super(IncrementalResNet, self).__init__()
        base_model = models.resnet152(weights=models.ResNet152_Weights.DEFAULT)
        self.feature_extractor = nn.Sequential(*list(base_model.children())[:-1])  # remove fc
        self.fc = nn.Linear(2048, num_classes)

    def forward(self, x):
        # with torch.no_grad():
        #     features = self.feature_extractor(x)
        #     features = torch.flatten(features, 1)
        features = self.feature_extractor(x)
        features = torch.flatten(features, 1)
        logits = self.fc(features)
        return logits

    def add_classes(self, num_new):
        old_weights = self.fc.weight.data.clone()
        old_bias = self.fc.bias.data.clone()

        new_fc = nn.Linear(2048, self.fc.out_features + num_new)
        new_fc.weight.data[:self.fc.out_features] = old_weights
        new_fc.bias.data[:self.fc.out_features] = old_bias

        self.fc = new_fc

# ---- STEP 3: Training Function ----
def run_epoch(phase, model, dataloader, optimizer, criterion, device):
    model.train()
    running_loss = 0
    running_corrects = 0
    for imgs, labels in dataloader:
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer.zero_grad()
        with torch.set_grad_enabled(phase == 'train'):
            outputs = model(imgs)
            _, preds = torch.max(outputs, 1)
            loss = criterion(outputs, labels)
            if phase == 'train':
                loss.backward()
                optimizer.step()
        running_loss += loss.item() * imgs.size(0)
        running_corrects += torch.sum(preds == labels.data)
    num_samples = len(dataloader.dataset)
    return { 
        'loss': running_loss / num_samples, 
        'acc': running_corrects.double() / num_samples
    }

In [4]:
dataset_dir = "../insect-dataset/lepidoptera"

image_size = 224
img_header_footer_ratio = 1.1
normazile_x = [0.485, 0.456, 0.406]
normalize_y = [0.229, 0.224, 0.225]

def_transform = [
    transforms.Resize(int(image_size * img_header_footer_ratio)),
    transforms.CenterCrop((image_size, image_size)),
    transforms.RandomRotation(15, fill=(0, 0, 0)),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
    transforms.Normalize(normazile_x, normalize_y),
]

batch_size = 32

# Iteration 1

In [5]:
i1_train_dir = f"{dataset_dir}/i00001-train"
i1_val_dir = f"{dataset_dir}/i00001-val"

i1_classes = os.listdir(i1_train_dir)
num_i1_classes = len(i1_classes)
print(f"num_i1_classes = {num_i1_classes}")

i1_train_images = [ 
    (f"{i1_train_dir}/{class_dir}/{img}", i1_classes.index(class_dir)) 
    for class_dir in os.listdir(i1_train_dir) 
    for img in os.listdir(f"{i1_train_dir}/{class_dir}")
]
num_i1_train_images = len(i1_train_images)
print(f"num_i1_train_images = {num_i1_train_images}")
i1_train_dataset = SimpleImageDataset(
    image_paths = [ t[0] for t in i1_train_images],
    labels = [ t[1] for t in i1_train_images],
    transform = transforms.Compose(def_transform)
)
i1_train_loader = DataLoader(i1_train_dataset, batch_size=batch_size, shuffle=True)

i1_val_images = [ 
    (f"{i1_val_dir}/{class_dir}/{img}", i1_classes.index(class_dir)) 
    for class_dir in os.listdir(i1_val_dir) 
    for img in os.listdir(f"{i1_val_dir}/{class_dir}")
]
num_i1_val_images = len(i1_val_images)
print(f"num_i1_val_images = {num_i1_val_images}")
i1_val_dataset = SimpleImageDataset(
    image_paths = [ t[0] for t in i1_val_images],
    labels = [ t[1] for t in i1_val_images],
    transform = transforms.Compose(def_transform)
)
i1_val_loader = DataLoader(i1_val_dataset, batch_size=batch_size, shuffle=True)

num_i1_classes = 21
num_i1_train_images = 2460
num_i1_val_images = 46


In [6]:
model = IncrementalResNet(num_classes=num_i1_classes).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()

start_time = time.time()
for epoch in range(10):
    print(f"i1.e{epoch+1:03} | ", end='')
    
    train_result = run_epoch('train', model, i1_train_loader, optimizer, criterion, device)
    print(f"TRAIN loss:{train_result['loss']:.3f} acc:{train_result['acc']:.3f} | ", end='')
    
    val_result = run_epoch('val', model, i1_val_loader, optimizer, criterion, device)
    print(f"VAL loss: {val_result['loss']:.3f} acc:{val_result['acc']:.3f} | ", end='')
    
    print(f"Elapsed time: {datetime.timedelta(seconds=(time.time() - start_time))}")

i1_checkpoint = copy.deepcopy(model)

i1.e001 | TRAIN loss:1.355 acc:0.632 | VAL loss: 1.047 acc:0.696 | Elapsed time: 0:00:34.466492
i1.e002 | TRAIN loss:0.342 acc:0.885 | VAL loss: 0.537 acc:0.804 | Elapsed time: 0:01:10.626994
i1.e003 | TRAIN loss:0.174 acc:0.947 | VAL loss: 0.318 acc:0.870 | Elapsed time: 0:01:49.678433
i1.e004 | TRAIN loss:0.106 acc:0.963 | VAL loss: 0.309 acc:0.935 | Elapsed time: 0:02:30.762869
i1.e005 | TRAIN loss:0.060 acc:0.985 | VAL loss: 0.188 acc:0.957 | Elapsed time: 0:03:12.476010
i1.e006 | TRAIN loss:0.068 acc:0.978 | VAL loss: 0.168 acc:0.957 | Elapsed time: 0:03:55.005025
i1.e007 | TRAIN loss:0.044 acc:0.987 | VAL loss: 0.046 acc:1.000 | Elapsed time: 0:04:38.296883
i1.e008 | TRAIN loss:0.038 acc:0.989 | VAL loss: 0.046 acc:0.978 | Elapsed time: 0:05:22.042816
i1.e009 | TRAIN loss:0.027 acc:0.991 | VAL loss: 0.039 acc:1.000 | Elapsed time: 0:06:06.707914
i1.e010 | TRAIN loss:0.030 acc:0.991 | VAL loss: 0.044 acc:1.000 | Elapsed time: 0:06:52.235945


# Iteration 2
no overlapping classes

In [7]:
i2_train_dir = f"{dataset_dir}/i00002-train"
i2_val_dir = f"{dataset_dir}/i00002-val"

i2_classes = os.listdir(i2_train_dir)
num_i2_classes = len(i2_classes)
print(f"num_i2_classes = {num_i2_classes}")

i2_train_images = [ 
    (f"{i2_train_dir}/{class_dir}/{img}", num_i1_classes + i2_classes.index(class_dir)) 
    for class_dir in os.listdir(i2_train_dir) 
    for img in os.listdir(f"{i2_train_dir}/{class_dir}")
]
num_i2_train_images = len(i2_train_images)
print(f"num_i2_train_images = {num_i2_train_images}")
i2_train_dataset = SimpleImageDataset(
    image_paths = [ t[0] for t in i2_train_images],
    labels = [ t[1] for t in i2_train_images],
    transform = transforms.Compose(def_transform)
)
i2_train_loader = DataLoader(i2_train_dataset, batch_size=batch_size, shuffle=True)

i2_val_images = [ 
    (f"{i2_val_dir}/{class_dir}/{img}", num_i1_classes + i2_classes.index(class_dir)) 
    for class_dir in os.listdir(i2_val_dir) 
    for img in os.listdir(f"{i2_val_dir}/{class_dir}")
]
num_i2_val_images = len(i2_val_images)
print(f"num_i2_val_images = {num_i2_val_images}")
i2_val_dataset = SimpleImageDataset(
    image_paths = [ t[0] for t in i2_val_images],
    labels = [ t[1] for t in i2_val_images],
    transform = transforms.Compose(def_transform)
)
i2_val_loader = DataLoader(i2_val_dataset, batch_size=batch_size, shuffle=True)

num_i2_classes = 19
num_i2_train_images = 916
num_i2_val_images = 42


## with only i1 data

In [8]:
model = copy.deepcopy(i1_checkpoint)
model.add_classes(num_new=num_i2_classes)
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
print("Added new classes. New classifier size:", model.fc.out_features)

start_time = time.time()
for epoch in range(5):
    print(f"i2.e{epoch+1:03} | ", end='')
    
    train_result = run_epoch('train', model, i2_train_loader, optimizer, criterion, device)
    print(f"TRAIN loss:{train_result['loss']:.3f} acc:{train_result['acc']:.3f} | ", end='')
    
    val_result = run_epoch('val', model, i1_val_loader, optimizer, criterion, device)
    print(f"i1-VAL loss:{val_result['loss']:.3f} acc:{val_result['acc']:.3f} | ", end='')
    
    val_result = run_epoch('val', model, i2_val_loader, optimizer, criterion, device)
    print(f"i2-VAL loss:{val_result['loss']:.3f} acc:{val_result['acc']:.3f} | ", end='')
    
    print(f"Elapsed time: {datetime.timedelta(seconds=(time.time() - start_time))}")

Added new classes. New classifier size: 40
i2.e001 | TRAIN loss:1.833 acc:0.517 | i1-VAL loss:1.812 acc:0.500 | i2-VAL loss:1.867 acc:0.500 | Elapsed time: 0:00:18.000814
i2.e002 | TRAIN loss:0.452 acc:0.892 | i1-VAL loss:2.492 acc:0.326 | i2-VAL loss:0.833 acc:0.833 | Elapsed time: 0:00:35.832851
i2.e003 | TRAIN loss:0.129 acc:0.971 | i1-VAL loss:2.839 acc:0.261 | i2-VAL loss:0.427 acc:0.905 | Elapsed time: 0:00:54.128214
i2.e004 | TRAIN loss:0.061 acc:0.990 | i1-VAL loss:3.356 acc:0.283 | i2-VAL loss:0.237 acc:0.952 | Elapsed time: 0:01:12.253101
i2.e005 | TRAIN loss:0.031 acc:0.995 | i1-VAL loss:3.157 acc:0.239 | i2-VAL loss:0.313 acc:0.905 | Elapsed time: 0:01:30.693455


CATASTROPHIC FORGETTING... i1-VAL.acc down to 23.9%

## with i1+i2 data

In [9]:
model = copy.deepcopy(i1_checkpoint)
model.add_classes(num_new=num_i2_classes)
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
print("Added new classes. New classifier size:", model.fc.out_features)

combined_train_dataset = ConcatDataset([i1_train_dataset, i2_train_dataset])
combined_train_loader = DataLoader(combined_train_dataset, batch_size=batch_size, shuffle=True)

start_time = time.time()
for epoch in range(5):
    print(f"i2.e{epoch+1:03} | ", end='')
    
    train_result = run_epoch('train', model, combined_train_loader, optimizer, criterion, device)
    print(f"TRAIN loss:{train_result['loss']:.3f} acc:{train_result['acc']:.3f} | ", end='')
    
    val_result = run_epoch('val', model, i1_val_loader, optimizer, criterion, device)
    print(f"i1-VAL loss:{val_result['loss']:.3f} acc:{val_result['acc']:.3f} | ", end='')
    
    val_result = run_epoch('val', model, i2_val_loader, optimizer, criterion, device)
    print(f"i2-VAL loss:{val_result['loss']:.3f} acc:{val_result['acc']:.3f} | ", end='')
    
    print(f"Elapsed time: {datetime.timedelta(seconds=(time.time() - start_time))}")

Added new classes. New classifier size: 40
i2.e001 | TRAIN loss:0.351 acc:0.909 | i1-VAL loss:0.206 acc:0.913 | i2-VAL loss:1.163 acc:0.714 | Elapsed time: 0:01:03.489736
i2.e002 | TRAIN loss:0.090 acc:0.973 | i1-VAL loss:0.257 acc:0.935 | i2-VAL loss:0.662 acc:0.857 | Elapsed time: 0:02:07.336270
i2.e003 | TRAIN loss:0.058 acc:0.987 | i1-VAL loss:0.012 acc:1.000 | i2-VAL loss:0.308 acc:0.952 | Elapsed time: 0:03:11.821860
i2.e004 | TRAIN loss:0.033 acc:0.992 | i1-VAL loss:0.071 acc:0.978 | i2-VAL loss:0.278 acc:0.929 | Elapsed time: 0:04:15.221807
i2.e005 | TRAIN loss:0.043 acc:0.988 | i1-VAL loss:0.056 acc:0.978 | i2-VAL loss:0.189 acc:1.000 | Elapsed time: 0:05:18.450692


## with only i1 data + using distillation loss

In [11]:
import torch.nn.functional as F

def distillation_loss(student_logits, teacher_logits, temperature=2.0):
    student_log_probs = F.log_softmax(student_logits / temperature, dim=1)
    teacher_probs = F.softmax(teacher_logits / temperature, dim=1)
    return F.kl_div(student_log_probs, teacher_probs, reduction='batchmean') * (temperature ** 2)

def run_epoch_v2(phase, model, dataloader, optimizer, criterion, device, teacher_model=None, distill_lambda=1.0, temperature=2.0):
    model.train() if phase == 'train' else model.eval()
    total_loss, total_correct, total_samples = 0.0, 0, 0

    for imgs, labels in dataloader:
        imgs, labels = imgs.to(device), labels.to(device)

        optimizer.zero_grad()
        with torch.set_grad_enabled(phase == 'train'):
            outputs = model(imgs)
            loss = criterion(outputs, labels)

            # Add distillation loss if teacher is given
            if teacher_model is not None and phase == 'train':
                with torch.no_grad():
                    teacher_outputs = teacher_model(imgs)
                loss += distill_lambda * distillation_loss(outputs, teacher_outputs, temperature)

            if phase == 'train':
                loss.backward()
                optimizer.step()

        total_loss += loss.item() * imgs.size(0)
        _, preds = torch.max(outputs, 1)
        total_correct += torch.sum(preds == labels).item()
        total_samples += imgs.size(0)

    return {
        "loss": total_loss / total_samples,
        "acc": total_correct / total_samples,
    }

In [20]:
model = copy.deepcopy(i1_checkpoint)
model.add_classes(num_new=num_i2_classes)
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
print("Added new classes. New classifier size:", model.fc.out_features)

teacher_model = copy.deepcopy(model)
teacher_model.eval()
for p in teacher_model.parameters():
    p.requires_grad = False

for epoch in range(5):
    print(f"i2.e{epoch+1:03} | ", end='')

    train_result = run_epoch_v2('train', model, i2_train_loader, optimizer, criterion, device, 
                                teacher_model=teacher_model, distill_lambda=1.0, temperature=2.0)
    print(f"TRAIN loss:{train_result['loss']:.3f} acc:{train_result['acc']:.3f} | ", end='')

    val_result_i1 = run_epoch_v2('val', model, i1_val_loader, optimizer, criterion, device, 
                                teacher_model=None, distill_lambda=1.0, temperature=2.0)
    print(f"i1-VAL loss:{val_result_i1['loss']:.3f} acc:{val_result_i1['acc']:.3f} | ", end='')

    val_result_i2 = run_epoch_v2('val', model, i2_val_loader, optimizer, criterion, device, 
                                teacher_model=None, distill_lambda=1.0, temperature=2.0)
    print(f"i2-VAL loss:{val_result_i2['loss']:.3f} acc:{val_result_i2['acc']:.3f} | ", end='')
    
    print(f"Elapsed time: {datetime.timedelta(seconds=(time.time() - start_time))}")

Added new classes. New classifier size: 40
i2.e001 | TRAIN loss:2.862 acc:0.456 | i1-VAL loss:0.290 acc:1.000 | i2-VAL loss:1.826 acc:0.619 | Elapsed time: 0:37:28.882383
i2.e002 | TRAIN loss:1.825 acc:0.824 | i1-VAL loss:0.236 acc:1.000 | i2-VAL loss:0.929 acc:0.833 | Elapsed time: 0:37:47.412774
i2.e003 | TRAIN loss:1.524 acc:0.927 | i1-VAL loss:0.224 acc:1.000 | i2-VAL loss:0.606 acc:0.952 | Elapsed time: 0:38:06.830109
i2.e004 | TRAIN loss:1.423 acc:0.950 | i1-VAL loss:0.260 acc:0.978 | i2-VAL loss:0.511 acc:0.952 | Elapsed time: 0:38:26.658125
i2.e005 | TRAIN loss:1.368 acc:0.955 | i1-VAL loss:0.299 acc:1.000 | i2-VAL loss:0.464 acc:1.000 | Elapsed time: 0:38:46.674804
