In [None]:
import numpy as np
import copy
import matplotlib
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

In [None]:
import torch
import torch.nn as nn
import dtnnlib as dtnn
import resnet_cifar

from torchvision import datasets, transforms as T
from torch.utils import data

In [None]:
torch.set_float32_matmul_precision('high')

In [None]:
from tqdm import tqdm
import os, time, sys, random, json

In [None]:
cifar_train = T.Compose([
    T.RandomCrop(size=32, padding=4),
    T.RandomHorizontalFlip(),
    T.ToTensor(),
    T.Normalize(
        mean=[0.4914, 0.4822, 0.4465], # mean=[0.5071, 0.4865, 0.4409] for cifar100
        std=[0.2023, 0.1994, 0.2010], # std=[0.2009, 0.1984, 0.2023] for cifar100
    ),
])

cifar_test = T.Compose([
    T.ToTensor(),
    T.Normalize(
        mean=[0.4914, 0.4822, 0.4465], # mean=[0.5071, 0.4865, 0.4409] for cifar100
        std=[0.2023, 0.1994, 0.2010], # std=[0.2009, 0.1984, 0.2023] for cifar100
    ),
])

train_dataset = datasets.CIFAR10(root="data/", train=True, download=True, transform=cifar_train)
test_dataset = datasets.CIFAR10(root="data/", train=False, download=True, transform=cifar_test)

In [None]:
batch_size = 64
train_loader = data.DataLoader(dataset=train_dataset, num_workers=4, batch_size=batch_size, shuffle=True)
test_loader = data.DataLoader(dataset=test_dataset, num_workers=4, batch_size=batch_size, shuffle=False)

In [None]:
device = torch.device("cuda:0")

In [None]:
criterion = nn.CrossEntropyLoss()

In [None]:
for xx, yy in train_loader:
    xx, yy = xx.to(device), yy.to(device)
    print(xx.shape, yy.shape)
    break

In [None]:
net = resnet_cifar.cifar_resnet20(num_classes=10, distance=2)
net.conv1

## Try Different metrics for CNN

In [None]:
def train(epoch, model, optimizer):
    model.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
    return train_loss/(batch_idx+1), correct/total

In [None]:
def test(epoch, model, model_name):
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(test_loader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            
    return test_loss/(batch_idx+1), correct/total

In [None]:
EPOCHS = 200
# EPOCHS = 1

In [None]:
try:
    with open(f"./outputs/04_bench_metrics_c10_res20.json") as f:
        accs_bench = json.load(f)
except:
    accs_bench = {}

In [None]:
accs_bench

In [None]:
SEEDS = [852, 963, 159, 147]
for seed in SEEDS:
    acc_dict = {}
    for key in ["stereographic", "linear", 2]:
        _s = str(seed)
        _k = str(key)
        if str(_s) in accs_bench.keys():
            if str(_k) in accs_bench[str(_s)].keys():
                if len(accs_bench[_s][_k]["test_acc"]) == EPOCHS:
                    print(f"Completed for {_k}; seed {_s}")
                    continue
        
        print("_________________________")
        print(f"Experimenting for {key}; seed {seed}")
        
        torch.manual_seed(seed)
        np.random.seed(seed)
        random.seed(seed)
        train_loader = data.DataLoader(dataset=train_dataset, num_workers=4, batch_size=batch_size, shuffle=True)
        test_loader = data.DataLoader(dataset=test_dataset, num_workers=4, batch_size=batch_size, shuffle=False)
        
        net = resnet_cifar.cifar_resnet20(num_classes=10, distance=key).to(device)
        net = torch.compile(net)

        model_name = f"04_c10_{str(key)}_s{seed}"

        criterion = nn.CrossEntropyLoss()
        optimizer = torch.optim.SGD(net.parameters(), lr=0.1,
                              momentum=0.9, weight_decay=5e-4)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
        best_acc = -1
        
        train_losses, train_accs = [], []
        test_losses, test_accs = [], []
        for epoch in tqdm(range(EPOCHS)):
            tr_loss, tr_acc = train(epoch, net, optimizer)
            train_losses.append(tr_loss)
            train_accs.append(tr_acc)
            te_loss, te_acc = test(epoch, net, model_name)
            test_losses.append(te_loss)
            test_accs.append(te_acc)
            
            ######## Save checkpoint.
            if te_acc > best_acc:
                state = {
                    'model': net.state_dict(),
                    'acc': te_acc,
                    'epoch': epoch,
                }
                if not os.path.isdir('models'): os.mkdir('models')
                torch.save(state, f'./models/{model_name}.pth')
                best_acc = te_acc
            #######################
            
            scheduler.step()
        ##### after full training
        acc_dict[key] = {"train_acc":train_accs, "train_loss":train_losses, "test_acc":test_accs, "test_loss":test_losses}
        accs_bench[seed] = acc_dict
        ## Save it in the file.
        with open(f"./outputs/04_bench_metrics_c10_res20.json", "w") as f:
            json.dump(accs_bench, f, indent=3)
        pass

In [None]:
accs_bench

## Plot the training curves

In [None]:
with open(f"./outputs/04_bench_metrics_c10_res20.json") as f:
    benchmark = json.load(f, object_pairs_hook=array_on_duplicate_keys)

In [None]:
for seed in benchmark.keys():
    print(seed, benchmark[seed].keys())

In [None]:
all_accs_per_seed = {}
for m in ["stereographic", "linear", "2"]:
    maxk, maxv = None, -1
    print(m)
    for seed in benchmark.keys():
        max_test_acc = np.max(benchmark[seed][m]["test_acc"])
        print(seed, max_test_acc)
        if max_test_acc > maxv:
            maxv = max_test_acc
            maxk = seed
    all_accs_per_seed[m] = maxk
    print(f"\t\t{maxk} : {maxv}")

In [None]:
# all_accs_per_seed

### plot for highest accuracies

In [None]:
fig, axs = plt.subplots(2, 3, figsize=(12,6))
axs[0][0].set_ylabel("accuracy")
axs[1][0].set_ylabel("loss")

names = {'stereographic': 'istereo', 'linear': 'linear', '2': r'$l^2$'}

for i, (m, s) in enumerate(all_accs_per_seed.items()):

    data = benchmark[s][m]

    axs[0][i].plot(data["train_acc"], label="train", color='tab:red')
    axs[0][i].plot(data["test_acc"], label="test", color='tab:green')
    axs[0][i].set_xticks([])
    axs[0][i].legend()    

    axs[1][i].plot(data["train_loss"], label="train", color='tab:pink')
    axs[1][i].plot(data["test_loss"], label="test", color='tab:blue')
    axs[1][i].set_xlabel(r"epochs $\to$ "+names[m])
    axs[1][i].legend()    
    
plt.savefig("./outputs/04_bench_metrics_c10_res20_best_plot.pdf", bbox_inches="tight")