This version adds support for Cifar 100, 
1. input channel padding
2. Global pooling for manifold

In [1]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms

import matplotlib
import matplotlib.pyplot as plt
import numpy as np

import os, sys, pathlib, random, time, pickle, copy
from tqdm import tqdm

In [2]:
device = torch.device("cuda:0")
# device = torch.device("cpu")

In [3]:
import torch.optim as optim
from torch.utils import data

In [4]:
import nflib
from nflib.flows import SequentialFlow, NormalizingFlow, ActNorm, ActNorm2D, AffineConstantFlow
import nflib.coupling_flows as icf
import nflib.inn_flow as inn
import nflib.res_flow as irf

### Datasets

In [5]:
# cifar_train = transforms.Compose([
#     transforms.RandomCrop(size=32, padding=4),
#     transforms.RandomHorizontalFlip(),
#     transforms.ToTensor(),
#     transforms.Normalize(
#         mean=[0.4914, 0.4822, 0.4465], # mean=[0.5071, 0.4865, 0.4409] for cifar100
#         std=[0.2023, 0.1994, 0.2010], # std=[0.2009, 0.1984, 0.2023] for cifar100
#     ),
# ])

# cifar_test = transforms.Compose([
#     transforms.ToTensor(),
#     transforms.Normalize(
#         mean=[0.4914, 0.4822, 0.4465], # mean=[0.5071, 0.4865, 0.4409] for cifar100
#         std=[0.2023, 0.1994, 0.2010], # std=[0.2009, 0.1984, 0.2023] for cifar100
#     ),
# ])

# train_dataset = datasets.CIFAR10(root="../../../../../_Datasets/cifar10/", train=True, download=True, transform=cifar_train)
# test_dataset = datasets.CIFAR10(root="../../../../../_Datasets/cifar10/", train=False, download=True, transform=cifar_test)

In [6]:
cifar_train = transforms.Compose([
    transforms.RandomCrop(size=32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.5071, 0.4865, 0.4409],
        std=[0.2009, 0.1984, 0.2023],
    ),
])

cifar_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.5071, 0.4865, 0.4409],
        std=[0.2009, 0.1984, 0.2023],
    ),
])

train_dataset = datasets.CIFAR100(root="../../../../../_Datasets/cifar100/", train=True, download=True, transform=cifar_train)
test_dataset = datasets.CIFAR100(root="../../../../../_Datasets/cifar100/", train=False, download=True, transform=cifar_test)

Files already downloaded and verified
Files already downloaded and verified


In [7]:
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=128, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=128, shuffle=False, num_workers=2)

In [8]:
xx, yy = iter(train_loader).next()

In [9]:
xx.shape

torch.Size([128, 3, 32, 32])

### Model

In [10]:
class PadChannel(nn.Module):
    
    def __init__(self, channel_pad):
        super().__init__()
        self.channel_pad = channel_pad
        
    def forward(self, x):
        y = torch.zeros(x.shape[0], self.channel_pad, x.shape[2], x.shape[3], device=x.device)
        x = torch.cat([x, y], dim=1)
        return x

In [11]:
# actf = irf.Swish
# flows = [
#     PadChannel(8-3),
#     nn.BatchNorm2d(8),
#     irf.ConvResidualFlow(8, [32, 32], kernels=5, activation=actf),
#     irf.InvertiblePooling(2),
#     nn.BatchNorm2d(32),
#     irf.ConvResidualFlow(32, [128, 128], kernels=5, activation=actf),
#     nn.BatchNorm2d(32),
#     irf.ConvResidualFlow(32, [128, 128], kernels=5, activation=actf),
#     irf.InvertiblePooling(2),
#     nn.BatchNorm2d(128),
#     irf.ConvResidualFlow(128, [256, 256], kernels=5, activation=actf),
#     nn.BatchNorm2d(128),
#     irf.ConvResidualFlow(128, [256, 256], kernels=5, activation=actf),
#     irf.InvertiblePooling(2),
#     nn.BatchNorm2d(512),
#     irf.ConvResidualFlow(512, [512, 512], kernels=5, activation=actf),
#     nn.BatchNorm2d(512),
#     irf.ConvResidualFlow(512, [512, 512], kernels=5, activation=actf),
#     nn.BatchNorm2d(512),
#     irf.Flatten(img_size=(512, 4, 4)),
#     nn.BatchNorm1d(8192),
#     nn.Linear(3072, 2),
#     nn.BatchNorm1d(2),
#         ]

# # backbone = SequentialFlow(flows)
# backbone = nn.Sequential(*flows)

In [12]:
actf = irf.Swish
flows = [
    nn.BatchNorm2d(3),
    irf.ConvResidualFlow(3, [32, 32], kernels=5, activation=actf),
    irf.InvertiblePooling(2),
    nn.BatchNorm2d(12),
    irf.ConvResidualFlow(12, [64, 64], kernels=5, activation=actf),
    nn.BatchNorm2d(12),
    irf.ConvResidualFlow(12, [64, 64], kernels=5, activation=actf),
    irf.InvertiblePooling(2),
    nn.BatchNorm2d(48),
    irf.ConvResidualFlow(48, [128, 128], kernels=5, activation=actf),
    nn.BatchNorm2d(48),
    irf.ConvResidualFlow(48, [128, 128], kernels=5, activation=actf),
    irf.InvertiblePooling(2),
    nn.BatchNorm2d(192),
    irf.ConvResidualFlow(192, [256, 256], kernels=5, activation=actf),
    nn.BatchNorm2d(192),
    irf.ConvResidualFlow(192, [256, 256], kernels=5, activation=actf),
    nn.BatchNorm2d(192),
    nn.AdaptiveAvgPool2d(1),
    irf.Flatten(img_size=(192, 1, 1)),
    nn.BatchNorm1d(192),
#     irf.Flatten(img_size=(192, 4, 4)),
#     nn.BatchNorm1d(3072),
#     nn.Linear(3072, 2),
#     nn.BatchNorm1d(2),
        ]

# backbone = SequentialFlow(flows)
backbone = nn.Sequential(*flows)

In [13]:
backbone.to(device)

Sequential(
  (0): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (1): ConvResidualFlow(
    (resblock): ModuleList(
      (0): Conv2d(3, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (1): Swish()
      (2): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (3): Swish()
      (4): Conv2d(32, 3, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    )
  )
  (2): InvertiblePooling()
  (3): BatchNorm2d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (4): ConvResidualFlow(
    (resblock): ModuleList(
      (0): Conv2d(12, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (1): Swish()
      (2): Conv2d(64, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (3): Swish()
      (4): Conv2d(64, 12, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    )
  )
  (5): BatchNorm2d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (6): ConvResidualFlow(
    (resblock)

In [14]:
backbone(xx.to(device)).shape

torch.Size([128, 192])

In [15]:
def get_children(module):
    child = list(module.children())
    if len(child) == 0:
        return [module]
    children = []
    for ch in child:
        grand_ch = get_children(ch)
        children+=grand_ch
    return children

def remove_spectral_norm(model):
    for child in get_children(model):
        if hasattr(child, 'weight'):
            print("Yes", child)
            try:
                nn.utils.remove_spectral_norm(child)
                print("Success")
            except:
                print("Failed")
    return

In [16]:
# remove_spectral_norm(backbone)

In [17]:
backbone(xx.to(device)).shape, 32*32*8

(torch.Size([128, 192]), 8192)

In [18]:
print("number of params: ", sum(p.numel() for p in backbone.parameters()))

number of params:  9941759


In [19]:
for xx, yy in train_loader:
    tt = backbone(xx.to(device))
    print(xx.shape, tt.shape)
    break

torch.Size([128, 3, 32, 32]) torch.Size([128, 192])


In [20]:
class ConnectedClassifier_Linear(nn.Module):
    
    def __init__(self,input_dim, num_sets, output_dim, inv_temp=1):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.num_sets = num_sets
        self.inv_temp = nn.Parameter(torch.ones(1)*inv_temp)
        
        self.linear = nn.Linear(input_dim, num_sets)
#         self.linear.bias.data *= 0
#         self.linear.weight.data *= 0.1
#         self.cls_weight = nn.Parameter(torch.randn(num_sets, output_dim)/output_dim)

        init_val = torch.randn(num_sets, output_dim)
        for ns in range(num_sets):
            init_val[ns, ns%output_dim] = 5
        self.cls_weight = nn.Parameter(init_val)
        
        self.cls_confidence = None
        
        
    def forward(self, x, hard=False):
#         self.cls_weight.data = torch.abs(self.cls_weight.data/self.cls_weight.data.sum(dim=1, keepdim=True))
        
        x = self.linear(x)
        x = x*torch.exp(self.inv_temp)

        if hard:
            x = torch.softmax(x*1e5, dim=1)
        else:
            x = torch.softmax(x, dim=1)
        self.cls_confidence = x
#         c = torch.softmax(self.cls_weight, dim=1)
        c = self.cls_weight
        return x@c ## since both are normalized, it is also normalized

In [21]:
class ConnectedClassifier_SoftKMeans(nn.Module):
    
    def __init__(self,input_dim, num_sets, output_dim, inv_temp=1):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.num_sets = num_sets
        self.inv_temp = nn.Parameter(torch.ones(1)*inv_temp)
        
        self.centers = nn.Parameter(torch.rand(num_sets, input_dim)*2-1)
        
#         self.cls_weight = nn.Parameter(torch.ones(num_sets, output_dim)/output_dim)

        init_val = torch.randn(num_sets, output_dim)
        for ns in range(num_sets):
            init_val[ns, ns%output_dim] = 5
        self.cls_weight = nn.Parameter(init_val)

        self.cls_confidence = None
        
        
    def forward(self, x, hard=False):
#         self.cls_weight.data = torch.abs(self.cls_weight.data/self.cls_weight.data.sum(dim=1, keepdim=True))
        
        dists = torch.cdist(x, self.centers)
        dists = dists/np.sqrt(self.input_dim) ### correction to make diagonal of unit square 1 in nD space
        dists = dists*torch.exp(self.inv_temp)

        if hard:
            x = torch.softmax(-dists*1e5, dim=1)
        else:
            x = torch.softmax(-dists, dim=1)
        self.cls_confidence = x
#         c = torch.softmax(self.cls_weight, dim=1)
        c = self.cls_weight
        return x@c ## since both are normalized, it is also normalized
#         return torch.softmax(x@self.cls_weight, dim=1)

    def set_centroid_to_data_randomly(self, data_loader, model):
        num_centers = self.centers.shape[0]
        xxs, yys = [], []
        count = 0
        for xx, yy in data_loader:
            yout = model(xx.to(device)).data.cpu()
            xxs.append(yout)
            yys.append(yy)
            count += len(xx)
            if count >= num_centers:
                break
        
        yout = torch.cat(xxs, dim=0)
        yy = torch.cat(yys, dim=0)
        
        yout = yout[:num_centers].to(self.centers.device)
        yy = yy[:num_centers].to(self.centers.device)
        
        self.centers.data = yout
        
        init_val = torch.randn(self.num_sets, self.output_dim)
        for ns in range(num_centers):
            init_val[ns, yy[ns]] = 5
        self.cls_weight.data = init_val.to(self.cls_weight.device)
        pass

In [22]:
train_loader.dataset

Dataset CIFAR100
    Number of datapoints: 50000
    Root location: ../../../../../_Datasets/cifar100/
    Split: Train
    StandardTransform
Transform: Compose(
               RandomCrop(size=(32, 32), padding=4)
               RandomHorizontalFlip(p=0.5)
               ToTensor()
               Normalize(mean=[0.5071, 0.4865, 0.4409], std=[0.2009, 0.1984, 0.2023])
           )

In [23]:
train_loader.dataset.transforms

StandardTransform
Transform: Compose(
               RandomCrop(size=(32, 32), padding=4)
               RandomHorizontalFlip(p=0.5)
               ToTensor()
               Normalize(mean=[0.5071, 0.4865, 0.4409], std=[0.2009, 0.1984, 0.2023])
           )

In [24]:
#### C10
# classifier = ConnectedClassifier_SoftKMeans(2, 20, 10)
# classifier = ConnectedClassifier_Linear(2, 20, 10)

#### C100
# classifier = ConnectedClassifier_SoftKMeans(8192, 500, 100)
# classifier = ConnectedClassifier_SoftKMeans(192, 500, 100, inv_temp=0)
classifier = ConnectedClassifier_Linear(192, 500, 100, inv_temp=0)

#### for MLP based classification
# classifier = nn.Sequential(nn.Linear(2, 500), nn.SELU(), nn.Linear(500, 100))

classifier = classifier.to(device)

In [25]:
# classifier.set_centroid_to_data_randomly(train_loader, backbone)

In [26]:
print("number of params: ", sum(p.numel() for p in backbone.parameters()))
print("number of params: ", sum(p.numel() for p in classifier.parameters()))

number of params:  9941759
number of params:  146501


In [27]:
model = nn.Sequential(backbone, classifier).to(device)

In [28]:
print("number of params: ", sum(p.numel() for p in model.parameters()))

number of params:  10088260


## Training

In [29]:
model_name = 'c100_GAP_multi_inv_linear_v0'

In [30]:
# model_name = 'c10_2d_multi_inv_v0'
# model_name = 'c10_2d_multi_ord_v0'
# model_name = 'c100_2d_multi_inv_v1'
# model_name = 'c100_2d_multi_ord_v0'

In [31]:
EPOCHS = 200
criterion = nn.CrossEntropyLoss()
# optimizer = torch.optim.SGD(model.parameters(), lr=0.1,
#                       momentum=0.9, weight_decay=5e-4)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)

In [32]:
## Following is copied from 
### https://github.com/kuangliu/pytorch-cifar/blob/master/main.py

# Training
def train(epoch):
    model.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(tqdm(train_loader)):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
    print(f"[Train] {epoch} Loss: {train_loss/(batch_idx+1):.3f} | Acc: {100.*correct/total:.3f} {correct}/{total}")
    return

In [33]:
best_acc = -1
def test(epoch):
    global best_acc
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(tqdm(test_loader)):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            
    print(f"[Test] {epoch} Loss: {test_loss/(batch_idx+1):.3f} | Acc: {100.*correct/total:.3f} {correct}/{total}")
    
    # Save checkpoint.
    acc = 100.*correct/total
    if acc > best_acc:
        print('Saving..')
        state = {
            'model': model.state_dict(),
            'acc': acc,
            'epoch': epoch,
        }
        if not os.path.isdir('models'):
            os.mkdir('models')
        torch.save(state, f'./models/{model_name}.pth')
        best_acc = acc

In [34]:
start_epoch = 0  # start from epoch 0 or last checkpoint epoch
resume = False

if resume:
    # Load checkpoint.
    print('==> Resuming from checkpoint..')
    assert os.path.isdir('./models'), 'Error: no checkpoint directory found!'
    checkpoint = torch.load(f'./models/{model_name}.pth')
    model.load_state_dict(checkpoint['model'])
    best_acc = checkpoint['acc']
    start_epoch = checkpoint['epoch']

In [35]:
### Train the whole damn thing

for epoch in range(start_epoch, start_epoch+EPOCHS): ## for 200 epochs
    train(epoch)
    test(epoch)
    scheduler.step()

100%|██████████| 391/391 [01:10<00:00,  5.54it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 0 Loss: 4.466 | Acc: 4.736 2368/50000


100%|██████████| 79/79 [00:02<00:00, 29.01it/s]


[Test] 0 Loss: 4.349 | Acc: 6.820 682/10000
Saving..


100%|██████████| 391/391 [00:55<00:00,  7.06it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 1 Loss: 4.260 | Acc: 9.774 4887/50000


100%|██████████| 79/79 [00:05<00:00, 14.21it/s]


[Test] 1 Loss: 4.153 | Acc: 11.650 1165/10000
Saving..


100%|██████████| 391/391 [01:23<00:00,  4.70it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 2 Loss: 4.109 | Acc: 13.606 6803/50000


100%|██████████| 79/79 [00:05<00:00, 14.10it/s]


[Test] 2 Loss: 4.039 | Acc: 15.180 1518/10000
Saving..


100%|██████████| 391/391 [01:22<00:00,  4.71it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 3 Loss: 3.984 | Acc: 16.804 8402/50000


100%|██████████| 79/79 [00:05<00:00, 13.98it/s]


[Test] 3 Loss: 3.910 | Acc: 18.840 1884/10000
Saving..


100%|██████████| 391/391 [01:22<00:00,  4.71it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 4 Loss: 3.854 | Acc: 20.084 10042/50000


100%|██████████| 79/79 [00:05<00:00, 14.13it/s]


[Test] 4 Loss: 3.783 | Acc: 21.660 2166/10000
Saving..


100%|██████████| 391/391 [01:22<00:00,  4.71it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 5 Loss: 3.729 | Acc: 23.354 11677/50000


100%|██████████| 79/79 [00:05<00:00, 13.98it/s]


[Test] 5 Loss: 3.656 | Acc: 25.180 2518/10000
Saving..


100%|██████████| 391/391 [01:23<00:00,  4.70it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 6 Loss: 3.604 | Acc: 26.422 13211/50000


100%|██████████| 79/79 [00:05<00:00, 13.97it/s]


[Test] 6 Loss: 3.546 | Acc: 27.700 2770/10000
Saving..


100%|██████████| 391/391 [01:22<00:00,  4.71it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 7 Loss: 3.492 | Acc: 29.120 14560/50000


100%|██████████| 79/79 [00:05<00:00, 14.11it/s]


[Test] 7 Loss: 3.438 | Acc: 30.110 3011/10000
Saving..


100%|██████████| 391/391 [01:23<00:00,  4.69it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 8 Loss: 3.393 | Acc: 31.600 15800/50000


100%|██████████| 79/79 [00:05<00:00, 13.83it/s]


[Test] 8 Loss: 3.369 | Acc: 31.620 3162/10000
Saving..


100%|██████████| 391/391 [01:23<00:00,  4.70it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 9 Loss: 3.297 | Acc: 33.748 16874/50000


100%|██████████| 79/79 [00:05<00:00, 14.14it/s]


[Test] 9 Loss: 3.303 | Acc: 33.280 3328/10000
Saving..


100%|██████████| 391/391 [01:23<00:00,  4.70it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 10 Loss: 3.198 | Acc: 36.262 18131/50000


100%|██████████| 79/79 [00:05<00:00, 14.08it/s]


[Test] 10 Loss: 3.249 | Acc: 35.090 3509/10000
Saving..


100%|██████████| 391/391 [01:23<00:00,  4.70it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 11 Loss: 3.110 | Acc: 38.406 19203/50000


100%|██████████| 79/79 [00:05<00:00, 14.09it/s]


[Test] 11 Loss: 3.159 | Acc: 36.790 3679/10000
Saving..


100%|██████████| 391/391 [01:23<00:00,  4.70it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 12 Loss: 3.030 | Acc: 40.200 20100/50000


100%|██████████| 79/79 [00:05<00:00, 14.00it/s]


[Test] 12 Loss: 3.115 | Acc: 38.060 3806/10000
Saving..


100%|██████████| 391/391 [01:23<00:00,  4.71it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 13 Loss: 2.947 | Acc: 41.906 20953/50000


100%|██████████| 79/79 [00:05<00:00, 13.85it/s]


[Test] 13 Loss: 3.074 | Acc: 38.760 3876/10000
Saving..


100%|██████████| 391/391 [01:23<00:00,  4.70it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 14 Loss: 2.872 | Acc: 43.612 21806/50000


100%|██████████| 79/79 [00:05<00:00, 13.98it/s]


[Test] 14 Loss: 2.998 | Acc: 40.500 4050/10000
Saving..


100%|██████████| 391/391 [01:23<00:00,  4.70it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 15 Loss: 2.802 | Acc: 45.092 22546/50000


100%|██████████| 79/79 [00:05<00:00, 14.06it/s]


[Test] 15 Loss: 2.971 | Acc: 40.820 4082/10000
Saving..


100%|██████████| 391/391 [01:23<00:00,  4.68it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 16 Loss: 2.735 | Acc: 46.740 23370/50000


100%|██████████| 79/79 [00:05<00:00, 14.09it/s]


[Test] 16 Loss: 2.909 | Acc: 42.260 4226/10000
Saving..


100%|██████████| 391/391 [01:23<00:00,  4.70it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 17 Loss: 2.670 | Acc: 48.146 24073/50000


100%|██████████| 79/79 [00:05<00:00, 14.02it/s]


[Test] 17 Loss: 2.889 | Acc: 42.860 4286/10000
Saving..


100%|██████████| 391/391 [01:23<00:00,  4.68it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 18 Loss: 2.598 | Acc: 49.702 24851/50000


100%|██████████| 79/79 [00:05<00:00, 14.09it/s]


[Test] 18 Loss: 2.872 | Acc: 43.700 4370/10000
Saving..


100%|██████████| 391/391 [01:23<00:00,  4.68it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 19 Loss: 2.535 | Acc: 50.968 25484/50000


100%|██████████| 79/79 [00:05<00:00, 14.06it/s]


[Test] 19 Loss: 2.834 | Acc: 43.970 4397/10000
Saving..


100%|██████████| 391/391 [01:22<00:00,  4.71it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 20 Loss: 2.475 | Acc: 52.452 26226/50000


100%|██████████| 79/79 [00:05<00:00, 13.88it/s]


[Test] 20 Loss: 2.788 | Acc: 45.180 4518/10000
Saving..


100%|██████████| 391/391 [01:23<00:00,  4.70it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 21 Loss: 2.413 | Acc: 53.734 26867/50000


100%|██████████| 79/79 [00:05<00:00, 14.02it/s]


[Test] 21 Loss: 2.755 | Acc: 45.370 4537/10000
Saving..


100%|██████████| 391/391 [01:23<00:00,  4.71it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 22 Loss: 2.362 | Acc: 54.930 27465/50000


100%|██████████| 79/79 [00:05<00:00, 14.00it/s]


[Test] 22 Loss: 2.728 | Acc: 46.280 4628/10000
Saving..


100%|██████████| 391/391 [01:23<00:00,  4.68it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 23 Loss: 2.309 | Acc: 56.002 28001/50000


100%|██████████| 79/79 [00:05<00:00, 14.03it/s]


[Test] 23 Loss: 2.702 | Acc: 47.530 4753/10000
Saving..


100%|██████████| 391/391 [01:23<00:00,  4.70it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 24 Loss: 2.255 | Acc: 57.108 28554/50000


100%|██████████| 79/79 [00:05<00:00, 14.11it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 24 Loss: 2.692 | Acc: 47.190 4719/10000


100%|██████████| 391/391 [01:23<00:00,  4.70it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 25 Loss: 2.201 | Acc: 58.194 29097/50000


100%|██████████| 79/79 [00:05<00:00, 13.94it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 25 Loss: 2.678 | Acc: 47.430 4743/10000


100%|██████████| 391/391 [01:23<00:00,  4.69it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 26 Loss: 2.157 | Acc: 59.224 29612/50000


100%|██████████| 79/79 [00:05<00:00, 14.10it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 26 Loss: 2.676 | Acc: 47.460 4746/10000


100%|██████████| 391/391 [01:23<00:00,  4.67it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 27 Loss: 2.110 | Acc: 60.114 30057/50000


100%|██████████| 79/79 [00:05<00:00, 14.06it/s]


[Test] 27 Loss: 2.634 | Acc: 48.220 4822/10000
Saving..


100%|██████████| 391/391 [01:23<00:00,  4.69it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 28 Loss: 2.056 | Acc: 61.406 30703/50000


100%|██████████| 79/79 [00:05<00:00, 14.06it/s]


[Test] 28 Loss: 2.612 | Acc: 49.080 4908/10000
Saving..


100%|██████████| 391/391 [01:23<00:00,  4.71it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 29 Loss: 2.018 | Acc: 62.076 31038/50000


100%|██████████| 79/79 [00:05<00:00, 14.02it/s]


[Test] 29 Loss: 2.593 | Acc: 49.180 4918/10000
Saving..


100%|██████████| 391/391 [01:23<00:00,  4.66it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 30 Loss: 1.977 | Acc: 62.898 31449/50000


100%|██████████| 79/79 [00:05<00:00, 14.05it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 30 Loss: 2.594 | Acc: 48.710 4871/10000


100%|██████████| 391/391 [01:23<00:00,  4.70it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 31 Loss: 1.943 | Acc: 63.538 31769/50000


100%|██████████| 79/79 [00:05<00:00, 14.02it/s]


[Test] 31 Loss: 2.573 | Acc: 49.520 4952/10000
Saving..


100%|██████████| 391/391 [01:23<00:00,  4.69it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 32 Loss: 1.894 | Acc: 64.648 32324/50000


100%|██████████| 79/79 [00:05<00:00, 14.13it/s]


[Test] 32 Loss: 2.563 | Acc: 49.630 4963/10000
Saving..


100%|██████████| 391/391 [01:23<00:00,  4.70it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 33 Loss: 1.859 | Acc: 65.336 32668/50000


100%|██████████| 79/79 [00:05<00:00, 13.97it/s]


[Test] 33 Loss: 2.549 | Acc: 50.060 5006/10000
Saving..


100%|██████████| 391/391 [01:23<00:00,  4.67it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 34 Loss: 1.811 | Acc: 66.354 33177/50000


100%|██████████| 79/79 [00:05<00:00, 14.13it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 34 Loss: 2.554 | Acc: 49.980 4998/10000


100%|██████████| 391/391 [01:23<00:00,  4.69it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 35 Loss: 1.775 | Acc: 67.192 33596/50000


100%|██████████| 79/79 [00:05<00:00, 14.19it/s]


[Test] 35 Loss: 2.540 | Acc: 50.160 5016/10000
Saving..


100%|██████████| 391/391 [01:23<00:00,  4.67it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 36 Loss: 1.736 | Acc: 67.868 33934/50000


100%|██████████| 79/79 [00:05<00:00, 14.05it/s]


[Test] 36 Loss: 2.534 | Acc: 50.390 5039/10000
Saving..


100%|██████████| 391/391 [01:23<00:00,  4.69it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 37 Loss: 1.714 | Acc: 68.338 34169/50000


100%|██████████| 79/79 [00:05<00:00, 13.95it/s]


[Test] 37 Loss: 2.497 | Acc: 51.290 5129/10000
Saving..


100%|██████████| 391/391 [01:23<00:00,  4.69it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 38 Loss: 1.678 | Acc: 68.906 34453/50000


100%|██████████| 79/79 [00:05<00:00, 14.04it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 38 Loss: 2.512 | Acc: 50.670 5067/10000


100%|██████████| 391/391 [01:23<00:00,  4.70it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 39 Loss: 1.645 | Acc: 69.694 34847/50000


100%|██████████| 79/79 [00:05<00:00, 13.99it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 39 Loss: 2.509 | Acc: 51.020 5102/10000


100%|██████████| 391/391 [01:23<00:00,  4.66it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 40 Loss: 1.609 | Acc: 70.480 35240/50000


100%|██████████| 79/79 [00:05<00:00, 14.01it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 40 Loss: 2.504 | Acc: 50.870 5087/10000


100%|██████████| 391/391 [01:23<00:00,  4.70it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 41 Loss: 1.586 | Acc: 70.894 35447/50000


100%|██████████| 79/79 [00:05<00:00, 14.05it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 41 Loss: 2.498 | Acc: 51.170 5117/10000


100%|██████████| 391/391 [01:23<00:00,  4.71it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 42 Loss: 1.556 | Acc: 71.402 35701/50000


100%|██████████| 79/79 [00:05<00:00, 13.87it/s]


[Test] 42 Loss: 2.468 | Acc: 51.650 5165/10000
Saving..


100%|██████████| 391/391 [01:23<00:00,  4.69it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 43 Loss: 1.521 | Acc: 72.190 36095/50000


100%|██████████| 79/79 [00:05<00:00, 14.15it/s]


[Test] 43 Loss: 2.472 | Acc: 51.800 5180/10000
Saving..


100%|██████████| 391/391 [01:23<00:00,  4.70it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 44 Loss: 1.498 | Acc: 72.584 36292/50000


100%|██████████| 79/79 [00:05<00:00, 14.11it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 44 Loss: 2.487 | Acc: 51.110 5111/10000


100%|██████████| 391/391 [01:23<00:00,  4.68it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 45 Loss: 1.470 | Acc: 73.332 36666/50000


100%|██████████| 79/79 [00:05<00:00, 14.10it/s]


[Test] 45 Loss: 2.465 | Acc: 51.920 5192/10000
Saving..


100%|██████████| 391/391 [01:23<00:00,  4.69it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 46 Loss: 1.437 | Acc: 73.712 36856/50000


100%|██████████| 79/79 [00:05<00:00, 13.97it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 46 Loss: 2.468 | Acc: 51.740 5174/10000


100%|██████████| 391/391 [01:23<00:00,  4.67it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 47 Loss: 1.426 | Acc: 73.964 36982/50000


100%|██████████| 79/79 [00:05<00:00, 13.92it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 47 Loss: 2.456 | Acc: 51.580 5158/10000


100%|██████████| 391/391 [01:23<00:00,  4.70it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 48 Loss: 1.391 | Acc: 74.646 37323/50000


100%|██████████| 79/79 [00:05<00:00, 13.80it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 48 Loss: 2.449 | Acc: 51.800 5180/10000


100%|██████████| 391/391 [01:23<00:00,  4.69it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 50 Loss: 1.353 | Acc: 75.396 37698/50000


100%|██████████| 79/79 [00:05<00:00, 13.97it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 50 Loss: 2.451 | Acc: 51.840 5184/10000


100%|██████████| 391/391 [01:23<00:00,  4.69it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 51 Loss: 1.313 | Acc: 76.278 38139/50000


100%|██████████| 79/79 [00:05<00:00, 14.22it/s]


[Test] 51 Loss: 2.414 | Acc: 52.710 5271/10000
Saving..


100%|██████████| 391/391 [01:23<00:00,  4.70it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 52 Loss: 1.290 | Acc: 76.704 38352/50000


100%|██████████| 79/79 [00:05<00:00, 13.92it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 52 Loss: 2.438 | Acc: 52.320 5232/10000


100%|██████████| 391/391 [01:23<00:00,  4.70it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 53 Loss: 1.270 | Acc: 77.104 38552/50000


100%|██████████| 79/79 [00:05<00:00, 13.85it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 53 Loss: 2.432 | Acc: 52.400 5240/10000


100%|██████████| 391/391 [01:23<00:00,  4.69it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 54 Loss: 1.245 | Acc: 77.656 38828/50000


100%|██████████| 79/79 [00:05<00:00, 14.11it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 54 Loss: 2.435 | Acc: 52.160 5216/10000


100%|██████████| 391/391 [01:22<00:00,  4.74it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 55 Loss: 1.222 | Acc: 77.996 38998/50000


100%|██████████| 79/79 [00:05<00:00, 14.23it/s]


[Test] 55 Loss: 2.399 | Acc: 53.080 5308/10000
Saving..


100%|██████████| 391/391 [01:22<00:00,  4.74it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 56 Loss: 1.203 | Acc: 78.536 39268/50000


100%|██████████| 79/79 [00:05<00:00, 14.32it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 56 Loss: 2.408 | Acc: 52.740 5274/10000


 13%|█▎        | 50/391 [00:10<01:11,  4.74it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

100%|██████████| 391/391 [01:22<00:00,  4.74it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 90 Loss: 0.714 | Acc: 87.650 43825/50000


100%|██████████| 79/79 [00:05<00:00, 14.26it/s]


[Test] 90 Loss: 2.377 | Acc: 54.310 5431/10000
Saving..


100%|██████████| 391/391 [01:22<00:00,  4.74it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 91 Loss: 0.713 | Acc: 87.666 43833/50000


100%|██████████| 79/79 [00:05<00:00, 14.28it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 91 Loss: 2.353 | Acc: 54.310 5431/10000


100%|██████████| 391/391 [01:22<00:00,  4.74it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 92 Loss: 0.704 | Acc: 87.858 43929/50000


100%|██████████| 79/79 [00:05<00:00, 14.31it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 92 Loss: 2.372 | Acc: 54.200 5420/10000


100%|██████████| 391/391 [01:21<00:00,  4.77it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 93 Loss: 0.699 | Acc: 87.898 43949/50000


100%|██████████| 79/79 [00:05<00:00, 14.31it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 93 Loss: 2.401 | Acc: 53.700 5370/10000


100%|██████████| 391/391 [01:22<00:00,  4.76it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 94 Loss: 0.683 | Acc: 88.226 44113/50000


100%|██████████| 79/79 [00:05<00:00, 14.29it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 94 Loss: 2.380 | Acc: 54.260 5426/10000


100%|██████████| 391/391 [01:22<00:00,  4.74it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 95 Loss: 0.681 | Acc: 88.232 44116/50000


100%|██████████| 79/79 [00:05<00:00, 14.29it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 95 Loss: 2.386 | Acc: 54.190 5419/10000


100%|██████████| 391/391 [01:22<00:00,  4.74it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 96 Loss: 0.666 | Acc: 88.456 44228/50000


100%|██████████| 79/79 [00:05<00:00, 14.26it/s]


[Test] 96 Loss: 2.379 | Acc: 54.400 5440/10000
Saving..


 61%|██████    | 238/391 [00:50<00:32,  4.72it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

100%|██████████| 391/391 [01:22<00:00,  4.74it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 131 Loss: 0.456 | Acc: 92.242 46121/50000


100%|██████████| 79/79 [00:05<00:00, 14.29it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 131 Loss: 2.361 | Acc: 55.170 5517/10000


100%|██████████| 391/391 [01:22<00:00,  4.74it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 132 Loss: 0.453 | Acc: 92.298 46149/50000


100%|██████████| 79/79 [00:05<00:00, 14.28it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 132 Loss: 2.358 | Acc: 55.020 5502/10000


100%|██████████| 391/391 [01:21<00:00,  4.77it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 133 Loss: 0.452 | Acc: 92.310 46155/50000


100%|██████████| 79/79 [00:05<00:00, 14.31it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 133 Loss: 2.364 | Acc: 55.120 5512/10000


100%|██████████| 391/391 [01:22<00:00,  4.76it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 134 Loss: 0.447 | Acc: 92.394 46197/50000


100%|██████████| 79/79 [00:05<00:00, 14.32it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 134 Loss: 2.362 | Acc: 54.980 5498/10000


100%|██████████| 391/391 [01:22<00:00,  4.74it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 135 Loss: 0.443 | Acc: 92.472 46236/50000


100%|██████████| 79/79 [00:05<00:00, 14.28it/s]


[Test] 135 Loss: 2.349 | Acc: 55.410 5541/10000
Saving..


100%|██████████| 391/391 [01:22<00:00,  4.76it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 136 Loss: 0.436 | Acc: 92.626 46313/50000


100%|██████████| 79/79 [00:05<00:00, 14.34it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 136 Loss: 2.352 | Acc: 55.090 5509/10000


100%|██████████| 391/391 [01:22<00:00,  4.76it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 137 Loss: 0.436 | Acc: 92.554 46277/50000


100%|██████████| 79/79 [00:05<00:00, 14.37it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 137 Loss: 2.355 | Acc: 55.050 5505/10000


 33%|███▎      | 129/391 [00:27<00:55,  4.71it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

100%|██████████| 391/391 [01:22<00:00,  4.72it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 172 Loss: 0.360 | Acc: 93.892 46946/50000


100%|██████████| 79/79 [00:04<00:00, 16.66it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 172 Loss: 2.349 | Acc: 55.690 5569/10000


100%|██████████| 391/391 [01:22<00:00,  4.72it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 173 Loss: 0.357 | Acc: 93.930 46965/50000


100%|██████████| 79/79 [00:05<00:00, 15.18it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 173 Loss: 2.349 | Acc: 55.730 5573/10000


100%|██████████| 391/391 [01:22<00:00,  4.72it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 174 Loss: 0.356 | Acc: 93.928 46964/50000


100%|██████████| 79/79 [00:05<00:00, 15.17it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 174 Loss: 2.343 | Acc: 55.690 5569/10000


100%|██████████| 391/391 [01:22<00:00,  4.72it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 175 Loss: 0.356 | Acc: 93.920 46960/50000


100%|██████████| 79/79 [00:05<00:00, 15.20it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 175 Loss: 2.346 | Acc: 55.650 5565/10000


100%|██████████| 391/391 [01:22<00:00,  4.73it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 176 Loss: 0.357 | Acc: 93.942 46971/50000


100%|██████████| 79/79 [00:05<00:00, 15.80it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 176 Loss: 2.351 | Acc: 55.510 5551/10000


100%|██████████| 391/391 [01:22<00:00,  4.72it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 177 Loss: 0.353 | Acc: 93.986 46993/50000


100%|██████████| 79/79 [00:05<00:00, 15.07it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 177 Loss: 2.347 | Acc: 55.650 5565/10000


100%|██████████| 391/391 [01:22<00:00,  4.75it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 178 Loss: 0.351 | Acc: 94.008 47004/50000


100%|██████████| 79/79 [00:05<00:00, 15.00it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 178 Loss: 2.344 | Acc: 55.770 5577/10000


 30%|██▉       | 116/391 [00:24<00:58,  4.71it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



In [36]:
best_acc

55.97

In [37]:
######## C100 - 2D
#### inv nn with MLP classifier: 31.58 Acc ;
#### non-inv nn with MLP classifier: 11.89 Acc ;

######## C100 - GAP with connected classifier : 28.59 Acc;

In [38]:
classifier.inv_temp

Parameter containing:
tensor([1.3865], device='cuda:0', requires_grad=True)

In [39]:
checkpoint = torch.load(f'./models/{model_name}.pth')
best_acc = checkpoint['acc']
start_epoch = checkpoint['epoch']

best_acc, start_epoch

(55.97, 192)

In [40]:
checkpoint.keys()

dict_keys(['model', 'acc', 'epoch'])

In [41]:
model_name

'c100_GAP_multi_inv_linear_v0'

In [42]:
asdasd

NameError: name 'asdasd' is not defined

In [None]:
model.load_state_dict(checkpoint['model'])

### Hard test accuracy with count per classifier

In [None]:
model.eval()
print("Testing")

In [None]:
backbone, classifier = model[0], model[1]

In [None]:
classifier

In [None]:
test_count = 0
test_acc = 0
set_count = torch.zeros(classifier.num_sets).to(device)
for xx, yy in tqdm(test_loader):
    xx, yy = xx.to(device), yy.to(device)
    with torch.no_grad():
        yout = classifier(backbone(xx), hard=True)
        set_indx, count = torch.unique(torch.argmax(classifier.cls_confidence, dim=1), return_counts=True) 
        set_count[set_indx] += count
    outputs = torch.argmax(yout, dim=1).data.cpu().numpy()
    correct = (outputs == yy.data.cpu().numpy()).astype(float).sum()
    test_acc += correct
    test_count += len(xx)

print(f'Hard Test Acc:{float(test_acc)/test_count*100:.2f}%')
print(set_count.type(torch.long).tolist())

In [None]:
#### Classifiers that enclose any data
torch.count_nonzero(set_count)

In [None]:
embeddings, labels, ilabels = [], [], []
model.eval()
for xx, yy in tqdm(test_loader):
    ilabels.append(yy)
    xx, yy = xx.to(device), yy.to(device)
    with torch.no_grad():
        embs = backbone(xx)
        embeddings.append(embs.cpu())
        
#         yout = classifier(embs, hard=True)
        yout = classifier(embs)
        labels.append(torch.argmax(yout, dim=1).cpu())

embeddings = torch.cat(embeddings, dim=0)
labels = torch.cat(labels, dim=0)
ilabels = torch.cat(ilabels, dim=0)

In [None]:
plt.figure(figsize=(12,9))
plt.scatter(embeddings[:,0], embeddings[:, 1], c=labels, s=2, cmap='tab20')

cbar = plt.colorbar(ticks=range(10), #label='classes', 
                    boundaries=np.arange(11)-0.5)
cbar.set_ticks(np.arange(10))
cbar.set_ticklabels(list(range(10)))

plt.savefig('./output/01_prediction_c10.pdf')

In [None]:
plt.figure(figsize=(12,9))
plt.scatter(embeddings[:,0], embeddings[:, 1], c=ilabels, s=2, cmap='tab10')

cbar = plt.colorbar(ticks=range(10), #label='classes', 
                    boundaries=np.arange(11)-0.5)
cbar.set_ticks(np.arange(10))
cbar.set_ticklabels(list(range(10)))

plt.savefig('./output/02_ground_truth_c10.pdf')

In [None]:
### Error Nodes
plt.figure(figsize=(12,9))
plt.scatter(embeddings[:,0], embeddings[:, 1], c=(labels==ilabels), s=2, cmap='PiYG')

cbar = plt.colorbar(ticks=range(10), #label='classes', 
                    boundaries=np.arange(3)-0.5)
cbar.set_ticks(np.arange(2))
cbar.set_ticklabels(['incorrect', 'correct', ])
cbar.ax.tick_params(rotation=90)

plt.savefig('./output/03_errors_c10.pdf')

### plot on a 2D map, the decision boundary

In [None]:
ng = 1000
_a,_b,_c,_d = embeddings[:,0].min()-0.1, embeddings[:,0].max()+0.1, embeddings[:,1].min()-0.1, embeddings[:,1].max()+0.1
xg, yg = torch.linspace(_a, _b, ng), torch.linspace(_c, _d, ng)
xg, yg = torch.meshgrid(xg, yg)
xyg = torch.stack([xg.reshape(-1), yg.reshape(-1)], dim=-1)

In [None]:
BS = 1000
output = []
for i in range(0, len(xyg), BS):
    xx = xyg[i:i+BS].to(device)
    with torch.no_grad():
#         yout = classifier(xx, hard=True)
        yout = classifier(xx)
        output.append(torch.argmax(yout, dim=1).cpu())
        
output = torch.cat(output, dim=0)

In [None]:
plt.figure(figsize=(12,9))
plt.scatter(embeddings[:,0], embeddings[:, 1], c=ilabels, s=2, cmap='tab10')

cbar = plt.colorbar(ticks=range(10), #label='classes', 
                    boundaries=np.arange(11)-0.5)
cbar.set_ticks(np.arange(10))
cbar.set_ticklabels(list(range(10)))

plt.imshow(output.reshape(xg.shape).t(), interpolation='nearest',
           extent=(_a, _b, _c, _d),
           alpha=0.5, cmap='tab10',
           aspect='auto', origin='lower')

plt.savefig('./output/04_DecisionBoundary_Class_c10.pdf')

In [None]:
## display different class boundary
BS = 1000
output2 = []
for i in range(0, len(xyg), BS):
    xx = xyg[i:i+BS].to(device)
    with torch.no_grad():
        yout = classifier(xx, hard=True)
        output2.append(torch.argmax(classifier.cls_confidence, dim=1).cpu())
output2 = torch.cat(output2, dim=0)

In [None]:
plt.figure(figsize=(12,9))
plt.scatter(embeddings[:,0], embeddings[:, 1], c=ilabels, s=2, cmap='tab10')

plt.imshow(output2.reshape(xg.shape).t(), interpolation='nearest',
           extent=(_a, _b, _c, _d),
           alpha=0.9, cmap='tab20',
           aspect='auto', origin='lower')

cbar = plt.colorbar(ticks=range(20), #label='classes', 
                    boundaries=np.arange(21)-0.5)
cbar.set_ticks(np.arange(20))
cbar.set_ticklabels(list(range(20)))

### plot centroids as well
# plt.scatter(*classifier[1].centers.data.cpu().numpy().T, marker='*', c='k')
plt.scatter(*classifier.centers.data.cpu().numpy().T, marker='*', c='k')

plt.xlim(xg.min(), xg.max())
plt.ylim(yg.min(), yg.max()) 
plt.savefig('./output/05_DecisionBoundary_Sets_c10.pdf')

### Hard train accuracy with count per classifier

In [None]:
test_count = 0
test_acc = 0
set_count = torch.zeros(classifier.num_sets).to(device)
for xx, yy in tqdm(train_loader):
    xx, yy = xx.to(device), yy.to(device)
    with torch.no_grad():
        yout = classifier(backbone(xx), hard=True)
        set_indx, count = torch.unique(torch.argmax(classifier.cls_confidence, dim=1), return_counts=True) 
        set_count[set_indx] += count
    outputs = torch.argmax(yout, dim=1).data.cpu().numpy()
    correct = (outputs == yy.data.cpu().numpy()).astype(float).sum()
    test_acc += correct
    test_count += len(xx)

print(f'Hard Train Acc:{float(test_acc)/test_count*100:.2f}%')
print(set_count.type(torch.long).tolist())

In [None]:
#### Classifiers that enclose any data
torch.count_nonzero(set_count)

In [None]:
#### classifier with class representation
torch.argmax(classifier.cls_weight, dim=1)

### analyze per classifier accuracy

In [None]:
test_count = 0
test_acc = 0
set_count = torch.zeros(classifier.num_sets).to(device)
set_acc = torch.zeros(classifier.num_sets).to(device)
for xx, yy in tqdm(train_loader):
    xx, yy = xx.to(device), yy.to(device)
    with torch.no_grad():
        yout = classifier(backbone(xx), hard=True)
        
    cls_indx = torch.argmax(classifier.cls_confidence, dim=1)
    set_indx, count = torch.unique(cls_indx, return_counts=True) 
    set_count[set_indx] += count
    
    outputs = torch.argmax(yout, dim=1).data.cpu().numpy()
    correct = (outputs == yy.data.cpu().numpy()).astype(float)
    
    ### class_index has 100 possible values
    for i, c in enumerate(correct):
        set_acc[cls_indx[i]] += c
    
#     print(set_acc.sum(), set_count.sum())
#     break
    test_acc += correct.sum()
    test_count += len(xx)

print(f'Hard Train Acc:{float(test_acc)/test_count*100:.2f}%')
print(set_count.type(torch.long).tolist())

In [None]:
# set_acc/set_count

In [None]:
for i, (cnt, acc, cls) in enumerate(zip(set_count.type(torch.long).tolist(),
                                   (set_acc/set_count).tolist(),
                                   torch.argmax(classifier.cls_weight, dim=1).tolist())):
    if cnt == 0: continue
    print(f"{i},\t {cnt},\t {cls}\t {acc*100:.2f}%")