In [1]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision.models import resnet18, ResNet18_Weights
from torchvision.transforms import Compose, ToTensor, Lambda, Normalize

In [2]:
def _create_batch(unbatched_data, unbatched_label, unbatched_test_data, unbatched_test_label):
    unbatched_data = torch.split(unbatched_data, 100)
    unbatched_label = torch.split(unbatched_label, 100)
    unbatched_test_data = torch.split(unbatched_test_data, 100)
    unbatched_test_label = torch.split(unbatched_test_label, 100)
    return unbatched_data, unbatched_label, unbatched_test_data, unbatched_test_label


def get_dataset():
    # dataset
    transform = Compose([
        ToTensor(),
        Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)),
    ])

    target_transform = Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), 1))

    train_dataset = CIFAR10('./data', train=True, download=True, transform=transform, target_transform=target_transform)
    test_dataset = CIFAR10('./data', train=False, download=True, transform=transform)

    train_loader = DataLoader(train_dataset, batch_size=len(train_dataset))
    test_loader = DataLoader(test_dataset, batch_size=len(test_dataset))

    train_data, train_label = next(iter(train_loader))
    test_data, test_label = next(iter(test_loader))

    print('train data: {}, train label: {}'.format(train_data.size(), train_label.size()))
    print('test data: {}, test label: {}'.format(test_data.size(), test_label.size()))

    sorted_train_label_arg = torch.argsort(torch.argmax(train_label, dim=1))
    sorted_train_label = train_label[sorted_train_label_arg]
    sorted_train_data = train_data[sorted_train_label_arg]

    train_data_sampled = []
    train_label_sampled = []
    for class_idx in range(10):
        class_idx = class_idx * 5000
        train_data_sampled.append(sorted_train_data[class_idx:(class_idx + 1000)])
        train_label_sampled.append(sorted_train_label[class_idx:(class_idx + 1000)])

    rand_idx = torch.randperm(10000)
    train_data_sampled = torch.concat(train_data_sampled, dim=0)[rand_idx]
    train_label_sampled = torch.concat(train_label_sampled, dim=0)[rand_idx]
    print(
        'train_data_sampled: {}, train_label_sampled: {}'.format(train_data_sampled.size(), train_label_sampled.size()))

    train_data_sampled, train_label_sampled, test_data, test_label = _create_batch(train_data_sampled,
                                                                                   train_label_sampled, test_data,
                                                                                   test_label)
    print('train_data_sampled: {}, train_label_sampled: {}'.format(len(train_data_sampled), len(train_label_sampled)))
    print('train_data_sampled: {}, train_label_sampled: {}'.format(train_data_sampled[0].size(),
                                                                   train_label_sampled[0].size()))

    return train_data_sampled, train_label_sampled, test_data, test_label

In [3]:
def train(train_data_sampled, train_label_sampled, test_data, test_label, augmentation=None, num_epoch=100, lr=0.0001):
    model = resnet18(weights=ResNet18_Weights.DEFAULT)
    model.fc = nn.Linear(512, 10)
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    optimizer = Adam(model.parameters(), lr=lr)
    criterion = nn.MSELoss()
    model = model.to(device)

    running_loss = []
    running_acc = []
    loss = None

    for epoch in range(num_epoch):
        for idx, (data, label) in enumerate(zip(train_data_sampled, train_label_sampled)):
            model.train()
            if augmentation is not None:
                data, label = augmentation(data, label)
            data, label = data.to(device), label.to(device)

            optimizer.zero_grad()
            preds = model(data)
            loss = criterion(preds, label)
            loss.backward()
            optimizer.step()
            running_loss.append(loss.item())

            # test
            # if (idx + 1) % 10 == 0:
        model.eval()
        tot_acc = torch.zeros(1).to(device)
        test_data_size = 0
        with torch.no_grad():
            for test_data_batch, test_label_batch in zip(test_data, test_label):
                test_data_batch, test_label_batch = test_data_batch.to(device), test_label_batch.to(device)
                test_preds = model(test_data_batch)
                test_preds = torch.argmax(test_preds, dim=1)
                tot_acc = tot_acc + torch.count_nonzero((test_preds == test_label_batch).long())
                test_data_size += test_data_batch.size(0)
            running_acc.append(tot_acc.item() / test_data_size)
            # print(tot_acc)
            # print(test_data_size)
        print('epoch: {}, loss: {}, acc: {}'.format(epoch + 1, loss, tot_acc.item() / test_data_size))

    plt.figure(figsize=(10, 5))
    plt.plot(list(range(len(running_loss))), running_loss)
    plt.xlabel('iteration')
    plt.ylabel('loss')
    plt.title('loss vs. iteration')
    plt.show()

    plt.figure(figsize=(10, 5))
    plt.plot(list(range(len(running_acc))), running_acc)
    plt.xlabel('iteration')
    plt.ylabel('acc')
    plt.title('acc vs. iteration')
    plt.show()

    return running_acc, running_loss

In [4]:
def mixup(data_batch, label_batch, alpha):
    mixup_idx = np.random.choice(data_batch.size(0), data_batch.size(0))
    mixup_samples = data_batch[mixup_idx]
    mixup_labels = label_batch[mixup_idx]
    lambda_arr = torch.tensor(np.random.beta(alpha, alpha, size=data_batch.size(0)))
    for sample_idx, (data_sample, label_sample) in enumerate(zip(data_batch, label_batch)):
        mixup_samples[sample_idx] = lambda_arr[sample_idx] * data_sample + (1 - lambda_arr[sample_idx]) * mixup_samples[
            sample_idx]
        mixup_labels[sample_idx] = lambda_arr[sample_idx] * label_sample + (1 - lambda_arr[sample_idx]) * mixup_labels[
            sample_idx]
    return mixup_samples, mixup_labels

def cutout(data_batch, label_batch, k=16):
    cutout_mask = np.random.choice(2, data_batch.size(0)) # if zero no cutoff if one cutoff
    cutout_samples = torch.zeros(data_batch.size())
    for sample_idx, (data_sample, label_sample) in enumerate(zip(data_batch, label_batch)):
        if cutout_mask[sample_idx] == 1:
            random_row = np.random.choice(data_sample.size(1), 1)[0]
            random_col = np.random.choice(data_sample.size(2), 1)[0]
            if k % 2 == 0:
                cutout_row_min, cutout_row_max = int(max(0, random_row - ((k / 2) - 1))), int(min(data_sample.size(1) - 1, random_row + (k / 2)))
                cutout_col_min, cutout_col_max = int(max(0, random_col - ((k / 2) - 1))), int(min(data_sample.size(1) - 1, random_col + (k / 2)))
            else:
                cutout_row_min, cutout_row_max = int(max(0, random_row - ((k - 1) / 2))), int(min(data_sample.size(1) - 1, random_row + ((k - 1) / 2)))
                cutout_col_min, cutout_col_max = int(max(0, random_col - ((k - 1) / 2))), int(min(data_sample.size(1) - 1, random_col + ((k - 1) / 2)))
            zero_filter = torch.zeros((data_sample.size(0), cutout_row_max - cutout_row_min + 1, cutout_col_max - cutout_col_min + 1))
            data_sample[:, cutout_row_min:(cutout_row_max + 1), cutout_col_min:(cutout_col_max + 1)] = zero_filter
            cutout_samples[sample_idx] = data_sample
    return data_batch, label_batch

def standard_augmentation(data_batch, label_batch, k=4):
    standard_samples = torch.zeros(data_batch.size())
    for sample_idx, (data_sample, label_sample) in enumerate(zip(data_batch, label_batch)):
        upward_k, rightward_k = np.random.choice(list(range(-1*k, k+1)), 2).astype(int)
        if upward_k > 0:
            standard_samples[sample_idx, :, :(data_sample.size(1) - upward_k), :] = data_sample[:, upward_k:, :]
        else:
            upward_k = -1 * upward_k
            standard_samples[sample_idx, :, upward_k:, :] = data_sample[:, :(data_sample.size(1) - upward_k), :]

        if rightward_k > 0:
            standard_samples[sample_idx, :, :, rightward_k:] = data_sample[:, :, :(data_sample.size(2) - rightward_k)]
        else:
            rightward_k = -1 * rightward_k
            standard_samples[sample_idx, :, :, :(data_sample.size(2) - rightward_k)] = data_sample[:, :, rightward_k:]

        flip_or_not = np.random.choice(2, 1) # if zero not flip if one flip
        if flip_or_not == 1:
            flip_r = torch.fliplr(standard_samples[sample_idx, 0, :, :])
            flip_g = torch.fliplr(standard_samples[sample_idx, 1, :, :])
            flip_b = torch.fliplr(standard_samples[sample_idx, 2, :, :])
            standard_samples[sample_idx] = torch.stack([flip_r, flip_g, flip_b])
    return standard_samples, label_batch

def combined_augmentation(data_batch, label_batch, alpha, k_cutout=16, k_standard=4):
    data_batch, label_batch = standard_augmentation(data_batch, label_batch, k=k_standard)
    data_batch, label_batch = cutout(data_batch, label_batch, k=k_cutout)
    return mixup(data_batch, label_batch, alpha)

In [None]:
train_X, train_y, test_X, test_y = get_dataset()
running_loss_without_aug, running_acc_without_aug = train(train_X, train_y, test_X, test_y)

Files already downloaded and verified
Files already downloaded and verified
train data: torch.Size([50000, 3, 32, 32]), train label: torch.Size([50000, 10])
test data: torch.Size([10000, 3, 32, 32]), test label: torch.Size([10000])
train_data_sampled: torch.Size([10000, 3, 32, 32]), train_label_sampled: torch.Size([10000, 10])
train_data_sampled: 100, train_label_sampled: 100
train_data_sampled: torch.Size([100, 3, 32, 32]), train_label_sampled: torch.Size([100, 10])
epoch: 1, loss: 0.08498416095972061, acc: 0.3082
epoch: 2, loss: 0.0710790753364563, acc: 0.4341
epoch: 3, loss: 0.05815282464027405, acc: 0.504
epoch: 4, loss: 0.04718910902738571, acc: 0.5445
epoch: 5, loss: 0.036710530519485474, acc: 0.5678
epoch: 6, loss: 0.02772563323378563, acc: 0.5776
epoch: 7, loss: 0.02106798253953457, acc: 0.5811
epoch: 8, loss: 0.015844400972127914, acc: 0.583
epoch: 9, loss: 0.012305676005780697, acc: 0.5854
epoch: 10, loss: 0.01037162821739912, acc: 0.584
epoch: 11, loss: 0.009008882567286491,

In [None]:
train_X, train_y, test_X, test_y = get_dataset()
running_loss_mixup_2, running_acc_mixup_2 = train(train_X, train_y, test_X, test_y, augmentation=lambda x, y: mixup(x, y, 0.2))
train_X, train_y, test_X, test_y = get_dataset()
running_loss_mixup_4, running_acc_mixup_4 = train(train_X, train_y, test_X, test_y, augmentation=lambda x, y: mixup(x, y, 0.4))

In [None]:
train_X, train_y, test_X, test_y = get_dataset()
running_loss_cutout, running_acc_cutout = train(train_X, train_y, test_X, test_y, augmentation=cutout)

In [None]:
train_X, train_y, test_X, test_y = get_dataset()
running_loss_standard, running_acc_standard = train(train_X, train_y, test_X, test_y, augmentation=standard_augmentation)

In [None]:
train_X, train_y, test_X, test_y = get_dataset()
running_loss_combined_2, running_acc_combined_2 = train(train_X, train_y, test_X, test_y, augmentation=lambda x, y: combined_augmentation(x, y, 0.2))
train_X, train_y, test_X, test_y = get_dataset()
running_loss_combined_4, running_acc_combined_4 = train(train_X, train_y, test_X, test_y, augmentation=lambda x, y: combined_augmentation(x, y, 0.4))