In [1]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from mpl_toolkits.mplot3d import Axes3D

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils import data

import random, os, pathlib, time
from tqdm import tqdm
from sklearn import datasets

In [2]:
import nflib
from nflib.flows import SequentialFlow, NormalizingFlow, ActNorm, AffineConstantFlow
import nflib.coupling_flows as icf
import nflib.inn_flow as inn
import nflib.res_flow as irf

In [3]:
from torch import distributions
from torch.distributions import MultivariateNormal

In [4]:
device = torch.device("cuda:0")
# device = torch.device("cpu")

## MNIST dataset

In [5]:
import mylibrary.datasets as datasets
# import mylibrary.nnlib as tnn

In [6]:
mnist = datasets.FashionMNIST()
# mnist.download_mnist()
# mnist.save_mnist()
train_data, train_label_, test_data, test_label_ = mnist.load()

train_data = train_data / 255.
test_data = test_data / 255.

# train_label = tnn.Logits.index_to_logit(train_label_)
train_size = len(train_label_)

In [7]:
## converting data to pytorch format
train_data = torch.Tensor(train_data)
test_data = torch.Tensor(test_data)
train_label = torch.LongTensor(train_label_)
test_label = torch.LongTensor(test_label_)

In [8]:
input_size = 784
output_size = 10

In [9]:
class MNIST_Dataset(data.Dataset):
    
    def __init__(self, data, label):
        self.data = data
        self.label = label
        
#         self.label = mask.type(torch.float32).reshape(-1,1)
        self._shuffle_data_()
        
    def __len__(self):
        return len(self.data)
    
    def _shuffle_data_(self):
        randidx = random.sample(range(len(self.data)), k=len(self.data))
        self.data = self.data[randidx]
        self.label = self.label[randidx]
    
    def __getitem__(self, idx):
        img, lbl = self.data[idx], self.label[idx]
        return img, lbl

In [10]:
train_dataset = MNIST_Dataset(train_data, train_label)
test_dataset = MNIST_Dataset(test_data, test_label)

In [80]:
class ConnectedClassifier_Softmax(nn.Module):
    
    def __init__(self,input_dim, num_sets, output_dim, inv_temp=1):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.num_sets = num_sets
        self.inv_temp = nn.Parameter(torch.ones(1)*inv_temp)
        
        self.linear = nn.Linear(input_dim, num_sets)
#         self.linear.bias.data *= 0
#         self.linear.weight.data *= 0.1
#         self.cls_weight = nn.Parameter(torch.randn(num_sets, output_dim)/output_dim)

        init_val = torch.randn(num_sets, output_dim)*0.01
        for ns in range(num_sets):
            init_val[ns, ns%output_dim] = 10.
        self.cls_weight = nn.Parameter(init_val)
        
        self.cls_confidence = None
        
        
    def forward(self, x, hard=False):
#         self.cls_weight.data = torch.abs(self.cls_weight.data/self.cls_weight.data.sum(dim=1, keepdim=True))
        
        x = self.linear(x)
        if hard:
            x = torch.softmax(x*1e5, dim=1)
        else:
            x = torch.softmax(x*self.inv_temp, dim=1)
        self.cls_confidence = x
        c = torch.softmax(self.cls_weight, dim=1)
#         c = self.cls_weight
        return x@c ## since both are normalized, it is also normalized

In [81]:
# class ConnectedClassifier_SoftKMeans(nn.Module):
    
#     def __init__(self,input_dim, num_sets, output_dim, inv_temp=1):
#         super().__init__()
#         self.input_dim = input_dim
#         self.output_dim = output_dim
#         self.num_sets = num_sets
#         self.inv_temp = nn.Parameter(torch.ones(1)*inv_temp)
        
#         self.centers = nn.Parameter(torch.rand(num_sets, input_dim)*2-1)
        
# #         self.cls_weight = nn.Parameter(torch.ones(num_sets, output_dim)/output_dim)

#         init_val = torch.randn(num_sets, output_dim)*0.01
#         for ns in range(num_sets):
#             init_val[ns, ns%output_dim] = 10.
#         self.cls_weight = nn.Parameter(init_val)

#         self.cls_confidence = None
        
        
#     def forward(self, x, hard=False):
#         x = x[:, :self.input_dim]
#         dists = torch.cdist(x, self.centers)
#         dists = dists/np.sqrt(self.input_dim) ### correction to make diagonal of unit square 1 in nD space
        
#         if hard:
#             x = torch.softmax(-dists*1e5, dim=1)
#         else:
#             x = torch.softmax(-dists*self.inv_temp, dim=1)
#         self.cls_confidence = x
#         c = torch.softmax(self.cls_weight, dim=1)
# #         c = self.cls_weight
#         return x@c ## since both are normalized, it is also normalized

In [82]:
class ConnectedClassifier_SoftKMeans(nn.Module):
    
    def __init__(self,input_dim, num_sets, output_dim, inv_temp=1):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.num_sets = num_sets
        self.inv_temp = nn.Parameter(torch.ones(1)*inv_temp)
        
        self.centers = nn.Parameter(torch.rand(num_sets, input_dim)*2-1)
        
#         self.cls_weight = nn.Parameter(torch.ones(num_sets, output_dim)/output_dim)

        init_val = torch.randn(num_sets, output_dim)*0.01
        for ns in range(num_sets):
            init_val[ns, ns%output_dim] = 10.
        self.cls_weight = nn.Parameter(init_val)

        self.cls_confidence = None
        
        
    def forward(self, x, hard=False):
#         self.cls_weight.data = torch.abs(self.cls_weight.data/self.cls_weight.data.sum(dim=1, keepdim=True))
        
        x = x[:, :self.input_dim]
        dists = torch.cdist(x, self.centers)
        dists = dists/np.sqrt(self.input_dim) ### correction to make diagonal of unit square 1 in nD space
        
        if hard:
            x = torch.softmax(-dists*1e5, dim=1)
        else:
            x = torch.softmax(-dists*self.inv_temp, dim=1)
        self.cls_confidence = x
        c = torch.softmax(self.cls_weight, dim=1)
#         c = self.cls_weight
        return x@c ## since both are normalized, it is also normalized

In [83]:
# actf = irf.Swish
# flows = [
#     ActNorm(784),
#     irf.ResidualFlow(784, [784], activation=actf),
#     ActNorm(784),
#     irf.ResidualFlow(784, [784], activation=actf),
#     ActNorm(784),
#         ]

# model = SequentialFlow(flows)

In [115]:
model = nn.Sequential(nn.Linear(784, 784, bias=False),
                      nn.BatchNorm1d(784),
                      nn.SELU(),
                      nn.Linear(784, 784, bias=False),
                      nn.BatchNorm1d(784),
                      nn.SELU(),
                     )

In [116]:
model.to(device)

Sequential(
  (0): Linear(in_features=784, out_features=784, bias=False)
  (1): BatchNorm1d(784, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): SELU()
  (3): Linear(in_features=784, out_features=784, bias=False)
  (4): BatchNorm1d(784, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (5): SELU()
)

In [117]:
list(model.flows[0].parameters())

AttributeError: 'Sequential' object has no attribute 'flows'

In [118]:
classifier = ConnectedClassifier_SoftKMeans(784, 100, 10)
# classifier = ConnectedClassifier_Softmax(784, 100, 10)
classifier = classifier.to(device)

## Model Train

In [119]:
learning_rate = 0.0003
batch_size = 50

In [120]:
train_loader = data.DataLoader(dataset=train_dataset, num_workers=4, batch_size=batch_size, shuffle=True)
test_loader = data.DataLoader(dataset=test_dataset, num_workers=4, batch_size=batch_size, shuffle=False)

In [121]:
# criterion = nn.NLLLoss()
criterion = nn.CrossEntropyLoss()

optimizer = optim.Adam(list(model.parameters())+list(classifier.parameters()),
                       lr=learning_rate, weight_decay=1e-15) # todo tune WD
# optimizer = optim.SGD(model.parameters(), lr=0.1)

print("number of params: ", sum(p.numel() for p in model.parameters()))

number of params:  1232448


In [122]:
for p in model.parameters():
    print(torch.isnan(p).type(torch.float32).sum())

tensor(0., device='cuda:0')
tensor(0., device='cuda:0')
tensor(0., device='cuda:0')
tensor(0., device='cuda:0')
tensor(0., device='cuda:0')
tensor(0., device='cuda:0')


In [123]:
model(torch.randn(10, 784).to(device))

tensor([[-1.4210,  0.2018, -1.2675,  ..., -1.2747, -0.7042, -0.9518],
        [-0.5687,  0.4812, -0.9916,  ..., -0.7586, -0.0931, -1.0872],
        [ 0.6926, -0.7393,  0.0140,  ...,  1.2545,  0.7303, -0.6991],
        ...,
        [ 0.3149,  1.7224, -0.9930,  ...,  2.4332,  1.4887,  1.5481],
        [-0.8335, -0.2494, -1.1215,  ..., -0.2140, -0.9818, -0.8321],
        [-0.9755,  0.7651,  0.7195,  ..., -1.0379, -1.1725, -1.2073]],
       device='cuda:0', grad_fn=<EluBackward0>)

In [124]:
xx = iter(test_loader).next()[0]
xx.shape

torch.Size([50, 784])

In [125]:
losses = []
train_accs = []
test_accs = []
EPOCHS = 50

index = 0
for epoch in range(EPOCHS):
    train_acc = 0
    train_count = 0
    for xx, yy in tqdm(train_loader):
        xx, yy = xx.to(device), yy.to(device)
#     for xx, yy in tqdm(test_loader):

        yout = model(xx)
#         print(yout)
        yout = classifier(yout)    
#         print(yout)
        loss = criterion(yout, yy)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.append(float(loss))

        outputs = torch.argmax(yout, dim=1).data.cpu().numpy()
        correct = (outputs == yy.data.cpu().numpy()).astype(float).sum()
        train_acc += correct
        train_count += len(outputs)
#         break

    train_accs.append(float(train_acc)/train_count*100)
    train_acc = 0
    train_count = 0

    print(f'Epoch: {epoch}:{index},  Loss:{float(loss)}')
    test_count = 0
    test_acc = 0
    for xx, yy in tqdm(test_loader):
        xx, yy = xx.to(device), yy.to(device)
        with torch.no_grad():
            yout = classifier(model(xx))    
        outputs = torch.argmax(yout, dim=1).data.cpu().numpy()
        correct = (outputs == yy.data.cpu().numpy()).astype(float).sum()
        test_acc += correct
        test_count += len(xx)
    test_accs.append(float(test_acc)/test_count*100)
    print(f'Train Acc:{train_accs[-1]:.2f}%, Test Acc:{test_accs[-1]:.2f}%')
    print()

### after each class index is finished training
print(f'\t-> Train Acc {max(train_accs)} ; Test Acc {max(test_accs)}')

100%|██████████| 1200/1200 [00:03<00:00, 391.66it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 0:0,  Loss:2.217250108718872


100%|██████████| 200/200 [00:00<00:00, 559.66it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:81.55%, Test Acc:82.90%



100%|██████████| 1200/1200 [00:03<00:00, 398.56it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 1:0,  Loss:1.9976089000701904


100%|██████████| 200/200 [00:00<00:00, 568.69it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:84.91%, Test Acc:84.49%



100%|██████████| 1200/1200 [00:03<00:00, 399.94it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 2:0,  Loss:1.7570583820343018


100%|██████████| 200/200 [00:00<00:00, 569.54it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:86.35%, Test Acc:85.37%



100%|██████████| 1200/1200 [00:02<00:00, 400.41it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 3:0,  Loss:1.693103313446045


100%|██████████| 200/200 [00:00<00:00, 590.52it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:87.37%, Test Acc:86.59%



100%|██████████| 1200/1200 [00:03<00:00, 395.36it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 4:0,  Loss:1.649895191192627


100%|██████████| 200/200 [00:00<00:00, 559.84it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:88.16%, Test Acc:86.67%



100%|██████████| 1200/1200 [00:02<00:00, 403.36it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 5:0,  Loss:1.5554869174957275


100%|██████████| 200/200 [00:00<00:00, 547.70it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:88.74%, Test Acc:87.18%



100%|██████████| 1200/1200 [00:03<00:00, 399.71it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 6:0,  Loss:1.5285190343856812


100%|██████████| 200/200 [00:00<00:00, 567.80it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:89.19%, Test Acc:87.14%



100%|██████████| 1200/1200 [00:02<00:00, 400.79it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 7:0,  Loss:1.586421012878418


100%|██████████| 200/200 [00:00<00:00, 563.06it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:89.48%, Test Acc:86.60%



100%|██████████| 1200/1200 [00:02<00:00, 405.45it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 8:0,  Loss:1.5765680074691772


100%|██████████| 200/200 [00:00<00:00, 580.01it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:89.71%, Test Acc:87.32%



100%|██████████| 1200/1200 [00:02<00:00, 400.60it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 9:0,  Loss:1.5783597230911255


100%|██████████| 200/200 [00:00<00:00, 564.60it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:89.90%, Test Acc:87.16%



100%|██████████| 1200/1200 [00:03<00:00, 398.83it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 10:0,  Loss:1.6235712766647339


100%|██████████| 200/200 [00:00<00:00, 565.36it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:90.16%, Test Acc:87.64%



100%|██████████| 1200/1200 [00:02<00:00, 403.75it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 11:0,  Loss:1.5906027555465698


100%|██████████| 200/200 [00:00<00:00, 600.54it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:90.53%, Test Acc:87.90%



100%|██████████| 1200/1200 [00:03<00:00, 398.89it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 12:0,  Loss:1.5108877420425415


100%|██████████| 200/200 [00:00<00:00, 566.38it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:90.71%, Test Acc:87.81%



100%|██████████| 1200/1200 [00:03<00:00, 398.64it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 13:0,  Loss:1.5244624614715576


100%|██████████| 200/200 [00:00<00:00, 571.43it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:90.72%, Test Acc:87.51%



100%|██████████| 1200/1200 [00:02<00:00, 403.96it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 14:0,  Loss:1.5436559915542603


100%|██████████| 200/200 [00:00<00:00, 572.53it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:90.91%, Test Acc:87.65%



100%|██████████| 1200/1200 [00:03<00:00, 397.22it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 15:0,  Loss:1.5871174335479736


100%|██████████| 200/200 [00:00<00:00, 573.17it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:91.12%, Test Acc:87.95%



100%|██████████| 1200/1200 [00:03<00:00, 395.43it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 16:0,  Loss:1.5098756551742554


100%|██████████| 200/200 [00:00<00:00, 568.82it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:91.16%, Test Acc:87.41%



100%|██████████| 1200/1200 [00:03<00:00, 395.84it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 17:0,  Loss:1.5842382907867432


100%|██████████| 200/200 [00:00<00:00, 576.73it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:91.47%, Test Acc:87.56%



100%|██████████| 1200/1200 [00:02<00:00, 403.41it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 18:0,  Loss:1.599894404411316


100%|██████████| 200/200 [00:00<00:00, 549.09it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:91.45%, Test Acc:88.02%



100%|██████████| 1200/1200 [00:03<00:00, 392.72it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 19:0,  Loss:1.5045396089553833


100%|██████████| 200/200 [00:00<00:00, 597.13it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:91.65%, Test Acc:88.15%



100%|██████████| 1200/1200 [00:03<00:00, 396.10it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 20:0,  Loss:1.5181281566619873


100%|██████████| 200/200 [00:00<00:00, 582.24it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:91.64%, Test Acc:88.23%



100%|██████████| 1200/1200 [00:03<00:00, 396.07it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 21:0,  Loss:1.494056224822998


100%|██████████| 200/200 [00:00<00:00, 573.26it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:91.96%, Test Acc:88.31%



100%|██████████| 1200/1200 [00:02<00:00, 401.55it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 22:0,  Loss:1.560639500617981


100%|██████████| 200/200 [00:00<00:00, 552.50it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:91.89%, Test Acc:87.81%



100%|██████████| 1200/1200 [00:02<00:00, 401.33it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 23:0,  Loss:1.4787487983703613


100%|██████████| 200/200 [00:00<00:00, 589.75it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:92.05%, Test Acc:88.47%



100%|██████████| 1200/1200 [00:03<00:00, 399.10it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 24:0,  Loss:1.5368199348449707


100%|██████████| 200/200 [00:00<00:00, 566.70it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:92.14%, Test Acc:88.41%



100%|██████████| 1200/1200 [00:02<00:00, 401.43it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 25:0,  Loss:1.540602207183838


100%|██████████| 200/200 [00:00<00:00, 574.22it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:92.33%, Test Acc:88.10%



100%|██████████| 1200/1200 [00:02<00:00, 412.31it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 26:0,  Loss:1.5456085205078125


100%|██████████| 200/200 [00:00<00:00, 581.84it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:92.31%, Test Acc:88.21%



100%|██████████| 1200/1200 [00:02<00:00, 406.35it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 27:0,  Loss:1.5572260618209839


100%|██████████| 200/200 [00:00<00:00, 585.65it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:92.45%, Test Acc:87.83%



100%|██████████| 1200/1200 [00:02<00:00, 401.59it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 28:0,  Loss:1.521071434020996


100%|██████████| 200/200 [00:00<00:00, 583.68it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:92.43%, Test Acc:88.38%



100%|██████████| 1200/1200 [00:02<00:00, 400.66it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 29:0,  Loss:1.5693799257278442


100%|██████████| 200/200 [00:00<00:00, 569.24it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:92.57%, Test Acc:87.80%



100%|██████████| 1200/1200 [00:02<00:00, 403.09it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 30:0,  Loss:1.480139970779419


100%|██████████| 200/200 [00:00<00:00, 563.38it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:92.71%, Test Acc:88.62%



100%|██████████| 1200/1200 [00:02<00:00, 400.02it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 31:0,  Loss:1.4967584609985352


100%|██████████| 200/200 [00:00<00:00, 574.66it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:92.64%, Test Acc:88.24%



100%|██████████| 1200/1200 [00:03<00:00, 396.82it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 32:0,  Loss:1.4894963502883911


100%|██████████| 200/200 [00:00<00:00, 540.88it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:92.80%, Test Acc:88.64%



100%|██████████| 1200/1200 [00:02<00:00, 402.02it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 33:0,  Loss:1.575962781906128


100%|██████████| 200/200 [00:00<00:00, 563.31it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:92.88%, Test Acc:88.53%



100%|██████████| 1200/1200 [00:03<00:00, 392.88it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 34:0,  Loss:1.6194770336151123


100%|██████████| 200/200 [00:00<00:00, 571.69it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:92.94%, Test Acc:88.63%



100%|██████████| 1200/1200 [00:02<00:00, 408.35it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 35:0,  Loss:1.5123810768127441


100%|██████████| 200/200 [00:00<00:00, 552.51it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:93.08%, Test Acc:88.62%



100%|██████████| 1200/1200 [00:02<00:00, 400.69it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 36:0,  Loss:1.5655328035354614


100%|██████████| 200/200 [00:00<00:00, 576.41it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:93.13%, Test Acc:88.49%



100%|██████████| 1200/1200 [00:03<00:00, 389.35it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 37:0,  Loss:1.4848432540893555


100%|██████████| 200/200 [00:00<00:00, 563.89it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:93.22%, Test Acc:88.56%



100%|██████████| 1200/1200 [00:03<00:00, 393.95it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 38:0,  Loss:1.50151789188385


100%|██████████| 200/200 [00:00<00:00, 582.14it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:93.25%, Test Acc:88.74%



100%|██████████| 1200/1200 [00:03<00:00, 397.07it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 39:0,  Loss:1.5363467931747437


100%|██████████| 200/200 [00:00<00:00, 570.44it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:93.27%, Test Acc:88.76%



100%|██████████| 1200/1200 [00:02<00:00, 403.69it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 40:0,  Loss:1.5949763059616089


100%|██████████| 200/200 [00:00<00:00, 545.70it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:93.42%, Test Acc:88.61%



100%|██████████| 1200/1200 [00:03<00:00, 395.56it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 41:0,  Loss:1.5739184617996216


100%|██████████| 200/200 [00:00<00:00, 574.47it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:93.38%, Test Acc:88.79%



100%|██████████| 1200/1200 [00:03<00:00, 398.20it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 42:0,  Loss:1.6237308979034424


100%|██████████| 200/200 [00:00<00:00, 587.96it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:93.51%, Test Acc:88.97%



100%|██████████| 1200/1200 [00:02<00:00, 402.82it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 43:0,  Loss:1.5362839698791504


100%|██████████| 200/200 [00:00<00:00, 585.60it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:93.63%, Test Acc:88.42%



100%|██████████| 1200/1200 [00:03<00:00, 393.13it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 44:0,  Loss:1.6309760808944702


100%|██████████| 200/200 [00:00<00:00, 573.09it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:93.50%, Test Acc:88.92%



100%|██████████| 1200/1200 [00:03<00:00, 395.95it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 45:0,  Loss:1.5014489889144897


100%|██████████| 200/200 [00:00<00:00, 583.95it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:93.59%, Test Acc:88.94%



100%|██████████| 1200/1200 [00:03<00:00, 397.26it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 46:0,  Loss:1.504794955253601


100%|██████████| 200/200 [00:00<00:00, 589.43it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:93.81%, Test Acc:88.77%



100%|██████████| 1200/1200 [00:03<00:00, 399.04it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 47:0,  Loss:1.508610486984253


100%|██████████| 200/200 [00:00<00:00, 575.09it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:93.72%, Test Acc:88.27%



100%|██████████| 1200/1200 [00:02<00:00, 404.70it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 48:0,  Loss:1.5181329250335693


100%|██████████| 200/200 [00:00<00:00, 519.42it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:93.80%, Test Acc:88.83%



100%|██████████| 1200/1200 [00:03<00:00, 399.39it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 49:0,  Loss:1.504277229309082


100%|██████████| 200/200 [00:00<00:00, 572.70it/s]

Train Acc:93.92%, Test Acc:88.84%

	-> Train Acc 93.91833333333334 ; Test Acc 88.97





In [None]:
## LinearSoftmax	-> Train Acc 90.44 ; Test Acc 87.58
## DistanceSoftmax	-> Train Acc 93.91833333333334 ; Test Acc 88.97

In [126]:
torch.softmax(classifier.cls_weight, dim=1)[0].sum()

tensor(1.0000, device='cuda:0', grad_fn=<SumBackward0>)

In [76]:
# classifier.cls_weight[0].sum()

tensor(-4.4816, device='cuda:0', grad_fn=<SumBackward0>)

### Hard test accuracy with count per classifier

In [127]:
test_count = 0
test_acc = 0
set_count = torch.zeros(classifier.num_sets).to(device)
for xx, yy in tqdm(test_loader):
    xx, yy = xx.to(device), yy.to(device)
    with torch.no_grad():
        yout = classifier(model(xx), hard=True)
        set_indx, count = torch.unique(torch.argmax(classifier.cls_confidence, dim=1), return_counts=True) 
        set_count[set_indx] += count
    outputs = torch.argmax(yout, dim=1).data.cpu().numpy()
    correct = (outputs == yy.data.cpu().numpy()).astype(float).sum()
    test_acc += correct
    test_count += len(xx)

print(f'Hard Test Acc:{float(test_acc)/test_count*100:.2f}%')
print(set_count.type(torch.long).tolist())

100%|██████████| 200/200 [00:00<00:00, 552.34it/s]

Hard Test Acc:88.87%
[203, 2, 3, 75, 291, 79, 42, 803, 29, 164, 400, 579, 40, 124, 55, 54, 0, 0, 217, 0, 23, 0, 137, 1, 6, 20, 188, 0, 226, 0, 1, 0, 18, 50, 265, 12, 412, 3, 2, 2, 282, 102, 192, 427, 0, 11, 109, 4, 184, 0, 20, 225, 5, 23, 68, 464, 113, 0, 59, 20, 42, 0, 8, 276, 0, 310, 4, 107, 22, 609, 2, 1, 169, 54, 0, 3, 39, 7, 14, 3, 49, 63, 534, 0, 9, 33, 10, 4, 98, 214, 1, 0, 21, 17, 192, 4, 26, 70, 148, 3]





### Hard train accuracy with count per classifier

In [128]:
test_count = 0
test_acc = 0
set_count = torch.zeros(classifier.num_sets).to(device)
for xx, yy in tqdm(train_loader):
    xx, yy = xx.to(device), yy.to(device)
    with torch.no_grad():
        yout = classifier(model(xx), hard=True)
        set_indx, count = torch.unique(torch.argmax(classifier.cls_confidence, dim=1), return_counts=True) 
        set_count[set_indx] += count
    outputs = torch.argmax(yout, dim=1).data.cpu().numpy()
    correct = (outputs == yy.data.cpu().numpy()).astype(float).sum()
    test_acc += correct
    test_count += len(xx)

print(f'Hard Train Acc:{float(test_acc)/test_count*100:.2f}%')
print(set_count.type(torch.long).tolist())

100%|██████████| 1200/1200 [00:01<00:00, 745.66it/s]

Hard Train Acc:94.22%
[1386, 9, 20, 449, 2100, 542, 218, 4756, 208, 946, 2124, 3747, 175, 742, 316, 354, 3, 0, 1400, 1, 180, 0, 788, 3, 23, 165, 1100, 0, 1282, 8, 3, 0, 68, 258, 1600, 64, 2619, 12, 4, 11, 1747, 622, 1135, 2568, 0, 66, 693, 48, 1088, 0, 164, 1278, 29, 127, 413, 2754, 719, 0, 382, 135, 220, 0, 27, 1553, 1, 1837, 28, 666, 110, 3660, 16, 5, 872, 364, 1, 31, 215, 43, 101, 34, 293, 273, 3185, 0, 62, 145, 43, 14, 565, 1227, 3, 0, 107, 139, 1044, 37, 115, 433, 859, 20]





In [129]:
#### Classifiers that enclose any data
torch.count_nonzero(set_count)

tensor(90, device='cuda:0')

In [130]:
#### classifier with class representation
torch.argmax(classifier.cls_weight, dim=1)

tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3,
        4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7,
        8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1,
        2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5,
        6, 7, 8, 9], device='cuda:0')

In [131]:
# The class labels are same as that of initialized
# tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3,
#         4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7,
#         8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1,
#         2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5,
#         6, 7, 8, 9], device='cuda:0')

In [132]:
classifier.cls_weight

Parameter containing:
tensor([[15.4343, -2.5752, -2.5682, -2.5366, -2.5784, -2.5840, -2.3904, -2.5766,
         -2.5715, -2.5801],
        [-2.6027, 15.3726, -2.6042, -2.5841, -2.5935, -2.6008, -2.6042, -2.6043,
         -2.6027, -2.5945],
        [-2.4773, -2.4911, 15.4963, -2.4824, -2.3633, -2.4851, -2.3626, -2.5023,
         -2.4834, -2.5012],
        [-2.5534, -2.5674, -2.5833, 15.5599, -2.5612, -2.5921, -2.5602, -2.5967,
         -2.5961, -2.5909],
        [-2.5420, -2.5470, -2.3712, -2.5110, 15.5261, -2.5400, -2.4244, -2.5571,
         -2.5331, -2.5470],
        [-2.5564, -2.5525, -2.5579, -2.5592, -2.5588, 15.4203, -2.5596, -2.5064,
         -2.5450, -2.5383],
        [-2.3042, -2.4408, -2.2975, -2.4075, -2.2925, -2.4537, 15.5461, -2.4535,
         -2.4328, -2.4592],
        [-2.6714, -2.6751, -2.6732, -2.6675, -2.6710, -2.6144, -2.6725, 15.5935,
         -2.6630, -2.6189],
        [-2.6257, -2.6414, -2.6287, -2.6392, -2.6213, -2.6197, -2.6252, -2.6359,
         15.4365, -2.6230

In [133]:
# torch.unique(torch.argmax(classifier.cls_confidence, dim=1), return_counts=True)

In [134]:
classifier.inv_temp

Parameter containing:
tensor([12.8763], device='cuda:0', requires_grad=True)

In [135]:
### example output per classifier
yout[5]

tensor([1.4948e-08, 1.4584e-08, 1.4847e-08, 1.4654e-08, 1.4749e-08, 1.4790e-08,
        1.4930e-08, 1.4765e-08, 1.0000e+00, 1.4702e-08], device='cuda:0')

In [None]:
asdfsdf ## to break the code

### analyze per classifier accuracy

In [136]:
test_count = 0
test_acc = 0
set_count = torch.zeros(classifier.num_sets).to(device)
set_acc = torch.zeros(classifier.num_sets).to(device)
for xx, yy in tqdm(train_loader):
    xx, yy = xx.to(device), yy.to(device)
    with torch.no_grad():
        yout = classifier(model(xx), hard=True)
        
    cls_indx = torch.argmax(classifier.cls_confidence, dim=1)
    set_indx, count = torch.unique(cls_indx, return_counts=True) 
    set_count[set_indx] += count
    
    outputs = torch.argmax(yout, dim=1).data.cpu().numpy()
    correct = (outputs == yy.data.cpu().numpy()).astype(float)
    
    ### class_index has 100 possible values
    for i, c in enumerate(correct):
        set_acc[cls_indx[i]] += c
    
#     print(set_acc.sum(), set_count.sum())
#     break
    test_acc += correct.sum()
    test_count += len(xx)

print(f'Hard Train Acc:{float(test_acc)/test_count*100:.2f}%')
print(set_count.type(torch.long).tolist())

100%|██████████| 1200/1200 [00:03<00:00, 338.48it/s]

Hard Train Acc:94.20%
[1423, 8, 24, 431, 2139, 560, 200, 4761, 222, 964, 2139, 3743, 164, 776, 304, 336, 4, 0, 1404, 4, 189, 0, 791, 0, 23, 167, 1091, 0, 1267, 11, 1, 0, 64, 286, 1590, 66, 2641, 17, 6, 8, 1686, 617, 1095, 2522, 0, 67, 696, 52, 1089, 0, 138, 1283, 15, 129, 383, 2726, 714, 2, 413, 129, 233, 0, 34, 1555, 0, 1839, 22, 674, 98, 3696, 17, 6, 858, 367, 3, 28, 222, 37, 101, 39, 299, 290, 3252, 0, 65, 145, 50, 15, 574, 1167, 3, 0, 126, 137, 1043, 45, 103, 431, 830, 16]





In [137]:
set_acc/set_count

tensor([0.9452, 0.8750, 0.9583, 0.9211, 0.9158, 0.9982, 0.8650, 0.9853, 0.9775,
        0.9948, 0.8672, 0.9989, 0.8293, 0.8235, 0.9309, 0.9940, 0.7500,    nan,
        0.9957, 1.0000, 0.9524,    nan, 0.9039,    nan, 0.9565, 1.0000, 0.7562,
           nan, 0.9897, 0.9091, 1.0000,    nan, 0.7031, 0.9126, 0.9220, 0.9848,
        0.9023, 0.9412, 1.0000, 1.0000, 0.9419, 0.9984, 0.9461, 0.9623,    nan,
        1.0000, 0.9210, 1.0000, 0.9917,    nan, 0.9203, 0.9977, 0.8000, 0.8527,
        0.9399, 0.9938, 0.8151, 1.0000, 0.9831, 0.9922, 0.9185,    nan, 0.5000,
        0.9350,    nan, 0.9956, 0.9091, 0.9614, 1.0000, 0.9811, 0.9412, 1.0000,
        0.7401, 0.9619, 0.3333, 1.0000, 0.9414, 1.0000, 0.9307, 1.0000, 0.8060,
        0.9966, 0.8598,    nan, 0.9692, 1.0000, 0.7800, 0.9333, 1.0000, 0.9709,
        1.0000,    nan, 0.7302, 0.9562, 0.8897, 1.0000, 0.8155, 0.9536, 0.9892,
        1.0000], device='cuda:0')

In [138]:
for i, (cnt, acc, cls) in enumerate(zip(set_count.type(torch.long).tolist(),
                                   (set_acc/set_count).tolist(),
                                   torch.argmax(classifier.cls_weight, dim=1).tolist())):
    if cnt == 0: continue
    print(f"{i},\t {cnt},\t {cls}\t {acc*100:.2f}%")

0,	 1423,	 0	 94.52%
1,	 8,	 1	 87.50%
2,	 24,	 2	 95.83%
3,	 431,	 3	 92.11%
4,	 2139,	 4	 91.58%
5,	 560,	 5	 99.82%
6,	 200,	 6	 86.50%
7,	 4761,	 7	 98.53%
8,	 222,	 8	 97.75%
9,	 964,	 9	 99.48%
10,	 2139,	 0	 86.72%
11,	 3743,	 1	 99.89%
12,	 164,	 2	 82.93%
13,	 776,	 3	 82.35%
14,	 304,	 4	 93.09%
15,	 336,	 5	 99.40%
16,	 4,	 6	 75.00%
18,	 1404,	 8	 99.57%
19,	 4,	 9	 100.00%
20,	 189,	 0	 95.24%
22,	 791,	 2	 90.39%
24,	 23,	 4	 95.65%
25,	 167,	 5	 100.00%
26,	 1091,	 6	 75.62%
28,	 1267,	 8	 98.97%
29,	 11,	 9	 90.91%
30,	 1,	 0	 100.00%
32,	 64,	 2	 70.31%
33,	 286,	 3	 91.26%
34,	 1590,	 4	 92.20%
35,	 66,	 5	 98.48%
36,	 2641,	 6	 90.23%
37,	 17,	 7	 94.12%
38,	 6,	 8	 100.00%
39,	 8,	 9	 100.00%
40,	 1686,	 0	 94.19%
41,	 617,	 1	 99.84%
42,	 1095,	 2	 94.61%
43,	 2522,	 3	 96.23%
45,	 67,	 5	 100.00%
46,	 696,	 6	 92.10%
47,	 52,	 7	 100.00%
48,	 1089,	 8	 99.17%
50,	 138,	 0	 92.03%
51,	 1283,	 1	 99.77%
52,	 15,	 2	 80.00%
53,	 129,	 3	 85.27%
54,	 383,	 4	 93.99%
5