In [1]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from mpl_toolkits.mplot3d import Axes3D

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils import data

import random, os, pathlib, time
from tqdm import tqdm
from sklearn import datasets

In [2]:
import nflib
from nflib.flows import SequentialFlow, NormalizingFlow, ActNorm, ActNorm2D, AffineConstantFlow
import nflib.coupling_flows as icf
import nflib.inn_flow as inn
import nflib.res_flow as irf

In [3]:
from torch import distributions
from torch.distributions import MultivariateNormal

In [4]:
device = torch.device("cuda:1")
# device = torch.device("cpu")

## MNIST dataset

In [5]:
import mylibrary.datasets as datasets
# import mylibrary.nnlib as tnn

In [6]:
mnist = datasets.FashionMNIST()
# mnist.download_mnist()
# mnist.save_mnist()
train_data, train_label_, test_data, test_label_ = mnist.load()

train_data = train_data / 255.
test_data = test_data / 255.

# train_label = tnn.Logits.index_to_logit(train_label_)
train_size = len(train_label_)

In [7]:
## converting data to pytorch format
train_data = torch.Tensor(train_data)
test_data = torch.Tensor(test_data)
train_label = torch.LongTensor(train_label_)
test_label = torch.LongTensor(test_label_)

In [8]:
input_size = 784
output_size = 10

In [9]:
class MNIST_Dataset(data.Dataset):
    
    def __init__(self, data, label):
        self.data = data.reshape(-1, 1, 28, 28)
        self.label = label
        
#         self.label = mask.type(torch.float32).reshape(-1,1)
        self._shuffle_data_()
        
    def __len__(self):
        return len(self.data)
    
    def _shuffle_data_(self):
        randidx = random.sample(range(len(self.data)), k=len(self.data))
        self.data = self.data[randidx]
        self.label = self.label[randidx]
    
    def __getitem__(self, idx):
        img, lbl = self.data[idx], self.label[idx]
        return img, lbl

In [10]:
train_dataset = MNIST_Dataset(train_data, train_label)
test_dataset = MNIST_Dataset(test_data, test_label)

In [11]:
class ConnectedClassifier_Linear(nn.Module):
    
    def __init__(self,input_dim, num_sets, output_dim, inv_temp=1):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.num_sets = num_sets
        self.inv_temp = nn.Parameter(torch.ones(1)*inv_temp)
        
        self.linear = nn.Linear(input_dim, num_sets)
#         self.linear.bias.data *= 0
#         self.linear.weight.data *= 0.1
#         self.cls_weight = nn.Parameter(torch.randn(num_sets, output_dim)/output_dim)

        init_val = torch.randn(num_sets, output_dim)*0.01
        for ns in range(num_sets):
            init_val[ns, ns%output_dim] = 0.1
        self.cls_weight = nn.Parameter(init_val)
        
        self.cls_confidence = None
        
        
    def forward(self, x, hard=False):
#         self.cls_weight.data = torch.abs(self.cls_weight.data/self.cls_weight.data.sum(dim=1, keepdim=True))
        
        x = self.linear(x)
        if hard:
            x = torch.softmax(x*1e5, dim=1)
        else:
            x = torch.softmax(x*self.inv_temp, dim=1)
        self.cls_confidence = x
        c = torch.softmax(self.cls_weight, dim=1)
#         c = self.cls_weight
        return x@c ## since both are normalized, it is also normalized

In [12]:
class ConnectedClassifier_SoftKMeans(nn.Module):
    
    def __init__(self,input_dim, num_sets, output_dim, inv_temp=1):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.num_sets = num_sets
        self.inv_temp = nn.Parameter(torch.ones(1)*inv_temp)
        
        self.centers = nn.Parameter(torch.rand(num_sets, input_dim)*2-1)
        
#         self.cls_weight = nn.Parameter(torch.ones(num_sets, output_dim)/output_dim)

        init_val = torch.randn(num_sets, output_dim)*0.01
        for ns in range(num_sets):
            init_val[ns, ns%output_dim] = 0.1
        self.cls_weight = nn.Parameter(init_val)

        self.cls_confidence = None
        
        
    def forward(self, x, hard=False):
#         self.cls_weight.data = torch.abs(self.cls_weight.data/self.cls_weight.data.sum(dim=1, keepdim=True))
        
        x = x[:, :self.input_dim]
        dists = torch.cdist(x, self.centers)
        dists = dists/np.sqrt(self.input_dim) ### correction to make diagonal of unit square 1 in nD space
        
        if hard:
            x = torch.softmax(-dists*1e5, dim=1)
        else:
            x = torch.softmax(-dists*self.inv_temp, dim=1)
        self.cls_confidence = x
        c = torch.softmax(self.cls_weight, dim=1)
#         c = self.cls_weight
        return x@c ## since both are normalized, it is also normalized
#         return torch.softmax(x@self.cls_weight, dim=1)

In [13]:
# actf = irf.Swish
# flows = [
#     ActNorm(784),
#     irf.ResidualFlow(784, [784], activation=actf),
#     ActNorm(784),
#     irf.ResidualFlow(784, [784], activation=actf),
#     ActNorm(784),
#         ]

# model = SequentialFlow(flows)

In [14]:
actf = irf.Swish
flows = [
    ActNorm2D(1),
    irf.ConvResidualFlow(1, [16], activation=actf),
    irf.InvertiblePooling(2),
    ActNorm2D(4),
    irf.ConvResidualFlow(4, [64], activation=actf),
    irf.InvertiblePooling(2),
    ActNorm2D(16),
    irf.ConvResidualFlow(16, [64, 64], activation=actf),
    irf.Flatten(img_size=(16, 7, 7))
        ]

model = SequentialFlow(flows)

In [15]:
# model = nn.Sequential(nn.Linear(784, 784, bias=False),
#                       nn.BatchNorm1d(784),
#                       nn.SELU(),
#                       nn.Linear(784, 784, bias=False),
#                       nn.BatchNorm1d(784),
#                       nn.SELU(),
#                      )

In [16]:
# model(torch.randn(3, 1, 28, 28)).shape

In [17]:
model.to(device)

SequentialFlow(
  (flows): ModuleList(
    (0): ActNorm2D()
    (1): ConvResidualFlow(
      (resblock): ModuleList(
        (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): Swish()
        (2): Conv2d(16, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
    )
    (2): InvertiblePooling()
    (3): ActNorm2D()
    (4): ConvResidualFlow(
      (resblock): ModuleList(
        (0): Conv2d(4, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): Swish()
        (2): Conv2d(64, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
    )
    (5): InvertiblePooling()
    (6): ActNorm2D()
    (7): ConvResidualFlow(
      (resblock): ModuleList(
        (0): Conv2d(16, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): Swish()
        (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (3): Swish()
        (4): Conv2d(64, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )


In [18]:
classifier = ConnectedClassifier_SoftKMeans(784, 100, 10)
# classifier = ConnectedClassifier_Softmax(784, 100, 10)
classifier = classifier.to(device)

## Model Train

In [19]:
learning_rate = 0.0003
batch_size = 50

In [20]:
train_loader = data.DataLoader(dataset=train_dataset, num_workers=4, batch_size=batch_size, shuffle=True)
test_loader = data.DataLoader(dataset=test_dataset, num_workers=4, batch_size=batch_size, shuffle=False)

In [21]:
# criterion = nn.NLLLoss()
criterion = nn.CrossEntropyLoss()

optimizer = optim.Adam(list(model.parameters())+list(classifier.parameters()), lr=learning_rate)
# optimizer = optim.SGD(model.parameters(), lr=0.1)

print("number of params: ", sum(p.numel() for p in model.parameters()))

number of params:  60467


In [22]:
# for p in model.parameters():
#     print(torch.isnan(p).type(torch.float32).sum())

In [23]:
# model(torch.randn(10, 784).to(device)).shape

In [24]:
xx = iter(test_loader).next()[0]
xx.shape

torch.Size([50, 1, 28, 28])

In [26]:
losses = []
train_accs = []
test_accs = []
EPOCHS = 50

In [27]:
index = 0
for epoch in range(EPOCHS):
    train_acc = 0
    train_count = 0
    for xx, yy in tqdm(train_loader):
        xx, yy = xx.to(device), yy.to(device)
#     for xx, yy in tqdm(test_loader):

        yout = model(xx)
#         print(yout)
        yout = classifier(yout)    
#         print(yout)
        loss = criterion(yout, yy)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.append(float(loss))

        outputs = torch.argmax(yout, dim=1).data.cpu().numpy()
        correct = (outputs == yy.data.cpu().numpy()).astype(float).sum()
        train_acc += correct
        train_count += len(outputs)
#         break

    train_accs.append(float(train_acc)/train_count*100)
    train_acc = 0
    train_count = 0

    print(f'Epoch: {epoch}:{index},  Loss:{float(loss)}')
    test_count = 0
    test_acc = 0
    for xx, yy in tqdm(test_loader):
        xx, yy = xx.to(device), yy.to(device)
        with torch.no_grad():
            yout = classifier(model(xx))    
        outputs = torch.argmax(yout, dim=1).data.cpu().numpy()
        correct = (outputs == yy.data.cpu().numpy()).astype(float).sum()
        test_acc += correct
        test_count += len(xx)
    test_accs.append(float(test_acc)/test_count*100)
    print(f'Train Acc:{train_accs[-1]:.2f}%, Test Acc:{test_accs[-1]:.2f}%')
    print()

### after each class index is finished training
print(f'\t-> Train Acc {max(train_accs)} ; Test Acc {max(test_accs)}')

100%|██████████| 1200/1200 [00:12<00:00, 99.27it/s] 
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 0:0,  Loss:2.2965710163116455


100%|██████████| 200/200 [00:00<00:00, 222.92it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:22.62%, Test Acc:35.84%



100%|██████████| 1200/1200 [00:10<00:00, 110.00it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 1:0,  Loss:2.2472329139709473


100%|██████████| 200/200 [00:00<00:00, 208.36it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:32.99%, Test Acc:31.31%



100%|██████████| 1200/1200 [00:10<00:00, 109.46it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 2:0,  Loss:2.1716580390930176


100%|██████████| 200/200 [00:00<00:00, 209.41it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:40.92%, Test Acc:47.55%



100%|██████████| 1200/1200 [00:10<00:00, 109.87it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 3:0,  Loss:2.082125186920166


100%|██████████| 200/200 [00:00<00:00, 208.14it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:55.36%, Test Acc:62.44%



100%|██████████| 1200/1200 [00:10<00:00, 109.99it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 4:0,  Loss:1.9444466829299927


100%|██████████| 200/200 [00:00<00:00, 209.52it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:67.30%, Test Acc:71.46%



100%|██████████| 1200/1200 [00:10<00:00, 110.02it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 5:0,  Loss:1.7906897068023682


100%|██████████| 200/200 [00:00<00:00, 204.11it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:77.04%, Test Acc:78.55%



100%|██████████| 1200/1200 [00:10<00:00, 109.88it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 6:0,  Loss:1.7427773475646973


100%|██████████| 200/200 [00:00<00:00, 205.85it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:80.02%, Test Acc:79.83%



100%|██████████| 1200/1200 [00:10<00:00, 110.11it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 7:0,  Loss:1.7181074619293213


100%|██████████| 200/200 [00:00<00:00, 205.95it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:81.14%, Test Acc:80.78%



100%|██████████| 1200/1200 [00:10<00:00, 110.17it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 8:0,  Loss:1.6762293577194214


100%|██████████| 200/200 [00:00<00:00, 203.93it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:81.71%, Test Acc:81.22%



100%|██████████| 1200/1200 [00:10<00:00, 110.13it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 9:0,  Loss:1.748274326324463


100%|██████████| 200/200 [00:00<00:00, 212.85it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:82.28%, Test Acc:81.61%



100%|██████████| 1200/1200 [00:10<00:00, 109.60it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 10:0,  Loss:1.611384391784668


100%|██████████| 200/200 [00:00<00:00, 207.50it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:82.74%, Test Acc:82.53%



100%|██████████| 1200/1200 [00:10<00:00, 109.93it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 11:0,  Loss:1.6488560438156128


100%|██████████| 200/200 [00:00<00:00, 208.45it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:83.18%, Test Acc:82.71%



100%|██████████| 1200/1200 [00:10<00:00, 109.95it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 12:0,  Loss:1.6494359970092773


100%|██████████| 200/200 [00:00<00:00, 211.04it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:85.09%, Test Acc:85.27%



100%|██████████| 1200/1200 [00:10<00:00, 109.83it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 13:0,  Loss:1.6051931381225586


100%|██████████| 200/200 [00:00<00:00, 204.05it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:86.27%, Test Acc:86.00%



100%|██████████| 1200/1200 [00:10<00:00, 109.31it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 14:0,  Loss:1.595191478729248


100%|██████████| 200/200 [00:00<00:00, 206.73it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:86.88%, Test Acc:86.54%



100%|██████████| 1200/1200 [00:10<00:00, 109.85it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 15:0,  Loss:1.5706875324249268


100%|██████████| 200/200 [00:00<00:00, 212.47it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:87.27%, Test Acc:86.38%



100%|██████████| 1200/1200 [00:10<00:00, 109.39it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 16:0,  Loss:1.6784493923187256


100%|██████████| 200/200 [00:00<00:00, 201.77it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:87.54%, Test Acc:86.83%



100%|██████████| 1200/1200 [00:10<00:00, 109.76it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 17:0,  Loss:1.5549932718276978


100%|██████████| 200/200 [00:00<00:00, 202.23it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:88.06%, Test Acc:87.81%



100%|██████████| 1200/1200 [00:10<00:00, 109.93it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 18:0,  Loss:1.5489935874938965


100%|██████████| 200/200 [00:00<00:00, 207.71it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:88.48%, Test Acc:88.27%



100%|██████████| 1200/1200 [00:10<00:00, 109.66it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 19:0,  Loss:1.5223302841186523


100%|██████████| 200/200 [00:00<00:00, 208.31it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:88.95%, Test Acc:88.60%



100%|██████████| 1200/1200 [00:10<00:00, 109.85it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 20:0,  Loss:1.6003624200820923


100%|██████████| 200/200 [00:00<00:00, 209.67it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:89.59%, Test Acc:88.97%



100%|██████████| 1200/1200 [00:10<00:00, 109.70it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 21:0,  Loss:1.5831204652786255


100%|██████████| 200/200 [00:00<00:00, 206.52it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:89.98%, Test Acc:89.59%



100%|██████████| 1200/1200 [00:10<00:00, 109.58it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 22:0,  Loss:1.5807907581329346


100%|██████████| 200/200 [00:00<00:00, 214.78it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:90.16%, Test Acc:89.02%



100%|██████████| 1200/1200 [00:10<00:00, 110.10it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 23:0,  Loss:1.5413240194320679


100%|██████████| 200/200 [00:00<00:00, 205.87it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:90.45%, Test Acc:89.77%



100%|██████████| 1200/1200 [00:10<00:00, 109.95it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 24:0,  Loss:1.5877796411514282


100%|██████████| 200/200 [00:00<00:00, 207.82it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:90.66%, Test Acc:89.63%



100%|██████████| 1200/1200 [00:11<00:00, 109.02it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 25:0,  Loss:1.5305677652359009


100%|██████████| 200/200 [00:00<00:00, 207.22it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:90.86%, Test Acc:89.26%



100%|██████████| 1200/1200 [00:10<00:00, 109.80it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 26:0,  Loss:1.5596085786819458


100%|██████████| 200/200 [00:00<00:00, 204.66it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:90.94%, Test Acc:89.92%



100%|██████████| 1200/1200 [00:10<00:00, 109.62it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 27:0,  Loss:1.5271220207214355


100%|██████████| 200/200 [00:01<00:00, 193.11it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:91.25%, Test Acc:90.36%



100%|██████████| 1200/1200 [00:10<00:00, 110.09it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 28:0,  Loss:1.550388216972351


100%|██████████| 200/200 [00:00<00:00, 209.18it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:91.40%, Test Acc:90.46%



100%|██████████| 1200/1200 [00:10<00:00, 110.22it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 29:0,  Loss:1.5624693632125854


100%|██████████| 200/200 [00:00<00:00, 206.94it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:91.59%, Test Acc:90.57%



100%|██████████| 1200/1200 [00:10<00:00, 109.65it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 30:0,  Loss:1.4967538118362427


100%|██████████| 200/200 [00:00<00:00, 208.42it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:91.71%, Test Acc:90.43%



100%|██████████| 1200/1200 [00:10<00:00, 109.86it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 31:0,  Loss:1.6004054546356201


100%|██████████| 200/200 [00:00<00:00, 202.28it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:91.87%, Test Acc:90.41%



100%|██████████| 1200/1200 [00:10<00:00, 109.51it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 32:0,  Loss:1.582597255706787


100%|██████████| 200/200 [00:01<00:00, 198.31it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:91.97%, Test Acc:90.15%



100%|██████████| 1200/1200 [00:10<00:00, 109.34it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 33:0,  Loss:1.5071167945861816


100%|██████████| 200/200 [00:00<00:00, 202.82it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:92.14%, Test Acc:90.85%



100%|██████████| 1200/1200 [00:10<00:00, 109.92it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 34:0,  Loss:1.5001825094223022


100%|██████████| 200/200 [00:00<00:00, 209.40it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:92.21%, Test Acc:90.19%



100%|██████████| 1200/1200 [00:10<00:00, 109.80it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 35:0,  Loss:1.537244439125061


100%|██████████| 200/200 [00:00<00:00, 201.67it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:92.42%, Test Acc:91.20%



100%|██████████| 1200/1200 [00:10<00:00, 109.65it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 36:0,  Loss:1.5314174890518188


100%|██████████| 200/200 [00:00<00:00, 211.60it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:92.49%, Test Acc:90.97%



100%|██████████| 1200/1200 [00:10<00:00, 109.99it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 37:0,  Loss:1.5352544784545898


100%|██████████| 200/200 [00:01<00:00, 198.37it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:92.73%, Test Acc:91.21%



100%|██████████| 1200/1200 [00:10<00:00, 111.45it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 38:0,  Loss:1.5279300212860107


100%|██████████| 200/200 [00:00<00:00, 219.58it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:92.70%, Test Acc:91.23%



100%|██████████| 1200/1200 [00:10<00:00, 114.27it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 39:0,  Loss:1.4634608030319214


100%|██████████| 200/200 [00:00<00:00, 220.70it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:92.85%, Test Acc:91.14%



100%|██████████| 1200/1200 [00:10<00:00, 114.86it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 40:0,  Loss:1.5826213359832764


100%|██████████| 200/200 [00:00<00:00, 218.49it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:93.03%, Test Acc:91.17%



100%|██████████| 1200/1200 [00:10<00:00, 112.42it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 41:0,  Loss:1.526139259338379


100%|██████████| 200/200 [00:00<00:00, 218.94it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:93.18%, Test Acc:91.51%



100%|██████████| 1200/1200 [00:10<00:00, 113.53it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 42:0,  Loss:1.5915091037750244


100%|██████████| 200/200 [00:00<00:00, 216.31it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:93.29%, Test Acc:91.64%



100%|██████████| 1200/1200 [00:10<00:00, 114.47it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 43:0,  Loss:1.5755573511123657


100%|██████████| 200/200 [00:00<00:00, 215.48it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:93.34%, Test Acc:91.14%



100%|██████████| 1200/1200 [00:10<00:00, 113.77it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 44:0,  Loss:1.53287672996521


100%|██████████| 200/200 [00:00<00:00, 221.28it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:93.41%, Test Acc:91.40%



100%|██████████| 1200/1200 [00:10<00:00, 113.66it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 45:0,  Loss:1.546945333480835


100%|██████████| 200/200 [00:00<00:00, 221.24it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:93.59%, Test Acc:91.75%



100%|██████████| 1200/1200 [00:10<00:00, 114.01it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 46:0,  Loss:1.5353235006332397


100%|██████████| 200/200 [00:00<00:00, 218.91it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:93.69%, Test Acc:91.16%



100%|██████████| 1200/1200 [00:10<00:00, 112.69it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 47:0,  Loss:1.573116421699524


100%|██████████| 200/200 [00:00<00:00, 224.68it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:93.88%, Test Acc:91.46%



100%|██████████| 1200/1200 [00:10<00:00, 113.29it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 48:0,  Loss:1.5196865797042847


100%|██████████| 200/200 [00:00<00:00, 214.36it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:93.88%, Test Acc:91.47%



100%|██████████| 1200/1200 [00:10<00:00, 112.80it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 49:0,  Loss:1.509291410446167


100%|██████████| 200/200 [00:00<00:00, 218.67it/s]

Train Acc:93.91%, Test Acc:91.69%

	-> Train Acc 93.91166666666668 ; Test Acc 91.75





In [28]:
classifier.inv_temp

Parameter containing:
tensor([16.7558], device='cuda:1', requires_grad=True)

### Hard test accuracy with count per classifier

In [29]:
test_count = 0
test_acc = 0
set_count = torch.zeros(classifier.num_sets).to(device)
for xx, yy in tqdm(test_loader):
    xx, yy = xx.to(device), yy.to(device)
    with torch.no_grad():
        yout = classifier(model(xx), hard=True)
        set_indx, count = torch.unique(torch.argmax(classifier.cls_confidence, dim=1), return_counts=True) 
        set_count[set_indx] += count
    outputs = torch.argmax(yout, dim=1).data.cpu().numpy()
    correct = (outputs == yy.data.cpu().numpy()).astype(float).sum()
    test_acc += correct
    test_count += len(xx)

print(f'Hard Test Acc:{float(test_acc)/test_count*100:.2f}%')
print(set_count.type(torch.long).tolist())

100%|██████████| 200/200 [00:00<00:00, 219.00it/s]

Hard Test Acc:91.72%
[1010, 963, 1022, 945, 0, 965, 0, 0, 0, 0, 0, 0, 0, 10, 0, 12, 4, 0, 0, 0, 20, 0, 0, 0, 0, 0, 897, 1041, 954, 982, 0, 0, 0, 59, 942, 0, 0, 0, 12, 0, 18, 0, 0, 0, 0, 0, 0, 0, 6, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 3, 0, 0, 38, 0, 0, 0, 36, 0, 0, 0, 0, 0, 0, 0, 0, 19, 0, 0, 0, 0, 30, 0, 0, 0, 0, 0, 0, 0, 9, 0, 0, 0, 0, 0]





### Hard train accuracy with count per classifier

In [30]:
test_count = 0
test_acc = 0
set_count = torch.zeros(classifier.num_sets).to(device)
for xx, yy in tqdm(train_loader):
    xx, yy = xx.to(device), yy.to(device)
    with torch.no_grad():
        yout = classifier(model(xx), hard=True)
        set_indx, count = torch.unique(torch.argmax(classifier.cls_confidence, dim=1), return_counts=True) 
        set_count[set_indx] += count
    outputs = torch.argmax(yout, dim=1).data.cpu().numpy()
    correct = (outputs == yy.data.cpu().numpy()).astype(float).sum()
    test_acc += correct
    test_count += len(xx)

print(f'Hard Train Acc:{float(test_acc)/test_count*100:.2f}%')
print(set_count.type(torch.long).tolist())

100%|██████████| 1200/1200 [00:05<00:00, 233.51it/s]

Hard Train Acc:94.45%
[6032, 5834, 6053, 5754, 0, 5814, 0, 0, 0, 1, 0, 0, 0, 59, 0, 65, 27, 0, 0, 0, 124, 0, 0, 4, 0, 0, 5477, 6150, 5696, 5926, 0, 0, 4, 293, 5722, 0, 0, 0, 52, 0, 85, 0, 0, 0, 0, 0, 0, 0, 33, 0, 10, 2, 2, 0, 0, 9, 0, 0, 0, 22, 0, 0, 1, 0, 0, 12, 0, 0, 211, 0, 0, 0, 199, 0, 1, 0, 0, 0, 1, 0, 1, 117, 0, 0, 0, 0, 166, 0, 0, 0, 0, 0, 0, 0, 39, 1, 1, 0, 0, 0]





In [31]:
#### Classifiers that enclose any data
torch.count_nonzero(set_count)

tensor(38, device='cuda:1')

In [32]:
#### classifier with class representation
torch.argmax(classifier.cls_weight, dim=1)

tensor([0, 1, 2, 3, 4, 5, 6, 5, 8, 9, 6, 3, 6, 3, 4, 5, 6, 7, 8, 9, 0, 3, 6, 3,
        6, 5, 6, 7, 8, 9, 0, 3, 6, 3, 4, 5, 6, 5, 8, 9, 0, 1, 6, 3, 6, 5, 6, 5,
        8, 5, 0, 1, 2, 3, 6, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 5, 8, 9, 0, 1,
        2, 6, 4, 5, 6, 5, 8, 9, 0, 1, 6, 3, 6, 5, 6, 5, 8, 9, 0, 1, 6, 3, 4, 5,
        6, 5, 8, 9], device='cuda:1')

In [33]:
# The class labels are same as that of initialized
# tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3,
#         4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7,
#         8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1,
#         2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5,
#         6, 7, 8, 9], device='cuda:0')

In [34]:
classifier.cls_weight

Parameter containing:
tensor([[ 9.9651, -8.3586, -8.0780, -8.1773, -8.0812, -8.5657, -7.4315, -8.5668,
         -8.5246, -8.5689],
        [-8.4009,  9.9437, -8.5594, -7.6172, -8.5123, -8.5249, -8.4975, -8.5251,
         -8.5438, -8.5429],
        [-8.1342, -8.5138, 10.2070, -8.4267, -7.0783, -8.5771, -7.0886, -8.5765,
         -8.5024, -8.5772],
        [-8.1059, -7.3575, -8.3937, 10.0796, -8.3084, -8.4469, -8.2969, -8.4449,
         -8.4572, -8.4559],
        [-2.6310, -3.4729, -0.8183, -3.0468,  1.9774, -3.7064,  1.8816, -3.7827,
         -3.3317, -3.7476],
        [-8.5420, -8.5369, -8.5448, -8.5381, -8.5461,  9.9999, -8.5429, -7.5907,
         -8.1362, -7.6427],
        [-1.7174, -4.6637, -2.9627, -4.1755, -3.7274, -5.0109,  5.3939, -5.0871,
         -4.7791, -5.0709],
        [-1.7704, -1.9785, -2.0841, -1.7862, -2.1846,  2.6268, -1.5301, -0.0152,
         -0.7611, -0.4236],
        [-3.9767, -4.2292, -4.1630, -4.1689, -4.2116, -3.6602, -3.8590, -3.7507,
          5.6222, -3.7014

In [35]:
# torch.unique(torch.argmax(classifier.cls_confidence, dim=1), return_counts=True)

In [36]:
classifier.inv_temp

Parameter containing:
tensor([16.7558], device='cuda:1', requires_grad=True)

In [37]:
### example output per classifier
yout[5]

tensor([1.5403e-08, 8.6201e-09, 1.4322e-07, 9.9440e-09, 6.6119e-08, 8.0251e-09,
        1.0000e+00, 7.8943e-09, 8.5586e-09, 8.0861e-09], device='cuda:1')

In [43]:
asdfsdf ## to break the code

NameError: name 'asdfsdf' is not defined

### analyze per classifier accuracy

In [44]:
test_count = 0
test_acc = 0
set_count = torch.zeros(classifier.num_sets).to(device)
set_acc = torch.zeros(classifier.num_sets).to(device)
for xx, yy in tqdm(train_loader):
    xx, yy = xx.to(device), yy.to(device)
    with torch.no_grad():
        yout = classifier(model(xx), hard=True)
        
    cls_indx = torch.argmax(classifier.cls_confidence, dim=1)
    set_indx, count = torch.unique(cls_indx, return_counts=True) 
    set_count[set_indx] += count
    
    outputs = torch.argmax(yout, dim=1).data.cpu().numpy()
    correct = (outputs == yy.data.cpu().numpy()).astype(float)
    
    ### class_index has 100 possible values
    for i, c in enumerate(correct):
        set_acc[cls_indx[i]] += c
    
#     print(set_acc.sum(), set_count.sum())
#     break
    test_acc += correct.sum()
    test_count += len(xx)

print(f'Hard Train Acc:{float(test_acc)/test_count*100:.2f}%')
print(set_count.type(torch.long).tolist())

100%|██████████| 1200/1200 [00:06<00:00, 176.71it/s]

Hard Train Acc:94.52%
[0, 127, 413, 0, 0, 0, 5391, 0, 0, 0, 0, 5835, 0, 68, 0, 0, 0, 0, 0, 0, 7, 0, 5506, 0, 0, 1, 1, 6138, 126, 0, 7, 0, 11, 0, 0, 0, 41, 0, 0, 5715, 3, 0, 147, 0, 4089, 2, 0, 0, 0, 1, 0, 3, 0, 7, 1610, 0, 0, 0, 0, 0, 83, 0, 0, 543, 227, 6, 0, 0, 4, 0, 5752, 0, 0, 10, 0, 5986, 0, 0, 2, 0, 278, 0, 0, 5409, 0, 0, 397, 0, 5731, 156, 0, 0, 0, 0, 0, 0, 39, 0, 128, 0]





In [45]:
set_acc/set_count

tensor([   nan, 0.9370, 0.8741,    nan,    nan,    nan, 0.8611,    nan,    nan,
           nan,    nan, 0.9962,    nan, 0.8382,    nan,    nan,    nan,    nan,
           nan,    nan, 1.0000,    nan, 0.9130,    nan,    nan, 1.0000, 1.0000,
        0.9653, 0.9444,    nan, 1.0000,    nan, 0.7273,    nan,    nan,    nan,
        0.9024,    nan,    nan, 0.9890, 1.0000,    nan, 0.8367,    nan, 0.9359,
        1.0000,    nan,    nan,    nan, 1.0000,    nan, 1.0000,    nan, 0.8571,
        0.8522,    nan,    nan,    nan,    nan,    nan, 0.8072,    nan,    nan,
        0.9503, 0.9163, 1.0000,    nan,    nan, 1.0000,    nan, 0.9080,    nan,
           nan, 0.9000,    nan, 0.9932,    nan,    nan, 0.5000,    nan, 0.7914,
           nan,    nan, 0.9440,    nan,    nan, 0.7456,    nan, 0.9949, 0.9936,
           nan,    nan,    nan,    nan,    nan,    nan, 0.7436,    nan, 0.8750,
           nan], device='cuda:0')

In [46]:
for i, (cnt, acc, cls) in enumerate(zip(set_count.type(torch.long).tolist(),
                                   (set_acc/set_count).tolist(),
                                   torch.argmax(classifier.cls_weight, dim=1).tolist())):
    if cnt == 0: continue
    print(f"{i},\t {cnt},\t {cls}\t {acc*100:.2f}%")

1,	 127,	 1	 93.70%
2,	 413,	 2	 87.41%
6,	 5391,	 6	 86.11%
11,	 5835,	 1	 99.62%
13,	 68,	 3	 83.82%
20,	 7,	 0	 100.00%
22,	 5506,	 2	 91.30%
25,	 1,	 5	 100.00%
26,	 1,	 6	 100.00%
27,	 6138,	 7	 96.53%
28,	 126,	 8	 94.44%
30,	 7,	 0	 100.00%
32,	 11,	 2	 72.73%
36,	 41,	 6	 90.24%
39,	 5715,	 9	 98.90%
40,	 3,	 0	 100.00%
42,	 147,	 2	 83.67%
44,	 4089,	 4	 93.59%
45,	 2,	 5	 100.00%
49,	 1,	 9	 100.00%
51,	 3,	 1	 100.00%
53,	 7,	 3	 85.71%
54,	 1610,	 4	 85.22%
60,	 83,	 0	 80.72%
63,	 543,	 3	 95.03%
64,	 227,	 4	 91.63%
65,	 6,	 5	 100.00%
68,	 4,	 8	 100.00%
70,	 5752,	 0	 90.80%
73,	 10,	 3	 90.00%
75,	 5986,	 5	 99.32%
78,	 2,	 8	 50.00%
80,	 278,	 0	 79.14%
83,	 5409,	 3	 94.40%
86,	 397,	 6	 74.56%
88,	 5731,	 8	 99.49%
89,	 156,	 9	 99.36%
96,	 39,	 6	 74.36%
98,	 128,	 8	 87.50%


### Benchmark - Robustness

In [50]:
import foolbox as fb
import foolbox.attacks as fa

In [59]:
epsilons = [0.001, 0.01, 0.03, 0.1, 0.3, 0.5, 1.0]
# epsilons = [0.0005, 0.001, 0.0015, 0.002, 0.003, 0.005, 0.01, 0.02, 0.03, 0.1, 0.3, 0.5, 1.0,]

#### Benchmark on full dataset

In [60]:
attacks = [
    fa.FGSM(),
    fa.LinfPGD(),
    fa.LinfBasicIterativeAttack(),
#     fa.LinfAdditiveUniformNoiseAttack(),
#     fa.LinfDeepFoolAttack(),
]
atk_names = [
    "FGSM", 
    "LinfPGD", 
    "LinfBasicIterativeAttack", 
#     "LinfAdditiveUniformNoiseAttack", 
#     "LinfDeepFoolAttack",
]

In [61]:
def get_attack_success(model, attack, dataloader):
    fmodel = fb.PyTorchModel(model.eval(), bounds=(-1, 1))   
    success_per_eps = []

    for eps in epsilons:
        count = 0
        total = 0
        print(f"Running one epoch attack for eps: {eps}")
        for i, (images, labels) in enumerate(tqdm(dataloader)):
            images, labels = images.to(device), labels.to(device)
            _, _, success = attack(fmodel, images, labels, epsilons=[eps])
            count += int(torch.count_nonzero(success))
            total += torch.numel(success)
#             break
        success_per_eps.append(count/total)
    return success_per_eps

In [62]:
# get_attack_success(model, attacks[0], test_loader)

### Compute the robustness

In [63]:
outputs = {}

model.eval()
### load model and compute attack_success
for j in range(len(attacks)):
    attack = attacks[j]
    atname = atk_names[j]

    print(f"Attacking on model using {atname}")
    succ_eps = get_attack_success(model, attack, test_loader)
    outputs[atname] = succ_eps
    print(f"Success rate is : {succ_eps}")

  0%|          | 0/200 [00:00<?, ?it/s]

Attacking on model using FGSM
Running one epoch attack for eps: 0.001


100%|██████████| 200/200 [00:01<00:00, 117.17it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Running one epoch attack for eps: 0.01


100%|██████████| 200/200 [00:01<00:00, 117.00it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Running one epoch attack for eps: 0.03


100%|██████████| 200/200 [00:01<00:00, 117.42it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Running one epoch attack for eps: 0.1


100%|██████████| 200/200 [00:01<00:00, 117.23it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Running one epoch attack for eps: 0.3


100%|██████████| 200/200 [00:01<00:00, 116.27it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Running one epoch attack for eps: 0.5


100%|██████████| 200/200 [00:01<00:00, 116.64it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Running one epoch attack for eps: 1.0


100%|██████████| 200/200 [00:01<00:00, 115.87it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Success rate is : [0.9988, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
Attacking on model using LinfPGD
Running one epoch attack for eps: 0.001


100%|██████████| 200/200 [00:49<00:00,  4.08it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Running one epoch attack for eps: 0.01


100%|██████████| 200/200 [00:49<00:00,  4.07it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Running one epoch attack for eps: 0.03


100%|██████████| 200/200 [00:49<00:00,  4.07it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Running one epoch attack for eps: 0.1


100%|██████████| 200/200 [00:49<00:00,  4.07it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Running one epoch attack for eps: 0.3


100%|██████████| 200/200 [00:49<00:00,  4.07it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Running one epoch attack for eps: 0.5


100%|██████████| 200/200 [00:49<00:00,  4.07it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Running one epoch attack for eps: 1.0


100%|██████████| 200/200 [00:49<00:00,  4.07it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Success rate is : [0.9987, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
Attacking on model using LinfBasicIterativeAttack
Running one epoch attack for eps: 0.001


100%|██████████| 200/200 [00:12<00:00, 16.18it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Running one epoch attack for eps: 0.01


100%|██████████| 200/200 [00:12<00:00, 16.20it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Running one epoch attack for eps: 0.03


100%|██████████| 200/200 [00:12<00:00, 16.20it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Running one epoch attack for eps: 0.1


100%|██████████| 200/200 [00:12<00:00, 16.20it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Running one epoch attack for eps: 0.3


100%|██████████| 200/200 [00:12<00:00, 16.22it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Running one epoch attack for eps: 0.5


100%|██████████| 200/200 [00:12<00:00, 16.22it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Running one epoch attack for eps: 1.0


100%|██████████| 200/200 [00:12<00:00, 16.20it/s]

Success rate is : [0.9988, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]



