In [1]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms

import matplotlib
import matplotlib.pyplot as plt
import numpy as np

import os, sys, pathlib, random, time, pickle, copy
from tqdm import tqdm

In [247]:
device = torch.device("cuda:0")
# device = torch.device("cpu")

In [248]:
import torch.optim as optim
from torch.utils import data

In [249]:
import nflib
from nflib.flows import SequentialFlow, NormalizingFlow, ActNorm, ActNorm2D, AffineConstantFlow
import nflib.coupling_flows as icf
import nflib.inn_flow as inn
import nflib.res_flow as irf

### Datasets

In [250]:
# cifar_train = transforms.Compose([
#     transforms.RandomCrop(size=32, padding=4),
#     transforms.RandomHorizontalFlip(),
#     transforms.ToTensor(),
#     transforms.Normalize(
#         mean=[0.4914, 0.4822, 0.4465], # mean=[0.5071, 0.4865, 0.4409] for cifar100
#         std=[0.2023, 0.1994, 0.2010], # std=[0.2009, 0.1984, 0.2023] for cifar100
#     ),
# ])

# cifar_test = transforms.Compose([
#     transforms.ToTensor(),
#     transforms.Normalize(
#         mean=[0.4914, 0.4822, 0.4465], # mean=[0.5071, 0.4865, 0.4409] for cifar100
#         std=[0.2023, 0.1994, 0.2010], # std=[0.2009, 0.1984, 0.2023] for cifar100
#     ),
# ])

# train_dataset = datasets.CIFAR10(root="../../../../../_Datasets/cifar10/", train=True, download=True, transform=cifar_train)
# test_dataset = datasets.CIFAR10(root="../../../../../_Datasets/cifar10/", train=False, download=True, transform=cifar_test)

In [343]:
cifar_train = transforms.Compose([
    transforms.RandomCrop(size=32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.5071, 0.4865, 0.4409],
        std=[0.2009, 0.1984, 0.2023],
    ),
])

cifar_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.5071, 0.4865, 0.4409],
        std=[0.2009, 0.1984, 0.2023],
    ),
])

train_dataset = datasets.CIFAR100(root="../../../../../_Datasets/cifar100/", train=True, download=True, transform=cifar_train)
test_dataset = datasets.CIFAR100(root="../../../../../_Datasets/cifar100/", train=False, download=True, transform=cifar_test)

Files already downloaded and verified
Files already downloaded and verified


In [344]:
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=128, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=128, shuffle=False, num_workers=2)

In [345]:
xx, yy = iter(train_loader).next()

In [346]:
xx.shape

torch.Size([128, 3, 32, 32])

### Model

In [407]:
actf = irf.Swish
flows = [
#     ActNorm2D(3),
    nn.BatchNorm2d(3),
    irf.ConvResidualFlow(3, [32, 32], kernels=5, activation=actf),
    irf.InvertiblePooling(2),
#     ActNorm2D(12),
    nn.BatchNorm2d(12),
    irf.ConvResidualFlow(12, [64, 64], kernels=5, activation=actf),
#     ActNorm2D(12),
    nn.BatchNorm2d(12),
    irf.ConvResidualFlow(12, [64, 64], kernels=5, activation=actf),
    irf.InvertiblePooling(2),
#     ActNorm2D(48),
    nn.BatchNorm2d(48),
    irf.ConvResidualFlow(48, [128, 128], kernels=5, activation=actf),
#     ActNorm2D(48),
    nn.BatchNorm2d(48),
    irf.ConvResidualFlow(48, [128, 128], kernels=5, activation=actf),
    irf.InvertiblePooling(2),
#     ActNorm2D(192),
    nn.BatchNorm2d(192),
    irf.ConvResidualFlow(192, [256, 256], kernels=5, activation=actf),
#     ActNorm2D(192),
    nn.BatchNorm2d(192),
    irf.ConvResidualFlow(192, [256, 256], kernels=5, activation=actf),
    nn.BatchNorm2d(192),
    irf.Flatten(img_size=(192, 4, 4)),
#     ActNorm(3072),
#     nn.BatchNorm1d(3072),
#     nn.Linear(3072, 3072, bias=False),
    nn.BatchNorm1d(3072),
        ]

# backbone = SequentialFlow(flows)
backbone = nn.Sequential(*flows)

In [408]:
backbone.to(device)

Sequential(
  (0): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (1): ConvResidualFlow(
    (resblock): ModuleList(
      (0): Conv2d(3, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (1): Swish()
      (2): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (3): Swish()
      (4): Conv2d(32, 3, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    )
  )
  (2): InvertiblePooling()
  (3): BatchNorm2d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (4): ConvResidualFlow(
    (resblock): ModuleList(
      (0): Conv2d(12, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (1): Swish()
      (2): Conv2d(64, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (3): Swish()
      (4): Conv2d(64, 12, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    )
  )
  (5): BatchNorm2d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (6): ConvResidualFlow(
    (resblock)

In [409]:
backbone(xx.to(device)).shape, 32*32*3

(torch.Size([128, 3072]), 3072)

In [410]:
print("number of params: ", sum(p.numel() for p in backbone.parameters()))

number of params:  9947519


In [467]:
def get_children(module):
    child = list(module.children())
    if len(child) == 0:
        return [module]
    children = []
    for ch in child:
        grand_ch = get_children(ch)
        children+=grand_ch
    return children

def remove_spectral_norm(model):
    for child in get_children(model):
        if hasattr(child, 'weight'):
            print("Yes", child)
            try:
                nn.utils.remove_spectral_norm(child)
                print("Success")
            except:
                print("Failed")
    return

In [468]:
# remove_spectral_norm(backbone)

In [469]:
for xx, yy in train_loader:
    tt = backbone(xx.to(device))
    print(xx.shape, tt.shape)
    break

torch.Size([128, 3, 32, 32]) torch.Size([128, 3072])


In [470]:
class ConnectedClassifier_Linear(nn.Module):
    
    def __init__(self,input_dim, num_sets, output_dim, inv_temp=1):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.num_sets = num_sets
        self.inv_temp = nn.Parameter(torch.ones(1)*inv_temp)
        
        self.linear = nn.Linear(input_dim, num_sets)
#         self.linear.bias.data *= 0
#         self.linear.weight.data *= 0.1
#         self.cls_weight = nn.Parameter(torch.randn(num_sets, output_dim)/output_dim)

        init_val = torch.randn(num_sets, output_dim)
        for ns in range(num_sets):
            init_val[ns, ns%output_dim] = 5
        self.cls_weight = nn.Parameter(init_val)
        
        self.cls_confidence = None
        
        
    def forward(self, x, hard=False):
#         self.cls_weight.data = torch.abs(self.cls_weight.data/self.cls_weight.data.sum(dim=1, keepdim=True))
        
        x = self.linear(x)*torch.exp(self.inv_temp)
        if hard:
            x = torch.softmax(x*1e5, dim=1)
        else:
            x = torch.softmax(x, dim=1)
#             x = torch.softmax(x*self.inv_temp, dim=1)
        self.cls_confidence = x
#         c = torch.softmax(self.cls_weight, dim=1)
        c = self.cls_weight
        return x@c ## since both are normalized, it is also normalized

In [471]:
class ConnectedClassifier_SoftKMeans(nn.Module):
    
    def __init__(self,input_dim, num_sets, output_dim, inv_temp=1):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.num_sets = num_sets
        self.inv_temp = nn.Parameter(torch.ones(1)*inv_temp)
        
        self.centers = nn.Parameter(torch.rand(num_sets, input_dim)*2-1)
        
#         self.cls_weight = nn.Parameter(torch.ones(num_sets, output_dim)/output_dim)

        init_val = torch.randn(num_sets, output_dim)
        for ns in range(num_sets):
            init_val[ns, ns%output_dim] = 5
        self.cls_weight = nn.Parameter(init_val)

        self.cls_confidence = None
        self.ln = nn.LayerNorm(self.num_sets, elementwise_affine=False)
        
        
    def forward(self, x, hard=False):
#         self.cls_weight.data = torch.abs(self.cls_weight.data/self.cls_weight.data.sum(dim=1, keepdim=True))
        
        dists = torch.cdist(x, self.centers)
        ### correction to make diagonal of unit square 1 in nD space
        dists = dists/np.sqrt(self.input_dim)

        dists = self.ln(dists)
        dists = dists*torch.exp(self.inv_temp)
        
#         dists = dists/self.input_dim
#         dists = dists/dists.norm(dim=1, keepdim=True)
        if hard:
            x = torch.softmax(-dists*1e5, dim=1)
        else:
            x = torch.softmax(-dists, dim=1)
#             x = torch.softmax(-dists*self.inv_temp, dim=1)
        self.cls_confidence = x
#         c = torch.softmax(self.cls_weight, dim=1)
        c = self.cls_weight
        return x@c ## since both are normalized, it is also normalized
#         return torch.softmax(x@self.cls_weight, dim=1)

    def set_centroid_to_data_randomly(self, data_loader, model):
        num_centers = self.centers.shape[0]
        xxs, yys = [], []
        count = 0
        for xx, yy in data_loader:
            yout = model(xx.to(device)).data.cpu()
            xxs.append(yout)
            yys.append(yy)
            count += len(xx)
            if count >= num_centers:
                break
        
        yout = torch.cat(xxs, dim=0)
        yy = torch.cat(yys, dim=0)
        
        yout = yout[:num_centers].to(self.centers.device)
        yy = yy[:num_centers].to(self.centers.device)
        
        self.centers.data = yout
        
        init_val = torch.randn(self.num_sets, self.output_dim)#/self.output_dim
        for ns in range(num_centers):
            init_val[ns, yy[ns]] = 5.
        self.cls_weight.data = init_val.to(self.cls_weight.device)
        pass

In [479]:
#### for cifar 10
# classifier = ConnectedClassifier_SoftKMeans(3072, 100, 10)
# classifier = ConnectedClassifier_Linear(3072, 100, 10)

#### for cifar 100
classifier = ConnectedClassifier_SoftKMeans(3072, 500, 100, inv_temp=0)
# classifier = ConnectedClassifier_Linear(3072, 500, 100, inv_temp=0)
# classifier = ConnectedClassifier_Linear(3072, 500, 100, )

#### for MLP based classification
# classifier = nn.Sequential(nn.Linear(3072, 3072), nn.SELU(), nn.Linear(3072, 100))

In [480]:
classifier = classifier.to(device)

In [481]:
classifier.set_centroid_to_data_randomly(train_loader, backbone)

In [482]:
print("number of params: ", sum(p.numel() for p in backbone.parameters()))
print("number of params: ", sum(p.numel() for p in classifier.parameters()))

number of params:  9947519
number of params:  1586001


In [483]:
### debug linear classifier
yout = classifier(torch.randn(10, 3072).to(device))

In [484]:
i = 0
yout[i].sort()

torch.return_types.sort(
values=tensor([-0.0832, -0.0626, -0.0566, -0.0523, -0.0478, -0.0411, -0.0366, -0.0301,
        -0.0286, -0.0270, -0.0177, -0.0161, -0.0138, -0.0136, -0.0134, -0.0091,
        -0.0086, -0.0075, -0.0073, -0.0069, -0.0068, -0.0068, -0.0063, -0.0063,
        -0.0039,  0.0025,  0.0041,  0.0045,  0.0094,  0.0156,  0.0160,  0.0193,
         0.0197,  0.0200,  0.0217,  0.0225,  0.0249,  0.0287,  0.0291,  0.0301,
         0.0318,  0.0321,  0.0332,  0.0349,  0.0355,  0.0362,  0.0370,  0.0374,
         0.0398,  0.0421,  0.0461,  0.0472,  0.0495,  0.0499,  0.0502,  0.0511,
         0.0517,  0.0535,  0.0546,  0.0562,  0.0580,  0.0618,  0.0651,  0.0663,
         0.0675,  0.0726,  0.0734,  0.0745,  0.0776,  0.0800,  0.0815,  0.0830,
         0.0860,  0.0866,  0.0882,  0.0900,  0.0922,  0.0929,  0.0939,  0.1003,
         0.1011,  0.1025,  0.1048,  0.1054,  0.1077,  0.1095,  0.1106,  0.1123,
         0.1184,  0.1209,  0.1233,  0.1417,  0.1441,  0.1453,  0.1536,  0.1640,
        

In [485]:
classifier.cls_confidence[i].sort()

torch.return_types.sort(
values=tensor([3.0278e-05, 3.3608e-05, 3.8403e-05, 3.8621e-05, 4.5161e-05, 4.7087e-05,
        4.7437e-05, 5.1392e-05, 5.2420e-05, 7.6403e-05, 9.4347e-05, 1.0151e-04,
        1.0331e-04, 1.0731e-04, 1.2773e-04, 1.3065e-04, 1.4096e-04, 1.4781e-04,
        1.5580e-04, 1.5842e-04, 1.5904e-04, 1.6153e-04, 1.6203e-04, 1.7211e-04,
        1.7308e-04, 1.8065e-04, 1.8561e-04, 1.9515e-04, 2.0109e-04, 2.0505e-04,
        2.0999e-04, 2.1302e-04, 2.1433e-04, 2.2519e-04, 2.3412e-04, 2.4682e-04,
        2.5914e-04, 2.5943e-04, 2.6600e-04, 2.7200e-04, 2.8443e-04, 2.9012e-04,
        2.9746e-04, 2.9959e-04, 3.0043e-04, 3.1509e-04, 3.2080e-04, 3.2269e-04,
        3.3399e-04, 3.3474e-04, 3.4041e-04, 3.4125e-04, 3.5525e-04, 3.6157e-04,
        3.6179e-04, 3.7181e-04, 3.7595e-04, 3.7719e-04, 3.8021e-04, 3.8124e-04,
        3.8465e-04, 3.8831e-04, 3.9240e-04, 4.1332e-04, 4.1978e-04, 4.3027e-04,
        4.3379e-04, 4.5319e-04, 4.5752e-04, 4.5937e-04, 4.6889e-04, 4.7223e-04,
        

In [486]:
model = nn.Sequential(backbone, classifier).to(device)

In [487]:
print("number of params: ", sum(p.numel() for p in model.parameters()))

number of params:  11533520


## Training

In [488]:
model_name = 'c100_inv_v5'

In [489]:
# model_name = 'c10_inv_v0'
# model_name = 'c10_ord_v0'
# model_name = 'c100_inv_v0'
# model_name = 'c100_ord_v0'

In [490]:
EPOCHS = 200
criterion = nn.CrossEntropyLoss()
# optimizer = torch.optim.SGD(model.parameters(), lr=0.1,
#                       momentum=0.9, weight_decay=5e-4)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)

In [491]:
## Following is copied from 
### https://github.com/kuangliu/pytorch-cifar/blob/master/main.py

# Training
def train(epoch):
    model.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(tqdm(train_loader)):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
    print(f"[Train] {epoch} Loss: {train_loss/(batch_idx+1):.3f} | Acc: {100.*correct/total:.3f} {correct}/{total}")
    return

In [492]:
best_acc = -1
def test(epoch):
    global best_acc
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(tqdm(test_loader)):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            
    print(f"[Test] {epoch} Loss: {test_loss/(batch_idx+1):.3f} | Acc: {100.*correct/total:.3f} {correct}/{total}")
    
    # Save checkpoint.
    acc = 100.*correct/total
    if acc > best_acc:
        print('Saving..')
        state = {
            'model': model.state_dict(),
            'acc': acc,
            'epoch': epoch,
        }
        if not os.path.isdir('models'):
            os.mkdir('models')
        torch.save(state, f'./models/{model_name}.pth')
        best_acc = acc

In [493]:
start_epoch = 0  # start from epoch 0 or last checkpoint epoch
resume = False

if resume:
    # Load checkpoint.
    print('==> Resuming from checkpoint..')
    assert os.path.isdir('./models'), 'Error: no checkpoint directory found!'
    checkpoint = torch.load(f'./models/{model_name}.pth')
    model.load_state_dict(checkpoint['model'])
    best_acc = checkpoint['acc']
    start_epoch = checkpoint['epoch']

In [494]:
### Train the whole damn thing

for epoch in range(start_epoch, start_epoch+EPOCHS): ## for 200 epochs
    train(epoch)
    test(epoch)
    scheduler.step()

100%|██████████| 391/391 [00:39<00:00,  9.93it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 0 Loss: 4.507 | Acc: 7.066 3533/50000


100%|██████████| 79/79 [00:02<00:00, 29.18it/s]


[Test] 0 Loss: 4.479 | Acc: 7.230 723/10000
Saving..


100%|██████████| 391/391 [00:39<00:00,  9.88it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 1 Loss: 4.469 | Acc: 7.284 3642/50000


100%|██████████| 79/79 [00:02<00:00, 28.92it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 1 Loss: 4.448 | Acc: 7.100 710/10000


100%|██████████| 391/391 [00:39<00:00,  9.87it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 2 Loss: 4.431 | Acc: 7.554 3777/50000


100%|██████████| 79/79 [00:02<00:00, 29.10it/s]


[Test] 2 Loss: 4.405 | Acc: 7.790 779/10000
Saving..


100%|██████████| 391/391 [00:39<00:00,  9.88it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 3 Loss: 4.401 | Acc: 8.034 4017/50000


100%|██████████| 79/79 [00:02<00:00, 29.10it/s]


[Test] 3 Loss: 4.395 | Acc: 7.870 787/10000
Saving..


100%|██████████| 391/391 [00:39<00:00,  9.80it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 4 Loss: 4.375 | Acc: 8.372 4186/50000


100%|██████████| 79/79 [00:02<00:00, 28.91it/s]


[Test] 4 Loss: 4.362 | Acc: 8.580 858/10000
Saving..


100%|██████████| 391/391 [00:39<00:00,  9.78it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 5 Loss: 4.339 | Acc: 8.674 4337/50000


100%|██████████| 79/79 [00:02<00:00, 28.71it/s]


[Test] 5 Loss: 4.314 | Acc: 8.960 896/10000
Saving..


100%|██████████| 391/391 [00:39<00:00,  9.78it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 6 Loss: 4.308 | Acc: 8.824 4412/50000


100%|██████████| 79/79 [00:02<00:00, 28.51it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 6 Loss: 4.294 | Acc: 8.800 880/10000


100%|██████████| 391/391 [00:40<00:00,  9.77it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 7 Loss: 4.270 | Acc: 9.164 4582/50000


100%|██████████| 79/79 [00:02<00:00, 28.89it/s]


[Test] 7 Loss: 4.246 | Acc: 9.540 954/10000
Saving..


100%|██████████| 391/391 [00:40<00:00,  9.76it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 8 Loss: 4.232 | Acc: 9.806 4903/50000


100%|██████████| 79/79 [00:02<00:00, 28.39it/s]


[Test] 8 Loss: 4.216 | Acc: 9.660 966/10000
Saving..


100%|██████████| 391/391 [00:40<00:00,  9.73it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 9 Loss: 4.199 | Acc: 10.254 5127/50000


100%|██████████| 79/79 [00:02<00:00, 28.97it/s]


[Test] 9 Loss: 4.197 | Acc: 10.230 1023/10000
Saving..


100%|██████████| 391/391 [00:40<00:00,  9.76it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 10 Loss: 4.162 | Acc: 10.942 5471/50000


100%|██████████| 79/79 [00:02<00:00, 28.85it/s]


[Test] 10 Loss: 4.139 | Acc: 11.010 1101/10000
Saving..


100%|██████████| 391/391 [00:40<00:00,  9.77it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 11 Loss: 4.120 | Acc: 11.504 5752/50000


100%|██████████| 79/79 [00:02<00:00, 28.57it/s]


[Test] 11 Loss: 4.088 | Acc: 12.350 1235/10000
Saving..


100%|██████████| 391/391 [00:40<00:00,  9.74it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 12 Loss: 4.073 | Acc: 12.218 6109/50000


100%|██████████| 79/79 [00:02<00:00, 28.89it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 12 Loss: 4.052 | Acc: 11.760 1176/10000


 55%|█████▍    | 215/391 [00:22<00:18,  9.72it/s]


KeyboardInterrupt: 

In [338]:
best_acc

55.54

In [33]:
######## C100
#### inv nn with MLP classifier: 59.3 Acc ;
#### non-inv nn with MLP classifier: 54.67 Acc ;

#### inv nn + ConnectedDist: 30.03 Acc;
#### inv nn + ConnectedLin: 48.79 Acc; -> v1

In [None]:
#### benchmark on non-inv + MLP/ConC has very low performance.. verifying here...

In [221]:
classifier.inv_temp

Parameter containing:
tensor([-0.1376], device='cuda:0', requires_grad=True)

In [113]:
checkpoint = torch.load(f'./models/{model_name}.pth')
best_acc = checkpoint['acc']
start_epoch = checkpoint['epoch']

best_acc, start_epoch

(59.3, 197)

### Hard test accuracy with count per classifier

In [339]:
test_count = 0
test_acc = 0
set_count = torch.zeros(classifier.num_sets).to(device)
for xx, yy in tqdm(test_loader):
    xx, yy = xx.to(device), yy.to(device)
    with torch.no_grad():
        yout = classifier(backbone(xx), hard=True)
        set_indx, count = torch.unique(torch.argmax(classifier.cls_confidence, dim=1), return_counts=True) 
        set_count[set_indx] += count
    outputs = torch.argmax(yout, dim=1).data.cpu().numpy()
    correct = (outputs == yy.data.cpu().numpy()).astype(float).sum()
    test_acc += correct
    test_count += len(xx)

print(f'Hard Test Acc:{float(test_acc)/test_count*100:.2f}%')
print(set_count.type(torch.long).tolist())

100%|██████████| 79/79 [00:02<00:00, 28.25it/s]

Hard Test Acc:54.97%
[0, 0, 4, 0, 0, 11, 30, 1, 0, 0, 3, 11, 0, 3, 3, 0, 0, 4, 0, 0, 0, 0, 3, 0, 0, 0, 0, 20, 0, 1, 3, 0, 0, 0, 0, 0, 0, 3, 0, 2, 0, 0, 0, 0, 5, 0, 3, 0, 0, 0, 2, 0, 0, 0, 0, 18, 0, 2, 1, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0, 86, 1, 0, 18, 0, 0, 0, 0, 0, 8, 0, 0, 0, 4, 0, 2, 0, 0, 0, 15, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 17, 33, 0, 0, 0, 3, 1, 0, 5, 2, 0, 0, 0, 0, 0, 0, 6, 15, 0, 31, 0, 20, 6, 0, 0, 0, 21, 0, 0, 3, 67, 0, 0, 7, 1, 0, 0, 0, 0, 0, 0, 16, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 11, 0, 0, 0, 0, 0, 0, 0, 0, 8, 0, 0, 84, 0, 0, 0, 2, 0, 0, 0, 0, 18, 0, 0, 0, 0, 2, 0, 0, 0, 2, 5, 0, 0, 13, 0, 0, 0, 7, 0, 3, 0, 3, 11, 12, 0, 0, 0, 15, 0, 0, 0, 0, 7, 0, 0, 1, 0, 9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 3, 0, 7, 0, 0, 44, 4, 0, 0, 51, 0, 22, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 2, 0, 15, 0, 0, 4, 0, 0, 1, 6, 0, 26, 0, 0, 0, 18, 0, 19, 0, 24, 0, 3, 98, 0, 0, 0, 0, 0, 0, 0, 73, 0, 16, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0,




In [None]:
torch.count_nonzero(set_count) ## tensor(810) for v3

In [342]:
classifier.cls_confidence[0].max()

tensor(1., device='cuda:0')

### Hard train accuracy with count per classifier

In [34]:
test_count = 0
test_acc = 0
set_count = torch.zeros(classifier.num_sets).to(device)
for xx, yy in tqdm(train_loader):
    xx, yy = xx.to(device), yy.to(device)
    with torch.no_grad():
        yout = classifier(backbone(xx), hard=True)
        set_indx, count = torch.unique(torch.argmax(classifier.cls_confidence, dim=1), return_counts=True) 
        set_count[set_indx] += count
    outputs = torch.argmax(yout, dim=1).data.cpu().numpy()
    correct = (outputs == yy.data.cpu().numpy()).astype(float).sum()
    test_acc += correct
    test_count += len(xx)

print(f'Hard Train Acc:{float(test_acc)/test_count*100:.2f}%')
print(set_count.type(torch.long).tolist())

100%|██████████| 391/391 [00:29<00:00, 13.04it/s]

Hard Train Acc:91.31%
[0, 0, 0, 0, 4452, 25, 4597, 0, 4948, 0, 0, 0, 0, 162, 0, 0, 0, 0, 0, 4909, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5013, 4926, 0, 56, 0, 5022, 4125, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 25, 0, 5073, 38, 4, 158, 0, 0, 0, 0, 0, 60, 9, 84, 0, 0, 119, 4967, 0, 25, 0, 0, 0, 0, 0, 0, 0, 0, 30, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 284, 0, 0, 0, 0, 0, 888, 0]





In [35]:
#### Classifiers that enclose any data
torch.count_nonzero(set_count)

tensor(26, device='cuda:1')

In [36]:
#### classifier with class representation
torch.argmax(classifier.cls_weight, dim=1)

tensor([6, 3, 7, 9, 5, 0, 2, 3, 3, 3, 7, 7, 5, 3, 5, 6, 3, 1, 8, 7, 3, 1, 5, 2,
        3, 3, 3, 1, 6, 2, 1, 1, 5, 6, 9, 9, 2, 3, 4, 8, 3, 9, 8, 3, 5, 9, 9, 8,
        3, 2, 2, 1, 8, 8, 8, 0, 6, 7, 2, 3, 5, 7, 0, 9, 1, 0, 6, 6, 5, 4, 1, 1,
        9, 8, 2, 5, 1, 1, 7, 7, 8, 9, 2, 2, 9, 0, 9, 5, 4, 3, 3, 1, 5, 1, 5, 4,
        5, 9, 8, 5], device='cuda:1')