In [1]:
import numpy as np
import copy
import matplotlib
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

In [2]:
import torch
import torch.nn as nn
import dtnnlib as dtnn
import resnet_cifar

from torchvision import datasets, transforms as T
from torch.utils import data

In [3]:
from tqdm import tqdm
import os, time, sys

In [4]:
# os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
# os.environ["CUDA_VISIBLE_DEVICES"] = "1"
# os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

In [5]:
cifar_train = T.Compose([
    T.RandomCrop(size=32, padding=4),
    T.RandomHorizontalFlip(),
    T.ToTensor(),
    T.Normalize(
        mean=[0.4914, 0.4822, 0.4465], # mean=[0.5071, 0.4865, 0.4409] for cifar100
        std=[0.2023, 0.1994, 0.2010], # std=[0.2009, 0.1984, 0.2023] for cifar100
    ),
])

cifar_test = T.Compose([
    T.ToTensor(),
    T.Normalize(
        mean=[0.4914, 0.4822, 0.4465], # mean=[0.5071, 0.4865, 0.4409] for cifar100
        std=[0.2023, 0.1994, 0.2010], # std=[0.2009, 0.1984, 0.2023] for cifar100
    ),
])

train_dataset = datasets.CIFAR10(root="../../../../../_Datasets/cifar10/", train=True, download=True, transform=cifar_train)
test_dataset = datasets.CIFAR10(root="../../../../../_Datasets/cifar10/", train=False, download=True, transform=cifar_test)

Files already downloaded and verified
Files already downloaded and verified


In [6]:
# train_dataset.data = train_dataset.data.view(-1, 28*28)
# test_dataset.data = test_dataset.data.view(-1, 28*28)

In [7]:
batch_size = 128
train_loader = data.DataLoader(dataset=train_dataset, num_workers=4, batch_size=batch_size, shuffle=True)
test_loader = data.DataLoader(dataset=test_dataset, num_workers=4, batch_size=batch_size, shuffle=False)

In [8]:
device = torch.device("cuda:1")

In [9]:
criterion = nn.CrossEntropyLoss()

In [10]:
for xx, yy in train_loader:
    xx, yy = xx.to(device), yy.to(device)
    print(xx.shape, yy.shape)
    break

torch.Size([128, 3, 32, 32]) torch.Size([128])


In [11]:
# net = resnet_cifar.cifar_resnet20(num_classes=10, distance=0.5)
# net

In [12]:
# asdasd

## Any function as metric

In [13]:
# class FunctionDT(nn.Module):
    
#     def __init__(self, input_dim, num_centers, func, inv_temp=0.):
#         '''
#         func [input_dim -> 1]
#         '''
#         super().__init__()
#         self.input_dim = input_dim
#         self.num_centers = num_centers
#         self.func = func
        
#         self.inv_temp = nn.Parameter(torch.ones(1)*inv_temp)
        
#         self.centers = torch.randn(num_centers, input_dim)/3.
#         self.centers = nn.Parameter(self.centers)
    
#     def forward(self, x):
#         z = x.unsqueeze(1) - self.centers.unsqueeze(0)
#         dists = self.func(z).squeeze(-1)
#         dists = -dists*torch.exp(self.inv_temp)
#         return dists

In [14]:
# from classes import DistanceRegressor, ConvexNN
# from nflib.flows import SequentialFlow, ActNorm
# import nflib.res_flow as irf

## Try Different metrics for CNN

In [15]:
## Following is copied from 
### https://github.com/kuangliu/pytorch-cifar/blob/master/main.py

# Training
def train(epoch, model, optimizer):
    model.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(tqdm(train_loader)):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
    print(f"[Train] {epoch} Loss: {train_loss/(batch_idx+1):.3f} | Acc: {100.*correct/total:.3f} {correct}/{total}")
    return

In [16]:
best_acc = -1
def test(epoch, model, model_name):
    global best_acc
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(tqdm(test_loader)):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            
    print(f"[Test] {epoch} Loss: {test_loss/(batch_idx+1):.3f} | Acc: {100.*correct/total:.3f} {correct}/{total}")
    
    # Save checkpoint.
    acc = 100.*correct/total
    if acc > best_acc:
        print('Saving..')
        state = {
            'model': model.state_dict(),
            'acc': acc,
            'epoch': epoch,
        }
        if not os.path.isdir('models'):
            os.mkdir('models')
        torch.save(state, f'./models/{model_name}.pth')
        best_acc = acc

In [17]:
EPOCHS = 200

In [None]:
acc_dict = {}
for key in ["stereographic", 1]:
    print("_________________________")
    print(f"Experimenting for {key} ;")
    net = resnet_cifar.cifar_resnet20(num_classes=10, distance=key).to(device)
#     net = torch.compile(net)
#     net = torch.compile(net, mode="reduce-overhead")
#     net = torch.compile(net, mode="max-autotune")
    
    model_name = f"00.3_c10_{str(key)}"

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(net.parameters(), lr=0.1,
                          momentum=0.9, weight_decay=5e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
    best_acc = -1
    for epoch in range(EPOCHS):
        train(epoch, net, optimizer)
        test(epoch, net, model_name)
        scheduler.step()
    acc_dict[key] = float(best_acc)
    pass

_________________________
Experimenting for stereographic ;


100%|███████████████████████████████████████████████████| 391/391 [02:03<00:00,  3.17it/s]


[Train] 0 Loss: 1.673 | Acc: 38.060 19030/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:06<00:00, 12.08it/s]


[Test] 0 Loss: 1.733 | Acc: 41.350 4135/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [02:03<00:00,  3.17it/s]


[Train] 1 Loss: 1.151 | Acc: 58.652 29326/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:06<00:00, 12.11it/s]


[Test] 1 Loss: 1.358 | Acc: 51.580 5158/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [02:03<00:00,  3.17it/s]


[Train] 2 Loss: 1.013 | Acc: 63.876 31938/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:06<00:00, 12.09it/s]


[Test] 2 Loss: 3.291 | Acc: 10.060 1006/10000


100%|███████████████████████████████████████████████████| 391/391 [02:03<00:00,  3.17it/s]


[Train] 3 Loss: 0.862 | Acc: 69.934 34967/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:06<00:00, 12.07it/s]


[Test] 3 Loss: 3.992 | Acc: 10.090 1009/10000


100%|███████████████████████████████████████████████████| 391/391 [02:03<00:00,  3.17it/s]


[Train] 4 Loss: 0.780 | Acc: 72.904 36452/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:06<00:00, 12.05it/s]


[Test] 4 Loss: 2.996 | Acc: 14.640 1464/10000


100%|███████████████████████████████████████████████████| 391/391 [02:03<00:00,  3.17it/s]


[Train] 5 Loss: 0.722 | Acc: 74.752 37376/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:06<00:00, 12.10it/s]


[Test] 5 Loss: 2.743 | Acc: 25.940 2594/10000


100%|███████████████████████████████████████████████████| 391/391 [02:03<00:00,  3.17it/s]


[Train] 6 Loss: 0.685 | Acc: 76.370 38185/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:06<00:00, 12.09it/s]


[Test] 6 Loss: 2.305 | Acc: 35.890 3589/10000


100%|███████████████████████████████████████████████████| 391/391 [02:03<00:00,  3.17it/s]


[Train] 7 Loss: 0.664 | Acc: 77.098 38549/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:06<00:00, 12.06it/s]


[Test] 7 Loss: 1.108 | Acc: 64.180 6418/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [02:03<00:00,  3.17it/s]


[Train] 8 Loss: 0.641 | Acc: 77.932 38966/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:06<00:00, 12.07it/s]


[Test] 8 Loss: 0.905 | Acc: 68.440 6844/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [02:03<00:00,  3.17it/s]


[Train] 9 Loss: 0.622 | Acc: 78.716 39358/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:06<00:00, 12.06it/s]


[Test] 9 Loss: 2.137 | Acc: 54.040 5404/10000


100%|███████████████████████████████████████████████████| 391/391 [02:03<00:00,  3.17it/s]


[Train] 10 Loss: 0.608 | Acc: 79.314 39657/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:06<00:00, 12.08it/s]


[Test] 10 Loss: 0.897 | Acc: 70.640 7064/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [02:03<00:00,  3.17it/s]


[Train] 11 Loss: 0.600 | Acc: 79.170 39585/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:06<00:00, 12.00it/s]


[Test] 11 Loss: 1.617 | Acc: 54.180 5418/10000


100%|███████████████████████████████████████████████████| 391/391 [02:03<00:00,  3.17it/s]


[Train] 12 Loss: 0.598 | Acc: 79.352 39676/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:06<00:00, 11.97it/s]


[Test] 12 Loss: 1.396 | Acc: 62.650 6265/10000


100%|███████████████████████████████████████████████████| 391/391 [02:03<00:00,  3.16it/s]


[Train] 13 Loss: 0.575 | Acc: 80.172 40086/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:06<00:00, 11.89it/s]


[Test] 13 Loss: 0.942 | Acc: 70.590 7059/10000


100%|███████████████████████████████████████████████████| 391/391 [02:03<00:00,  3.16it/s]


[Train] 14 Loss: 0.588 | Acc: 79.836 39918/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:06<00:00, 11.95it/s]


[Test] 14 Loss: 0.930 | Acc: 69.650 6965/10000


100%|███████████████████████████████████████████████████| 391/391 [02:03<00:00,  3.16it/s]


[Train] 15 Loss: 0.568 | Acc: 80.456 40228/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:06<00:00, 11.96it/s]


[Test] 15 Loss: 1.033 | Acc: 67.850 6785/10000


100%|███████████████████████████████████████████████████| 391/391 [02:03<00:00,  3.16it/s]


[Train] 16 Loss: 0.559 | Acc: 80.686 40343/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:06<00:00, 11.92it/s]


[Test] 16 Loss: 0.882 | Acc: 73.110 7311/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [02:03<00:00,  3.17it/s]


[Train] 17 Loss: 0.547 | Acc: 81.112 40556/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:06<00:00, 11.80it/s]


[Test] 17 Loss: 1.462 | Acc: 52.160 5216/10000


100%|███████████████████████████████████████████████████| 391/391 [02:03<00:00,  3.16it/s]


[Train] 18 Loss: 0.547 | Acc: 81.176 40588/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:06<00:00, 11.94it/s]


[Test] 18 Loss: 0.753 | Acc: 74.610 7461/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [02:03<00:00,  3.17it/s]


[Train] 19 Loss: 0.533 | Acc: 81.626 40813/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:06<00:00, 12.01it/s]


[Test] 19 Loss: 1.219 | Acc: 64.390 6439/10000


100%|███████████████████████████████████████████████████| 391/391 [02:03<00:00,  3.17it/s]


[Train] 20 Loss: 0.530 | Acc: 81.740 40870/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:06<00:00, 12.06it/s]


[Test] 20 Loss: 1.172 | Acc: 66.410 6641/10000


100%|███████████████████████████████████████████████████| 391/391 [02:03<00:00,  3.17it/s]


[Train] 21 Loss: 0.528 | Acc: 81.932 40966/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:06<00:00, 12.10it/s]


[Test] 21 Loss: 0.812 | Acc: 73.590 7359/10000


100%|███████████████████████████████████████████████████| 391/391 [02:03<00:00,  3.17it/s]


[Train] 22 Loss: 0.524 | Acc: 81.836 40918/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:06<00:00, 12.08it/s]


[Test] 22 Loss: 1.201 | Acc: 63.610 6361/10000


100%|███████████████████████████████████████████████████| 391/391 [02:03<00:00,  3.17it/s]


[Train] 23 Loss: 0.509 | Acc: 82.472 41236/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:06<00:00, 12.08it/s]


[Test] 23 Loss: 0.749 | Acc: 75.250 7525/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [02:03<00:00,  3.17it/s]


[Train] 24 Loss: 0.506 | Acc: 82.630 41315/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:06<00:00, 12.10it/s]


[Test] 24 Loss: 0.823 | Acc: 73.530 7353/10000


100%|███████████████████████████████████████████████████| 391/391 [02:03<00:00,  3.17it/s]


[Train] 25 Loss: 0.502 | Acc: 82.722 41361/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:06<00:00, 12.07it/s]


[Test] 25 Loss: 0.770 | Acc: 74.050 7405/10000


100%|███████████████████████████████████████████████████| 391/391 [02:03<00:00,  3.17it/s]


[Train] 26 Loss: 0.497 | Acc: 83.006 41503/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:06<00:00, 12.08it/s]


[Test] 26 Loss: 0.937 | Acc: 74.530 7453/10000


100%|███████████████████████████████████████████████████| 391/391 [02:03<00:00,  3.17it/s]


[Train] 27 Loss: 0.493 | Acc: 83.080 41540/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:06<00:00, 12.07it/s]


[Test] 27 Loss: 1.275 | Acc: 65.190 6519/10000


100%|███████████████████████████████████████████████████| 391/391 [02:03<00:00,  3.17it/s]


[Train] 28 Loss: 0.646 | Acc: 77.820 38910/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:06<00:00, 12.05it/s]


[Test] 28 Loss: 1.500 | Acc: 60.730 6073/10000


100%|███████████████████████████████████████████████████| 391/391 [02:03<00:00,  3.17it/s]


[Train] 29 Loss: 0.600 | Acc: 79.376 39688/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:06<00:00, 12.05it/s]


[Test] 29 Loss: 1.089 | Acc: 64.630 6463/10000


100%|███████████████████████████████████████████████████| 391/391 [02:03<00:00,  3.17it/s]


[Train] 30 Loss: 0.575 | Acc: 80.078 40039/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:06<00:00, 12.08it/s]


[Test] 30 Loss: 2.811 | Acc: 49.160 4916/10000


100%|███████████████████████████████████████████████████| 391/391 [02:03<00:00,  3.17it/s]


[Train] 31 Loss: 0.548 | Acc: 81.000 40500/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:06<00:00, 12.10it/s]


[Test] 31 Loss: 0.883 | Acc: 73.280 7328/10000


100%|███████████████████████████████████████████████████| 391/391 [02:03<00:00,  3.17it/s]


[Train] 32 Loss: 0.576 | Acc: 80.282 40141/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:06<00:00, 12.05it/s]


[Test] 32 Loss: 1.229 | Acc: 59.330 5933/10000


100%|███████████████████████████████████████████████████| 391/391 [02:03<00:00,  3.17it/s]


[Train] 33 Loss: 0.569 | Acc: 80.314 40157/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:06<00:00, 12.04it/s]


[Test] 33 Loss: 1.137 | Acc: 60.790 6079/10000


100%|███████████████████████████████████████████████████| 391/391 [02:09<00:00,  3.03it/s]


[Train] 34 Loss: 0.537 | Acc: 81.634 40817/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:06<00:00, 12.06it/s]


[Test] 34 Loss: 0.691 | Acc: 76.200 7620/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [02:03<00:00,  3.17it/s]


[Train] 35 Loss: 0.556 | Acc: 80.894 40447/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:06<00:00, 12.06it/s]


[Test] 35 Loss: 0.901 | Acc: 68.740 6874/10000


100%|███████████████████████████████████████████████████| 391/391 [02:03<00:00,  3.15it/s]


[Train] 36 Loss: 0.533 | Acc: 81.796 40898/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:06<00:00, 12.06it/s]


[Test] 36 Loss: 1.091 | Acc: 63.640 6364/10000


 42%|█████████████████████▍                             | 164/391 [01:01<01:32,  2.46it/s]

In [None]:
acc_dict

In [None]:
# net[0].centers.shape