In [9]:
from factorisation.densenet import DenseNet, Bottleneck, GroupedCroissantBottleneck, Transition, get_standard_densnet121
from utils import count_nonzero_parameters

In [31]:
def get_student_densenet(growth_rate: int = 24) -> DenseNet:
    return DenseNet([4, 8, 16, 8], Bottleneck, Transition, growth_rate=growth_rate)

In [46]:
from _typing_ import DataLoader, Optimizer
from torch import nn
import torch

from utils import calculate_mix_up_loss

def train(
    train_loader: DataLoader,
    net: nn.Module,
    optimiser: Optimizer,
    criterion,
    device: str = "cuda",
    half: bool = False,
    clip: bool = False,
    mixup: bool = False,
    teacher: nn.Module | None = None,
) -> tuple[float, float]:
    net.train()
    train_loss = 0.0
    correct = 0
    total = 0

    for inputs, targets in train_loader:
        inputs, targets = inputs.to(device), targets.to(device)
        if half:
            inputs = inputs.half()

        optimiser.zero_grad()

        if mixup:
            loss, outputs = calculate_mix_up_loss(net, criterion, inputs, targets, device)
        elif teacher is not None:
            loss, outputs = calculate_distillation_loss(net, teacher, inputs, targets)
        else:
            outputs = net(inputs)
            loss = criterion(outputs, targets)

        loss.backward()
        optimiser.step()

        if clip:
            net.clip()

        train_loss += loss.item()
        _, predicted = outputs.max(1)

        if not mixup:
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

    acc = 100. * correct / total if total > 0 else 0.0
    return acc, train_loss

def test(
        test_loader: DataLoader,
        net: nn.Module,
        criterion = nn.CrossEntropyLoss(),
        device: str = "cuda",
        half: bool = False,
    ) -> tuple[float]:
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            if half:
                inputs = inputs.half()

            outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

    acc = 100.* correct / total
    return acc, test_loss

def run_epochs(
    net: "Module",
    train_loader: DataLoader,
    test_loader: DataLoader,
    hyperparams: dict,
    n_epochs: int = 200,
    start_epoch: int = 0,
    device: str = "cuda",
    half: bool = False,
    clip: bool = False,
    mixup: bool = False,
    teacher: nn.Module | None = None,
    checkpoint_file_name: str = "train_ckpt",
):
    best_acc = 0
    train_accs = []
    test_accs = []

    for epoch in range(start_epoch, start_epoch + n_epochs):
        print(f"\nEpoch {epoch}")

        train_acc, _ = train(
            train_loader,
            net,
            hyperparams["optimiser"],
            hyperparams["criterion"],
            device,
            half=half,
            clip=clip,
            mixup=mixup,
            teacher=teacher,
        )
        test_acc, _ = test(
            test_loader,
            net,
            hyperparams["criterion"],
            device,
            half
        )
        train_accs.append(train_acc)
        test_accs.append(test_acc)

        # if test_acc > best_acc:
        #     save_checkpoint(net, test_acc, epoch, checkpoint_file_name)
        #     best_acc = test_acc

    return best_acc, train_accs, test_accs

In [47]:
from utils import (
    get_best_transformations,
    get_cifar10_train_val_loaders,
    get_test_cifar10_dataloader,
    load_trained_model,
    load_untrained_model,
)

train_transforms = get_best_transformations()
train_loader, _ = get_cifar10_train_val_loaders(transform=train_transforms)
test_loader = get_test_cifar10_dataloader()

teacher_model, _ = load_trained_model()
train_details = load_untrained_model(teacher_model)

student_model = get_student_densenet()
student_model.to("cuda")

_, _, _ = run_epochs(
    student_model,
    train_loader,
    test_loader,
    train_details,
    n_epochs=2,
    teacher=teacher_model,
)


Epoch 0


KeyboardInterrupt: 