## Dataset

Description from Section D.2:

![Description from Section D.2](images/paper__image_augmentations.png)

In [1]:
# Data Hyperparameters
DATASET__IMAGE_HEIGHT = 32
DATASET__IMAGE_WIDTH = 32

In [2]:
import torch
from torchvision import transforms

import custom_transforms

# transforms.ToTensor() not needed as we use torchvision.io.read_image,
# which gives torch.Tensor instead of PIL.Image
# Data Augmentation transforms are mostly from Bazinga699/NCL
# https://github.com/Bazinga699/NCL/blob/2bbf193/lib/dataset/cui_cifar.py#L64
train_transform = transforms.Compose([
    transforms.ConvertImageDtype(torch.float32),
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    custom_transforms.Cutout(n_holes=1, length=16),
    # TODO: Check if this is correct values for SIMCLR augmentation
    transforms.RandomApply([
        transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)
    ], p=0.8),
    transforms.RandomGrayscale(p=0.2),
    transforms.RandomApply([
        transforms.GaussianBlur(kernel_size=3, sigma=[.1, 2.])
    ], p=0.5),
    # TODO: Verify where this number comes from: is it CIFAR-10 or CIFAR-10-LT?
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
valid_transform = transforms.Compose([
    transforms.ConvertImageDtype(torch.float32),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
test_transform = transforms.Compose([
    transforms.ConvertImageDtype(torch.float32),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

In [3]:
from datasets.cifar10 import CIFAR10LTDataset

In [4]:
train_json_filepath = "data/json/cifar10_imbalance100/cifar10_imbalance100_train.json"
train_images_dirpath = "data/json/cifar10_imbalance100/images/"
train_dataset = CIFAR10LTDataset(
    json_filepath=train_json_filepath,
    images_dirpath=train_images_dirpath,
    transform=train_transform,
)

In [5]:
valid_json_filepath = "data/json/cifar10_imbalance100/cifar10_imbalance100_valid.json"
valid_images_dirpath = "data/json/cifar10_imbalance100/images/"
valid_dataset = CIFAR10LTDataset(
    json_filepath=valid_json_filepath,
    images_dirpath=valid_images_dirpath,
    transform=valid_transform,
)

In [6]:
len(train_dataset), len(valid_dataset)

(12406, 10000)

## DataLoader

In [7]:
# DataLoader Hyperparameters
DATALOADER__NUM_WORKERS = 4
DATALOADER__BATCH_SIZE = 128

In [8]:
# Compute weights
import json

import numpy as np

labels = np.arange(10)
with open(train_json_filepath, "r") as f:
    json_data = json.load(f)
sample_labels = [annotation["category_id"] for annotation in json_data["annotations"]]
sample_labels_count = torch.LongTensor([len(np.where(sample_labels == l)[0]) for l in labels])
weights = 1. / sample_labels_count
sample_weights = torch.FloatTensor([weights[l] for l in sample_labels])

In [9]:
from torch.utils.data import DataLoader, WeightedRandomSampler

train_sampler = WeightedRandomSampler(
    weights=sample_weights,
    num_samples=50000, # https://stackoverflow.com/a/67802529
    replacement=True,
)
train_loader = DataLoader(
    train_dataset,
    sampler=train_sampler,
    batch_size=DATALOADER__BATCH_SIZE,
    num_workers=DATALOADER__NUM_WORKERS,
)

In [10]:
valid_loader = DataLoader(
    valid_dataset,
    batch_size=DATALOADER__BATCH_SIZE,
    num_workers=DATALOADER__NUM_WORKERS,
)

## OPeN

<img src="images/paper__noise_image.png" width=50%>

In [11]:
# OPeN Hyperparameters
OPEN__NOISE_RATIO = 1/3
OPEN__START_EPOCH = 160

In [12]:
# Gather all training images
train_images = []
for minibatch_images, _ in train_loader:
    train_images.extend(minibatch_images)
train_images = torch.stack(train_images)

In [13]:
# Find mean and std per channel
mean_per_channel = train_images.mean(dim=(0, 2, 3))
std_per_channel = train_images.std(dim=(0, 2, 3))
print(mean_per_channel, std_per_channel)

tensor([-0.7539, -0.7452, -0.6490]) tensor([1.4275, 1.4369, 1.4134])


In [14]:
# Edited code from paper to include clipping
def sample_noise_images(mean_per_channel, std_per_channel, n_images, height, width):
    """Samples pure noise images from the normal distribution N(mean,std)"""
    r = torch.normal(mean_per_channel[0], std_per_channel[0], size=(n_images, 1, height, width))
    g = torch.normal(mean_per_channel[1], std_per_channel[1], size=(n_images, 1, height, width))
    b = torch.normal(mean_per_channel[2], std_per_channel[2], size=(n_images, 1, height, width))
    pure_noise_images = torch.cat((r, g, b), 1)
    clipped_pure_noise_images = torch.clip(pure_noise_images, min=0, max=1)

    return clipped_pure_noise_images

<img src="images/paper__open.png" width=50%>

In [15]:
import torch

def replace_images_with_pure_noise(
    images,
    labels,
    mean_per_channel,
    std_per_channel,
    height,
    width,
    noise_ratio,
):
    # Compute representation ratio
    # TODO: Should `sample_labels_count` be a parameter to this method?
    representation_ratio = sample_labels_count[labels] / torch.max(sample_labels_count)

    # Compute probabilities to replace natural images with pure noise images
    noise_probs = (1 - representation_ratio) * noise_ratio

    # Sample indexes to replace with noise according to Bernoulli distribution
    noise_indices = torch.nonzero(torch.bernoulli(noise_probs)).view(-1)

    # Replace natural images with sampled pure noise images
    noise_images = sample_noise_images(
        mean_per_channel,
        std_per_channel,
        n_images=len(noise_indices),
        height=height,
        width=width,
    )
    images[noise_indices] = noise_images

    # Create mask for noise images - later used by DAR-BN
    noise_mask = torch.zeros(images.size(0), dtype=torch.bool)
    noise_mask[noise_indices] = True

    return images, noise_mask

## Model

In [16]:
# Model hyperparameters
MODEL__WIDERESNET_DEPTH = 28
MODEL__WIDERESNET_K = 10

In [17]:
from networks import WideResNet

net = WideResNet(
    num_classes=10,
    depth=MODEL__WIDERESNET_DEPTH,
    widen_factor=MODEL__WIDERESNET_K,
)

In [18]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

count_parameters(net)

36479194

In [19]:
net = net.cuda()

## Wandb

In [20]:
import wandb
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mseungjaeryanlee[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

## Optimizer

In [21]:
# Optimizer Hyperparameters
OPTIM__LR = 0.1
OPTIM__MOMENTUM = 0.9
OPTIM__WEIGHT_DECAY = 2e-4

In [22]:
import torch.optim as optim

optimizer = optim.SGD(
    net.parameters(),
    lr=OPTIM__LR,
    momentum=OPTIM__MOMENTUM,
    weight_decay=OPTIM__WEIGHT_DECAY,
)
scheduler = optim.lr_scheduler.StepLR(
    optimizer,
    step_size=1,
    gamma=0.01,
)

## Prepare Training

In [23]:
# Training Hyperparameters
N_EPOCH = 5
SAVE_CKPT_EVERY_N_EPOCH = 10
LOAD_CKPT = False
LOAD_CKPT_FILEPATH = ""

In [24]:
import torch.nn as nn

criterion = nn.CrossEntropyLoss(reduction="none")

## Training Loop

In [25]:
def save_checkpoint(
    model,
    optimizer,
    checkpoint_filepath: str,
):
    torch.save({
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
    }, checkpoint_filepath)


def load_checkpoint(
    model,
    optimizer,
    checkpoint_filepath: str,
):
    checkpoint = torch.load(checkpoint_filepath)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

In [26]:
if LOAD_CKPT:
    load_checkpoint(net, optimizer, LOAD_CKPT_FILEPATH)

In [27]:
wandb_run = wandb.init(
    project="pure-noise",
    entity="brianryan",
)

wandb.config.update({
    # Dataset
    "dataset__image_height": DATASET__IMAGE_HEIGHT,
    "dataset__image_width": DATASET__IMAGE_WIDTH,
    # DataLoader
    "dataloader__num_workers": DATALOADER__NUM_WORKERS,
    "dataloader__batch_size": DATALOADER__BATCH_SIZE,
    # OPeN
    "open__start_epoch": OPEN__START_EPOCH,
    "open__noise_ratio": OPEN__NOISE_RATIO,
    # Optimizer
    "optim__lr": OPTIM__LR,
    "optim__momentum": OPTIM__MOMENTUM,
    "optim__weight_decay": OPTIM__WEIGHT_DECAY,
    # Model
    "model__wideresnet_depth": MODEL__WIDERESNET_DEPTH,
    "model__wideresnet_k": MODEL__WIDERESNET_K,
    # Training
    "n_epoch": N_EPOCH,
    "save_ckpt_every_n_epoch": SAVE_CKPT_EVERY_N_EPOCH,
    "load_ckpt": LOAD_CKPT,
    "load_ckpt_filepath": LOAD_CKPT_FILEPATH,
})

[34m[1mwandb[0m: Currently logged in as: [33mseungjaeryanlee[0m ([33mbrianryan[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [28]:
from collections import defaultdict
import os

import torch

for epoch_i in range(N_EPOCH):
    # Save checkpoint
    if epoch_i % SAVE_CKPT_EVERY_N_EPOCH == 0:
        checkpoint_filepath = f"checkpoints/{wandb.run.name}__epoch_{epoch_i}.pt"
        os.makedirs("checkpoints/", exist_ok=True)
        save_checkpoint(net, optimizer, checkpoint_filepath)
        wandb.save(checkpoint_filepath)

    ## Training Phase
    net.train()
    train_losses = []
    train_labels = []
    train_preds = []
    for minibatch_i, (inputs, labels) in enumerate(train_loader):
        if epoch_i >= OPEN__START_EPOCH:
            inputs, noise_mask = replace_images_with_pure_noise(
                inputs,
                labels,
                mean_per_channel,
                std_per_channel,
                height=DATASET__IMAGE_HEIGHT,
                width=DATASET__IMAGE_WIDTH,
                noise_ratio=OPEN__NOISE_RATIO,
            )

        inputs = inputs.float().cuda()
        labels = labels.cuda()

        optimizer.zero_grad()
        outputs = net(inputs)
        losses = criterion(outputs, labels)
        losses.mean().backward()
        optimizer.step()

        preds = torch.argmax(outputs, dim=1)
        train_losses.extend(losses.cpu().detach().tolist())
        train_labels.extend(labels.cpu().detach().tolist())
        train_preds.extend(preds.cpu().detach().tolist())

    train_losses = np.array(train_losses)
    train_labels = np.array(train_labels)
    train_preds = np.array(train_preds)

    # Filter losses by classes
    train_loss_per_class_dict = {
        f"train_loss__class_{class_}": train_losses[np.where(train_labels == class_)[0]].mean()
        for class_ in np.arange(10)
    }
    # Filter preds by classes for accuracy
    train_acc_per_class_dict = {
        f"train_acc__class_{class_}": (train_preds == train_labels)[np.where(train_labels == class_)[0]].mean()
        for class_ in np.arange(10)
    }

    ## Validation Phase
    net.eval()
    with torch.no_grad():
        # Save all losses and labels for each example
        valid_losses = []
        valid_labels = []
        valid_preds = []
        for minibatch_i, (inputs, labels) in enumerate(valid_loader):
            inputs = inputs.float().cuda()
            labels = labels.cuda()

            outputs = net(inputs)
            losses = criterion(outputs, labels)
            preds = torch.argmax(outputs, dim=1)

            valid_losses.extend(losses.cpu().detach().tolist())
            valid_labels.extend(labels.cpu().detach().tolist())
            valid_preds.extend(preds.cpu().detach().tolist())

    valid_losses = np.array(valid_losses)
    valid_labels = np.array(valid_labels)
    valid_preds = np.array(valid_preds)

    # Filter losses by classes
    valid_loss_per_class_dict = {
        f"valid_loss__class_{class_}": valid_losses[np.where(valid_labels == class_)[0]].mean()
        for class_ in np.arange(10)
    }
    # Filter preds by classes for accuracy
    valid_acc_per_class_dict = {
        f"valid_acc__class_{class_}": (valid_preds == valid_labels)[np.where(valid_labels == class_)[0]].mean()
        for class_ in np.arange(10)
    }

    # Logging
    wandb.log({
        "train_loss": np.mean(train_losses),
        "train_acc": np.mean(train_preds == train_labels),
        **train_loss_per_class_dict,
        **train_acc_per_class_dict,
        "valid_loss": np.mean(valid_losses),
        "valid_acc": np.mean(valid_preds == valid_labels),
        **valid_loss_per_class_dict,
        **valid_acc_per_class_dict,
    })
    if epoch_i in [160, 180]:
        scheduler.step()

# Save final checkpoint
checkpoint_filepath = f"checkpoints/{wandb.run.name}__epoch_{N_EPOCH}.pt"
os.makedirs("checkpoints/", exist_ok=True)
save_checkpoint(net, optimizer, checkpoint_filepath)
wandb.save(checkpoint_filepath)

# Finish wandb run
wandb_run.finish()

VBox(children=(Label(value='185.275 MB of 417.744 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.44…

0,1
train_acc,▁▃▅▇█
train_acc__class_0,▁▃▅▇█
train_acc__class_1,▁▄▅▇█
train_acc__class_2,▁▃▅▇█
train_acc__class_3,▁▁▂▅█
train_acc__class_4,▂▁▃▇█
train_acc__class_5,▁▃▄▆█
train_acc__class_6,▁▄▆▇█
train_acc__class_7,▁▅▆▇█
train_acc__class_8,▁▂▅▇█

0,1
train_acc,0.52898
train_acc__class_0,0.74359
train_acc__class_1,0.71162
train_acc__class_2,0.36498
train_acc__class_3,0.33882
train_acc__class_4,0.41451
train_acc__class_5,0.39093
train_acc__class_6,0.48102
train_acc__class_7,0.56038
train_acc__class_8,0.62221
