In [1]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms

import matplotlib
import matplotlib.pyplot as plt
import numpy as np

import os, sys, pathlib, random, time, pickle, copy, json
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device("cuda:1")
# device = torch.device("cpu")

In [3]:
# SEED = 147
# SEED = 258
SEED = 369

torch.manual_seed(SEED)
np.random.seed(SEED)

In [4]:
import torch.optim as optim
from torch.utils import data

In [5]:
# cifar_train = transforms.Compose([
#     transforms.RandomCrop(size=32, padding=4),
#     transforms.RandomHorizontalFlip(),
#     transforms.ToTensor(),
#     transforms.Normalize(
#         mean=[0.4914, 0.4822, 0.4465], # mean=[0.5071, 0.4865, 0.4409] for cifar100
#         std=[0.2023, 0.1994, 0.2010], # std=[0.2009, 0.1984, 0.2023] for cifar100
#     ),
# ])

# cifar_test = transforms.Compose([
#     transforms.ToTensor(),
#     transforms.Normalize(
#         mean=[0.4914, 0.4822, 0.4465], # mean=[0.5071, 0.4865, 0.4409] for cifar100
#         std=[0.2023, 0.1994, 0.2010], # std=[0.2009, 0.1984, 0.2023] for cifar100
#     ),
# ])

# train_dataset = datasets.CIFAR10(root="../../../../../_Datasets/cifar10/", train=True, download=True, transform=cifar_train)
# test_dataset = datasets.CIFAR10(root="../../../../../_Datasets/cifar10/", train=False, download=True, transform=cifar_test)

In [6]:
# train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=32, shuffle=True, num_workers=2)
# test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=32, shuffle=False, num_workers=2)

In [7]:
cifar_train = transforms.Compose([
    transforms.RandomCrop(size=32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.5071, 0.4865, 0.4409],
        std=[0.2009, 0.1984, 0.2023],
    ),
])

cifar_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.5071, 0.4865, 0.4409],
        std=[0.2009, 0.1984, 0.2023],
    ),
])

train_dataset = datasets.CIFAR100(root="../../../../../_Datasets/cifar100/", train=True, download=True, transform=cifar_train)
test_dataset = datasets.CIFAR100(root="../../../../../_Datasets/cifar100/", train=False, download=True, transform=cifar_test)

Files already downloaded and verified
Files already downloaded and verified


In [8]:
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=128, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=128, shuffle=False, num_workers=2)

# Model

In [9]:
class MlpBLock(nn.Module):
    
    def __init__(self, input_dim, hidden_layers_ratio=[2], actf=nn.GELU):
        super().__init__()
        self.input_dim = input_dim
        #### convert hidden layers ratio to list if integer is inputted
        if isinstance(hidden_layers_ratio, int):
            hidden_layers_ratio = [hidden_layers_ratio]
            
        self.hlr = [1]+hidden_layers_ratio+[1]
        
        self.mlp = []
        ### for 1 hidden layer, we iterate 2 times
        for h in range(len(self.hlr)-1):
            i, o = int(self.hlr[h]*self.input_dim),\
                    int(self.hlr[h+1]*self.input_dim)
            self.mlp.append(nn.Linear(i, o))
            self.mlp.append(actf())
        self.mlp = self.mlp[:-1]
        
        self.mlp = nn.Sequential(*self.mlp)
        
    def forward(self, x):
        return self.mlp(x)

In [10]:
MlpBLock(2, [3,4])

MlpBLock(
  (mlp): Sequential(
    (0): Linear(in_features=2, out_features=6, bias=True)
    (1): GELU(approximate='none')
    (2): Linear(in_features=6, out_features=8, bias=True)
    (3): GELU(approximate='none')
    (4): Linear(in_features=8, out_features=2, bias=True)
  )
)

## MLP-Mixer 

In [11]:
class MixerBlock(nn.Module):
    
    def __init__(self, patch_dim, channel_dim):
        super().__init__()
        
        self.ln0 = nn.LayerNorm(channel_dim)
        self.mlp_patch = MlpBLock(patch_dim, [2])
        self.ln1 = nn.LayerNorm(channel_dim)
        self.mlp_channel = MlpBLock(channel_dim, [2])
    
    def forward(self, x):
        ## x has shape-> N, nP, nC/hidden_dims; C=Channel, P=Patch
        
        ######## !!!! Can use same mixer on shape of -> N, C, P;
        
        #### mix per patch
        y = self.ln0(x) ### per channel layer normalization ?? 
        y = torch.swapaxes(y, -1, -2)
        y = self.mlp_patch(y)
        y = torch.swapaxes(y, -1, -2)
        x = x+y
        
        #### mix per channel 
        y = self.ln1(x)
        y = self.mlp_channel(y)
        x = x+y
        return x

In [12]:
class MlpMixer(nn.Module):
    
    def __init__(self, image_dim:tuple, patch_size:tuple, hidden_expansion:float, num_blocks:int, num_classes:int):
        super().__init__()
        
        self.img_dim = image_dim ### must contain (C, H, W) or (H, W)
        self.scaler = nn.UpsamplingBilinear2d(size=(self.img_dim[-2], self.img_dim[-1]))
        
        ### find patch dim
        d0 = int(image_dim[-2]/patch_size[0])
        d1 = int(image_dim[-1]/patch_size[1])
        assert d0*patch_size[0]==image_dim[-2], "Image must be divisible into patch size"
        assert d1*patch_size[1]==image_dim[-1], "Image must be divisible into patch size"
#         self.d0, self.d1 = d0, d1 ### number of patches in each axis
        __patch_size = patch_size[0]*patch_size[1]*image_dim[0] ## number of channels in each patch
    
        ### find channel dim
        channel_size = d0*d1 ## number of patches
        
        ### after the number of channels are changed
        init_dim = __patch_size
#         final_dim = int(patch_size[0]*patch_size[1]*hidden_expansion)
        final_dim = int(init_dim*hidden_expansion)

        self.unfold = nn.Unfold(kernel_size=patch_size, stride=patch_size)
        #### rescale the patches (patch wise image non preserving transform, unlike bilinear interpolation)
        self.channel_change = nn.Linear(init_dim, final_dim)
        print(f"MLP Mixer : Channes per patch -> Initial:{init_dim} Final:{final_dim}")
        
        
        self.channel_dim = final_dim
        self.patch_dim = channel_size
        
        self.mixer_blocks = []
        for i in range(num_blocks):
            self.mixer_blocks.append(MixerBlock(self.patch_dim, self.channel_dim))
        self.mixer_blocks = nn.Sequential(*self.mixer_blocks)
        
        self.linear = nn.Linear(self.patch_dim*self.channel_dim, num_classes)
        
        
    def forward(self, x):
        bs = x.shape[0]
        x = self.scaler(x)
        x = self.unfold(x).swapaxes(-1, -2)
        x = self.channel_change(x)
        x = self.mixer_blocks(x)
        x = self.linear(x.view(bs, -1))
        return x

In [13]:
# class MlpMixer(nn.Module):
    
#     def __init__(self, input_channels:int, hidden_image_dim:tuple, patch_size:tuple, num_blocks:int, num_classes:int, use1x1scale=True):
#         super().__init__()
        
#         self.img_dim = hidden_image_dim ### must contain (C, H, W) or (H, W)
        
#         ### find patch dim
#         d0 = int(self.img_dim[-2]/patch_size[0])
#         d1 = int(self.img_dim[-1]/patch_size[1])
#         assert d0*patch_size[0]==self.img_dim[-2], "Image must be divisible into patch size"
#         assert d1*patch_size[1]==self.img_dim[-1], "Image must be divisible into patch size"
        
#         ### find channel dim
#         channel_size = d0*d1 ## number of patches
        
#         ### after the number of channels are changed
#         init_dim = patch_size[0]*patch_size[1]*input_channels ## number of channels in each patch
#         final_dim = patch_size[0]*patch_size[1]*self.img_dim[0]

#         self.scaler = nn.UpsamplingBilinear2d(size=(self.img_dim[-2], self.img_dim[-1]))
        
#         if use1x1scale:
#             self.conv1x1 = nn.Conv2d(input_channels, self.img_dim[0], kernel_size=1, stride=1)
#             if input_channels == self.img_dim[0]:
#                 self.conv1x1 = nn.Identity()
#         else:
#             #### rescale the patches (patch wise image non preserving transform, unlike bilinear interpolation)
#             self.channel_change = nn.Linear(init_dim, final_dim) ## apply after unfold
            
#         self.unfold = nn.Unfold(kernel_size=patch_size, stride=patch_size)

#         print(f"MLP Mixer : Channes per patch -> Initial:{init_dim} Final:{final_dim}")
        
#         ### after the axis are swapped
#         self.channel_dim = final_dim
#         self.patch_dim = channel_size
        
#         self.mixer_blocks = []
#         for i in range(num_blocks):
#             self.mixer_blocks.append(MixerBlock(self.patch_dim, self.channel_dim))
#         self.mixer_blocks = nn.Sequential(*self.mixer_blocks)
        
#         self.linear = nn.Linear(self.patch_dim*self.channel_dim, num_classes)
        
        
#     def forward(self, x):
#         bs = x.shape[0]
#         x = self.scaler(x) ## scale to high dimension
#         x = self.conv1x1(x)
#         x = self.unfold(x).swapaxes(-1, -2).contiguous() ## convert to patches, bring channels first
# #         x = self.channel_change(x) ## change the number of channels
#         x = self.mixer_blocks(x) ## the mixer architecture
#         x = self.linear(x.view(bs, -1)) ## classify
#         return x

In [14]:
mixer = MlpMixer((3, 32, 32), (4, 4), hidden_expansion=2.5, num_blocks=1, num_classes=10)
mixer

MLP Mixer : Channes per patch -> Initial:48 Final:120


MlpMixer(
  (scaler): UpsamplingBilinear2d(size=(32, 32), mode=bilinear)
  (unfold): Unfold(kernel_size=(4, 4), dilation=1, padding=0, stride=(4, 4))
  (channel_change): Linear(in_features=48, out_features=120, bias=True)
  (mixer_blocks): Sequential(
    (0): MixerBlock(
      (ln0): LayerNorm((120,), eps=1e-05, elementwise_affine=True)
      (mlp_patch): MlpBLock(
        (mlp): Sequential(
          (0): Linear(in_features=64, out_features=128, bias=True)
          (1): GELU(approximate='none')
          (2): Linear(in_features=128, out_features=64, bias=True)
        )
      )
      (ln1): LayerNorm((120,), eps=1e-05, elementwise_affine=True)
      (mlp_channel): MlpBLock(
        (mlp): Sequential(
          (0): Linear(in_features=120, out_features=240, bias=True)
          (1): GELU(approximate='none')
          (2): Linear(in_features=240, out_features=120, bias=True)
        )
      )
    )
  )
  (linear): Linear(in_features=7680, out_features=10, bias=True)
)

In [15]:
print("number of params: ", sum(p.numel() for p in mixer.parameters())) 

number of params:  157706


In [16]:
mixer(torch.randn(1, 3, 32, 32))

tensor([[-0.1648, -0.0510, -0.0486, -0.1505, -0.3137, -0.5226,  0.0023,  0.5233,
         -0.2587, -0.5398]], grad_fn=<AddmmBackward0>)

## Patch Mixer

In [17]:
class PatchMixerBlock(nn.Module):
    
    def __init__(self, patch_size, num_channel):
        super().__init__()
        self.patch_size = patch_size
        
#         self.unfold = nn.Unfold(kernel_size=patch_size, stride=patch_size)
        ps = None
        if isinstance(patch_size, int):
            ps = patch_size**2
        else:
            ps = patch_size[0]*patch_size[1]
        ps = ps*num_channel
        
        self.ln0 = nn.LayerNorm(ps)
        self.mlp_patch = MlpBLock(ps, [2])
        
#         self.fold = nn.Fold(kernel_size=patch_size, stride=patch_size)
    
    def forward(self, x):
        ## x has shape-> N, C, H, W; C=Channel
        
        sz = x.shape
        
        y = nn.functional.unfold(x, 
                                 kernel_size=self.patch_size, 
                                 stride=self.patch_size
                                )
        #### mix per patch
        y = torch.swapaxes(y, -1, -2)
        y = self.ln0(y) 
        y = self.mlp_patch(y)
        y = torch.swapaxes(y, -1, -2)
        
        y = nn.functional.fold(y, (sz[-2], sz[-1]), 
                               kernel_size=self.patch_size, 
                               stride=self.patch_size
                              )
        x = x+y
        return x

In [18]:
pmb = PatchMixerBlock(8, 3)
pmb

PatchMixerBlock(
  (ln0): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
  (mlp_patch): MlpBLock(
    (mlp): Sequential(
      (0): Linear(in_features=192, out_features=384, bias=True)
      (1): GELU(approximate='none')
      (2): Linear(in_features=384, out_features=192, bias=True)
    )
  )
)

In [19]:
# pmb(torch.randn(1, 3, 35, 35)).shape

In [20]:
def get_factors(n):
    facts = []
    for i in range(2, n+1):
        if n%i == 0:
            facts.append(i)
    return facts

class PatchMlpMixer(nn.Module):
    
    def __init__(self, image_dim:tuple, patch_sizes:tuple, hidden_channels:int, num_blocks:int, num_classes:int):
        super().__init__()
        
        self.img_dim = image_dim ### must contain (C, H, W)
        self.target_dim = np.prod(patch_sizes)
        
        ### find number of channel for input, the channel is 
        num_channel = image_dim[0]
        
        self.conv1x1 = nn.Conv2d(num_channel, hidden_channels, kernel_size=1, stride=1)
        if num_channel == hidden_channels:
            self.conv1x1 = nn.Identity()
        
        self.mixer_blocks = []
        for i in range(num_blocks):
            for ps in patch_sizes:
                self.mixer_blocks.append(PatchMixerBlock(ps, hidden_channels))
                
        self.mixer_blocks = nn.Sequential(*self.mixer_blocks)
        self.linear = nn.Linear(self.target_dim*self.target_dim*hidden_channels, num_classes)
    
    def forward(self, x):
        bs = x.shape[0]
        
        x = nn.functional.interpolate(x, size=self.target_dim, mode='bilinear', align_corners=True)
        
        x = self.conv1x1(x) 
        x = self.mixer_blocks(x)
        x = self.linear(x.view(bs, -1))
        return x

In [21]:
4*3*5

60

In [22]:
patch_mixer = PatchMlpMixer((3, 35, 35), patch_sizes=[5, 7], hidden_channels=3, num_blocks=1, num_classes=10)

In [23]:
patch_mixer

PatchMlpMixer(
  (conv1x1): Identity()
  (mixer_blocks): Sequential(
    (0): PatchMixerBlock(
      (ln0): LayerNorm((75,), eps=1e-05, elementwise_affine=True)
      (mlp_patch): MlpBLock(
        (mlp): Sequential(
          (0): Linear(in_features=75, out_features=150, bias=True)
          (1): GELU(approximate='none')
          (2): Linear(in_features=150, out_features=75, bias=True)
        )
      )
    )
    (1): PatchMixerBlock(
      (ln0): LayerNorm((147,), eps=1e-05, elementwise_affine=True)
      (mlp_patch): MlpBLock(
        (mlp): Sequential(
          (0): Linear(in_features=147, out_features=294, bias=True)
          (1): GELU(approximate='none')
          (2): Linear(in_features=294, out_features=147, bias=True)
        )
      )
    )
  )
  (linear): Linear(in_features=3675, out_features=10, bias=True)
)

In [24]:
print("number of params: ", sum(p.numel() for p in patch_mixer.parameters())) 

number of params:  146806


In [25]:
patch_mixer(torch.randn(1, 3, 32, 32)).shape

torch.Size([1, 10])

#### Final Model

In [26]:
### Model 1 : balanced with (1) of PatchOnly ; higher image scaling, low channel scaling for similar params
#### it has large dimension of final hidden unit.
## 81, 150
# model = MlpMixer((3, 5*9, 5*9), (5, 5), hidden_expansion=2, num_blocks=10, num_classes=10)
### balanced model with 4*9 expansion as well
## 81, 144
model = MlpMixer((3, 4*9, 4*9), (4, 4), hidden_expansion=3.0, num_blocks=10, num_classes=10)

### Model 2 : use simple method of expanding only hidden dimension/channel
## 64, 135
# model = MlpMixer((3, 32, 32), (4, 4), hidden_expansion=3.2, num_blocks=10, num_classes=10)


model = model.to(device)

MLP Mixer : Channes per patch -> Initial:48 Final:144


In [27]:
model

MlpMixer(
  (scaler): UpsamplingBilinear2d(size=(36, 36), mode=bilinear)
  (unfold): Unfold(kernel_size=(4, 4), dilation=1, padding=0, stride=(4, 4))
  (channel_change): Linear(in_features=48, out_features=144, bias=True)
  (mixer_blocks): Sequential(
    (0): MixerBlock(
      (ln0): LayerNorm((144,), eps=1e-05, elementwise_affine=True)
      (mlp_patch): MlpBLock(
        (mlp): Sequential(
          (0): Linear(in_features=81, out_features=162, bias=True)
          (1): GELU(approximate='none')
          (2): Linear(in_features=162, out_features=81, bias=True)
        )
      )
      (ln1): LayerNorm((144,), eps=1e-05, elementwise_affine=True)
      (mlp_channel): MlpBLock(
        (mlp): Sequential(
          (0): Linear(in_features=144, out_features=288, bias=True)
          (1): GELU(approximate='none')
          (2): Linear(in_features=288, out_features=144, bias=True)
        )
      )
    )
    (1): MixerBlock(
      (ln0): LayerNorm((144,), eps=1e-05, elementwise_affine=True)

In [28]:
# print("number of params: ", sum(p.numel() for p in model.parameters())) 
print("number of params: ", sum(p.numel() for p in model.mixer_blocks.parameters())) 

number of params:  1104390


In [29]:
model(torch.randn(1, 3, 32, 32).to(device)).shape

torch.Size([1, 10])

In [30]:
### Model 1 : balanced using best settings for default patch mixer without scaling hidden channels.
model = PatchMlpMixer((3, 35, 35), patch_sizes=[5,7], hidden_channels=3, num_blocks=10, num_classes=10)

model = model.to(device)

In [31]:
model

PatchMlpMixer(
  (conv1x1): Identity()
  (mixer_blocks): Sequential(
    (0): PatchMixerBlock(
      (ln0): LayerNorm((75,), eps=1e-05, elementwise_affine=True)
      (mlp_patch): MlpBLock(
        (mlp): Sequential(
          (0): Linear(in_features=75, out_features=150, bias=True)
          (1): GELU(approximate='none')
          (2): Linear(in_features=150, out_features=75, bias=True)
        )
      )
    )
    (1): PatchMixerBlock(
      (ln0): LayerNorm((147,), eps=1e-05, elementwise_affine=True)
      (mlp_patch): MlpBLock(
        (mlp): Sequential(
          (0): Linear(in_features=147, out_features=294, bias=True)
          (1): GELU(approximate='none')
          (2): Linear(in_features=294, out_features=147, bias=True)
        )
      )
    )
    (2): PatchMixerBlock(
      (ln0): LayerNorm((75,), eps=1e-05, elementwise_affine=True)
      (mlp_patch): MlpBLock(
        (mlp): Sequential(
          (0): Linear(in_features=75, out_features=150, bias=True)
          (1): GELU(a

In [32]:
print("number of params: ", sum(p.numel() for p in model.mixer_blocks.parameters())) 

number of params:  1100460


In [33]:
model(torch.randn(1, 3, 32, 32).to(device)).shape

torch.Size([1, 10])

In [34]:
print("number of params: ", sum(p.numel() for p in model.parameters())) 
## Patch ||  1137220
## Mixer ||  1141703

number of params:  1137220


## Training

In [35]:
# model_name = f'mlp_mixer_c10_s{SEED}'
# model_name = f'temp_c10_s{SEED}'

In [36]:
# EPOCHS = 200
# criterion = nn.CrossEntropyLoss()
# optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)

In [37]:
# STAT ={'train_stat':[], 'test_stat':[]}

In [38]:
## Following is copied from 
### https://github.com/kuangliu/pytorch-cifar/blob/master/main.py

# Training
def train(epoch):
    model.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(tqdm(train_loader)):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        
    STAT['train_stat'].append((epoch, train_loss/(batch_idx+1), 100.*correct/total)) ### (Epochs, Loss, Acc)
    print(f"[Train] {epoch} Loss: {train_loss/(batch_idx+1):.3f} | Acc: {100.*correct/total:.3f} {correct}/{total}")
    return

In [39]:
best_acc = -1
def test(epoch):
    global best_acc
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    time_taken = []
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(tqdm(test_loader)):
            inputs, targets = inputs.to(device), targets.to(device)

            start = time.time()

            outputs = model(inputs)

            start = time.time()-start
            time_taken.append(start)

            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

    STAT['test_stat'].append((epoch, test_loss/(batch_idx+1), 100.*correct/total, np.mean(time_taken))) ### (Epochs, Loss, Acc, time)
    print(f"[Test] {epoch} Loss: {test_loss/(batch_idx+1):.3f} | Acc: {100.*correct/total:.3f} {correct}/{total}")
    
    # Save checkpoint.
    acc = 100.*correct/total
    if acc > best_acc:
        print('Saving..')
        state = {
            'model': model.state_dict(),
            'acc': acc,
            'epoch': epoch
        }
        if not os.path.isdir('models'):
            os.mkdir('models')
        torch.save(state, f'./models/{model_name}.pth')
        best_acc = acc
        
    with open(f"./output/{model_name}_data.json", 'w') as f:
        json.dump(STAT, f, indent=0)

In [40]:
# start_epoch = 0  # start from epoch 0 or last checkpoint epoch
# resume = False

# if resume:
#     # Load checkpoint.
#     print('==> Resuming from checkpoint..')
#     assert os.path.isdir('./models'), 'Error: no checkpoint directory found!'
#     checkpoint = torch.load(f'./models/{model_name}.pth')
#     model.load_state_dict(checkpoint['model'])
#     best_acc = checkpoint['acc']
#     start_epoch = checkpoint['epoch']

In [41]:
# ### Train the whole damn thing

# for epoch in range(start_epoch, start_epoch+EPOCHS): ## for 200 epochs
#     train(epoch)
#     test(epoch)
#     scheduler.step()

In [42]:
# best_acc

In [43]:
# checkpoint = torch.load(f'./models/{model_name}.pth')
# best_acc = checkpoint['acc']
# start_epoch = checkpoint['epoch']

# best_acc, start_epoch

In [44]:
# model.load_state_dict(checkpoint['model'])

In [45]:
# model

In [46]:
# STAT

In [47]:
# train_stat = np.array(STAT['train_stat'])
# test_stat = np.array(STAT['test_stat'])

In [48]:
# plt.plot(train_stat[:,1], label='train')
# plt.plot(test_stat[:,1], label='test')
# plt.ylabel("Loss")
# plt.legend()
# plt.savefig(f"./output/plots/{model_name}_loss.svg")
# plt.show()

In [49]:
# plt.plot(train_stat[:,2], label='train')
# plt.plot(test_stat[:,2], label='test')
# plt.ylabel("Accuracy")
# plt.legend()
# plt.savefig(f"./output/plots/{model_name}_accs.svg")
# plt.show()

## Benchmark Training

In [50]:
def get_data_loaders(seed, ds):
    BS = 64
    if ds == 'c100': BS = 128
    torch.manual_seed(seed)
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=BS, shuffle=True, num_workers=2)
    test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=BS, shuffle=False, num_workers=2)
    return train_loader, test_loader

In [51]:
def benchmark():
    global model, optimizer, train_loader, test_loader, model_name, criterion, STAT, best_acc
    EPOCHS = 200
    criterion = nn.CrossEntropyLoss()
    lr = 0.001
    DS = 'c100'
#     DS = 'c10'
    for SEED in [369]:
        for num_layers in [7, 10]:

            for i in range(3): ## 3 models training
                print("Experiment index:", i)
                train_loader, test_loader = get_data_loaders(SEED, DS)
                num_cls = 10
                if DS=='c100': num_cls = 100
                ### FOR ORIGINAL MIXER V0
                if i==0:
                    ## hard core ignore
#                     if SEED==258 and num_layers==7: continue
                    model = MlpMixer((3, 4*9, 4*9), (4, 4), hidden_expansion=3.0, num_blocks=num_layers, num_classes=num_cls)
                    model_name = f'original_mixer0_l{num_layers}_{DS}_s{SEED}'
                elif i == 1:
                    ### FOR ORIGINAL MIXER V1
#                     if SEED==258 and num_layers==7: continue
                    model = MlpMixer((3, 32, 32), (4, 4), hidden_expansion=3.2, num_blocks=num_layers, num_classes=num_cls)
                    model_name = f'original_mixer1_l{num_layers}_{DS}_s{SEED}'
                elif i == 2:
#                     if SEED==369 and num_layers==7: continue
                    model = PatchMlpMixer((3, 35, 35), patch_sizes=[5,7], hidden_channels=3, num_blocks=num_layers, num_classes=num_cls)
                    model_name = f'patchonly_mixer0_l{num_layers}_{DS}_s{SEED}'
                else:
                    print("JPT........!!!!")
                model = model.to(device)
                optimizer = torch.optim.Adam(model.parameters(), lr=lr)
                scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)

                num_params = sum(p.numel() for p in model.parameters())

                model_name = "03.0_" + model_name
                print(f"EXPERIMENTING FOR : {model_name} | params: {num_params}  .......\n.......")
                
#                 continue
                STAT ={'train_stat':[], 'test_stat':[], 'num_params':num_params}
                best_acc = -1
                for epoch in range(0, EPOCHS): ## for 200 epochs
                    train(epoch)
                    test(epoch)
                    scheduler.step()
                print(f"Training finished\n")
                pass
            pass
        pass
    return 0           

In [None]:
benchmark()

Experiment index: 0
MLP Mixer : Channes per patch -> Initial:48 Final:144
EXPERIMENTING FOR : 03.0_original_mixer0_l7_c100_s369 | params: 1946629  .......
.......


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 31.19it/s]


[Train] 0 Loss: 3.658 | Acc: 16.148 8074/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 65.64it/s]


[Test] 0 Loss: 3.103 | Acc: 25.030 2503/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 31.96it/s]


[Train] 1 Loss: 3.023 | Acc: 26.358 13179/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 67.07it/s]


[Test] 1 Loss: 2.806 | Acc: 31.160 3116/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 31.66it/s]


[Train] 2 Loss: 2.744 | Acc: 31.588 15794/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 67.41it/s]


[Test] 2 Loss: 2.569 | Acc: 35.420 3542/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 31.68it/s]


[Train] 3 Loss: 2.539 | Acc: 35.934 17967/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 68.36it/s]


[Test] 3 Loss: 2.425 | Acc: 38.790 3879/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 31.75it/s]


[Train] 4 Loss: 2.383 | Acc: 39.342 19671/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 67.08it/s]


[Test] 4 Loss: 2.310 | Acc: 41.750 4175/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 31.62it/s]


[Train] 5 Loss: 2.234 | Acc: 42.852 21426/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 67.50it/s]


[Test] 5 Loss: 2.249 | Acc: 43.180 4318/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 31.56it/s]


[Train] 6 Loss: 2.098 | Acc: 45.506 22753/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 66.77it/s]


[Test] 6 Loss: 2.201 | Acc: 44.660 4466/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 31.50it/s]


[Train] 7 Loss: 1.989 | Acc: 47.862 23931/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 66.64it/s]


[Test] 7 Loss: 2.185 | Acc: 46.100 4610/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 31.35it/s]


[Train] 8 Loss: 1.874 | Acc: 50.466 25233/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 67.99it/s]


[Test] 8 Loss: 2.158 | Acc: 46.830 4683/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 31.34it/s]


[Train] 9 Loss: 1.772 | Acc: 52.770 26385/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 67.54it/s]


[Test] 9 Loss: 2.081 | Acc: 48.760 4876/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 31.40it/s]


[Train] 10 Loss: 1.668 | Acc: 55.178 27589/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 66.83it/s]


[Test] 10 Loss: 2.080 | Acc: 49.420 4942/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 31.39it/s]


[Train] 11 Loss: 1.573 | Acc: 57.478 28739/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 67.46it/s]


[Test] 11 Loss: 2.085 | Acc: 50.110 5011/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 31.45it/s]


[Train] 12 Loss: 1.488 | Acc: 59.444 29722/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 66.86it/s]


[Test] 12 Loss: 2.122 | Acc: 50.440 5044/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 31.47it/s]


[Train] 13 Loss: 1.407 | Acc: 61.320 30660/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 67.81it/s]


[Test] 13 Loss: 2.132 | Acc: 50.120 5012/10000


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 31.37it/s]


[Train] 14 Loss: 1.316 | Acc: 63.606 31803/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 67.36it/s]


[Test] 14 Loss: 2.192 | Acc: 51.200 5120/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 31.41it/s]


[Train] 15 Loss: 1.233 | Acc: 65.586 32793/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 67.34it/s]


[Test] 15 Loss: 2.247 | Acc: 51.100 5110/10000


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 31.49it/s]


[Train] 16 Loss: 1.164 | Acc: 67.290 33645/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 67.15it/s]


[Test] 16 Loss: 2.357 | Acc: 51.220 5122/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 31.45it/s]


[Train] 17 Loss: 1.090 | Acc: 69.422 34711/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 67.54it/s]


[Test] 17 Loss: 2.323 | Acc: 52.340 5234/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 31.42it/s]


[Train] 18 Loss: 1.015 | Acc: 71.126 35563/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 67.13it/s]


[Test] 18 Loss: 2.477 | Acc: 50.730 5073/10000


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 31.39it/s]


[Train] 19 Loss: 0.949 | Acc: 72.732 36366/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 66.11it/s]


[Test] 19 Loss: 2.537 | Acc: 50.510 5051/10000


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 31.29it/s]


[Train] 20 Loss: 0.876 | Acc: 74.720 37360/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 66.36it/s]


[Test] 20 Loss: 2.589 | Acc: 51.770 5177/10000


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 31.28it/s]


[Train] 21 Loss: 0.825 | Acc: 76.242 38121/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 66.53it/s]


[Test] 21 Loss: 2.773 | Acc: 50.950 5095/10000


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 31.21it/s]


[Train] 22 Loss: 0.778 | Acc: 77.588 38794/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 66.19it/s]


[Test] 22 Loss: 2.878 | Acc: 50.980 5098/10000


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 31.08it/s]


[Train] 23 Loss: 0.731 | Acc: 79.046 39523/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 66.60it/s]


[Test] 23 Loss: 2.872 | Acc: 50.980 5098/10000


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 31.14it/s]


[Train] 24 Loss: 0.688 | Acc: 80.038 40019/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 66.45it/s]


[Test] 24 Loss: 3.066 | Acc: 50.360 5036/10000


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 31.28it/s]


[Train] 25 Loss: 0.649 | Acc: 81.322 40661/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 66.31it/s]


[Test] 25 Loss: 3.178 | Acc: 51.420 5142/10000


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 31.28it/s]


[Train] 26 Loss: 0.614 | Acc: 82.282 41141/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 67.05it/s]


[Test] 26 Loss: 3.232 | Acc: 51.510 5151/10000


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 31.33it/s]


[Train] 27 Loss: 0.589 | Acc: 83.286 41643/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 67.07it/s]


[Test] 27 Loss: 3.363 | Acc: 51.330 5133/10000


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 31.29it/s]


[Train] 28 Loss: 0.578 | Acc: 83.726 41863/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 67.08it/s]


[Test] 28 Loss: 3.369 | Acc: 52.040 5204/10000


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 31.33it/s]


[Train] 29 Loss: 0.545 | Acc: 84.656 42328/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 66.45it/s]


[Test] 29 Loss: 3.549 | Acc: 51.000 5100/10000


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 31.19it/s]


[Train] 30 Loss: 0.505 | Acc: 85.600 42800/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 66.38it/s]


[Test] 30 Loss: 3.575 | Acc: 51.670 5167/10000


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 31.23it/s]


[Train] 31 Loss: 0.506 | Acc: 85.990 42995/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 65.44it/s]


[Test] 31 Loss: 3.777 | Acc: 52.000 5200/10000


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 31.15it/s]


[Train] 32 Loss: 0.486 | Acc: 86.512 43256/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 67.09it/s]


[Test] 32 Loss: 3.819 | Acc: 51.750 5175/10000


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 31.30it/s]


[Train] 33 Loss: 0.469 | Acc: 87.078 43539/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 66.98it/s]


[Test] 33 Loss: 3.877 | Acc: 52.790 5279/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 31.23it/s]


[Train] 34 Loss: 0.431 | Acc: 88.198 44099/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 65.09it/s]


[Test] 34 Loss: 3.946 | Acc: 52.150 5215/10000


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 31.28it/s]


[Train] 35 Loss: 0.433 | Acc: 88.016 44008/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 66.98it/s]


[Test] 35 Loss: 4.208 | Acc: 51.590 5159/10000


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 31.22it/s]


[Train] 36 Loss: 0.442 | Acc: 88.214 44107/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 65.97it/s]


[Test] 36 Loss: 4.270 | Acc: 51.630 5163/10000


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 31.28it/s]


[Train] 37 Loss: 0.415 | Acc: 88.964 44482/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 66.98it/s]


[Test] 37 Loss: 4.217 | Acc: 52.670 5267/10000


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 31.28it/s]


[Train] 38 Loss: 0.397 | Acc: 89.374 44687/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 66.87it/s]


[Test] 38 Loss: 4.421 | Acc: 52.630 5263/10000


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 31.06it/s]


[Train] 39 Loss: 0.390 | Acc: 89.628 44814/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 67.52it/s]


[Test] 39 Loss: 4.475 | Acc: 52.590 5259/10000


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 31.14it/s]


[Train] 40 Loss: 0.378 | Acc: 90.012 45006/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 65.78it/s]


[Test] 40 Loss: 4.513 | Acc: 52.110 5211/10000


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 31.20it/s]


[Train] 41 Loss: 0.356 | Acc: 90.674 45337/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 66.71it/s]


[Test] 41 Loss: 4.484 | Acc: 53.170 5317/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 31.05it/s]


[Train] 42 Loss: 0.373 | Acc: 90.202 45101/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 65.95it/s]


[Test] 42 Loss: 4.796 | Acc: 52.360 5236/10000


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 31.16it/s]


[Train] 43 Loss: 0.341 | Acc: 91.174 45587/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 66.78it/s]


[Test] 43 Loss: 4.822 | Acc: 52.930 5293/10000


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 30.88it/s]


[Train] 44 Loss: 0.352 | Acc: 90.968 45484/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 65.96it/s]


[Test] 44 Loss: 4.974 | Acc: 52.320 5232/10000


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 31.06it/s]


[Train] 45 Loss: 0.346 | Acc: 91.268 45634/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 66.48it/s]


[Test] 45 Loss: 5.029 | Acc: 52.070 5207/10000


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 31.08it/s]


[Train] 46 Loss: 0.329 | Acc: 91.720 45860/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 66.18it/s]


[Test] 46 Loss: 5.157 | Acc: 52.340 5234/10000


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 31.27it/s]


[Train] 47 Loss: 0.331 | Acc: 91.738 45869/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 66.97it/s]


[Test] 47 Loss: 5.146 | Acc: 52.780 5278/10000


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 31.19it/s]


[Train] 48 Loss: 0.317 | Acc: 92.122 46061/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 67.48it/s]


[Test] 48 Loss: 5.319 | Acc: 53.030 5303/10000


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 31.05it/s]


[Train] 49 Loss: 0.304 | Acc: 92.444 46222/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 66.93it/s]


[Test] 49 Loss: 5.406 | Acc: 52.030 5203/10000


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 31.01it/s]


[Train] 50 Loss: 0.309 | Acc: 92.272 46136/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 66.98it/s]


[Test] 50 Loss: 5.440 | Acc: 52.000 5200/10000


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 31.03it/s]


[Train] 51 Loss: 0.290 | Acc: 92.828 46414/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 66.68it/s]


[Test] 51 Loss: 5.514 | Acc: 52.710 5271/10000


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 31.00it/s]


[Train] 52 Loss: 0.288 | Acc: 93.050 46525/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 67.44it/s]


[Test] 52 Loss: 5.511 | Acc: 52.890 5289/10000


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 31.01it/s]


[Train] 53 Loss: 0.261 | Acc: 93.570 46785/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 66.52it/s]


[Test] 53 Loss: 5.700 | Acc: 53.120 5312/10000


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 30.97it/s]


[Train] 54 Loss: 0.265 | Acc: 93.596 46798/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 67.16it/s]


[Test] 54 Loss: 5.816 | Acc: 51.940 5194/10000


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 31.10it/s]


[Train] 55 Loss: 0.288 | Acc: 93.206 46603/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 66.95it/s]


[Test] 55 Loss: 5.725 | Acc: 52.170 5217/10000


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 30.87it/s]


[Train] 56 Loss: 0.268 | Acc: 93.738 46869/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 66.94it/s]


[Test] 56 Loss: 5.811 | Acc: 52.790 5279/10000


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 30.85it/s]


[Train] 57 Loss: 0.261 | Acc: 93.890 46945/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 66.39it/s]


[Test] 57 Loss: 6.033 | Acc: 51.640 5164/10000


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 30.90it/s]


[Train] 58 Loss: 0.248 | Acc: 94.008 47004/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 66.35it/s]


[Test] 58 Loss: 6.142 | Acc: 52.070 5207/10000


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 30.97it/s]


[Train] 59 Loss: 0.254 | Acc: 94.072 47036/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 66.59it/s]


[Test] 59 Loss: 6.023 | Acc: 52.660 5266/10000


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 31.08it/s]


[Train] 60 Loss: 0.244 | Acc: 94.334 47167/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 66.28it/s]


[Test] 60 Loss: 6.186 | Acc: 51.790 5179/10000


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 30.81it/s]


[Train] 61 Loss: 0.234 | Acc: 94.584 47292/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 65.79it/s]


[Test] 61 Loss: 6.240 | Acc: 52.940 5294/10000


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 30.75it/s]


[Train] 62 Loss: 0.232 | Acc: 94.640 47320/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 67.44it/s]


[Test] 62 Loss: 6.228 | Acc: 52.420 5242/10000


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 30.97it/s]


[Train] 63 Loss: 0.233 | Acc: 94.718 47359/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 67.36it/s]


[Test] 63 Loss: 6.347 | Acc: 52.770 5277/10000


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 31.00it/s]


[Train] 64 Loss: 0.223 | Acc: 94.998 47499/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 65.97it/s]


[Test] 64 Loss: 6.281 | Acc: 53.940 5394/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 31.11it/s]


[Train] 65 Loss: 0.213 | Acc: 95.004 47502/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 65.47it/s]


[Test] 65 Loss: 6.503 | Acc: 52.190 5219/10000


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 31.05it/s]


[Train] 66 Loss: 0.213 | Acc: 95.134 47567/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 67.47it/s]


[Test] 66 Loss: 6.640 | Acc: 52.390 5239/10000


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 31.09it/s]


[Train] 67 Loss: 0.210 | Acc: 95.250 47625/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 67.44it/s]


[Test] 67 Loss: 6.681 | Acc: 53.060 5306/10000


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 31.10it/s]


[Train] 68 Loss: 0.211 | Acc: 95.158 47579/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 66.77it/s]


[Test] 68 Loss: 6.653 | Acc: 52.670 5267/10000


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 30.94it/s]


[Train] 69 Loss: 0.200 | Acc: 95.614 47807/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 66.84it/s]


[Test] 69 Loss: 6.818 | Acc: 52.580 5258/10000


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 30.91it/s]


[Train] 70 Loss: 0.180 | Acc: 95.938 47969/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 63.09it/s]


[Test] 70 Loss: 6.720 | Acc: 52.990 5299/10000


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 30.92it/s]


[Train] 71 Loss: 0.196 | Acc: 95.614 47807/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 60.90it/s]


[Test] 71 Loss: 6.919 | Acc: 52.640 5264/10000


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 31.03it/s]


[Train] 72 Loss: 0.192 | Acc: 95.748 47874/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 66.15it/s]


[Test] 72 Loss: 6.965 | Acc: 52.430 5243/10000


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 31.10it/s]


[Train] 73 Loss: 0.185 | Acc: 95.836 47918/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 66.57it/s]


[Test] 73 Loss: 6.898 | Acc: 53.150 5315/10000


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 30.54it/s]


[Train] 74 Loss: 0.169 | Acc: 96.234 48117/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 66.57it/s]


[Test] 74 Loss: 6.922 | Acc: 53.840 5384/10000


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 31.13it/s]


[Train] 75 Loss: 0.167 | Acc: 96.264 48132/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 66.39it/s]


[Test] 75 Loss: 6.953 | Acc: 53.020 5302/10000


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 31.04it/s]


[Train] 76 Loss: 0.167 | Acc: 96.240 48120/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 66.66it/s]


[Test] 76 Loss: 7.242 | Acc: 53.080 5308/10000


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 31.09it/s]


[Train] 77 Loss: 0.163 | Acc: 96.378 48189/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 66.49it/s]


[Test] 77 Loss: 7.260 | Acc: 53.080 5308/10000


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 31.20it/s]


[Train] 78 Loss: 0.160 | Acc: 96.466 48233/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 66.44it/s]


[Test] 78 Loss: 7.187 | Acc: 52.900 5290/10000


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 31.02it/s]


[Train] 79 Loss: 0.150 | Acc: 96.768 48384/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 67.21it/s]


[Test] 79 Loss: 7.362 | Acc: 53.560 5356/10000


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 30.95it/s]


[Train] 80 Loss: 0.157 | Acc: 96.612 48306/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 66.85it/s]


[Test] 80 Loss: 7.405 | Acc: 53.350 5335/10000


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 30.94it/s]


[Train] 81 Loss: 0.149 | Acc: 96.766 48383/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 66.32it/s]


[Test] 81 Loss: 7.310 | Acc: 53.090 5309/10000


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 31.01it/s]


[Train] 82 Loss: 0.140 | Acc: 96.996 48498/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 67.02it/s]


[Test] 82 Loss: 7.404 | Acc: 53.470 5347/10000


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 31.26it/s]


[Train] 83 Loss: 0.134 | Acc: 97.010 48505/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 66.53it/s]


[Test] 83 Loss: 7.524 | Acc: 53.520 5352/10000


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 31.07it/s]


[Train] 84 Loss: 0.131 | Acc: 97.072 48536/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 67.49it/s]


[Test] 84 Loss: 7.493 | Acc: 53.710 5371/10000


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 31.03it/s]


[Train] 85 Loss: 0.135 | Acc: 97.144 48572/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 66.86it/s]


[Test] 85 Loss: 7.621 | Acc: 53.990 5399/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 30.82it/s]


[Train] 86 Loss: 0.135 | Acc: 97.106 48553/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 66.54it/s]


[Test] 86 Loss: 7.601 | Acc: 54.320 5432/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 31.02it/s]


[Train] 87 Loss: 0.122 | Acc: 97.412 48706/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 66.44it/s]


[Test] 87 Loss: 7.666 | Acc: 53.770 5377/10000


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 30.67it/s]


[Train] 88 Loss: 0.116 | Acc: 97.582 48791/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 67.17it/s]


[Test] 88 Loss: 7.674 | Acc: 53.760 5376/10000


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 30.96it/s]


[Train] 89 Loss: 0.132 | Acc: 97.324 48662/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 66.76it/s]


[Test] 89 Loss: 7.600 | Acc: 53.670 5367/10000


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 30.91it/s]


[Train] 90 Loss: 0.118 | Acc: 97.466 48733/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 65.41it/s]


[Test] 90 Loss: 7.620 | Acc: 54.710 5471/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 30.99it/s]


[Train] 91 Loss: 0.117 | Acc: 97.548 48774/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 66.03it/s]


[Test] 91 Loss: 7.793 | Acc: 53.860 5386/10000


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 31.04it/s]


[Train] 92 Loss: 0.121 | Acc: 97.520 48760/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 66.89it/s]


[Test] 92 Loss: 7.526 | Acc: 53.750 5375/10000


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 30.96it/s]


[Train] 93 Loss: 0.096 | Acc: 97.840 48920/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 67.29it/s]


[Test] 93 Loss: 7.605 | Acc: 54.020 5402/10000


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 30.87it/s]


[Train] 94 Loss: 0.104 | Acc: 97.804 48902/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 65.86it/s]


[Test] 94 Loss: 7.718 | Acc: 53.760 5376/10000


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 30.86it/s]


[Train] 95 Loss: 0.105 | Acc: 97.740 48870/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 65.45it/s]


[Test] 95 Loss: 7.769 | Acc: 54.110 5411/10000


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 31.12it/s]


[Train] 96 Loss: 0.096 | Acc: 97.912 48956/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 66.82it/s]


[Test] 96 Loss: 7.852 | Acc: 54.680 5468/10000


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 30.99it/s]


[Train] 97 Loss: 0.099 | Acc: 97.992 48996/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 65.22it/s]


[Test] 97 Loss: 7.964 | Acc: 54.260 5426/10000


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 30.78it/s]


[Train] 98 Loss: 0.088 | Acc: 98.122 49061/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 64.77it/s]


[Test] 98 Loss: 7.853 | Acc: 53.960 5396/10000


100%|███████████████████████████████████████████████████| 391/391 [00:12<00:00, 30.27it/s]


[Train] 99 Loss: 0.093 | Acc: 98.082 49041/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 66.58it/s]


[Test] 99 Loss: 8.113 | Acc: 54.000 5400/10000


100%|███████████████████████████████████████████████████| 391/391 [00:22<00:00, 17.22it/s]


[Train] 100 Loss: 0.094 | Acc: 98.004 49002/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 36.49it/s]


[Test] 100 Loss: 7.850 | Acc: 54.030 5403/10000


100%|███████████████████████████████████████████████████| 391/391 [00:24<00:00, 16.23it/s]


[Train] 101 Loss: 0.090 | Acc: 98.130 49065/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 40.20it/s]


[Test] 101 Loss: 8.059 | Acc: 54.430 5443/10000


100%|███████████████████████████████████████████████████| 391/391 [00:25<00:00, 15.57it/s]


[Train] 102 Loss: 0.089 | Acc: 98.166 49083/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 43.12it/s]


[Test] 102 Loss: 7.930 | Acc: 53.690 5369/10000


100%|███████████████████████████████████████████████████| 391/391 [00:25<00:00, 15.63it/s]


[Train] 103 Loss: 0.081 | Acc: 98.268 49134/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 37.72it/s]


[Test] 103 Loss: 7.902 | Acc: 54.010 5401/10000


100%|███████████████████████████████████████████████████| 391/391 [00:25<00:00, 15.44it/s]


[Train] 104 Loss: 0.084 | Acc: 98.246 49123/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 36.41it/s]


[Test] 104 Loss: 7.741 | Acc: 54.650 5465/10000


100%|███████████████████████████████████████████████████| 391/391 [00:25<00:00, 15.41it/s]


[Train] 105 Loss: 0.073 | Acc: 98.406 49203/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 39.50it/s]


[Test] 105 Loss: 8.037 | Acc: 54.110 5411/10000


100%|███████████████████████████████████████████████████| 391/391 [00:25<00:00, 15.54it/s]


[Train] 106 Loss: 0.068 | Acc: 98.644 49322/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 37.46it/s]


[Test] 106 Loss: 7.826 | Acc: 54.230 5423/10000


100%|███████████████████████████████████████████████████| 391/391 [00:25<00:00, 15.26it/s]


[Train] 107 Loss: 0.069 | Acc: 98.638 49319/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 45.66it/s]


[Test] 107 Loss: 8.024 | Acc: 53.720 5372/10000


100%|███████████████████████████████████████████████████| 391/391 [00:25<00:00, 15.33it/s]


[Train] 108 Loss: 0.072 | Acc: 98.498 49249/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 41.37it/s]


[Test] 108 Loss: 8.070 | Acc: 54.360 5436/10000


100%|███████████████████████████████████████████████████| 391/391 [00:25<00:00, 15.26it/s]


[Train] 109 Loss: 0.068 | Acc: 98.532 49266/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 35.44it/s]


[Test] 109 Loss: 8.075 | Acc: 54.630 5463/10000


100%|███████████████████████████████████████████████████| 391/391 [00:25<00:00, 15.44it/s]


[Train] 110 Loss: 0.063 | Acc: 98.718 49359/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 35.97it/s]


[Test] 110 Loss: 8.148 | Acc: 53.960 5396/10000


100%|███████████████████████████████████████████████████| 391/391 [00:25<00:00, 15.46it/s]


[Train] 111 Loss: 0.063 | Acc: 98.620 49310/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 37.65it/s]


[Test] 111 Loss: 8.120 | Acc: 54.890 5489/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:25<00:00, 15.62it/s]


[Train] 112 Loss: 0.058 | Acc: 98.752 49376/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 36.13it/s]


[Test] 112 Loss: 8.311 | Acc: 54.480 5448/10000


100%|███████████████████████████████████████████████████| 391/391 [00:24<00:00, 15.65it/s]


[Train] 113 Loss: 0.060 | Acc: 98.860 49430/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 37.84it/s]


[Test] 113 Loss: 8.186 | Acc: 54.680 5468/10000


100%|███████████████████████████████████████████████████| 391/391 [00:25<00:00, 15.10it/s]


[Train] 114 Loss: 0.052 | Acc: 98.858 49429/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 44.23it/s]


[Test] 114 Loss: 8.242 | Acc: 54.420 5442/10000


100%|███████████████████████████████████████████████████| 391/391 [00:25<00:00, 15.39it/s]


[Train] 115 Loss: 0.055 | Acc: 98.882 49441/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 38.40it/s]


[Test] 115 Loss: 8.252 | Acc: 54.890 5489/10000


100%|███████████████████████████████████████████████████| 391/391 [00:20<00:00, 19.49it/s]


[Train] 116 Loss: 0.054 | Acc: 98.894 49447/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 53.93it/s]


[Test] 116 Loss: 8.137 | Acc: 54.920 5492/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:19<00:00, 19.89it/s]


[Train] 117 Loss: 0.054 | Acc: 98.828 49414/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 45.28it/s]


[Test] 117 Loss: 8.287 | Acc: 54.010 5401/10000


100%|███████████████████████████████████████████████████| 391/391 [00:19<00:00, 20.46it/s]


[Train] 118 Loss: 0.054 | Acc: 98.976 49488/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 45.46it/s]


[Test] 118 Loss: 8.076 | Acc: 54.600 5460/10000


100%|███████████████████████████████████████████████████| 391/391 [00:19<00:00, 19.81it/s]


[Train] 119 Loss: 0.051 | Acc: 98.986 49493/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 36.49it/s]


[Test] 119 Loss: 8.120 | Acc: 54.910 5491/10000


100%|███████████████████████████████████████████████████| 391/391 [00:25<00:00, 15.58it/s]


[Train] 120 Loss: 0.037 | Acc: 99.112 49556/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 36.89it/s]


[Test] 120 Loss: 8.097 | Acc: 54.150 5415/10000


100%|███████████████████████████████████████████████████| 391/391 [00:24<00:00, 15.69it/s]


[Train] 121 Loss: 0.045 | Acc: 99.026 49513/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 41.59it/s]


[Test] 121 Loss: 8.092 | Acc: 54.880 5488/10000


100%|███████████████████████████████████████████████████| 391/391 [00:25<00:00, 15.35it/s]


[Train] 122 Loss: 0.038 | Acc: 99.200 49600/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 38.00it/s]


[Test] 122 Loss: 8.049 | Acc: 55.250 5525/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:20<00:00, 19.05it/s]


[Train] 123 Loss: 0.040 | Acc: 99.184 49592/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 42.61it/s]


[Test] 123 Loss: 7.839 | Acc: 55.110 5511/10000


100%|███████████████████████████████████████████████████| 391/391 [00:25<00:00, 15.33it/s]


[Train] 124 Loss: 0.035 | Acc: 99.240 49620/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 43.14it/s]


[Test] 124 Loss: 8.145 | Acc: 55.230 5523/10000


100%|███████████████████████████████████████████████████| 391/391 [00:25<00:00, 15.59it/s]


[Train] 125 Loss: 0.034 | Acc: 99.302 49651/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 42.08it/s]


[Test] 125 Loss: 8.100 | Acc: 54.590 5459/10000


100%|███████████████████████████████████████████████████| 391/391 [00:25<00:00, 15.21it/s]


[Train] 126 Loss: 0.037 | Acc: 99.212 49606/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 45.24it/s]


[Test] 126 Loss: 8.238 | Acc: 54.360 5436/10000


100%|███████████████████████████████████████████████████| 391/391 [00:25<00:00, 15.29it/s]


[Train] 127 Loss: 0.034 | Acc: 99.340 49670/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 41.14it/s]


[Test] 127 Loss: 8.153 | Acc: 54.760 5476/10000


100%|███████████████████████████████████████████████████| 391/391 [00:25<00:00, 15.38it/s]


[Train] 128 Loss: 0.034 | Acc: 99.292 49646/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 37.51it/s]


[Test] 128 Loss: 8.088 | Acc: 55.160 5516/10000


100%|███████████████████████████████████████████████████| 391/391 [00:25<00:00, 15.38it/s]


[Train] 129 Loss: 0.034 | Acc: 99.316 49658/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 41.49it/s]


[Test] 129 Loss: 8.201 | Acc: 54.550 5455/10000


100%|███████████████████████████████████████████████████| 391/391 [00:25<00:00, 15.52it/s]


[Train] 130 Loss: 0.035 | Acc: 99.266 49633/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 36.59it/s]


[Test] 130 Loss: 8.215 | Acc: 54.750 5475/10000


100%|███████████████████████████████████████████████████| 391/391 [00:25<00:00, 15.44it/s]


[Train] 131 Loss: 0.033 | Acc: 99.384 49692/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 42.58it/s]


[Test] 131 Loss: 8.074 | Acc: 54.890 5489/10000


100%|███████████████████████████████████████████████████| 391/391 [00:25<00:00, 15.48it/s]


[Train] 132 Loss: 0.025 | Acc: 99.400 49700/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 42.13it/s]


[Test] 132 Loss: 8.101 | Acc: 54.810 5481/10000


100%|███████████████████████████████████████████████████| 391/391 [00:25<00:00, 15.19it/s]


[Train] 133 Loss: 0.032 | Acc: 99.374 49687/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 40.69it/s]


[Test] 133 Loss: 7.981 | Acc: 54.920 5492/10000


100%|███████████████████████████████████████████████████| 391/391 [00:25<00:00, 15.47it/s]


[Train] 134 Loss: 0.028 | Acc: 99.452 49726/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 32.29it/s]


[Test] 134 Loss: 8.034 | Acc: 54.960 5496/10000


100%|███████████████████████████████████████████████████| 391/391 [00:25<00:00, 15.53it/s]


[Train] 135 Loss: 0.022 | Acc: 99.560 49780/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 39.27it/s]


[Test] 135 Loss: 8.137 | Acc: 54.790 5479/10000


100%|███████████████████████████████████████████████████| 391/391 [00:25<00:00, 15.24it/s]


[Train] 136 Loss: 0.026 | Acc: 99.596 49798/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 38.86it/s]


[Test] 136 Loss: 7.889 | Acc: 54.980 5498/10000


100%|███████████████████████████████████████████████████| 391/391 [00:25<00:00, 15.36it/s]


[Train] 137 Loss: 0.021 | Acc: 99.504 49752/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 37.84it/s]


[Test] 137 Loss: 7.943 | Acc: 55.110 5511/10000


100%|███████████████████████████████████████████████████| 391/391 [00:25<00:00, 15.60it/s]


[Train] 138 Loss: 0.020 | Acc: 99.576 49788/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 41.61it/s]


[Test] 138 Loss: 8.135 | Acc: 54.840 5484/10000


100%|███████████████████████████████████████████████████| 391/391 [00:25<00:00, 15.45it/s]


[Train] 139 Loss: 0.021 | Acc: 99.600 49800/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 43.16it/s]


[Test] 139 Loss: 7.919 | Acc: 55.420 5542/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:25<00:00, 15.45it/s]


[Train] 140 Loss: 0.018 | Acc: 99.622 49811/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 39.70it/s]


[Test] 140 Loss: 7.950 | Acc: 55.240 5524/10000


100%|███████████████████████████████████████████████████| 391/391 [00:25<00:00, 15.54it/s]


[Train] 141 Loss: 0.019 | Acc: 99.614 49807/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 43.76it/s]


[Test] 141 Loss: 7.965 | Acc: 55.260 5526/10000


100%|███████████████████████████████████████████████████| 391/391 [00:25<00:00, 15.47it/s]


[Train] 142 Loss: 0.019 | Acc: 99.614 49807/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 43.47it/s]


[Test] 142 Loss: 7.917 | Acc: 55.500 5550/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:25<00:00, 15.40it/s]


[Train] 143 Loss: 0.017 | Acc: 99.644 49822/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 37.46it/s]


[Test] 143 Loss: 7.788 | Acc: 55.540 5554/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:24<00:00, 15.65it/s]


[Train] 144 Loss: 0.019 | Acc: 99.652 49826/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 39.58it/s]


[Test] 144 Loss: 7.873 | Acc: 55.170 5517/10000


100%|███████████████████████████████████████████████████| 391/391 [00:25<00:00, 15.44it/s]


[Train] 145 Loss: 0.016 | Acc: 99.712 49856/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 38.86it/s]


[Test] 145 Loss: 7.889 | Acc: 55.500 5550/10000


100%|███████████████████████████████████████████████████| 391/391 [00:25<00:00, 15.48it/s]


[Train] 146 Loss: 0.016 | Acc: 99.696 49848/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 37.78it/s]


[Test] 146 Loss: 7.876 | Acc: 55.590 5559/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:25<00:00, 15.42it/s]


[Train] 147 Loss: 0.015 | Acc: 99.750 49875/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 38.33it/s]


[Test] 147 Loss: 7.856 | Acc: 55.570 5557/10000


100%|███████████████████████████████████████████████████| 391/391 [00:25<00:00, 15.19it/s]


[Train] 148 Loss: 0.016 | Acc: 99.730 49865/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 41.27it/s]


[Test] 148 Loss: 7.800 | Acc: 55.600 5560/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:25<00:00, 15.37it/s]


[Train] 149 Loss: 0.016 | Acc: 99.748 49874/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 37.37it/s]


[Test] 149 Loss: 7.664 | Acc: 55.700 5570/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:25<00:00, 15.39it/s]


[Train] 150 Loss: 0.014 | Acc: 99.768 49884/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 39.25it/s]


[Test] 150 Loss: 7.620 | Acc: 55.820 5582/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:25<00:00, 15.35it/s]


[Train] 151 Loss: 0.014 | Acc: 99.788 49894/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 39.19it/s]


[Test] 151 Loss: 7.614 | Acc: 55.870 5587/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:25<00:00, 15.49it/s]


[Train] 152 Loss: 0.011 | Acc: 99.772 49886/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 36.19it/s]


[Test] 152 Loss: 7.669 | Acc: 56.250 5625/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:25<00:00, 15.34it/s]


[Train] 153 Loss: 0.010 | Acc: 99.804 49902/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 44.03it/s]


[Test] 153 Loss: 7.859 | Acc: 55.520 5552/10000


100%|███████████████████████████████████████████████████| 391/391 [00:25<00:00, 15.06it/s]


[Train] 154 Loss: 0.010 | Acc: 99.806 49903/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 42.20it/s]


[Test] 154 Loss: 7.729 | Acc: 55.460 5546/10000


100%|███████████████████████████████████████████████████| 391/391 [00:25<00:00, 15.30it/s]


[Train] 155 Loss: 0.013 | Acc: 99.806 49903/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 40.07it/s]


[Test] 155 Loss: 7.676 | Acc: 55.370 5537/10000


100%|███████████████████████████████████████████████████| 391/391 [00:25<00:00, 15.32it/s]


[Train] 156 Loss: 0.012 | Acc: 99.800 49900/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 39.24it/s]


[Test] 156 Loss: 7.575 | Acc: 55.610 5561/10000


100%|███████████████████████████████████████████████████| 391/391 [00:25<00:00, 15.30it/s]


[Train] 157 Loss: 0.009 | Acc: 99.832 49916/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 39.95it/s]


[Test] 157 Loss: 7.591 | Acc: 55.940 5594/10000


100%|███████████████████████████████████████████████████| 391/391 [00:25<00:00, 15.33it/s]


[Train] 158 Loss: 0.010 | Acc: 99.796 49898/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 38.29it/s]


[Test] 158 Loss: 7.596 | Acc: 56.090 5609/10000


100%|███████████████████████████████████████████████████| 391/391 [00:25<00:00, 15.54it/s]


[Train] 159 Loss: 0.010 | Acc: 99.840 49920/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 37.09it/s]


[Test] 159 Loss: 7.578 | Acc: 56.130 5613/10000


100%|███████████████████████████████████████████████████| 391/391 [00:25<00:00, 15.61it/s]


[Train] 160 Loss: 0.008 | Acc: 99.868 49934/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 42.41it/s]


[Test] 160 Loss: 7.509 | Acc: 56.280 5628/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:25<00:00, 15.49it/s]


[Train] 161 Loss: 0.010 | Acc: 99.844 49922/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 36.33it/s]


[Test] 161 Loss: 7.499 | Acc: 55.990 5599/10000


100%|███████████████████████████████████████████████████| 391/391 [00:25<00:00, 15.54it/s]


[Train] 162 Loss: 0.009 | Acc: 99.860 49930/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 37.84it/s]


[Test] 162 Loss: 7.436 | Acc: 55.890 5589/10000


100%|███████████████████████████████████████████████████| 391/391 [00:25<00:00, 15.61it/s]


[Train] 163 Loss: 0.007 | Acc: 99.878 49939/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 36.35it/s]


[Test] 163 Loss: 7.442 | Acc: 56.080 5608/10000


100%|███████████████████████████████████████████████████| 391/391 [00:25<00:00, 15.44it/s]


[Train] 164 Loss: 0.008 | Acc: 99.874 49937/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 41.03it/s]


[Test] 164 Loss: 7.420 | Acc: 55.700 5570/10000


100%|███████████████████████████████████████████████████| 391/391 [00:25<00:00, 15.23it/s]


[Train] 165 Loss: 0.006 | Acc: 99.882 49941/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 45.11it/s]


[Test] 165 Loss: 7.426 | Acc: 55.970 5597/10000


100%|███████████████████████████████████████████████████| 391/391 [00:25<00:00, 15.31it/s]


[Train] 166 Loss: 0.006 | Acc: 99.896 49948/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 39.30it/s]


[Test] 166 Loss: 7.420 | Acc: 56.060 5606/10000


100%|███████████████████████████████████████████████████| 391/391 [00:25<00:00, 15.58it/s]


[Train] 167 Loss: 0.005 | Acc: 99.890 49945/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 37.38it/s]


[Test] 167 Loss: 7.405 | Acc: 55.920 5592/10000


100%|███████████████████████████████████████████████████| 391/391 [00:25<00:00, 15.38it/s]


[Train] 168 Loss: 0.006 | Acc: 99.882 49941/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 43.82it/s]


[Test] 168 Loss: 7.362 | Acc: 56.350 5635/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:25<00:00, 15.21it/s]


[Train] 169 Loss: 0.006 | Acc: 99.908 49954/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 40.00it/s]


[Test] 169 Loss: 7.339 | Acc: 56.280 5628/10000


100%|███████████████████████████████████████████████████| 391/391 [00:25<00:00, 15.40it/s]


[Train] 170 Loss: 0.006 | Acc: 99.906 49953/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 37.12it/s]


[Test] 170 Loss: 7.320 | Acc: 56.110 5611/10000


100%|███████████████████████████████████████████████████| 391/391 [00:25<00:00, 15.25it/s]


[Train] 171 Loss: 0.005 | Acc: 99.932 49966/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 39.95it/s]


[Test] 171 Loss: 7.312 | Acc: 56.310 5631/10000


100%|███████████████████████████████████████████████████| 391/391 [00:25<00:00, 15.34it/s]


[Train] 172 Loss: 0.004 | Acc: 99.920 49960/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 35.58it/s]


[Test] 172 Loss: 7.241 | Acc: 56.390 5639/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:25<00:00, 15.52it/s]


[Train] 173 Loss: 0.005 | Acc: 99.922 49961/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 32.52it/s]


[Test] 173 Loss: 7.258 | Acc: 56.230 5623/10000


100%|███████████████████████████████████████████████████| 391/391 [00:25<00:00, 15.56it/s]


[Train] 174 Loss: 0.005 | Acc: 99.920 49960/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 40.67it/s]


[Test] 174 Loss: 7.247 | Acc: 56.300 5630/10000


100%|███████████████████████████████████████████████████| 391/391 [00:25<00:00, 15.24it/s]


[Train] 175 Loss: 0.004 | Acc: 99.920 49960/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 38.17it/s]


[Test] 175 Loss: 7.192 | Acc: 56.250 5625/10000


100%|███████████████████████████████████████████████████| 391/391 [00:25<00:00, 15.27it/s]


[Train] 176 Loss: 0.004 | Acc: 99.928 49964/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 41.63it/s]


[Test] 176 Loss: 7.191 | Acc: 56.210 5621/10000


100%|███████████████████████████████████████████████████| 391/391 [00:25<00:00, 15.33it/s]


[Train] 177 Loss: 0.004 | Acc: 99.930 49965/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 40.91it/s]


[Test] 177 Loss: 7.200 | Acc: 56.340 5634/10000


100%|███████████████████████████████████████████████████| 391/391 [00:25<00:00, 15.50it/s]


[Train] 178 Loss: 0.005 | Acc: 99.912 49956/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 44.94it/s]


[Test] 178 Loss: 7.244 | Acc: 56.290 5629/10000


100%|███████████████████████████████████████████████████| 391/391 [00:25<00:00, 15.21it/s]


[Train] 179 Loss: 0.003 | Acc: 99.946 49973/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 43.50it/s]


[Test] 179 Loss: 7.225 | Acc: 56.460 5646/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:25<00:00, 15.26it/s]


[Train] 180 Loss: 0.003 | Acc: 99.942 49971/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 35.70it/s]


[Test] 180 Loss: 7.234 | Acc: 56.350 5635/10000


100%|███████████████████████████████████████████████████| 391/391 [00:25<00:00, 15.18it/s]


[Train] 181 Loss: 0.003 | Acc: 99.954 49977/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 36.73it/s]


[Test] 181 Loss: 7.223 | Acc: 56.310 5631/10000


100%|███████████████████████████████████████████████████| 391/391 [00:25<00:00, 15.18it/s]


[Train] 182 Loss: 0.003 | Acc: 99.946 49973/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 39.41it/s]


[Test] 182 Loss: 7.230 | Acc: 56.330 5633/10000


100%|███████████████████████████████████████████████████| 391/391 [00:24<00:00, 15.68it/s]


[Train] 183 Loss: 0.003 | Acc: 99.942 49971/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 33.36it/s]


[Test] 183 Loss: 7.219 | Acc: 56.270 5627/10000


100%|███████████████████████████████████████████████████| 391/391 [00:25<00:00, 15.52it/s]


[Train] 184 Loss: 0.003 | Acc: 99.944 49972/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 38.11it/s]


[Test] 184 Loss: 7.223 | Acc: 56.380 5638/10000


100%|███████████████████████████████████████████████████| 391/391 [00:25<00:00, 15.59it/s]


[Train] 185 Loss: 0.003 | Acc: 99.944 49972/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 38.47it/s]


[Test] 185 Loss: 7.221 | Acc: 56.320 5632/10000


100%|███████████████████████████████████████████████████| 391/391 [00:25<00:00, 15.56it/s]


[Train] 186 Loss: 0.002 | Acc: 99.946 49973/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 43.11it/s]


[Test] 186 Loss: 7.230 | Acc: 56.410 5641/10000


100%|███████████████████████████████████████████████████| 391/391 [00:25<00:00, 15.31it/s]


[Train] 187 Loss: 0.002 | Acc: 99.958 49979/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 36.18it/s]


[Test] 187 Loss: 7.218 | Acc: 56.350 5635/10000


100%|███████████████████████████████████████████████████| 391/391 [00:24<00:00, 15.68it/s]


[Train] 188 Loss: 0.002 | Acc: 99.966 49983/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 35.74it/s]


[Test] 188 Loss: 7.222 | Acc: 56.330 5633/10000


100%|███████████████████████████████████████████████████| 391/391 [00:19<00:00, 20.22it/s]


[Train] 189 Loss: 0.002 | Acc: 99.958 49979/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 57.94it/s]


[Test] 189 Loss: 7.217 | Acc: 56.370 5637/10000


100%|███████████████████████████████████████████████████| 391/391 [00:15<00:00, 26.06it/s]


[Train] 190 Loss: 0.002 | Acc: 99.946 49973/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 58.03it/s]


[Test] 190 Loss: 7.208 | Acc: 56.440 5644/10000


100%|███████████████████████████████████████████████████| 391/391 [00:14<00:00, 26.08it/s]


[Train] 191 Loss: 0.003 | Acc: 99.938 49969/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 59.05it/s]


[Test] 191 Loss: 7.212 | Acc: 56.440 5644/10000


100%|███████████████████████████████████████████████████| 391/391 [00:14<00:00, 26.25it/s]


[Train] 192 Loss: 0.002 | Acc: 99.940 49970/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 58.05it/s]


[Test] 192 Loss: 7.212 | Acc: 56.500 5650/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:15<00:00, 25.79it/s]


[Train] 193 Loss: 0.002 | Acc: 99.942 49971/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 57.31it/s]


[Test] 193 Loss: 7.210 | Acc: 56.490 5649/10000


100%|███████████████████████████████████████████████████| 391/391 [00:15<00:00, 25.86it/s]


[Train] 194 Loss: 0.002 | Acc: 99.950 49975/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 58.73it/s]


[Test] 194 Loss: 7.208 | Acc: 56.510 5651/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:15<00:00, 25.80it/s]


[Train] 195 Loss: 0.001 | Acc: 99.958 49979/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 57.87it/s]


[Test] 195 Loss: 7.208 | Acc: 56.500 5650/10000


100%|███████████████████████████████████████████████████| 391/391 [00:15<00:00, 25.59it/s]


[Train] 196 Loss: 0.002 | Acc: 99.948 49974/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 57.51it/s]


[Test] 196 Loss: 7.206 | Acc: 56.510 5651/10000


100%|███████████████████████████████████████████████████| 391/391 [00:23<00:00, 16.68it/s]


[Train] 197 Loss: 0.001 | Acc: 99.966 49983/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 35.73it/s]


[Test] 197 Loss: 7.205 | Acc: 56.540 5654/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:30<00:00, 12.68it/s]


[Train] 198 Loss: 0.002 | Acc: 99.956 49978/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 36.08it/s]


[Test] 198 Loss: 7.205 | Acc: 56.510 5651/10000


100%|███████████████████████████████████████████████████| 391/391 [00:30<00:00, 12.71it/s]


[Train] 199 Loss: 0.002 | Acc: 99.968 49984/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 35.10it/s]


[Test] 199 Loss: 7.205 | Acc: 56.510 5651/10000
Training finished

Experiment index: 1
MLP Mixer : Channes per patch -> Initial:48 Final:153
EXPERIMENTING FOR : 03.0_original_mixer1_l7_c100_s369 | params: 1765778  .......
.......


100%|███████████████████████████████████████████████████| 391/391 [00:27<00:00, 14.20it/s]


[Train] 0 Loss: 3.631 | Acc: 16.192 8096/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 32.73it/s]


[Test] 0 Loss: 3.039 | Acc: 26.490 2649/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:27<00:00, 14.16it/s]


[Train] 1 Loss: 2.961 | Acc: 27.516 13758/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 35.66it/s]


[Test] 1 Loss: 2.709 | Acc: 32.490 3249/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:27<00:00, 14.23it/s]


[Train] 2 Loss: 2.648 | Acc: 33.300 16650/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 40.20it/s]


[Test] 2 Loss: 2.524 | Acc: 36.370 3637/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:27<00:00, 14.42it/s]


[Train] 3 Loss: 2.428 | Acc: 38.076 19038/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 39.65it/s]


[Test] 3 Loss: 2.387 | Acc: 40.080 4008/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:27<00:00, 14.41it/s]


[Train] 4 Loss: 2.254 | Acc: 41.818 20909/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 35.46it/s]


[Test] 4 Loss: 2.239 | Acc: 43.140 4314/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:27<00:00, 14.19it/s]


[Train] 5 Loss: 2.098 | Acc: 45.300 22650/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 35.48it/s]


[Test] 5 Loss: 2.135 | Acc: 45.920 4592/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:27<00:00, 14.44it/s]


[Train] 6 Loss: 1.969 | Acc: 48.026 24013/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 41.56it/s]


[Test] 6 Loss: 2.014 | Acc: 47.980 4798/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:27<00:00, 14.34it/s]


[Train] 7 Loss: 1.844 | Acc: 50.642 25321/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 38.93it/s]


[Test] 7 Loss: 2.041 | Acc: 47.840 4784/10000


100%|███████████████████████████████████████████████████| 391/391 [00:27<00:00, 14.13it/s]


[Train] 8 Loss: 1.732 | Acc: 53.236 26618/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 38.25it/s]


[Test] 8 Loss: 1.978 | Acc: 49.600 4960/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:27<00:00, 14.43it/s]


[Train] 9 Loss: 1.642 | Acc: 55.542 27771/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 39.63it/s]


[Test] 9 Loss: 1.936 | Acc: 51.370 5137/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:27<00:00, 14.10it/s]


[Train] 10 Loss: 1.539 | Acc: 58.006 29003/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 36.06it/s]


[Test] 10 Loss: 1.998 | Acc: 50.720 5072/10000


100%|███████████████████████████████████████████████████| 391/391 [00:27<00:00, 14.26it/s]


[Train] 11 Loss: 1.456 | Acc: 59.750 29875/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 33.97it/s]


[Test] 11 Loss: 1.966 | Acc: 51.820 5182/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:27<00:00, 14.25it/s]


[Train] 12 Loss: 1.378 | Acc: 61.612 30806/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 35.60it/s]


[Test] 12 Loss: 1.979 | Acc: 52.330 5233/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:27<00:00, 14.09it/s]


[Train] 13 Loss: 1.293 | Acc: 63.668 31834/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 36.32it/s]


[Test] 13 Loss: 2.010 | Acc: 52.370 5237/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:27<00:00, 14.18it/s]


[Train] 14 Loss: 1.205 | Acc: 66.008 33004/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 36.51it/s]


[Test] 14 Loss: 2.029 | Acc: 52.710 5271/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:27<00:00, 14.34it/s]


[Train] 15 Loss: 1.125 | Acc: 68.036 34018/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 38.28it/s]


[Test] 15 Loss: 2.068 | Acc: 53.880 5388/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:27<00:00, 14.30it/s]


[Train] 16 Loss: 1.066 | Acc: 69.570 34785/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 38.34it/s]


[Test] 16 Loss: 2.108 | Acc: 53.380 5338/10000


100%|███████████████████████████████████████████████████| 391/391 [00:27<00:00, 14.17it/s]


[Train] 17 Loss: 0.982 | Acc: 71.624 35812/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 34.95it/s]


[Test] 17 Loss: 2.147 | Acc: 53.770 5377/10000


100%|███████████████████████████████████████████████████| 391/391 [00:27<00:00, 14.23it/s]


[Train] 18 Loss: 0.917 | Acc: 73.350 36675/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 33.70it/s]


[Test] 18 Loss: 2.287 | Acc: 53.120 5312/10000


100%|███████████████████████████████████████████████████| 391/391 [00:27<00:00, 14.42it/s]


[Train] 19 Loss: 0.858 | Acc: 74.690 37345/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:02<00:00, 38.39it/s]


[Test] 19 Loss: 2.284 | Acc: 52.840 5284/10000


 19%|█████████▉                                          | 75/391 [00:05<00:20, 15.78it/s]

In [None]:
asdfasdfwdf

# Models -- New Test 
1. Scaling is different for original mlp mixer so not useful for benchmarking
2. MLP mixer uses channel wise mixing, here channels are number of patches; it changes differntly from patch size when scaling.

In [None]:
class MlpBLock(nn.Module):
    
    def __init__(self, input_dim, hidden_layers_ratio=[2], actf=nn.GELU):
        super().__init__()
        self.input_dim = input_dim
        #### convert hidden layers ratio to list if integer is inputted
        if isinstance(hidden_layers_ratio, int):
            hidden_layers_ratio = [hidden_layers_ratio]
            
        self.hlr = [1]+hidden_layers_ratio+[1]
        
        self.mlp = []
        ### for 1 hidden layer, we iterate 2 times
        for h in range(len(self.hlr)-1):
            i, o = int(self.hlr[h]*self.input_dim),\
                    int(self.hlr[h+1]*self.input_dim)
            self.mlp.append(nn.Linear(i, o))
            self.mlp.append(actf())
        self.mlp = self.mlp[:-1]
        
        self.mlp = nn.Sequential(*self.mlp)
        
    def forward(self, x):
        return self.mlp(x)

In [None]:
MlpBLock(2, [3,4])

## MLP-Mixer 

In [None]:
class MixerBlock(nn.Module):
    
    def __init__(self, patch_dim, channel_dim):
        super().__init__()
        
        self.ln0 = nn.LayerNorm(channel_dim)
        self.mlp_patch = MlpBLock(patch_dim, [2])
        self.ln1 = nn.LayerNorm(channel_dim)
        self.mlp_channel = MlpBLock(channel_dim, [2])
    
    def forward(self, x):
        ## x has shape-> N, nP, nC/hidden_dims; C=Channel, P=Patch
        
        ######## !!!! Can use same mixer on shape of -> N, C, P;
        
        #### mix per patch
        y = self.ln0(x) ### per channel layer normalization ?? 
        y = torch.swapaxes(y, -1, -2)
        y = self.mlp_patch(y)
        y = torch.swapaxes(y, -1, -2)
        x = x+y
        
        #### mix per channel 
        y = self.ln1(x)
        y = self.mlp_channel(y)
        x = x+y
        return x

In [None]:
class MlpMixer(nn.Module):
    
    def __init__(self, input_channels:int, hidden_image_dim:tuple, patch_size:tuple, num_blocks:int, num_classes:int, use1x1scale=True):
        super().__init__()
        
        self.img_dim = hidden_image_dim ### must contain (C, H, W) or (H, W)
        
        ### find patch dim
        d0 = int(self.img_dim[-2]/patch_size[0])
        d1 = int(self.img_dim[-1]/patch_size[1])
        assert d0*patch_size[0]==self.img_dim[-2], "Image must be divisible into patch size"
        assert d1*patch_size[1]==self.img_dim[-1], "Image must be divisible into patch size"
        
        ### find channel dim
        channel_size = d0*d1 ## number of patches
        
        ### after the number of channels are changed
        init_dim = patch_size[0]*patch_size[1]*input_channels ## number of channels in each patch
        final_dim = patch_size[0]*patch_size[1]*self.img_dim[0]

        self.scaler = nn.UpsamplingBilinear2d(size=(self.img_dim[-2], self.img_dim[-1]))
        
        if use1x1scale:
            self.conv1x1 = nn.Conv2d(input_channels, self.img_dim[0], kernel_size=1, stride=1)
            if input_channels == self.img_dim[0]:
                self.conv1x1 = nn.Identity()
        else:
            #### rescale the patches (patch wise image non preserving transform, unlike bilinear interpolation)
            self.channel_change = nn.Linear(init_dim, final_dim) ## apply after unfold
            
        self.unfold = nn.Unfold(kernel_size=patch_size, stride=patch_size)

        print(f"MLP Mixer : Channes per patch -> Initial:{init_dim} Final:{final_dim}")
        
        ### after the axis are swapped
        self.channel_dim = final_dim
        self.patch_dim = channel_size
        
        self.mixer_blocks = []
        for i in range(num_blocks):
            self.mixer_blocks.append(MixerBlock(self.patch_dim, self.channel_dim))
        self.mixer_blocks = nn.Sequential(*self.mixer_blocks)
        
        self.linear = nn.Linear(self.patch_dim*self.channel_dim, num_classes)
        
        
    def forward(self, x):
        bs = x.shape[0]
        x = self.scaler(x) ## scale to high dimension
        x = self.conv1x1(x)
        x = self.unfold(x).swapaxes(-1, -2).contiguous() ## convert to patches, bring channels first
#         x = self.channel_change(x) ## change the number of channels
        x = self.mixer_blocks(x) ## the mixer architecture
        x = self.linear(x.view(bs, -1)) ## classify
        return x

In [None]:
mixer = MlpMixer(3, (5, 35, 35), (5, 5), num_blocks=1, num_classes=10, use1x1scale=False)
mixer

In [None]:
mixer(torch.randn(1, 3, 32, 32))

## Patch Mixer

In [None]:
class PatchMixerBlock(nn.Module):
    
    def __init__(self, patch_size, num_channel):
        super().__init__()
        self.patch_size = patch_size
        
#         self.unfold = nn.Unfold(kernel_size=patch_size, stride=patch_size)
        ps = None
        if isinstance(patch_size, int):
            ps = patch_size**2
        else:
            ps = patch_size[0]*patch_size[1]
        ps = ps*num_channel
        
        self.ln0 = nn.LayerNorm(ps)
        self.mlp_patch = MlpBLock(ps, [2])
        
#         self.fold = nn.Fold(kernel_size=patch_size, stride=patch_size)
    
    def forward(self, x):
        ## x has shape-> N, C, H, W; C=Channel
        
        sz = x.shape
        
        y = nn.functional.unfold(x, 
                                 kernel_size=self.patch_size, 
                                 stride=self.patch_size
                                )
        #### mix per patch
        y = torch.swapaxes(y, -1, -2)
        y = self.ln0(y) 
        y = self.mlp_patch(y)
        y = torch.swapaxes(y, -1, -2)
        
        y = nn.functional.fold(y, (sz[-2], sz[-1]), 
                               kernel_size=self.patch_size, 
                               stride=self.patch_size
                              )
        x = x+y
        return x

In [None]:
pmb = PatchMixerBlock(7, 5)
pmb

In [None]:
pmb(torch.randn(1, 5, 35, 35)).shape

In [None]:
def get_factors(n):
    facts = []
    for i in range(2, n+1):
        if n%i == 0:
            facts.append(i)
    return facts

class PatchMlpMixer(nn.Module):
    
    def __init__(self, input_channels:int, hidden_image_dim:tuple, patch_sizes:list, num_blocks:int, num_classes:int):
        super().__init__()
        
        self.img_dim = hidden_image_dim ### must contain (C, H, W)
        self.scaler = nn.UpsamplingBilinear2d(size=(self.img_dim[-2], self.img_dim[-1]))
        
        assert np.prod(patch_sizes) == hidden_image_dim[-1], "The product of patches must equal image dimension"
        assert np.prod(patch_sizes) == hidden_image_dim[-2], "The product of patches must equal image dimension"
        
        ### find number of channel for input, the channel is 
        self.conv1x1 = nn.Conv2d(input_channels, self.img_dim[0], kernel_size=1, stride=1)
        if input_channels == self.img_dim[0]:
            self.conv1x1 = nn.Identity()
        
        self.mixer_blocks = []
        for i in range(num_blocks):
            for ps in patch_sizes:
                self.mixer_blocks.append(PatchMixerBlock(ps, self.img_dim[0]))
                
        self.mixer_blocks = nn.Sequential(*self.mixer_blocks)
        self.linear = nn.Linear(np.prod(self.img_dim), num_classes)
    
    def forward(self, x):
        bs = x.shape[0]
        x = self.scaler(x)
        x = self.conv1x1(x)
        x = self.mixer_blocks(x)
        x = self.linear(x.view(bs, -1))
        return x

In [None]:
patch_mixer = PatchMlpMixer(3, (5, 35, 35), patch_sizes=[5, 7], num_blocks=1, num_classes=10)

In [None]:
patch_mixer

In [None]:
patch_mixer(torch.randn(1, 3, 32, 32)).shape

#### Final Model

In [None]:
model = MlpMixer(3, (9, 32, 32), patch_size=(4, 4), num_blocks=10, num_classes=10)
model = model.to(device)

In [None]:
print("number of params: ", sum(p.numel() for p in model.parameters())) 

In [None]:
model

In [None]:
model = PatchMlpMixer(3, (3, 35, 35), patch_sizes=[5, 7], num_blocks=10, num_classes=10)
model = model.to(device)

In [None]:
print("number of params: ", sum(p.numel() for p in model.parameters())) 

In [None]:
model

In [None]:
print("number of params: ", sum(p.numel() for p in model.parameters())) 
## Patch ||  1137220
## Mixer ||  1141703