In [1]:
import warnings
warnings.filterwarnings("ignore", message = "Applied workaround for CuDNN issue")

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau, LambdaLR, OneCycleLR, ExponentialLR, CosineAnnealingLR
from torch.cuda.amp import autocast, GradScaler

import torchvision
from torchvision import transforms, datasets
from torchvision.datasets import ImageFolder

import matplotlib.pyplot as plt
from sklearn.metrics import f1_score
import os
import time
import numpy as np
import random

# Custom Imports
from tests import test_utils
from tests import loading_utils
from residual_networks.RKAN_ResNet import RKAN_ResNet
from residual_networks.RKAN_DenseNet import RKAN_DenseNet
from residual_networks.RKAN_RegNet import RKAN_RegNet
from residual_networks.RKAN_ResNeXt import RKAN_ResNeXt
from residual_networks.RKAN_WideResNet import RKAN_WideResNet
from residual_networks.RKAN_VGG import RKAN_VGG
from residual_networks.RKAN_ConvNeXt import RKAN_ConvNeXt

print(torch.cuda.is_available())

True


In [2]:
REPRODUCIBLE = True
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    if REPRODUCIBLE:
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    else:
        torch.backends.cudnn.deterministic = False
        torch.backends.cudnn.benchmark = True
        
SEED = 10
set_seed(SEED)

def seed_worker(worker_id):
    worker_seed = SEED + worker_id
    np.random.seed(worker_seed)
    random.seed(worker_seed)

g = torch.Generator()
g.manual_seed(SEED)

<torch._C.Generator at 0x79e87f0ce890>

In [3]:
dataset_list = {"cifar_100": [256, 100, "small", 32, "full"], "cifar_10": [256, 10, "small", 32, "full"], "svhn": [256, 10, "small", 32, "full"],
                "tiny_imagenet": [256, 200, "medium", 64, "full"], "food_101": [256, 101, "large", 128, "full"], "imagenet_1k": [256, 1000, "large", 224, "transfer"]}
dataset = list(dataset_list.keys())[5]
batch_size, num_classes, dataset_size, input_size, train_status = dataset_list.get(dataset, [None, None, None, None, None])
print(f"dataset: {dataset}")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
image_upscale = True
upscale_check = image_upscale and dataset_size == "large"

if dataset in ["cifar_100", "cifar_10", "svhn"]:
    if upscale_check:
        input_size = 128
    if dataset == "cifar_100":
        normalize_mean, normalize_std = (0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)
        augment_policy = transforms.AutoAugmentPolicy.CIFAR10
    elif dataset == "cifar_10":
        normalize_mean, normalize_std = (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
        augment_policy = transforms.AutoAugmentPolicy.CIFAR10
    else:
        normalize_mean, normalize_std = (0.4377, 0.4438, 0.4728), (0.1980, 0.2010, 0.1970)
        augment_policy = transforms.AutoAugmentPolicy.SVHN

    train_transform = transforms.Compose([
        transforms.Resize(input_size, interpolation = transforms.InterpolationMode.BICUBIC) if upscale_check else transforms.RandomCrop(input_size, padding = 4),
        transforms.RandomHorizontalFlip() if dataset != "svhn" else transforms.Lambda(lambda x: x),
        transforms.AutoAugment(policy = augment_policy),
        transforms.ToTensor(),
        transforms.Normalize(normalize_mean, normalize_std)
    ])

    val_transform = transforms.Compose([
        transforms.Resize(input_size, interpolation = transforms.InterpolationMode.BICUBIC) if upscale_check else transforms.Lambda(lambda x: x),
        transforms.ToTensor(),
        transforms.Normalize(normalize_mean, normalize_std)
    ])

    if dataset == "cifar_100":
        train_dataset = torchvision.datasets.CIFAR100(root = "./cifar_data", train = True, download = False, transform = train_transform)
        val_dataset = torchvision.datasets.CIFAR100(root = "./cifar_data", train = False, download = False, transform = val_transform)
    elif dataset == "cifar_10":
        train_dataset = torchvision.datasets.CIFAR10(root = "./cifar_data", train = True, download = False, transform = train_transform)
        val_dataset = torchvision.datasets.CIFAR10(root = "./cifar_data", train = False, download = False, transform = val_transform)
    else:
        train_dataset = torchvision.datasets.SVHN(root = "./svhn_data", split = "train", download = False, transform = train_transform)
        val_dataset = torchvision.datasets.SVHN(root = "./svhn_data", split = "test", download = False, transform = val_transform)

elif dataset == "tiny_imagenet":
    train_transform = transforms.Compose([
        transforms.RandomCrop(input_size, padding = 4),
        transforms.RandomHorizontalFlip(),
        transforms.AutoAugment(policy = transforms.AutoAugmentPolicy.IMAGENET),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    val_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    train_dataset = ImageFolder(root = "./tiny-imagenet-200/train", transform = train_transform)
    val_dataset = ImageFolder(root = "./tiny-imagenet-200/val", transform = val_transform)

elif dataset == "food_101":
    image_dir = "food-101/images"
    train_txt = os.path.join("food-101/meta", "train.txt")
    test_txt = os.path.join("food-101/meta", "test.txt")

    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(input_size),
        transforms.RandomHorizontalFlip(),
        transforms.AutoAugment(policy = transforms.AutoAugmentPolicy.IMAGENET),
        transforms.ToTensor(),
        transforms.Normalize(mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225])
    ])

    val_transform = transforms.Compose([
        transforms.Resize(int(input_size * 1.25)),
        transforms.CenterCrop(input_size),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    train_dataset = loading_utils.Food101Dataset(root_dir = image_dir, txt_file = train_txt, transform = train_transform)
    val_dataset = loading_utils.Food101Dataset(root_dir = image_dir, txt_file = test_txt, transform = val_transform)

elif dataset == "imagenet_1k":
    train_dir = "imagenet-object-localization-challenge/ILSVRC/Data/CLS-LOC/train"
    val_dir = "imagenet-object-localization-challenge/ILSVRC/Data/CLS-LOC/val"
    val_ann_dir = "imagenet-object-localization-challenge/ILSVRC/Annotations/CLS-LOC/val"
    synset_to_class = loading_utils.generate_synset_to_class_mapping(train_dir)

    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(input_size),
        transforms.RandomHorizontalFlip(),
        transforms.AutoAugment(policy = transforms.AutoAugmentPolicy.IMAGENET),
        transforms.ToTensor(),
        transforms.Normalize(mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225])
    ])

    val_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(input_size),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    train_dataset = datasets.ImageFolder(root = train_dir, transform = train_transform)
    val_dataset = loading_utils.ImageNetValDataset(img_dir = val_dir, ann_dir = val_ann_dir, synset_to_class = synset_to_class, transform = val_transform)

else:
    raise ValueError(f"Unknown dataset '{dataset}'. Please specify 'cifar_100', 'cifar_10', 'svhn', 'tiny_imagenet', or 'imagenet_1k'.")

train_loader = DataLoader(train_dataset, batch_size = batch_size, shuffle = True, num_workers = 8, pin_memory = True, worker_init_fn = seed_worker, generator = g)
test_loader = DataLoader(val_dataset, batch_size = batch_size, shuffle = False, num_workers = 8, pin_memory = True, worker_init_fn = seed_worker, generator = g)

print(f"training images: {len(train_dataset)}")
print(f"validation images: {len(val_dataset)}")
for images, labels in train_loader:
    print(f"Image batch shape: {images.shape}")
    break

dataset: imagenet_1k
training images: 1281167
validation images: 50000
Image batch shape: torch.Size([256, 3, 224, 224])


### Training

In [4]:
model_configs = {
    "resnet18": {"class": RKAN_ResNet, "reduce_factor": lambda r: r, "single_conv": False},
    "resnet34": {"class": RKAN_ResNet, "reduce_factor": lambda r: r, "single_conv": False},
    "resnet50": {"class": RKAN_ResNet, "reduce_factor": lambda r: r, "single_conv": True},
    "resnet101": {"class": RKAN_ResNet, "reduce_factor": lambda r: r, "single_conv": True},
    "resnet152": {"class": RKAN_ResNet, "reduce_factor": lambda r: r, "single_conv": True},
    "densenet121": {"class": RKAN_DenseNet, "reduce_factor": lambda r: r, "single_conv": True},
    "densenet169": {"class": RKAN_DenseNet, "reduce_factor": lambda r: r, "single_conv": True},
    "densenet201": {"class": RKAN_DenseNet, "reduce_factor": lambda r: r, "single_conv": True},
    "resnext50_32x4d": {"class": RKAN_ResNeXt, "reduce_factor": lambda r: r, "single_conv": True},
    "resnext101_32x8d": {"class": RKAN_ResNeXt, "reduce_factor": lambda r: r, "single_conv": True},
    "wide_resnet50_2": {"class": RKAN_WideResNet, "reduce_factor": lambda r: r, "single_conv": True},
    "wide_resnet101_2": {"class": RKAN_WideResNet, "reduce_factor": lambda r: r, "single_conv": True},
    "convnext_tiny": {"class": RKAN_ConvNeXt, "reduce_factor": lambda r: r, "single_conv": True},
    "convnext_small": {"class": RKAN_ConvNeXt, "reduce_factor": lambda r: r, "single_conv": True},
    "convnext_base": {"class": RKAN_ConvNeXt, "reduce_factor": lambda r: r, "single_conv": True},
    "convnext_large": {"class": RKAN_ConvNeXt, "reduce_factor": lambda r: r, "single_conv": True},
    "regnet_y_400mf": {"class": RKAN_RegNet, "reduce_factor": lambda r: r[-1], "single_conv": False},
    "regnet_y_800mf": {"class": RKAN_RegNet, "reduce_factor": lambda r: r[-1], "single_conv": False},
    "regnet_y_1_6gf": {"class": RKAN_RegNet, "reduce_factor": lambda r: r[-1], "single_conv": True},
    "regnet_y_3_2gf": {"class": RKAN_RegNet, "reduce_factor": lambda r: r[-1], "single_conv": True},
    "vgg11_bn": {"class": RKAN_VGG, "reduce_factor": lambda r: r[-1], "single_conv": False},
    "vgg13_bn": {"class": RKAN_VGG, "reduce_factor": lambda r: r[-1], "single_conv": False},
    "vgg16_bn": {"class": RKAN_VGG, "reduce_factor": lambda r: r[-1], "single_conv": True},
    "vgg19_bn": {"class": RKAN_VGG, "reduce_factor": lambda r: r[-1], "single_conv": True},
}

In [None]:
model_name = "regnet_y_800mf"
reduce_factor = [1, 1, 1, 1]
mechanisms = [None, None, None, "addition"]
pretrained = True if train_status == "transfer" else False
model_config = model_configs.get(model_name)
if model_config is None:
    raise ValueError(f"Unknown model name: '{model_name}'.")
model = model_config["class"](version = model_name, num_classes = num_classes, reduce_factor = model_config["reduce_factor"](reduce_factor),
                              dataset_size = dataset_size, pretrained = pretrained, mechanisms = mechanisms, single_conv = model_config["single_conv"]).to(device)

warmup_epochs = 0
use_cutmix = True
criterion = nn.CrossEntropyLoss()

if train_status == "full":
    initial_lr = 0.1
    num_epochs = 200
    optimizer = optim.SGD(model.parameters(), lr = initial_lr, weight_decay = 1e-3)
    # scheduler = ReduceLROnPlateau(optimizer, "min", patience = 5, factor = 0.5)
    # scheduler = ExponentialLR(optimizer, gamma = 0.95)
    # scheduler = test_utils.DecreasingCosineAnnealingWarmRestarts(optimizer, T_0 = 150, T_mult = 1, eta_min = 1e-5, decay_factor = 0.5)
    scheduler = OneCycleLR(optimizer, max_lr = 0.1, steps_per_epoch = len(train_loader), epochs = num_epochs, pct_start = 0.2, div_factor = 10, final_div_factor = 500)
elif train_status == "transfer":
    use_cutmix = False
    initial_lr = 0.001
    num_epochs = 30
    optimizer = optim.AdamW(model.parameters(), lr = initial_lr, weight_decay = 1e-4)
    scheduler = CosineAnnealingLR(optimizer, T_max = num_epochs, eta_min = 1e-5)
else:
    raise ValueError(f"Unknown train_status: '{train_status}'.")
    
warmup_scheduler = LambdaLR(optimizer, lr_lambda = lambda epoch: test_utils.warmup_scheduler(epoch, warmup_epochs))
scaler = GradScaler()
early_stopping = test_utils.EarlyStopping(patience = num_epochs, min_delta = 0, path = "best.pt")

train_losses, val_losses, train_accuracies, val_accuracies = [], [], [], []
best_loss, best_epoch = float("inf"), -1
if hasattr(model, "log_norms") and model.log_norms:
    train_base_norms, train_residual_norms, train_combined_norms = [], [], []

def train(epoch):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    top5_correct = 0

    for inputs, targets in train_loader:
        inputs, targets = inputs.to(device), targets.to(device)

        optimizer.zero_grad()
        if use_cutmix and np.random.rand() < 0.5:
            inputs, targets_a, targets_b, lam = test_utils.cutmix_data(inputs, targets)
            with autocast():
                outputs = model(inputs)
                loss = lam * criterion(outputs, targets_a) + (1 - lam) * criterion(outputs, targets_b)
        else:
            with autocast():
                outputs = model(inputs)
                loss = criterion(outputs, targets)
        
        scaler.scale(loss).backward()
        scale_before = scaler.get_scale()
        scaler.step(optimizer)
        scaler.update()
        if scale_before <= scaler.get_scale() and isinstance(scheduler, OneCycleLR):
            scheduler.step()
            
        if hasattr(model, "log_norms") and model.log_norms:
            train_base_norms.extend(model.base_norms)
            train_residual_norms.extend(model.residual_norms)
            train_combined_norms.extend(model.combined_norms)
            model.reset_norms()

        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        top5_correct += torch.topk(outputs, 5, dim = 1)[1].eq(targets.view(-1, 1)).sum().item()

    train_losses.append(running_loss / len(train_loader))
    train_accuracies.append(100. * correct / total)

    print(f"Epoch [{epoch + 1}/{num_epochs}]\n"
          f"Training Loss: {running_loss / len(train_loader):.4f}, "
          f"Top-1 Accuracy: {100. * correct/total:.2f}%, "
          f"Top-5 Accuracy: {100. * top5_correct/total:.2f}%, "
          f"Learning Rate: {optimizer.param_groups[0]['lr']}")

def validate():
    global best_loss, best_epoch
    model.eval()
    running_loss = 0.0
    correct = top5_correct = total = 0
    all_targets, all_preds = [], []

    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            with autocast():
                outputs = model(inputs)
                loss = criterion(outputs, targets)

            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            top5_correct += torch.topk(outputs, 5, dim = 1)[1].eq(targets.view(-1, 1)).sum().item()

            all_targets.extend(targets.cpu().numpy())
            all_preds.extend(predicted.cpu().numpy())

    f1 = f1_score(all_targets, all_preds, average = "macro")
    val_loss = running_loss / len(test_loader)
    val_losses.append(running_loss / len(test_loader))
    val_accuracies.append(100. * correct / total)
    if val_loss < best_loss:
        best_loss = val_loss
        best_epoch = epoch + 1

    elapsed_time = time.time() - start_time
    print(f"Validation Loss: {val_loss:.4f}, "
          f"Top-1 Accuracy: {100. * correct/total:.2f}%, "
          f"Top-5 Accuracy: {100. * top5_correct/total:.2f}%\n"
          f"Best Val Loss: {best_loss:.4f} (Epoch {best_epoch}), F1 Score: {f1:.4f}, "
          f"Time: {elapsed_time:.2f}s")

    if epoch < warmup_epochs:
        warmup_scheduler.step()
    else:
        if isinstance(scheduler, ReduceLROnPlateau):
            scheduler.step(val_loss)
        elif isinstance(scheduler, OneCycleLR):
            pass
        else:
            scheduler.step()
            
    early_stopping(running_loss, model)
    if early_stopping.early_stop:
        print("Early stopping...")
        return True
    return False

torch.cuda.empty_cache()
for epoch in range(num_epochs):
    start_time = time.time()
    train(epoch)
    if validate():
        break

test_utils.profile_model(model, input_size, device)
torch.cuda.empty_cache()

#### Plot

In [None]:
max_val_accuracy = max(val_accuracies)
best_epoch = val_accuracies.index(max_val_accuracy) + 1
min_val_loss = val_losses[val_accuracies.index(max_val_accuracy)]
print(f"Highest Accuracy: {max_val_accuracy:.2f}% at Epoch {best_epoch}, Loss: {min_val_loss:.4f}")
print(', '.join(f"{loss:.4f}" for loss in val_losses))
print(', '.join(f"{accuracy:.2f}" for accuracy in val_accuracies))
if hasattr(model, "log_norms") and model.log_norms:
    print(', '.join(f"{norms:.4f}" for norms in train_base_norms))
    print(', '.join(f"{norms:.4f}" for norms in train_residual_norms))
    print(', '.join(f"{norms:.4f}" for norms in train_combined_norms))

window_size = 1
smooth_train_losses = test_utils.moving_average(train_losses[:best_epoch], window_size)
smooth_val_losses = test_utils.moving_average(val_losses[:best_epoch], window_size)
smooth_train_accuracies = test_utils.moving_average(train_accuracies[:best_epoch], window_size)
smooth_val_accuracies = test_utils.moving_average(val_accuracies[:best_epoch], window_size)
x_smooth = np.arange(window_size, best_epoch + 1)

plt.plot(x_smooth, smooth_train_losses, label = "Training Loss")
plt.plot(x_smooth, smooth_val_losses, label = "Validation Loss")
plt.title("Training and Validation Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.show()

plt.plot(x_smooth, smooth_train_accuracies, label = "Training Accuracy")
plt.plot(x_smooth, smooth_val_accuracies, label = "Validation Accuracy")
plt.title("Training and Validation Accuracy")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.legend()
plt.show()