In [1]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from mpl_toolkits.mplot3d import Axes3D

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils import data

import random, os, pathlib, time
from tqdm import tqdm
from sklearn import datasets

In [2]:
import nflib
from nflib.flows import SequentialFlow, NormalizingFlow, ActNorm, AffineConstantFlow
import nflib.coupling_flows as icf
import nflib.inn_flow as inn
import nflib.res_flow as irf

In [3]:
from torch import distributions
from torch.distributions import MultivariateNormal

In [4]:
device = torch.device("cuda:0")
# device = torch.device("cpu")

## MNIST dataset

In [5]:
import mylibrary.datasets as datasets
# import mylibrary.nnlib as tnn

In [6]:
mnist = datasets.FashionMNIST()
# mnist.download_mnist()
# mnist.save_mnist()
train_data, train_label_, test_data, test_label_ = mnist.load()

train_data = train_data / 255.
test_data = test_data / 255.

# train_label = tnn.Logits.index_to_logit(train_label_)
train_size = len(train_label_)

In [7]:
## converting data to pytorch format
train_data = torch.Tensor(train_data)
test_data = torch.Tensor(test_data)
train_label = torch.LongTensor(train_label_)
test_label = torch.LongTensor(test_label_)

In [8]:
input_size = 784
output_size = 10

In [9]:
class MNIST_Dataset(data.Dataset):
    
    def __init__(self, data, label):
        self.data = data
        self.label = label
        
#         self.label = mask.type(torch.float32).reshape(-1,1)
        self._shuffle_data_()
        
    def __len__(self):
        return len(self.data)
    
    def _shuffle_data_(self):
        randidx = random.sample(range(len(self.data)), k=len(self.data))
        self.data = self.data[randidx]
        self.label = self.label[randidx]
    
    def __getitem__(self, idx):
        img, lbl = self.data[idx], self.label[idx]
        return img, lbl

In [10]:
train_dataset = MNIST_Dataset(train_data, train_label)
test_dataset = MNIST_Dataset(test_data, test_label)

In [11]:
class ConnectedClassifier_Softmax(nn.Module):
    
    def __init__(self,input_dim, num_sets, output_dim, inv_temp=1):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.num_sets = num_sets
        self.inv_temp = nn.Parameter(torch.ones(1)*inv_temp)
        
        self.linear = nn.Linear(input_dim, num_sets)
        self.linear.bias.data *= 0
        self.linear.weight.data *= 0.1
        self.cls_weight = nn.Parameter(torch.ones(num_sets, output_dim)/output_dim)
        self.cls_confidence = None
        
        
    def forward(self, x, hard=True):
        x = self.linear(x)
        if hard:
            x = torch.softmax(-x*1e5, dim=1)
        else:
            x = torch.softmax(-x*self.inv_temp, dim=1)
        self.cls_confidence = x
        c = torch.softmax(self.cls_weight, dim=1)
#         c = self.cls_weight
        return x@c ## since both are normalized, it is also normalized

In [12]:
# class ConnectedClassifier_SoftKMeans(nn.Module):
    
#     def __init__(self,input_dim, num_sets, output_dim, inv_temp=1):
#         super().__init__()
#         self.input_dim = input_dim
#         self.output_dim = output_dim
#         self.num_sets = num_sets
#         self.inv_temp = nn.Parameter(torch.ones(1)*inv_temp)
        
#         self.centers = nn.Parameter(torch.rand(num_sets, input_dim)*2-1)
        
# #         self.cls_weight = nn.Parameter(torch.ones(num_sets, output_dim)/output_dim)

#         init_val = torch.randn(num_sets, output_dim)*0.01
#         for ns in range(num_sets):
#             init_val[ns, ns%output_dim] = 10.
#         self.cls_weight = nn.Parameter(init_val)

#         self.cls_confidence = None
        
        
#     def forward(self, x, hard=False):
#         x = x[:, :self.input_dim]
#         dists = torch.cdist(x, self.centers)
#         dists = dists/np.sqrt(self.input_dim) ### correction to make diagonal of unit square 1 in nD space
        
#         if hard:
#             x = torch.softmax(-dists*1e5, dim=1)
#         else:
#             x = torch.softmax(-dists*self.inv_temp, dim=1)
#         self.cls_confidence = x
#         c = torch.softmax(self.cls_weight, dim=1)
# #         c = self.cls_weight
#         return x@c ## since both are normalized, it is also normalized

In [13]:
class ConnectedClassifier_SoftKMeans(nn.Module):
    
    def __init__(self,input_dim, num_sets, output_dim, inv_temp=1):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.num_sets = num_sets
        self.inv_temp = nn.Parameter(torch.ones(1)*inv_temp)
        
        self.centers = nn.Parameter(torch.rand(num_sets, input_dim)*2-1)
        
#         self.cls_weight = nn.Parameter(torch.ones(num_sets, output_dim)/output_dim)

        init_val = torch.randn(num_sets, output_dim)*0.01
        for ns in range(num_sets):
            init_val[ns, ns%output_dim] = 10.
        self.cls_weight = nn.Parameter(init_val)

        self.cls_confidence = None
        
        
    def forward(self, x, hard=False):
        self.cls_weight.data = torch.abs(self.cls_weight.data/self.cls_weight.data.sum(dim=1, keepdim=True))
        
        x = x[:, :self.input_dim]
        dists = torch.cdist(x, self.centers)
        dists = dists/np.sqrt(self.input_dim) ### correction to make diagonal of unit square 1 in nD space
        
        if hard:
            x = torch.softmax(-dists*1e5, dim=1)
        else:
            x = torch.softmax(-dists*self.inv_temp, dim=1)
        self.cls_confidence = x
#         c = torch.softmax(self.cls_weight, dim=1)
        c = self.cls_weight
        return x@c ## since both are normalized, it is also normalized

In [15]:
actf = irf.Swish
flows = [
    ActNorm(784),
    irf.ResidualFlow(784, [784], activation=actf),
    ActNorm(784),
    irf.ResidualFlow(784, [784], activation=actf),
    ActNorm(784),
        ]

model = SequentialFlow(flows)
model = model.to(device)

In [16]:
# model = nn.Sequential(nn.Linear(784, 784, bias=False),
#                       nn.BatchNorm1d(784),
#                       nn.SELU(),
#                       nn.Linear(784, 784, bias=False),
#                       nn.BatchNorm1d(784),
#                       nn.SELU(),
#                      )

In [17]:
model.to(device)

SequentialFlow(
  (flows): ModuleList(
    (0): ActNorm()
    (1): ResidualFlow(
      (resblock): ModuleList(
        (0): Linear(in_features=784, out_features=784, bias=True)
        (1): Swish()
        (2): Linear(in_features=784, out_features=784, bias=True)
      )
    )
    (2): ActNorm()
    (3): ResidualFlow(
      (resblock): ModuleList(
        (0): Linear(in_features=784, out_features=784, bias=True)
        (1): Swish()
        (2): Linear(in_features=784, out_features=784, bias=True)
      )
    )
    (4): ActNorm()
  )
)

In [18]:
list(model.flows[0].parameters())

[Parameter containing:
 tensor([[-5.2329e-01, -9.5173e-01, -1.3168e+00,  7.7942e-01, -3.3448e-01,
          -4.0439e-01, -6.2763e-01,  4.5768e-01, -3.0654e-01,  9.8062e-01,
          -1.9224e-01,  3.6289e-01,  7.4330e-01,  1.4523e-01, -4.1951e-01,
          -3.4426e-01, -4.4905e-02, -1.7280e+00, -6.9065e-01,  1.3470e+00,
           1.7937e+00,  2.1665e-01, -3.2041e-02,  5.4881e-01,  1.3522e+00,
          -2.1197e+00, -1.6294e+00,  1.9281e-02, -9.1860e-01, -1.0233e+00,
           2.0823e+00,  7.1021e-01, -1.1690e+00,  1.5567e+00,  5.1556e-01,
           2.7652e-01,  4.5479e-01, -2.8915e-01,  1.0738e+00,  2.4312e+00,
           2.7813e+00,  3.4386e-02, -3.5578e-01, -1.4994e+00,  4.6217e-01,
           1.2910e+00, -5.1789e-01,  6.8852e-01,  1.9942e+00,  2.9899e-02,
           1.0201e+00, -4.8417e-01,  1.1400e+00,  1.5406e+00,  2.8544e-01,
          -2.3441e-01, -6.9238e-01,  1.2394e+00,  7.7584e-01, -7.7021e-01,
          -1.0340e+00,  4.4855e-01,  1.0274e+00, -5.4802e-01,  1.2045e+00,
  

In [19]:
classifier = ConnectedClassifier_SoftKMeans(784, 100, 10)
# classifier = ConnectedClassifier_Softmax(784, 10, 10)
classifier = classifier.to(device)

## Model Train

In [20]:
learning_rate = 0.0003
batch_size = 50

In [21]:
train_loader = data.DataLoader(dataset=train_dataset, num_workers=4, batch_size=batch_size, shuffle=True)
test_loader = data.DataLoader(dataset=test_dataset, num_workers=4, batch_size=batch_size, shuffle=False)

In [22]:
# criterion = nn.NLLLoss()
criterion = nn.CrossEntropyLoss()

optimizer = optim.Adam(list(model.parameters())+list(classifier.parameters()),
                       lr=learning_rate, weight_decay=1e-15) # todo tune WD
# optimizer = optim.SGD(model.parameters(), lr=0.1)

print("number of params: ", sum(p.numel() for p in model.parameters()))

number of params:  2466466


In [23]:
for p in model.parameters():
    print(torch.isnan(p).type(torch.float32).sum())

tensor(0., device='cuda:0')
tensor(0., device='cuda:0')
tensor(0., device='cuda:0')
tensor(0., device='cuda:0')
tensor(0., device='cuda:0')
tensor(0., device='cuda:0')
tensor(0., device='cuda:0')
tensor(0., device='cuda:0')
tensor(0., device='cuda:0')
tensor(0., device='cuda:0')
tensor(0., device='cuda:0')
tensor(0., device='cuda:0')
tensor(0., device='cuda:0')
tensor(0., device='cuda:0')
tensor(0., device='cuda:0')
tensor(0., device='cuda:0')


In [24]:
model(torch.randn(10, 784).to(device))

tensor([[-1.5607,  0.6105,  1.1704,  ..., -0.5931, -1.0189, -1.0891],
        [ 1.3151, -2.1658, -1.2809,  ..., -1.5512,  0.4107, -0.0355],
        [-0.3232,  0.2452,  1.1155,  ...,  0.0886,  0.7967, -0.5107],
        ...,
        [ 0.5892, -0.5161,  0.4618,  ...,  0.5948,  1.3824,  1.2422],
        [-1.6883, -0.1353, -1.9031,  ..., -0.7540,  0.3301, -1.4200],
        [ 0.0261,  0.1167,  0.7527,  ...,  1.6670, -1.4482,  0.7953]],
       device='cuda:0', grad_fn=<AddBackward0>)

In [25]:
xx = iter(test_loader).next()[0]
xx.shape

torch.Size([50, 784])

In [28]:
losses = []
train_accs = []
test_accs = []
EPOCHS = 50

index = 0
for epoch in range(EPOCHS):
    train_acc = 0
    train_count = 0
    for xx, yy in tqdm(train_loader):
        xx, yy = xx.to(device), yy.to(device)
#     for xx, yy in tqdm(test_loader):

        yout = model(xx)
#         print(yout)
        yout = classifier(yout)    
#         print(yout)
        loss = criterion(yout, yy)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.append(float(loss))

        outputs = torch.argmax(yout, dim=1).data.cpu().numpy()
        correct = (outputs == yy.data.cpu().numpy()).astype(float).sum()
        train_acc += correct
        train_count += len(outputs)
#         break

    train_accs.append(float(train_acc)/train_count*100)
    train_acc = 0
    train_count = 0

    print(f'Epoch: {epoch}:{index},  Loss:{float(loss)}')
    test_count = 0
    test_acc = 0
    for xx, yy in tqdm(test_loader):
        xx, yy = xx.to(device), yy.to(device)
        with torch.no_grad():
            yout = classifier(model(xx))    
        outputs = torch.argmax(yout, dim=1).data.cpu().numpy()
        correct = (outputs == yy.data.cpu().numpy()).astype(float).sum()
        test_acc += correct
        test_count += len(xx)
    test_accs.append(float(test_acc)/test_count*100)
    print(f'Train Acc:{train_accs[-1]:.2f}%, Test Acc:{test_accs[-1]:.2f}%')
    print()

### after each class index is finished training
print(f'\t-> Train Acc {max(train_accs)} ; Test Acc {max(test_accs)}')

100%|██████████| 1200/1200 [00:10<00:00, 114.19it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 0:0,  Loss:1.5872070789337158


100%|██████████| 200/200 [00:00<00:00, 231.65it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:88.99%, Test Acc:87.39%



100%|██████████| 1200/1200 [00:10<00:00, 113.54it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 1:0,  Loss:1.6091848611831665


100%|██████████| 200/200 [00:00<00:00, 239.62it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:89.61%, Test Acc:87.27%



100%|██████████| 1200/1200 [00:10<00:00, 114.26it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 2:0,  Loss:1.617140769958496


100%|██████████| 200/200 [00:00<00:00, 223.47it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:89.89%, Test Acc:87.63%



100%|██████████| 1200/1200 [00:10<00:00, 113.08it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 3:0,  Loss:1.575707197189331


100%|██████████| 200/200 [00:00<00:00, 234.16it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:90.24%, Test Acc:87.70%



100%|██████████| 1200/1200 [00:10<00:00, 113.65it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 4:0,  Loss:1.5634993314743042


100%|██████████| 200/200 [00:00<00:00, 244.43it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:90.48%, Test Acc:87.75%



100%|██████████| 1200/1200 [00:10<00:00, 114.11it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 5:0,  Loss:1.5652235746383667


100%|██████████| 200/200 [00:00<00:00, 254.21it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:90.65%, Test Acc:88.03%



100%|██████████| 1200/1200 [00:10<00:00, 114.32it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 6:0,  Loss:1.4846620559692383


100%|██████████| 200/200 [00:00<00:00, 234.28it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:90.92%, Test Acc:88.39%



100%|██████████| 1200/1200 [00:10<00:00, 113.90it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 7:0,  Loss:1.5491163730621338


100%|██████████| 200/200 [00:00<00:00, 252.35it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:90.97%, Test Acc:88.44%



100%|██████████| 1200/1200 [00:10<00:00, 112.83it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 8:0,  Loss:1.5592498779296875


100%|██████████| 200/200 [00:00<00:00, 236.85it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:91.33%, Test Acc:88.31%



100%|██████████| 1200/1200 [00:10<00:00, 113.49it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 9:0,  Loss:1.4895471334457397


100%|██████████| 200/200 [00:00<00:00, 243.68it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:91.42%, Test Acc:88.60%



100%|██████████| 1200/1200 [00:10<00:00, 113.68it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 10:0,  Loss:1.544619083404541


100%|██████████| 200/200 [00:00<00:00, 252.61it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:91.66%, Test Acc:88.49%



100%|██████████| 1200/1200 [00:10<00:00, 114.06it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 11:0,  Loss:1.5180134773254395


100%|██████████| 200/200 [00:00<00:00, 238.11it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:91.73%, Test Acc:89.04%



100%|██████████| 1200/1200 [00:10<00:00, 112.83it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 12:0,  Loss:1.5188182592391968


100%|██████████| 200/200 [00:00<00:00, 231.23it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:91.94%, Test Acc:88.52%



100%|██████████| 1200/1200 [00:10<00:00, 113.40it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 13:0,  Loss:1.6488500833511353


100%|██████████| 200/200 [00:00<00:00, 244.69it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:92.10%, Test Acc:88.53%



100%|██████████| 1200/1200 [00:10<00:00, 114.49it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 14:0,  Loss:1.5558513402938843


100%|██████████| 200/200 [00:00<00:00, 233.67it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:92.20%, Test Acc:88.98%



100%|██████████| 1200/1200 [00:10<00:00, 113.23it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 15:0,  Loss:1.5284711122512817


100%|██████████| 200/200 [00:00<00:00, 248.26it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:92.48%, Test Acc:88.38%



100%|██████████| 1200/1200 [00:10<00:00, 114.75it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 16:0,  Loss:1.5698734521865845


100%|██████████| 200/200 [00:00<00:00, 224.44it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:92.57%, Test Acc:88.27%



100%|██████████| 1200/1200 [00:10<00:00, 113.76it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 17:0,  Loss:1.6316076517105103


100%|██████████| 200/200 [00:00<00:00, 236.24it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:92.63%, Test Acc:88.76%



100%|██████████| 1200/1200 [00:10<00:00, 114.30it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 18:0,  Loss:1.4893118143081665


100%|██████████| 200/200 [00:00<00:00, 232.25it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:92.81%, Test Acc:88.75%



100%|██████████| 1200/1200 [00:10<00:00, 114.48it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 19:0,  Loss:1.5801801681518555


100%|██████████| 200/200 [00:00<00:00, 238.74it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:92.89%, Test Acc:88.58%



100%|██████████| 1200/1200 [00:10<00:00, 113.30it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 20:0,  Loss:1.5155861377716064


100%|██████████| 200/200 [00:00<00:00, 227.27it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:93.08%, Test Acc:88.88%



100%|██████████| 1200/1200 [00:10<00:00, 113.30it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 21:0,  Loss:1.5979125499725342


100%|██████████| 200/200 [00:00<00:00, 238.53it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:93.14%, Test Acc:88.28%



100%|██████████| 1200/1200 [00:10<00:00, 113.73it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 22:0,  Loss:1.5524548292160034


100%|██████████| 200/200 [00:00<00:00, 243.28it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:93.22%, Test Acc:88.63%



100%|██████████| 1200/1200 [00:10<00:00, 114.54it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 23:0,  Loss:1.4927555322647095


100%|██████████| 200/200 [00:00<00:00, 245.30it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:93.44%, Test Acc:88.50%



100%|██████████| 1200/1200 [00:10<00:00, 112.27it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 24:0,  Loss:1.5481436252593994


100%|██████████| 200/200 [00:00<00:00, 230.24it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:93.36%, Test Acc:89.42%



100%|██████████| 200/200 [00:00<00:00, 239.63it/s]]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:94.30%, Test Acc:89.26%



100%|██████████| 1200/1200 [00:10<00:00, 112.96it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 34:0,  Loss:1.460972547531128


100%|██████████| 200/200 [00:00<00:00, 234.89it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:94.39%, Test Acc:88.87%



100%|██████████| 1200/1200 [00:10<00:00, 112.53it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 35:0,  Loss:1.551367163658142


100%|██████████| 200/200 [00:00<00:00, 244.10it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:94.40%, Test Acc:89.22%



100%|██████████| 1200/1200 [00:10<00:00, 113.90it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 36:0,  Loss:1.4859247207641602


100%|██████████| 200/200 [00:00<00:00, 247.59it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:94.57%, Test Acc:88.80%



100%|██████████| 1200/1200 [00:10<00:00, 114.10it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 37:0,  Loss:1.5142561197280884


100%|██████████| 200/200 [00:00<00:00, 244.24it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:94.68%, Test Acc:88.95%



100%|██████████| 1200/1200 [00:10<00:00, 114.14it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 38:0,  Loss:1.4878987073898315


100%|██████████| 200/200 [00:00<00:00, 243.87it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:94.77%, Test Acc:89.16%



100%|██████████| 1200/1200 [00:10<00:00, 112.95it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 39:0,  Loss:1.502936601638794


100%|██████████| 200/200 [00:00<00:00, 229.54it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:94.73%, Test Acc:89.21%



100%|██████████| 1200/1200 [00:10<00:00, 113.57it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 40:0,  Loss:1.500788688659668


100%|██████████| 200/200 [00:00<00:00, 242.98it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:94.83%, Test Acc:89.11%



100%|██████████| 1200/1200 [00:10<00:00, 114.76it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 41:0,  Loss:1.4805999994277954


100%|██████████| 200/200 [00:00<00:00, 242.36it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:94.98%, Test Acc:89.52%



100%|██████████| 1200/1200 [00:10<00:00, 112.86it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 42:0,  Loss:1.4990063905715942


100%|██████████| 200/200 [00:00<00:00, 234.57it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:95.10%, Test Acc:89.65%



100%|██████████| 1200/1200 [00:10<00:00, 113.04it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 43:0,  Loss:1.5205912590026855


100%|██████████| 200/200 [00:00<00:00, 245.66it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:95.13%, Test Acc:89.46%



100%|██████████| 1200/1200 [00:10<00:00, 114.04it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 44:0,  Loss:1.4804770946502686


100%|██████████| 200/200 [00:00<00:00, 235.74it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:95.36%, Test Acc:89.29%



100%|██████████| 1200/1200 [00:10<00:00, 114.09it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 45:0,  Loss:1.460913896560669


100%|██████████| 200/200 [00:00<00:00, 227.44it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:95.27%, Test Acc:89.51%



100%|██████████| 1200/1200 [00:10<00:00, 114.39it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 46:0,  Loss:1.5688625574111938


100%|██████████| 200/200 [00:00<00:00, 235.75it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:95.40%, Test Acc:89.05%



100%|██████████| 1200/1200 [00:10<00:00, 113.08it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 47:0,  Loss:1.491915225982666


100%|██████████| 200/200 [00:00<00:00, 234.27it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:95.34%, Test Acc:89.50%



100%|██████████| 1200/1200 [00:10<00:00, 113.50it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 48:0,  Loss:1.5005404949188232


100%|██████████| 200/200 [00:00<00:00, 239.39it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:95.47%, Test Acc:89.51%



100%|██████████| 1200/1200 [00:10<00:00, 112.42it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 49:0,  Loss:1.5203367471694946


100%|██████████| 200/200 [00:00<00:00, 201.03it/s]

Train Acc:95.49%, Test Acc:88.74%

	-> Train Acc 95.48833333333333 ; Test Acc 89.64999999999999





In [30]:
# classifier.cls_weight

### Hard test accuracy with count per classifier

In [31]:
test_count = 0
test_acc = 0
set_count = torch.zeros(classifier.num_sets).to(device)
for xx, yy in tqdm(test_loader):
    xx, yy = xx.to(device), yy.to(device)
    with torch.no_grad():
        yout = classifier(model(xx), hard=True)
        set_indx, count = torch.unique(torch.argmax(classifier.cls_confidence, dim=1), return_counts=True) 
        set_count[set_indx] += count
    outputs = torch.argmax(yout, dim=1).data.cpu().numpy()
    correct = (outputs == yy.data.cpu().numpy()).astype(float).sum()
    test_acc += correct
    test_count += len(xx)

print(f'Hard Test Acc:{float(test_acc)/test_count*100:.2f}%')
print(set_count.type(torch.long).tolist())

100%|██████████| 200/200 [00:00<00:00, 211.54it/s]

Hard Test Acc:88.76%
[0, 0, 1040, 15, 0, 0, 0, 0, 20, 0, 0, 0, 0, 0, 0, 138, 0, 0, 92, 7, 0, 2, 0, 149, 949, 0, 0, 0, 626, 15, 0, 1, 0, 134, 0, 1, 0, 955, 71, 422, 0, 0, 0, 0, 0, 850, 0, 0, 3, 0, 0, 0, 0, 101, 0, 0, 0, 0, 69, 0, 0, 0, 0, 1, 0, 0, 0, 0, 51, 0, 910, 0, 0, 0, 0, 0, 0, 0, 10, 0, 1195, 968, 0, 2, 0, 0, 0, 0, 5, 618, 0, 6, 0, 539, 0, 0, 0, 0, 35, 0]





### Hard train accuracy with count per classifier

In [32]:
test_count = 0
test_acc = 0
set_count = torch.zeros(classifier.num_sets).to(device)
for xx, yy in tqdm(train_loader):
    xx, yy = xx.to(device), yy.to(device)
    with torch.no_grad():
        yout = classifier(model(xx), hard=True)
        set_indx, count = torch.unique(torch.argmax(classifier.cls_confidence, dim=1), return_counts=True) 
        set_count[set_indx] += count
    outputs = torch.argmax(yout, dim=1).data.cpu().numpy()
    correct = (outputs == yy.data.cpu().numpy()).astype(float).sum()
    test_acc += correct
    test_count += len(xx)

print(f'Hard Train Acc:{float(test_acc)/test_count*100:.2f}%')
print(set_count.type(torch.long).tolist())

100%|██████████| 1200/1200 [00:04<00:00, 254.65it/s]

Hard Train Acc:94.73%
[0, 0, 6122, 87, 0, 0, 0, 0, 161, 0, 0, 0, 0, 0, 0, 782, 0, 0, 460, 50, 0, 4, 0, 820, 5862, 0, 0, 0, 3782, 46, 0, 4, 0, 875, 0, 1, 0, 5767, 433, 2409, 0, 2, 0, 0, 0, 5186, 0, 0, 15, 0, 0, 0, 0, 600, 0, 0, 0, 0, 500, 0, 0, 1, 0, 23, 0, 0, 0, 0, 331, 0, 5684, 1, 0, 14, 0, 0, 0, 0, 67, 0, 6490, 5909, 0, 27, 0, 0, 0, 0, 42, 3759, 0, 17, 0, 3480, 0, 0, 0, 0, 187, 0]





In [33]:
#### Classifiers that enclose any data
torch.count_nonzero(set_count)

tensor(37, device='cuda:0')

In [34]:
#### classifier with class representation
torch.argmax(classifier.cls_weight, dim=1)

tensor([6, 6, 2, 3, 4, 5, 6, 7, 8, 6, 6, 1, 2, 3, 4, 5, 6, 7, 8, 9, 6, 1, 2, 3,
        4, 5, 6, 7, 8, 9, 6, 1, 6, 3, 4, 5, 6, 7, 8, 9, 6, 1, 2, 3, 4, 5, 6, 7,
        8, 6, 6, 1, 2, 3, 4, 5, 6, 7, 8, 9, 6, 1, 6, 3, 4, 5, 6, 7, 8, 9, 0, 1,
        2, 3, 4, 5, 6, 7, 8, 9, 6, 1, 2, 3, 4, 5, 6, 7, 8, 9, 6, 1, 2, 3, 4, 5,
        6, 7, 8, 9], device='cuda:0')

In [66]:
# The class labels are same as that of initialized
# tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3,
#         4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7,
#         8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1,
#         2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5,
#         6, 7, 8, 9], device='cuda:0')

In [67]:
classifier.cls_weight

Parameter containing:
tensor([[14.5033, -2.3390, -2.3409, -2.2740, -2.3565, -2.3720, -2.1097, -2.3840,
         -2.3706, -2.3809],
        [-2.3772, 14.4716, -2.3956, -2.3275, -2.3879, -2.3858, -2.3834, -2.3961,
         -2.4096, -2.4031],
        [-2.3144, -2.3321, 14.5638, -2.3289, -2.1009, -2.3546, -2.1033, -2.3621,
         -2.3181, -2.3527],
        [-2.2622, -2.2628, -2.3509, 14.4797, -2.3197, -2.3655, -2.3083, -2.3762,
         -2.3589, -2.3743],
        [-2.4062, -2.4268, -2.1114, -2.3799, 14.5947, -2.4376, -2.1948, -2.4477,
         -2.3933, -2.4408],
        [-2.3914, -2.3885, -2.3987, -2.3950, -2.4002, 14.5619, -2.3902, -2.2499,
         -2.3637, -2.3111],
        [-2.0379, -2.2405, -1.9522, -2.1715, -1.9584, -2.2554, 14.4938, -2.2768,
         -2.2024, -2.2734],
        [-2.4377, -2.4237, -2.4369, -2.4340, -2.4470, -2.2871, -2.4310, 14.6102,
         -2.4123, -2.3418],
        [-2.4166, -2.4554, -2.4106, -2.4334, -2.4156, -2.4106, -2.3998, -2.4082,
         14.5393, -2.4187

In [68]:
# torch.unique(torch.argmax(classifier.cls_confidence, dim=1), return_counts=True)

In [69]:
classifier.inv_temp

Parameter containing:
tensor([7.8020], device='cuda:0', requires_grad=True)

In [70]:
### example output per classifier
yout[5]

tensor([3.8328e-08, 3.8267e-08, 3.8325e-08, 3.8312e-08, 3.8171e-08, 4.4385e-08,
        3.8185e-08, 1.0000e+00, 3.9245e-08, 4.1537e-08], device='cuda:0')

In [71]:
asdfsdf ## to break the code

NameError: name 'asdfsdf' is not defined

### analyze per classifier accuracy

In [73]:
test_count = 0
test_acc = 0
set_count = torch.zeros(classifier.num_sets).to(device)
set_acc = torch.zeros(classifier.num_sets).to(device)
for xx, yy in tqdm(train_loader):
    xx, yy = xx.to(device), yy.to(device)
    with torch.no_grad():
        yout = classifier(model(xx), hard=True)
        
    cls_indx = torch.argmax(classifier.cls_confidence, dim=1)
    set_indx, count = torch.unique(cls_indx, return_counts=True) 
    set_count[set_indx] += count
    
    outputs = torch.argmax(yout, dim=1).data.cpu().numpy()
    correct = (outputs == yy.data.cpu().numpy()).astype(float)
    
    ### class_index has 100 possible values
    for i, c in enumerate(correct):
        set_acc[cls_indx[i]] += c
    
#     print(set_acc.sum(), set_count.sum())
#     break
    test_acc += correct.sum()
    test_count += len(xx)

print(f'Hard Train Acc:{float(test_acc)/test_count*100:.2f}%')
print(set_count.type(torch.long).tolist())

100%|██████████| 1200/1200 [00:05<00:00, 233.51it/s]

Hard Train Acc:91.65%
[409, 0, 318, 18, 3561, 2, 368, 86, 44, 60, 2173, 17, 588, 65, 242, 238, 788, 66, 19, 0, 59, 0, 213, 250, 27, 1823, 11, 595, 1173, 3158, 265, 0, 1260, 125, 72, 23, 7, 49, 11, 397, 457, 2083, 2972, 701, 157, 31, 2750, 5, 1994, 5, 116, 1421, 196, 891, 3, 507, 1342, 707, 47, 1212, 130, 76, 88, 338, 91, 222, 68, 4301, 373, 39, 0, 1870, 117, 79, 292, 126, 3, 108, 18, 997, 1614, 0, 8, 1717, 243, 20, 389, 59, 2211, 96, 14, 452, 84, 2009, 2009, 2892, 381, 47, 109, 133]





In [74]:
set_acc/set_count

tensor([0.8973,    nan, 0.8711, 0.5000, 0.8110, 1.0000, 0.6168, 0.8953, 0.8864,
        0.9000, 0.9609, 0.9412, 0.8197, 0.6769, 0.8430, 0.9874, 0.5723, 0.8333,
        0.8421,    nan, 0.8644,    nan, 0.9624, 0.7560, 0.6667, 0.9863, 0.3636,
        0.9580, 0.9838, 0.9725, 0.9472,    nan, 0.8754, 0.7040, 0.5000, 1.0000,
        0.7143, 0.8980, 1.0000, 0.9673, 0.8709, 0.9981, 0.8782, 0.8488, 0.6624,
        0.9355, 0.7967, 1.0000, 0.9940, 0.4000, 0.8707, 0.9887, 0.8418, 0.9484,
        0.0000, 0.9961, 0.8860, 0.9760, 0.8298, 0.9645, 0.7769, 0.9737, 0.8068,
        0.9379, 0.6264, 0.9820, 0.5882, 0.9772, 0.9786, 0.9231,    nan, 0.9968,
        0.7350, 0.6709, 0.7774, 0.9365, 0.6667, 0.8241, 0.8889, 0.9609, 0.8990,
           nan, 0.1250, 0.9278, 0.6132, 1.0000, 0.8869, 0.8136, 0.9851, 0.9062,
        0.9286, 0.9978, 0.8095, 0.9512, 0.8706, 0.9969, 0.7087, 0.7447, 0.9083,
        0.7820], device='cuda:0')

In [76]:
for i, (cnt, acc, cls) in enumerate(zip(set_count.type(torch.long).tolist(),
                                   (set_acc/set_count).tolist(),
                                   torch.argmax(classifier.cls_weight, dim=1).tolist())):
    if cnt == 0: continue
    print(f"{i},\t {cnt},\t {cls}\t {acc*100:.2f}%")

0,	 409,	 0	 89.73%
2,	 318,	 2	 87.11%
3,	 18,	 3	 50.00%
4,	 3561,	 4	 81.10%
5,	 2,	 5	 100.00%
6,	 368,	 6	 61.68%
7,	 86,	 7	 89.53%
8,	 44,	 8	 88.64%
9,	 60,	 9	 90.00%
10,	 2173,	 0	 96.09%
11,	 17,	 1	 94.12%
12,	 588,	 2	 81.97%
13,	 65,	 3	 67.69%
14,	 242,	 4	 84.30%
15,	 238,	 5	 98.74%
16,	 788,	 6	 57.23%
17,	 66,	 7	 83.33%
18,	 19,	 8	 84.21%
20,	 59,	 0	 86.44%
22,	 213,	 2	 96.24%
23,	 250,	 3	 75.60%
24,	 27,	 4	 66.67%
25,	 1823,	 5	 98.63%
26,	 11,	 6	 36.36%
27,	 595,	 7	 95.80%
28,	 1173,	 8	 98.38%
29,	 3158,	 9	 97.25%
30,	 265,	 0	 94.72%
32,	 1260,	 2	 87.54%
33,	 125,	 3	 70.40%
34,	 72,	 4	 50.00%
35,	 23,	 5	 100.00%
36,	 7,	 6	 71.43%
37,	 49,	 7	 89.80%
38,	 11,	 8	 100.00%
39,	 397,	 9	 96.73%
40,	 457,	 0	 87.09%
41,	 2083,	 1	 99.81%
42,	 2972,	 2	 87.82%
43,	 701,	 3	 84.88%
44,	 157,	 4	 66.24%
45,	 31,	 5	 93.55%
46,	 2750,	 6	 79.67%
47,	 5,	 7	 100.00%
48,	 1994,	 8	 99.40%
49,	 5,	 9	 40.00%
50,	 116,	 0	 87.07%
51,	 1421,	 1	 98.87%
52,	 196,	