Official Code : https://github.com/aanna0701/SPT_LSA_ViT <br>
This code is to train CIFAR10 dataset with Swin Transformer. <br>
If you need another transformer model, you should add some more sources.

In [1]:
!pip install timm
!pip install einops
!pip install colorama
!pip install torchsummary

Collecting timm
  Downloading timm-0.5.4-py3-none-any.whl (431 kB)
[K     |████████████████████████████████| 431 kB 4.2 MB/s 
Installing collected packages: timm
Successfully installed timm-0.5.4
Collecting einops
  Downloading einops-0.4.0-py3-none-any.whl (28 kB)
Installing collected packages: einops
Successfully installed einops-0.4.0
Collecting colorama
  Downloading colorama-0.4.4-py2.py3-none-any.whl (16 kB)
Installing collected packages: colorama
Successfully installed colorama-0.4.4


In [2]:
import numpy as np
import torch

def rand_bbox(size, lam):
    W = size[2]
    H = size[3]
    cut_rat = np.sqrt(1. - lam)
    cut_w = np.int(W * cut_rat)
    cut_h = np.int(H * cut_rat)

    # uniform
    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)

    return bbx1, bby1, bbx2, bby2

def mixup_data(x, y, args):
    '''Returns mixed inputs, pairs of targets, and lambda'''
    if args.alpha > 0:
        lam = np.random.beta(args.alpha, args.alpha)
    else:
        lam = 1

    batch_size = x.size()[0]
    
    index = torch.randperm(batch_size).cuda()

    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

def cutmix_data(x, y, args):
    '''Returns mixed inputs, pairs of targets, and lambda'''
    if args.beta > 0:
        lam = np.random.beta(args.beta, args.beta)
    else:
        lam = 1

    batch_size = x.size()[0]
    
    index = torch.randperm(batch_size).cuda()

    y_a, y_b = y, y[index]

    bbx1, bby1, bbx2, bby2 = rand_bbox(x.size(), lam)
    x_sliced = x[index, :, bbx1:bbx2, bby1:bby2]
    # adjust lambda to exactly match pixel ratio
    lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (x.size()[-1] * x.size()[-2]))
    
    return [bbx1, bby1, bbx2, bby2 ], y_a, y_b, lam, x_sliced

def mixup_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)


In [3]:
# Thanks to rwightman's timm package
# github.com:rwightman/pytorch-image-models
import torch.nn as nn
import torch.nn.functional as F


class LabelSmoothingCrossEntropy(nn.Module):
    """
    NLL loss with label smoothing.
    """
    def __init__(self, smoothing=0.1):
        """
        Constructor for the LabelSmoothing module.
        :param smoothing: label smoothing factor
        """
        super(LabelSmoothingCrossEntropy, self).__init__()
        assert smoothing < 1.0
        self.smoothing = smoothing
        self.confidence = 1. - smoothing

    def _compute_losses(self, x, target):
        log_prob = F.log_softmax(x, dim=-1)
        nll_loss = -log_prob.gather(dim=-1, index=target.unsqueeze(1))
        nll_loss = nll_loss.squeeze(1)
        smooth_loss = -log_prob.mean(dim=-1)
        loss = self.confidence * nll_loss + self.smoothing * smooth_loss
        return loss

    def forward(self, x, target):
        return self._compute_losses(x, target).mean()

In [4]:
# Copyright (c) 2015-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the CC-by-NC license found in the
# LICENSE file in the root directory of this source tree.
#
from torch.utils.data.sampler import BatchSampler
import torch
import numpy as np
from torch.utils.data.dataloader import default_collate
from collections.abc import Mapping, Sequence


class RASampler(torch.utils.data.Sampler):
    """
    Batch Sampler with Repeated Augmentations (RA)
    - dataset_len: original length of the dataset
    - batch_size
    - repetitions: instances per image
    - len_factor: multiplicative factor for epoch size
    """

    def __init__(self, dataset_len, batch_size, repetitions=1, len_factor=3.0, shuffle=False, drop_last=False):
        self.dataset_len = dataset_len
        self.batch_size = batch_size
        self.repetitions = repetitions
        self.len_images = int(dataset_len * len_factor)
        self.shuffle = shuffle
        self.drop_last = drop_last

    def shuffler(self):
        if self.shuffle:
            new_perm = lambda: iter(np.random.permutation(self.dataset_len))
        else:
            new_perm = lambda: iter(np.arange(self.dataset_len))
        shuffle = new_perm()
        while True:
            try:
                index = next(shuffle)
            except StopIteration:
                shuffle = new_perm()
                index = next(shuffle)
            for repetition in range(self.repetitions):
                yield index

    def __iter__(self):
        shuffle = iter(self.shuffler())
        seen = 0
        batch = []
        for _ in range(self.len_images):
            index = next(shuffle)
            batch.append(index)
            if len(batch) == self.batch_size:
                yield batch
                batch = []
        if batch and not self.drop_last:
            yield batch

    def __len__(self):
        if self.drop_last:
            return self.len_images // self.batch_size
        else:
            return (self.len_images + self.batch_size - 1) // self.batch_size


def list_collate(batch):
    """
    Collate into a list instead of a tensor to deal with variable-sized inputs
    """
    elem_type = type(batch[0])
    if isinstance(batch[0], torch.Tensor):
        return batch
    elif elem_type.__module__ == 'numpy':
        if elem_type.__name__ == 'ndarray':
            return list_collate([torch.from_numpy(b) for b in batch])
    elif isinstance(batch[0], Mapping):
        return {key: list_collate([d[key] for d in batch]) for key in batch[0]}
    elif isinstance(batch[0], Sequence):
        transposed = zip(*batch)
        return [list_collate(samples) for samples in transposed]
    return default_collate(batch)

In [5]:
import os
import sys
import time

term_width = int(20)

TOTAL_BAR_LENGTH = 65.

def progress_bar(current, total, msg=None):
  
    L = []
    if msg:
        L.append(msg)

    msg = ''.join(L)
    sys.stdout.write(msg)
    for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3):
        sys.stdout.write(' ')

    # Go back to the center of the bar.
    for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2):
        sys.stdout.write('\b')

    if current < total-1:
        sys.stdout.write('\r')
    else:
        sys.stdout.write('\n')
    sys.stdout.flush()

def format_time(seconds):
    days = int(seconds / 3600/24)
    seconds = seconds - days*3600*24
    hours = int(seconds / 3600)
    seconds = seconds - hours*3600
    minutes = int(seconds / 60)
    seconds = seconds - minutes*60
    secondsf = int(seconds)
    seconds = seconds - secondsf
    millis = int(seconds*1000)

    f = ''
    i = 1
    if days > 0:
        f += str(days) + 'D'
        i += 1
    if hours > 0 and i <= 2:
        f += str(hours) + 'h'
        i += 1
    if minutes > 0 and i <= 2:
        f += str(minutes) + 'm'
        i += 1
    if secondsf > 0 and i <= 2:
        f += str(secondsf) + 's'
        i += 1
    if millis > 0 and i <= 2:
        f += str(millis) + 'ms'
        i += 1
    if f == '':
        f = '0ms'
    return f

In [6]:
from collections import OrderedDict
import csv
from genericpath import exists
import os

keys = ['T Loss', 'T Top-1', 'V Loss', 'V Top-1', 'ParameterScale']

class Logger_dict():
    def __init__(self, logger, save_path):
        self.dict = OrderedDict()
        self.logger = logger
        self.savepath = save_path
        if os.path.exists(os.path.join(self.savepath, 'log.csv')):
            self.init_csv()
        self.write_csv(keys)
                
    def update(self, key, value):
        self.dict[key] = value
        
    def init_csv(self):
        fileVariable = open(os.path.join(self.savepath, 'log.csv'), 'r+')
        fileVariable.truncate(0)
        fileVariable.close()
        
    def write_csv(self, x):
        with open(os.path.join(self.savepath, 'log.csv'), "a") as outfile:
            csvwriter = outfile
            csvwriter = csv.writer(outfile)
            csvwriter.writerow(x)
        
    def print(self):
        i = 0
        values = []
        for key, value in self.dict.items():
            print(f'{key}' +'\t'+ f'{value}')
            i += 1
            values.append(value)
        self.write_csv(values)
        print()

In [7]:
class EarlyStopping:
    def __init__(self, patience=0, verbose=0, mode='max'):
        self._step = 0
        self._loss = 0.0
        self.patience = patience
        self.verbose = verbose
        self.best_value = 0.0
        if mode == 'max':
            self.mode = 1
        else:
            self.mode = -1 

    def validate(self, value):
        if self._loss * self.mode >= value * self.mode:
            self._step += 1
            if self._step > self.patience:
                if self.verbose:
                    print(f'Training process is stopped early....\n\n')
                return self.best_value
        else:
            if self.best_value * self.mode < value * self.mode:
                self.best_value = value
            self._step = 0
        
        self._loss = value

        return 0
    
def accuracy(output, target, topk=(1,)):
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].flatten().float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

In [8]:
import math
import torch
from torch.optim.lr_scheduler import _LRScheduler

class CosineAnnealingWarmupRestarts(_LRScheduler):
    """
        optimizer (Optimizer): Wrapped optimizer.
        first_cycle_steps (int): First cycle step size.
        cycle_mult(float): Cycle steps magnification. Default: -1.
        max_lr(float): First cycle's max learning rate. Default: 0.1.
        min_lr(float): Min learning rate. Default: 0.001.
        warmup_steps(int): Linear warmup step size. Default: 0.
        gamma(float): Decrease rate of max learning rate by cycle. Default: 1.
        last_epoch (int): The index of last epoch. Default: -1.
    """
    
    def __init__(self,
                 optimizer : torch.optim.Optimizer,
                 first_cycle_steps : int,
                 cycle_mult : float = 1.,
                 max_lr : float = 0.1,
                 min_lr : float = 0.001,
                 warmup_steps : int = 0,
                 gamma : float = 1.,
                 last_epoch : int = -1
        ):
        assert warmup_steps < first_cycle_steps
        
        self.first_cycle_steps = first_cycle_steps # first cycle step size
        self.cycle_mult = cycle_mult # cycle steps magnification
        self.base_max_lr = max_lr # first max learning rate
        self.max_lr = max_lr # max learning rate in the current cycle
        self.min_lr = min_lr # min learning rate
        self.warmup_steps = warmup_steps # warmup step size
        self.gamma = gamma # decrease rate of max learning rate by cycle
        
        self.cur_cycle_steps = first_cycle_steps # first cycle step size
        self.cycle = 0 # cycle count
        self.step_in_cycle = last_epoch # step size of the current cycle
        
        super(CosineAnnealingWarmupRestarts, self).__init__(optimizer, last_epoch)
        
        # set learning rate min_lr
        self.init_lr()
    
    def init_lr(self):
        self.base_lrs = []
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = self.min_lr
            self.base_lrs.append(self.min_lr)
    
    def get_lr(self):
        if self.step_in_cycle == -1:
            return self.base_lrs
        elif self.step_in_cycle < self.warmup_steps:
            return [(self.max_lr - base_lr)*self.step_in_cycle / self.warmup_steps + base_lr for base_lr in self.base_lrs]
        else:
            return [base_lr + (self.max_lr - base_lr) \
                    * (1 + math.cos(math.pi * (self.step_in_cycle-self.warmup_steps) \
                                    / (self.cur_cycle_steps - self.warmup_steps))) / 2
                    for base_lr in self.base_lrs]

    def step(self, epoch=None):
        if epoch is None:
            epoch = self.last_epoch + 1
            self.step_in_cycle = self.step_in_cycle + 1
            if self.step_in_cycle >= self.cur_cycle_steps:
                self.cycle += 1
                self.step_in_cycle = self.step_in_cycle - self.cur_cycle_steps
                self.cur_cycle_steps = int((self.cur_cycle_steps - self.warmup_steps) * self.cycle_mult) + self.warmup_steps
        else:
            if epoch >= self.first_cycle_steps:
                if self.cycle_mult == 1.:
                    self.step_in_cycle = epoch % self.first_cycle_steps
                    self.cycle = epoch // self.first_cycle_steps
                else:
                    n = int(math.log((epoch / self.first_cycle_steps * (self.cycle_mult - 1) + 1), self.cycle_mult))
                    self.cycle = n
                    self.step_in_cycle = epoch - int(self.first_cycle_steps * (self.cycle_mult ** n - 1) / (self.cycle_mult - 1))
                    self.cur_cycle_steps = self.first_cycle_steps * self.cycle_mult ** (n)
            else:
                self.cur_cycle_steps = self.first_cycle_steps
                self.step_in_cycle = epoch
                
        self.max_lr = self.base_max_lr * (self.gamma**self.cycle)
        self.last_epoch = math.floor(epoch)
        for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
            param_group['lr'] = lr

def build_scheduler(args, optimizer, n_iter_per_epoch):
    num_steps = int(args.epochs * n_iter_per_epoch)
    warmup_steps = int(args.warmup * n_iter_per_epoch)

    lr_scheduler = CosineAnnealingWarmupRestarts(
        optimizer,
        first_cycle_steps=num_steps,
        cycle_mult=1.,
        max_lr = args.lr,
        min_lr = 1e-6,
        warmup_steps=warmup_steps
        )
    return lr_scheduler


In [9]:
import os
from colorama import Fore, Style
import torchvision.transforms as transforms
import torchvision.datasets as datasets

def datainfo(logger, args):
    if args.dataset == 'CIFAR10':
        print(Fore.YELLOW+'*'*80)
        logger.debug('CIFAR10')
        print('*'*80 + Style.RESET_ALL)
        n_classes = 10
        img_mean, img_std = (0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)
        img_size = 32        
        
    elif args.dataset == 'CIFAR100':
        print(Fore.YELLOW+'*'*80)
        logger.debug('CIFAR100')
        print('*'*80 + Style.RESET_ALL)
        n_classes = 100
        img_mean, img_std = (0.5070, 0.4865, 0.4409), (0.2673, 0.2564, 0.2762) 
        img_size = 32        
        
    elif args.dataset == 'SVHN':
        print(Fore.YELLOW+'*'*80)
        logger.debug('SVHN')
        print('*'*80 + Style.RESET_ALL)
        n_classes = 10
        img_mean, img_std = (0.4377, 0.4438, 0.4728), (0.1980, 0.2010, 0.1970) 
        img_size = 32
        
    elif args.dataset == 'T-IMNET':
        print(Fore.YELLOW+'*'*80)
        logger.debug('T-IMNET')
        print('*'*80 + Style.RESET_ALL)
        n_classes = 200
        img_mean, img_std = (0.4802, 0.4481, 0.3975), (0.2770, 0.2691, 0.2821)
        img_size = 64
        
    data_info = dict()
    data_info['n_classes'] = n_classes
    data_info['stat'] = (img_mean, img_std)
    data_info['img_size'] = img_size
    
    return data_info

def dataload(args, augmentations, normalize, data_info):
    if args.dataset == 'CIFAR10':
        train_dataset = datasets.CIFAR10(
            root=args.data_path, train=True, download=True, transform=augmentations)
        val_dataset = datasets.CIFAR10(
            root=args.data_path, train=False, download=False, transform=transforms.Compose([
            transforms.Resize(data_info['img_size']),
            transforms.ToTensor(),
            *normalize]))
        
    elif args.dataset == 'CIFAR100':

        train_dataset = datasets.CIFAR100(
            root=args.data_path, train=True, download=True, transform=augmentations)
        val_dataset = datasets.CIFAR100(
            root=args.data_path, train=False, download=False, transform=transforms.Compose([
            transforms.Resize(data_info['img_size']),
            transforms.ToTensor(),
            *normalize]))
        
    elif args.dataset == 'SVHN':

        train_dataset = datasets.SVHN(
            root=args.data_path, split='train', download=True, transform=augmentations)
        val_dataset = datasets.SVHN(
            root=args.data_path, split='test', download=True, transform=transforms.Compose([
            transforms.Resize(data_info['img_size']),
            transforms.ToTensor(),
            *normalize]))
        
    elif args.dataset == 'T-IMNET':
        train_dataset = datasets.ImageFolder(
            root=os.path.join(args.data_path, 'tiny_imagenet', 'train'), transform=augmentations)
        val_dataset = datasets.ImageFolder(
            root=os.path.join(args.data_path, 'tiny_imagenet', 'val'), 
            transform=transforms.Compose([
            transforms.Resize(data_info['img_size']), transforms.ToTensor(), *normalize]))
    
    return train_dataset, val_dataset

In [10]:
import logging as log
import torch
from torch.nn import Module

import numpy as np
import random
import logging as log
import torch
import torch.nn as nn
import torch.optim
import torch.utils.data
import torchvision.transforms as transforms
from colorama import Fore, Style
from torchsummary import summary
import os
import argparse
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import warnings
warnings.filterwarnings("ignore", category=Warning)

In [11]:
def drop_path(x, drop_prob: float = 0., training: bool = False):
    """
    Obtained from: github.com:rwightman/pytorch-image-models
    Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
    This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
    the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
    changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
    'survival rate' as the argument.
    """
    if drop_prob == 0. or not training:
        return x
    keep_prob = 1 - drop_prob
    shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
    random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
    random_tensor.floor_()  # binarize
    output = x.div(keep_prob) * random_tensor
    return output


class DropPath(Module):
    """
    Obtained from: github.com:rwightman/pytorch-image-models
    Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).
    """
    def __init__(self, drop_prob=None):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training)

In [12]:
import torch
from torch import nn
from einops import rearrange
from einops.layers.torch import Rearrange
import math

class ShiftedPatchTokenization(nn.Module):
    def __init__(self, in_dim, dim, merging_size=2, exist_class_t=False, is_pe=False):
        super().__init__()
        
        self.exist_class_t = exist_class_t
        
        self.patch_shifting = PatchShifting(merging_size)
        
        patch_dim = (in_dim*5) * (merging_size**2) 
        if exist_class_t:
            self.class_linear = nn.Linear(in_dim, dim)

        self.is_pe = is_pe
        
        self.merging = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = merging_size, p2 = merging_size),
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, dim)
        )

    def forward(self, x):
        
        if self.exist_class_t:
            visual_tokens, class_token = x[:, 1:], x[:, (0,)]
            reshaped = rearrange(visual_tokens, 'b (h w) d -> b d h w', h=int(math.sqrt(x.size(1))))
            out_visual = self.patch_shifting(reshaped)
            out_visual = self.merging(out_visual)
            out_class = self.class_linear(class_token)
            out = torch.cat([out_class, out_visual], dim=1)
        
        else:
            out = x if self.is_pe else rearrange(x, 'b (h w) d -> b d h w', h=int(math.sqrt(x.size(1))))
            out = self.patch_shifting(out)
            out = self.merging(out)    
        
        return out
        
class PatchShifting(nn.Module):
    def __init__(self, patch_size):
        super().__init__()
        self.shift = int(patch_size * (1/2))
        
    def forward(self, x):
     
        x_pad = torch.nn.functional.pad(x, (self.shift, self.shift, self.shift, self.shift))
        # if self.is_mean:
        #     x_pad = x_pad.mean(dim=1, keepdim = True)
        
        """ 4 cardinal directions """
        #############################
        # x_l2 = x_pad[:, :, self.shift:-self.shift, :-self.shift*2]
        # x_r2 = x_pad[:, :, self.shift:-self.shift, self.shift*2:]
        # x_t2 = x_pad[:, :, :-self.shift*2, self.shift:-self.shift]
        # x_b2 = x_pad[:, :, self.shift*2:, self.shift:-self.shift]
        # x_cat = torch.cat([x, x_l2, x_r2, x_t2, x_b2], dim=1) 
        #############################
        
        """ 4 diagonal directions """
        # #############################
        x_lu = x_pad[:, :, :-self.shift*2, :-self.shift*2]
        x_ru = x_pad[:, :, :-self.shift*2, self.shift*2:]
        x_lb = x_pad[:, :, self.shift*2:, :-self.shift*2]
        x_rb = x_pad[:, :, self.shift*2:, self.shift*2:]
        x_cat = torch.cat([x, x_lu, x_ru, x_lb, x_rb], dim=1) 
        # #############################
        
        """ 8 cardinal directions """
        #############################
        # x_l2 = x_pad[:, :, self.shift:-self.shift, :-self.shift*2]
        # x_r2 = x_pad[:, :, self.shift:-self.shift, self.shift*2:]
        # x_t2 = x_pad[:, :, :-self.shift*2, self.shift:-self.shift]
        # x_b2 = x_pad[:, :, self.shift*2:, self.shift:-self.shift]
        # x_lu = x_pad[:, :, :-self.shift*2, :-self.shift*2]
        # x_ru = x_pad[:, :, :-self.shift*2, self.shift*2:]
        # x_lb = x_pad[:, :, self.shift*2:, :-self.shift*2]
        # x_rb = x_pad[:, :, self.shift*2:, self.shift*2:]
        # x_cat = torch.cat([x, x_l2, x_r2, x_t2, x_b2, x_lu, x_ru, x_lb, x_rb], dim=1) 
        #############################
        
        # out = self.out(x_cat)
        out = x_cat
        
        return out

In [13]:
# --------------------------------------------------------
# Swin Transformer
# Copyright (c) 2021 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ze Liu
# --------------------------------------------------------
import math
import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint
from timm.models.layers import to_2tuple, trunc_normal_
# from utils.drop_path import DropPath
import torch
from einops.layers.torch import Rearrange
# from .SPT import ShiftedPatchTokenization
class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


def window_partition(x, window_size):
    """
    Args:
        x: (B, H, W, C)
        window_size (int): window size
    Returns:
        windows: (num_windows*B, window_size, window_size, C)
    """
    B, H, W, C = x.shape
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    return windows


def window_reverse(windows, window_size, H, W):
    """
    Args:
        windows: (num_windows*B, window_size, window_size, C)
        window_size (int): Window size
        H (int): Height of image
        W (int): Width of image
    Returns:
        x: (B, H, W, C)
    """
    B = int(windows.shape[0] / (H * W / window_size / window_size))
    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
    return x


class WindowAttention(nn.Module):
    r""" Window based multi-head self attention (W-MSA) module with relative position bias.
    It supports both of shifted and non-shifted window.
    Args:
        dim (int): Number of input channels.
        window_size (tuple[int]): The height and width of the window.
        num_heads (int): Number of attention heads.
        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
        proj_drop (float, optional): Dropout ratio of output. Default: 0.0
    """

    def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., is_LSA=False):

        super().__init__()
        self.dim = dim
        self.window_size = window_size  # Wh, Ww
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5
        self.is_LSA = is_LSA
        if is_LSA:
            self.scale = nn.Parameter(self.scale*torch.ones(self.num_heads))
            self.mask = torch.eye((window_size[0]**2), (window_size[0]**2))
            self.mask = torch.nonzero((self.mask == 1), as_tuple=False)
            self.inf = float('-inf')

        # define a parameter table of relative position bias
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH

        # get pair-wise relative position index for each token inside the window
        coords_h = torch.arange(self.window_size[0])
        coords_w = torch.arange(self.window_size[1])
        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
        relative_coords[:, :, 1] += self.window_size[1] - 1
        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
        self.register_buffer("relative_position_index", relative_position_index)

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        trunc_normal_(self.relative_position_bias_table, std=.02)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x, mask=None):
        """
        Args:
            x: input features with shape of (num_windows*B, N, C)
            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
        """
        B_, N, C = x.shape
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)


        if not self.is_LSA:
            q = q * self.scale
        
        else:
            scale = self.scale
            q = torch.mul(q, scale.unsqueeze(0).unsqueeze(-1).unsqueeze(-1).expand((B_, self.num_heads, 1, 1)))

        
        attn = (q @ k.transpose(-2, -1))

        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
        attn = attn + relative_position_bias.unsqueeze(0)

        if mask is not None:
            nW = mask.shape[0]
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N)

            if self.is_LSA:
                attn[:, :, self.mask[:, 0], self.mask[:, 1]] = self.inf
            attn = self.softmax(attn)
            
        else:
            if self.is_LSA:
                attn[:, :, self.mask[:, 0], self.mask[:, 1]] = self.inf
            attn = self.softmax(attn)

        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        
        return x

    def extra_repr(self) -> str:
        return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'


class SwinTransformerBlock(nn.Module):
    r""" Swin Transformer Block.
    Args:
        dim (int): Number of input channels.
        input_resolution (tuple[int]): Input resulotion.
        num_heads (int): Number of attention heads.
        window_size (int): Window size.
        shift_size (int): Shift size for SW-MSA.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float, optional): Stochastic depth rate. Default: 0.0
        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """

    def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
                 act_layer=nn.GELU, norm_layer=nn.LayerNorm, is_LSA=False):
        super().__init__()
        self.dim = dim
        self.input_resolution = input_resolution
        self.num_heads = num_heads
        self.window_size = window_size
        self.shift_size = shift_size
        self.mlp_ratio = mlp_ratio
        if min(self.input_resolution) <= self.window_size:
            # if window size is larger than input resolution, we don't partition windows
            self.shift_size = 0
            self.window_size = min(self.input_resolution)
        assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"

        self.norm1 = norm_layer(dim)
        self.attn = WindowAttention(
            dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
            qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, is_LSA=is_LSA)

        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

        if self.shift_size > 0:
            # calculate attention mask for SW-MSA
            H, W = self.input_resolution
            img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1
            h_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            w_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            cnt = 0
            for h in h_slices:
                for w in w_slices:
                    img_mask[:, h, w, :] = cnt
                    cnt += 1

            mask_windows = window_partition(img_mask, self.window_size)  # N_w^2, window_size, window_size, 1
            mask_windows = mask_windows.view(-1, self.window_size * self.window_size)   # N_w^2, window_size, window_size
            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)   # (N_w^2, 1, window_size, window_size) - (N_w^2, window_size, 1, window_size)
            attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
        else:
            attn_mask = None

        self.register_buffer("attn_mask", attn_mask)    # No parameter

    def forward(self, x):
        # H, W = self.input_resolution
        B, L, C = x.shape
        H = int(math.sqrt(L))
#        assert L == H * W, "input feature has wrong size"
#        print(H, W, B, L, C)

        
        shortcut = x
        x = self.norm1(x)
        x = x.view(B, H, H, C)

        # cyclic shift
        if self.shift_size > 0:
            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
        else:
            shifted_x = x

        # partition windows
        x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C
        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # nW*B, window_size*window_size, C

        # W-MSA/SW-MSA
        attn_windows = self.attn(x_windows, mask=self.attn_mask)  # nW*B, window_size*window_size, C

        # merge windows
        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
        shifted_x = window_reverse(attn_windows, self.window_size, H, H)  # B H' W' C

        # reverse cyclic shift
        if self.shift_size > 0:
            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
        else:
            x = shifted_x
        x = x.view(B, L, C)

        # FFN
        x = shortcut + self.drop_path(x)
        x = x + self.drop_path(self.mlp(self.norm2(x)))

        return x

    def extra_repr(self) -> str:
        return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
               f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"



class PatchMerging(nn.Module):
    r""" Patch Merging Layer.
    Args:
        input_resolution (tuple[int]): Resolution of input feature.
        dim (int): Number of input channels.
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """

    def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
        super().__init__()
        self.input_resolution = input_resolution
        self.dim = dim
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        self.norm = norm_layer(4 * dim)

    def forward(self, x):
        """
        x: B, H*W, C
        """
        H, W = self.input_resolution
        B, L, C = x.shape
#       assert L == H * W, "input feature has wrong size"
#       assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."

        x = x.view(B, H, W, C)

        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C

        x = self.norm(x)
        x = self.reduction(x)

        return x

    def extra_repr(self) -> str:
        return f"input_resolution={self.input_resolution}, dim={self.dim}"

    def flops(self):
        H, W = self.input_resolution
        flops = H * W * self.dim    # layer norm
        flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim  # reduction
        return flops


class BasicLayer(nn.Module):
    """ A basic Swin Transformer layer for one stage.
    Args:
        dim (int): Number of input channels.
        input_resolution (tuple[int]): Input resolution.
        depth (int): Number of blocks.
        num_heads (int): Number of attention heads.
        window_size (int): Local window size.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
    """

    def __init__(self, dim, input_resolution, depth, num_heads, window_size, 
                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0., norm_layer=nn.LayerNorm, downsample=False, use_checkpoint=False,
                 is_LSA=False, is_SPT=False):

        super().__init__()
        self.dim = dim
        self.input_resolution = input_resolution
        self.depth = depth
        self.use_checkpoint = use_checkpoint

        
        # build blocks
        self.blocks = nn.ModuleList([
            SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
                                 num_heads=num_heads, window_size=window_size,
                                 shift_size=0 if (i % 2 == 0) else window_size // 2,
                                 mlp_ratio=mlp_ratio,
                                 qkv_bias=qkv_bias, qk_scale=qk_scale,
                                 drop=drop, attn_drop=attn_drop,
                                 drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
                                 norm_layer=norm_layer, is_LSA=is_LSA)
            for i in range(depth)])

        # patch merging layer
        if downsample:
            if not is_SPT:
                self.downsample = PatchMerging(input_resolution, dim=dim, norm_layer=norm_layer)
            else:
                self.downsample = ShiftedPatchTokenization(dim, dim*2, 2)
                    
        else:
            self.downsample = None



    def forward(self, x):
        for blk in self.blocks:
            if self.use_checkpoint:
                x = checkpoint.checkpoint(blk, x)
            else:
                #print(x.shape)
                x = blk(x)
        if self.downsample is not None:
             x = self.downsample(x)
            
             
        return x

    def extra_repr(self) -> str:
        return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"



class PatchEmbed(nn.Module):
    r""" Image to Patch Embedding
    Args:
        img_size (int): Image size.  Default: 224.
        patch_size (int): Patch token size. Default: 4.
        in_chans (int): Number of input image channels. Default: 3.
        embed_dim (int): Number of linear projection output channels. Default: 96.
        norm_layer (nn.Module, optional): Normalization layer. Default: None
    """

    def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
        super().__init__()
        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)
        patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
        self.img_size = img_size
        self.patch_size = patch_size
        self.patches_resolution = patches_resolution

        self.in_chans = in_chans
        self.embed_dim = embed_dim

        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        if norm_layer is not None:
            self.norm = norm_layer(embed_dim)
        else:
            self.norm = None

    def forward(self, x):
        B, C, H, W = x.shape
        # FIXME look at relaxing size constraints
        assert H == self.img_size[0] and W == self.img_size[1], \
            f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
        x = self.proj(x).flatten(2).transpose(1, 2)  # B Ph*Pw C
        if self.norm is not None:
            x = self.norm(x)
        return x

class SwinTransformer(nn.Module):
    r""" Swin Transformer
        A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows`  -
          https://arxiv.org/pdf/2103.14030
    Args:
        img_size (int | tuple(int)): Input image size. Default 224
        patch_size (int | tuple(int)): Patch size. Default: 4
        in_chans (int): Number of input image channels. Default: 3
        num_classes (int): Number of classes for classification head. Default: 1000
        embed_dim (int): Patch embedding dimension. Default: 96
        depths (tuple(int)): Depth of each Swin Transformer layer.
        num_heads (tuple(int)): Number of attention heads in different layers.
        window_size (int): Window size. Default: 7
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
        qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
        drop_rate (float): Dropout rate. Default: 0
        attn_drop_rate (float): Attention dropout rate. Default: 0
        drop_path_rate (float): Stochastic depth rate. Default: 0.1
        norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
        ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
        patch_norm (bool): If True, add normalization after patch embedding. Default: True
        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
    """

    def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,
                 embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
                 window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
                 norm_layer=nn.LayerNorm, patch_norm=True,
                 use_checkpoint=False, is_LSA=False, is_SPT=False,
                 **kwargs):
        super().__init__()
           

        self.num_classes = num_classes
        self.num_layers = len(depths)
        self.embed_dim = embed_dim
        self.patch_norm = patch_norm
        self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
        self.mlp_ratio = mlp_ratio
        
        """ Base """
        if not is_SPT:
            self.patch_embed = PatchEmbed(
                img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
                norm_layer=norm_layer if self.patch_norm else None)     
            self.img_resolution = self.patch_embed.patches_resolution

        else:
            self.patch_embed = ShiftedPatchTokenization(3, embed_dim, patch_size, is_pe=True)
            self.img_resolution = (img_size//patch_size, img_size//patch_size)  
        
        # absolute position embedding
        self.absolute_pos_embed = nn.Parameter(torch.zeros(1, self.img_resolution[0]**2, embed_dim))
        trunc_normal_(self.absolute_pos_embed, std=.02)

        self.pos_drop = nn.Dropout(p=drop_rate)

        # stochastic depth
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule

        self.pool_idx = list()
        

        # build layers
        self.layers = nn.ModuleList()
        for i_layer in range(self.num_layers):
            is_first = i_layer == 0
            layer = BasicLayer(
                               dim=int(embed_dim * 2 ** i_layer),
                               input_resolution=(self.img_resolution[0] // (2 ** i_layer),
                                                 self.img_resolution[1] // (2 ** i_layer)),
                               depth=depths[i_layer],
                               num_heads=num_heads[i_layer],
                               window_size=window_size,
                               mlp_ratio=self.mlp_ratio,
                               qkv_bias=qkv_bias, qk_scale=qk_scale,
                               drop=drop_rate, attn_drop=attn_drop_rate,
                               drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
                               norm_layer=norm_layer, is_LSA=is_LSA, is_SPT=is_SPT,
                               downsample=True if (i_layer < self.num_layers - 1) else False,
                               use_checkpoint=use_checkpoint)
            self.layers.append(layer)
            
        self.img_resolution = [self.img_resolution[0] // (2**(self.num_layers-1)), 
                               self.img_resolution[1] // (2**(self.num_layers-1))]

        self.norm = norm_layer(self.num_features)
        self.avgpool = nn.AdaptiveAvgPool1d(1)
        self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()

    
    def forward_features(self, x):
    
        k = 0        
        
        x = self.patch_embed(x)   
        
        x = x + self.absolute_pos_embed
        x = self.pos_drop(x)
        
        for i, layer in enumerate(self.layers):
            x = layer(x)

                
        x = self.norm(x)  # B L C
        x = self.avgpool(x.transpose(1, 2))  # B C 1
        x = torch.flatten(x, 1)
               
        return x

    def forward(self, x):
        x = self.forward_features(x)
        x = self.head(x)
        return x

In [14]:
def create_model(img_size, n_classes, args):
    if args.model =='swin':
        depths = [2, 6, 4]
        num_heads = [3, 6, 12]
        mlp_ratio = 2
        window_size = 4
        patch_size = 2 if img_size == 32 else 4
            
        model = SwinTransformer(img_size=img_size, window_size=window_size, drop_path_rate=args.sd, 
                                patch_size=patch_size, mlp_ratio=mlp_ratio, depths=depths, num_heads=num_heads, num_classes=n_classes, 
                                is_SPT=args.is_SPT, is_LSA=args.is_LSA)
        
    return model

In [22]:
import easydict
def init_parser(data_path='./dataset', dataset='CIFAR10', workers=4, print_freq=1, epochs=100, warmup=10):
    args = easydict.EasyDict({'data_path':data_path, 'dataset':dataset, 'workers':workers, \
                              'print_freq':print_freq, 'epochs':epochs, 'warmup':warmup, \
                              'batch_size':128, 'lr':0.001, 'weight_decay':5e-2, 'model':'swin', \
                              'disable_cos':False, 'enable_aug':False, 'gpu':0, 'no_cuda':False, \
                              'ls':False, 'seed':0, 'tag':'', 'sd':0.1, 'resume':False,\
                              'smoothing':0.1, 'beta':1.0, 'alpha':1.0, 'mix_prob':0.5, \
                              'ra':3, 're':0.25, 're_sh':0.4, 're_r1':0.3, 'is_LSA':True, 'is_SPT':True, \
                              'cm':False, 'mu':False, 'aa':False
                              })

    # Optimization hyperparams
    # parser.add_argument('--channel', type=int, help='disable cuda')
    # parser.add_argument('--heads', type=int, help='disable cuda')
    # parser.add_argument('--depth', type=int, help='disable cuda')

    return args


In [30]:
from __future__ import absolute_import

from torchvision.transforms import *

import random
import math

class RandomErasing(object):
    def __init__(self, probability = 0.5, sl = 0.02, sh = 0.4, r1 = 0.3, mean=[0.4914, 0.4822, 0.4465]):
        self.EPSILON = probability
        self.mean = mean
        self.sl = sl
        self.sh = sh
        self.r1 = r1
       
    def __call__(self, img):

        if random.uniform(0, 1) > self.EPSILON:
            return img

        for _ in range(100):
            area = img.size()[1] * img.size()[2]
       
            target_area = random.uniform(self.sl, self.sh) * area
            aspect_ratio = random.uniform(self.r1, 1/self.r1)

            h = int(round(math.sqrt(target_area * aspect_ratio)))
            w = int(round(math.sqrt(target_area / aspect_ratio)))

            if w < img.size()[2] and h < img.size()[1]:
                x1 = random.randint(0, img.size()[1] - h)
                y1 = random.randint(0, img.size()[2] - w)
                if img.size()[0] == 3:
                    #img[0, x1:x1+h, y1:y1+w] = random.uniform(0, 1)
                    #img[1, x1:x1+h, y1:y1+w] = random.uniform(0, 1)
                    #img[2, x1:x1+h, y1:y1+w] = random.uniform(0, 1)
                    img[0, x1:x1+h, y1:y1+w] = self.mean[0]
                    img[1, x1:x1+h, y1:y1+w] = self.mean[1]
                    img[2, x1:x1+h, y1:y1+w] = self.mean[2]
                    #img[:, x1:x1+h, y1:y1+w] = torch.from_numpy(np.random.rand(3, h, w))
                else:
                    img[0, x1:x1+h, y1:y1+w] = self.mean[1]
                    # img[0, x1:x1+h, y1:y1+w] = torch.from_numpy(np.random.rand(1, h, w))
                return img

        return img

In [32]:
def train(train_loader, model, criterion, optimizer, epoch, scheduler,  args):
    model.train()
    loss_val, acc1_val = 0, 0
    n = 0
        
    
    for i, (images, target) in enumerate(train_loader):
        if (not args.no_cuda) and torch.cuda.is_available():
            images = images.cuda(args.gpu, non_blocking=True)
            target = target.cuda(args.gpu, non_blocking=True)
                
        # Cutmix only
        if args.cm and not args.mu:
            r = np.random.rand(1)
            if r < args.mix_prob:
                slicing_idx, y_a, y_b, lam, sliced = cutmix_data(images, target, args)
                images[:, :, slicing_idx[0]:slicing_idx[2], slicing_idx[1]:slicing_idx[3]] = sliced
                output = model(images)
                
                loss =  mixup_criterion(criterion, output, y_a, y_b, lam)
                
                   
            else:
                output = model(images)
                
                loss = criterion(output, target)
                               
                
        # Mixup only
        elif not args.cm and args.mu:
            r = np.random.rand(1)
            if r < args.mix_prob:
                images, y_a, y_b, lam = mixup_data(images, target, args)
                output = model(images)
                
                loss =  mixup_criterion(criterion, output, y_a, y_b, lam)
                
                
            
            else:
                output = model(images)
                
                loss =  criterion(output, target)
                 
                
        # Both Cutmix and Mixup
        elif args.cm and args.mu:
            r = np.random.rand(1)
            if r < args.mix_prob:
                switching_prob = np.random.rand(1)
                
                # Cutmix
                if switching_prob < 0.5:
                    slicing_idx, y_a, y_b, lam, sliced = cutmix_data(images, target, args)
                    images[:, :, slicing_idx[0]:slicing_idx[2], slicing_idx[1]:slicing_idx[3]] = sliced
                    output = model(images)
                    
                    loss =  mixup_criterion(criterion, output, y_a, y_b, lam)
                    
                    
                # Mixup
                else:
                    images, y_a, y_b, lam = mixup_data(images, target, args)
                    output = model(images)
                    
                    loss = mixup_criterion(criterion, output, y_a, y_b, lam) 
                    
            else:
                output = model(images)
                
                loss = criterion(output, target) 
          
        # No Mix
        else:
            output = model(images)
                                
            loss = criterion(output, target)
            
        acc = accuracy(output, target, (1,))
        acc1 = acc[0]
        n += images.size(0)
        loss_val += float(loss.item() * images.size(0))
        acc1_val += float(acc1[0] * images.size(0))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()
        lr = optimizer.param_groups[0]["lr"]

        if args.print_freq >= 0 and i % args.print_freq == 0:
            avg_loss, avg_acc1 = (loss_val / n), (acc1_val / n)
            progress_bar(i, len(train_loader),f'[Epoch {epoch+1}/{args.epochs}][T][{i}]   Loss: {avg_loss:.4e}   Top-1: {avg_acc1:6.2f}   LR: {lr:.7f}'+' '*10)

    logger_dict.update(keys[0], avg_loss)
    logger_dict.update(keys[1], avg_acc1)
    writer.add_scalar("Loss/train", avg_loss, epoch)
    writer.add_scalar("Acc/train", avg_acc1, epoch)
    
    return lr


In [36]:
def validate(val_loader, model, criterion, lr, args, epoch=None):
    model.eval()
    loss_val, acc1_val = 0, 0
    n = 0
    with torch.no_grad():
        for i, (images, target) in enumerate(val_loader):
            if (not args.no_cuda) and torch.cuda.is_available():
                images = images.cuda(args.gpu, non_blocking=True)
                target = target.cuda(args.gpu, non_blocking=True)

            
            output = model(images)
            loss = criterion(output, target)
            
            acc = accuracy(output, target, (1, 5))
            acc1 = acc[0]
            n += images.size(0)
            loss_val += float(loss.item() * images.size(0))
            acc1_val += float(acc1[0] * images.size(0))

            if args.print_freq >= 0 and i % args.print_freq == 0:
                avg_loss, avg_acc1 = (loss_val / n), (acc1_val / n)
                progress_bar(i, len(val_loader), f'[Epoch {epoch+1}][V][{i}]   Loss: {avg_loss:.4e}   Top-1: {avg_acc1:6.2f}   LR: {lr:.6f}')
    print()        

    print(Fore.BLUE)
    print('*'*80)
    
    logger_dict.update(keys[2], avg_loss)
    logger_dict.update(keys[3], avg_acc1)
    
    writer.add_scalar("Loss/val", avg_loss, epoch)
    writer.add_scalar("Acc/val", avg_acc1, epoch)

    
    return avg_acc1

In [38]:
args = init_parser(epochs=50)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)  # if you are using multi-GPU.
np.random.seed(args.seed)  # Numpy module.
random.seed(args.seed)  # Python random module.
torch.manual_seed(args.seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

model_name = args.model
best_acc1 = 0

if not args.is_SPT:
    model_name += "-Base"
else:
    model_name += "-SPT"

if args.is_LSA:
    model_name += "-LSA"
    
model_name += f"-{args.tag}-{args.dataset}-LR[{args.lr}]-Seed{args.seed}"
writer = SummaryWriter(os.path.join(os.getcwd(), 'tensorboard', model_name))
save_path = os.path.join(os.getcwd(), 'save', model_name)
if save_path:
    os.makedirs(save_path, exist_ok=True)

torch.cuda.set_device(args.gpu)

log_dir = os.path.join(save_path, 'history.csv')
logger = log.getLogger(__name__)
formatter = log.Formatter('%(message)s')
streamHandler = log.StreamHandler()
fileHandler = log.FileHandler(log_dir, 'a')
streamHandler.setFormatter(formatter)
fileHandler.setFormatter(formatter)
logger.addHandler(streamHandler)
logger.addHandler(fileHandler)
logger.setLevel(level=log.DEBUG)


global logger_dict
global keys

logger_dict = Logger_dict(logger, save_path)
keys = ['T Loss', 'T Top-1', 'V Loss', 'V Top-1']

data_info = datainfo(logger, args)

model = create_model(data_info['img_size'], data_info['n_classes'], args)
model.cuda(args.gpu)

print(Fore.GREEN+'*'*80)
logger.debug(f"Creating model: {model_name}")    
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
logger.debug(f'Number of params: {format(n_parameters, ",")}')
logger.debug(f'Initial learning rate: {args.lr:.6f}')
logger.debug(f"Start training for {args.epochs} epochs")
print('*'*80+Style.RESET_ALL)


if args.ls:
    print(Fore.YELLOW + '*'*80)
    logger.debug('label smoothing used')
    print('*'*80+Style.RESET_ALL)
    criterion = LabelSmoothingCrossEntropy()

else:
    criterion = nn.CrossEntropyLoss()    
    
if args.sd > 0.:
    print(Fore.YELLOW + '*'*80)
    logger.debug(f'Stochastic depth({args.sd}) used ')
    print('*'*80+Style.RESET_ALL)         

criterion = criterion.cuda(args.gpu)

normalize = [transforms.Normalize(mean=data_info['stat'][0], std=data_info['stat'][1])]


if args.cm:
    print(Fore.YELLOW+'*'*80)
    logger.debug('Cutmix used')
    print('*'*80 + Style.RESET_ALL)
if args.mu:
    print(Fore.YELLOW+'*'*80)
    logger.debug('Mixup used')
    print('*'*80 + Style.RESET_ALL)
if args.ra > 1:        
    
    print(Fore.YELLOW+'*'*80)
    logger.debug(f'Repeated Aug({args.ra}) used')
    print('*'*80 + Style.RESET_ALL)

'''
    Data Augmentation
'''
augmentations = []

if args.aa == True:
    print(Fore.YELLOW+'*'*80)
    logger.debug('Autoaugmentation used')      
    
    if 'CIFAR' in args.dataset:
        print("CIFAR Policy")
        from utils.autoaug import CIFAR10Policy
        augmentations += [
            
            transforms.RandomCrop(data_info['img_size'], padding=4),
            transforms.RandomHorizontalFlip(),
            CIFAR10Policy()
        ]
        
    elif 'SVHN' in args.dataset:
        print("SVHN Policy")    
        from utils.autoaug import SVHNPolicy
        augmentations += [
            
          transforms.RandomCrop(data_info['img_size'], padding=4),
            transforms.RandomHorizontalFlip(),
            SVHNPolicy()
        ]
                
    else:
        from utils.autoaug import ImageNetPolicy
        augmentations += [                
          transforms.RandomCrop(data_info['img_size'], padding=4),
            transforms.RandomHorizontalFlip(),
            ImageNetPolicy()
        ]
        
    print('*'*80 + Style.RESET_ALL)
    

if args.re > 0:
    print(Fore.YELLOW + '*'*80)
    logger.debug(f'Random erasing({args.re}) used ')
    print('*'*80+Style.RESET_ALL)    
    
    
    augmentations += [                
        transforms.ToTensor(),
        *normalize,
        RandomErasing(probability = args.re, sh = args.re_sh, r1 = args.re_r1, mean=data_info['stat'][0])]

else:
    augmentations += [                
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(data_info['img_size'], padding=4),
        transforms.ToTensor(),
        *normalize]


augmentations = transforms.Compose(augmentations)
  
train_dataset, val_dataset = dataload(args, augmentations, normalize, data_info)

train_loader = torch.utils.data.DataLoader(
    train_dataset,  num_workers=args.workers, pin_memory=True,
    batch_sampler=RASampler(len(train_dataset), args.batch_size, 1, args.ra, shuffle=True, drop_last=True))
val_loader = torch.utils.data.DataLoader(
    val_dataset, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=args.workers)
'''
    Training
'''

optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
scheduler = build_scheduler(args, optimizer, len(train_loader))

summary(model, (3, data_info['img_size'], data_info['img_size']))

print()
print("Beginning training")
print()

lr = optimizer.param_groups[0]["lr"]

if args.resume:
    checkpoint = torch.load(args.resume)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    scheduler.load_state_dict(checkpoint['scheduler'])
    final_epoch = args.epochs
    args.epochs = final_epoch - (checkpoint['epoch'] + 1)


for epoch in tqdm(range(args.epochs)):
    lr = train(train_loader, model, criterion, optimizer, epoch, scheduler, args)
    acc1 = validate(val_loader, model, criterion, lr, args, epoch=epoch)
    torch.save({
        'model_state_dict': model.state_dict(),
        'epoch': epoch,
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler': scheduler.state_dict(), 
        }, 
        os.path.join(save_path, 'checkpoint.pth'))
    
    logger_dict.print()
    
    if acc1 > best_acc1:
        print('* Best model upate *')
        best_acc1 = acc1
        
        torch.save({
                'model_state_dict': model.state_dict(),
                'epoch': epoch,
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler': scheduler.state_dict(),
            }, os.path.join(save_path, 'best.pth'))         
    
    print(f'Best acc1 {best_acc1:.2f}')
    print('*'*80)
    print(Style.RESET_ALL)        
    
    writer.add_scalar("Learning Rate", lr, epoch)
    
    
print(Fore.RED+'*'*80)
logger.debug(f'best top-1: {best_acc1:.2f}, final top-1: {acc1:.2f}')
print('*'*80+Style.RESET_ALL)
torch.save(model.state_dict(), os.path.join(save_path, 'checkpoint.pth'))

CIFAR10
CIFAR10
CIFAR10
CIFAR10
CIFAR10
CIFAR10
CIFAR10
CIFAR10
Creating model: swin-SPT-LSA--CIFAR10-LR[0.001]-Seed0
Creating model: swin-SPT-LSA--CIFAR10-LR[0.001]-Seed0
Creating model: swin-SPT-LSA--CIFAR10-LR[0.001]-Seed0
Creating model: swin-SPT-LSA--CIFAR10-LR[0.001]-Seed0
Creating model: swin-SPT-LSA--CIFAR10-LR[0.001]-Seed0
Creating model: swin-SPT-LSA--CIFAR10-LR[0.001]-Seed0
Creating model: swin-SPT-LSA--CIFAR10-LR[0.001]-Seed0
Creating model: swin-SPT-LSA--CIFAR10-LR[0.001]-Seed0
Number of params: 8,562,166
Number of params: 8,562,166
Number of params: 8,562,166
Number of params: 8,562,166
Number of params: 8,562,166
Number of params: 8,562,166
Number of params: 8,562,166
Number of params: 8,562,166
Initial learning rate: 0.001000
Initial learning rate: 0.001000
Initial learning rate: 0.001000
Initial learning rate: 0.001000
Initial learning rate: 0.001000
Initial learning rate: 0.001000
Initial learning rate: 0.001000
Initial learning rate: 0.001000
Start training for 50 ep

[33m********************************************************************************
********************************************************************************[0m
[32m********************************************************************************
********************************************************************************[0m
[33m********************************************************************************


Stochastic depth(0.1) used 
Stochastic depth(0.1) used 
Stochastic depth(0.1) used 
Stochastic depth(0.1) used 
Stochastic depth(0.1) used 
Repeated Aug(3) used
Repeated Aug(3) used
Repeated Aug(3) used
Repeated Aug(3) used
Repeated Aug(3) used
Repeated Aug(3) used
Repeated Aug(3) used
Repeated Aug(3) used
Random erasing(0.25) used 
Random erasing(0.25) used 
Random erasing(0.25) used 
Random erasing(0.25) used 
Random erasing(0.25) used 
Random erasing(0.25) used 
Random erasing(0.25) used 
Random erasing(0.25) used 


********************************************************************************[0m
[33m********************************************************************************
********************************************************************************[0m
[33m********************************************************************************
********************************************************************************[0m
Files already downloaded and verified
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
     PatchShifting-1           [-1, 15, 32, 32]               0
         Rearrange-2              [-1, 256, 60]               0
         LayerNorm-3              [-1, 256, 60]             120
            Linear-4              [-1, 256, 96]           5,856
ShiftedPatchTokenization-5              [-1, 256, 96]               0
           Dropout-6              [-1, 256, 96]               0
         LayerNo

  0%|          | 0/50 [00:00<?, ?it/s]

[Epoch 1/50][T][1170]   Loss: 1.5751e+00   Top-1:  43.11   LR: 0.0001009          
[Epoch 1][V][78]   Loss: 1.2404e+00   Top-1:  55.42   LR: 0.000101

[34m
********************************************************************************
T Loss	1.5751214263583737
T Top-1	43.11285760034159
V Loss	1.2403974447250365
V Top-1	55.42

* Best model upate *


  2%|▏         | 1/50 [03:15<2:39:51, 195.75s/it]

Best acc1 55.42
********************************************************************************
[0m
[Epoch 2/50][T][1170]   Loss: 1.0347e+00   Top-1:  63.11   LR: 0.0002008          
[Epoch 2][V][78]   Loss: 9.1944e-01   Top-1:  67.05   LR: 0.000201

[34m
********************************************************************************
T Loss	1.0346717726563308
T Top-1	63.113791631084545
V Loss	0.9194403228759765
V Top-1	67.05

* Best model upate *


  4%|▍         | 2/50 [06:31<2:36:42, 195.88s/it]

Best acc1 67.05
********************************************************************************
[0m
[Epoch 3/50][T][1170]   Loss: 7.2815e-01   Top-1:  74.33   LR: 0.0003007          
[Epoch 3][V][78]   Loss: 8.5674e-01   Top-1:  70.91   LR: 0.000301

[34m
********************************************************************************
T Loss	0.7281523625327418
T Top-1	74.32883219470538
V Loss	0.8567358511924744
V Top-1	70.91

* Best model upate *


  6%|▌         | 3/50 [09:47<2:33:22, 195.79s/it]

Best acc1 70.91
********************************************************************************
[0m
[Epoch 4/50][T][1170]   Loss: 5.4348e-01   Top-1:  80.76   LR: 0.0004006          
[Epoch 4][V][78]   Loss: 8.0762e-01   Top-1:  73.84   LR: 0.000401

[34m
********************************************************************************
T Loss	0.5434779106654437
T Top-1	80.76363684884714
V Loss	0.8076164762496948
V Top-1	73.84

* Best model upate *


  8%|▊         | 4/50 [13:03<2:30:05, 195.77s/it]

Best acc1 73.84
********************************************************************************
[0m
[Epoch 5/50][T][1170]   Loss: 4.4301e-01   Top-1:  84.38   LR: 0.0005005          
[Epoch 5][V][78]   Loss: 7.8685e-01   Top-1:  75.10   LR: 0.000501

[34m
********************************************************************************
T Loss	0.4430058037211608
T Top-1	84.38300597779676
V Loss	0.7868496666908265
V Top-1	75.1

* Best model upate *


 10%|█         | 5/50 [16:19<2:26:51, 195.82s/it]

Best acc1 75.10
********************************************************************************
[0m
[Epoch 6/50][T][1170]   Loss: 3.7996e-01   Top-1:  86.61   LR: 0.0006004          
[Epoch 6][V][78]   Loss: 7.8576e-01   Top-1:  75.92   LR: 0.000600

[34m
********************************************************************************
T Loss	0.3799643236841532
T Top-1	86.6106692997438
V Loss	0.7857618550300598
V Top-1	75.92

* Best model upate *


 12%|█▏        | 6/50 [19:35<2:23:38, 195.88s/it]

Best acc1 75.92
********************************************************************************
[0m
[Epoch 7/50][T][1170]   Loss: 3.3991e-01   Top-1:  88.09   LR: 0.0007003          
[Epoch 7][V][78]   Loss: 7.5987e-01   Top-1:  76.47   LR: 0.000700

[34m
********************************************************************************
T Loss	0.3399131581072722
T Top-1	88.0851035439795
V Loss	0.7598741766929626
V Top-1	76.47

* Best model upate *


 14%|█▍        | 7/50 [22:51<2:20:36, 196.20s/it]

Best acc1 76.47
********************************************************************************
[0m
[Epoch 8/50][T][1170]   Loss: 3.0053e-01   Top-1:  89.48   LR: 0.0008002          
[Epoch 8][V][78]   Loss: 8.3968e-01   Top-1:  76.30   LR: 0.000800

[34m
********************************************************************************


 16%|█▌        | 8/50 [26:08<2:17:25, 196.33s/it]

T Loss	0.30052570488389796
T Top-1	89.48147950469684
V Loss	0.8396818264007568
V Top-1	76.3

Best acc1 76.47
********************************************************************************
[0m
[Epoch 9/50][T][1170]   Loss: 2.8188e-01   Top-1:  90.14   LR: 0.0009001          
[Epoch 9][V][78]   Loss: 8.0244e-01   Top-1:  77.42   LR: 0.000900

[34m
********************************************************************************
T Loss	0.28187640494866617
T Top-1	90.14063834329633
V Loss	0.802444375705719
V Top-1	77.42

* Best model upate *


 18%|█▊        | 9/50 [29:25<2:14:23, 196.66s/it]

Best acc1 77.42
********************************************************************************
[0m
[Epoch 10/50][T][1170]   Loss: 2.5950e-01   Top-1:  90.95   LR: 0.0010000          
[Epoch 10][V][78]   Loss: 8.1955e-01   Top-1:  76.83   LR: 0.001000

[34m
********************************************************************************


 20%|██        | 10/50 [32:41<2:10:58, 196.45s/it]

T Loss	0.25949996101561007
T Top-1	90.95391225448334
V Loss	0.8195484823226928
V Top-1	76.83

Best acc1 77.42
********************************************************************************
[0m
[Epoch 11/50][T][1170]   Loss: 2.3225e-01   Top-1:  91.90   LR: 0.0009985          
[Epoch 11][V][78]   Loss: 8.3621e-01   Top-1:  78.21   LR: 0.000998

[34m
********************************************************************************
T Loss	0.2322549119680789
T Top-1	91.9032877882152
V Loss	0.8362138164520264
V Top-1	78.21

* Best model upate *


 22%|██▏       | 11/50 [35:58<2:07:48, 196.62s/it]

Best acc1 78.21
********************************************************************************
[0m
[Epoch 12/50][T][1170]   Loss: 2.0431e-01   Top-1:  92.95   LR: 0.0009939          
[Epoch 12][V][78]   Loss: 8.0516e-01   Top-1:  79.18   LR: 0.000994

[34m
********************************************************************************
T Loss	0.20431010228487084
T Top-1	92.9534052092229
V Loss	0.8051638414382934
V Top-1	79.18

* Best model upate *


 24%|██▍       | 12/50 [39:15<2:04:26, 196.50s/it]

Best acc1 79.18
********************************************************************************
[0m
[Epoch 13/50][T][1170]   Loss: 1.7753e-01   Top-1:  93.82   LR: 0.0009862          
[Epoch 13][V][78]   Loss: 8.3982e-01   Top-1:  79.45   LR: 0.000986

[34m
********************************************************************************
T Loss	0.17752519815231266
T Top-1	93.81871797608882
V Loss	0.8398152425289154
V Top-1	79.45

* Best model upate *


 26%|██▌       | 13/50 [42:31<2:01:06, 196.40s/it]

Best acc1 79.45
********************************************************************************
[0m
[Epoch 14/50][T][1170]   Loss: 1.6219e-01   Top-1:  94.41   LR: 0.0009756          
[Epoch 14][V][78]   Loss: 8.0275e-01   Top-1:  79.71   LR: 0.000976

[34m
********************************************************************************
T Loss	0.16219010728353317
T Top-1	94.41449615713066
V Loss	0.8027450355529785
V Top-1	79.71

* Best model upate *


 28%|██▊       | 14/50 [45:46<1:57:37, 196.03s/it]

Best acc1 79.71
********************************************************************************
[0m
[Epoch 15/50][T][1170]   Loss: 1.4751e-01   Top-1:  94.89   LR: 0.0009620          
[Epoch 15][V][78]   Loss: 8.3898e-01   Top-1:  79.32   LR: 0.000962

[34m
********************************************************************************


 30%|███       | 15/50 [49:02<1:54:20, 196.00s/it]

T Loss	0.14751232905197256
T Top-1	94.89018467122118
V Loss	0.8389796092987061
V Top-1	79.32

Best acc1 79.71
********************************************************************************
[0m
[Epoch 16/50][T][1170]   Loss: 1.3645e-01   Top-1:  95.29   LR: 0.0009456          
[Epoch 16][V][78]   Loss: 8.4559e-01   Top-1:  80.43   LR: 0.000946

[34m
********************************************************************************
T Loss	0.13644599856651307
T Top-1	95.28648057216054
V Loss	0.8455919489860535
V Top-1	80.43

* Best model upate *


 32%|███▏      | 16/50 [52:18<1:51:04, 196.02s/it]

Best acc1 80.43
********************************************************************************
[0m
[Epoch 17/50][T][1170]   Loss: 1.2511e-01   Top-1:  95.64   LR: 0.0009264          
[Epoch 17][V][78]   Loss: 8.7760e-01   Top-1:  80.12   LR: 0.000926

[34m
********************************************************************************


 34%|███▍      | 17/50 [55:33<1:47:42, 195.83s/it]

T Loss	0.12511115925375413
T Top-1	95.64274658411614
V Loss	0.8775976937055587
V Top-1	80.12

Best acc1 80.43
********************************************************************************
[0m
[Epoch 18/50][T][1170]   Loss: 1.1924e-01   Top-1:  95.92   LR: 0.0009046          
[Epoch 18][V][78]   Loss: 8.2483e-01   Top-1:  81.00   LR: 0.000905

[34m
********************************************************************************
T Loss	0.11923522895709457
T Top-1	95.92028714773697
V Loss	0.8248301813602448
V Top-1	81.0

* Best model upate *


 36%|███▌      | 18/50 [58:49<1:44:22, 195.70s/it]

Best acc1 81.00
********************************************************************************
[0m
[Epoch 19/50][T][1170]   Loss: 1.0725e-01   Top-1:  96.32   LR: 0.0008803          
[Epoch 19][V][78]   Loss: 8.3510e-01   Top-1:  81.32   LR: 0.000880

[34m
********************************************************************************
T Loss	0.10725101610179592
T Top-1	96.31658304867635
V Loss	0.8351039840698242
V Top-1	81.32

* Best model upate *


 38%|███▊      | 19/50 [1:02:05<1:41:11, 195.85s/it]

Best acc1 81.32
********************************************************************************
[0m
[Epoch 20/50][T][1170]   Loss: 1.0257e-01   Top-1:  96.46   LR: 0.0008537          
[Epoch 20][V][78]   Loss: 8.7195e-01   Top-1:  79.76   LR: 0.000854

[34m
********************************************************************************


 40%|████      | 20/50 [1:05:21<1:37:56, 195.88s/it]

T Loss	0.1025734158045112
T Top-1	96.45935631938514
V Loss	0.8719472801208497
V Top-1	79.76

Best acc1 81.32
********************************************************************************
[0m
[Epoch 21/50][T][1170]   Loss: 9.6317e-02   Top-1:  96.70   LR: 0.0008249          
[Epoch 21][V][78]   Loss: 8.1789e-01   Top-1:  81.88   LR: 0.000825

[34m
********************************************************************************
T Loss	0.09631675768589536
T Top-1	96.70020281810419
V Loss	0.8178937817573547
V Top-1	81.88

* Best model upate *


 42%|████▏     | 21/50 [1:08:38<1:34:47, 196.12s/it]

Best acc1 81.88
********************************************************************************
[0m
[Epoch 22/50][T][1170]   Loss: 8.7536e-02   Top-1:  97.00   LR: 0.0007941          
[Epoch 22][V][78]   Loss: 8.3012e-01   Top-1:  82.10   LR: 0.000794

[34m
********************************************************************************
T Loss	0.08753557986222736
T Top-1	96.9984254910333
V Loss	0.8301187068939209
V Top-1	82.1

* Best model upate *


 44%|████▍     | 22/50 [1:11:54<1:31:33, 196.19s/it]

Best acc1 82.10
********************************************************************************
[0m
[Epoch 23/50][T][1170]   Loss: 8.3180e-02   Top-1:  97.15   LR: 0.0007615          
[Epoch 23][V][78]   Loss: 8.5118e-01   Top-1:  81.18   LR: 0.000761

[34m
********************************************************************************


 46%|████▌     | 23/50 [1:15:10<1:28:17, 196.20s/it]

T Loss	0.08317957479064532
T Top-1	97.15187339880444
V Loss	0.8511788854599
V Top-1	81.18

Best acc1 82.10
********************************************************************************
[0m
[Epoch 24/50][T][1170]   Loss: 7.8228e-02   Top-1:  97.35   LR: 0.0007273          
[Epoch 24][V][78]   Loss: 8.6806e-01   Top-1:  81.56   LR: 0.000727

[34m
********************************************************************************


 48%|████▊     | 24/50 [1:18:27<1:25:03, 196.30s/it]

T Loss	0.07822835396889384
T Top-1	97.34535119555935
V Loss	0.868059076499939
V Top-1	81.56

Best acc1 82.10
********************************************************************************
[0m
[Epoch 25/50][T][1170]   Loss: 7.2416e-02   Top-1:  97.53   LR: 0.0006917          
[Epoch 25][V][78]   Loss: 8.8393e-01   Top-1:  82.00   LR: 0.000692

[34m
********************************************************************************


 50%|█████     | 25/50 [1:21:43<1:21:47, 196.29s/it]

T Loss	0.07241597971031928
T Top-1	97.53415883859948
V Loss	0.8839295742988587
V Top-1	82.0

Best acc1 82.10
********************************************************************************
[0m
[Epoch 26/50][T][1170]   Loss: 6.8475e-02   Top-1:  97.66   LR: 0.0006549          
[Epoch 26][V][78]   Loss: 8.9025e-01   Top-1:  82.15   LR: 0.000655

[34m
********************************************************************************
T Loss	0.06847528055275255
T Top-1	97.66492314261315
V Loss	0.8902476426124573
V Top-1	82.15

* Best model upate *


 52%|█████▏    | 26/50 [1:24:59<1:18:31, 196.33s/it]

Best acc1 82.15
********************************************************************************
[0m
[Epoch 27/50][T][1170]   Loss: 6.3984e-02   Top-1:  97.81   LR: 0.0006171          
[Epoch 27][V][78]   Loss: 8.8684e-01   Top-1:  81.11   LR: 0.000617

[34m
********************************************************************************


 54%|█████▍    | 27/50 [1:28:16<1:15:14, 196.29s/it]

T Loss	0.06398353419114734
T Top-1	97.81436806148591
V Loss	0.886838108253479
V Top-1	81.11

Best acc1 82.15
********************************************************************************
[0m
[Epoch 28/50][T][1170]   Loss: 6.0155e-02   Top-1:  97.96   LR: 0.0005786          
[Epoch 28][V][78]   Loss: 8.8872e-01   Top-1:  82.21   LR: 0.000579

[34m
********************************************************************************
T Loss	0.06015491392542978
T Top-1	97.96181148590948
V Loss	0.8887209238529206
V Top-1	82.21

* Best model upate *


 56%|█████▌    | 28/50 [1:31:32<1:12:02, 196.47s/it]

Best acc1 82.21
********************************************************************************
[0m
[Epoch 29/50][T][1170]   Loss: 5.5933e-02   Top-1:  98.11   LR: 0.0005397          
[Epoch 29][V][78]   Loss: 8.6492e-01   Top-1:  82.47   LR: 0.000540

[34m
********************************************************************************
T Loss	0.055932854198437916
T Top-1	98.11325789923143
V Loss	0.864921402835846
V Top-1	82.47

* Best model upate *


 58%|█████▊    | 29/50 [1:34:50<1:08:53, 196.83s/it]

Best acc1 82.47
********************************************************************************
[0m
[Epoch 30/50][T][1170]   Loss: 4.9721e-02   Top-1:  98.32   LR: 0.0005005          
[Epoch 30][V][78]   Loss: 7.9878e-01   Top-1:  82.88   LR: 0.000501

[34m
********************************************************************************
T Loss	0.049721016787568365
T Top-1	98.31807749786508
V Loss	0.7987843943595886
V Top-1	82.88

* Best model upate *


 60%|██████    | 30/50 [1:38:08<1:05:40, 197.02s/it]

Best acc1 82.88
********************************************************************************
[0m
[Epoch 31/50][T][1170]   Loss: 4.7319e-02   Top-1:  98.36   LR: 0.0004613          
[Epoch 31][V][78]   Loss: 8.6971e-01   Top-1:  82.32   LR: 0.000461

[34m
********************************************************************************


 62%|██████▏   | 31/50 [1:41:25<1:02:23, 197.05s/it]

T Loss	0.04731949617454733
T Top-1	98.36344470538002
V Loss	0.869708578491211
V Top-1	82.32

Best acc1 82.88
********************************************************************************
[0m
[Epoch 32/50][T][1170]   Loss: 4.5249e-02   Top-1:  98.46   LR: 0.0004224          
[Epoch 32][V][78]   Loss: 8.9106e-01   Top-1:  82.86   LR: 0.000422

[34m
********************************************************************************


 64%|██████▍   | 32/50 [1:44:42<59:06, 197.01s/it]  

T Loss	0.04524936696596406
T Top-1	98.45684777967548
V Loss	0.8910640858650207
V Top-1	82.86

Best acc1 82.88
********************************************************************************
[0m
[Epoch 33/50][T][1170]   Loss: 3.9929e-02   Top-1:  98.67   LR: 0.0003839          
[Epoch 33][V][78]   Loss: 9.5703e-01   Top-1:  82.17   LR: 0.000384

[34m
********************************************************************************


 66%|██████▌   | 33/50 [1:47:58<55:48, 196.94s/it]

T Loss	0.03992925162183957
T Top-1	98.6683390264731
V Loss	0.9570272714614868
V Top-1	82.17

Best acc1 82.88
********************************************************************************
[0m
[Epoch 34/50][T][1170]   Loss: 3.6033e-02   Top-1:  98.81   LR: 0.0003461          
[Epoch 34][V][78]   Loss: 9.4037e-01   Top-1:  82.30   LR: 0.000346

[34m
********************************************************************************


 68%|██████▊   | 34/50 [1:51:16<52:32, 197.04s/it]

T Loss	0.03603274080916057
T Top-1	98.80577497865073
V Loss	0.9403681758880615
V Top-1	82.3

Best acc1 82.88
********************************************************************************
[0m
[Epoch 35/50][T][1170]   Loss: 3.2748e-02   Top-1:  98.90   LR: 0.0003093          
[Epoch 35][V][78]   Loss: 8.7819e-01   Top-1:  83.37   LR: 0.000309

[34m
********************************************************************************
T Loss	0.03274776446530643
T Top-1	98.8985108881298
V Loss	0.8781851999759674
V Top-1	83.37

* Best model upate *


 70%|███████   | 35/50 [1:54:34<49:20, 197.35s/it]

Best acc1 83.37
********************************************************************************
[0m
[Epoch 36/50][T][1170]   Loss: 3.1566e-02   Top-1:  98.90   LR: 0.0002737          
[Epoch 36][V][78]   Loss: 8.8928e-01   Top-1:  83.08   LR: 0.000274

[34m
********************************************************************************


 72%|███████▏  | 36/50 [1:57:51<46:04, 197.44s/it]

T Loss	0.03156575571644628
T Top-1	98.8991780529462
V Loss	0.8892827708244324
V Top-1	83.08

Best acc1 83.37
********************************************************************************
[0m
[Epoch 37/50][T][1170]   Loss: 2.8207e-02   Top-1:  99.04   LR: 0.0002395          
[Epoch 37][V][78]   Loss: 9.1197e-01   Top-1:  83.49   LR: 0.000240

[34m
********************************************************************************
T Loss	0.028207453379516287
T Top-1	99.03794833475662
V Loss	0.9119747738838195
V Top-1	83.49

* Best model upate *


 74%|███████▍  | 37/50 [2:01:09<42:48, 197.59s/it]

Best acc1 83.49
********************************************************************************
[0m
[Epoch 38/50][T][1170]   Loss: 2.4040e-02   Top-1:  99.16   LR: 0.0002069          
[Epoch 38][V][78]   Loss: 8.8895e-01   Top-1:  83.80   LR: 0.000207

[34m
********************************************************************************
T Loss	0.024039503738720942
T Top-1	99.16337532023911
V Loss	0.8889488147735596
V Top-1	83.8

* Best model upate *


 76%|███████▌  | 38/50 [2:04:27<39:31, 197.63s/it]

Best acc1 83.80
********************************************************************************
[0m
[Epoch 39/50][T][1170]   Loss: 2.3334e-02   Top-1:  99.19   LR: 0.0001761          
[Epoch 39][V][78]   Loss: 9.4361e-01   Top-1:  83.55   LR: 0.000176

[34m
********************************************************************************


 78%|███████▊  | 39/50 [2:07:44<36:12, 197.52s/it]

T Loss	0.023334075774981367
T Top-1	99.19339773697695
V Loss	0.9436096293449402
V Top-1	83.55

Best acc1 83.80
********************************************************************************
[0m
[Epoch 40/50][T][1170]   Loss: 2.0256e-02   Top-1:  99.33   LR: 0.0001473          
[Epoch 40][V][78]   Loss: 9.1123e-01   Top-1:  84.03   LR: 0.000147

[34m
********************************************************************************
T Loss	0.02025614267506519
T Top-1	99.32616353543979
V Loss	0.9112305102348328
V Top-1	84.03

* Best model upate *


 80%|████████  | 40/50 [2:11:03<32:57, 197.75s/it]

Best acc1 84.03
********************************************************************************
[0m
[Epoch 41/50][T][1170]   Loss: 1.7446e-02   Top-1:  99.41   LR: 0.0001207          
[Epoch 41][V][78]   Loss: 9.0484e-01   Top-1:  84.05   LR: 0.000121

[34m
********************************************************************************
T Loss	0.01744635159411079
T Top-1	99.41022630230572
V Loss	0.904835481262207
V Top-1	84.05

* Best model upate *


 82%|████████▏ | 41/50 [2:14:21<29:41, 197.89s/it]

Best acc1 84.05
********************************************************************************
[0m
[Epoch 42/50][T][1170]   Loss: 1.5982e-02   Top-1:  99.46   LR: 0.0000964          
[Epoch 42][V][78]   Loss: 8.7572e-01   Top-1:  84.44   LR: 0.000096

[34m
********************************************************************************
T Loss	0.01598197420006254
T Top-1	99.45959649871904
V Loss	0.8757159769058227
V Top-1	84.44

* Best model upate *


 84%|████████▍ | 42/50 [2:17:39<26:24, 198.10s/it]

Best acc1 84.44
********************************************************************************
[0m
[Epoch 43/50][T][1170]   Loss: 1.3670e-02   Top-1:  99.52   LR: 0.0000746          
[Epoch 43][V][78]   Loss: 8.8086e-01   Top-1:  84.90   LR: 0.000075

[34m
********************************************************************************
T Loss	0.013670090946265381
T Top-1	99.51964133219471
V Loss	0.880859190940857
V Top-1	84.9

* Best model upate *


 86%|████████▌ | 43/50 [2:20:58<23:07, 198.28s/it]

Best acc1 84.90
********************************************************************************
[0m
[Epoch 44/50][T][1170]   Loss: 1.2712e-02   Top-1:  99.58   LR: 0.0000554          
[Epoch 44][V][78]   Loss: 8.9581e-01   Top-1:  84.64   LR: 0.000055

[34m
********************************************************************************


 88%|████████▊ | 44/50 [2:24:17<19:50, 198.48s/it]

T Loss	0.012711574440279023
T Top-1	99.5750160119556
V Loss	0.895806789970398
V Top-1	84.64

Best acc1 84.90
********************************************************************************
[0m
[Epoch 45/50][T][1170]   Loss: 1.1603e-02   Top-1:  99.61   LR: 0.0000390          
[Epoch 45][V][78]   Loss: 9.1074e-01   Top-1:  84.27   LR: 0.000039

[34m
********************************************************************************


 90%|█████████ | 45/50 [2:27:36<16:33, 198.63s/it]

T Loss	0.01160347413959346
T Top-1	99.61304440649018
V Loss	0.9107387508392334
V Top-1	84.27

Best acc1 84.90
********************************************************************************
[0m
[Epoch 46/50][T][1170]   Loss: 9.8585e-03   Top-1:  99.67   LR: 0.0000254          
[Epoch 46][V][78]   Loss: 9.2343e-01   Top-1:  84.45   LR: 0.000025

[34m
********************************************************************************


 92%|█████████▏| 46/50 [2:30:54<13:14, 198.56s/it]

T Loss	0.009858457048498994
T Top-1	99.66775192143467
V Loss	0.9234324228286743
V Top-1	84.45

Best acc1 84.90
********************************************************************************
[0m
[Epoch 47/50][T][1170]   Loss: 9.5008e-03   Top-1:  99.67   LR: 0.0000148          
[Epoch 47][V][78]   Loss: 9.0768e-01   Top-1:  84.90   LR: 0.000015

[34m
********************************************************************************


 94%|█████████▍| 47/50 [2:34:13<09:55, 198.43s/it]

T Loss	0.009500830177395352
T Top-1	99.67042058070025
V Loss	0.9076848116874695
V Top-1	84.9

Best acc1 84.90
********************************************************************************
[0m
[Epoch 48/50][T][1170]   Loss: 9.3981e-03   Top-1:  99.68   LR: 0.0000071          
[Epoch 48][V][78]   Loss: 9.0215e-01   Top-1:  84.81   LR: 0.000007

[34m
********************************************************************************


 96%|█████████▌| 48/50 [2:37:31<06:37, 198.53s/it]

T Loss	0.00939814007344496
T Top-1	99.67842655849701
V Loss	0.9021532170295715
V Top-1	84.81

Best acc1 84.90
********************************************************************************
[0m
[Epoch 49/50][T][1170]   Loss: 9.0563e-03   Top-1:  99.69   LR: 0.0000025          
[Epoch 49][V][78]   Loss: 9.0026e-01   Top-1:  84.88   LR: 0.000003

[34m
********************************************************************************


 98%|█████████▊| 49/50 [2:40:50<03:18, 198.71s/it]

T Loss	0.009056332364848842
T Top-1	99.69310418445772
V Loss	0.900261706829071
V Top-1	84.88

Best acc1 84.90
********************************************************************************
[0m
[Epoch 50/50][T][1170]   Loss: 8.4135e-03   Top-1:  99.71   LR: 0.0000010          
[Epoch 50][V][78]   Loss: 8.9727e-01   Top-1:  84.88   LR: 0.000001

[34m
********************************************************************************


100%|██████████| 50/50 [2:44:10<00:00, 197.00s/it]
best top-1: 84.90, final top-1: 84.88
best top-1: 84.90, final top-1: 84.88
best top-1: 84.90, final top-1: 84.88
best top-1: 84.90, final top-1: 84.88
best top-1: 84.90, final top-1: 84.88
best top-1: 84.90, final top-1: 84.88
best top-1: 84.90, final top-1: 84.88
best top-1: 84.90, final top-1: 84.88


T Loss	0.008413474368789985
T Top-1	99.71045046968403
V Loss	0.897269169330597
V Top-1	84.88

Best acc1 84.90
********************************************************************************
[0m
[31m********************************************************************************
********************************************************************************[0m


## Download model

In [41]:
from google.colab import files
files.download(os.path.join(save_path, 'checkpoint.pth')) 
files.download(os.path.join(save_path, 'best.pth')) 
files.download(os.path.join(save_path, 'log.csv')) 

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>