In [1]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms

import matplotlib
import matplotlib.pyplot as plt
import numpy as np

import os, sys, pathlib, random, time, pickle, copy
from tqdm import tqdm

In [2]:
device = torch.device("cuda:0")
# device = torch.device("cpu")

In [3]:
import torch.optim as optim
from torch.utils import data

In [4]:
import nflib
from nflib.flows import SequentialFlow, NormalizingFlow, ActNorm, ActNorm2D, AffineConstantFlow
import nflib.coupling_flows as icf
import nflib.inn_flow as inn
import nflib.res_flow as irf

### Datasets

In [5]:
cifar_train = transforms.Compose([
    transforms.RandomCrop(size=32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.4914, 0.4822, 0.4465], # mean=[0.5071, 0.4865, 0.4409] for cifar100
        std=[0.2023, 0.1994, 0.2010], # std=[0.2009, 0.1984, 0.2023] for cifar100
    ),
])

cifar_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.4914, 0.4822, 0.4465], # mean=[0.5071, 0.4865, 0.4409] for cifar100
        std=[0.2023, 0.1994, 0.2010], # std=[0.2009, 0.1984, 0.2023] for cifar100
    ),
])

train_dataset = datasets.CIFAR10(root="../../../../../_Datasets/cifar10/", train=True, download=True, transform=cifar_train)
test_dataset = datasets.CIFAR10(root="../../../../../_Datasets/cifar10/", train=False, download=True, transform=cifar_test)

Files already downloaded and verified
Files already downloaded and verified


In [6]:
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=128, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=128, shuffle=False, num_workers=2)

In [7]:
xx, yy = iter(train_loader).next()

In [8]:
xx.shape

torch.Size([128, 3, 32, 32])

### Model

In [10]:
actf = irf.Swish
flows = [
    ActNorm2D(3),
    irf.ConvResidualFlow(3, [32, 32], kernels=5, activation=actf),
    irf.InvertiblePooling(2),
    ActNorm2D(12),
    irf.ConvResidualFlow(12, [64, 64], kernels=5, activation=actf),
    ActNorm2D(12),
    irf.ConvResidualFlow(12, [64, 64], kernels=5, activation=actf),
    irf.InvertiblePooling(2),
    ActNorm2D(48),
    irf.ConvResidualFlow(48, [128, 128], kernels=5, activation=actf),
    ActNorm2D(48),
    irf.ConvResidualFlow(48, [128, 128], kernels=5, activation=actf),
    irf.InvertiblePooling(2),
    ActNorm2D(192),
    irf.ConvResidualFlow(192, [256, 256], kernels=5, activation=actf),
    ActNorm2D(192),
    irf.ConvResidualFlow(192, [256, 256], kernels=5, activation=actf),
    irf.Flatten(img_size=(192, 4, 4)),
    ActNorm(3072),
        ]

# backbone = SequentialFlow(flows)
backbone = nn.Sequential(*flows)

In [11]:
backbone(xx).shape, 32*32*3

(torch.Size([128, 3072]), 3072)

In [12]:
print("number of params: ", sum(p.numel() for p in backbone.parameters()))

number of params:  35113480


In [13]:
def get_children(module):
    child = list(module.children())
    if len(child) == 0:
        return [module]
    children = []
    for ch in child:
        grand_ch = get_children(ch)
        children+=grand_ch
    return children

def remove_spectral_norm(model):
    for child in get_children(model):
        if hasattr(child, 'weight'):
            print("Yes", child)
            try:
                nn.utils.remove_spectral_norm(child)
                print("Success")
            except:
                print("Failed")
    return

In [14]:
# remove_spectral_norm(backbone)

In [15]:
backbone.to(device)

SequentialFlow(
  (flows): ModuleList(
    (0): ActNorm2D()
    (1): ConvResidualFlow(
      (resblock): ModuleList(
        (0): Conv2d(3, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
        (1): Swish()
        (2): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
        (3): Swish()
        (4): Conv2d(32, 3, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      )
    )
    (2): InvertiblePooling()
    (3): ActNorm2D()
    (4): ConvResidualFlow(
      (resblock): ModuleList(
        (0): Conv2d(12, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
        (1): Swish()
        (2): Conv2d(64, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
        (3): Swish()
        (4): Conv2d(64, 12, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      )
    )
    (5): ConvResidualFlow(
      (resblock): ModuleList(
        (0): Conv2d(12, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
        (1): Swish()
        (2): Conv2d(64, 64, kernel_

In [16]:
for xx, yy in train_loader:
    tt = backbone(xx.to(device))
    print(xx.shape, tt.shape)
    break

torch.Size([128, 3, 32, 32]) torch.Size([128, 3072])


In [17]:
class ConnectedClassifier_Linear(nn.Module):
    
    def __init__(self,input_dim, num_sets, output_dim, inv_temp=1):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.num_sets = num_sets
        self.inv_temp = nn.Parameter(torch.ones(1)*inv_temp)
        
        self.linear = nn.Linear(input_dim, num_sets)
#         self.linear.bias.data *= 0
#         self.linear.weight.data *= 0.1
#         self.cls_weight = nn.Parameter(torch.randn(num_sets, output_dim)/output_dim)

        init_val = torch.randn(num_sets, output_dim)*0.01
        for ns in range(num_sets):
            init_val[ns, ns%output_dim] = 0.1
        self.cls_weight = nn.Parameter(init_val)
        
        self.cls_confidence = None
        
        
    def forward(self, x, hard=False):
#         self.cls_weight.data = torch.abs(self.cls_weight.data/self.cls_weight.data.sum(dim=1, keepdim=True))
        
        x = self.linear(x)
        if hard:
            x = torch.softmax(x*1e5, dim=1)
        else:
            x = torch.softmax(x*self.inv_temp, dim=1)
        self.cls_confidence = x
        c = torch.softmax(self.cls_weight, dim=1)
#         c = self.cls_weight
        return x@c ## since both are normalized, it is also normalized

In [18]:
class ConnectedClassifier_SoftKMeans(nn.Module):
    
    def __init__(self,input_dim, num_sets, output_dim, inv_temp=1):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.num_sets = num_sets
        self.inv_temp = nn.Parameter(torch.ones(1)*inv_temp)
        
        self.centers = nn.Parameter(torch.rand(num_sets, input_dim)*2-1)
        
#         self.cls_weight = nn.Parameter(torch.ones(num_sets, output_dim)/output_dim)

        init_val = torch.randn(num_sets, output_dim)*0.01
        for ns in range(num_sets):
            init_val[ns, ns%output_dim] = 0.1
        self.cls_weight = nn.Parameter(init_val)

        self.cls_confidence = None
        
        
    def forward(self, x, hard=False):
#         self.cls_weight.data = torch.abs(self.cls_weight.data/self.cls_weight.data.sum(dim=1, keepdim=True))
        
#         x = x[:, :self.input_dim]
        dists = torch.cdist(x, self.centers)
        dists = dists/np.sqrt(self.input_dim) ### correction to make diagonal of unit square 1 in nD space
        
        if hard:
            x = torch.softmax(-dists*1e5, dim=1)
        else:
            x = torch.softmax(-dists*self.inv_temp, dim=1)
        self.cls_confidence = x
        c = torch.softmax(self.cls_weight, dim=1)
#         c = self.cls_weight
        return x@c ## since both are normalized, it is also normalized
#         return torch.softmax(x@self.cls_weight, dim=1)

In [19]:
# classifier = ConnectedClassifier_SoftKMeans(3072, 100, 10)
classifier = ConnectedClassifier_Linear(3072, 100, 10)
classifier = classifier.to(device)

In [20]:
print("number of params: ", sum(p.numel() for p in backbone.parameters()))
print("number of params: ", sum(p.numel() for p in classifier.parameters()))

number of params:  35113480
number of params:  308301


In [21]:
model = nn.Sequential(backbone, classifier).to(device)

In [22]:
print("number of params: ", sum(p.numel() for p in model.parameters()))

number of params:  35421781


## Training

In [5]:
# model_name = 'c10_inv_v0'
model_name = 'c10_ord_v0'

In [24]:
criterion = nn.CrossEntropyLoss()
# optimizer = torch.optim.SGD(model.parameters(), lr=0.1,
#                       momentum=0.9, weight_decay=5e-4)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)

In [25]:
## Following is copied from 
### https://github.com/kuangliu/pytorch-cifar/blob/master/main.py

# Training
def train(epoch):
    model.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(tqdm(train_loader)):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
    print(f"[Train] {epoch} Loss: {train_loss/(batch_idx+1):.3f} | Acc: {100.*correct/total:.3f} {correct}/{total}")
    return

In [26]:
best_acc = -1
def test(epoch):
    global best_acc
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(tqdm(test_loader)):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            
    print(f"[Test] {epoch} Loss: {test_loss/(batch_idx+1):.3f} | Acc: {100.*correct/total:.3f} {correct}/{total}")
    
    # Save checkpoint.
    acc = 100.*correct/total
    if acc > best_acc:
        print('Saving..')
        state = {
            'model': model.state_dict(),
            'acc': acc,
            'epoch': epoch,
        }
        if not os.path.isdir('models'):
            os.mkdir('models')
        torch.save(state, f'./models/{model_name}.pth')
        best_acc = acc

In [27]:
start_epoch = 0  # start from epoch 0 or last checkpoint epoch
resume = False

if resume:
    # Load checkpoint.
    print('==> Resuming from checkpoint..')
    assert os.path.isdir('./models'), 'Error: no checkpoint directory found!'
    checkpoint = torch.load(f'./models/{model_name}.pth')
    model.load_state_dict(checkpoint['model'])
    best_acc = checkpoint['acc']
    start_epoch = checkpoint['epoch']

In [None]:
### Train the whole damn thing

for epoch in range(start_epoch, start_epoch+200): ## for 200 epochs
    train(epoch)
    test(epoch)
    scheduler.step()

100%|██████████| 391/391 [00:53<00:00,  7.34it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 0 Loss: 2.299 | Acc: 27.474 13737/50000


100%|██████████| 79/79 [00:03<00:00, 23.68it/s]


[Test] 0 Loss: 2.297 | Acc: 36.950 3695/10000
Saving..


100%|██████████| 391/391 [00:53<00:00,  7.29it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 1 Loss: 2.295 | Acc: 36.060 18030/50000


100%|██████████| 79/79 [00:03<00:00, 23.39it/s]


[Test] 1 Loss: 2.293 | Acc: 39.470 3947/10000
Saving..


100%|██████████| 391/391 [00:53<00:00,  7.25it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 2 Loss: 2.291 | Acc: 38.894 19447/50000


100%|██████████| 79/79 [00:03<00:00, 23.23it/s]


[Test] 2 Loss: 2.288 | Acc: 43.500 4350/10000
Saving..


100%|██████████| 391/391 [00:54<00:00,  7.22it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 3 Loss: 2.286 | Acc: 41.882 20941/50000


100%|██████████| 79/79 [00:03<00:00, 23.17it/s]


[Test] 3 Loss: 2.284 | Acc: 44.310 4431/10000
Saving..


100%|██████████| 391/391 [00:54<00:00,  7.21it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 4 Loss: 2.281 | Acc: 43.316 21658/50000


100%|██████████| 79/79 [00:03<00:00, 23.14it/s]


[Test] 4 Loss: 2.279 | Acc: 44.390 4439/10000
Saving..


100%|██████████| 391/391 [00:54<00:00,  7.21it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 5 Loss: 2.276 | Acc: 44.310 22155/50000


100%|██████████| 79/79 [00:03<00:00, 23.33it/s]


[Test] 5 Loss: 2.272 | Acc: 47.110 4711/10000
Saving..


100%|██████████| 391/391 [00:54<00:00,  7.21it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 6 Loss: 2.270 | Acc: 44.986 22493/50000


100%|██████████| 79/79 [00:03<00:00, 23.35it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 6 Loss: 2.268 | Acc: 45.540 4554/10000


100%|██████████| 391/391 [00:54<00:00,  7.21it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 7 Loss: 2.264 | Acc: 45.920 22960/50000


100%|██████████| 79/79 [00:03<00:00, 23.39it/s]


[Test] 7 Loss: 2.260 | Acc: 48.640 4864/10000
Saving..


100%|██████████| 391/391 [00:54<00:00,  7.20it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 8 Loss: 2.258 | Acc: 46.802 23401/50000


100%|██████████| 79/79 [00:03<00:00, 23.34it/s]


[Test] 8 Loss: 2.253 | Acc: 49.480 4948/10000
Saving..


100%|██████████| 391/391 [00:54<00:00,  7.20it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 9 Loss: 2.251 | Acc: 47.556 23778/50000


100%|██████████| 79/79 [00:03<00:00, 23.24it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 9 Loss: 2.247 | Acc: 48.670 4867/10000


100%|██████████| 391/391 [00:54<00:00,  7.21it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 10 Loss: 2.244 | Acc: 48.276 24138/50000


100%|██████████| 79/79 [00:03<00:00, 23.15it/s]


[Test] 10 Loss: 2.239 | Acc: 50.440 5044/10000
Saving..


100%|██████████| 391/391 [00:54<00:00,  7.19it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 11 Loss: 2.237 | Acc: 49.142 24571/50000


100%|██████████| 79/79 [00:03<00:00, 23.38it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 11 Loss: 2.233 | Acc: 49.620 4962/10000


100%|██████████| 391/391 [00:54<00:00,  7.20it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 12 Loss: 2.230 | Acc: 49.696 24848/50000


100%|██████████| 79/79 [00:03<00:00, 23.25it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 12 Loss: 2.225 | Acc: 50.040 5004/10000


100%|██████████| 391/391 [00:54<00:00,  7.20it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 13 Loss: 2.222 | Acc: 50.280 25140/50000


100%|██████████| 79/79 [00:03<00:00, 23.06it/s]


[Test] 13 Loss: 2.217 | Acc: 51.700 5170/10000
Saving..


100%|██████████| 391/391 [00:53<00:00,  7.24it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 14 Loss: 2.214 | Acc: 50.830 25415/50000


100%|██████████| 79/79 [00:03<00:00, 23.46it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 14 Loss: 2.210 | Acc: 50.920 5092/10000


100%|██████████| 391/391 [00:53<00:00,  7.26it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 15 Loss: 2.206 | Acc: 51.406 25703/50000


100%|██████████| 79/79 [00:03<00:00, 23.46it/s]


[Test] 15 Loss: 2.200 | Acc: 52.480 5248/10000
Saving..


100%|██████████| 391/391 [00:53<00:00,  7.26it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 16 Loss: 2.198 | Acc: 51.418 25709/50000


100%|██████████| 79/79 [00:03<00:00, 23.43it/s]


[Test] 16 Loss: 2.192 | Acc: 52.860 5286/10000
Saving..


100%|██████████| 391/391 [00:53<00:00,  7.25it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 17 Loss: 2.188 | Acc: 52.270 26135/50000


100%|██████████| 79/79 [00:03<00:00, 23.51it/s]


[Test] 17 Loss: 2.182 | Acc: 53.260 5326/10000
Saving..


100%|██████████| 391/391 [00:54<00:00,  7.23it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 18 Loss: 2.180 | Acc: 52.482 26241/50000


100%|██████████| 79/79 [00:03<00:00, 21.96it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 18 Loss: 2.176 | Acc: 52.110 5211/10000


100%|██████████| 391/391 [00:54<00:00,  7.19it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 19 Loss: 2.170 | Acc: 53.088 26544/50000


100%|██████████| 79/79 [00:03<00:00, 23.34it/s]


[Test] 19 Loss: 2.164 | Acc: 53.750 5375/10000
Saving..


100%|██████████| 391/391 [00:54<00:00,  7.22it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 20 Loss: 2.162 | Acc: 53.402 26701/50000


100%|██████████| 79/79 [00:03<00:00, 23.40it/s]


[Test] 20 Loss: 2.155 | Acc: 54.600 5460/10000
Saving..


100%|██████████| 391/391 [00:53<00:00,  7.25it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 21 Loss: 2.152 | Acc: 53.800 26900/50000


100%|██████████| 79/79 [00:03<00:00, 23.37it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 21 Loss: 2.151 | Acc: 54.130 5413/10000


100%|██████████| 391/391 [00:54<00:00,  7.22it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 22 Loss: 2.143 | Acc: 54.212 27106/50000


100%|██████████| 79/79 [00:03<00:00, 23.34it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 22 Loss: 2.140 | Acc: 53.870 5387/10000


100%|██████████| 391/391 [00:54<00:00,  7.21it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 23 Loss: 2.133 | Acc: 54.734 27367/50000


100%|██████████| 79/79 [00:03<00:00, 23.39it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 23 Loss: 2.130 | Acc: 54.470 5447/10000


100%|██████████| 391/391 [00:54<00:00,  7.21it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 24 Loss: 2.124 | Acc: 54.958 27479/50000


100%|██████████| 79/79 [00:03<00:00, 23.24it/s]


[Test] 24 Loss: 2.120 | Acc: 55.490 5549/10000
Saving..


100%|██████████| 391/391 [00:54<00:00,  7.23it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 25 Loss: 2.114 | Acc: 55.442 27721/50000


100%|██████████| 79/79 [00:03<00:00, 23.36it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 25 Loss: 2.111 | Acc: 55.210 5521/10000


100%|██████████| 391/391 [00:54<00:00,  7.20it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 26 Loss: 2.105 | Acc: 55.448 27724/50000


100%|██████████| 79/79 [00:03<00:00, 23.30it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 26 Loss: 2.106 | Acc: 54.440 5444/10000


100%|██████████| 391/391 [00:54<00:00,  7.21it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 27 Loss: 2.096 | Acc: 55.784 27892/50000


100%|██████████| 79/79 [00:03<00:00, 23.33it/s]


[Test] 27 Loss: 2.091 | Acc: 56.360 5636/10000
Saving..


100%|██████████| 391/391 [00:54<00:00,  7.22it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 28 Loss: 2.086 | Acc: 56.114 28057/50000


100%|██████████| 79/79 [00:03<00:00, 23.27it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 28 Loss: 2.083 | Acc: 55.870 5587/10000


100%|██████████| 391/391 [00:54<00:00,  7.21it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 29 Loss: 2.078 | Acc: 56.020 28010/50000


100%|██████████| 79/79 [00:03<00:00, 23.34it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 29 Loss: 2.076 | Acc: 56.110 5611/10000


100%|██████████| 391/391 [00:54<00:00,  7.20it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 30 Loss: 2.070 | Acc: 56.346 28173/50000


100%|██████████| 79/79 [00:03<00:00, 23.16it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 30 Loss: 2.069 | Acc: 55.840 5584/10000


100%|██████████| 391/391 [00:54<00:00,  7.20it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 31 Loss: 2.060 | Acc: 56.726 28363/50000


100%|██████████| 79/79 [00:03<00:00, 23.28it/s]


[Test] 31 Loss: 2.058 | Acc: 56.650 5665/10000
Saving..


100%|██████████| 391/391 [00:54<00:00,  7.20it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 32 Loss: 2.052 | Acc: 56.820 28410/50000


100%|██████████| 79/79 [00:03<00:00, 23.19it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 32 Loss: 2.050 | Acc: 56.390 5639/10000


100%|██████████| 391/391 [00:54<00:00,  7.20it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 33 Loss: 2.044 | Acc: 56.902 28451/50000


100%|██████████| 79/79 [00:03<00:00, 23.23it/s]


[Test] 33 Loss: 2.041 | Acc: 56.810 5681/10000
Saving..


100%|██████████| 391/391 [00:54<00:00,  7.21it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 34 Loss: 2.036 | Acc: 57.200 28600/50000


100%|██████████| 79/79 [00:03<00:00, 23.34it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 34 Loss: 2.034 | Acc: 56.720 5672/10000


100%|██████████| 391/391 [00:54<00:00,  7.20it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 35 Loss: 2.030 | Acc: 57.120 28560/50000


100%|██████████| 79/79 [00:03<00:00, 23.22it/s]


[Test] 35 Loss: 2.027 | Acc: 57.360 5736/10000
Saving..


100%|██████████| 391/391 [00:54<00:00,  7.20it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 36 Loss: 2.019 | Acc: 57.928 28964/50000


100%|██████████| 79/79 [00:03<00:00, 23.22it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 36 Loss: 2.020 | Acc: 56.940 5694/10000


100%|██████████| 391/391 [00:54<00:00,  7.20it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 37 Loss: 2.011 | Acc: 58.038 29019/50000


100%|██████████| 79/79 [00:03<00:00, 23.29it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 37 Loss: 2.012 | Acc: 57.210 5721/10000


100%|██████████| 391/391 [00:54<00:00,  7.20it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 38 Loss: 2.005 | Acc: 57.954 28977/50000


100%|██████████| 79/79 [00:03<00:00, 23.32it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 38 Loss: 2.011 | Acc: 56.800 5680/10000


100%|██████████| 391/391 [00:54<00:00,  7.19it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 39 Loss: 1.998 | Acc: 58.192 29096/50000


100%|██████████| 79/79 [00:03<00:00, 22.93it/s]


[Test] 39 Loss: 2.000 | Acc: 57.540 5754/10000
Saving..


100%|██████████| 391/391 [00:54<00:00,  7.21it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 40 Loss: 1.991 | Acc: 58.428 29214/50000


100%|██████████| 79/79 [00:03<00:00, 23.21it/s]


[Test] 40 Loss: 1.995 | Acc: 57.590 5759/10000
Saving..


100%|██████████| 391/391 [00:54<00:00,  7.19it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 41 Loss: 1.984 | Acc: 58.604 29302/50000


100%|██████████| 79/79 [00:03<00:00, 23.24it/s]


[Test] 41 Loss: 1.987 | Acc: 57.940 5794/10000
Saving..


100%|██████████| 391/391 [00:54<00:00,  7.19it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 42 Loss: 1.977 | Acc: 58.874 29437/50000


100%|██████████| 79/79 [00:03<00:00, 23.20it/s]


[Test] 42 Loss: 1.981 | Acc: 58.050 5805/10000
Saving..


100%|██████████| 391/391 [00:54<00:00,  7.19it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 43 Loss: 1.971 | Acc: 58.910 29455/50000


100%|██████████| 79/79 [00:03<00:00, 23.23it/s]


[Test] 43 Loss: 1.972 | Acc: 58.710 5871/10000
Saving..


100%|██████████| 391/391 [00:54<00:00,  7.19it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 44 Loss: 1.966 | Acc: 59.018 29509/50000


100%|██████████| 79/79 [00:03<00:00, 23.05it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 44 Loss: 1.970 | Acc: 58.190 5819/10000


100%|██████████| 391/391 [00:54<00:00,  7.19it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 45 Loss: 1.960 | Acc: 59.058 29529/50000


100%|██████████| 79/79 [00:03<00:00, 23.26it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 45 Loss: 1.965 | Acc: 58.410 5841/10000


100%|██████████| 391/391 [00:54<00:00,  7.20it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 46 Loss: 1.955 | Acc: 59.280 29640/50000


100%|██████████| 79/79 [00:03<00:00, 23.16it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 46 Loss: 1.962 | Acc: 58.230 5823/10000


100%|██████████| 391/391 [00:54<00:00,  7.18it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 47 Loss: 1.950 | Acc: 59.530 29765/50000


100%|██████████| 79/79 [00:03<00:00, 23.33it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 47 Loss: 1.962 | Acc: 58.220 5822/10000


100%|██████████| 391/391 [00:54<00:00,  7.19it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 48 Loss: 1.943 | Acc: 59.736 29868/50000


100%|██████████| 79/79 [00:03<00:00, 23.28it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 48 Loss: 1.954 | Acc: 58.440 5844/10000


100%|██████████| 391/391 [00:54<00:00,  7.19it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 49 Loss: 1.939 | Acc: 59.732 29866/50000


100%|██████████| 79/79 [00:03<00:00, 23.25it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 49 Loss: 1.950 | Acc: 58.170 5817/10000


100%|██████████| 391/391 [00:54<00:00,  7.19it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 50 Loss: 1.933 | Acc: 59.996 29998/50000


100%|██████████| 79/79 [00:03<00:00, 23.13it/s]


[Test] 50 Loss: 1.941 | Acc: 59.220 5922/10000
Saving..


100%|██████████| 391/391 [00:54<00:00,  7.19it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 51 Loss: 1.929 | Acc: 60.082 30041/50000


100%|██████████| 79/79 [00:03<00:00, 23.25it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 51 Loss: 1.938 | Acc: 59.070 5907/10000


100%|██████████| 391/391 [00:54<00:00,  7.19it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 52 Loss: 1.923 | Acc: 60.264 30132/50000


100%|██████████| 79/79 [00:03<00:00, 23.26it/s]


[Test] 52 Loss: 1.928 | Acc: 59.480 5948/10000
Saving..


100%|██████████| 391/391 [00:54<00:00,  7.18it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 53 Loss: 1.920 | Acc: 60.330 30165/50000


100%|██████████| 79/79 [00:03<00:00, 23.19it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 53 Loss: 1.929 | Acc: 59.380 5938/10000


100%|██████████| 391/391 [00:54<00:00,  7.19it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 54 Loss: 1.916 | Acc: 60.326 30163/50000


100%|██████████| 79/79 [00:03<00:00, 23.32it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 54 Loss: 1.926 | Acc: 59.440 5944/10000


100%|██████████| 391/391 [00:54<00:00,  7.19it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 55 Loss: 1.912 | Acc: 60.476 30238/50000


100%|██████████| 79/79 [00:03<00:00, 23.28it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 55 Loss: 1.929 | Acc: 58.140 5814/10000


100%|██████████| 391/391 [00:54<00:00,  7.18it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 56 Loss: 1.908 | Acc: 60.694 30347/50000


100%|██████████| 79/79 [00:03<00:00, 23.07it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 56 Loss: 1.918 | Acc: 59.410 5941/10000


100%|██████████| 391/391 [00:54<00:00,  7.18it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 57 Loss: 1.903 | Acc: 61.058 30529/50000


100%|██████████| 79/79 [00:03<00:00, 23.21it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 57 Loss: 1.915 | Acc: 59.370 5937/10000


100%|██████████| 391/391 [00:54<00:00,  7.19it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 58 Loss: 1.900 | Acc: 60.914 30457/50000


100%|██████████| 79/79 [00:03<00:00, 23.19it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 58 Loss: 1.914 | Acc: 59.330 5933/10000


100%|██████████| 391/391 [00:54<00:00,  7.18it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 59 Loss: 1.896 | Acc: 61.084 30542/50000


100%|██████████| 79/79 [00:03<00:00, 23.21it/s]


[Test] 59 Loss: 1.908 | Acc: 59.530 5953/10000
Saving..


100%|██████████| 391/391 [00:54<00:00,  7.18it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 60 Loss: 1.895 | Acc: 61.118 30559/50000


100%|██████████| 79/79 [00:03<00:00, 23.21it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 60 Loss: 1.910 | Acc: 59.340 5934/10000


100%|██████████| 391/391 [00:54<00:00,  7.19it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 61 Loss: 1.889 | Acc: 61.346 30673/50000


100%|██████████| 79/79 [00:03<00:00, 23.18it/s]


[Test] 61 Loss: 1.903 | Acc: 59.930 5993/10000
Saving..


100%|██████████| 391/391 [00:54<00:00,  7.19it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 62 Loss: 1.887 | Acc: 61.482 30741/50000


100%|██████████| 79/79 [00:03<00:00, 23.21it/s]


[Test] 62 Loss: 1.898 | Acc: 60.430 6043/10000
Saving..


100%|██████████| 391/391 [00:54<00:00,  7.19it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 63 Loss: 1.884 | Acc: 61.406 30703/50000


100%|██████████| 79/79 [00:03<00:00, 23.23it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 63 Loss: 1.901 | Acc: 60.000 6000/10000


100%|██████████| 391/391 [00:54<00:00,  7.19it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 64 Loss: 1.882 | Acc: 61.446 30723/50000


100%|██████████| 79/79 [00:03<00:00, 23.19it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 64 Loss: 1.904 | Acc: 58.870 5887/10000


100%|██████████| 391/391 [00:54<00:00,  7.19it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 65 Loss: 1.880 | Acc: 61.480 30740/50000


100%|██████████| 79/79 [00:03<00:00, 23.16it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 65 Loss: 1.896 | Acc: 59.620 5962/10000


100%|██████████| 391/391 [00:54<00:00,  7.18it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 66 Loss: 1.877 | Acc: 61.676 30838/50000


100%|██████████| 79/79 [00:03<00:00, 23.08it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 66 Loss: 1.891 | Acc: 60.040 6004/10000


100%|██████████| 391/391 [00:54<00:00,  7.18it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 67 Loss: 1.872 | Acc: 61.960 30980/50000


100%|██████████| 79/79 [00:03<00:00, 23.08it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 67 Loss: 1.895 | Acc: 59.230 5923/10000


100%|██████████| 391/391 [00:54<00:00,  7.17it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 68 Loss: 1.871 | Acc: 61.874 30937/50000


100%|██████████| 79/79 [00:03<00:00, 23.12it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 68 Loss: 1.890 | Acc: 59.970 5997/10000


100%|██████████| 391/391 [00:54<00:00,  7.18it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 69 Loss: 1.869 | Acc: 61.984 30992/50000


100%|██████████| 79/79 [00:03<00:00, 23.16it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 69 Loss: 1.884 | Acc: 60.380 6038/10000


100%|██████████| 391/391 [00:54<00:00,  7.19it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 70 Loss: 1.865 | Acc: 62.246 31123/50000


100%|██████████| 79/79 [00:03<00:00, 23.27it/s]


[Test] 70 Loss: 1.880 | Acc: 60.590 6059/10000
Saving..


100%|██████████| 391/391 [00:54<00:00,  7.19it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 71 Loss: 1.863 | Acc: 62.272 31136/50000


100%|██████████| 79/79 [00:03<00:00, 22.95it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 71 Loss: 1.885 | Acc: 59.970 5997/10000


100%|██████████| 391/391 [00:54<00:00,  7.19it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 72 Loss: 1.861 | Acc: 62.362 31181/50000


100%|██████████| 79/79 [00:03<00:00, 23.23it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 72 Loss: 1.882 | Acc: 60.110 6011/10000


100%|██████████| 391/391 [00:54<00:00,  7.18it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 73 Loss: 1.858 | Acc: 62.486 31243/50000


100%|██████████| 79/79 [00:03<00:00, 23.16it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 73 Loss: 1.880 | Acc: 60.450 6045/10000


100%|██████████| 391/391 [00:54<00:00,  7.19it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 74 Loss: 1.857 | Acc: 62.532 31266/50000


100%|██████████| 79/79 [00:03<00:00, 23.19it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 74 Loss: 1.875 | Acc: 60.520 6052/10000


100%|██████████| 391/391 [00:54<00:00,  7.19it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 75 Loss: 1.855 | Acc: 62.666 31333/50000


100%|██████████| 79/79 [00:03<00:00, 23.19it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 75 Loss: 1.875 | Acc: 60.430 6043/10000


100%|██████████| 391/391 [00:54<00:00,  7.19it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 76 Loss: 1.854 | Acc: 62.684 31342/50000


100%|██████████| 79/79 [00:03<00:00, 23.16it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 76 Loss: 1.876 | Acc: 60.150 6015/10000


100%|██████████| 391/391 [00:54<00:00,  7.19it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 77 Loss: 1.852 | Acc: 62.792 31396/50000


100%|██████████| 79/79 [00:03<00:00, 23.05it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 77 Loss: 1.875 | Acc: 60.160 6016/10000


100%|██████████| 391/391 [00:54<00:00,  7.17it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 78 Loss: 1.849 | Acc: 62.952 31476/50000


100%|██████████| 79/79 [00:03<00:00, 23.20it/s]


[Test] 78 Loss: 1.871 | Acc: 60.700 6070/10000
Saving..


100%|██████████| 391/391 [00:54<00:00,  7.19it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 79 Loss: 1.848 | Acc: 62.936 31468/50000


100%|██████████| 79/79 [00:03<00:00, 23.33it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 79 Loss: 1.872 | Acc: 60.310 6031/10000


100%|██████████| 391/391 [00:54<00:00,  7.17it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 80 Loss: 1.847 | Acc: 62.892 31446/50000


100%|██████████| 79/79 [00:03<00:00, 23.21it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 80 Loss: 1.870 | Acc: 60.280 6028/10000


100%|██████████| 391/391 [00:54<00:00,  7.17it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 81 Loss: 1.845 | Acc: 63.074 31537/50000


100%|██████████| 79/79 [00:03<00:00, 23.11it/s]


[Test] 81 Loss: 1.866 | Acc: 60.940 6094/10000
Saving..


100%|██████████| 391/391 [00:54<00:00,  7.18it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 82 Loss: 1.843 | Acc: 63.178 31589/50000


100%|██████████| 79/79 [00:03<00:00, 23.17it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 82 Loss: 1.867 | Acc: 60.720 6072/10000


100%|██████████| 391/391 [00:54<00:00,  7.18it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 83 Loss: 1.840 | Acc: 63.346 31673/50000


100%|██████████| 79/79 [00:03<00:00, 23.16it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 83 Loss: 1.865 | Acc: 60.760 6076/10000


100%|██████████| 391/391 [00:54<00:00,  7.18it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 84 Loss: 1.840 | Acc: 63.354 31677/50000


100%|██████████| 79/79 [00:03<00:00, 23.08it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 84 Loss: 1.870 | Acc: 59.950 5995/10000


100%|██████████| 391/391 [00:54<00:00,  7.17it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 85 Loss: 1.837 | Acc: 63.536 31768/50000


100%|██████████| 79/79 [00:03<00:00, 23.26it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 85 Loss: 1.866 | Acc: 60.440 6044/10000


100%|██████████| 391/391 [00:54<00:00,  7.17it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 86 Loss: 1.837 | Acc: 63.490 31745/50000


100%|██████████| 79/79 [00:03<00:00, 23.17it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 86 Loss: 1.864 | Acc: 60.620 6062/10000


100%|██████████| 391/391 [00:54<00:00,  7.17it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 87 Loss: 1.834 | Acc: 63.720 31860/50000


100%|██████████| 79/79 [00:03<00:00, 23.07it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 87 Loss: 1.861 | Acc: 60.620 6062/10000


100%|██████████| 391/391 [00:54<00:00,  7.16it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 88 Loss: 1.833 | Acc: 63.632 31816/50000


100%|██████████| 79/79 [00:03<00:00, 23.16it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 88 Loss: 1.859 | Acc: 60.890 6089/10000


100%|██████████| 391/391 [00:54<00:00,  7.15it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 89 Loss: 1.832 | Acc: 63.734 31867/50000


100%|██████████| 79/79 [00:03<00:00, 23.17it/s]


[Test] 89 Loss: 1.859 | Acc: 61.090 6109/10000
Saving..


100%|██████████| 391/391 [00:54<00:00,  7.16it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 90 Loss: 1.831 | Acc: 63.830 31915/50000


100%|██████████| 79/79 [00:03<00:00, 23.14it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 90 Loss: 1.858 | Acc: 60.780 6078/10000


100%|██████████| 391/391 [00:54<00:00,  7.16it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 91 Loss: 1.830 | Acc: 63.926 31963/50000


100%|██████████| 79/79 [00:03<00:00, 23.05it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 91 Loss: 1.859 | Acc: 60.830 6083/10000


100%|██████████| 391/391 [00:54<00:00,  7.17it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 92 Loss: 1.829 | Acc: 63.886 31943/50000


100%|██████████| 79/79 [00:03<00:00, 22.92it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 92 Loss: 1.858 | Acc: 60.930 6093/10000


100%|██████████| 391/391 [00:54<00:00,  7.17it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 93 Loss: 1.826 | Acc: 64.144 32072/50000


100%|██████████| 79/79 [00:03<00:00, 23.07it/s]


[Test] 93 Loss: 1.854 | Acc: 61.340 6134/10000
Saving..


100%|██████████| 391/391 [00:54<00:00,  7.16it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 94 Loss: 1.824 | Acc: 64.368 32184/50000


100%|██████████| 79/79 [00:03<00:00, 23.10it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 94 Loss: 1.858 | Acc: 60.620 6062/10000


100%|██████████| 391/391 [00:54<00:00,  7.16it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 95 Loss: 1.823 | Acc: 64.376 32188/50000


100%|██████████| 79/79 [00:03<00:00, 23.13it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 95 Loss: 1.855 | Acc: 61.000 6100/10000


100%|██████████| 391/391 [00:54<00:00,  7.17it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 96 Loss: 1.824 | Acc: 64.210 32105/50000


100%|██████████| 79/79 [00:03<00:00, 23.10it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 96 Loss: 1.855 | Acc: 60.990 6099/10000


100%|██████████| 391/391 [00:54<00:00,  7.15it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 97 Loss: 1.824 | Acc: 64.186 32093/50000


100%|██████████| 79/79 [00:03<00:00, 23.00it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 97 Loss: 1.853 | Acc: 61.210 6121/10000


100%|██████████| 391/391 [00:54<00:00,  7.16it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 98 Loss: 1.821 | Acc: 64.432 32216/50000


100%|██████████| 79/79 [00:03<00:00, 22.92it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 98 Loss: 1.857 | Acc: 60.700 6070/10000


100%|██████████| 391/391 [00:54<00:00,  7.15it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 99 Loss: 1.819 | Acc: 64.552 32276/50000


100%|██████████| 79/79 [00:03<00:00, 23.03it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 99 Loss: 1.855 | Acc: 60.750 6075/10000


100%|██████████| 391/391 [00:54<00:00,  7.16it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 100 Loss: 1.820 | Acc: 64.482 32241/50000


100%|██████████| 79/79 [00:03<00:00, 23.04it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 100 Loss: 1.854 | Acc: 60.810 6081/10000


100%|██████████| 391/391 [00:54<00:00,  7.16it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 101 Loss: 1.820 | Acc: 64.416 32208/50000


100%|██████████| 79/79 [00:03<00:00, 22.86it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 101 Loss: 1.852 | Acc: 61.020 6102/10000


100%|██████████| 391/391 [00:54<00:00,  7.15it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 102 Loss: 1.817 | Acc: 64.684 32342/50000


100%|██████████| 79/79 [00:03<00:00, 23.03it/s]


[Test] 102 Loss: 1.849 | Acc: 61.430 6143/10000
Saving..


100%|██████████| 391/391 [00:54<00:00,  7.16it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 103 Loss: 1.815 | Acc: 64.854 32427/50000


100%|██████████| 79/79 [00:03<00:00, 23.00it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 103 Loss: 1.849 | Acc: 61.220 6122/10000


100%|██████████| 391/391 [00:54<00:00,  7.15it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 104 Loss: 1.816 | Acc: 64.620 32310/50000


100%|██████████| 79/79 [00:03<00:00, 23.05it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 104 Loss: 1.850 | Acc: 61.100 6110/10000


100%|██████████| 391/391 [00:54<00:00,  7.15it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 105 Loss: 1.813 | Acc: 64.912 32456/50000


100%|██████████| 79/79 [00:03<00:00, 22.96it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 105 Loss: 1.848 | Acc: 61.400 6140/10000


100%|██████████| 391/391 [00:54<00:00,  7.15it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 106 Loss: 1.815 | Acc: 64.776 32388/50000


100%|██████████| 79/79 [00:03<00:00, 23.08it/s]


[Test] 106 Loss: 1.846 | Acc: 61.530 6153/10000
Saving..


100%|██████████| 391/391 [00:54<00:00,  7.16it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 107 Loss: 1.814 | Acc: 64.798 32399/50000


100%|██████████| 79/79 [00:03<00:00, 22.98it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 107 Loss: 1.848 | Acc: 61.280 6128/10000


100%|██████████| 391/391 [00:54<00:00,  7.15it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 108 Loss: 1.812 | Acc: 64.954 32477/50000


100%|██████████| 79/79 [00:03<00:00, 23.06it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 108 Loss: 1.846 | Acc: 61.490 6149/10000


100%|██████████| 391/391 [00:54<00:00,  7.15it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 109 Loss: 1.811 | Acc: 65.064 32532/50000


100%|██████████| 79/79 [00:03<00:00, 23.15it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 109 Loss: 1.847 | Acc: 61.350 6135/10000


100%|██████████| 391/391 [00:54<00:00,  7.14it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 110 Loss: 1.812 | Acc: 64.960 32480/50000


100%|██████████| 79/79 [00:03<00:00, 22.95it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 110 Loss: 1.849 | Acc: 61.100 6110/10000


100%|██████████| 391/391 [00:54<00:00,  7.14it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 111 Loss: 1.809 | Acc: 65.222 32611/50000


100%|██████████| 79/79 [00:03<00:00, 23.09it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 111 Loss: 1.845 | Acc: 61.530 6153/10000


100%|██████████| 391/391 [00:54<00:00,  7.14it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 112 Loss: 1.808 | Acc: 65.254 32627/50000


100%|██████████| 79/79 [00:03<00:00, 22.83it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 112 Loss: 1.846 | Acc: 61.480 6148/10000


100%|██████████| 391/391 [00:54<00:00,  7.14it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 113 Loss: 1.808 | Acc: 65.232 32616/50000


100%|██████████| 79/79 [00:03<00:00, 23.03it/s]


[Test] 113 Loss: 1.843 | Acc: 61.670 6167/10000
Saving..


100%|██████████| 391/391 [00:54<00:00,  7.15it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 114 Loss: 1.806 | Acc: 65.384 32692/50000


100%|██████████| 79/79 [00:03<00:00, 22.80it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 114 Loss: 1.845 | Acc: 61.460 6146/10000


100%|██████████| 391/391 [00:54<00:00,  7.13it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 115 Loss: 1.805 | Acc: 65.496 32748/50000


100%|██████████| 79/79 [00:03<00:00, 23.05it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 115 Loss: 1.843 | Acc: 61.510 6151/10000


100%|██████████| 391/391 [00:54<00:00,  7.13it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 116 Loss: 1.805 | Acc: 65.442 32721/50000


100%|██████████| 79/79 [00:03<00:00, 22.92it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 116 Loss: 1.843 | Acc: 61.670 6167/10000


100%|██████████| 391/391 [00:54<00:00,  7.13it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 117 Loss: 1.806 | Acc: 65.440 32720/50000


100%|██████████| 79/79 [00:03<00:00, 23.01it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 117 Loss: 1.845 | Acc: 61.390 6139/10000


100%|██████████| 391/391 [00:54<00:00,  7.13it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 118 Loss: 1.805 | Acc: 65.482 32741/50000


100%|██████████| 79/79 [00:03<00:00, 22.87it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 118 Loss: 1.844 | Acc: 61.480 6148/10000


100%|██████████| 391/391 [00:54<00:00,  7.13it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 119 Loss: 1.803 | Acc: 65.646 32823/50000


100%|██████████| 79/79 [00:03<00:00, 22.84it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 119 Loss: 1.844 | Acc: 61.520 6152/10000


100%|██████████| 391/391 [00:54<00:00,  7.13it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 120 Loss: 1.803 | Acc: 65.608 32804/50000


100%|██████████| 79/79 [00:03<00:00, 22.88it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 120 Loss: 1.844 | Acc: 61.390 6139/10000


100%|██████████| 391/391 [00:54<00:00,  7.13it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 121 Loss: 1.803 | Acc: 65.588 32794/50000


100%|██████████| 79/79 [00:03<00:00, 23.00it/s]


[Test] 121 Loss: 1.842 | Acc: 61.700 6170/10000
Saving..


100%|██████████| 391/391 [00:54<00:00,  7.14it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 122 Loss: 1.802 | Acc: 65.726 32863/50000


100%|██████████| 79/79 [00:03<00:00, 23.12it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 122 Loss: 1.842 | Acc: 61.490 6149/10000


100%|██████████| 391/391 [00:54<00:00,  7.14it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 123 Loss: 1.800 | Acc: 65.838 32919/50000


100%|██████████| 79/79 [00:03<00:00, 23.01it/s]


[Test] 123 Loss: 1.840 | Acc: 61.790 6179/10000
Saving..


100%|██████████| 391/391 [00:54<00:00,  7.14it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 124 Loss: 1.799 | Acc: 65.832 32916/50000


100%|██████████| 79/79 [00:03<00:00, 22.92it/s]


[Test] 124 Loss: 1.840 | Acc: 61.840 6184/10000
Saving..


100%|██████████| 391/391 [00:54<00:00,  7.15it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 125 Loss: 1.800 | Acc: 65.846 32923/50000


100%|██████████| 79/79 [00:03<00:00, 22.94it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 125 Loss: 1.839 | Acc: 61.750 6175/10000


100%|██████████| 391/391 [00:54<00:00,  7.13it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 126 Loss: 1.800 | Acc: 65.784 32892/50000


100%|██████████| 79/79 [00:03<00:00, 22.93it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 126 Loss: 1.842 | Acc: 61.610 6161/10000


100%|██████████| 391/391 [00:54<00:00,  7.14it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 127 Loss: 1.798 | Acc: 66.030 33015/50000


100%|██████████| 79/79 [00:03<00:00, 23.07it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 127 Loss: 1.843 | Acc: 61.350 6135/10000


100%|██████████| 391/391 [00:54<00:00,  7.13it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 128 Loss: 1.798 | Acc: 66.012 33006/50000


100%|██████████| 79/79 [00:03<00:00, 22.97it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 128 Loss: 1.840 | Acc: 61.810 6181/10000


100%|██████████| 391/391 [00:54<00:00,  7.13it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 129 Loss: 1.798 | Acc: 66.008 33004/50000


100%|██████████| 79/79 [00:03<00:00, 22.98it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 129 Loss: 1.843 | Acc: 61.410 6141/10000


100%|██████████| 391/391 [00:54<00:00,  7.13it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 130 Loss: 1.796 | Acc: 66.134 33067/50000


100%|██████████| 79/79 [00:03<00:00, 23.00it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 130 Loss: 1.839 | Acc: 61.810 6181/10000


100%|██████████| 391/391 [00:54<00:00,  7.13it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 131 Loss: 1.796 | Acc: 66.142 33071/50000


100%|██████████| 79/79 [00:03<00:00, 22.96it/s]


[Test] 131 Loss: 1.839 | Acc: 61.860 6186/10000
Saving..


100%|██████████| 391/391 [00:54<00:00,  7.14it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 132 Loss: 1.796 | Acc: 66.092 33046/50000


100%|██████████| 79/79 [00:03<00:00, 22.86it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 132 Loss: 1.841 | Acc: 61.480 6148/10000


100%|██████████| 391/391 [00:54<00:00,  7.13it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 133 Loss: 1.795 | Acc: 66.148 33074/50000


100%|██████████| 79/79 [00:03<00:00, 22.94it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 133 Loss: 1.839 | Acc: 61.780 6178/10000


100%|██████████| 391/391 [00:54<00:00,  7.13it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 134 Loss: 1.794 | Acc: 66.258 33129/50000


100%|██████████| 79/79 [00:03<00:00, 22.92it/s]


[Test] 134 Loss: 1.838 | Acc: 62.070 6207/10000
Saving..


100%|██████████| 391/391 [00:54<00:00,  7.14it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 135 Loss: 1.794 | Acc: 66.210 33105/50000


100%|██████████| 79/79 [00:03<00:00, 23.10it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 135 Loss: 1.838 | Acc: 61.850 6185/10000


100%|██████████| 391/391 [00:54<00:00,  7.13it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 136 Loss: 1.793 | Acc: 66.364 33182/50000


100%|██████████| 79/79 [00:03<00:00, 22.93it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 136 Loss: 1.838 | Acc: 61.870 6187/10000


100%|██████████| 391/391 [00:54<00:00,  7.13it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 137 Loss: 1.792 | Acc: 66.414 33207/50000


100%|██████████| 79/79 [00:03<00:00, 22.73it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 137 Loss: 1.839 | Acc: 61.810 6181/10000


100%|██████████| 391/391 [00:54<00:00,  7.13it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 138 Loss: 1.793 | Acc: 66.332 33166/50000


100%|██████████| 79/79 [00:03<00:00, 22.90it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 138 Loss: 1.843 | Acc: 61.290 6129/10000


100%|██████████| 391/391 [00:54<00:00,  7.13it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 139 Loss: 1.793 | Acc: 66.300 33150/50000


100%|██████████| 79/79 [00:03<00:00, 22.91it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 139 Loss: 1.838 | Acc: 61.750 6175/10000


100%|██████████| 391/391 [00:54<00:00,  7.13it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 140 Loss: 1.792 | Acc: 66.434 33217/50000


100%|██████████| 79/79 [00:03<00:00, 22.90it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 140 Loss: 1.839 | Acc: 61.660 6166/10000


100%|██████████| 391/391 [00:54<00:00,  7.14it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 141 Loss: 1.791 | Acc: 66.490 33245/50000


100%|██████████| 79/79 [00:03<00:00, 22.91it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 141 Loss: 1.838 | Acc: 61.870 6187/10000


100%|██████████| 391/391 [00:54<00:00,  7.13it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 142 Loss: 1.791 | Acc: 66.506 33253/50000


100%|██████████| 79/79 [00:03<00:00, 22.91it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 142 Loss: 1.837 | Acc: 61.910 6191/10000


100%|██████████| 391/391 [00:54<00:00,  7.13it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 143 Loss: 1.791 | Acc: 66.548 33274/50000


100%|██████████| 79/79 [00:03<00:00, 23.03it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 143 Loss: 1.837 | Acc: 61.890 6189/10000


100%|██████████| 391/391 [00:54<00:00,  7.14it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 144 Loss: 1.790 | Acc: 66.606 33303/50000


100%|██████████| 79/79 [00:03<00:00, 22.99it/s]


[Test] 144 Loss: 1.835 | Acc: 62.130 6213/10000
Saving..


100%|██████████| 391/391 [00:54<00:00,  7.14it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 145 Loss: 1.789 | Acc: 66.680 33340/50000


100%|██████████| 79/79 [00:03<00:00, 22.99it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 145 Loss: 1.837 | Acc: 61.880 6188/10000


100%|██████████| 391/391 [00:54<00:00,  7.13it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 146 Loss: 1.789 | Acc: 66.616 33308/50000


100%|██████████| 79/79 [00:03<00:00, 22.98it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 146 Loss: 1.836 | Acc: 62.090 6209/10000


100%|██████████| 391/391 [00:54<00:00,  7.13it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 147 Loss: 1.789 | Acc: 66.722 33361/50000


100%|██████████| 79/79 [00:03<00:00, 23.02it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 147 Loss: 1.837 | Acc: 61.810 6181/10000


100%|██████████| 391/391 [00:54<00:00,  7.13it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 148 Loss: 1.788 | Acc: 66.716 33358/50000


100%|██████████| 79/79 [00:03<00:00, 22.78it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 148 Loss: 1.835 | Acc: 61.990 6199/10000


100%|██████████| 391/391 [00:54<00:00,  7.13it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 149 Loss: 1.789 | Acc: 66.672 33336/50000


100%|██████████| 79/79 [00:03<00:00, 23.07it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 149 Loss: 1.836 | Acc: 62.030 6203/10000


100%|██████████| 391/391 [00:54<00:00,  7.13it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 150 Loss: 1.788 | Acc: 66.668 33334/50000


100%|██████████| 79/79 [00:03<00:00, 22.85it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 150 Loss: 1.837 | Acc: 61.860 6186/10000


100%|██████████| 391/391 [00:54<00:00,  7.13it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 151 Loss: 1.788 | Acc: 66.668 33334/50000


100%|██████████| 79/79 [00:03<00:00, 22.95it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 151 Loss: 1.834 | Acc: 62.120 6212/10000


100%|██████████| 391/391 [00:54<00:00,  7.13it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 152 Loss: 1.787 | Acc: 66.840 33420/50000


100%|██████████| 79/79 [00:03<00:00, 22.84it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 152 Loss: 1.836 | Acc: 61.880 6188/10000


100%|██████████| 391/391 [00:54<00:00,  7.13it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 153 Loss: 1.788 | Acc: 66.722 33361/50000


100%|██████████| 79/79 [00:03<00:00, 23.12it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 153 Loss: 1.836 | Acc: 61.850 6185/10000


100%|██████████| 391/391 [00:54<00:00,  7.14it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

[Train] 154 Loss: 1.787 | Acc: 66.794 33397/50000


100%|██████████| 79/79 [00:03<00:00, 22.85it/s]
  0%|          | 0/391 [00:00<?, ?it/s]

[Test] 154 Loss: 1.834 | Acc: 62.060 6206/10000


 22%|██▏       | 86/391 [00:12<00:42,  7.20it/s]

In [None]:
best_acc

In [None]:
classifier.inv_temp

In [6]:
checkpoint = torch.load(f'./models/{model_name}.pth')
best_acc = checkpoint['acc']
start_epoch = checkpoint['epoch']

best_acc, start_epoch

RuntimeError: CUDA error: out of memory
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.

### Hard test accuracy with count per classifier

In [None]:
test_count = 0
test_acc = 0
set_count = torch.zeros(classifier.num_sets).to(device)
for xx, yy in tqdm(test_loader):
    xx, yy = xx.to(device), yy.to(device)
    with torch.no_grad():
        yout = classifier(backbone(xx), hard=True)
        set_indx, count = torch.unique(torch.argmax(classifier.cls_confidence, dim=1), return_counts=True) 
        set_count[set_indx] += count
    outputs = torch.argmax(yout, dim=1).data.cpu().numpy()
    correct = (outputs == yy.data.cpu().numpy()).astype(float).sum()
    test_acc += correct
    test_count += len(xx)

print(f'Hard Test Acc:{float(test_acc)/test_count*100:.2f}%')
print(set_count.type(torch.long).tolist())

### Hard train accuracy with count per classifier

In [None]:
test_count = 0
test_acc = 0
set_count = torch.zeros(classifier.num_sets).to(device)
for xx, yy in tqdm(train_loader):
    xx, yy = xx.to(device), yy.to(device)
    with torch.no_grad():
        yout = classifier(backbone(xx), hard=True)
        set_indx, count = torch.unique(torch.argmax(classifier.cls_confidence, dim=1), return_counts=True) 
        set_count[set_indx] += count
    outputs = torch.argmax(yout, dim=1).data.cpu().numpy()
    correct = (outputs == yy.data.cpu().numpy()).astype(float).sum()
    test_acc += correct
    test_count += len(xx)

print(f'Hard Train Acc:{float(test_acc)/test_count*100:.2f}%')
print(set_count.type(torch.long).tolist())

In [None]:
#### Classifiers that enclose any data
torch.count_nonzero(set_count)

In [None]:
#### classifier with class representation
torch.argmax(classifier.cls_weight, dim=1)

### analyze per classifier accuracy

In [None]:
test_count = 0
test_acc = 0
set_count = torch.zeros(classifier.num_sets).to(device)
set_acc = torch.zeros(classifier.num_sets).to(device)
for xx, yy in tqdm(train_loader):
    xx, yy = xx.to(device), yy.to(device)
    with torch.no_grad():
        yout = classifier(backbone(xx), hard=True)
        
    cls_indx = torch.argmax(classifier.cls_confidence, dim=1)
    set_indx, count = torch.unique(cls_indx, return_counts=True) 
    set_count[set_indx] += count
    
    outputs = torch.argmax(yout, dim=1).data.cpu().numpy()
    correct = (outputs == yy.data.cpu().numpy()).astype(float)
    
    ### class_index has 100 possible values
    for i, c in enumerate(correct):
        set_acc[cls_indx[i]] += c
    
#     print(set_acc.sum(), set_count.sum())
#     break
    test_acc += correct.sum()
    test_count += len(xx)

print(f'Hard Train Acc:{float(test_acc)/test_count*100:.2f}%')
print(set_count.type(torch.long).tolist())

In [None]:
# set_acc/set_count

In [None]:
for i, (cnt, acc, cls) in enumerate(zip(set_count.type(torch.long).tolist(),
                                   (set_acc/set_count).tolist(),
                                   torch.argmax(classifier.cls_weight, dim=1).tolist())):
    if cnt == 0: continue
    print(f"{i},\t {cnt},\t {cls}\t {acc*100:.2f}%")