In [1]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from mpl_toolkits.mplot3d import Axes3D

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils import data

import random, os, pathlib, time
from tqdm import tqdm
from sklearn import datasets

In [2]:
import nflib
from nflib.flows import SequentialFlow, NormalizingFlow, ActNorm, AffineConstantFlow
import nflib.coupling_flows as icf
import nflib.inn_flow as inn
import nflib.res_flow as irf

In [3]:
from torch import distributions
from torch.distributions import MultivariateNormal

## MNIST dataset

In [4]:
import mylibrary.datasets as datasets
import mylibrary.nnlib as tnn

In [5]:
mnist = datasets.MNIST()
train_data, train_label_, test_data, test_label_ = mnist.load()

train_data = train_data / 255.
test_data = test_data / 255.

# train_label = tnn.Logits.index_to_logit(train_label_)
train_size = len(train_label_)

In [6]:
## converting data to pytorch format
train_data = torch.Tensor(train_data)
test_data = torch.Tensor(test_data)
train_label = torch.LongTensor(train_label_)
test_label = torch.LongTensor(test_label_)

In [7]:
input_size = 784
output_size = 10

In [8]:
class MNIST_Dataset(data.Dataset):
    
    def __init__(self, data, label):
        self.data = data
        self.label = label
        
#         self.label = mask.type(torch.float32).reshape(-1,1)
        self._shuffle_data_()
        
    def __len__(self):
        return len(self.data)
    
    def _shuffle_data_(self):
        randidx = random.sample(range(len(self.data)), k=len(self.data))
        self.data = self.data[randidx]
        self.label = self.label[randidx]
    
    def __getitem__(self, idx):
        img, lbl = self.data[idx], self.label[idx]
        return img, lbl

In [9]:
train_dataset = MNIST_Dataset(train_data, train_label)
test_dataset = MNIST_Dataset(test_data, test_label)

In [10]:
class ConnectedClassifier_Softmax(nn.Module):
    
    def __init__(self,input_dim, num_sets, output_dim, inv_temp=1):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.num_sets = num_sets
        self.inv_temp = nn.Parameter(torch.ones(1)*inv_temp)
        
        self.linear = nn.Linear(input_dim, num_sets)
        self.linear.bias.data *= 0
        self.linear.weight.data *= 0.1
        self.cls_weight = nn.Parameter(torch.ones(num_sets, output_dim)/output_dim)
        self.cls_confidence = None
        
        
    def forward(self, x, hard=True):
        x = self.linear(x)
        if hard:
            x = torch.softmax(-x*1e5, dim=1)
        else:
            x = torch.softmax(-x*self.inv_temp, dim=1)
        self.cls_confidence = x
        c = torch.softmax(self.cls_weight, dim=1)
#         c = self.cls_weight
        return x@c ## since both are normalized, it is also normalized

In [11]:
class ConnectedClassifier_SoftKMeans(nn.Module):
    
    def __init__(self,input_dim, num_sets, output_dim, inv_temp=1):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.num_sets = num_sets
        self.inv_temp = nn.Parameter(torch.ones(1)*inv_temp)
        
        self.centers = nn.Parameter(torch.rand(num_sets, input_dim)*2-1)
        
#         self.cls_weight = nn.Parameter(torch.ones(num_sets, output_dim)/output_dim)

        init_val = torch.randn(num_sets, output_dim)*0.01
        for ns in range(num_sets):
            init_val[ns, ns%output_dim] = 10.
        self.cls_weight = nn.Parameter(init_val)

        self.cls_confidence = None
        
        
    def forward(self, x, hard=False):
        x = x[:, :self.input_dim]
        dists = torch.cdist(x, self.centers)
        dists = dists/np.sqrt(self.input_dim) ### correction to make diagonal of unit square 1 in nD space
        
        if hard:
            x = torch.softmax(-dists*1e5, dim=1)
        else:
            x = torch.softmax(-dists*self.inv_temp, dim=1)
        self.cls_confidence = x
        c = torch.softmax(self.cls_weight, dim=1)
#         c = self.cls_weight
        return x@c ## since both are normalized, it is also normalized

In [12]:
actf = irf.Swish
flows = [
    ActNorm(784),
    irf.ResidualFlow(784, [784], activation=actf),
    ActNorm(784),
    irf.ResidualFlow(784, [784], activation=actf),
    ActNorm(784),
        ]

model = SequentialFlow(flows)

In [13]:
# model = nn.Sequential(nn.Linear(784, 784, bias=False),
#                       nn.BatchNorm1d(784),
#                       nn.SELU(),
#                       nn.Linear(784, 784, bias=False),
#                       nn.BatchNorm1d(784),
#                       nn.SELU(),
#                      )

In [14]:
model

SequentialFlow(
  (flows): ModuleList(
    (0): ActNorm()
    (1): ResidualFlow(
      (resblock): ModuleList(
        (0): Linear(in_features=784, out_features=784, bias=True)
        (1): Swish()
        (2): Linear(in_features=784, out_features=784, bias=True)
      )
    )
    (2): ActNorm()
    (3): ResidualFlow(
      (resblock): ModuleList(
        (0): Linear(in_features=784, out_features=784, bias=True)
        (1): Swish()
        (2): Linear(in_features=784, out_features=784, bias=True)
      )
    )
    (4): ActNorm()
  )
)

In [15]:
classifier = ConnectedClassifier_SoftKMeans(784, 100, 10)
# classifier = ConnectedClassifier_Softmax(784, 10, 10)

## Model Train

In [16]:
learning_rate = 0.001
batch_size = 50

In [17]:
train_loader = data.DataLoader(dataset=train_dataset, num_workers=4, batch_size=batch_size, shuffle=True)
test_loader = data.DataLoader(dataset=test_dataset, num_workers=4, batch_size=batch_size, shuffle=False)

In [18]:
# criterion = nn.NLLLoss()
criterion = nn.CrossEntropyLoss()

optimizer = optim.Adam(list(model.parameters())+list(classifier.parameters()),
                       lr=learning_rate, weight_decay=1e-15) # todo tune WD
# optimizer = optim.SGD(model.parameters(), lr=0.1)

print("number of params: ", sum(p.numel() for p in model.parameters()))

number of params:  2466466


In [19]:
for p in model.parameters():
    print(torch.isnan(p).type(torch.float32).sum())

tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)


In [20]:
model(torch.randn(10, 784))

tensor([[ 5.6462e-01, -8.2422e-01,  8.1742e-01,  ...,  4.4810e-01,
          1.2630e+00, -4.7599e-01],
        [-5.1802e-01,  1.9383e-01, -1.3537e+00,  ...,  9.6932e-01,
          4.4705e-01, -8.1837e-01],
        [ 3.3639e-01,  5.7598e-01, -9.1273e-01,  ..., -1.2152e+00,
         -7.9978e-01,  4.9138e-01],
        ...,
        [ 3.1210e-01, -9.5864e-01,  3.9177e-01,  ...,  3.7570e-01,
          8.2501e-01,  6.3896e-01],
        [-1.1926e+00,  1.1075e+00, -1.1066e+00,  ..., -4.7953e-01,
          3.2403e-01,  1.3781e+00],
        [ 1.4601e-03,  2.0014e+00,  2.3595e-01,  ...,  8.9468e-01,
         -4.2698e-01, -3.1332e-01]], grad_fn=<AddBackward0>)

In [21]:
xx = iter(test_loader).next()[0]
xx.shape

torch.Size([50, 784])

In [22]:
losses = []
train_accs = []
test_accs = []
EPOCHS = 5

index = 0
for epoch in range(EPOCHS):
    train_acc = 0
    train_count = 0
    for xx, yy in tqdm(train_loader):
#     for xx, yy in tqdm(test_loader):

        yout = model(xx)
#         print(yout)
        yout = classifier(yout)    
#         print(yout)
        loss = criterion(yout, yy)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.append(float(loss))

        outputs = torch.argmax(yout, dim=1).data.cpu().numpy()
        correct = (outputs == yy.data.cpu().numpy()).astype(float).sum()
        train_acc += correct
        train_count += len(outputs)
#         break

    train_accs.append(float(train_acc)/train_count*100)
    train_acc = 0
    train_count = 0

    print(f'Epoch: {epoch}:{index},  Loss:{float(loss)}')
    test_count = 0
    test_acc = 0
    for xx, yy in tqdm(test_loader):
        with torch.no_grad():
            yout = classifier(model(xx))    
        outputs = torch.argmax(yout, dim=1).data.cpu().numpy()
        correct = (outputs == yy.data.cpu().numpy()).astype(float).sum()
        test_acc += correct
        test_count += len(xx)
    test_accs.append(float(test_acc)/test_count*100)
    print(f'Train Acc:{train_accs[-1]:.2f}%, Test Acc:{test_accs[-1]:.2f}%')
    print()

### after each class index is finished training
print(f'\t-> Train Acc {max(train_accs)} ; Test Acc {max(test_accs)}')

100%|██████████| 1200/1200 [01:17<00:00, 15.51it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 0:0,  Loss:1.747867465019226


100%|██████████| 200/200 [00:02<00:00, 67.87it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:60.77%, Test Acc:92.35%



100%|██████████| 1200/1200 [01:16<00:00, 15.72it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 1:0,  Loss:1.5182769298553467


100%|██████████| 200/200 [00:03<00:00, 66.13it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:94.60%, Test Acc:96.47%



100%|██████████| 1200/1200 [01:14<00:00, 16.19it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 2:0,  Loss:1.5134742259979248


100%|██████████| 200/200 [00:02<00:00, 68.87it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:96.92%, Test Acc:97.38%



100%|██████████| 1200/1200 [01:16<00:00, 15.67it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 3:0,  Loss:1.48512601852417


100%|██████████| 200/200 [00:03<00:00, 62.93it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:97.75%, Test Acc:97.39%



100%|██████████| 1200/1200 [01:19<00:00, 15.11it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 4:0,  Loss:1.5048245191574097


100%|██████████| 200/200 [00:03<00:00, 61.66it/s]

Train Acc:98.14%, Test Acc:97.78%

	-> Train Acc 98.14333333333335 ; Test Acc 97.78





In [23]:
# classifier.cls_weight

In [24]:
test_count = 0
test_acc = 0
set_count = torch.zeros(classifier.num_sets)
for xx, yy in tqdm(test_loader):
    with torch.no_grad():
        yout = classifier(model(xx), hard=True)
        set_indx, count = torch.unique(torch.argmax(classifier.cls_confidence, dim=1), return_counts=True) 
        set_count[set_indx] += count
    outputs = torch.argmax(yout, dim=1).data.cpu().numpy()
    correct = (outputs == yy.data.cpu().numpy()).astype(float).sum()
    test_acc += correct
    test_count += len(xx)

print(f'Hard Test Acc:{float(test_acc)/test_count*100:.2f}%')
print(set_count.type(torch.long).tolist())

100%|██████████| 200/200 [00:03<00:00, 61.42it/s]

Hard Test Acc:97.77%
[0, 0, 0, 0, 961, 0, 0, 8, 0, 0, 0, 0, 10, 1, 0, 0, 0, 2, 2, 2, 0, 7, 0, 0, 0, 0, 0, 0, 9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1012, 0, 1, 0, 0, 0, 0, 1, 8, 0, 0, 1132, 1010, 0, 6, 0, 5, 0, 4, 0, 0, 0, 0, 3, 0, 2, 0, 3, 942, 5, 0, 0, 0, 1, 0, 893, 930, 1037, 0, 0, 0, 0, 0, 995, 0, 0, 0, 0, 0, 1006, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0]





In [25]:
test_count = 0
test_acc = 0
set_count = torch.zeros(classifier.num_sets)
for xx, yy in tqdm(train_loader):
    with torch.no_grad():
        yout = classifier(model(xx), hard=True)
        set_indx, count = torch.unique(torch.argmax(classifier.cls_confidence, dim=1), return_counts=True) 
        set_count[set_indx] += count
    outputs = torch.argmax(yout, dim=1).data.cpu().numpy()
    correct = (outputs == yy.data.cpu().numpy()).astype(float).sum()
    test_acc += correct
    test_count += len(xx)

print(f'Hard Train Acc:{float(test_acc)/test_count*100:.2f}%')
print(set_count.type(torch.long).tolist())

100%|██████████| 1200/1200 [00:18<00:00, 65.55it/s]

Hard Train Acc:98.52%
[0, 0, 0, 0, 5787, 0, 0, 16, 0, 3, 0, 0, 30, 0, 0, 0, 0, 2, 13, 17, 0, 18, 0, 13, 0, 0, 0, 0, 86, 0, 0, 0, 0, 30, 0, 0, 0, 0, 2, 2, 5985, 0, 18, 1, 0, 0, 0, 3, 36, 1, 20, 6710, 5904, 1, 16, 0, 13, 1, 6, 0, 0, 0, 0, 14, 0, 13, 0, 12, 5676, 13, 0, 0, 0, 0, 0, 5408, 5858, 6328, 0, 0, 0, 0, 0, 5972, 0, 0, 0, 0, 0, 5969, 0, 0, 1, 0, 0, 0, 0, 0, 2, 0]





In [26]:
torch.count_nonzero(set_count)

tensor(39)

In [27]:
torch.argmax(classifier.cls_weight, dim=1)

tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3,
        4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7,
        8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1,
        2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5,
        6, 7, 8, 9])

In [28]:
classifier.cls_weight

Parameter containing:
tensor([[ 3.2481, -1.9120, -1.8775, -1.8452, -2.0118, -1.8551, -1.8489, -1.9447,
         -1.9533, -2.0220],
        [-2.2413,  3.4531, -2.1003, -2.0808, -2.2142, -2.2052, -2.1769, -2.0775,
         -2.1191, -2.2117],
        [-2.2893, -1.9798,  3.5763, -2.0842, -2.3458, -2.3421, -2.1180, -2.2457,
         -2.1323, -2.3975],
        [-2.4627, -2.3062, -2.3605,  3.8263, -2.5696, -2.2584, -2.5298, -2.5112,
         -2.3591, -2.5930],
        [-3.9166, -3.7967, -3.8445, -3.8763,  5.1556, -3.8806, -3.7623, -3.6258,
         -3.8696, -3.1604],
        [-1.9905, -2.1742, -2.2445, -1.7629, -2.2293,  3.5023, -2.1339, -2.2450,
         -1.9059, -2.2329],
        [-2.0167, -2.0158, -1.9016, -2.0930, -2.0277, -2.0935,  3.4135, -2.1528,
         -2.0537, -2.1572],
        [-3.0335, -2.8421, -2.9603, -2.9363, -2.8616, -3.0288, -3.0203,  4.2774,
         -2.9650, -2.5665],
        [-2.1527, -1.9891, -2.0262, -1.9953, -2.2812, -2.0870, -2.1779, -2.2314,
          3.5292, -2.2495

In [29]:
# torch.unique(torch.argmax(classifier.cls_confidence, dim=1), return_counts=True)

In [30]:
classifier.inv_temp

Parameter containing:
tensor([4.9740], requires_grad=True)

In [31]:
yout[5]

tensor([1.3953e-04, 1.7997e-04, 9.9867e-01, 1.6720e-04, 1.2971e-04, 1.2782e-04,
        1.6456e-04, 1.3638e-04, 1.5945e-04, 1.2579e-04])

In [32]:
asdfsdf

NameError: name 'asdfsdf' is not defined

In [None]:
#### analyze per classifier accuracy

In [33]:
test_count = 0
test_acc = 0
set_count = torch.zeros(classifier.num_sets)
set_acc = torch.zeros(classifier.num_sets)
for xx, yy in tqdm(train_loader):
    with torch.no_grad():
        yout = classifier(model(xx), hard=True)
        
    cls_indx = torch.argmax(classifier.cls_confidence, dim=1)
    set_indx, count = torch.unique(cls_indx, return_counts=True) 
    set_count[set_indx] += count
    
    outputs = torch.argmax(yout, dim=1).data.cpu().numpy()
    correct = (outputs == yy.data.cpu().numpy()).astype(float)
    
    ### class_index has 100 possible values
    for i, c in enumerate(correct):
        set_acc[cls_indx[i]] += c
    
#     print(set_acc.sum(), set_count.sum())
#     break
    test_acc += correct.sum()
    test_count += len(xx)

print(f'Hard Train Acc:{float(test_acc)/test_count*100:.2f}%')
print(set_count.type(torch.long).tolist())

100%|██████████| 1200/1200 [00:18<00:00, 64.51it/s]

Hard Train Acc:98.52%
[0, 0, 0, 0, 5787, 0, 0, 16, 0, 3, 0, 0, 30, 0, 0, 0, 0, 2, 13, 17, 0, 18, 0, 13, 0, 0, 0, 0, 86, 0, 0, 0, 0, 30, 0, 0, 0, 0, 2, 2, 5985, 0, 18, 1, 0, 0, 0, 3, 36, 1, 20, 6710, 5904, 1, 16, 0, 13, 1, 6, 0, 0, 0, 0, 14, 0, 13, 0, 12, 5676, 13, 0, 0, 0, 0, 0, 5408, 5858, 6328, 0, 0, 0, 0, 0, 5972, 0, 0, 0, 0, 0, 5969, 0, 0, 1, 0, 0, 0, 0, 0, 2, 0]





In [34]:
set_acc/set_count

tensor([   nan,    nan,    nan,    nan, 0.9893,    nan,    nan, 0.3750,    nan,
        0.0000,    nan,    nan, 0.7000,    nan,    nan,    nan,    nan, 0.0000,
        0.8462, 0.3529,    nan, 0.3333,    nan, 0.6154,    nan,    nan,    nan,
           nan, 0.8721,    nan,    nan,    nan,    nan, 0.6667,    nan,    nan,
           nan,    nan, 0.0000, 0.5000, 0.9828,    nan, 0.6111, 0.0000,    nan,
           nan,    nan, 0.0000, 0.8333, 0.0000, 0.5500, 0.9931, 0.9861, 0.0000,
        0.5000,    nan, 0.8462, 1.0000, 0.3333,    nan,    nan,    nan,    nan,
        0.9286,    nan, 0.3077,    nan, 0.5000, 0.9919, 0.4615,    nan,    nan,
           nan,    nan,    nan, 0.9871, 0.9947, 0.9776,    nan,    nan,    nan,
           nan,    nan, 0.9950,    nan,    nan,    nan,    nan,    nan, 0.9776,
           nan,    nan, 1.0000,    nan,    nan,    nan,    nan,    nan, 0.5000,
           nan])

In [35]:
for i, (cnt, acc) in enumerate(zip(set_count.type(torch.long).tolist(), (set_acc/set_count).tolist())):
    if cnt == 0: continue
    print(f"{i},\t {cnt},\t {acc*100:.2f}%")

4,	 5787,	 98.93%
7,	 16,	 37.50%
9,	 3,	 0.00%
12,	 30,	 70.00%
17,	 2,	 0.00%
18,	 13,	 84.62%
19,	 17,	 35.29%
21,	 18,	 33.33%
23,	 13,	 61.54%
28,	 86,	 87.21%
33,	 30,	 66.67%
38,	 2,	 0.00%
39,	 2,	 50.00%
40,	 5985,	 98.28%
42,	 18,	 61.11%
43,	 1,	 0.00%
47,	 3,	 0.00%
48,	 36,	 83.33%
49,	 1,	 0.00%
50,	 20,	 55.00%
51,	 6710,	 99.31%
52,	 5904,	 98.61%
53,	 1,	 0.00%
54,	 16,	 50.00%
56,	 13,	 84.62%
57,	 1,	 100.00%
58,	 6,	 33.33%
63,	 14,	 92.86%
65,	 13,	 30.77%
67,	 12,	 50.00%
68,	 5676,	 99.19%
69,	 13,	 46.15%
75,	 5408,	 98.71%
76,	 5858,	 99.47%
77,	 6328,	 97.76%
83,	 5972,	 99.50%
89,	 5969,	 97.76%
92,	 1,	 100.00%
98,	 2,	 50.00%
