In [11]:
import random
import torch
from torch.utils.data import DataLoader
from torch import nn
from torch.optim import Optimizer
from utils import (
    load_untrained_model,
    get_best_transformations,
    get_cifar10_train_val_loaders,
    get_test_cifar10_dataloader,
    save_checkpoint,
    test,
    pickle_dump
)
from _typing_ import (
    Dataset,
    Optimizer,
    LRScheduler,
    Module,
    Transform,
)

In [None]:
def train(
    train_loader: DataLoader,
    net: nn.Module,
    optimiser: Optimizer,
    criterion,
    device: str = "cuda",
    half: bool = False,
    clip: bool = False,
    mixup: bool = False,
) -> tuple[float, float]:
    net.train()
    train_loss = 0.0
    correct = 0
    total = 0

    for inputs, targets in train_loader:
        inputs, targets = inputs.to(device), targets.to(device)
        if half:
            inputs = inputs.half()

        optimiser.zero_grad()

        if mixup:
            # Shuffle the batch
            index = torch.randperm(inputs.size(0)).to(device)
            inputs_perm = inputs[index]
            targets_perm = targets[index]

            # Sample lambda
            lam = random.random()

            # Mix the inputs
            inputs_mix = lam * inputs + (1 - lam) * inputs_perm
            outputs = net(inputs_mix)

            # Mixup loss
            loss = lam * criterion(outputs, targets) + (1 - lam) * criterion(outputs, targets_perm)
        else:
            outputs = net(inputs)
            loss = criterion(outputs, targets)

        loss.backward()
        optimiser.step()

        if clip:
            net.clip()

        train_loss += loss.item()
        _, predicted = outputs.max(1)

        if mixup:
            # Can't compute real accuracy with mixed labels; just skip
            total += 0
            correct += 0
        else:
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

    acc = 100. * correct / total if total > 0 else 0.0
    return acc, train_loss

def run_epochs(
    train_loader: DataLoader,
    test_loader: DataLoader,
    hyperparams: dict,
    n_epochs: int = 200,
    start_epoch: int = 0,
    device: str = "cuda",
    half: bool = False,
    clip: bool = False,
    mixup: bool = False,
):
    best_acc = 0
    train_accs = []
    test_accs = []

    for epoch in range(start_epoch, start_epoch + n_epochs):
        print(f"\nEpoch {epoch}")

        train_acc, train_loss = train(
            train_loader,
            net,
            hyperparams["optimiser"],
            hyperparams["criterion"],
            device,
            half=half,
            clip=clip,
            mixup=mixup
        )

        test_acc, test_loss = test(
            test_loader,
            net,
            hyperparams["criterion"],
            device,
            half
        )

        train_accs.append(train_acc)
        test_accs.append(test_acc)

        if test_acc > best_acc:
            save_checkpoint(net, test_acc, epoch)
            best_acc = test_acc

    return best_acc, train_accs, test_accs


In [13]:
model_dict = load_untrained_model("DenseNet121")
model = model_dict["model"]

In [14]:
train_transforms = get_best_transformations()
train_loader, val_loader = get_cifar10_train_val_loaders(transform=train_transforms)
test_loader = get_test_cifar10_dataloader()

In [15]:
best_acc, train_accs, test_accs = run_epochs(
    model,
    train_loader,
    val_loader,
    model_dict,
    n_epochs=5,
    mixup=True
)


Epoch 0
Train Loss: 2219.8968 | Train Acc: 0.00%
Test  Loss: 380.8304 | Test  Acc: 58.54%
Saving..

Epoch 1
Train Loss: 1945.3437 | Train Acc: 0.00%
Test  Loss: 353.3369 | Test  Acc: 62.17%
Saving..

Epoch 2
Train Loss: 1822.6729 | Train Acc: 0.00%
Test  Loss: 275.8569 | Test  Acc: 71.81%
Saving..

Epoch 3
Train Loss: 1751.1444 | Train Acc: 0.00%
Test  Loss: 249.4704 | Test  Acc: 75.64%
Saving..

Epoch 4
Train Loss: 1689.3967 | Train Acc: 0.00%
Test  Loss: 228.6286 | Test  Acc: 77.71%
Saving..


In [None]:
test_acc, test_loss = test(
    test_loader,
    model,
)

res = {
    'net': model.state_dict(),
    'acc': test_acc,
    'train_accs': train_accs,
    "val_accs": test_accs,
}

pickle_dump(res, "train_results")