In [1]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from mpl_toolkits.mplot3d import Axes3D

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils import data

import random, os, pathlib, time
from tqdm import tqdm
from sklearn import datasets

In [2]:
import nflib
from nflib.flows import SequentialFlow, NormalizingFlow, ActNorm, AffineConstantFlow
import nflib.coupling_flows as icf
import nflib.inn_flow as inn
import nflib.res_flow as irf

In [3]:
from torch import distributions
from torch.distributions import MultivariateNormal

In [4]:
device = torch.device("cuda:0")
# device = torch.device("cpu")

## MNIST dataset

In [5]:
import mylibrary.datasets as datasets
# import mylibrary.nnlib as tnn

In [6]:
mnist = datasets.FashionMNIST()
# mnist.download_mnist()
# mnist.save_mnist()
train_data, train_label_, test_data, test_label_ = mnist.load()

train_data = train_data / 255.
test_data = test_data / 255.

# train_label = tnn.Logits.index_to_logit(train_label_)
train_size = len(train_label_)

In [7]:
## converting data to pytorch format
train_data = torch.Tensor(train_data)
test_data = torch.Tensor(test_data)
train_label = torch.LongTensor(train_label_)
test_label = torch.LongTensor(test_label_)

In [8]:
input_size = 784
output_size = 10

In [9]:
class MNIST_Dataset(data.Dataset):
    
    def __init__(self, data, label):
        self.data = data
        self.label = label
        
#         self.label = mask.type(torch.float32).reshape(-1,1)
        self._shuffle_data_()
        
    def __len__(self):
        return len(self.data)
    
    def _shuffle_data_(self):
        randidx = random.sample(range(len(self.data)), k=len(self.data))
        self.data = self.data[randidx]
        self.label = self.label[randidx]
    
    def __getitem__(self, idx):
        img, lbl = self.data[idx], self.label[idx]
        return img, lbl

In [10]:
train_dataset = MNIST_Dataset(train_data, train_label)
test_dataset = MNIST_Dataset(test_data, test_label)

In [142]:
class ConnectedClassifier_Softmax(nn.Module):
    
    def __init__(self,input_dim, num_sets, output_dim, inv_temp=1):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.num_sets = num_sets
        self.inv_temp = nn.Parameter(torch.ones(1)*inv_temp)
        
        self.linear = nn.Linear(input_dim, num_sets)
#         self.linear.bias.data *= 0
#         self.linear.weight.data *= 0.1
#         self.cls_weight = nn.Parameter(torch.randn(num_sets, output_dim)/output_dim)

        init_val = torch.randn(num_sets, output_dim)*0.01
        for ns in range(num_sets):
            init_val[ns, ns%output_dim] = 10.
        self.cls_weight = nn.Parameter(init_val)
        
        self.cls_confidence = None
        
        
    def forward(self, x, hard=False):
        self.cls_weight.data = torch.abs(self.cls_weight.data/self.cls_weight.data.sum(dim=1, keepdim=True))
        
        x = self.linear(x)
        if hard:
            x = torch.softmax(x*1e5, dim=1)
        else:
            x = torch.softmax(x*self.inv_temp, dim=1)
        self.cls_confidence = x
#         c = torch.softmax(self.cls_weight, dim=1)
        c = self.cls_weight
        return x@c ## since both are normalized, it is also normalized

In [143]:
# class ConnectedClassifier_SoftKMeans(nn.Module):
    
#     def __init__(self,input_dim, num_sets, output_dim, inv_temp=1):
#         super().__init__()
#         self.input_dim = input_dim
#         self.output_dim = output_dim
#         self.num_sets = num_sets
#         self.inv_temp = nn.Parameter(torch.ones(1)*inv_temp)
        
#         self.centers = nn.Parameter(torch.rand(num_sets, input_dim)*2-1)
        
# #         self.cls_weight = nn.Parameter(torch.ones(num_sets, output_dim)/output_dim)

#         init_val = torch.randn(num_sets, output_dim)*0.01
#         for ns in range(num_sets):
#             init_val[ns, ns%output_dim] = 10.
#         self.cls_weight = nn.Parameter(init_val)

#         self.cls_confidence = None
        
        
#     def forward(self, x, hard=False):
#         x = x[:, :self.input_dim]
#         dists = torch.cdist(x, self.centers)
#         dists = dists/np.sqrt(self.input_dim) ### correction to make diagonal of unit square 1 in nD space
        
#         if hard:
#             x = torch.softmax(-dists*1e5, dim=1)
#         else:
#             x = torch.softmax(-dists*self.inv_temp, dim=1)
#         self.cls_confidence = x
#         c = torch.softmax(self.cls_weight, dim=1)
# #         c = self.cls_weight
#         return x@c ## since both are normalized, it is also normalized

In [144]:
class ConnectedClassifier_SoftKMeans(nn.Module):
    
    def __init__(self,input_dim, num_sets, output_dim, inv_temp=1):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.num_sets = num_sets
        self.inv_temp = nn.Parameter(torch.ones(1)*inv_temp)
        
        self.centers = nn.Parameter(torch.rand(num_sets, input_dim)*2-1)
        
#         self.cls_weight = nn.Parameter(torch.ones(num_sets, output_dim)/output_dim)

        init_val = torch.randn(num_sets, output_dim)*0.01
        for ns in range(num_sets):
            init_val[ns, ns%output_dim] = 10.
        self.cls_weight = nn.Parameter(init_val)

        self.cls_confidence = None
        
        
    def forward(self, x, hard=False):
        self.cls_weight.data = torch.abs(self.cls_weight.data/self.cls_weight.data.sum(dim=1, keepdim=True))
        
        x = x[:, :self.input_dim]
        dists = torch.cdist(x, self.centers)
        dists = dists/np.sqrt(self.input_dim) ### correction to make diagonal of unit square 1 in nD space
        
        if hard:
            x = torch.softmax(-dists*1e5, dim=1)
        else:
            x = torch.softmax(-dists*self.inv_temp, dim=1)
        self.cls_confidence = x
#         c = torch.softmax(self.cls_weight, dim=1)
        c = self.cls_weight
        return x@c ## since both are normalized, it is also normalized

In [145]:
# actf = irf.Swish
# flows = [
#     ActNorm(784),
#     irf.ResidualFlow(784, [784], activation=actf),
#     ActNorm(784),
#     irf.ResidualFlow(784, [784], activation=actf),
#     ActNorm(784),
#         ]

# model = SequentialFlow(flows)

In [146]:
model = nn.Sequential(nn.Linear(784, 784, bias=False),
                      nn.BatchNorm1d(784),
                      nn.SELU(),
                      nn.Linear(784, 784, bias=False),
                      nn.BatchNorm1d(784),
                      nn.SELU(),
                     )

In [147]:
model.to(device)

Sequential(
  (0): Linear(in_features=784, out_features=784, bias=False)
  (1): BatchNorm1d(784, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): SELU()
  (3): Linear(in_features=784, out_features=784, bias=False)
  (4): BatchNorm1d(784, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (5): SELU()
)

In [148]:
list(model.flows[0].parameters())

AttributeError: 'Sequential' object has no attribute 'flows'

In [149]:
# classifier = ConnectedClassifier_SoftKMeans(784, 100, 10)
classifier = ConnectedClassifier_Softmax(784, 100, 10)
classifier = classifier.to(device)

## Model Train

In [150]:
learning_rate = 0.0003
batch_size = 50

In [151]:
train_loader = data.DataLoader(dataset=train_dataset, num_workers=4, batch_size=batch_size, shuffle=True)
test_loader = data.DataLoader(dataset=test_dataset, num_workers=4, batch_size=batch_size, shuffle=False)

In [152]:
# criterion = nn.NLLLoss()
criterion = nn.CrossEntropyLoss()

optimizer = optim.Adam(list(model.parameters())+list(classifier.parameters()),
                       lr=learning_rate, weight_decay=1e-15) # todo tune WD
# optimizer = optim.SGD(model.parameters(), lr=0.1)

print("number of params: ", sum(p.numel() for p in model.parameters()))

number of params:  1232448


In [153]:
for p in model.parameters():
    print(torch.isnan(p).type(torch.float32).sum())

tensor(0., device='cuda:0')
tensor(0., device='cuda:0')
tensor(0., device='cuda:0')
tensor(0., device='cuda:0')
tensor(0., device='cuda:0')
tensor(0., device='cuda:0')


In [154]:
model(torch.randn(10, 784).to(device))

tensor([[ 0.9458,  0.9715,  2.2747,  ..., -1.3463,  0.0900,  0.0901],
        [ 1.8275, -1.1361, -1.2798,  ..., -1.4258,  0.0946, -0.9902],
        [-0.6070,  1.7841, -0.7394,  ...,  0.8922,  0.9638, -1.0139],
        ...,
        [ 1.3297, -0.6568, -0.8720,  ...,  1.0454, -1.4354, -0.5707],
        [-1.4546, -1.1686,  0.9964,  ...,  1.6013,  1.6338, -1.0110],
        [-0.5842,  1.3439,  0.1981,  ..., -0.0883, -0.8458,  2.3987]],
       device='cuda:0', grad_fn=<EluBackward0>)

In [155]:
xx = iter(test_loader).next()[0]
xx.shape

torch.Size([50, 784])

In [156]:
losses = []
train_accs = []
test_accs = []
EPOCHS = 50

index = 0
for epoch in range(EPOCHS):
    train_acc = 0
    train_count = 0
    for xx, yy in tqdm(train_loader):
        xx, yy = xx.to(device), yy.to(device)
#     for xx, yy in tqdm(test_loader):

        yout = model(xx)
#         print(yout)
        yout = classifier(yout)    
#         print(yout)
        loss = criterion(yout, yy)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.append(float(loss))

        outputs = torch.argmax(yout, dim=1).data.cpu().numpy()
        correct = (outputs == yy.data.cpu().numpy()).astype(float).sum()
        train_acc += correct
        train_count += len(outputs)
#         break

    train_accs.append(float(train_acc)/train_count*100)
    train_acc = 0
    train_count = 0

    print(f'Epoch: {epoch}:{index},  Loss:{float(loss)}')
    test_count = 0
    test_acc = 0
    for xx, yy in tqdm(test_loader):
        xx, yy = xx.to(device), yy.to(device)
        with torch.no_grad():
            yout = classifier(model(xx))    
        outputs = torch.argmax(yout, dim=1).data.cpu().numpy()
        correct = (outputs == yy.data.cpu().numpy()).astype(float).sum()
        test_acc += correct
        test_count += len(xx)
    test_accs.append(float(test_acc)/test_count*100)
    print(f'Train Acc:{train_accs[-1]:.2f}%, Test Acc:{test_accs[-1]:.2f}%')
    print()

### after each class index is finished training
print(f'\t-> Train Acc {max(train_accs)} ; Test Acc {max(test_accs)}')

100%|██████████| 1200/1200 [00:02<00:00, 412.58it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 0:0,  Loss:1.5859712362289429


100%|██████████| 200/200 [00:00<00:00, 596.03it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:81.79%, Test Acc:81.96%



100%|██████████| 1200/1200 [00:02<00:00, 423.91it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 1:0,  Loss:1.65738844871521


100%|██████████| 200/200 [00:00<00:00, 630.44it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:84.17%, Test Acc:83.46%



100%|██████████| 1200/1200 [00:02<00:00, 418.46it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 2:0,  Loss:1.7071442604064941


100%|██████████| 200/200 [00:00<00:00, 552.96it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:85.16%, Test Acc:82.76%



100%|██████████| 1200/1200 [00:02<00:00, 417.68it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 3:0,  Loss:1.5982575416564941


100%|██████████| 200/200 [00:00<00:00, 601.69it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:85.52%, Test Acc:84.12%



100%|██████████| 1200/1200 [00:02<00:00, 415.14it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 4:0,  Loss:1.6208800077438354


100%|██████████| 200/200 [00:00<00:00, 619.14it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:85.83%, Test Acc:84.39%



100%|██████████| 1200/1200 [00:02<00:00, 414.63it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 5:0,  Loss:1.630985975265503


100%|██████████| 200/200 [00:00<00:00, 538.57it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:86.18%, Test Acc:84.08%



100%|██████████| 1200/1200 [00:02<00:00, 419.68it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 6:0,  Loss:1.5958607196807861


100%|██████████| 200/200 [00:00<00:00, 552.79it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:86.48%, Test Acc:84.42%



100%|██████████| 1200/1200 [00:02<00:00, 413.91it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 7:0,  Loss:1.6782994270324707


100%|██████████| 200/200 [00:00<00:00, 650.57it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:86.75%, Test Acc:84.78%



100%|██████████| 1200/1200 [00:02<00:00, 421.04it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 8:0,  Loss:1.6436071395874023


100%|██████████| 200/200 [00:00<00:00, 594.52it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:86.79%, Test Acc:84.75%



100%|██████████| 1200/1200 [00:02<00:00, 420.86it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 9:0,  Loss:1.5888993740081787


100%|██████████| 200/200 [00:00<00:00, 595.60it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:86.93%, Test Acc:85.17%



100%|██████████| 1200/1200 [00:02<00:00, 421.25it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 10:0,  Loss:1.5864720344543457


100%|██████████| 200/200 [00:00<00:00, 593.15it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:87.24%, Test Acc:85.15%



100%|██████████| 1200/1200 [00:02<00:00, 414.47it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 11:0,  Loss:1.611083984375


100%|██████████| 200/200 [00:00<00:00, 558.19it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:87.33%, Test Acc:85.45%



100%|██████████| 1200/1200 [00:02<00:00, 418.30it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 12:0,  Loss:1.634911060333252


100%|██████████| 200/200 [00:00<00:00, 630.88it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:87.61%, Test Acc:85.45%



100%|██████████| 1200/1200 [00:02<00:00, 415.12it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 13:0,  Loss:1.5660384893417358


100%|██████████| 200/200 [00:00<00:00, 513.79it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:87.47%, Test Acc:85.79%



100%|██████████| 1200/1200 [00:02<00:00, 419.94it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 14:0,  Loss:1.603417158126831


100%|██████████| 200/200 [00:00<00:00, 581.75it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:87.71%, Test Acc:86.00%



100%|██████████| 1200/1200 [00:02<00:00, 414.95it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 15:0,  Loss:1.547838568687439


100%|██████████| 200/200 [00:00<00:00, 545.05it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:87.94%, Test Acc:85.87%



100%|██████████| 1200/1200 [00:02<00:00, 423.87it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 16:0,  Loss:1.5640358924865723


100%|██████████| 200/200 [00:00<00:00, 598.73it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:87.83%, Test Acc:85.69%



100%|██████████| 1200/1200 [00:02<00:00, 424.58it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 17:0,  Loss:1.57367742061615


100%|██████████| 200/200 [00:00<00:00, 604.00it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:88.02%, Test Acc:86.06%



100%|██████████| 1200/1200 [00:02<00:00, 418.01it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 18:0,  Loss:1.546586275100708


100%|██████████| 200/200 [00:00<00:00, 561.15it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:88.27%, Test Acc:85.52%



100%|██████████| 1200/1200 [00:02<00:00, 421.09it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 19:0,  Loss:1.6344590187072754


100%|██████████| 200/200 [00:00<00:00, 625.81it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:88.38%, Test Acc:85.33%



100%|██████████| 1200/1200 [00:02<00:00, 418.05it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 20:0,  Loss:1.585923671722412


100%|██████████| 200/200 [00:00<00:00, 605.65it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:88.32%, Test Acc:86.04%



100%|██████████| 1200/1200 [00:02<00:00, 414.51it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 21:0,  Loss:1.6701271533966064


100%|██████████| 200/200 [00:00<00:00, 571.57it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:88.50%, Test Acc:85.74%



100%|██████████| 1200/1200 [00:02<00:00, 406.96it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 22:0,  Loss:1.6410032510757446


100%|██████████| 200/200 [00:00<00:00, 607.81it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:88.39%, Test Acc:86.28%



100%|██████████| 1200/1200 [00:02<00:00, 417.39it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 23:0,  Loss:1.562125563621521


100%|██████████| 200/200 [00:00<00:00, 564.76it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:88.67%, Test Acc:85.97%



100%|██████████| 1200/1200 [00:02<00:00, 416.36it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 24:0,  Loss:1.5440185070037842


100%|██████████| 200/200 [00:00<00:00, 563.80it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:88.61%, Test Acc:86.19%



100%|██████████| 1200/1200 [00:02<00:00, 418.75it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 25:0,  Loss:1.6632070541381836


100%|██████████| 200/200 [00:00<00:00, 621.00it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:89.02%, Test Acc:85.57%



100%|██████████| 1200/1200 [00:02<00:00, 417.82it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 26:0,  Loss:1.5481547117233276


100%|██████████| 200/200 [00:00<00:00, 604.74it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:88.92%, Test Acc:86.02%



100%|██████████| 1200/1200 [00:02<00:00, 415.42it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 27:0,  Loss:1.5643914937973022


100%|██████████| 200/200 [00:00<00:00, 628.15it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:89.06%, Test Acc:85.98%



100%|██████████| 1200/1200 [00:02<00:00, 415.68it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 28:0,  Loss:1.5512384176254272


100%|██████████| 200/200 [00:00<00:00, 578.80it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:88.83%, Test Acc:86.58%



100%|██████████| 1200/1200 [00:02<00:00, 416.21it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 29:0,  Loss:1.5781817436218262


100%|██████████| 200/200 [00:00<00:00, 576.77it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:88.90%, Test Acc:86.31%



100%|██████████| 1200/1200 [00:02<00:00, 413.50it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 30:0,  Loss:1.5602346658706665


100%|██████████| 200/200 [00:00<00:00, 619.27it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:89.10%, Test Acc:85.72%



100%|██████████| 1200/1200 [00:02<00:00, 417.13it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 31:0,  Loss:1.584736943244934


100%|██████████| 200/200 [00:00<00:00, 633.64it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:89.15%, Test Acc:86.11%



100%|██████████| 1200/1200 [00:02<00:00, 412.05it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 32:0,  Loss:1.5752341747283936


100%|██████████| 200/200 [00:00<00:00, 594.28it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:89.04%, Test Acc:86.40%



100%|██████████| 1200/1200 [00:02<00:00, 412.92it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 33:0,  Loss:1.5528117418289185


100%|██████████| 200/200 [00:00<00:00, 521.76it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:89.08%, Test Acc:86.00%



100%|██████████| 1200/1200 [00:02<00:00, 419.44it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 34:0,  Loss:1.573508858680725


100%|██████████| 200/200 [00:00<00:00, 576.49it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:89.31%, Test Acc:86.98%



100%|██████████| 1200/1200 [00:02<00:00, 415.82it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 35:0,  Loss:1.5635652542114258


100%|██████████| 200/200 [00:00<00:00, 570.88it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:89.33%, Test Acc:85.69%



100%|██████████| 1200/1200 [00:02<00:00, 416.34it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 36:0,  Loss:1.6178604364395142


100%|██████████| 200/200 [00:00<00:00, 580.73it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:89.27%, Test Acc:85.92%



100%|██████████| 1200/1200 [00:02<00:00, 414.74it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 37:0,  Loss:1.666319727897644


100%|██████████| 200/200 [00:00<00:00, 562.52it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:89.43%, Test Acc:85.64%



100%|██████████| 1200/1200 [00:02<00:00, 414.47it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 38:0,  Loss:1.6217929124832153


100%|██████████| 200/200 [00:00<00:00, 596.56it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:89.22%, Test Acc:86.50%



100%|██████████| 1200/1200 [00:02<00:00, 422.17it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 39:0,  Loss:1.6287568807601929


100%|██████████| 200/200 [00:00<00:00, 573.66it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:89.34%, Test Acc:86.90%



100%|██████████| 1200/1200 [00:02<00:00, 413.14it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 40:0,  Loss:1.5452221632003784


100%|██████████| 200/200 [00:00<00:00, 630.18it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:89.63%, Test Acc:86.51%



100%|██████████| 1200/1200 [00:02<00:00, 415.87it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 41:0,  Loss:1.5304566621780396


100%|██████████| 200/200 [00:00<00:00, 632.33it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:89.47%, Test Acc:86.59%



100%|██████████| 1200/1200 [00:02<00:00, 413.78it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 42:0,  Loss:1.5750335454940796


100%|██████████| 200/200 [00:00<00:00, 580.88it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:89.55%, Test Acc:86.34%



100%|██████████| 1200/1200 [00:02<00:00, 416.04it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 43:0,  Loss:1.5068165063858032


100%|██████████| 200/200 [00:00<00:00, 601.64it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:89.41%, Test Acc:86.32%



 13%|█▎        | 158/1200 [00:00<00:03, 329.26it/s]


KeyboardInterrupt: 

In [None]:
# classifier.cls_weight

### Hard test accuracy with count per classifier

In [157]:
test_count = 0
test_acc = 0
set_count = torch.zeros(classifier.num_sets).to(device)
for xx, yy in tqdm(test_loader):
    xx, yy = xx.to(device), yy.to(device)
    with torch.no_grad():
        yout = classifier(model(xx), hard=True)
        set_indx, count = torch.unique(torch.argmax(classifier.cls_confidence, dim=1), return_counts=True) 
        set_count[set_indx] += count
    outputs = torch.argmax(yout, dim=1).data.cpu().numpy()
    correct = (outputs == yy.data.cpu().numpy()).astype(float).sum()
    test_acc += correct
    test_count += len(xx)

print(f'Hard Test Acc:{float(test_acc)/test_count*100:.2f}%')
print(set_count.type(torch.long).tolist())

100%|██████████| 200/200 [00:00<00:00, 489.74it/s]

Hard Test Acc:86.32%
[0, 0, 0, 0, 0, 0, 0, 0, 499, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 1023, 0, 0, 0, 0, 939, 1047, 0, 0, 0, 0, 0, 982, 0, 0, 15, 0, 940, 0, 0, 0, 971, 0, 3, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 102, 0, 0, 0, 0, 0, 930, 0, 0, 0, 0, 0, 497, 0, 0, 0, 1, 1095, 0, 0, 0, 947, 0, 0]





### Hard train accuracy with count per classifier

In [158]:
test_count = 0
test_acc = 0
set_count = torch.zeros(classifier.num_sets).to(device)
for xx, yy in tqdm(train_loader):
    xx, yy = xx.to(device), yy.to(device)
    with torch.no_grad():
        yout = classifier(model(xx), hard=True)
        set_indx, count = torch.unique(torch.argmax(classifier.cls_confidence, dim=1), return_counts=True) 
        set_count[set_indx] += count
    outputs = torch.argmax(yout, dim=1).data.cpu().numpy()
    correct = (outputs == yy.data.cpu().numpy()).astype(float).sum()
    test_acc += correct
    test_count += len(xx)

print(f'Hard Train Acc:{float(test_acc)/test_count*100:.2f}%')
print(set_count.type(torch.long).tolist())

100%|██████████| 1200/1200 [00:01<00:00, 776.87it/s]

Hard Train Acc:89.58%
[0, 0, 0, 7, 0, 0, 0, 0, 3026, 0, 0, 0, 1, 0, 3, 1, 15, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 24, 0, 0, 1, 0, 0, 0, 6094, 0, 0, 2, 0, 5836, 6313, 0, 0, 0, 0, 0, 5920, 0, 0, 104, 0, 5485, 0, 0, 0, 5945, 0, 25, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 669, 0, 0, 0, 0, 0, 5514, 1, 0, 0, 0, 0, 2952, 0, 0, 0, 4, 6413, 0, 0, 0, 5640, 0, 0]





In [159]:
#### Classifiers that enclose any data
torch.count_nonzero(set_count)

tensor(25, device='cuda:0')

In [160]:
#### classifier with class representation
torch.argmax(classifier.cls_weight, dim=1)

tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3,
        4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7,
        8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1,
        2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5,
        6, 7, 8, 9], device='cuda:0')

In [161]:
# The class labels are same as that of initialized
# tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3,
#         4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7,
#         8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1,
#         2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5,
#         6, 7, 8, 9], device='cuda:0')

In [162]:
classifier.cls_weight

Parameter containing:
tensor([[9.4619e-01, 4.4570e-14, 1.5639e-02, 1.1144e-02, 6.0348e-03, 2.1456e-03,
         1.6702e-02, 5.9290e-09, 2.1414e-03, 4.5562e-14],
        [8.6018e-05, 9.5355e-01, 1.1300e-02, 6.9282e-03, 6.2897e-03, 7.4113e-03,
         1.3287e-02, 1.9646e-09, 1.1510e-03, 1.5429e-09],
        [7.2009e-09, 7.4185e-10, 9.2392e-01, 1.2708e-03, 4.6695e-03, 1.3651e-02,
         4.5858e-02, 1.2355e-03, 9.3988e-03, 7.0837e-10],
        [4.5053e-02, 5.3718e-02, 2.1376e-04, 8.7840e-01, 2.3060e-04, 3.7111e-03,
         1.8112e-02, 1.5335e-04, 2.1481e-04, 1.9436e-04],
        [1.8616e-03, 1.2641e-14, 2.8892e-02, 1.0584e-02, 9.3146e-01, 1.2568e-14,
         2.6190e-02, 3.5428e-09, 1.0090e-03, 1.1771e-14],
        [1.1155e-09, 4.4267e-16, 3.9969e-03, 1.0296e-06, 1.7499e-10, 9.8900e-01,
         7.0006e-03, 2.6296e-16, 2.1079e-12, 4.4521e-16],
        [3.2145e-03, 6.9887e-14, 2.6743e-02, 3.9855e-03, 7.9626e-03, 7.0100e-03,
         9.5065e-01, 4.5230e-09, 4.3349e-04, 7.0137e-14],
     

In [163]:
# torch.unique(torch.argmax(classifier.cls_confidence, dim=1), return_counts=True)

In [164]:
classifier.inv_temp

Parameter containing:
tensor([2.2525], device='cuda:0', requires_grad=True)

In [165]:
### example output per classifier
yout[5]

tensor([2.0711e-04, 9.9817e-01, 2.0902e-04, 1.5644e-04, 2.0877e-04, 2.0955e-04,
        2.0827e-04, 2.1023e-04, 2.0954e-04, 2.0947e-04], device='cuda:0')

In [166]:
asdfsdf ## to break the code

NameError: name 'asdfsdf' is not defined

### analyze per classifier accuracy

In [None]:
test_count = 0
test_acc = 0
set_count = torch.zeros(classifier.num_sets).to(device)
set_acc = torch.zeros(classifier.num_sets).to(device)
for xx, yy in tqdm(train_loader):
    xx, yy = xx.to(device), yy.to(device)
    with torch.no_grad():
        yout = classifier(model(xx), hard=True)
        
    cls_indx = torch.argmax(classifier.cls_confidence, dim=1)
    set_indx, count = torch.unique(cls_indx, return_counts=True) 
    set_count[set_indx] += count
    
    outputs = torch.argmax(yout, dim=1).data.cpu().numpy()
    correct = (outputs == yy.data.cpu().numpy()).astype(float)
    
    ### class_index has 100 possible values
    for i, c in enumerate(correct):
        set_acc[cls_indx[i]] += c
    
#     print(set_acc.sum(), set_count.sum())
#     break
    test_acc += correct.sum()
    test_count += len(xx)

print(f'Hard Train Acc:{float(test_acc)/test_count*100:.2f}%')
print(set_count.type(torch.long).tolist())

In [None]:
set_acc/set_count

In [None]:
for i, (cnt, acc, cls) in enumerate(zip(set_count.type(torch.long).tolist(),
                                   (set_acc/set_count).tolist(),
                                   torch.argmax(classifier.cls_weight, dim=1).tolist())):
    if cnt == 0: continue
    print(f"{i},\t {cnt},\t {cls}\t {acc*100:.2f}%")