## Dataset

Description from Section D.2:

![Description from Section D.2](images/paper__image_augmentations.png)

In [1]:
import torch
from torchvision import transforms

import custom_transforms

# transforms.ToTensor() not needed as we use torchvision.io.read_image,
# which gives torch.Tensor instead of PIL.Image
# Data Augmentation transforms are mostly from Bazinga699/NCL
# https://github.com/Bazinga699/NCL/blob/2bbf193/lib/dataset/cui_cifar.py#L64
train_transform = transforms.Compose([
    transforms.ConvertImageDtype(torch.float32),
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    custom_transforms.Cutout(n_holes=1, length=16),
    # TODO: Check if this is correct values for SIMCLR augmentation
    transforms.RandomApply([
        transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)
    ], p=0.8),
    transforms.RandomGrayscale(p=0.2),
    transforms.RandomApply([
        transforms.GaussianBlur(kernel_size=3, sigma=[.1, 2.])
    ], p=0.5),
    # TODO: Verify where this number comes from: is it CIFAR-10 or CIFAR-10-LT?
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
valid_transform = transforms.Compose([
    transforms.ConvertImageDtype(torch.float32),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
test_transform = transforms.Compose([
    transforms.ConvertImageDtype(torch.float32),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

In [2]:
from datasets.cifar10 import CIFAR10LTDataset

In [3]:
train_json_filepath = "data/json/cifar10_imbalance100/cifar10_imbalance100_train.json"
train_images_dirpath = "data/json/cifar10_imbalance100/images/"
train_dataset = CIFAR10LTDataset(
    json_filepath=train_json_filepath,
    images_dirpath=train_images_dirpath,
    transform=train_transform,
)

In [4]:
valid_json_filepath = "data/json/cifar10_imbalance100/cifar10_imbalance100_valid.json"
valid_images_dirpath = "data/json/cifar10_imbalance100/images/"
valid_dataset = CIFAR10LTDataset(
    json_filepath=valid_json_filepath,
    images_dirpath=valid_images_dirpath,
    transform=valid_transform,
)

In [5]:
len(train_dataset), len(valid_dataset)

(12406, 10000)

## DataLoader

In [6]:
# DataLoader Hyperparameters
DATALOADER__NUM_WORKERS = 4
DATALOADER__BATCH_SIZE = 128

In [7]:
# Compute weights
import json

import numpy as np

labels = np.arange(10)
with open(train_json_filepath, "r") as f:
    json_data = json.load(f)
sample_labels = [annotation["category_id"] for annotation in json_data["annotations"]]
sample_labels_count = np.array([len(np.where(sample_labels == l)[0]) for l in labels])
weights = 1. / sample_labels_count
sample_weights = np.array([weights[l] for l in sample_labels])

In [8]:
from torch.utils.data import DataLoader, WeightedRandomSampler

train_sampler = WeightedRandomSampler(
    weights=sample_weights,
    num_samples=50000, # https://stackoverflow.com/a/67802529
    replacement=True,
)
train_loader = DataLoader(
    train_dataset,
    sampler=train_sampler,
    batch_size=DATALOADER__BATCH_SIZE,
    num_workers=DATALOADER__NUM_WORKERS,
)

In [9]:
valid_loader = DataLoader(
    valid_dataset,
    batch_size=DATALOADER__BATCH_SIZE,
    num_workers=DATALOADER__NUM_WORKERS,
)

## Model

In [10]:
# Model hyperparameters
MODEL__WIDERESNET_DEPTH = 28
MODEL__WIDERESNET_K = 10

In [11]:
from networks import WideResNet

net = WideResNet(
    num_classes=10,
    depth=MODEL__WIDERESNET_DEPTH,
    widen_factor=MODEL__WIDERESNET_K,
)

In [12]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

count_parameters(net)

36479194

In [13]:
net = net.cuda()

## Wandb

In [14]:
import wandb
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mseungjaeryanlee[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

## Optimizer

In [15]:
# Optimizer Hyperparameters
OPTIM__LR = 0.1
OPTIM__MOMENTUM = 0.9
OPTIM__WEIGHT_DECAY = 2e-4

In [16]:
import torch.optim as optim

optimizer = optim.SGD(
    net.parameters(),
    lr=OPTIM__LR,
    momentum=OPTIM__MOMENTUM,
    weight_decay=OPTIM__WEIGHT_DECAY,
)
scheduler = optim.lr_scheduler.StepLR(
    optimizer,
    step_size=1,
    gamma=0.01,
)

## Prepare Training

In [17]:
# Training Hyperparameters
N_EPOCH = 10

In [18]:
import torch.nn as nn

criterion = nn.CrossEntropyLoss(reduction="none")

## Training Loop

In [19]:
wandb_run = wandb.init(
    project="pure-noise",
    entity="brianryan",
)

wandb.config.update({
    # Data
    "dataloader__num_workers": DATALOADER__NUM_WORKERS,
    "dataloader__batch_size": DATALOADER__BATCH_SIZE,
    # Optimizer
    "optim__lr": OPTIM__LR,
    "optim__momentum": OPTIM__MOMENTUM,
    "optim__weight_decay": OPTIM__WEIGHT_DECAY,
    # Model
    "model__wideresnet_depth": MODEL__WIDERESNET_DEPTH,
    "model__wideresnet_k": MODEL__WIDERESNET_K,
    # Training
    "n_epoch": N_EPOCH,
})

[34m[1mwandb[0m: Currently logged in as: [33mseungjaeryanlee[0m ([33mbrianryan[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [None]:
from collections import defaultdict
import torch

for epoch_i in range(N_EPOCH):
    # Training Phase
    net.train()
    train_losses = []
    train_labels = []
    for minibatch_i, (inputs, labels) in enumerate(train_loader):
        inputs = inputs.float().cuda()
        labels = labels.cuda()

        optimizer.zero_grad()
        outputs = net(inputs)
        losses = criterion(outputs, labels)
        losses.mean().backward()
        optimizer.step()

        train_losses.extend(losses.cpu().detach().tolist())
        train_labels.extend(labels.cpu().detach().tolist())

    train_losses = np.array(train_losses)
    train_labels = np.array(train_labels)

    # Filter losses by classes
    train_loss_per_class_dict = {
        f"train_loss__class_{class_}": train_losses[np.where(train_labels == class_)[0]].mean()
        for class_ in np.arange(10)
    }

    # Validation Phase
    net.eval()
    with torch.no_grad():
        # Save all losses and labels for each example
        valid_losses = []
        valid_labels = []
        for minibatch_i, (inputs, labels) in enumerate(valid_loader):
            inputs = inputs.float().cuda()
            labels = labels.cuda()

            outputs = net(inputs)
            losses = criterion(outputs, labels)

            valid_losses.extend(losses.cpu().detach().tolist())
            valid_labels.extend(labels.cpu().detach().tolist())

        valid_losses = np.array(valid_losses)
        valid_labels = np.array(valid_labels)

        # Filter losses by classes
        valid_loss_per_class_dict = {
            f"valid_loss__class_{class_}": valid_losses[np.where(valid_labels == class_)[0]].mean()
            for class_ in np.arange(10)
        }

    wandb.log({
        "train_loss": np.mean(train_losses),
        **train_loss_per_class_dict,
        "valid_loss": np.mean(valid_losses),
        **valid_loss_per_class_dict,
    })
    if epoch_i in [160, 180]:
        scheduler.step()

# Finish wandb run
wandb_run.finish()