In [1]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from mpl_toolkits.mplot3d import Axes3D

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils import data

import random, os, pathlib, time
from tqdm import tqdm
from sklearn import datasets

In [2]:
import nflib
from nflib.flows import SequentialFlow, NormalizingFlow, ActNorm, AffineConstantFlow
import nflib.coupling_flows as icf
import nflib.inn_flow as inn
import nflib.res_flow as irf

In [3]:
import dtnnlib as dtnn

In [4]:
device = torch.device("cuda:0")
# device = torch.device("cpu")

## MNIST dataset

In [5]:
import mylibrary.datasets as datasets

In [6]:
mnist = datasets.FashionMNIST()
# mnist.download_mnist()
# mnist.save_mnist()
train_data, train_label_, test_data, test_label_ = mnist.load()

train_data = train_data / 255.
test_data = test_data / 255.

# train_label = tnn.Logits.index_to_logit(train_label_)
train_size = len(train_label_)

In [7]:
## converting data to pytorch format
train_data = torch.Tensor(train_data)
test_data = torch.Tensor(test_data)
train_label = torch.LongTensor(train_label_)
test_label = torch.LongTensor(test_label_)

In [8]:
input_size = 784
output_size = 10

In [9]:
class MNIST_Dataset(data.Dataset):
    
    def __init__(self, data, label):
        self.data = data
        self.label = label
        
#         self.label = mask.type(torch.float32).reshape(-1,1)
        self._shuffle_data_()
        
    def __len__(self):
        return len(self.data)
    
    def _shuffle_data_(self):
        randidx = random.sample(range(len(self.data)), k=len(self.data))
        self.data = self.data[randidx]
        self.label = self.label[randidx]
    
    def __getitem__(self, idx):
        img, lbl = self.data[idx], self.label[idx]
        return img, lbl

In [10]:
train_dataset = MNIST_Dataset(train_data, train_label)
test_dataset = MNIST_Dataset(test_data, test_label)

In [11]:
learning_rate = 0.0003
batch_size = 50

In [12]:
train_loader = data.DataLoader(dataset=train_dataset, num_workers=4, batch_size=batch_size, shuffle=True)
test_loader = data.DataLoader(dataset=test_dataset, num_workers=4, batch_size=batch_size, shuffle=False)

In [254]:
class ParameterSelector(nn.Module):
    
    def __init__(self, input_dim, num_sets, parameter_shapes, inv_temp=1):
        super().__init__()
        
        self.input_dim = input_dim
        self.num_sets = num_sets
        self.parameter_shapes = parameter_shapes
#         self.inv_temp = nn.Parameter(torch.ones(1)*inv_temp)
        self.inv_temp = inv_temp
        
        self.dt = dtnn.DistanceTransformBase(input_dim, num_sets)
        
        self.parameter_list = []
        for ps in self.parameter_shapes:
            ps = (self.num_sets, *ps)
            if len(ps) > 1:
                param = torch.nn.init.xavier_uniform_(torch.empty(*ps), gain=1.0).reshape(self.num_sets, -1)
            else:
                param = torch.zeros(*ps)
#             print(ps, param)
            self.parameter_list.append(nn.Parameter(param))
        
        self.parameter_list = nn.ParameterList(self.parameter_list)

        ## class repeat sequentially
        self.cls_confidence = None
        
    def forward(self, x, hard=False):
        
        dists = self.dt(x)
        dists = dists/np.sqrt(self.input_dim) ### correction to make diagonal of unit square 1 in nD space
        if hard:
            x = torch.softmax(-dists*1e5, dim=1)
        else:
            x = torch.softmax(-dists*self.inv_temp, dim=1)
        
        self.cls_confidence = x
        
        ret_params = []
        for param, shape in zip(self.parameter_list, self.parameter_shapes):
#             print(x.shape, param.shape)
            p = x@param
#             print(p.shape)
            ret_params.append(p.reshape(-1, *shape))
        return ret_params

In [255]:
ps = ParameterSelector(10, 5, [(10,2), (2,)])

In [256]:
list(ps.parameters())

[Parameter containing:
 tensor([[-0.1313,  0.0607, -0.2348, -0.7330, -0.1623, -0.5914,  0.2529, -0.0696,
          -0.5902, -0.5673],
         [-0.5802, -0.4504, -0.1569, -0.1745,  0.1463,  0.2057,  0.0067, -0.3910,
          -0.1000, -0.0230],
         [ 0.0756,  0.2478,  0.1728, -0.3095,  0.1257, -0.1759,  0.2901,  0.1017,
           0.1039,  0.0228],
         [-1.0776,  0.1633, -0.1290,  0.0582,  0.3161,  0.2448,  0.4085,  0.4782,
          -0.4364, -0.3307],
         [ 0.2917,  0.6311,  0.0644, -0.1773, -0.5003, -0.1194, -0.5465,  0.2802,
           0.3461,  0.4844]], requires_grad=True),
 Parameter containing:
 tensor([[-0.1443, -0.0646, -0.0770, -0.2922,  0.3557,  0.1664, -0.1609, -0.3694,
           0.2793,  0.1567,  0.2841, -0.2629,  0.0418, -0.1899,  0.0156, -0.0733,
           0.2299, -0.2937, -0.3215, -0.2278],
         [ 0.2329,  0.2523, -0.0409,  0.4462, -0.1542,  0.4059, -0.0168,  0.1805,
           0.2208, -0.2575,  0.3665,  0.2329, -0.3225,  0.1412, -0.2589, -0.3646,
  

In [257]:
x = torch.randn(3, 10)
params = ps(x)

In [258]:
params[1].shape

torch.Size([3, 2])

In [259]:
def psLinear(x, weights, bias):
    return torch.bmm(x.unsqueeze(1), weights).squeeze(1) + bias

In [260]:
psLinear(x, params[0], params[1])

tensor([[-1.1387,  0.0128],
        [-0.1935, -0.3205],
        [-0.4597, -0.4354]], grad_fn=<AddBackward0>)

## Now the model

In [261]:
# actf = irf.Swish
# flows = [
#     ActNorm(784),
#     irf.ResidualFlow(784, [784], activation=actf),
#     ActNorm(784),
#     irf.ResidualFlow(784, [784], activation=actf),
#     ActNorm(784),
#         ]

# model = SequentialFlow(flows)

In [262]:
model = nn.Sequential(nn.Linear(784, 784, bias=False),
#                       nn.BatchNorm1d(784),
#                       nn.SELU(),
#                       nn.Linear(784, 784, bias=False),
#                       nn.BatchNorm1d(784),
#                       nn.SELU(),
                     )

In [263]:
model.to(device)

Sequential(
  (0): Linear(in_features=784, out_features=784, bias=False)
)

In [277]:
ps0 = ParameterSelector(784, 20, [(784, 10), (10,)])
ps0 = ps0.to(device)

## Model Train

In [278]:
# criterion = nn.NLLLoss()
criterion = nn.CrossEntropyLoss()

optimizer = optim.Adam(list(model.parameters())+list(ps0.parameters()),
                       lr=0.0003) # todo tune WD
# optimizer = optim.SGD(model.parameters(), lr=0.1)

print("number of params: ", sum(p.numel() for p in model.parameters()))

number of params:  614656


In [279]:
# for p in model.parameters():
#     print(torch.isnan(p).type(torch.float32).sum())

In [280]:
xx = iter(test_loader).next()[0]
xx.shape

torch.Size([50, 784])

In [281]:
model(xx.to(device))

tensor([[ 1.1920,  0.6853,  0.2957,  ...,  1.3573,  0.1067,  1.1073],
        [ 0.6517, -0.2408, -0.8800,  ..., -1.2877,  1.6682,  0.4547],
        [ 0.4065,  0.1481,  0.8354,  ..., -1.0897,  2.0784, -1.6912],
        ...,
        [-0.8900,  1.7951, -1.8275,  ..., -0.0817, -1.7312,  2.0230],
        [ 1.5230,  0.3577, -0.6214,  ...,  2.1484, -1.8978,  1.8731],
        [ 0.0172,  0.0642, -0.8129,  ..., -0.6853,  0.0893,  0.2986]],
       device='cuda:0', grad_fn=<MmBackward>)

In [282]:
losses = []
train_accs = []
test_accs = []
EPOCHS = 99

index = 0
for epoch in range(EPOCHS):
    if (epoch+1)%10 == 0:
        ps0.inv_temp *= np.e
    
    train_acc = 0
    train_count = 0
    for xx, yy in tqdm(train_loader):
        xx, yy = xx.to(device), yy.to(device)

        yout = model(xx)
        params = ps0(yout)
        yout = psLinear(yout, params[0], params[1])
        
        loss = criterion(yout, yy)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.append(float(loss))

        outputs = torch.argmax(yout, dim=1).data.cpu().numpy()
        correct = (outputs == yy.data.cpu().numpy()).astype(float).sum()
        train_acc += correct
        train_count += len(outputs)

    train_accs.append(float(train_acc)/train_count*100)
    train_acc = 0
    train_count = 0

    print(f'Epoch: {epoch}:{index},  Loss:{float(loss)}')
    test_count = 0
    test_acc = 0
    for xx, yy in tqdm(test_loader):
        xx, yy = xx.to(device), yy.to(device)
        with torch.no_grad():
            yout = model(xx)
            params = ps0(yout)
            yout = psLinear(yout, params[0], params[1])
        outputs = torch.argmax(yout, dim=1).data.cpu().numpy()
        correct = (outputs == yy.data.cpu().numpy()).astype(float).sum()
        test_acc += correct
        test_count += len(xx)
    test_accs.append(float(test_acc)/test_count*100)
    print(f'Train Acc:{train_accs[-1]:.2f}%, Test Acc:{test_accs[-1]:.2f}%')
    print()

### after each class index is finished training
print(f'\t-> Train Acc {max(train_accs)} ; Test Acc {max(test_accs)}')

100%|██████████| 1200/1200 [00:02<00:00, 489.60it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 0:0,  Loss:0.47675830125808716


100%|██████████| 200/200 [00:00<00:00, 666.32it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:83.74%, Test Acc:82.21%



100%|██████████| 1200/1200 [00:02<00:00, 489.57it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 1:0,  Loss:0.3606569170951843


100%|██████████| 200/200 [00:00<00:00, 561.32it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:85.56%, Test Acc:83.37%



100%|██████████| 1200/1200 [00:02<00:00, 495.51it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 2:0,  Loss:0.47073474526405334


100%|██████████| 200/200 [00:00<00:00, 604.04it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:85.80%, Test Acc:83.65%



100%|██████████| 1200/1200 [00:02<00:00, 495.49it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 3:0,  Loss:0.3458858132362366


100%|██████████| 200/200 [00:00<00:00, 557.31it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:85.92%, Test Acc:84.15%



100%|██████████| 1200/1200 [00:02<00:00, 485.23it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 4:0,  Loss:0.3445529639720917


100%|██████████| 200/200 [00:00<00:00, 588.94it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:86.19%, Test Acc:83.18%



100%|██████████| 1200/1200 [00:02<00:00, 487.62it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 5:0,  Loss:0.4272775948047638


100%|██████████| 200/200 [00:00<00:00, 614.99it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:86.17%, Test Acc:84.16%



100%|██████████| 1200/1200 [00:02<00:00, 496.39it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 6:0,  Loss:0.34665462374687195


100%|██████████| 200/200 [00:00<00:00, 625.84it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:86.56%, Test Acc:84.41%



100%|██████████| 1200/1200 [00:02<00:00, 493.92it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 7:0,  Loss:0.3635788857936859


100%|██████████| 200/200 [00:00<00:00, 580.26it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:86.83%, Test Acc:84.28%



100%|██████████| 1200/1200 [00:02<00:00, 494.89it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 8:0,  Loss:0.24878545105457306


100%|██████████| 200/200 [00:00<00:00, 590.98it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:87.02%, Test Acc:84.42%



100%|██████████| 1200/1200 [00:02<00:00, 495.05it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 9:0,  Loss:0.25621384382247925


100%|██████████| 200/200 [00:00<00:00, 590.61it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:87.83%, Test Acc:84.39%



100%|██████████| 1200/1200 [00:02<00:00, 491.11it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 10:0,  Loss:0.21614664793014526


100%|██████████| 200/200 [00:00<00:00, 561.16it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:88.48%, Test Acc:85.90%



100%|██████████| 1200/1200 [00:02<00:00, 486.88it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 11:0,  Loss:0.22510072588920593


100%|██████████| 200/200 [00:00<00:00, 550.39it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:88.99%, Test Acc:86.15%



100%|██████████| 1200/1200 [00:02<00:00, 491.60it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 12:0,  Loss:0.3283124566078186


100%|██████████| 200/200 [00:00<00:00, 618.12it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:89.28%, Test Acc:86.43%



100%|██████████| 1200/1200 [00:02<00:00, 494.42it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 13:0,  Loss:0.2865040898323059


100%|██████████| 200/200 [00:00<00:00, 625.90it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:89.61%, Test Acc:86.84%



100%|██████████| 1200/1200 [00:02<00:00, 496.72it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 14:0,  Loss:0.2653592526912689


100%|██████████| 200/200 [00:00<00:00, 569.40it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:90.08%, Test Acc:86.68%



100%|██████████| 1200/1200 [00:02<00:00, 489.29it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 15:0,  Loss:0.10752894729375839


100%|██████████| 200/200 [00:00<00:00, 571.70it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:90.29%, Test Acc:86.54%



100%|██████████| 1200/1200 [00:02<00:00, 497.09it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 16:0,  Loss:0.11117088049650192


100%|██████████| 200/200 [00:00<00:00, 617.26it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:90.44%, Test Acc:86.96%



100%|██████████| 1200/1200 [00:02<00:00, 492.49it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 17:0,  Loss:0.3045876622200012


100%|██████████| 200/200 [00:00<00:00, 573.81it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:90.65%, Test Acc:87.29%



100%|██████████| 1200/1200 [00:02<00:00, 495.81it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 18:0,  Loss:0.25309476256370544


100%|██████████| 200/200 [00:00<00:00, 609.55it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:90.79%, Test Acc:87.00%



100%|██████████| 1200/1200 [00:02<00:00, 492.76it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 19:0,  Loss:0.1673474907875061


100%|██████████| 200/200 [00:00<00:00, 627.76it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:90.81%, Test Acc:87.49%



100%|██████████| 1200/1200 [00:02<00:00, 494.32it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 20:0,  Loss:0.2232273668050766


100%|██████████| 200/200 [00:00<00:00, 569.07it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:91.52%, Test Acc:87.46%



100%|██████████| 1200/1200 [00:02<00:00, 487.98it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 21:0,  Loss:0.10517089813947678


100%|██████████| 200/200 [00:00<00:00, 615.22it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:91.71%, Test Acc:87.37%



100%|██████████| 1200/1200 [00:02<00:00, 492.97it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 22:0,  Loss:0.35364601016044617


100%|██████████| 200/200 [00:00<00:00, 641.13it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:91.97%, Test Acc:87.31%



100%|██████████| 1200/1200 [00:02<00:00, 495.24it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 23:0,  Loss:0.17272788286209106


100%|██████████| 200/200 [00:00<00:00, 591.41it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:92.13%, Test Acc:87.98%



100%|██████████| 1200/1200 [00:02<00:00, 488.76it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 24:0,  Loss:0.17974160611629486


100%|██████████| 200/200 [00:00<00:00, 616.02it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:92.37%, Test Acc:87.46%



100%|██████████| 1200/1200 [00:02<00:00, 497.46it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 25:0,  Loss:0.1240888237953186


100%|██████████| 200/200 [00:00<00:00, 634.47it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:92.55%, Test Acc:87.62%



100%|██████████| 1200/1200 [00:02<00:00, 483.27it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 26:0,  Loss:0.16366079449653625


100%|██████████| 200/200 [00:00<00:00, 597.14it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:92.69%, Test Acc:87.86%



100%|██████████| 1200/1200 [00:02<00:00, 483.09it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 27:0,  Loss:0.20633214712142944


100%|██████████| 200/200 [00:00<00:00, 639.29it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:92.83%, Test Acc:87.82%



100%|██████████| 1200/1200 [00:02<00:00, 494.45it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 28:0,  Loss:0.16878284513950348


100%|██████████| 200/200 [00:00<00:00, 621.34it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:92.87%, Test Acc:87.68%



100%|██████████| 1200/1200 [00:02<00:00, 499.72it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 29:0,  Loss:0.17518670856952667


100%|██████████| 200/200 [00:00<00:00, 579.11it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:92.44%, Test Acc:87.10%



100%|██████████| 1200/1200 [00:02<00:00, 494.36it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 30:0,  Loss:0.2774760127067566


100%|██████████| 200/200 [00:00<00:00, 608.54it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:93.31%, Test Acc:87.47%



100%|██████████| 1200/1200 [00:02<00:00, 496.54it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 31:0,  Loss:0.073199562728405


100%|██████████| 200/200 [00:00<00:00, 600.35it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:93.63%, Test Acc:87.65%



100%|██████████| 1200/1200 [00:02<00:00, 495.80it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 32:0,  Loss:0.159376323223114


100%|██████████| 200/200 [00:00<00:00, 606.62it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:93.64%, Test Acc:87.59%



100%|██████████| 1200/1200 [00:02<00:00, 489.85it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 33:0,  Loss:0.14734385907649994


100%|██████████| 200/200 [00:00<00:00, 553.87it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:93.83%, Test Acc:87.27%



100%|██████████| 1200/1200 [00:02<00:00, 490.88it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 34:0,  Loss:0.1507767289876938


100%|██████████| 200/200 [00:00<00:00, 603.81it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:93.94%, Test Acc:87.64%



100%|██████████| 1200/1200 [00:02<00:00, 494.05it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 35:0,  Loss:0.14482004940509796


100%|██████████| 200/200 [00:00<00:00, 585.39it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:94.13%, Test Acc:87.86%



100%|██████████| 1200/1200 [00:02<00:00, 488.96it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 36:0,  Loss:0.19169124960899353


100%|██████████| 200/200 [00:00<00:00, 607.07it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:94.22%, Test Acc:87.58%



100%|██████████| 1200/1200 [00:02<00:00, 493.87it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 37:0,  Loss:0.0540228933095932


100%|██████████| 200/200 [00:00<00:00, 621.13it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:94.30%, Test Acc:87.69%



100%|██████████| 1200/1200 [00:02<00:00, 489.87it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 38:0,  Loss:0.1589287370443344


100%|██████████| 200/200 [00:00<00:00, 623.53it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:94.43%, Test Acc:87.51%



100%|██████████| 1200/1200 [00:02<00:00, 497.89it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 39:0,  Loss:0.037339452654123306


100%|██████████| 200/200 [00:00<00:00, 594.98it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:93.54%, Test Acc:87.38%



100%|██████████| 1200/1200 [00:02<00:00, 493.63it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 40:0,  Loss:0.2356013059616089


100%|██████████| 200/200 [00:00<00:00, 602.11it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:94.70%, Test Acc:87.62%



100%|██████████| 1200/1200 [00:02<00:00, 499.35it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 41:0,  Loss:0.06436190754175186


100%|██████████| 200/200 [00:00<00:00, 573.39it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:94.67%, Test Acc:87.48%



100%|██████████| 1200/1200 [00:02<00:00, 497.32it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 42:0,  Loss:0.04429977014660835


100%|██████████| 200/200 [00:00<00:00, 603.55it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:94.84%, Test Acc:87.79%



100%|██████████| 1200/1200 [00:02<00:00, 486.86it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 43:0,  Loss:0.15090852975845337


100%|██████████| 200/200 [00:00<00:00, 642.46it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:94.87%, Test Acc:87.62%



100%|██████████| 1200/1200 [00:02<00:00, 489.65it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 44:0,  Loss:0.11595752835273743


100%|██████████| 200/200 [00:00<00:00, 608.96it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:94.95%, Test Acc:87.68%



100%|██████████| 1200/1200 [00:02<00:00, 498.02it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 45:0,  Loss:0.12410083413124084


100%|██████████| 200/200 [00:00<00:00, 576.05it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:95.18%, Test Acc:87.46%



100%|██████████| 1200/1200 [00:02<00:00, 494.69it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 46:0,  Loss:0.04987560585141182


100%|██████████| 200/200 [00:00<00:00, 587.21it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:95.20%, Test Acc:87.47%



100%|██████████| 1200/1200 [00:02<00:00, 497.78it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 47:0,  Loss:0.07835888117551804


100%|██████████| 200/200 [00:00<00:00, 603.06it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:95.34%, Test Acc:87.48%



100%|██████████| 1200/1200 [00:02<00:00, 490.45it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 48:0,  Loss:0.191451758146286


100%|██████████| 200/200 [00:00<00:00, 564.06it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:95.38%, Test Acc:87.33%



100%|██████████| 1200/1200 [00:02<00:00, 495.64it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 49:0,  Loss:0.025439441204071045


100%|██████████| 200/200 [00:00<00:00, 615.79it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:93.75%, Test Acc:86.93%



100%|██████████| 1200/1200 [00:02<00:00, 497.07it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 50:0,  Loss:0.1508156806230545


100%|██████████| 200/200 [00:00<00:00, 583.08it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:95.23%, Test Acc:87.19%



100%|██████████| 1200/1200 [00:02<00:00, 489.60it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 51:0,  Loss:0.08025030791759491


100%|██████████| 200/200 [00:00<00:00, 624.01it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:95.28%, Test Acc:87.25%



100%|██████████| 1200/1200 [00:02<00:00, 493.10it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 52:0,  Loss:0.14227469265460968


100%|██████████| 200/200 [00:00<00:00, 641.62it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:95.43%, Test Acc:87.11%



100%|██████████| 1200/1200 [00:02<00:00, 497.38it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 53:0,  Loss:0.08965343236923218


100%|██████████| 200/200 [00:00<00:00, 569.63it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:95.44%, Test Acc:87.14%



100%|██████████| 1200/1200 [00:02<00:00, 489.10it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 54:0,  Loss:0.09354958683252335


100%|██████████| 200/200 [00:00<00:00, 557.97it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:95.61%, Test Acc:87.16%



100%|██████████| 1200/1200 [00:02<00:00, 495.73it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 55:0,  Loss:0.16243623197078705


100%|██████████| 200/200 [00:00<00:00, 644.42it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:95.57%, Test Acc:87.40%



100%|██████████| 1200/1200 [00:02<00:00, 495.11it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 56:0,  Loss:0.12530793249607086


100%|██████████| 200/200 [00:00<00:00, 642.56it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:95.82%, Test Acc:87.54%



100%|██████████| 1200/1200 [00:02<00:00, 493.76it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 57:0,  Loss:0.11689524352550507


100%|██████████| 200/200 [00:00<00:00, 652.11it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:95.79%, Test Acc:87.03%



100%|██████████| 1200/1200 [00:02<00:00, 491.12it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 58:0,  Loss:0.09317535161972046


100%|██████████| 200/200 [00:00<00:00, 618.84it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:95.95%, Test Acc:87.48%



100%|██████████| 1200/1200 [00:02<00:00, 501.77it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 59:0,  Loss:0.22001659870147705


100%|██████████| 200/200 [00:00<00:00, 584.26it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:93.17%, Test Acc:86.27%



100%|██████████| 1200/1200 [00:02<00:00, 495.09it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 60:0,  Loss:0.37028244137763977


100%|██████████| 200/200 [00:00<00:00, 598.13it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:94.92%, Test Acc:87.02%



100%|██████████| 1200/1200 [00:02<00:00, 484.90it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 61:0,  Loss:0.16259059309959412


100%|██████████| 200/200 [00:00<00:00, 648.00it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:95.06%, Test Acc:86.82%



100%|██████████| 1200/1200 [00:02<00:00, 495.37it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 62:0,  Loss:0.3513183295726776


100%|██████████| 200/200 [00:00<00:00, 587.81it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:95.31%, Test Acc:86.59%



100%|██████████| 1200/1200 [00:02<00:00, 491.22it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 63:0,  Loss:0.14721189439296722


100%|██████████| 200/200 [00:00<00:00, 640.41it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:95.32%, Test Acc:86.87%



100%|██████████| 1200/1200 [00:02<00:00, 499.96it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 64:0,  Loss:0.10998351871967316


100%|██████████| 200/200 [00:00<00:00, 615.75it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:95.59%, Test Acc:86.57%



100%|██████████| 1200/1200 [00:02<00:00, 489.54it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 65:0,  Loss:0.14430207014083862


100%|██████████| 200/200 [00:00<00:00, 609.33it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:95.42%, Test Acc:86.54%



100%|██████████| 1200/1200 [00:02<00:00, 499.44it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 66:0,  Loss:0.14230819046497345


100%|██████████| 200/200 [00:00<00:00, 642.97it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:95.86%, Test Acc:86.56%



100%|██████████| 1200/1200 [00:02<00:00, 494.80it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 67:0,  Loss:0.0854414850473404


100%|██████████| 200/200 [00:00<00:00, 641.82it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:95.56%, Test Acc:87.22%



100%|██████████| 1200/1200 [00:02<00:00, 491.16it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 68:0,  Loss:0.22787360846996307


100%|██████████| 200/200 [00:00<00:00, 623.54it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:95.88%, Test Acc:86.87%



100%|██████████| 1200/1200 [00:02<00:00, 495.75it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 69:0,  Loss:0.14706678688526154


100%|██████████| 200/200 [00:00<00:00, 571.34it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:92.63%, Test Acc:85.75%



100%|██████████| 1200/1200 [00:02<00:00, 493.46it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 70:0,  Loss:0.1826787143945694


100%|██████████| 200/200 [00:00<00:00, 564.35it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:94.01%, Test Acc:86.49%



100%|██████████| 1200/1200 [00:02<00:00, 494.53it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 71:0,  Loss:0.2227826714515686


100%|██████████| 200/200 [00:00<00:00, 584.87it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:94.28%, Test Acc:86.62%



100%|██████████| 1200/1200 [00:02<00:00, 498.61it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 72:0,  Loss:0.1319226622581482


100%|██████████| 200/200 [00:00<00:00, 576.52it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:94.39%, Test Acc:86.69%



100%|██████████| 1200/1200 [00:02<00:00, 492.45it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 73:0,  Loss:0.13358904421329498


100%|██████████| 200/200 [00:00<00:00, 672.57it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:94.34%, Test Acc:85.91%



100%|██████████| 1200/1200 [00:02<00:00, 498.49it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 74:0,  Loss:0.1712982952594757


100%|██████████| 200/200 [00:00<00:00, 573.29it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:94.42%, Test Acc:86.74%



100%|██████████| 1200/1200 [00:02<00:00, 488.39it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 75:0,  Loss:0.08373043686151505


100%|██████████| 200/200 [00:00<00:00, 631.13it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:94.43%, Test Acc:86.28%



100%|██████████| 1200/1200 [00:02<00:00, 497.98it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 76:0,  Loss:0.17035728693008423


100%|██████████| 200/200 [00:00<00:00, 583.95it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:94.68%, Test Acc:85.96%



100%|██████████| 1200/1200 [00:02<00:00, 491.48it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 77:0,  Loss:0.1492016464471817


100%|██████████| 200/200 [00:00<00:00, 583.35it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:94.68%, Test Acc:86.35%



100%|██████████| 1200/1200 [00:02<00:00, 495.82it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 78:0,  Loss:0.19106394052505493


100%|██████████| 200/200 [00:00<00:00, 622.75it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:94.62%, Test Acc:86.28%



100%|██████████| 1200/1200 [00:02<00:00, 493.68it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 79:0,  Loss:0.32804176211357117


100%|██████████| 200/200 [00:00<00:00, 613.71it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:91.40%, Test Acc:85.29%



100%|██████████| 1200/1200 [00:02<00:00, 497.68it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 80:0,  Loss:0.20903177559375763


100%|██████████| 200/200 [00:00<00:00, 582.61it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:92.60%, Test Acc:84.58%



100%|██████████| 1200/1200 [00:02<00:00, 491.15it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 81:0,  Loss:0.24302244186401367


100%|██████████| 200/200 [00:00<00:00, 578.54it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:92.73%, Test Acc:84.39%



100%|██████████| 1200/1200 [00:02<00:00, 490.25it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 82:0,  Loss:0.07354529201984406


100%|██████████| 200/200 [00:00<00:00, 586.92it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:92.72%, Test Acc:85.92%



100%|██████████| 1200/1200 [00:02<00:00, 494.11it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 83:0,  Loss:0.29980725049972534


100%|██████████| 200/200 [00:00<00:00, 568.16it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:93.13%, Test Acc:84.51%



100%|██████████| 1200/1200 [00:02<00:00, 503.05it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 84:0,  Loss:0.3760280907154083


100%|██████████| 200/200 [00:00<00:00, 604.93it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:93.37%, Test Acc:85.61%



100%|██████████| 1200/1200 [00:02<00:00, 489.74it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 85:0,  Loss:0.08869103342294693


100%|██████████| 200/200 [00:00<00:00, 595.27it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:92.38%, Test Acc:86.37%



100%|██████████| 1200/1200 [00:02<00:00, 493.39it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 86:0,  Loss:0.2547784149646759


100%|██████████| 200/200 [00:00<00:00, 525.47it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:93.16%, Test Acc:85.56%



100%|██████████| 1200/1200 [00:02<00:00, 498.47it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 87:0,  Loss:0.07094269245862961


100%|██████████| 200/200 [00:00<00:00, 635.08it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:93.26%, Test Acc:85.70%



100%|██████████| 1200/1200 [00:02<00:00, 496.44it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 88:0,  Loss:0.29931533336639404


100%|██████████| 200/200 [00:00<00:00, 561.33it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:93.27%, Test Acc:85.59%



100%|██████████| 1200/1200 [00:02<00:00, 480.98it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 89:0,  Loss:0.2365560084581375


100%|██████████| 200/200 [00:00<00:00, 627.30it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:90.72%, Test Acc:85.20%



100%|██████████| 1200/1200 [00:02<00:00, 494.95it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 90:0,  Loss:0.28937116265296936


100%|██████████| 200/200 [00:00<00:00, 610.84it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:90.40%, Test Acc:85.78%



100%|██████████| 1200/1200 [00:02<00:00, 484.81it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 91:0,  Loss:0.2434096485376358


100%|██████████| 200/200 [00:00<00:00, 577.94it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:91.88%, Test Acc:86.33%



100%|██████████| 1200/1200 [00:02<00:00, 486.52it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 92:0,  Loss:0.24414007365703583


100%|██████████| 200/200 [00:00<00:00, 593.89it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:91.92%, Test Acc:85.90%



100%|██████████| 1200/1200 [00:02<00:00, 496.19it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 93:0,  Loss:0.19253669679164886


100%|██████████| 200/200 [00:00<00:00, 646.23it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:92.15%, Test Acc:84.04%



100%|██████████| 1200/1200 [00:02<00:00, 494.06it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 94:0,  Loss:0.26017794013023376


100%|██████████| 200/200 [00:00<00:00, 572.68it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:91.53%, Test Acc:86.06%



100%|██████████| 1200/1200 [00:02<00:00, 489.06it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 95:0,  Loss:0.14197514951229095


100%|██████████| 200/200 [00:00<00:00, 560.47it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:91.74%, Test Acc:85.65%



100%|██████████| 1200/1200 [00:02<00:00, 489.86it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 96:0,  Loss:0.24241207540035248


100%|██████████| 200/200 [00:00<00:00, 628.61it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:92.03%, Test Acc:85.51%



100%|██████████| 1200/1200 [00:02<00:00, 503.32it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 97:0,  Loss:0.41636422276496887


100%|██████████| 200/200 [00:00<00:00, 606.64it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:92.12%, Test Acc:86.33%



100%|██████████| 1200/1200 [00:02<00:00, 498.37it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 98:0,  Loss:0.21769079566001892


100%|██████████| 200/200 [00:00<00:00, 600.49it/s]

Train Acc:91.23%, Test Acc:85.58%

	-> Train Acc 95.94833333333334 ; Test Acc 87.98





In [None]:
##### when trained for 30 epochs
## sets = 10: -> -> Train Acc 95.145 ; Test Acc 89.25
## sets = 2: -> -> Train Acc 95.1367 ; Test Acc 89.26

In [283]:
ps0.inv_temp

8103.0839275753815

### Hard test accuracy with count per classifier

In [284]:
test_count = 0
test_acc = 0
set_count = torch.zeros(ps0.num_sets).to(device)
for xx, yy in tqdm(test_loader):
    xx, yy = xx.to(device), yy.to(device)
    with torch.no_grad():
        yout = model(xx)
        params = ps0(yout, hard=True)
        yout = psLinear(yout, params[0], params[1])
        set_indx, count = torch.unique(torch.argmax(ps0.cls_confidence, dim=1), return_counts=True) 
        set_count[set_indx] += count
    outputs = torch.argmax(yout, dim=1).data.cpu().numpy()
    correct = (outputs == yy.data.cpu().numpy()).astype(float).sum()
    test_acc += correct
    test_count += len(xx)

print(f'Hard Test Acc:{float(test_acc)/test_count*100:.2f}%')
print(set_count.type(torch.long).tolist())

100%|██████████| 200/200 [00:00<00:00, 561.21it/s]

Hard Test Acc:85.54%
[149, 190, 30, 223, 240, 172, 118, 1304, 239, 85, 303, 852, 749, 638, 1396, 275, 1574, 774, 158, 531]





In [285]:
### 100 epochs
# 10 sets == Hard Test Acc:85.05%
# 20 sets == Hard Test Acc:85.54%

### Hard train accuracy with count per classifier

In [286]:
test_count = 0
test_acc = 0
set_count = torch.zeros(ps0.num_sets).to(device)
for xx, yy in tqdm(train_loader):
    xx, yy = xx.to(device), yy.to(device)
    with torch.no_grad():
        yout = model(xx)
        params = ps0(yout, hard=True)
        yout = psLinear(yout, params[0], params[1])
        set_indx, count = torch.unique(torch.argmax(ps0.cls_confidence, dim=1), return_counts=True) 
        set_count[set_indx] += count
    outputs = torch.argmax(yout, dim=1).data.cpu().numpy()
    correct = (outputs == yy.data.cpu().numpy()).astype(float).sum()
    test_acc += correct
    test_count += len(xx)

print(f'Hard Train Acc:{float(test_acc)/test_count*100:.2f}%')
print(set_count.type(torch.long).tolist())

100%|██████████| 1200/1200 [00:01<00:00, 809.02it/s]

Hard Train Acc:92.04%
[904, 1165, 192, 1488, 1451, 1100, 690, 7644, 1281, 479, 1861, 4823, 4466, 4009, 8821, 1582, 9592, 4464, 963, 3025]





In [201]:
#### Classifiers that enclose any data
torch.count_nonzero(set_count)

tensor(2, device='cuda:0')

In [50]:
#### classifier with class representation
torch.argmax(classifier.cls_weight, dim=1)

tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3,
        4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7,
        8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1,
        2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5,
        6, 7, 8, 9], device='cuda:0')

In [51]:
# The class labels are same as that of initialized
# tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3,
#         4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7,
#         8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1,
#         2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5,
#         6, 7, 8, 9], device='cuda:0')

In [52]:
classifier.cls_weight

Parameter containing:
tensor([[6.2399e-01, 4.8636e-06, 9.7206e-05, 8.1379e-05, 4.9807e-05, 5.0686e-06,
         3.7574e-01, 9.6449e-06, 7.1354e-06, 8.5779e-06],
        [5.1824e-05, 9.9977e-01, 1.3589e-05, 2.3619e-05, 2.1795e-05, 2.3727e-05,
         2.6346e-05, 1.3267e-05, 3.7349e-05, 1.6489e-05],
        [1.8427e-04, 7.4690e-05, 6.3821e-01, 3.0832e-05, 1.6561e-01, 5.7689e-05,
         1.9566e-01, 4.7485e-05, 7.2066e-05, 4.7959e-05],
        [1.9669e-04, 1.7147e-04, 1.3287e-04, 9.9875e-01, 1.9478e-04, 1.0700e-04,
         7.7105e-05, 1.0586e-04, 1.7403e-04, 9.5043e-05],
        [9.4674e-05, 3.5800e-05, 2.7890e-01, 4.9266e-05, 6.5840e-01, 3.3135e-05,
         6.2412e-02, 2.1661e-05, 2.8074e-05, 2.7890e-05],
        [3.1471e-04, 3.0296e-04, 2.9542e-04, 2.8826e-04, 3.0731e-04, 9.9767e-01,
         2.7568e-04, 1.4418e-04, 2.7908e-04, 1.2719e-04],
        [3.7181e-05, 2.0084e-04, 2.1955e-01, 1.2258e-04, 2.1433e-01, 1.9574e-04,
         5.6498e-01, 2.0206e-04, 1.8711e-04, 2.0165e-04],
     

In [53]:
# torch.unique(torch.argmax(classifier.cls_confidence, dim=1), return_counts=True)

In [54]:
classifier.inv_temp

Parameter containing:
tensor([3.0942], device='cuda:0', requires_grad=True)

In [55]:
### example output per classifier
yout[5]

tensor([1.9399e-04, 1.7627e-04, 1.3068e-04, 9.9873e-01, 2.0075e-04, 1.0751e-04,
        7.0551e-05, 1.0562e-04, 1.8927e-04, 9.5756e-05], device='cuda:0')

In [134]:
asdfsdf ## to break the code

NameError: name 'asdfsdf' is not defined

### analyze per classifier accuracy

In [158]:
test_count = 0
test_acc = 0
set_count = torch.zeros(classifier.num_sets).to(device)
set_acc = torch.zeros(classifier.num_sets).to(device)
for xx, yy in tqdm(train_loader):
    xx, yy = xx.to(device), yy.to(device)
    with torch.no_grad():
        yout = classifier(model(xx), hard=True)
        
    cls_indx = torch.argmax(classifier.cls_confidence, dim=1)
    set_indx, count = torch.unique(cls_indx, return_counts=True) 
    set_count[set_indx] += count
    
    outputs = torch.argmax(yout, dim=1).data.cpu().numpy()
    correct = (outputs == yy.data.cpu().numpy()).astype(float)
    
    ### class_index has 100 possible values
    for i, c in enumerate(correct):
        set_acc[cls_indx[i]] += c
    
#     print(set_acc.sum(), set_count.sum())
#     break
    test_acc += correct.sum()
    test_count += len(xx)

print(f'Hard Train Acc:{float(test_acc)/test_count*100:.2f}%')
print(set_count.type(torch.long).tolist())

100%|██████████| 1200/1200 [00:19<00:00, 60.60it/s]

Hard Train Acc:94.22%
[4, 70, 5920, 0, 19, 0, 12, 3, 0, 2, 5, 0, 0, 6080, 0, 0, 0, 5984, 50, 0, 1, 7, 0, 8, 1, 1, 13, 0, 0, 14, 36, 1, 0, 1, 0, 971, 1876, 5, 7, 1, 71, 0, 12, 32, 0, 0, 5819, 5, 11, 0, 162, 0, 77, 6092, 37, 6243, 5, 343, 542, 205, 7, 1, 13, 16, 3, 5572, 3846, 7, 3197, 12, 0, 1, 1, 1, 0, 0, 6, 132, 4, 0, 32, 80, 25, 11, 0, 19, 37, 1, 0, 1, 0, 188, 0, 0, 48, 0, 3, 5971, 17, 0]





In [161]:
set_acc/set_count

tensor([1.0000, 0.5714, 0.9965,    nan, 0.5789,    nan, 1.0000, 0.6667,    nan,
        0.0000, 1.0000,    nan,    nan, 0.9704,    nan,    nan,    nan, 0.9883,
        0.5200,    nan, 1.0000, 0.8571,    nan, 1.0000, 1.0000, 0.0000, 0.1538,
           nan,    nan, 0.5714, 1.0000, 1.0000,    nan, 1.0000,    nan, 1.0000,
        0.7820, 0.2000, 0.2857, 0.0000, 1.0000,    nan, 1.0000, 1.0000,    nan,
           nan, 0.9443, 0.0000, 0.5455,    nan, 0.9877,    nan, 0.9870, 0.9703,
        0.2703, 0.9031, 0.6000, 0.9883, 1.0000, 0.7171, 0.4286, 0.0000, 1.0000,
        0.5000, 1.0000, 0.9248, 0.9280, 0.7143, 0.9984, 0.3333,    nan, 1.0000,
        0.0000, 1.0000,    nan,    nan, 0.6667, 0.5076, 0.2500,    nan, 0.5625,
        0.9750, 0.2800, 0.1818,    nan, 0.8947, 1.0000, 0.0000,    nan, 0.0000,
           nan, 0.9840,    nan,    nan, 0.9167,    nan, 0.3333, 0.8970, 0.2941,
           nan], device='cuda:0')

In [162]:
print(f"Index\tNumData\tClass\tAccuracy")
for i, (cnt, acc, cls) in enumerate(zip(set_count.type(torch.long).tolist(),
                                   (set_acc/set_count).tolist(),
                                   torch.argmax(classifier.cls_weight, dim=1).tolist())):
    if cnt == 0: continue
    print(f"{i}\t {cnt}\t {cls}\t {acc*100:.2f}%")

Index	NumData	Class	Accuracy
0	 4	 5	 100.00%
1	 70	 3	 57.14%
2	 5920	 1	 99.65%
4	 19	 4	 57.89%
6	 12	 5	 100.00%
7	 3	 4	 66.67%
9	 2	 6	 0.00%
10	 5	 6	 100.00%
13	 6080	 7	 97.04%
17	 5984	 8	 98.83%
18	 50	 3	 52.00%
20	 1	 4	 100.00%
21	 7	 5	 85.71%
23	 8	 5	 100.00%
24	 1	 6	 100.00%
25	 1	 7	 0.00%
26	 13	 6	 15.38%
29	 14	 6	 57.14%
30	 36	 5	 100.00%
31	 1	 4	 100.00%
33	 1	 4	 100.00%
35	 971	 5	 100.00%
36	 1876	 6	 78.20%
37	 5	 4	 20.00%
38	 7	 6	 28.57%
39	 1	 7	 0.00%
40	 71	 5	 100.00%
42	 12	 5	 100.00%
43	 32	 5	 100.00%
46	 5819	 3	 94.43%
47	 5	 8	 0.00%
48	 11	 4	 54.55%
50	 162	 5	 98.77%
52	 77	 5	 98.70%
53	 6092	 9	 97.03%
54	 37	 6	 27.03%
55	 6243	 0	 90.31%
56	 5	 4	 60.00%
57	 343	 5	 98.83%
58	 542	 5	 100.00%
59	 205	 3	 71.71%
60	 7	 6	 42.86%
61	 1	 4	 0.00%
62	 13	 5	 100.00%
63	 16	 6	 50.00%
64	 3	 5	 100.00%
65	 5572	 2	 92.48%
66	 3846	 6	 92.80%
67	 7	 4	 71.43%
68	 3197	 5	 99.84%
69	 12	 6	 33.33%
71	 1	 6	 100.00%
72	 1	 5	 0.00%
73	 1	 6	 