In [36]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms

import matplotlib
import matplotlib.pyplot as plt
import numpy as np

import os, sys, pathlib, random, time, pickle, copy
from tqdm import tqdm

In [37]:
import torch.optim as optim
from torch.utils import data

In [38]:
import nflib
from nflib.flows import SequentialFlow, NormalizingFlow, ActNorm, ActNorm2D, AffineConstantFlow
import nflib.coupling_flows as icf
import nflib.inn_flow as inn
import nflib.res_flow as irf

In [39]:
device = torch.device("cuda:1")
# device = torch.device("cpu")

### Datasets

In [40]:
# cifar_train = transforms.Compose([
#     transforms.RandomCrop(size=32, padding=4),
#     transforms.RandomHorizontalFlip(),
#     transforms.ToTensor(),
#     transforms.Normalize(
#         mean=[0.4914, 0.4822, 0.4465], # mean=[0.5071, 0.4865, 0.4409] for cifar100
#         std=[0.2023, 0.1994, 0.2010], # std=[0.2009, 0.1984, 0.2023] for cifar100
#     ),
# ])

# cifar_test = transforms.Compose([
#     transforms.ToTensor(),
#     transforms.Normalize(
#         mean=[0.4914, 0.4822, 0.4465], # mean=[0.5071, 0.4865, 0.4409] for cifar100
#         std=[0.2023, 0.1994, 0.2010], # std=[0.2009, 0.1984, 0.2023] for cifar100
#     ),
# ])

# train_dataset = datasets.CIFAR10(root="../../../../../_Datasets/cifar10/", train=True, download=True, transform=cifar_train)
# test_dataset = datasets.CIFAR10(root="../../../../../_Datasets/cifar10/", train=False, download=True, transform=cifar_test)

In [41]:
cifar_train = transforms.Compose([
    transforms.RandomCrop(size=32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.5071, 0.4865, 0.4409],
        std=[0.2009, 0.1984, 0.2023],
    ),
])

cifar_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.5071, 0.4865, 0.4409],
        std=[0.2009, 0.1984, 0.2023],
    ),
])

train_dataset = datasets.CIFAR100(root="../../../../../_Datasets/cifar100/", train=True, download=True, transform=cifar_train)
test_dataset = datasets.CIFAR100(root="../../../../../_Datasets/cifar100/", train=False, download=True, transform=cifar_test)

Files already downloaded and verified
Files already downloaded and verified


In [42]:
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=128, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=128, shuffle=False, num_workers=2)

In [43]:
xx, yy = iter(train_loader).next()

In [44]:
xx.shape

torch.Size([128, 3, 32, 32])

### Model

In [45]:
actf = irf.Swish
flows = [
#     ActNorm2D(3),
    nn.BatchNorm2d(3),
    irf.ConvResidualFlow(3, [32, 32], kernels=5, activation=actf),
    irf.InvertiblePooling(2),
#     ActNorm2D(12),
    nn.BatchNorm2d(12),
    irf.ConvResidualFlow(12, [64, 64], kernels=5, activation=actf),
#     ActNorm2D(12),
    nn.BatchNorm2d(12),
    irf.ConvResidualFlow(12, [64, 64], kernels=5, activation=actf),
    irf.InvertiblePooling(2),
#     ActNorm2D(48),
    nn.BatchNorm2d(48),
    irf.ConvResidualFlow(48, [128, 128], kernels=5, activation=actf),
#     ActNorm2D(48),
    nn.BatchNorm2d(48),
    irf.ConvResidualFlow(48, [128, 128], kernels=5, activation=actf),
    irf.InvertiblePooling(2),
#     ActNorm2D(192),
    nn.BatchNorm2d(192),
    irf.ConvResidualFlow(192, [256, 256], kernels=5, activation=actf),
#     ActNorm2D(192),
    nn.BatchNorm2d(192),
    irf.ConvResidualFlow(192, [256, 256], kernels=5, activation=actf),
    nn.BatchNorm2d(192),
    irf.Flatten(img_size=(192, 4, 4)),
#     ActNorm(3072),
#     nn.BatchNorm1d(3072),
#     nn.Linear(3072, 3072, bias=False),
    nn.BatchNorm1d(3072),
        ]

# backbone = SequentialFlow(flows)
backbone = nn.Sequential(*flows)

In [46]:
backbone.to(device)

Sequential(
  (0): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (1): ConvResidualFlow(
    (resblock): ModuleList(
      (0): Conv2d(3, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (1): Swish()
      (2): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (3): Swish()
      (4): Conv2d(32, 3, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    )
  )
  (2): InvertiblePooling()
  (3): BatchNorm2d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (4): ConvResidualFlow(
    (resblock): ModuleList(
      (0): Conv2d(12, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (1): Swish()
      (2): Conv2d(64, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (3): Swish()
      (4): Conv2d(64, 12, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    )
  )
  (5): BatchNorm2d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (6): ConvResidualFlow(
    (resblock)

In [47]:
# backbone(xx.to(device)).shape, 32*32*3

In [48]:
print("number of params: ", sum(p.numel() for p in backbone.parameters()))

number of params:  9947519


In [49]:
def get_children(module):
    child = list(module.children())
    if len(child) == 0:
        return [module]
    children = []
    for ch in child:
        grand_ch = get_children(ch)
        children+=grand_ch
    return children

def remove_spectral_norm(model):
    for child in get_children(model):
        if hasattr(child, 'weight'):
            print("Yes", child)
            try:
                nn.utils.remove_spectral_norm(child)
                print("Success")
            except:
                print("Failed")
    return

In [50]:
remove_spectral_norm(backbone)

Yes BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
Failed
Yes Conv2d(3, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
Success
Yes Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
Success
Yes Conv2d(32, 3, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
Success
Yes BatchNorm2d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
Failed
Yes Conv2d(12, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
Success
Yes Conv2d(64, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
Success
Yes Conv2d(64, 12, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
Success
Yes BatchNorm2d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
Failed
Yes Conv2d(12, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
Success
Yes Conv2d(64, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
Success
Yes Conv2d(64, 12, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
Success
Yes BatchNorm2d(48, eps=1e-0

In [51]:
for xx, yy in train_loader:
    tt = backbone(xx.to(device))
    print(xx.shape, tt.shape)
    break

torch.Size([128, 3, 32, 32]) torch.Size([128, 3072])


In [52]:
class ConnectedClassifier_Linear(nn.Module):
    
    def __init__(self,input_dim, num_sets, output_dim, inv_temp=1):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.num_sets = num_sets
        self.inv_temp = nn.Parameter(torch.ones(1)*inv_temp)
        
        self.linear = nn.Linear(input_dim, num_sets)
#         self.linear.bias.data *= 0
#         self.linear.weight.data *= 0.1
#         self.cls_weight = nn.Parameter(torch.randn(num_sets, output_dim)/output_dim)

        init_val = torch.randn(num_sets, output_dim)
        for ns in range(num_sets):
            init_val[ns, ns%output_dim] = 5
        self.cls_weight = nn.Parameter(init_val)
        
        self.cls_confidence = None
        
        
    def forward(self, x, hard=False):
#         self.cls_weight.data = torch.abs(self.cls_weight.data/self.cls_weight.data.sum(dim=1, keepdim=True))
        
        x = self.linear(x)*torch.exp(self.inv_temp)
        if hard:
            x = torch.softmax(x*1e5, dim=1)
        else:
            x = torch.softmax(x, dim=1)
#             x = torch.softmax(x*self.inv_temp, dim=1)
        self.cls_confidence = x
#         c = torch.softmax(self.cls_weight, dim=1)
        c = self.cls_weight
        return x@c ## since both are normalized, it is also normalized

In [53]:
class ConnectedClassifier_SoftKMeans(nn.Module):
    
    def __init__(self,input_dim, num_sets, output_dim, inv_temp=1):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.num_sets = num_sets
        self.inv_temp = nn.Parameter(torch.ones(1)*inv_temp)
        
        self.centers = nn.Parameter(torch.rand(num_sets, input_dim)*2-1)
        
#         self.cls_weight = nn.Parameter(torch.ones(num_sets, output_dim)/output_dim)

        init_val = torch.randn(num_sets, output_dim)
        for ns in range(num_sets):
            init_val[ns, ns%output_dim] = 5
        self.cls_weight = nn.Parameter(init_val)

        self.cls_confidence = None
        
        
    def forward(self, x, hard=False):
#         self.cls_weight.data = torch.abs(self.cls_weight.data/self.cls_weight.data.sum(dim=1, keepdim=True))
        
        dists = torch.cdist(x, self.centers)
        ### correction to make diagonal of unit square 1 in nD space
        dists = dists/np.sqrt(self.input_dim)
        dists = dists*torch.exp(self.inv_temp)
#         dists = dists/self.input_dim
#         dists = dists/dists.norm(dim=1, keepdim=True)
        if hard:
            x = torch.softmax(-dists*1e5, dim=1)
        else:
            x = torch.softmax(-dists, dim=1)
#             x = torch.softmax(-dists*self.inv_temp, dim=1)
        self.cls_confidence = x
#         c = torch.softmax(self.cls_weight, dim=1)
        c = self.cls_weight
        return x@c ## since both are normalized, it is also normalized
#         return torch.softmax(x@self.cls_weight, dim=1)

    def set_centroid_to_data_randomly(self, data_loader, model):
        num_centers = self.centers.shape[0]
        xxs, yys = [], []
        count = 0
        for xx, yy in data_loader:
            yout = model(xx.to(device)).data.cpu()
            xxs.append(yout)
            yys.append(yy)
            count += len(xx)
            if count >= num_centers:
                break
        
        yout = torch.cat(xxs, dim=0)
        yy = torch.cat(yys, dim=0)
        
        yout = yout[:num_centers].to(self.centers.device)
        yy = yy[:num_centers].to(self.centers.device)
        
        self.centers.data = yout
        
        init_val = torch.randn(self.num_sets, self.output_dim)#/self.output_dim
        for ns in range(num_centers):
            init_val[ns, yy[ns]] = 5.
        self.cls_weight.data = init_val.to(self.cls_weight.device)
        pass

In [54]:
#### for cifar 10
# classifier = ConnectedClassifier_SoftKMeans(3072, 100, 10)
# classifier = ConnectedClassifier_Linear(3072, 100, 10)

#### for cifar 100
# classifier = ConnectedClassifier_SoftKMeans(3072, 500, 100, inv_temp=0.8)
classifier = ConnectedClassifier_Linear(3072, 500, 100, inv_temp=0)
# classifier = ConnectedClassifier_Linear(3072, 500, 100, )

#### for MLP based classification
# classifier = nn.Sequential(nn.Linear(3072, 500), nn.SELU(), nn.Linear(500, 100))

In [55]:
classifier = classifier.to(device)

In [56]:
print("number of params: ", sum(p.numel() for p in backbone.parameters()))
print("number of params: ", sum(p.numel() for p in classifier.parameters()))

number of params:  9947519
number of params:  1586501


In [57]:
### debug linear classifier
yout = classifier(torch.randn(10, 3072).to(device))

In [58]:
i = 0
yout[i].sort()

torch.return_types.sort(
values=tensor([-5.1136e-02, -3.9579e-02, -3.8149e-02, -3.8140e-02, -3.6784e-02,
        -3.5998e-02, -3.3601e-02, -3.0321e-02, -2.8377e-02, -2.2965e-02,
        -1.9536e-02, -1.7964e-02, -1.7334e-02, -1.4202e-02,  3.8575e-05,
         1.0551e-03,  1.8266e-03,  3.8670e-03,  7.4088e-03,  7.6089e-03,
         7.9190e-03,  7.9490e-03,  9.0643e-03,  9.1031e-03,  1.0729e-02,
         1.6859e-02,  1.7508e-02,  1.7537e-02,  1.8479e-02,  1.8817e-02,
         2.5470e-02,  2.6183e-02,  2.6191e-02,  2.6932e-02,  2.7852e-02,
         3.1330e-02,  3.1937e-02,  3.4138e-02,  3.5556e-02,  3.5566e-02,
         3.7072e-02,  4.0023e-02,  4.0107e-02,  4.1044e-02,  4.4740e-02,
         4.7292e-02,  4.7450e-02,  4.9267e-02,  5.1015e-02,  5.1817e-02,
         5.2204e-02,  5.2969e-02,  5.3234e-02,  5.4429e-02,  5.6425e-02,
         5.7116e-02,  5.7443e-02,  5.8323e-02,  6.1053e-02,  6.1151e-02,
         6.1378e-02,  6.1871e-02,  6.2040e-02,  6.2043e-02,  6.2455e-02,
         6.2472e-02

In [59]:
classifier.cls_confidence[i].sort()

torch.return_types.sort(
values=tensor([0.0002, 0.0003, 0.0003, 0.0003, 0.0003, 0.0004, 0.0004, 0.0004, 0.0004,
        0.0005, 0.0005, 0.0005, 0.0005, 0.0005, 0.0005, 0.0005, 0.0005, 0.0006,
        0.0006, 0.0006, 0.0006, 0.0006, 0.0006, 0.0006, 0.0006, 0.0006, 0.0006,
        0.0006, 0.0007, 0.0007, 0.0007, 0.0007, 0.0007, 0.0007, 0.0007, 0.0007,
        0.0007, 0.0007, 0.0007, 0.0007, 0.0007, 0.0007, 0.0007, 0.0008, 0.0008,
        0.0008, 0.0008, 0.0008, 0.0008, 0.0008, 0.0008, 0.0008, 0.0008, 0.0008,
        0.0008, 0.0008, 0.0008, 0.0008, 0.0008, 0.0008, 0.0008, 0.0008, 0.0008,
        0.0008, 0.0008, 0.0009, 0.0009, 0.0009, 0.0009, 0.0009, 0.0009, 0.0009,
        0.0009, 0.0009, 0.0009, 0.0009, 0.0009, 0.0009, 0.0009, 0.0009, 0.0009,
        0.0009, 0.0009, 0.0009, 0.0009, 0.0009, 0.0009, 0.0009, 0.0010, 0.0010,
        0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010,
        0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010,
        

In [60]:
# classifier.set_centroid_to_data_randomly(train_loader, backbone)

In [61]:
model = nn.Sequential(backbone, classifier).to(device)

In [62]:
print("number of params: ", sum(p.numel() for p in model.parameters()))

number of params:  11534020


## Training

In [63]:
 ## debugging to find the good classifier/output distribution.
# model_name = 'c100_inv_v1'## using linear+500 units
# model_name = 'c100_inv_v2' ## using dists+500 units
# model_name = 'c100_inv_v3' ## using dists+3072 units
# model_name = 'c100_inv_v4' ## using linear+3072+unnormalized output units
# model_name = 'c100_inv_v5' ## using dists+500+unnormalized output units

# model_name = 'c100_inv_v6' ## using linear+500+unnormalized output units
model_name = 'c100_ord_v6' ## using linear+500+unnormalized output units

In [64]:
# model_name = 'c10_inv_v0'
# model_name = 'c10_ord_v0'
# model_name = 'c100_inv_v0'
# model_name = 'c100_ord_v0'

In [65]:
EPOCHS = 200
criterion = nn.CrossEntropyLoss()
# optimizer = torch.optim.SGD(model.parameters(), lr=0.1,
#                       momentum=0.9, weight_decay=5e-4)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001) ## lr=0.0001
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)

In [66]:
## Following is copied from 
### https://github.com/kuangliu/pytorch-cifar/blob/master/main.py

# Training
def train(epoch):
    model.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(tqdm(train_loader)):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
    print(f"[Train] {epoch} Loss: {train_loss/(batch_idx+1):.3f} | Acc: {100.*correct/total:.3f} {correct}/{total}")
    return

In [67]:
best_acc = -1
def test(epoch):
    global best_acc
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(tqdm(test_loader)):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            
    print(f"[Test] {epoch} Loss: {test_loss/(batch_idx+1):.3f} | Acc: {100.*correct/total:.3f} {correct}/{total}")
    
    # Save checkpoint.
    acc = 100.*correct/total
    if acc > best_acc:
        print('Saving..')
        state = {
            'model': model.state_dict(),
            'acc': acc,
            'epoch': epoch,
        }
        if not os.path.isdir('models'):
            os.mkdir('models')
        torch.save(state, f'./models/{model_name}.pth')
        best_acc = acc

In [68]:
start_epoch = 0  # start from epoch 0 or last checkpoint epoch
resume = False

if resume:
    # Load checkpoint.
    print('==> Resuming from checkpoint..')
    assert os.path.isdir('./models'), 'Error: no checkpoint directory found!'
    checkpoint = torch.load(f'./models/{model_name}.pth')
    model.load_state_dict(checkpoint['model'])
    best_acc = checkpoint['acc']
    start_epoch = checkpoint['epoch']

In [69]:
### Train the whole damn thing

for epoch in range(start_epoch, start_epoch+EPOCHS): ## for 200 epochs
    train(epoch)
    test(epoch)
    scheduler.step()

100%|███████████████████████████████████████████████████| 391/391 [00:38<00:00, 10.20it/s]


[Train] 0 Loss: 4.490 | Acc: 4.420 2210/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 27.83it/s]


[Test] 0 Loss: 4.363 | Acc: 7.490 749/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:38<00:00, 10.17it/s]


[Train] 1 Loss: 4.342 | Acc: 7.664 3832/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 27.84it/s]


[Test] 1 Loss: 4.261 | Acc: 9.720 972/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:38<00:00, 10.05it/s]


[Train] 2 Loss: 4.267 | Acc: 9.412 4706/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 27.65it/s]


[Test] 2 Loss: 4.198 | Acc: 11.460 1146/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:39<00:00, 10.02it/s]


[Train] 3 Loss: 4.210 | Acc: 10.912 5456/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 27.58it/s]


[Test] 3 Loss: 4.153 | Acc: 12.590 1259/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:39<00:00, 10.00it/s]


[Train] 4 Loss: 4.163 | Acc: 12.230 6115/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 27.59it/s]


[Test] 4 Loss: 4.104 | Acc: 13.910 1391/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:39<00:00,  9.99it/s]


[Train] 5 Loss: 4.125 | Acc: 13.208 6604/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 27.44it/s]


[Test] 5 Loss: 4.071 | Acc: 14.720 1472/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:39<00:00,  9.96it/s]


[Train] 6 Loss: 4.095 | Acc: 14.020 7010/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 27.48it/s]


[Test] 6 Loss: 4.043 | Acc: 15.480 1548/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:39<00:00,  9.96it/s]


[Train] 7 Loss: 4.059 | Acc: 14.656 7328/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 27.44it/s]


[Test] 7 Loss: 4.015 | Acc: 15.800 1580/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:39<00:00,  9.94it/s]


[Train] 8 Loss: 4.037 | Acc: 15.100 7550/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 27.46it/s]


[Test] 8 Loss: 3.998 | Acc: 16.440 1644/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:39<00:00,  9.96it/s]


[Train] 9 Loss: 4.008 | Acc: 16.026 8013/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 27.50it/s]


[Test] 9 Loss: 3.974 | Acc: 16.970 1697/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:39<00:00,  9.95it/s]


[Train] 10 Loss: 3.988 | Acc: 16.310 8155/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 27.41it/s]


[Test] 10 Loss: 3.959 | Acc: 17.310 1731/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:39<00:00,  9.97it/s]


[Train] 11 Loss: 3.965 | Acc: 16.696 8348/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 27.43it/s]


[Test] 11 Loss: 3.938 | Acc: 17.860 1786/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:39<00:00,  9.95it/s]


[Train] 12 Loss: 3.942 | Acc: 17.144 8572/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 27.42it/s]


[Test] 12 Loss: 3.922 | Acc: 18.150 1815/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:39<00:00,  9.97it/s]


[Train] 13 Loss: 3.928 | Acc: 17.462 8731/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 27.47it/s]


[Test] 13 Loss: 3.907 | Acc: 18.370 1837/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:39<00:00,  9.95it/s]


[Train] 14 Loss: 3.902 | Acc: 17.974 8987/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 27.49it/s]


[Test] 14 Loss: 3.897 | Acc: 18.750 1875/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:39<00:00,  9.97it/s]


[Train] 15 Loss: 3.885 | Acc: 18.348 9174/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 27.44it/s]


[Test] 15 Loss: 3.885 | Acc: 18.760 1876/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:39<00:00,  9.96it/s]


[Train] 16 Loss: 3.873 | Acc: 18.526 9263/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 27.49it/s]


[Test] 16 Loss: 3.881 | Acc: 18.840 1884/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:39<00:00,  9.96it/s]


[Train] 17 Loss: 3.851 | Acc: 19.182 9591/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 27.53it/s]


[Test] 17 Loss: 3.867 | Acc: 19.100 1910/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:39<00:00,  9.95it/s]


[Train] 18 Loss: 3.836 | Acc: 19.338 9669/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 27.44it/s]


[Test] 18 Loss: 3.851 | Acc: 19.540 1954/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:39<00:00,  9.97it/s]


[Train] 19 Loss: 3.816 | Acc: 19.834 9917/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 27.44it/s]


[Test] 19 Loss: 3.831 | Acc: 20.020 2002/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:39<00:00,  9.96it/s]


[Train] 20 Loss: 3.799 | Acc: 20.184 10092/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 27.44it/s]


[Test] 20 Loss: 3.827 | Acc: 19.890 1989/10000


100%|███████████████████████████████████████████████████| 391/391 [00:39<00:00,  9.96it/s]


[Train] 21 Loss: 3.782 | Acc: 20.418 10209/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 27.40it/s]


[Test] 21 Loss: 3.820 | Acc: 20.030 2003/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:39<00:00,  9.95it/s]


[Train] 22 Loss: 3.767 | Acc: 20.886 10443/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 27.25it/s]


[Test] 22 Loss: 3.815 | Acc: 19.910 1991/10000


100%|███████████████████████████████████████████████████| 391/391 [00:39<00:00,  9.97it/s]


[Train] 23 Loss: 3.759 | Acc: 21.020 10510/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 27.47it/s]


[Test] 23 Loss: 3.796 | Acc: 20.480 2048/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:39<00:00,  9.95it/s]


[Train] 24 Loss: 3.738 | Acc: 21.296 10648/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 27.43it/s]


[Test] 24 Loss: 3.788 | Acc: 20.390 2039/10000


100%|███████████████████████████████████████████████████| 391/391 [00:39<00:00,  9.96it/s]


[Train] 25 Loss: 3.731 | Acc: 21.532 10766/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 27.51it/s]


[Test] 25 Loss: 3.782 | Acc: 20.630 2063/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:39<00:00,  9.95it/s]


[Train] 26 Loss: 3.715 | Acc: 21.712 10856/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 27.40it/s]


[Test] 26 Loss: 3.772 | Acc: 20.930 2093/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:39<00:00,  9.97it/s]


[Train] 27 Loss: 3.694 | Acc: 22.302 11151/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 27.45it/s]


[Test] 27 Loss: 3.762 | Acc: 20.840 2084/10000


100%|███████████████████████████████████████████████████| 391/391 [00:39<00:00,  9.95it/s]


[Train] 28 Loss: 3.681 | Acc: 22.468 11234/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 27.35it/s]


[Test] 28 Loss: 3.751 | Acc: 21.120 2112/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:39<00:00,  9.97it/s]


[Train] 29 Loss: 3.668 | Acc: 22.826 11413/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 27.37it/s]


[Test] 29 Loss: 3.747 | Acc: 21.270 2127/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:39<00:00,  9.95it/s]


[Train] 30 Loss: 3.659 | Acc: 22.982 11491/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 27.32it/s]


[Test] 30 Loss: 3.743 | Acc: 21.260 2126/10000


100%|███████████████████████████████████████████████████| 391/391 [00:39<00:00,  9.96it/s]


[Train] 31 Loss: 3.645 | Acc: 23.376 11688/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 27.50it/s]


[Test] 31 Loss: 3.737 | Acc: 21.080 2108/10000


100%|███████████████████████████████████████████████████| 391/391 [00:39<00:00,  9.95it/s]


[Train] 32 Loss: 3.633 | Acc: 23.282 11641/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 27.44it/s]


[Test] 32 Loss: 3.732 | Acc: 21.840 2184/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:39<00:00,  9.95it/s]


[Train] 33 Loss: 3.617 | Acc: 23.778 11889/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 27.45it/s]


[Test] 33 Loss: 3.726 | Acc: 21.500 2150/10000


100%|███████████████████████████████████████████████████| 391/391 [00:39<00:00,  9.96it/s]


[Train] 34 Loss: 3.601 | Acc: 23.996 11998/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 27.42it/s]


[Test] 34 Loss: 3.714 | Acc: 21.840 2184/10000


100%|███████████████████████████████████████████████████| 391/391 [00:39<00:00,  9.96it/s]


[Train] 35 Loss: 3.593 | Acc: 24.250 12125/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 27.42it/s]


[Test] 35 Loss: 3.713 | Acc: 22.130 2213/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:39<00:00,  9.94it/s]


[Train] 36 Loss: 3.577 | Acc: 24.476 12238/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 27.41it/s]


[Test] 36 Loss: 3.689 | Acc: 22.570 2257/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:39<00:00,  9.96it/s]


[Train] 37 Loss: 3.562 | Acc: 24.966 12483/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 27.50it/s]


[Test] 37 Loss: 3.694 | Acc: 22.350 2235/10000


100%|███████████████████████████████████████████████████| 391/391 [00:39<00:00,  9.95it/s]


[Train] 38 Loss: 3.554 | Acc: 24.986 12493/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 27.55it/s]


[Test] 38 Loss: 3.688 | Acc: 22.700 2270/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:39<00:00,  9.96it/s]


[Train] 39 Loss: 3.545 | Acc: 25.160 12580/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 27.51it/s]


[Test] 39 Loss: 3.694 | Acc: 21.970 2197/10000


100%|███████████████████████████████████████████████████| 391/391 [00:39<00:00,  9.94it/s]


[Train] 40 Loss: 3.529 | Acc: 25.454 12727/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 27.40it/s]


[Test] 40 Loss: 3.685 | Acc: 22.520 2252/10000


100%|███████████████████████████████████████████████████| 391/391 [00:39<00:00,  9.96it/s]


[Train] 41 Loss: 3.525 | Acc: 25.618 12809/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 27.41it/s]


[Test] 41 Loss: 3.682 | Acc: 22.670 2267/10000


100%|███████████████████████████████████████████████████| 391/391 [00:39<00:00,  9.96it/s]


[Train] 42 Loss: 3.505 | Acc: 25.918 12959/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 27.44it/s]


[Test] 42 Loss: 3.664 | Acc: 23.190 2319/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:39<00:00,  9.96it/s]


[Train] 43 Loss: 3.495 | Acc: 26.254 13127/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 27.45it/s]


[Test] 43 Loss: 3.667 | Acc: 22.730 2273/10000


 92%|██████████████████████████████████████████████▉    | 360/391 [00:36<00:03,  9.93it/s]


KeyboardInterrupt: 

In [159]:
best_acc

57.25

In [157]:
torch.count_nonzero(a)

tensor(810)

In [None]:
######## C10

#### non-inv nn with MLP classifier: 85.73 Acc ; using no-spectral init
#### non-inv nn with connected classifier:   Acc ; using no-spectral init
#### inv nn with connected classifier: 84.25 Acc ; spectral normalized

In [33]:
######## C100
#### inv nn with MLP classifier: 59.3 Acc ;
#### non-inv nn with MLP classifier: 54.67 Acc ;

#### inv nn + ConnectedDist: 30.03 Acc; -> v2
#### inv nn + ConnectedLin: 48.79 Acc; -> v1

#### inv nn + ConnectedLin-3072: 46.82 Acc; v3
#### inv nn + ConnectedLin-3072-unnormalized: 55.54 Acc / 54.97 (Hard); v4

#### inv nn + ConnectedLin-500-unnormalized: 57.25 Acc / 56.83 (Hard); v6

In [160]:
classifier.inv_temp

Parameter containing:
tensor([0.3377], device='cuda:1', requires_grad=True)

In [161]:
checkpoint = torch.load(f'./models/{model_name}.pth')
best_acc = checkpoint['acc']
start_epoch = checkpoint['epoch']

best_acc, start_epoch

(57.25, 160)

In [163]:
model.load_state_dict(checkpoint['model'])

<All keys matched successfully>

In [None]:
backbone, classifier = model[0], model[1]

### Hard test accuracy with count per classifier

In [165]:
test_count = 0
test_acc = 0
set_count = torch.zeros(classifier.num_sets).to(device)
model.eval()
for xx, yy in tqdm(test_loader):
    xx, yy = xx.to(device), yy.to(device)
    with torch.no_grad():
        yout = classifier(backbone(xx), hard=True)
        set_indx, count = torch.unique(torch.argmax(classifier.cls_confidence, dim=1), return_counts=True) 
        set_count[set_indx] += count
    outputs = torch.argmax(yout, dim=1).data.cpu().numpy()
    correct = (outputs == yy.data.cpu().numpy()).astype(float).sum()
    test_acc += correct
    test_count += len(xx)

print(f'Hard Test Acc:{float(test_acc)/test_count*100:.2f}%')
print(set_count.type(torch.long).tolist())

100%|██████████| 79/79 [00:02<00:00, 26.92it/s]

Hard Test Acc:56.83%
[98, 0, 1, 20, 77, 0, 3, 2, 0, 1, 48, 12, 79, 0, 0, 94, 14, 0, 4, 10, 0, 21, 1, 98, 94, 13, 72, 56, 72, 15, 0, 0, 0, 1, 2, 7, 86, 0, 0, 0, 4, 0, 1, 0, 3, 67, 4, 42, 1, 0, 1, 82, 0, 107, 0, 31, 0, 0, 90, 102, 1, 1, 6, 0, 10, 43, 4, 0, 105, 2, 86, 102, 21, 0, 0, 0, 1, 0, 25, 80, 1, 1, 2, 8, 61, 0, 0, 0, 3, 101, 1, 8, 82, 49, 1, 0, 0, 6, 1, 66, 6, 2, 1, 1, 20, 1, 0, 0, 66, 77, 3, 2, 1, 83, 0, 3, 2, 0, 8, 0, 2, 0, 57, 8, 0, 24, 3, 0, 4, 56, 0, 3, 5, 66, 81, 0, 8, 92, 102, 16, 5, 1, 10, 0, 50, 0, 0, 48, 110, 0, 4, 5, 0, 3, 0, 26, 0, 6, 4, 12, 0, 13, 90, 54, 0, 1, 48, 3, 0, 0, 0, 0, 28, 3, 14, 0, 108, 1, 2, 0, 82, 9, 0, 1, 1, 6, 4, 0, 0, 12, 98, 45, 2, 3, 0, 38, 0, 63, 3, 1, 0, 1, 1, 9, 5, 87, 0, 0, 0, 1, 29, 3, 0, 9, 8, 1, 75, 0, 0, 0, 11, 0, 2, 0, 6, 3, 0, 1, 20, 0, 80, 3, 56, 6, 18, 8, 0, 3, 0, 1, 4, 84, 0, 0, 4, 9, 0, 7, 1, 85, 86, 7, 95, 0, 2, 1, 97, 4, 1, 0, 94, 0, 0, 14, 3, 72, 2, 0, 0, 94, 3, 0, 41, 10, 73, 102, 18, 0, 9, 2, 20, 1, 0, 93, 25, 0, 5, 42, 74, 1, 1, 




In [None]:
#### Classifiers that enclose any data
torch.count_nonzero(set_count)

### Hard train accuracy with count per classifier

In [34]:
test_count = 0
test_acc = 0
set_count = torch.zeros(classifier.num_sets).to(device)
for xx, yy in tqdm(train_loader):
    xx, yy = xx.to(device), yy.to(device)
    with torch.no_grad():
        yout = classifier(backbone(xx), hard=True)
        set_indx, count = torch.unique(torch.argmax(classifier.cls_confidence, dim=1), return_counts=True) 
        set_count[set_indx] += count
    outputs = torch.argmax(yout, dim=1).data.cpu().numpy()
    correct = (outputs == yy.data.cpu().numpy()).astype(float).sum()
    test_acc += correct
    test_count += len(xx)

print(f'Hard Train Acc:{float(test_acc)/test_count*100:.2f}%')
print(set_count.type(torch.long).tolist())

100%|██████████| 391/391 [00:29<00:00, 13.04it/s]

Hard Train Acc:91.31%
[0, 0, 0, 0, 4452, 25, 4597, 0, 4948, 0, 0, 0, 0, 162, 0, 0, 0, 0, 0, 4909, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5013, 4926, 0, 56, 0, 5022, 4125, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 25, 0, 5073, 38, 4, 158, 0, 0, 0, 0, 0, 60, 9, 84, 0, 0, 119, 4967, 0, 25, 0, 0, 0, 0, 0, 0, 0, 0, 30, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 284, 0, 0, 0, 0, 0, 888, 0]





In [35]:
#### Classifiers that enclose any data
torch.count_nonzero(set_count)

tensor(26, device='cuda:1')

In [36]:
#### classifier with class representation
torch.argmax(classifier.cls_weight, dim=1)

tensor([6, 3, 7, 9, 5, 0, 2, 3, 3, 3, 7, 7, 5, 3, 5, 6, 3, 1, 8, 7, 3, 1, 5, 2,
        3, 3, 3, 1, 6, 2, 1, 1, 5, 6, 9, 9, 2, 3, 4, 8, 3, 9, 8, 3, 5, 9, 9, 8,
        3, 2, 2, 1, 8, 8, 8, 0, 6, 7, 2, 3, 5, 7, 0, 9, 1, 0, 6, 6, 5, 4, 1, 1,
        9, 8, 2, 5, 1, 1, 7, 7, 8, 9, 2, 2, 9, 0, 9, 5, 4, 3, 3, 1, 5, 1, 5, 4,
        5, 9, 8, 5], device='cuda:1')