## Dataset

In [1]:
import torch
from torchvision import transforms

# transforms.ToTensor() not needed as we use torchvision.io.read_image,
# which gives torch.Tensor instead of PIL.Image
# Data Augmentation transforms are mostly from Bazinga699/NCL
# https://github.com/Bazinga699/NCL/blob/2bbf193/lib/dataset/cui_cifar.py#L64
train_transform = transforms.Compose([
    transforms.ConvertImageDtype(torch.float32),
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    # NOTE(ryanlee): Check normalize values
    transforms.Normalize([0.5071, 0.4867, 0.4408], [0.2675, 0.2565, 0.2761]),
])
valid_transform = transforms.Compose([
    transforms.ConvertImageDtype(torch.float32),
    transforms.Normalize([0.5071, 0.4867, 0.4408], [0.2675, 0.2565, 0.2761]),
])
test_transform = transforms.Compose([
    transforms.ConvertImageDtype(torch.float32),
    transforms.Normalize([0.5071, 0.4867, 0.4408], [0.2675, 0.2565, 0.2761]),
])

In [2]:
from datasets.cifar10 import CIFAR10LTDataset

train_json_filepath = "data/json/cifar10_imbalance100/cifar10_imbalance100_train.json"
train_images_dirpath = "data/json/cifar10_imbalance100/images/"
train_dataset = CIFAR10LTDataset(
    json_filepath=train_json_filepath,
    images_dirpath=train_images_dirpath,
    transform=train_transform,
)

In [3]:
valid_json_filepath = "data/json/cifar10_imbalance100/cifar10_imbalance100_valid.json"
valid_images_dirpath = "data/json/cifar10_imbalance100/images/"
valid_dataset = CIFAR10LTDataset(
    json_filepath=valid_json_filepath,
    images_dirpath=valid_images_dirpath,
    transform=valid_transform,
)

In [4]:
len(train_dataset), len(valid_dataset)

(12406, 10000)

## DataLoader

In [5]:
# DataLoader Hyperparameters
DATALOADER__NUM_WORKERS = 8
DATALOADER__BATCH_SIZE = 128

In [6]:
# Compute weights
import json

import numpy as np

labels = np.arange(10)
with open(train_json_filepath, "r") as f:
    json_data = json.load(f)
sample_labels = [annotation["category_id"] for annotation in json_data["annotations"]]
sample_labels_count = np.array([len(np.where(sample_labels == l)[0]) for l in labels])
weights = 1. / sample_labels_count
sample_weights = np.array([weights[l] for l in sample_labels])

In [7]:
from torch.utils.data import DataLoader, WeightedRandomSampler

train_sampler = WeightedRandomSampler(
    weights=sample_weights,
    num_samples=50000, # https://stackoverflow.com/a/67802529
    # num_samples=len(train_dataset),
    replacement=True,
)
train_loader = DataLoader(
    train_dataset,
    sampler=train_sampler,
    # shuffle=True,
    batch_size=DATALOADER__BATCH_SIZE,
    num_workers=DATALOADER__NUM_WORKERS,
)

In [8]:
valid_loader = DataLoader(
    valid_dataset,
    batch_size=DATALOADER__BATCH_SIZE,
    num_workers=DATALOADER__NUM_WORKERS,
)

## Setup Model

In [9]:
# Model hyperparameters
MODEL__NAME = "resnet32__ldam_drw"
MODEL__RESNET_DEPTH = 32
MODEL_PARAMS = {
    "MODEL__NAME": MODEL__NAME,
    "MODEL__RESNET_DEPTH": MODEL__RESNET_DEPTH,
}

In [10]:
from ldam_drw_models import resnet32

net = resnet32()

In [11]:
# from m2m_models import resnet32

# net = resnet32()

In [12]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

count_parameters(net)

464154

In [13]:
net = net.cuda()

## Wandb

In [14]:
import wandb
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mseungjaeryanlee[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

## Optimizer

In [15]:
# Optimizer Hyperparameters
OPTIM__LR = 0.1
OPTIM__MOMENTUM = 0.9
OPTIM__WEIGHT_DECAY = 2e-4

In [16]:
import torch.optim as optim

optimizer = optim.SGD(
    net.parameters(),
    lr=OPTIM__LR,
    momentum=OPTIM__MOMENTUM,
    weight_decay=OPTIM__WEIGHT_DECAY,
)

In [17]:
warmup_scheduler = optim.lr_scheduler.LinearLR(
    optimizer,
    # NOTE(ryanlee): 0 causes ZeroDivisionError
    start_factor=torch.finfo().tiny,
    end_factor=1,
    total_iters=5,
)
multistep_scheduler = optim.lr_scheduler.MultiStepLR(
    optimizer,
    milestones=[160,180],
    gamma=0.01,
)
scheduler = optim.lr_scheduler.ChainedScheduler([
    warmup_scheduler,
    multistep_scheduler,
])

## Prepare Training

In [18]:
# Training Hyperparameters
N_EPOCH = 200
SAVE_CKPT_EVERY_N_EPOCH = 10
LOAD_CKPT = False
LOAD_CKPT_FILEPATH = "checkpoints/.pt"
LOAD_CKPT_EPOCH = 0

In [19]:
import torch.nn as nn

criterion = nn.CrossEntropyLoss(reduction="none")

## Training Loop

In [20]:
def save_checkpoint(
    model,
    optimizer,
    checkpoint_filepath: str,
):
    torch.save({
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
    }, checkpoint_filepath)


def load_checkpoint(
    model,
    optimizer,
    checkpoint_filepath: str,
):
    checkpoint = torch.load(checkpoint_filepath)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

In [21]:
if LOAD_CKPT:
    load_checkpoint(net, optimizer, LOAD_CKPT_FILEPATH)

In [22]:
wandb_run = wandb.init(
    project="pure-noise",
    entity="brianryan",
)

wandb.config.update({
    # Data
    "dataloader__num_workers": DATALOADER__NUM_WORKERS,
    "dataloader__batch_size": DATALOADER__BATCH_SIZE,
    # Optimizer
    "optim__lr": OPTIM__LR,
    "optim__momentum": OPTIM__MOMENTUM,
    "optim__weight_decay": OPTIM__WEIGHT_DECAY,
    # Model
    **MODEL_PARAMS,
    # Training
    "n_epoch": N_EPOCH,
    "save_ckpt_every_n_epoch": SAVE_CKPT_EVERY_N_EPOCH,
    "load_ckpt": LOAD_CKPT,
    "load_ckpt_filepath": LOAD_CKPT_FILEPATH,
    "load_ckpt_epoch": LOAD_CKPT_EPOCH,
})

[34m[1mwandb[0m: Currently logged in as: [33mseungjaeryanlee[0m ([33mbrianryan[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [23]:
from collections import defaultdict
import os

import torch

start_epoch_i, end_epoch_i = 0, N_EPOCH
if LOAD_CKPT:
    start_epoch_i += LOAD_CKPT_EPOCH
    end_epoch_i += LOAD_CKPT_EPOCH
for epoch_i in range(start_epoch_i, end_epoch_i):
    # Save checkpoint
    if epoch_i % SAVE_CKPT_EVERY_N_EPOCH == 0:
        checkpoint_filepath = f"checkpoints/{wandb.run.name}__epoch_{epoch_i}.pt"
        os.makedirs("checkpoints/", exist_ok=True)
        save_checkpoint(net, optimizer, checkpoint_filepath)
        wandb.save(checkpoint_filepath)

    ## Training Phase
    net.train()
    train_losses = []
    train_labels = []
    train_preds = []
    for minibatch_i, (inputs, labels) in enumerate(train_loader):
        inputs = inputs.float().cuda()
        labels = labels.cuda()

        optimizer.zero_grad()
        outputs = net(inputs)
        losses = criterion(outputs, labels)
        losses.mean().backward()
        optimizer.step()

        preds = torch.argmax(outputs, dim=1)
        train_losses.extend(losses.cpu().detach().tolist())
        train_labels.extend(labels.cpu().detach().tolist())
        train_preds.extend(preds.cpu().detach().tolist())

    train_losses = np.array(train_losses)
    train_labels = np.array(train_labels)
    train_preds = np.array(train_preds)

    # Filter losses by classes
    train_loss_per_class_dict = {
        f"train_loss__class_{class_}": train_losses[np.where(train_labels == class_)[0]].mean()
        for class_ in np.arange(10)
    }
    # Filter preds by classes for accuracy
    train_acc_per_class_dict = {
        f"train_acc__class_{class_}": (train_preds == train_labels)[np.where(train_labels == class_)[0]].mean()
        for class_ in np.arange(10)
    }

    ## Validation Phase
    net.eval()
    with torch.no_grad():
        # Save all losses and labels for each example
        valid_losses = []
        valid_labels = []
        valid_preds = []
        for minibatch_i, (inputs, labels) in enumerate(valid_loader):
            inputs = inputs.float().cuda()
            labels = labels.cuda()

            outputs = net(inputs)
            losses = criterion(outputs, labels)
            preds = torch.argmax(outputs, dim=1)

            valid_losses.extend(losses.cpu().detach().tolist())
            valid_labels.extend(labels.cpu().detach().tolist())
            valid_preds.extend(preds.cpu().detach().tolist())

    valid_losses = np.array(valid_losses)
    valid_labels = np.array(valid_labels)
    valid_preds = np.array(valid_preds)

    # Filter losses by classes
    valid_loss_per_class_dict = {
        f"valid_loss__class_{class_}": valid_losses[np.where(valid_labels == class_)[0]].mean()
        for class_ in np.arange(10)
    }
    # Filter preds by classes for accuracy
    valid_acc_per_class_dict = {
        f"valid_acc__class_{class_}": (valid_preds == valid_labels)[np.where(valid_labels == class_)[0]].mean()
        for class_ in np.arange(10)
    }

    # Logging
    wandb.log({
        "epoch_i": epoch_i,
        "train_loss": np.mean(train_losses),
        "train_acc": np.mean(train_preds == train_labels),
        **train_loss_per_class_dict,
        **train_acc_per_class_dict,
        "valid_loss": np.mean(valid_losses),
        "valid_acc": np.mean(valid_preds == valid_labels),
        **valid_loss_per_class_dict,
        **valid_acc_per_class_dict,
        "lr": scheduler.get_last_lr()[0],
    })
    scheduler.step()

# Save the last epoch
checkpoint_filepath = f"checkpoints/{wandb.run.name}__epoch_{end_epoch_i}.pt"
save_checkpoint(net, optimizer, checkpoint_filepath)
wandb.save(checkpoint_filepath)

# Finish wandb run
wandb_run.finish()

VBox(children=(Label(value='74.701 MB of 74.701 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, m…

0,1
epoch_i,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
lr,▂███████████████████████████████▁▁▁▁▁▁▁▁
train_acc,▁▆▇▇████████████████████████████████████
train_acc__class_0,▁▆▇▇▇▇▇▇▇███████████████████████████████
train_acc__class_1,▁▇▇█████████████████████████████████████
train_acc__class_2,▁▄▆▇▇▇▇▇▇▇▇▇▇▇▇▇████████████████████████
train_acc__class_3,▁▄▆▇▇▇▇▇▇▇▇▇████████████████████████████
train_acc__class_4,▁▅▇▇████████████████████████████████████
train_acc__class_5,▁▅▇█████████████████████████████████████
train_acc__class_6,▁▆██████████████████████████████████████

0,1
epoch_i,199.0
lr,1e-05
train_acc,0.99816
train_acc__class_0,0.98985
train_acc__class_1,0.99758
train_acc__class_2,0.99619
train_acc__class_3,0.99842
train_acc__class_4,0.99979
train_acc__class_5,0.9998
train_acc__class_6,1.0
