In [1]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms

import matplotlib
import matplotlib.pyplot as plt
import numpy as np

import os, sys, pathlib, random, time, pickle, copy
from tqdm import tqdm

In [2]:
device = torch.device("cuda:1")
# device = torch.device("cpu")

In [3]:
import torch.optim as optim
from torch.utils import data

In [4]:
import nflib
from nflib.flows import SequentialFlow, NormalizingFlow, ActNorm, ActNorm2D, AffineConstantFlow
import nflib.coupling_flows as icf
import nflib.inn_flow as inn
import nflib.res_flow as irf

### Datasets

In [5]:
cifar_train = transforms.Compose([
    transforms.RandomCrop(size=32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.4914, 0.4822, 0.4465], # mean=[0.5071, 0.4865, 0.4409] for cifar100
        std=[0.2023, 0.1994, 0.2010], # std=[0.2009, 0.1984, 0.2023] for cifar100
    ),
])

cifar_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.4914, 0.4822, 0.4465], # mean=[0.5071, 0.4865, 0.4409] for cifar100
        std=[0.2023, 0.1994, 0.2010], # std=[0.2009, 0.1984, 0.2023] for cifar100
    ),
])

train_dataset = datasets.CIFAR10(root="../../../../../_Datasets/cifar10/", train=True, download=True, transform=cifar_train)
test_dataset = datasets.CIFAR10(root="../../../../../_Datasets/cifar10/", train=False, download=True, transform=cifar_test)

Files already downloaded and verified
Files already downloaded and verified


In [6]:
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=128, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=128, shuffle=False, num_workers=2)

In [7]:
xx, yy = iter(train_loader).next()

In [8]:
xx.shape

torch.Size([128, 3, 32, 32])

### Model

In [10]:
actf = irf.Swish
flows = [
#     ActNorm2D(3),
    nn.BatchNorm2d(3),
    irf.ConvResidualFlow(3, [32, 32], kernels=5, activation=actf),
    irf.InvertiblePooling(2),
#     ActNorm2D(12),
    nn.BatchNorm2d(12),
    irf.ConvResidualFlow(12, [64, 64], kernels=5, activation=actf),
#     ActNorm2D(12),
    nn.BatchNorm2d(12),
    irf.ConvResidualFlow(12, [64, 64], kernels=5, activation=actf),
    irf.InvertiblePooling(2),
#     ActNorm2D(48),
    nn.BatchNorm2d(48),
    irf.ConvResidualFlow(48, [128, 128], kernels=5, activation=actf),
#     ActNorm2D(48),
    nn.BatchNorm2d(48),
    irf.ConvResidualFlow(48, [128, 128], kernels=5, activation=actf),
    irf.InvertiblePooling(2),
#     ActNorm2D(192),
    nn.BatchNorm2d(192),
    irf.ConvResidualFlow(192, [256, 256], kernels=5, activation=actf),
#     ActNorm2D(192),
    nn.BatchNorm2d(192),
    irf.ConvResidualFlow(192, [256, 256], kernels=5, activation=actf),
    nn.BatchNorm2d(192),
    irf.Flatten(img_size=(192, 4, 4)),
#     ActNorm(3072),
#     nn.BatchNorm1d(3072),
#     nn.Linear(3072, 3072, bias=False),
    nn.BatchNorm1d(3072),
        ]

# backbone = SequentialFlow(flows)
backbone = nn.Sequential(*flows)

In [11]:
backbone.to(device)

Sequential(
  (0): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (1): ConvResidualFlow(
    (resblock): ModuleList(
      (0): Conv2d(3, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (1): Swish()
      (2): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (3): Swish()
      (4): Conv2d(32, 3, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    )
  )
  (2): InvertiblePooling()
  (3): BatchNorm2d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (4): ConvResidualFlow(
    (resblock): ModuleList(
      (0): Conv2d(12, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (1): Swish()
      (2): Conv2d(64, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (3): Swish()
      (4): Conv2d(64, 12, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    )
  )
  (5): BatchNorm2d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (6): ConvResidualFlow(
    (resblock)

In [12]:
backbone(xx.to(device)).shape, 32*32*3

(torch.Size([128, 3072]), 3072)

In [13]:
print("number of params: ", sum(p.numel() for p in backbone.parameters()))

number of params:  9947519


In [14]:
def get_children(module):
    child = list(module.children())
    if len(child) == 0:
        return [module]
    children = []
    for ch in child:
        grand_ch = get_children(ch)
        children+=grand_ch
    return children

def remove_spectral_norm(model):
    for child in get_children(model):
        if hasattr(child, 'weight'):
            print("Yes", child)
            try:
                nn.utils.remove_spectral_norm(child)
                print("Success")
            except:
                print("Failed")
    return

In [15]:
# remove_spectral_norm(backbone)

In [16]:
for xx, yy in train_loader:
    tt = backbone(xx.to(device))
    print(xx.shape, tt.shape)
    break

torch.Size([128, 3, 32, 32]) torch.Size([128, 3072])


In [17]:
class ConnectedClassifier_Linear(nn.Module):
    
    def __init__(self,input_dim, num_sets, output_dim, inv_temp=1):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.num_sets = num_sets
        self.inv_temp = nn.Parameter(torch.ones(1)*inv_temp)
        
        self.linear = nn.Linear(input_dim, num_sets)
#         self.linear.bias.data *= 0
#         self.linear.weight.data *= 0.1
#         self.cls_weight = nn.Parameter(torch.randn(num_sets, output_dim)/output_dim)

        init_val = torch.randn(num_sets, output_dim)*0.01
        for ns in range(num_sets):
            init_val[ns, ns%output_dim] = 0.1
        self.cls_weight = nn.Parameter(init_val)
        
        self.cls_confidence = None
        
        
    def forward(self, x, hard=False):
#         self.cls_weight.data = torch.abs(self.cls_weight.data/self.cls_weight.data.sum(dim=1, keepdim=True))
        
        x = self.linear(x)
        if hard:
            x = torch.softmax(x*1e5, dim=1)
        else:
            x = torch.softmax(x*self.inv_temp, dim=1)
        self.cls_confidence = x
        c = torch.softmax(self.cls_weight, dim=1)
#         c = self.cls_weight
        return x@c ## since both are normalized, it is also normalized

In [18]:
class ConnectedClassifier_SoftKMeans(nn.Module):
    
    def __init__(self,input_dim, num_sets, output_dim, inv_temp=1):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.num_sets = num_sets
        self.inv_temp = nn.Parameter(torch.ones(1)*inv_temp)
        
        self.centers = nn.Parameter(torch.rand(num_sets, input_dim)*2-1)
        
#         self.cls_weight = nn.Parameter(torch.ones(num_sets, output_dim)/output_dim)

        init_val = torch.randn(num_sets, output_dim)*0.01
        for ns in range(num_sets):
            init_val[ns, ns%output_dim] = 0.1
        self.cls_weight = nn.Parameter(init_val)

        self.cls_confidence = None
        
        
    def forward(self, x, hard=False):
#         self.cls_weight.data = torch.abs(self.cls_weight.data/self.cls_weight.data.sum(dim=1, keepdim=True))
        
        dists = torch.cdist(x, self.centers)
        dists = dists/np.sqrt(self.input_dim) ### correction to make diagonal of unit square 1 in nD space
        
        if hard:
            x = torch.softmax(-dists*1e5, dim=1)
        else:
            x = torch.softmax(-dists*self.inv_temp, dim=1)
        self.cls_confidence = x
        c = torch.softmax(self.cls_weight, dim=1)
#         c = self.cls_weight
        return x@c ## since both are normalized, it is also normalized
#         return torch.softmax(x@self.cls_weight, dim=1)

    def set_centroid_to_data_randomly(self, data_loader, model):
        num_centers = self.centers.shape[0]
        xxs, yys = [], []
        count = 0
        for xx, yy in data_loader:
            yout = model(xx.to(device)).data.cpu()
            xxs.append(yout)
            yys.append(yy)
            count += len(xx)
            if count >= num_centers:
                break
        
        yout = torch.cat(xxs, dim=0)
        yy = torch.cat(yys, dim=0)
        
        yout = yout[:num_centers].to(self.centers.device)
        yy = yy[:num_centers].to(self.centers.device)
        
        self.centers.data = yout
        
        init_val = torch.ones(self.num_sets, self.output_dim)/self.output_dim
        for ns in range(num_centers):
            init_val[ns, yy[ns]] = 1.
        self.cls_weight.data = init_val.to(self.cls_weight.device)
        pass

In [19]:
classifier = ConnectedClassifier_SoftKMeans(3072, 100, 10)
# classifier = ConnectedClassifier_Linear(3072, 100, 10)
classifier = classifier.to(device)

In [20]:
print("number of params: ", sum(p.numel() for p in backbone.parameters()))
print("number of params: ", sum(p.numel() for p in classifier.parameters()))

number of params:  9947519
number of params:  308201


In [21]:
classifier.set_centroid_to_data_randomly(train_loader, backbone)

In [22]:
model = nn.Sequential(backbone, classifier).to(device)

In [23]:
print("number of params: ", sum(p.numel() for p in model.parameters()))

number of params:  10255720


## Training

In [24]:
# model_name = 'c10_inv_v0'
# model_name = 'c10_ord_v0'

In [25]:
criterion = nn.CrossEntropyLoss()
# optimizer = torch.optim.SGD(model.parameters(), lr=0.1,
#                       momentum=0.9, weight_decay=5e-4)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)

In [26]:
## Following is copied from 
### https://github.com/kuangliu/pytorch-cifar/blob/master/main.py

# Training
def train(epoch):
    model.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(tqdm(train_loader)):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
    print(f"[Train] {epoch} Loss: {train_loss/(batch_idx+1):.3f} | Acc: {100.*correct/total:.3f} {correct}/{total}")
    return

In [27]:
best_acc = -1
def test(epoch):
    global best_acc
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(tqdm(test_loader)):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            
    print(f"[Test] {epoch} Loss: {test_loss/(batch_idx+1):.3f} | Acc: {100.*correct/total:.3f} {correct}/{total}")
    
    # Save checkpoint.
    acc = 100.*correct/total
    if acc > best_acc:
        print('Saving..')
        state = {
            'model': model.state_dict(),
            'acc': acc,
            'epoch': epoch,
        }
        if not os.path.isdir('models'):
            os.mkdir('models')
        torch.save(state, f'./models/{model_name}.pth')
        best_acc = acc

In [28]:
start_epoch = 0  # start from epoch 0 or last checkpoint epoch
resume = False

if resume:
    # Load checkpoint.
    print('==> Resuming from checkpoint..')
    assert os.path.isdir('./models'), 'Error: no checkpoint directory found!'
    checkpoint = torch.load(f'./models/{model_name}.pth')
    model.load_state_dict(checkpoint['model'])
    best_acc = checkpoint['acc']
    start_epoch = checkpoint['epoch']

In [29]:
### Train the whole damn thing

for epoch in range(start_epoch, start_epoch+200): ## for 200 epochs
    train(epoch)
    test(epoch)
    scheduler.step()

100%|██████████| 391/391 [00:42<00:00,  9.31it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 0 Loss: 2.301 | Acc: 23.886 11943/50000


100%|██████████| 79/79 [00:02<00:00, 26.64it/s]


[Test] 0 Loss: 2.300 | Acc: 30.170 3017/10000
Saving..


100%|██████████| 391/391 [00:42<00:00,  9.20it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 1 Loss: 2.299 | Acc: 33.232 16616/50000


100%|██████████| 79/79 [00:02<00:00, 26.40it/s]


[Test] 1 Loss: 2.299 | Acc: 38.990 3899/10000
Saving..


100%|██████████| 391/391 [00:51<00:00,  7.55it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 2 Loss: 2.298 | Acc: 40.338 20169/50000


100%|██████████| 79/79 [00:06<00:00, 12.74it/s]


[Test] 2 Loss: 2.297 | Acc: 45.740 4574/10000
Saving..


100%|██████████| 391/391 [01:22<00:00,  4.76it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 3 Loss: 2.296 | Acc: 45.350 22675/50000


100%|██████████| 79/79 [00:02<00:00, 26.33it/s]


[Test] 3 Loss: 2.295 | Acc: 48.100 4810/10000
Saving..


100%|██████████| 391/391 [00:42<00:00,  9.10it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 4 Loss: 2.294 | Acc: 51.340 25670/50000


100%|██████████| 79/79 [00:03<00:00, 26.02it/s]


[Test] 4 Loss: 2.293 | Acc: 53.770 5377/10000
Saving..


100%|██████████| 391/391 [00:42<00:00,  9.12it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 5 Loss: 2.291 | Acc: 54.334 27167/50000


100%|██████████| 79/79 [00:03<00:00, 26.29it/s]


[Test] 5 Loss: 2.289 | Acc: 56.450 5645/10000
Saving..


100%|██████████| 391/391 [00:42<00:00,  9.11it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 6 Loss: 2.288 | Acc: 56.636 28318/50000


100%|██████████| 79/79 [00:03<00:00, 26.28it/s]


[Test] 6 Loss: 2.286 | Acc: 58.240 5824/10000
Saving..


100%|██████████| 391/391 [00:42<00:00,  9.11it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 7 Loss: 2.284 | Acc: 59.504 29752/50000


100%|██████████| 79/79 [00:03<00:00, 26.30it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 7 Loss: 2.282 | Acc: 57.310 5731/10000


100%|██████████| 391/391 [00:42<00:00,  9.12it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 8 Loss: 2.279 | Acc: 61.698 30849/50000


100%|██████████| 79/79 [00:02<00:00, 26.35it/s]


[Test] 8 Loss: 2.276 | Acc: 61.650 6165/10000
Saving..


100%|██████████| 391/391 [00:43<00:00,  9.09it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 9 Loss: 2.273 | Acc: 62.888 31444/50000


100%|██████████| 79/79 [00:03<00:00, 26.15it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 9 Loss: 2.270 | Acc: 61.450 6145/10000


100%|██████████| 391/391 [00:42<00:00,  9.11it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 10 Loss: 2.266 | Acc: 63.956 31978/50000


100%|██████████| 79/79 [00:02<00:00, 26.34it/s]


[Test] 10 Loss: 2.262 | Acc: 64.300 6430/10000
Saving..


100%|██████████| 391/391 [00:42<00:00,  9.11it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 11 Loss: 2.258 | Acc: 65.146 32573/50000


100%|██████████| 79/79 [00:03<00:00, 26.29it/s]


[Test] 11 Loss: 2.253 | Acc: 65.000 6500/10000
Saving..


100%|██████████| 391/391 [00:43<00:00,  9.09it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 12 Loss: 2.248 | Acc: 65.880 32940/50000


100%|██████████| 79/79 [00:03<00:00, 26.01it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 12 Loss: 2.243 | Acc: 64.330 6433/10000


100%|██████████| 391/391 [00:42<00:00,  9.12it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 13 Loss: 2.237 | Acc: 66.798 33399/50000


100%|██████████| 79/79 [00:02<00:00, 26.38it/s]


[Test] 13 Loss: 2.231 | Acc: 66.020 6602/10000
Saving..


100%|██████████| 391/391 [00:42<00:00,  9.09it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 14 Loss: 2.225 | Acc: 67.588 33794/50000


100%|██████████| 79/79 [00:03<00:00, 26.32it/s]


[Test] 14 Loss: 2.216 | Acc: 67.280 6728/10000
Saving..


100%|██████████| 391/391 [00:42<00:00,  9.10it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 15 Loss: 2.210 | Acc: 68.422 34211/50000


100%|██████████| 79/79 [00:03<00:00, 26.20it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 15 Loss: 2.204 | Acc: 66.360 6636/10000


100%|██████████| 391/391 [00:43<00:00,  9.08it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 16 Loss: 2.195 | Acc: 68.960 34480/50000


100%|██████████| 79/79 [00:02<00:00, 26.40it/s]


[Test] 16 Loss: 2.186 | Acc: 68.290 6829/10000
Saving..


100%|██████████| 391/391 [00:42<00:00,  9.10it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 17 Loss: 2.178 | Acc: 69.634 34817/50000


100%|██████████| 79/79 [00:03<00:00, 26.30it/s]


[Test] 17 Loss: 2.170 | Acc: 68.310 6831/10000
Saving..


100%|██████████| 391/391 [00:42<00:00,  9.11it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 18 Loss: 2.160 | Acc: 70.192 35096/50000


100%|██████████| 79/79 [00:03<00:00, 26.11it/s]


[Test] 18 Loss: 2.151 | Acc: 69.380 6938/10000
Saving..


100%|██████████| 391/391 [00:42<00:00,  9.10it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 19 Loss: 2.142 | Acc: 70.692 35346/50000


100%|██████████| 79/79 [00:03<00:00, 26.23it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 19 Loss: 2.136 | Acc: 67.590 6759/10000


100%|██████████| 391/391 [00:42<00:00,  9.11it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 20 Loss: 2.123 | Acc: 71.104 35552/50000


100%|██████████| 79/79 [00:03<00:00, 26.32it/s]


[Test] 20 Loss: 2.113 | Acc: 70.100 7010/10000
Saving..


100%|██████████| 391/391 [00:42<00:00,  9.09it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 21 Loss: 2.103 | Acc: 71.420 35710/50000


100%|██████████| 79/79 [00:02<00:00, 26.35it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 21 Loss: 2.094 | Acc: 70.090 7009/10000


100%|██████████| 391/391 [00:42<00:00,  9.10it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 22 Loss: 2.082 | Acc: 72.202 36101/50000


100%|██████████| 79/79 [00:02<00:00, 26.39it/s]


[Test] 22 Loss: 2.072 | Acc: 71.260 7126/10000
Saving..


100%|██████████| 391/391 [00:42<00:00,  9.12it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 23 Loss: 2.062 | Acc: 72.370 36185/50000


100%|██████████| 79/79 [00:03<00:00, 26.22it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 23 Loss: 2.053 | Acc: 70.960 7096/10000


100%|██████████| 391/391 [00:42<00:00,  9.11it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 24 Loss: 2.042 | Acc: 72.686 36343/50000


100%|██████████| 79/79 [00:03<00:00, 26.25it/s]


[Test] 24 Loss: 2.032 | Acc: 71.660 7166/10000
Saving..


100%|██████████| 391/391 [00:42<00:00,  9.11it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 25 Loss: 2.023 | Acc: 72.924 36462/50000


100%|██████████| 79/79 [00:03<00:00, 26.27it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 25 Loss: 2.016 | Acc: 71.110 7111/10000


100%|██████████| 391/391 [00:42<00:00,  9.11it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 26 Loss: 2.004 | Acc: 73.176 36588/50000


100%|██████████| 79/79 [00:03<00:00, 26.25it/s]


[Test] 26 Loss: 1.996 | Acc: 72.330 7233/10000
Saving..


100%|██████████| 391/391 [00:42<00:00,  9.10it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 27 Loss: 1.983 | Acc: 73.990 36995/50000


100%|██████████| 79/79 [00:03<00:00, 26.31it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 27 Loss: 1.982 | Acc: 71.990 7199/10000


100%|██████████| 391/391 [00:42<00:00,  9.12it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 28 Loss: 1.967 | Acc: 74.236 37118/50000


100%|██████████| 79/79 [00:03<00:00, 26.27it/s]


[Test] 28 Loss: 1.960 | Acc: 73.430 7343/10000
Saving..


100%|██████████| 391/391 [00:42<00:00,  9.12it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 29 Loss: 1.950 | Acc: 74.810 37405/50000


100%|██████████| 79/79 [00:03<00:00, 26.26it/s]


[Test] 29 Loss: 1.945 | Acc: 73.780 7378/10000
Saving..


100%|██████████| 391/391 [00:42<00:00,  9.11it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 30 Loss: 1.931 | Acc: 75.560 37780/50000


100%|██████████| 79/79 [00:02<00:00, 26.40it/s]


[Test] 30 Loss: 1.929 | Acc: 74.290 7429/10000
Saving..


100%|██████████| 391/391 [00:42<00:00,  9.13it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 31 Loss: 1.916 | Acc: 76.124 38062/50000


100%|██████████| 79/79 [00:03<00:00, 26.09it/s]


[Test] 31 Loss: 1.915 | Acc: 74.760 7476/10000
Saving..


100%|██████████| 391/391 [00:43<00:00,  9.09it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 32 Loss: 1.901 | Acc: 76.582 38291/50000


100%|██████████| 79/79 [00:03<00:00, 26.25it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 32 Loss: 1.901 | Acc: 74.730 7473/10000


100%|██████████| 391/391 [00:42<00:00,  9.11it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 33 Loss: 1.883 | Acc: 77.198 38599/50000


100%|██████████| 79/79 [00:03<00:00, 26.26it/s]


[Test] 33 Loss: 1.879 | Acc: 76.310 7631/10000
Saving..


100%|██████████| 391/391 [00:43<00:00,  9.08it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 34 Loss: 1.868 | Acc: 77.770 38885/50000


100%|██████████| 79/79 [00:02<00:00, 26.38it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 34 Loss: 1.873 | Acc: 75.900 7590/10000


100%|██████████| 391/391 [00:42<00:00,  9.12it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 35 Loss: 1.854 | Acc: 78.262 39131/50000


100%|██████████| 79/79 [00:02<00:00, 26.36it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 35 Loss: 1.860 | Acc: 76.150 7615/10000


100%|██████████| 391/391 [00:42<00:00,  9.11it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 36 Loss: 1.843 | Acc: 78.140 39070/50000


100%|██████████| 79/79 [00:03<00:00, 26.24it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 36 Loss: 1.854 | Acc: 75.430 7543/10000


100%|██████████| 391/391 [00:42<00:00,  9.11it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 37 Loss: 1.833 | Acc: 78.224 39112/50000


100%|██████████| 79/79 [00:03<00:00, 26.28it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 37 Loss: 1.841 | Acc: 75.910 7591/10000


100%|██████████| 391/391 [00:42<00:00,  9.10it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 38 Loss: 1.817 | Acc: 79.016 39508/50000


100%|██████████| 79/79 [00:02<00:00, 26.36it/s]


[Test] 38 Loss: 1.824 | Acc: 77.110 7711/10000
Saving..


100%|██████████| 391/391 [00:42<00:00,  9.11it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 39 Loss: 1.807 | Acc: 79.166 39583/50000


100%|██████████| 79/79 [00:03<00:00, 26.24it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 39 Loss: 1.819 | Acc: 76.810 7681/10000


100%|██████████| 391/391 [00:42<00:00,  9.10it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 40 Loss: 1.797 | Acc: 79.380 39690/50000


100%|██████████| 79/79 [00:03<00:00, 26.21it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 40 Loss: 1.815 | Acc: 75.870 7587/10000


100%|██████████| 391/391 [00:43<00:00,  9.09it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 41 Loss: 1.790 | Acc: 79.206 39603/50000


100%|██████████| 79/79 [00:02<00:00, 26.35it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 41 Loss: 1.809 | Acc: 75.740 7574/10000


100%|██████████| 391/391 [00:42<00:00,  9.09it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 42 Loss: 1.780 | Acc: 79.572 39786/50000


100%|██████████| 79/79 [00:03<00:00, 26.22it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 42 Loss: 1.800 | Acc: 76.030 7603/10000


100%|██████████| 391/391 [00:43<00:00,  9.09it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 43 Loss: 1.773 | Acc: 79.558 39779/50000


100%|██████████| 79/79 [00:03<00:00, 26.22it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 43 Loss: 1.790 | Acc: 76.850 7685/10000


100%|██████████| 391/391 [00:42<00:00,  9.09it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 44 Loss: 1.763 | Acc: 79.936 39968/50000


100%|██████████| 79/79 [00:03<00:00, 26.19it/s]


[Test] 44 Loss: 1.775 | Acc: 77.710 7771/10000
Saving..


100%|██████████| 391/391 [00:42<00:00,  9.10it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 45 Loss: 1.756 | Acc: 80.136 40068/50000


100%|██████████| 79/79 [00:02<00:00, 26.34it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 45 Loss: 1.772 | Acc: 77.360 7736/10000


100%|██████████| 391/391 [00:42<00:00,  9.11it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 46 Loss: 1.748 | Acc: 80.354 40177/50000


100%|██████████| 79/79 [00:03<00:00, 26.15it/s]


[Test] 46 Loss: 1.762 | Acc: 78.290 7829/10000
Saving..


100%|██████████| 391/391 [00:42<00:00,  9.11it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 47 Loss: 1.743 | Acc: 80.350 40175/50000


100%|██████████| 79/79 [00:02<00:00, 26.40it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 47 Loss: 1.766 | Acc: 76.990 7699/10000


100%|██████████| 391/391 [00:42<00:00,  9.12it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 48 Loss: 1.735 | Acc: 80.672 40336/50000


100%|██████████| 79/79 [00:03<00:00, 26.32it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 48 Loss: 1.762 | Acc: 77.030 7703/10000


100%|██████████| 391/391 [00:42<00:00,  9.10it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 49 Loss: 1.731 | Acc: 80.506 40253/50000


100%|██████████| 79/79 [00:03<00:00, 26.30it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 49 Loss: 1.749 | Acc: 78.010 7801/10000


100%|██████████| 391/391 [00:42<00:00,  9.10it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 50 Loss: 1.724 | Acc: 80.890 40445/50000


100%|██████████| 79/79 [00:03<00:00, 26.29it/s]


[Test] 50 Loss: 1.741 | Acc: 78.610 7861/10000
Saving..


100%|██████████| 391/391 [00:42<00:00,  9.12it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 51 Loss: 1.719 | Acc: 81.036 40518/50000


100%|██████████| 79/79 [00:03<00:00, 26.15it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 51 Loss: 1.752 | Acc: 77.010 7701/10000


100%|██████████| 391/391 [00:42<00:00,  9.10it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 52 Loss: 1.713 | Acc: 81.236 40618/50000


100%|██████████| 79/79 [00:03<00:00, 26.20it/s]


[Test] 52 Loss: 1.729 | Acc: 78.990 7899/10000
Saving..


100%|██████████| 391/391 [00:42<00:00,  9.11it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 53 Loss: 1.711 | Acc: 81.034 40517/50000


100%|██████████| 79/79 [00:03<00:00, 26.07it/s]


[Test] 53 Loss: 1.724 | Acc: 79.270 7927/10000
Saving..


100%|██████████| 391/391 [00:42<00:00,  9.10it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 54 Loss: 1.705 | Acc: 81.290 40645/50000


100%|██████████| 79/79 [00:03<00:00, 26.30it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 54 Loss: 1.723 | Acc: 79.120 7912/10000


100%|██████████| 391/391 [00:42<00:00,  9.10it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 55 Loss: 1.703 | Acc: 81.202 40601/50000


100%|██████████| 79/79 [00:03<00:00, 26.31it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 55 Loss: 1.739 | Acc: 76.790 7679/10000


100%|██████████| 391/391 [00:42<00:00,  9.12it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 56 Loss: 1.699 | Acc: 81.316 40658/50000


100%|██████████| 79/79 [00:03<00:00, 26.28it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 56 Loss: 1.723 | Acc: 78.430 7843/10000


100%|██████████| 391/391 [00:42<00:00,  9.10it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 57 Loss: 1.691 | Acc: 81.910 40955/50000


100%|██████████| 79/79 [00:02<00:00, 26.34it/s]


[Test] 57 Loss: 1.711 | Acc: 79.400 7940/10000
Saving..


100%|██████████| 391/391 [00:42<00:00,  9.11it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 58 Loss: 1.689 | Acc: 81.840 40920/50000


100%|██████████| 79/79 [00:03<00:00, 26.21it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 58 Loss: 1.715 | Acc: 78.680 7868/10000


100%|██████████| 391/391 [00:43<00:00,  9.08it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 59 Loss: 1.686 | Acc: 81.864 40932/50000


100%|██████████| 79/79 [00:03<00:00, 26.01it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 59 Loss: 1.708 | Acc: 79.380 7938/10000


100%|██████████| 391/391 [00:43<00:00,  9.09it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 60 Loss: 1.683 | Acc: 81.964 40982/50000


100%|██████████| 79/79 [00:02<00:00, 26.40it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 60 Loss: 1.710 | Acc: 78.680 7868/10000


100%|██████████| 391/391 [00:43<00:00,  9.09it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 61 Loss: 1.682 | Acc: 81.788 40894/50000


100%|██████████| 79/79 [00:02<00:00, 26.35it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 61 Loss: 1.704 | Acc: 79.110 7911/10000


100%|██████████| 391/391 [00:42<00:00,  9.12it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 62 Loss: 1.676 | Acc: 82.278 41139/50000


100%|██████████| 79/79 [00:03<00:00, 26.25it/s]


[Test] 62 Loss: 1.694 | Acc: 80.080 8008/10000
Saving..


100%|██████████| 391/391 [00:42<00:00,  9.11it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 63 Loss: 1.673 | Acc: 82.350 41175/50000


100%|██████████| 79/79 [00:03<00:00, 26.25it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 63 Loss: 1.704 | Acc: 78.710 7871/10000


100%|██████████| 391/391 [00:42<00:00,  9.12it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 64 Loss: 1.670 | Acc: 82.562 41281/50000


100%|██████████| 79/79 [00:03<00:00, 26.27it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 64 Loss: 1.694 | Acc: 79.790 7979/10000


100%|██████████| 391/391 [00:42<00:00,  9.13it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 65 Loss: 1.665 | Acc: 82.798 41399/50000


100%|██████████| 79/79 [00:02<00:00, 26.35it/s]


[Test] 65 Loss: 1.687 | Acc: 80.310 8031/10000
Saving..


100%|██████████| 391/391 [00:42<00:00,  9.12it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 66 Loss: 1.668 | Acc: 82.312 41156/50000


100%|██████████| 79/79 [00:03<00:00, 26.28it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 66 Loss: 1.694 | Acc: 79.420 7942/10000


100%|██████████| 391/391 [00:42<00:00,  9.13it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 67 Loss: 1.661 | Acc: 82.934 41467/50000


100%|██████████| 79/79 [00:02<00:00, 26.37it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 67 Loss: 1.691 | Acc: 79.580 7958/10000


100%|██████████| 391/391 [00:42<00:00,  9.13it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 68 Loss: 1.662 | Acc: 82.702 41351/50000


100%|██████████| 79/79 [00:03<00:00, 26.13it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 68 Loss: 1.693 | Acc: 79.310 7931/10000


100%|██████████| 391/391 [00:42<00:00,  9.11it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 69 Loss: 1.663 | Acc: 82.422 41211/50000


100%|██████████| 79/79 [00:02<00:00, 26.34it/s]


[Test] 69 Loss: 1.681 | Acc: 80.350 8035/10000
Saving..


100%|██████████| 391/391 [00:42<00:00,  9.13it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 70 Loss: 1.658 | Acc: 82.840 41420/50000


100%|██████████| 79/79 [00:03<00:00, 26.05it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 70 Loss: 1.686 | Acc: 79.730 7973/10000


100%|██████████| 391/391 [00:42<00:00,  9.12it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 71 Loss: 1.652 | Acc: 83.284 41642/50000


100%|██████████| 79/79 [00:03<00:00, 26.32it/s]


[Test] 71 Loss: 1.675 | Acc: 80.760 8076/10000
Saving..


100%|██████████| 391/391 [00:42<00:00,  9.12it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 72 Loss: 1.652 | Acc: 83.110 41555/50000


100%|██████████| 79/79 [00:02<00:00, 26.35it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 72 Loss: 1.678 | Acc: 80.400 8040/10000


100%|██████████| 391/391 [00:42<00:00,  9.13it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 73 Loss: 1.652 | Acc: 83.088 41544/50000


100%|██████████| 79/79 [00:03<00:00, 26.02it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 73 Loss: 1.678 | Acc: 80.270 8027/10000


100%|██████████| 391/391 [00:42<00:00,  9.12it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 74 Loss: 1.646 | Acc: 83.554 41777/50000


100%|██████████| 79/79 [00:02<00:00, 26.34it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 74 Loss: 1.680 | Acc: 80.030 8003/10000


100%|██████████| 391/391 [00:42<00:00,  9.12it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 75 Loss: 1.644 | Acc: 83.694 41847/50000


100%|██████████| 79/79 [00:03<00:00, 26.13it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 75 Loss: 1.675 | Acc: 80.360 8036/10000


100%|██████████| 391/391 [00:42<00:00,  9.11it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 76 Loss: 1.645 | Acc: 83.472 41736/50000


100%|██████████| 79/79 [00:02<00:00, 26.38it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 76 Loss: 1.672 | Acc: 80.700 8070/10000


100%|██████████| 391/391 [00:42<00:00,  9.11it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 77 Loss: 1.642 | Acc: 83.700 41850/50000


100%|██████████| 79/79 [00:03<00:00, 26.21it/s]


[Test] 77 Loss: 1.670 | Acc: 80.860 8086/10000
Saving..


100%|██████████| 391/391 [00:42<00:00,  9.11it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 78 Loss: 1.640 | Acc: 83.848 41924/50000


100%|██████████| 79/79 [00:03<00:00, 26.33it/s]


[Test] 78 Loss: 1.667 | Acc: 81.090 8109/10000
Saving..


100%|██████████| 391/391 [00:42<00:00,  9.09it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 79 Loss: 1.639 | Acc: 83.850 41925/50000


100%|██████████| 79/79 [00:03<00:00, 26.28it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 79 Loss: 1.668 | Acc: 80.830 8083/10000


100%|██████████| 391/391 [00:43<00:00,  9.09it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 80 Loss: 1.637 | Acc: 83.980 41990/50000


100%|██████████| 79/79 [00:03<00:00, 26.24it/s]


[Test] 80 Loss: 1.664 | Acc: 81.230 8123/10000
Saving..


100%|██████████| 391/391 [00:42<00:00,  9.13it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 81 Loss: 1.636 | Acc: 84.076 42038/50000


100%|██████████| 79/79 [00:03<00:00, 26.16it/s]


[Test] 81 Loss: 1.662 | Acc: 81.300 8130/10000
Saving..


100%|██████████| 391/391 [00:42<00:00,  9.10it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 82 Loss: 1.632 | Acc: 84.392 42196/50000


100%|██████████| 79/79 [00:03<00:00, 26.24it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 82 Loss: 1.667 | Acc: 80.690 8069/10000


100%|██████████| 391/391 [00:42<00:00,  9.12it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 83 Loss: 1.630 | Acc: 84.478 42239/50000


100%|██████████| 79/79 [00:03<00:00, 26.29it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 83 Loss: 1.666 | Acc: 80.620 8062/10000


100%|██████████| 391/391 [00:42<00:00,  9.11it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 84 Loss: 1.630 | Acc: 84.446 42223/50000


100%|██████████| 79/79 [00:03<00:00, 26.26it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 84 Loss: 1.664 | Acc: 80.810 8081/10000


100%|██████████| 391/391 [00:42<00:00,  9.11it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 85 Loss: 1.634 | Acc: 84.022 42011/50000


100%|██████████| 79/79 [00:03<00:00, 26.32it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 85 Loss: 1.667 | Acc: 80.550 8055/10000


100%|██████████| 391/391 [00:43<00:00,  9.09it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 86 Loss: 1.629 | Acc: 84.478 42239/50000


100%|██████████| 79/79 [00:02<00:00, 26.40it/s]


[Test] 86 Loss: 1.659 | Acc: 81.340 8134/10000
Saving..


100%|██████████| 391/391 [00:43<00:00,  9.08it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 87 Loss: 1.627 | Acc: 84.574 42287/50000


100%|██████████| 79/79 [00:03<00:00, 26.32it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 87 Loss: 1.662 | Acc: 80.930 8093/10000


100%|██████████| 391/391 [00:42<00:00,  9.10it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 88 Loss: 1.621 | Acc: 85.158 42579/50000


100%|██████████| 79/79 [00:03<00:00, 26.27it/s]


[Test] 88 Loss: 1.655 | Acc: 81.720 8172/10000
Saving..


100%|██████████| 391/391 [00:42<00:00,  9.10it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 89 Loss: 1.622 | Acc: 85.024 42512/50000


100%|██████████| 79/79 [00:03<00:00, 26.07it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 89 Loss: 1.658 | Acc: 81.340 8134/10000


100%|██████████| 391/391 [00:42<00:00,  9.11it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 90 Loss: 1.621 | Acc: 85.026 42513/50000


100%|██████████| 79/79 [00:03<00:00, 26.17it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 90 Loss: 1.655 | Acc: 81.520 8152/10000


100%|██████████| 391/391 [00:42<00:00,  9.11it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 91 Loss: 1.619 | Acc: 85.294 42647/50000


100%|██████████| 79/79 [00:03<00:00, 25.91it/s]


[Test] 91 Loss: 1.651 | Acc: 81.780 8178/10000
Saving..


100%|██████████| 391/391 [00:43<00:00,  9.09it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 92 Loss: 1.619 | Acc: 85.162 42581/50000


100%|██████████| 79/79 [00:02<00:00, 26.37it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 92 Loss: 1.658 | Acc: 81.130 8113/10000


100%|██████████| 391/391 [00:43<00:00,  9.09it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 93 Loss: 1.616 | Acc: 85.478 42739/50000


100%|██████████| 79/79 [00:03<00:00, 26.28it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 93 Loss: 1.658 | Acc: 81.130 8113/10000


100%|██████████| 391/391 [00:43<00:00,  9.08it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 94 Loss: 1.615 | Acc: 85.584 42792/50000


100%|██████████| 79/79 [00:03<00:00, 26.26it/s]


[Test] 94 Loss: 1.651 | Acc: 81.930 8193/10000
Saving..


100%|██████████| 391/391 [00:42<00:00,  9.10it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 95 Loss: 1.616 | Acc: 85.370 42685/50000


100%|██████████| 79/79 [00:03<00:00, 26.25it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 95 Loss: 1.657 | Acc: 81.320 8132/10000


100%|██████████| 391/391 [00:42<00:00,  9.09it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 96 Loss: 1.615 | Acc: 85.518 42759/50000


100%|██████████| 79/79 [00:03<00:00, 26.30it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 96 Loss: 1.654 | Acc: 81.470 8147/10000


100%|██████████| 391/391 [00:42<00:00,  9.10it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 97 Loss: 1.612 | Acc: 85.678 42839/50000


100%|██████████| 79/79 [00:03<00:00, 26.21it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 97 Loss: 1.657 | Acc: 81.130 8113/10000


100%|██████████| 391/391 [00:42<00:00,  9.10it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 98 Loss: 1.610 | Acc: 85.890 42945/50000


100%|██████████| 79/79 [00:03<00:00, 26.09it/s]


[Test] 98 Loss: 1.647 | Acc: 82.160 8216/10000
Saving..


100%|██████████| 391/391 [00:42<00:00,  9.10it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 99 Loss: 1.612 | Acc: 85.736 42868/50000


100%|██████████| 79/79 [00:03<00:00, 26.25it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 99 Loss: 1.650 | Acc: 81.780 8178/10000


100%|██████████| 391/391 [00:43<00:00,  9.08it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 100 Loss: 1.608 | Acc: 86.096 43048/50000


100%|██████████| 79/79 [00:03<00:00, 26.17it/s]


[Test] 100 Loss: 1.645 | Acc: 82.320 8232/10000
Saving..


100%|██████████| 391/391 [00:42<00:00,  9.10it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 101 Loss: 1.611 | Acc: 85.740 42870/50000


100%|██████████| 79/79 [00:03<00:00, 26.30it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 101 Loss: 1.645 | Acc: 82.280 8228/10000


100%|██████████| 391/391 [00:43<00:00,  9.08it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 102 Loss: 1.606 | Acc: 86.242 43121/50000


100%|██████████| 79/79 [00:03<00:00, 26.29it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 102 Loss: 1.648 | Acc: 81.840 8184/10000


100%|██████████| 391/391 [00:42<00:00,  9.10it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 103 Loss: 1.608 | Acc: 86.038 43019/50000


100%|██████████| 79/79 [00:03<00:00, 26.13it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 103 Loss: 1.645 | Acc: 82.220 8222/10000


100%|██████████| 391/391 [00:43<00:00,  9.07it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 104 Loss: 1.605 | Acc: 86.288 43144/50000


100%|██████████| 79/79 [00:03<00:00, 26.27it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 104 Loss: 1.654 | Acc: 81.390 8139/10000


100%|██████████| 391/391 [00:43<00:00,  9.09it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 105 Loss: 1.605 | Acc: 86.236 43118/50000


100%|██████████| 79/79 [00:03<00:00, 25.92it/s]


[Test] 105 Loss: 1.644 | Acc: 82.350 8235/10000
Saving..


100%|██████████| 391/391 [00:42<00:00,  9.10it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 106 Loss: 1.600 | Acc: 86.722 43361/50000


100%|██████████| 79/79 [00:03<00:00, 26.26it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 106 Loss: 1.648 | Acc: 81.770 8177/10000


100%|██████████| 391/391 [00:43<00:00,  9.09it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 107 Loss: 1.600 | Acc: 86.690 43345/50000


100%|██████████| 79/79 [00:03<00:00, 26.27it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 107 Loss: 1.650 | Acc: 81.620 8162/10000


100%|██████████| 391/391 [00:42<00:00,  9.10it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 108 Loss: 1.602 | Acc: 86.500 43250/50000


100%|██████████| 79/79 [00:02<00:00, 26.34it/s]


[Test] 108 Loss: 1.640 | Acc: 82.720 8272/10000
Saving..


100%|██████████| 391/391 [00:43<00:00,  9.09it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 109 Loss: 1.597 | Acc: 87.002 43501/50000


100%|██████████| 79/79 [00:03<00:00, 26.15it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 109 Loss: 1.648 | Acc: 81.730 8173/10000


100%|██████████| 391/391 [00:42<00:00,  9.10it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 110 Loss: 1.597 | Acc: 86.940 43470/50000


100%|██████████| 79/79 [00:03<00:00, 26.33it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 110 Loss: 1.645 | Acc: 82.140 8214/10000


100%|██████████| 391/391 [00:42<00:00,  9.09it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 111 Loss: 1.597 | Acc: 86.886 43443/50000


100%|██████████| 79/79 [00:03<00:00, 26.08it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 111 Loss: 1.639 | Acc: 82.670 8267/10000


100%|██████████| 391/391 [00:43<00:00,  9.09it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 112 Loss: 1.596 | Acc: 87.020 43510/50000


100%|██████████| 79/79 [00:03<00:00, 26.21it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 112 Loss: 1.644 | Acc: 82.210 8221/10000


100%|██████████| 391/391 [00:43<00:00,  9.09it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 113 Loss: 1.597 | Acc: 86.956 43478/50000


100%|██████████| 79/79 [00:03<00:00, 25.89it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 113 Loss: 1.643 | Acc: 82.260 8226/10000


100%|██████████| 391/391 [00:43<00:00,  9.08it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 114 Loss: 1.596 | Acc: 86.998 43499/50000


100%|██████████| 79/79 [00:03<00:00, 26.26it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 114 Loss: 1.647 | Acc: 81.790 8179/10000


100%|██████████| 391/391 [00:43<00:00,  9.07it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 115 Loss: 1.595 | Acc: 87.058 43529/50000


100%|██████████| 79/79 [00:03<00:00, 26.33it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 115 Loss: 1.643 | Acc: 82.300 8230/10000


100%|██████████| 391/391 [00:42<00:00,  9.10it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 116 Loss: 1.592 | Acc: 87.346 43673/50000


100%|██████████| 79/79 [00:03<00:00, 26.32it/s]


[Test] 116 Loss: 1.637 | Acc: 82.880 8288/10000
Saving..


100%|██████████| 391/391 [00:43<00:00,  9.09it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 117 Loss: 1.592 | Acc: 87.358 43679/50000


100%|██████████| 79/79 [00:03<00:00, 26.18it/s]


[Test] 117 Loss: 1.637 | Acc: 82.930 8293/10000
Saving..


100%|██████████| 391/391 [00:43<00:00,  9.09it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 118 Loss: 1.591 | Acc: 87.412 43706/50000


100%|██████████| 79/79 [00:03<00:00, 26.32it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 118 Loss: 1.636 | Acc: 82.850 8285/10000


100%|██████████| 391/391 [00:42<00:00,  9.11it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 119 Loss: 1.589 | Acc: 87.696 43848/50000


100%|██████████| 79/79 [00:03<00:00, 26.30it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 119 Loss: 1.637 | Acc: 82.770 8277/10000


100%|██████████| 391/391 [00:43<00:00,  9.09it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 120 Loss: 1.587 | Acc: 87.850 43925/50000


100%|██████████| 79/79 [00:03<00:00, 26.20it/s]


[Test] 120 Loss: 1.633 | Acc: 83.180 8318/10000
Saving..


100%|██████████| 391/391 [00:43<00:00,  9.09it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 121 Loss: 1.589 | Acc: 87.578 43789/50000


100%|██████████| 79/79 [00:03<00:00, 26.28it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 121 Loss: 1.639 | Acc: 82.610 8261/10000


100%|██████████| 391/391 [00:43<00:00,  9.08it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 122 Loss: 1.589 | Acc: 87.624 43812/50000


100%|██████████| 79/79 [00:03<00:00, 26.23it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 122 Loss: 1.637 | Acc: 82.810 8281/10000


100%|██████████| 391/391 [00:43<00:00,  9.09it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 123 Loss: 1.585 | Acc: 87.980 43990/50000


100%|██████████| 79/79 [00:03<00:00, 26.29it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 123 Loss: 1.639 | Acc: 82.610 8261/10000


100%|██████████| 391/391 [00:42<00:00,  9.11it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 124 Loss: 1.583 | Acc: 88.216 44108/50000


100%|██████████| 79/79 [00:03<00:00, 26.28it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 124 Loss: 1.637 | Acc: 82.810 8281/10000


100%|██████████| 391/391 [00:42<00:00,  9.09it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 125 Loss: 1.585 | Acc: 88.002 44001/50000


100%|██████████| 79/79 [00:03<00:00, 26.20it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 125 Loss: 1.634 | Acc: 83.070 8307/10000


100%|██████████| 391/391 [00:42<00:00,  9.10it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 126 Loss: 1.582 | Acc: 88.214 44107/50000


100%|██████████| 79/79 [00:03<00:00, 26.13it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 126 Loss: 1.637 | Acc: 82.680 8268/10000


100%|██████████| 391/391 [00:42<00:00,  9.09it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 127 Loss: 1.581 | Acc: 88.384 44192/50000


100%|██████████| 79/79 [00:03<00:00, 26.23it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 127 Loss: 1.634 | Acc: 83.010 8301/10000


100%|██████████| 391/391 [00:42<00:00,  9.09it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 128 Loss: 1.581 | Acc: 88.370 44185/50000


100%|██████████| 79/79 [00:03<00:00, 26.24it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 128 Loss: 1.633 | Acc: 83.150 8315/10000


100%|██████████| 391/391 [00:42<00:00,  9.10it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 129 Loss: 1.581 | Acc: 88.376 44188/50000


100%|██████████| 79/79 [00:03<00:00, 26.18it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 129 Loss: 1.637 | Acc: 82.660 8266/10000


100%|██████████| 391/391 [00:42<00:00,  9.10it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 130 Loss: 1.579 | Acc: 88.578 44289/50000


100%|██████████| 79/79 [00:03<00:00, 26.16it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 130 Loss: 1.635 | Acc: 82.810 8281/10000


100%|██████████| 391/391 [00:42<00:00,  9.11it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 131 Loss: 1.577 | Acc: 88.714 44357/50000


100%|██████████| 79/79 [00:03<00:00, 26.23it/s]


[Test] 131 Loss: 1.631 | Acc: 83.300 8330/10000
Saving..


100%|██████████| 391/391 [00:42<00:00,  9.10it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 132 Loss: 1.578 | Acc: 88.616 44308/50000


100%|██████████| 79/79 [00:03<00:00, 25.97it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 132 Loss: 1.631 | Acc: 83.270 8327/10000


100%|██████████| 391/391 [00:43<00:00,  9.08it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 133 Loss: 1.576 | Acc: 88.866 44433/50000


100%|██████████| 79/79 [00:03<00:00, 26.16it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 133 Loss: 1.632 | Acc: 83.100 8310/10000


100%|██████████| 391/391 [00:43<00:00,  9.08it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 134 Loss: 1.576 | Acc: 88.772 44386/50000


100%|██████████| 79/79 [00:03<00:00, 26.29it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 134 Loss: 1.631 | Acc: 83.270 8327/10000


100%|██████████| 391/391 [00:43<00:00,  9.08it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 135 Loss: 1.576 | Acc: 88.830 44415/50000


100%|██████████| 79/79 [00:03<00:00, 26.20it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 135 Loss: 1.633 | Acc: 83.130 8313/10000


100%|██████████| 391/391 [00:43<00:00,  9.09it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 136 Loss: 1.577 | Acc: 88.752 44376/50000


100%|██████████| 79/79 [00:03<00:00, 26.21it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 136 Loss: 1.632 | Acc: 83.280 8328/10000


100%|██████████| 391/391 [00:42<00:00,  9.11it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 137 Loss: 1.575 | Acc: 88.912 44456/50000


100%|██████████| 79/79 [00:03<00:00, 26.22it/s]


[Test] 137 Loss: 1.627 | Acc: 83.630 8363/10000
Saving..


100%|██████████| 391/391 [00:42<00:00,  9.10it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 138 Loss: 1.574 | Acc: 88.994 44497/50000


100%|██████████| 79/79 [00:03<00:00, 26.16it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 138 Loss: 1.634 | Acc: 83.050 8305/10000


100%|██████████| 391/391 [00:43<00:00,  9.09it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 139 Loss: 1.575 | Acc: 88.852 44426/50000


100%|██████████| 79/79 [00:03<00:00, 26.14it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 139 Loss: 1.634 | Acc: 83.070 8307/10000


100%|██████████| 391/391 [00:42<00:00,  9.09it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 140 Loss: 1.575 | Acc: 88.904 44452/50000


100%|██████████| 79/79 [00:03<00:00, 26.28it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 140 Loss: 1.630 | Acc: 83.420 8342/10000


100%|██████████| 391/391 [00:42<00:00,  9.12it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 141 Loss: 1.573 | Acc: 89.126 44563/50000


100%|██████████| 79/79 [00:03<00:00, 26.23it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 141 Loss: 1.631 | Acc: 83.320 8332/10000


100%|██████████| 391/391 [00:43<00:00,  9.08it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 142 Loss: 1.572 | Acc: 89.208 44604/50000


100%|██████████| 79/79 [00:03<00:00, 25.79it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 142 Loss: 1.629 | Acc: 83.570 8357/10000


100%|██████████| 391/391 [00:42<00:00,  9.09it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 143 Loss: 1.570 | Acc: 89.408 44704/50000


100%|██████████| 79/79 [00:02<00:00, 26.37it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 143 Loss: 1.631 | Acc: 83.290 8329/10000


100%|██████████| 391/391 [00:42<00:00,  9.10it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 144 Loss: 1.570 | Acc: 89.384 44692/50000


100%|██████████| 79/79 [00:03<00:00, 26.21it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 144 Loss: 1.628 | Acc: 83.530 8353/10000


100%|██████████| 391/391 [00:42<00:00,  9.11it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 145 Loss: 1.570 | Acc: 89.446 44723/50000


100%|██████████| 79/79 [00:03<00:00, 26.21it/s]


[Test] 145 Loss: 1.626 | Acc: 83.760 8376/10000
Saving..


100%|██████████| 391/391 [00:42<00:00,  9.10it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 146 Loss: 1.571 | Acc: 89.304 44652/50000


100%|██████████| 79/79 [00:03<00:00, 26.16it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 146 Loss: 1.626 | Acc: 83.660 8366/10000


100%|██████████| 391/391 [00:43<00:00,  9.07it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 147 Loss: 1.569 | Acc: 89.492 44746/50000


100%|██████████| 79/79 [00:03<00:00, 26.31it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 147 Loss: 1.627 | Acc: 83.760 8376/10000


100%|██████████| 391/391 [00:42<00:00,  9.10it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 148 Loss: 1.567 | Acc: 89.688 44844/50000


100%|██████████| 79/79 [00:03<00:00, 26.19it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 148 Loss: 1.627 | Acc: 83.650 8365/10000


100%|██████████| 391/391 [00:42<00:00,  9.10it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 149 Loss: 1.567 | Acc: 89.684 44842/50000


100%|██████████| 79/79 [00:03<00:00, 26.30it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 149 Loss: 1.629 | Acc: 83.480 8348/10000


100%|██████████| 391/391 [00:42<00:00,  9.10it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 150 Loss: 1.567 | Acc: 89.622 44811/50000


100%|██████████| 79/79 [00:03<00:00, 26.25it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 150 Loss: 1.627 | Acc: 83.610 8361/10000


100%|██████████| 391/391 [00:42<00:00,  9.10it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 151 Loss: 1.566 | Acc: 89.738 44869/50000


100%|██████████| 79/79 [00:03<00:00, 25.92it/s]


[Test] 151 Loss: 1.625 | Acc: 83.850 8385/10000
Saving..


100%|██████████| 391/391 [00:42<00:00,  9.12it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 152 Loss: 1.567 | Acc: 89.658 44829/50000


100%|██████████| 79/79 [00:02<00:00, 26.36it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 152 Loss: 1.628 | Acc: 83.540 8354/10000


100%|██████████| 391/391 [00:42<00:00,  9.11it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 153 Loss: 1.564 | Acc: 89.928 44964/50000


100%|██████████| 79/79 [00:03<00:00, 26.14it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 153 Loss: 1.626 | Acc: 83.810 8381/10000


100%|██████████| 391/391 [00:42<00:00,  9.11it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 154 Loss: 1.563 | Acc: 90.100 45050/50000


100%|██████████| 79/79 [00:03<00:00, 26.24it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 154 Loss: 1.626 | Acc: 83.680 8368/10000


100%|██████████| 391/391 [00:42<00:00,  9.10it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 155 Loss: 1.565 | Acc: 89.870 44935/50000


100%|██████████| 79/79 [00:03<00:00, 26.28it/s]


[Test] 155 Loss: 1.623 | Acc: 84.090 8409/10000
Saving..


100%|██████████| 391/391 [00:42<00:00,  9.11it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 156 Loss: 1.563 | Acc: 90.098 45049/50000


100%|██████████| 79/79 [00:03<00:00, 26.18it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 156 Loss: 1.625 | Acc: 83.740 8374/10000


100%|██████████| 391/391 [00:43<00:00,  9.08it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 157 Loss: 1.564 | Acc: 90.028 45014/50000


100%|██████████| 79/79 [00:03<00:00, 26.25it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 157 Loss: 1.625 | Acc: 83.830 8383/10000


100%|██████████| 391/391 [00:43<00:00,  9.09it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 158 Loss: 1.563 | Acc: 90.060 45030/50000


100%|██████████| 79/79 [00:03<00:00, 26.32it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 158 Loss: 1.626 | Acc: 83.710 8371/10000


100%|██████████| 391/391 [00:43<00:00,  9.09it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 159 Loss: 1.563 | Acc: 90.066 45033/50000


100%|██████████| 79/79 [00:03<00:00, 26.22it/s]


[Test] 159 Loss: 1.623 | Acc: 84.110 8411/10000
Saving..


100%|██████████| 391/391 [00:42<00:00,  9.10it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 160 Loss: 1.562 | Acc: 90.156 45078/50000


100%|██████████| 79/79 [00:03<00:00, 26.33it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 160 Loss: 1.623 | Acc: 84.060 8406/10000


100%|██████████| 391/391 [00:43<00:00,  9.08it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 161 Loss: 1.561 | Acc: 90.270 45135/50000


100%|██████████| 79/79 [00:03<00:00, 26.20it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 161 Loss: 1.624 | Acc: 83.980 8398/10000


100%|██████████| 391/391 [00:43<00:00,  9.09it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 162 Loss: 1.561 | Acc: 90.266 45133/50000


100%|██████████| 79/79 [00:03<00:00, 26.24it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 162 Loss: 1.624 | Acc: 83.840 8384/10000


100%|██████████| 391/391 [00:43<00:00,  9.08it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 163 Loss: 1.561 | Acc: 90.266 45133/50000


100%|██████████| 79/79 [00:02<00:00, 26.44it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 163 Loss: 1.626 | Acc: 83.710 8371/10000


100%|██████████| 391/391 [00:42<00:00,  9.10it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 164 Loss: 1.562 | Acc: 90.138 45069/50000


100%|██████████| 79/79 [00:03<00:00, 26.30it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 164 Loss: 1.623 | Acc: 84.110 8411/10000


100%|██████████| 391/391 [00:42<00:00,  9.09it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 165 Loss: 1.559 | Acc: 90.414 45207/50000


100%|██████████| 79/79 [00:03<00:00, 26.25it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 165 Loss: 1.623 | Acc: 84.030 8403/10000


100%|██████████| 391/391 [00:43<00:00,  9.09it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 166 Loss: 1.559 | Acc: 90.466 45233/50000


100%|██████████| 79/79 [00:03<00:00, 26.20it/s]


[Test] 166 Loss: 1.622 | Acc: 84.160 8416/10000
Saving..


100%|██████████| 391/391 [00:42<00:00,  9.10it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 167 Loss: 1.559 | Acc: 90.450 45225/50000


100%|██████████| 79/79 [00:03<00:00, 26.18it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 167 Loss: 1.622 | Acc: 84.150 8415/10000


100%|██████████| 391/391 [00:42<00:00,  9.11it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 168 Loss: 1.559 | Acc: 90.476 45238/50000


100%|██████████| 79/79 [00:02<00:00, 26.37it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 168 Loss: 1.623 | Acc: 83.920 8392/10000


100%|██████████| 391/391 [00:43<00:00,  9.09it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 169 Loss: 1.557 | Acc: 90.620 45310/50000


100%|██████████| 79/79 [00:02<00:00, 26.40it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 169 Loss: 1.624 | Acc: 83.850 8385/10000


100%|██████████| 391/391 [00:43<00:00,  9.08it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 170 Loss: 1.559 | Acc: 90.460 45230/50000


100%|██████████| 79/79 [00:03<00:00, 26.18it/s]


[Test] 170 Loss: 1.622 | Acc: 84.180 8418/10000
Saving..


100%|██████████| 391/391 [00:42<00:00,  9.10it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 171 Loss: 1.557 | Acc: 90.618 45309/50000


100%|██████████| 79/79 [00:03<00:00, 26.31it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 171 Loss: 1.622 | Acc: 84.050 8405/10000


100%|██████████| 391/391 [00:43<00:00,  9.09it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 172 Loss: 1.558 | Acc: 90.556 45278/50000


100%|██████████| 79/79 [00:03<00:00, 26.30it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 172 Loss: 1.623 | Acc: 84.080 8408/10000


100%|██████████| 391/391 [00:43<00:00,  9.09it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 173 Loss: 1.557 | Acc: 90.680 45340/50000


100%|██████████| 79/79 [00:03<00:00, 26.27it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 173 Loss: 1.623 | Acc: 84.040 8404/10000


100%|██████████| 391/391 [00:42<00:00,  9.10it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 174 Loss: 1.557 | Acc: 90.664 45332/50000


100%|██████████| 79/79 [00:03<00:00, 26.15it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 174 Loss: 1.623 | Acc: 84.040 8404/10000


100%|██████████| 391/391 [00:42<00:00,  9.10it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 175 Loss: 1.556 | Acc: 90.748 45374/50000


100%|██████████| 79/79 [00:02<00:00, 26.35it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 175 Loss: 1.622 | Acc: 84.030 8403/10000


100%|██████████| 391/391 [00:42<00:00,  9.11it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 176 Loss: 1.556 | Acc: 90.754 45377/50000


100%|██████████| 79/79 [00:03<00:00, 26.14it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 176 Loss: 1.622 | Acc: 84.160 8416/10000


100%|██████████| 391/391 [00:42<00:00,  9.09it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 177 Loss: 1.556 | Acc: 90.742 45371/50000


100%|██████████| 79/79 [00:03<00:00, 25.99it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 177 Loss: 1.623 | Acc: 84.010 8401/10000


100%|██████████| 391/391 [00:43<00:00,  9.09it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 178 Loss: 1.556 | Acc: 90.716 45358/50000


100%|██████████| 79/79 [00:03<00:00, 26.25it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 178 Loss: 1.622 | Acc: 84.100 8410/10000


100%|██████████| 391/391 [00:42<00:00,  9.10it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 179 Loss: 1.555 | Acc: 90.830 45415/50000


100%|██████████| 79/79 [00:03<00:00, 26.29it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 179 Loss: 1.621 | Acc: 84.180 8418/10000


100%|██████████| 391/391 [00:42<00:00,  9.10it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 180 Loss: 1.557 | Acc: 90.676 45338/50000


100%|██████████| 79/79 [00:03<00:00, 26.17it/s]


[Test] 180 Loss: 1.621 | Acc: 84.250 8425/10000
Saving..


100%|██████████| 391/391 [00:42<00:00,  9.11it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 181 Loss: 1.556 | Acc: 90.794 45397/50000


100%|██████████| 79/79 [00:03<00:00, 26.29it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 181 Loss: 1.622 | Acc: 84.220 8422/10000


100%|██████████| 391/391 [00:42<00:00,  9.10it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 182 Loss: 1.556 | Acc: 90.734 45367/50000


100%|██████████| 79/79 [00:02<00:00, 26.35it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 182 Loss: 1.621 | Acc: 84.250 8425/10000


100%|██████████| 391/391 [00:42<00:00,  9.10it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 183 Loss: 1.556 | Acc: 90.746 45373/50000


100%|██████████| 79/79 [00:03<00:00, 26.16it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 183 Loss: 1.621 | Acc: 84.220 8422/10000


100%|██████████| 391/391 [00:42<00:00,  9.11it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 184 Loss: 1.555 | Acc: 90.820 45410/50000


100%|██████████| 79/79 [00:03<00:00, 26.25it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 184 Loss: 1.621 | Acc: 84.210 8421/10000


100%|██████████| 391/391 [00:43<00:00,  9.09it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 185 Loss: 1.556 | Acc: 90.714 45357/50000


100%|██████████| 79/79 [00:03<00:00, 26.23it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 185 Loss: 1.622 | Acc: 84.090 8409/10000


100%|██████████| 391/391 [00:42<00:00,  9.11it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 186 Loss: 1.555 | Acc: 90.888 45444/50000


100%|██████████| 79/79 [00:03<00:00, 25.74it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 186 Loss: 1.622 | Acc: 84.130 8413/10000


100%|██████████| 391/391 [00:43<00:00,  9.09it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 187 Loss: 1.554 | Acc: 90.926 45463/50000


100%|██████████| 79/79 [00:03<00:00, 26.10it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 187 Loss: 1.621 | Acc: 84.180 8418/10000


100%|██████████| 391/391 [00:42<00:00,  9.11it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 188 Loss: 1.556 | Acc: 90.744 45372/50000


100%|██████████| 79/79 [00:03<00:00, 26.15it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 188 Loss: 1.621 | Acc: 84.160 8416/10000


100%|██████████| 391/391 [00:42<00:00,  9.10it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 189 Loss: 1.555 | Acc: 90.874 45437/50000


100%|██████████| 79/79 [00:02<00:00, 26.40it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 189 Loss: 1.622 | Acc: 84.100 8410/10000


100%|██████████| 391/391 [00:42<00:00,  9.10it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 190 Loss: 1.556 | Acc: 90.712 45356/50000


100%|██████████| 79/79 [00:03<00:00, 26.22it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 190 Loss: 1.622 | Acc: 84.090 8409/10000


100%|██████████| 391/391 [00:42<00:00,  9.10it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 191 Loss: 1.555 | Acc: 90.812 45406/50000


100%|██████████| 79/79 [00:02<00:00, 26.39it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 191 Loss: 1.622 | Acc: 84.060 8406/10000


100%|██████████| 391/391 [00:42<00:00,  9.10it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 192 Loss: 1.555 | Acc: 90.824 45412/50000


100%|██████████| 79/79 [00:02<00:00, 26.37it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 192 Loss: 1.622 | Acc: 84.130 8413/10000


100%|██████████| 391/391 [00:42<00:00,  9.10it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 193 Loss: 1.555 | Acc: 90.882 45441/50000


100%|██████████| 79/79 [00:03<00:00, 26.21it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 193 Loss: 1.622 | Acc: 84.130 8413/10000


100%|██████████| 391/391 [00:42<00:00,  9.11it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 194 Loss: 1.556 | Acc: 90.750 45375/50000


100%|██████████| 79/79 [00:03<00:00, 26.32it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 194 Loss: 1.622 | Acc: 84.150 8415/10000


100%|██████████| 391/391 [00:43<00:00,  9.09it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 195 Loss: 1.553 | Acc: 91.054 45527/50000


100%|██████████| 79/79 [00:03<00:00, 26.33it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 195 Loss: 1.622 | Acc: 84.050 8405/10000


100%|██████████| 391/391 [00:42<00:00,  9.11it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 196 Loss: 1.555 | Acc: 90.814 45407/50000


100%|██████████| 79/79 [00:03<00:00, 26.32it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 196 Loss: 1.622 | Acc: 84.110 8411/10000


100%|██████████| 391/391 [00:42<00:00,  9.10it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 197 Loss: 1.555 | Acc: 90.770 45385/50000


100%|██████████| 79/79 [00:03<00:00, 25.99it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 197 Loss: 1.622 | Acc: 84.130 8413/10000


100%|██████████| 391/391 [00:43<00:00,  9.06it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 198 Loss: 1.554 | Acc: 90.930 45465/50000


100%|██████████| 79/79 [00:03<00:00, 26.29it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 198 Loss: 1.621 | Acc: 84.160 8416/10000


100%|██████████| 391/391 [00:42<00:00,  9.10it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 199 Loss: 1.555 | Acc: 90.846 45423/50000


100%|██████████| 79/79 [00:03<00:00, 26.19it/s]

[Test] 199 Loss: 1.622 | Acc: 84.140 8414/10000





In [30]:
best_acc

84.25

In [None]:
'''
flows = [
    ActNorm2D(3),
    irf.ConvResidualFlow(3, [32, 32], kernels=5, activation=actf),
    irf.InvertiblePooling(2),
    ActNorm2D(12),
    irf.ConvResidualFlow(12, [64, 64], kernels=5, activation=actf),
    irf.ConvResidualFlow(12, [64, 64], kernels=5, activation=actf),
    irf.InvertiblePooling(2),
    ActNorm2D(48),
    irf.ConvResidualFlow(48, [128], activation=actf),
    irf.ConvResidualFlow(48, [128], activation=actf),
    irf.InvertiblePooling(2),
    ActNorm2D(192),
    irf.ConvResidualFlow(192, [128], activation=actf),
    irf.ConvResidualFlow(192, [128], activation=actf),
    irf.Flatten(img_size=(192, 4, 4))
        ]
'''
### ACC: 76.6? | 69.95 --> Inv + Connected Linear
### ACC: 67.03 | 67.61 --> Ord + Connected Linear


'''
flows = [
    ActNorm2D(3),
    irf.ConvResidualFlow(3, [32, 32], kernels=5, activation=actf),
    irf.InvertiblePooling(2),
    ActNorm2D(12),
    irf.ConvResidualFlow(12, [64, 64], kernels=5, activation=actf),
    irf.ConvResidualFlow(12, [64, 64], kernels=5, activation=actf),
    irf.InvertiblePooling(2),
    ActNorm2D(48),
    irf.ConvResidualFlow(48, [128, 128], kernels=5, activation=actf),
    irf.ConvResidualFlow(48, [128, 128], kernels=5, activation=actf),
    irf.InvertiblePooling(2),
    ActNorm2D(192),
    irf.ConvResidualFlow(192, [256, 256], kernels=5, activation=actf),
    irf.ConvResidualFlow(192, [256, 256], kernels=5, activation=actf),
    irf.Flatten(img_size=(192, 4, 4)),
    irf.ResidualFlow(3072, [4096], activation=actf),
        ]
'''
### ACC:  --> Inv + Connected Linear
### ACC:  --> Ord + Connected Linear

'''
actf = irf.Swish
flows = [
#     ActNorm2D(3),
    nn.BatchNorm2d(3),
    irf.ConvResidualFlow(3, [32, 32], kernels=5, activation=actf),
    irf.InvertiblePooling(2),
#     ActNorm2D(12),
    nn.BatchNorm2d(12),
    irf.ConvResidualFlow(12, [64, 64], kernels=5, activation=actf),
#     ActNorm2D(12),
    nn.BatchNorm2d(12),
    irf.ConvResidualFlow(12, [64, 64], kernels=5, activation=actf),
    irf.InvertiblePooling(2),
#     ActNorm2D(48),
    nn.BatchNorm2d(48),
    irf.ConvResidualFlow(48, [128, 128], kernels=5, activation=actf),
#     ActNorm2D(48),
    nn.BatchNorm2d(48),
    irf.ConvResidualFlow(48, [128, 128], kernels=5, activation=actf),
    irf.InvertiblePooling(2),
#     ActNorm2D(192),
    nn.BatchNorm2d(192),
    irf.ConvResidualFlow(192, [256, 256], kernels=5, activation=actf),
#     ActNorm2D(192),
    nn.BatchNorm2d(192),
    irf.ConvResidualFlow(192, [256, 256], kernels=5, activation=actf),
    irf.Flatten(img_size=(192, 4, 4)),
#     ActNorm(3072),
    nn.BatchNorm1d(3072),
    nn.Linear(3072, 3072, bias=False),
    nn.BatchNorm1d(3072),
        ]

backbone = nn.Sequential(*flows)
'''

### ACC:  --> Inv + Connected Distance
### ACC:  --> Ord + Connected Distance

In [None]:
#### non-inv nn with MLP classifier: 85.73 Acc ; using no-spectral init
#### non-inv nn with connected classifier:   Acc ; using no-spectral init
#### inv nn with connected classifier: 84.25 Acc ; spectral normalized

In [31]:
classifier.inv_temp

Parameter containing:
tensor([4.5950], device='cuda:1', requires_grad=True)

In [32]:
checkpoint = torch.load(f'./models/{model_name}.pth')
best_acc = checkpoint['acc']
start_epoch = checkpoint['epoch']

best_acc, start_epoch

(84.25, 180)

### Hard test accuracy with count per classifier

In [33]:
test_count = 0
test_acc = 0
set_count = torch.zeros(classifier.num_sets).to(device)
for xx, yy in tqdm(test_loader):
    xx, yy = xx.to(device), yy.to(device)
    with torch.no_grad():
        yout = classifier(backbone(xx), hard=True)
        set_indx, count = torch.unique(torch.argmax(classifier.cls_confidence, dim=1), return_counts=True) 
        set_count[set_indx] += count
    outputs = torch.argmax(yout, dim=1).data.cpu().numpy()
    correct = (outputs == yy.data.cpu().numpy()).astype(float).sum()
    test_acc += correct
    test_count += len(xx)

print(f'Hard Test Acc:{float(test_acc)/test_count*100:.2f}%')
print(set_count.type(torch.long).tolist())

100%|██████████| 79/79 [00:02<00:00, 26.83it/s]

Hard Test Acc:84.18%
[0, 0, 0, 0, 838, 3, 913, 0, 933, 0, 0, 0, 0, 60, 0, 0, 0, 0, 0, 940, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1021, 1002, 0, 13, 0, 1010, 640, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 1051, 10, 1, 37, 0, 0, 0, 0, 0, 19, 0, 27, 0, 0, 30, 971, 0, 8, 0, 0, 0, 0, 0, 1, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 103, 0, 0, 0, 0, 0, 361, 0]





### Hard train accuracy with count per classifier

In [34]:
test_count = 0
test_acc = 0
set_count = torch.zeros(classifier.num_sets).to(device)
for xx, yy in tqdm(train_loader):
    xx, yy = xx.to(device), yy.to(device)
    with torch.no_grad():
        yout = classifier(backbone(xx), hard=True)
        set_indx, count = torch.unique(torch.argmax(classifier.cls_confidence, dim=1), return_counts=True) 
        set_count[set_indx] += count
    outputs = torch.argmax(yout, dim=1).data.cpu().numpy()
    correct = (outputs == yy.data.cpu().numpy()).astype(float).sum()
    test_acc += correct
    test_count += len(xx)

print(f'Hard Train Acc:{float(test_acc)/test_count*100:.2f}%')
print(set_count.type(torch.long).tolist())

100%|██████████| 391/391 [00:29<00:00, 13.04it/s]

Hard Train Acc:91.31%
[0, 0, 0, 0, 4452, 25, 4597, 0, 4948, 0, 0, 0, 0, 162, 0, 0, 0, 0, 0, 4909, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5013, 4926, 0, 56, 0, 5022, 4125, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 25, 0, 5073, 38, 4, 158, 0, 0, 0, 0, 0, 60, 9, 84, 0, 0, 119, 4967, 0, 25, 0, 0, 0, 0, 0, 0, 0, 0, 30, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 284, 0, 0, 0, 0, 0, 888, 0]





In [35]:
#### Classifiers that enclose any data
torch.count_nonzero(set_count)

tensor(26, device='cuda:1')

In [36]:
#### classifier with class representation
torch.argmax(classifier.cls_weight, dim=1)

tensor([6, 3, 7, 9, 5, 0, 2, 3, 3, 3, 7, 7, 5, 3, 5, 6, 3, 1, 8, 7, 3, 1, 5, 2,
        3, 3, 3, 1, 6, 2, 1, 1, 5, 6, 9, 9, 2, 3, 4, 8, 3, 9, 8, 3, 5, 9, 9, 8,
        3, 2, 2, 1, 8, 8, 8, 0, 6, 7, 2, 3, 5, 7, 0, 9, 1, 0, 6, 6, 5, 4, 1, 1,
        9, 8, 2, 5, 1, 1, 7, 7, 8, 9, 2, 2, 9, 0, 9, 5, 4, 3, 3, 1, 5, 1, 5, 4,
        5, 9, 8, 5], device='cuda:1')

### analyze per classifier accuracy

In [None]:
test_count = 0
test_acc = 0
set_count = torch.zeros(classifier.num_sets).to(device)
set_acc = torch.zeros(classifier.num_sets).to(device)
for xx, yy in tqdm(train_loader):
    xx, yy = xx.to(device), yy.to(device)
    with torch.no_grad():
        yout = classifier(backbone(xx), hard=True)
        
    cls_indx = torch.argmax(classifier.cls_confidence, dim=1)
    set_indx, count = torch.unique(cls_indx, return_counts=True) 
    set_count[set_indx] += count
    
    outputs = torch.argmax(yout, dim=1).data.cpu().numpy()
    correct = (outputs == yy.data.cpu().numpy()).astype(float)
    
    ### class_index has 100 possible values
    for i, c in enumerate(correct):
        set_acc[cls_indx[i]] += c
    
#     print(set_acc.sum(), set_count.sum())
#     break
    test_acc += correct.sum()
    test_count += len(xx)

print(f'Hard Train Acc:{float(test_acc)/test_count*100:.2f}%')
print(set_count.type(torch.long).tolist())

In [None]:
# set_acc/set_count

In [None]:
for i, (cnt, acc, cls) in enumerate(zip(set_count.type(torch.long).tolist(),
                                   (set_acc/set_count).tolist(),
                                   torch.argmax(classifier.cls_weight, dim=1).tolist())):
    if cnt == 0: continue
    print(f"{i},\t {cnt},\t {cls}\t {acc*100:.2f}%")