# Description

This notebook contains visualizations of Mixup, Cutmix, Augmix and GridMask. Feel free to change the parameters and play around to see what works best for you. 

In [None]:
!pip install ../input/pretrainedmodels/pretrainedmodels-0.7.4/pretrainedmodels-0.7.4/ > /dev/null

In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

import fastai
from fastai.vision import *
from fastai.callbacks import SaveModelCallback
from torch.nn.modules.normalization import GroupNorm
import os
from torch.nn.modules import Conv2d
from sklearn.model_selection import KFold
from over9000 import *
from csvlogger import *
import pretrainedmodels
from mish_activation import *
import warnings
from fastai.vision import Image as Img
warnings.filterwarnings("ignore")

fastai.__version__

In [None]:
sz = 128
bs = 16
nfolds = 4 #keep the same split as the initial dataset
fold = 0
SEED = 2019
TRAIN = '../input/grapheme-imgs-128x128/'
LABELS = '../input/bengaliai-cv19/train.csv'
arch = pretrainedmodels.__dict__['se_resnext50_32x4d']
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

seed_everything(SEED)

# Data

In [None]:
df = pd.read_csv(LABELS)
nunique = list(df.nunique())[1:-1]
print(nunique)
df.head()

In [None]:
stats = ([0.0692], [0.2051])
data = (ImageList.from_df(df, path='.', folder=TRAIN, suffix='.png', 
        cols='image_id', convert_mode='L')
        .split_by_idx(range(fold*len(df)//nfolds,(fold+1)*len(df)//nfolds))
        .label_from_df(cols=['grapheme_root','vowel_diacritic','consonant_diacritic'])
        .transform(size=sz, padding_mode='zeros')
        .databunch(bs=bs)).normalize(stats)

data.show_batch()

In [None]:
class Head(nn.Module):
    def __init__(self, nc, n, ps=0.5):
        super().__init__()
        layers = [AdaptiveConcatPool2d(), Mish(), Flatten()] + \
            bn_drop_lin(nc*2, 512, True, ps, Mish()) + \
            bn_drop_lin(512, n, True, ps)
        self.fc = nn.Sequential(*layers)
        self._init_weight()
        
    def _init_weight(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                torch.nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1.0)
                m.bias.data.zero_()
        
    def forward(self, x):
        return self.fc(x)

#change the first conv to accept 1 chanel input
class Dnet_1ch(nn.Module):
    def __init__(self, arch=arch, n=nunique, pre=True, ps=0.5):
        super().__init__()
        m = arch(pretrained='imagenet') if pre else arch(pretrained=None)
#         convert_to_gem(m)
#         convert_to_conv2d(m)
#         convert_to_groupnorm(m)        
        conv = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        w = (m.layer0.conv1.weight.sum(1)).unsqueeze(1)
        conv.weight = nn.Parameter(w)
        
        self.layer0 = nn.Sequential(conv, m.layer0.bn1, m.layer0.relu1, m.layer0.pool)
        self.layer1 = m.layer1
        self.layer2 = m.layer2
        self.layer3 = m.layer3
        self.layer4 = nn.Sequential(m.layer4[0], m.layer4[1], m.layer4[2])

        
        nc = self.layer4[-1].se_module.fc2.out_channels #changes as per architecture
        self.head1 = Head(nc,n[0])
        self.head2 = Head(nc,n[1])
        self.head3 = Head(nc,n[2])
        #to_Mish(self.layer0), to_Mish(self.layer1), to_Mish(self.layer2)
        #to_Mish(self.layer3), to_Mish(self.layer4)
        
    def forward(self, x):    
        x = self.layer0(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        
        x1 = self.head1(x)
        x2 = self.head2(x)
        x3 = self.head3(x)
        
        return x1,x2,x3

# Loss

Cross entropy loss is applied independently to each part of the prediction and the result is summed with the corresponding weight.

In [None]:
class Loss_combine(nn.Module):
    def __init__(self):
        super().__init__()
        
    def forward(self, input, target,reduction='mean'):
        x1,x2,x3 = input
        x1,x2,x3 = x1.float(),x2.float(),x3.float()
        y = target.long()
        return 0.7*F.cross_entropy(x1,y[:,0],reduction=reduction) + 0.1*F.cross_entropy(x2,y[:,1],reduction=reduction) + \
          0.2*F.cross_entropy(x3,y[:,2],reduction=reduction)

The code below computes the competition metric and recall macro metrics for individual components of the prediction. The code is partially borrowed from fast.ai.

In [None]:
class Metric_idx(Callback):
    def __init__(self, idx, average='macro'):
        super().__init__()
        self.idx = idx
        self.n_classes = 0
        self.average = average
        self.cm = None
        self.eps = 1e-9
        
    def on_epoch_begin(self, **kwargs):
        self.tp = 0
        self.fp = 0
        self.cm = None
    
    def on_batch_end(self, last_output:Tensor, last_target:Tensor, **kwargs):
        last_output = last_output[self.idx]
        last_target = last_target[:,self.idx]
        preds = last_output.argmax(-1).view(-1).cpu()
        targs = last_target.long().cpu()
        
        if self.n_classes == 0:
            self.n_classes = last_output.shape[-1]
            self.x = torch.arange(0, self.n_classes)
        cm = ((preds==self.x[:, None]) & (targs==self.x[:, None, None])) \
          .sum(dim=2, dtype=torch.float32)
        if self.cm is None: self.cm =  cm
        else:               self.cm += cm

    def _weights(self, avg:str):
        if self.n_classes != 2 and avg == "binary":
            avg = self.average = "macro"
            warn("average=`binary` was selected for a non binary case. \
                 Value for average has now been set to `macro` instead.")
        if avg == "binary":
            if self.pos_label not in (0, 1):
                self.pos_label = 1
                warn("Invalid value for pos_label. It has now been set to 1.")
            if self.pos_label == 1: return Tensor([0,1])
            else: return Tensor([1,0])
        elif avg == "micro": return self.cm.sum(dim=0) / self.cm.sum()
        elif avg == "macro": return torch.ones((self.n_classes,)) / self.n_classes
        elif avg == "weighted": return self.cm.sum(dim=1) / self.cm.sum()
        
    def _recall(self):
        rec = torch.diag(self.cm) / (self.cm.sum(dim=1) + self.eps)
        if self.average is None: return rec
        else:
            if self.average == "micro": weights = self._weights(avg="weighted")
            else: weights = self._weights(avg=self.average)
            return (rec * weights).sum()
    
    def on_epoch_end(self, last_metrics, **kwargs): 
        return add_metrics(last_metrics, self._recall())
    
Metric_grapheme = partial(Metric_idx,0)
Metric_vowel = partial(Metric_idx,1)
Metric_consonant = partial(Metric_idx,2)

class Metric_tot(Callback):
    def __init__(self):
        super().__init__()
        self.grapheme = Metric_idx(0)
        self.vowel = Metric_idx(1)
        self.consonant = Metric_idx(2)
        
    def on_epoch_begin(self, **kwargs):
        self.grapheme.on_epoch_begin(**kwargs)
        self.vowel.on_epoch_begin(**kwargs)
        self.consonant.on_epoch_begin(**kwargs)
    
    def on_batch_end(self, last_output:Tensor, last_target:Tensor, **kwargs):
        self.grapheme.on_batch_end(last_output, last_target, **kwargs)
        self.vowel.on_batch_end(last_output, last_target, **kwargs)
        self.consonant.on_batch_end(last_output, last_target, **kwargs)
        
    def on_epoch_end(self, last_metrics, **kwargs): 
        return add_metrics(last_metrics, 0.5*self.grapheme._recall() +
                0.25*self.vowel._recall() + 0.25*self.consonant._recall())

In [None]:
#fix the issue in fast.ai of saving gradients along with weights
#so only weights are written, and files are ~4 times smaller

class SaveModelCallback(TrackerCallback):
    "A `TrackerCallback` that saves the model when monitored quantity is best."
    def __init__(self, learn:Learner, monitor:str='valid_loss', mode:str='auto',
                 every:str='improvement', name:str='bestmodel'):
        super().__init__(learn, monitor=monitor, mode=mode)
        self.every,self.name = every,name
        if self.every not in ['improvement', 'epoch']:
            warn(f'SaveModel every {self.every} is invalid, falling back to "improvement".')
            self.every = 'improvement'
                 
    def jump_to_epoch(self, epoch:int)->None:
        try: 
            self.learn.load(f'{self.name}_{epoch-1}', purge=False)
            print(f"Loaded {self.name}_{epoch-1}")
        except: print(f'Model {self.name}_{epoch-1} not found.')

    def on_epoch_end(self, epoch:int, **kwargs:Any)->None:
        "Compare the value monitored to its best score and maybe save the model."
        if self.every=="epoch":
            if epoch==31:
                self.learn.save(f'{self.name}_{epoch}')
#             torch.save(learn.model.state_dict(),f'{self.name}_{epoch}.pth')
        else: #every="improvement"
            current = self.get_monitor_value()
            if current is not None and self.operator(current, self.best):
                #print(f'Better model found at epoch {epoch} \
                #  with {self.monitor} value: {current}.')
                self.best = current
                self.learn.save(f'{self.name}')
#                 torch.save(learn.model.state_dict(),f'{self.name}.pth')

    def on_train_end(self, **kwargs):
        "Load the best model."
        if self.every=="improvement" and os.path.isfile(f'{self.name}.pth'):
            #self.learn.load(f'{self.name}', purge=False)
            self.model.load_state_dict(torch.load(f'{self.name}.pth'))

In [None]:
class EpochCallback(TrackerCallback):
    "A `TrackerCallback` that stops training after specified epochs"
    def __init__(self, learn:Learner):
        super().__init__(learn)

    def on_epoch_end(self, epoch:int, **kwargs:Any)->None:
        "Stop training"
        if epoch==31:
            torch.save(learn.opt.state_dict(),f'latest_optimizer_{epoch}.pth')
            torch.save(learn.model.state_dict(),f'latest_model_{epoch}.pth')
            return {"stop_training":True}

# LookAhead

In [None]:
from collections import defaultdict

import torch
from torch.optim.optimizer import Optimizer


class Lookahead(Optimizer):
    r"""PyTorch implementation of the lookahead wrapper.
    Lookahead Optimizer: https://arxiv.org/abs/1907.08610
    """

    def __init__(self, optimizer, la_steps=5, la_alpha=0.8, pullback_momentum="none"):
        """optimizer: inner optimizer
        la_steps (int): number of lookahead steps
        la_alpha (float): linear interpolation factor. 1.0 recovers the inner optimizer.
        pullback_momentum (str): change to inner optimizer momentum on interpolation update
        """
        self.optimizer = optimizer
        self._la_step = 0  # counter for inner optimizer
        self.la_alpha = la_alpha
        self._total_la_steps = la_steps
        pullback_momentum = pullback_momentum.lower()
        assert pullback_momentum in ["reset", "pullback", "none"]
        self.pullback_momentum = pullback_momentum

        self.state = defaultdict(dict)

        # Cache the current optimizer parameters
        for group in optimizer.param_groups:
            for p in group['params']:
                param_state = self.state[p]
                param_state['cached_params'] = torch.zeros_like(p.data)
                param_state['cached_params'].copy_(p.data)
                if self.pullback_momentum == "pullback":
                    param_state['cached_mom'] = torch.zeros_like(p.data)

    def __getstate__(self):
        return {
            'state': self.state,
            'optimizer': self.optimizer,
            'la_alpha': self.la_alpha,
            '_la_step': self._la_step,
            '_total_la_steps': self._total_la_steps,
            'pullback_momentum': self.pullback_momentum
        }

    def zero_grad(self):
        self.optimizer.zero_grad()

    def get_la_step(self):
        return self._la_step

    def state_dict(self):
        return self.optimizer.state_dict()

    def load_state_dict(self, state_dict):
        self.optimizer.load_state_dict(state_dict)

    def _backup_and_load_cache(self):
        """Useful for performing evaluation on the slow weights (which typically generalize better)
        """
        for group in self.optimizer.param_groups:
            for p in group['params']:
                param_state = self.state[p]
                param_state['backup_params'] = torch.zeros_like(p.data)
                param_state['backup_params'].copy_(p.data)
                p.data.copy_(param_state['cached_params'])

    def _clear_and_load_backup(self):
        for group in self.optimizer.param_groups:
            for p in group['params']:
                param_state = self.state[p]
                p.data.copy_(param_state['backup_params'])
                del param_state['backup_params']

    @property
    def param_groups(self):
        return self.optimizer.param_groups

    def step(self, closure=None):
        """Performs a single Lookahead optimization step.
        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = self.optimizer.step(closure)
        self._la_step += 1

        if self._la_step >= self._total_la_steps:
            self._la_step = 0
            # Lookahead and cache the current optimizer parameters
            for group in self.optimizer.param_groups:
                for p in group['params']:
                    param_state = self.state[p]
                    p.data.mul_(self.la_alpha).add_(1.0 - self.la_alpha, param_state['cached_params'])  # crucial line
                    param_state['cached_params'].copy_(p.data)
                    if self.pullback_momentum == "pullback":
                        internal_momentum = self.optimizer.state[p]["momentum_buffer"]
                        self.optimizer.state[p]["momentum_buffer"] = internal_momentum.mul_(self.la_alpha).add_(
                            1.0 - self.la_alpha, param_state["cached_mom"])
                        param_state["cached_mom"] = self.optimizer.state[p]["momentum_buffer"]
                    elif self.pullback_momentum == "reset":
                        self.optimizer.state[p]["momentum_buffer"] = torch.zeros_like(p.data)

        return loss

# Augmix

In [None]:
def preprocess(image):
    return torch.from_numpy(image)

In [None]:
from albumentations import (Rotate, GaussNoise, GaussianBlur, MedianBlur, RandomScale, Compose, OneOf, DualTransform, RandomBrightness,
                            RandomContrast, MotionBlur, Solarize, Equalize, Posterize, ShiftScaleRotate,
                           IAASharpen, IAAAffine)


aug_list = [GaussNoise(p=1), GaussianBlur(p=1),
           RandomBrightness(p=1), RandomContrast(p=1), ShiftScaleRotate(p=1, rotate_limit=20)]



# GridMask

In [None]:
from albumentations.augmentations import functional as Func

class GridMask(DualTransform):
    """GridMask augmentation for image classification and object detection.
    
    Author: Qishen Ha
    Email: haqishen@gmail.com
    2020/01/29

    Args:
        num_grid (int): number of grid in a row or column.
        fill_value (int, float, lisf of int, list of float): value for dropped pixels.
        rotate ((int, int) or int): range from which a random angle is picked. If rotate is a single int
            an angle is picked from (-rotate, rotate). Default: (-90, 90)
        mode (int):
            0 - cropout a quarter of the square of each grid (left top)
            1 - reserve a quarter of the square of each grid (left top)
            2 - cropout 2 quarter of the square of each grid (left top & right bottom)

    Targets:
        image, mask

    Image types:
        uint8, float32

    Reference:
    |  https://arxiv.org/abs/2001.04086
    |  https://github.com/akuxcw/GridMask
    """

    def __init__(self, num_grid=3, fill_value=0, rotate=0, mode=0, always_apply=False, p=0.5):
        super(GridMask, self).__init__(always_apply, p)
        if isinstance(num_grid, int):
            num_grid = (num_grid, num_grid)
        if isinstance(rotate, int):
            rotate = (-rotate, rotate)
        self.num_grid = num_grid
        self.fill_value = fill_value
        self.rotate = rotate
        self.mode = mode
        self.masks = None
        self.rand_h_max = []
        self.rand_w_max = []

    def init_masks(self, height, width):
        if self.masks is None:
            self.masks = []
            n_masks = self.num_grid[1] - self.num_grid[0] + 1
            for n, n_g in enumerate(range(self.num_grid[0], self.num_grid[1] + 1, 1)):
                grid_h = height / n_g
                grid_w = width / n_g
                this_mask = np.ones((int((n_g + 1) * grid_h), int((n_g + 1) * grid_w))).astype(np.uint8)
                for i in range(n_g + 1):
                    for j in range(n_g + 1):
                        this_mask[
                             int(i * grid_h) : int(i * grid_h + grid_h / 2),
                             int(j * grid_w) : int(j * grid_w + grid_w / 2)
                        ] = self.fill_value
                        if self.mode == 2:
                            this_mask[
                                 int(i * grid_h + grid_h / 2) : int(i * grid_h + grid_h),
                                 int(j * grid_w + grid_w / 2) : int(j * grid_w + grid_w)
                            ] = self.fill_value
                
                if self.mode == 1:
                    this_mask = 1 - this_mask

                self.masks.append(this_mask)
                self.rand_h_max.append(grid_h)
                self.rand_w_max.append(grid_w)

    def apply(self, image, mask, rand_h, rand_w, angle, **params):
        h, w = image.shape[:2]
        mask = Func.rotate(mask, angle) if self.rotate[1] > 0 else mask
        mask = mask[:,:,np.newaxis] if image.ndim == 3 else mask
        image *= mask[rand_h:rand_h+h, rand_w:rand_w+w].astype(image.dtype)
        return image

    def get_params_dependent_on_targets(self, params):
        img = params['image']
        height, width = img.shape[:2]
        self.init_masks(height, width)

        mid = np.random.randint(len(self.masks))
        mask = self.masks[mid]
        rand_h = np.random.randint(self.rand_h_max[mid])
        rand_w = np.random.randint(self.rand_w_max[mid])
        angle = np.random.randint(self.rotate[0], self.rotate[1]) if self.rotate[1] > 0 else 0

        return {'mask': mask, 'rand_h': rand_h, 'rand_w': rand_w, 'angle': angle}

    @property
    def targets_as_params(self):
        return ['image']

    def get_transform_init_args_names(self):
        return ('num_grid', 'fill_value', 'rotate', 'mode')

# MixUp

The code below modifies fast.ai MixUp calback to make it compatible with the current data.

In [None]:
class MixUpLoss(Module):
    "Adapt the loss function `crit` to go with mixup."
    
    def __init__(self, crit, reduction='mean'):
        super().__init__()
        if hasattr(crit, 'reduction'): 
            self.crit = crit
            self.old_red = crit.reduction
            setattr(self.crit, 'reduction', 'none')
        else: 
            self.crit = partial(crit, reduction='none')
            self.old_crit = crit
        self.reduction = reduction
        
    def forward(self, output, target):
        if len(target.shape) == 2 and target.shape[1] == 7:
            loss1, loss2 = self.crit(output,target[:,0:3].long()), self.crit(output,target[:,3:6].long())
            d = loss1 * target[:,-1] + loss2 * (1-target[:,-1])
        else:  d = self.crit(output, target)
        if self.reduction == 'mean':    return d.mean()
        elif self.reduction == 'sum':   return d.sum()
        return d
    
    def get_old(self):
        if hasattr(self, 'old_crit'):  return self.old_crit
        elif hasattr(self, 'old_red'): 
            setattr(self.crit, 'reduction', self.old_red)
            return self.crit

# CutMix

In [None]:
# https://github.com/oguiza/DataAugmentation/blob/master/ImageDataAugmentation.py

def rand_bbox(last_input_size, λ):
    '''lambd is always between .5 and 1'''

    W = last_input_size[-1]
    H = last_input_size[-2]
    cut_rat = np.sqrt(1. - λ) # 0. - .707
    cut_w = np.int(W * cut_rat)
    cut_h = np.int(H * cut_rat)

    # uniform
    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)

    return bbx1, bby1, bbx2, bby2

In [None]:
def cutout(x, n_holes:uniform_int=1, length:uniform_int=40):
    "Cut out `n_holes` number of square holes of size `length` in image at random locations."
    h,w = x.shape[1:]
    for n in range(n_holes):
        h_y = np.random.randint(0, h)
        h_x = np.random.randint(0, w)
        y1 = int(np.clip(h_y - length / 2, 0, h))
        y2 = int(np.clip(h_y + length / 2, 0, h))
        x1 = int(np.clip(h_x - length / 2, 0, w))
        x2 = int(np.clip(h_x + length / 2, 0, w))
        x[:, y1:y2, x1:x2] = 1
    return x

# Choice CallBack

In [None]:
class Choice(LearnerCallback):
    def __init__(self, learn:Learner, auglist, mixup_alpha:float=0.4, stack_x:bool=False, stack_y:bool=True, prob=0.5, cutmix_alpha:float=0.4,
                aug_alpha:float=1, mix_depth:float=-1, mixture_width:int=3):
        super().__init__(learn)
        self.cut_alpha,self.alpha,self.stack_x,self.stack_y = cutmix_alpha, mixup_alpha,stack_x,stack_y
        self.aug_alpha,self.mix_depth, self.mixture_width, self.aug_list = aug_alpha, mix_depth, mixture_width, auglist
        
    def on_train_begin(self, **kwargs):
        if self.stack_y: self.learn.loss_func = MixUpLoss(self.learn.loss_func)
            
    def aug(self,image,aug_list, alpha=1, mixture_width=3, mix_depth=3):
        ws = np.float32(np.random.dirichlet([alpha] * mixture_width))
        m = np.float32(np.random.beta(alpha, alpha))

        mix = torch.zeros_like(preprocess(image))
        for i in range(mixture_width):
            image_aug = image.copy()
            depth = mix_depth if mix_depth > 0 else np.random.randint(1, 4)
            
            for _ in range(depth):
                op = np.random.choice(aug_list)
                image_aug = op(image = image_aug)['image']
    # Preprocessing commutes since all coefficients are convex
            mix += ws[i] * preprocess(image_aug)

        mixed = (1 - m) * preprocess(image) + m * mix
        return mixed
            
    def on_batch_begin(self, last_input, last_target, train, epoch,**kwargs):
        if ((epoch in range(0,10)) or (epoch in range(40,50)) or (epoch in range(80,90))):
            name='Mixup'
        if ((epoch in range(10,20)) or (epoch in range(50,60)) or (epoch in range(90,100))):
            name='Cutmix'
        if ((epoch in range(20,30)) or (epoch in range(60,70))):
            name='Augmix'
        if ((epoch in range(30,40)) or (epoch in range(70,80))):
            name='GridMask'
        
        print(name)
        if name=='Mixup':
            if not train: return
            self.learn.loss_func = MixUpLoss(self.learn.loss_func)
            lambd = np.random.beta(self.alpha, self.alpha, last_target.size(0))
            lambd = np.concatenate([lambd[:,None], 1-lambd[:,None]], 1).max(1)
            lambd = last_input.new(lambd)
            shuffle = torch.randperm(last_target.size(0)).to(last_input.device)
            x1, y1 = last_input[shuffle], last_target[shuffle]
            if self.stack_x:
                new_input = [last_input, last_input[shuffle], lambd]
            else: 
                out_shape = [lambd.size(0)] + [1 for _ in range(len(x1.shape) - 1)]
                new_input = (last_input * lambd.view(out_shape) + x1 * (1-lambd).view(out_shape))
            if self.stack_y:
                new_target = torch.cat([last_target.float(), y1.float(), lambd[:,None].float()], 1)
            else:
                if len(last_target.shape) == 2:
                    lambd = lambd.unsqueeze(1).float()
                new_target = last_target.float() * lambd + y1.float() * (1-lambd)
            return {'last_input': new_input, 'last_target': new_target}
        
        elif name=='Cutmix':
            if not train: return
            self.learn.loss_func = MixUpLoss(self.learn.loss_func)
            lambd = np.random.beta(self.cut_alpha, self.cut_alpha)
            lambd = max(lambd, 1- lambd)
            shuffle = torch.randperm(last_target.size(0)).to(last_input.device)
            
            x1, y1 = last_input[shuffle], last_target[shuffle]
            #Get new input
            last_input_size = last_input.shape
            bbx1, bby1, bbx2, bby2 = rand_bbox(last_input.size(), lambd)
            new_input = last_input.clone()
            new_input[:, ..., bby1:bby2, bbx1:bbx2] = last_input[shuffle, ..., bby1:bby2, bbx1:bbx2]
            lambd = last_input.new([lambd])
            if self.stack_x:
                lambd = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (last_input_size[-1] * last_input_size[-2]))
                lambd = last_input.new([lambd])
            if self.stack_y:
                new_target = torch.cat([last_target.float(), y1.float(),
                                    lambd.repeat(last_input_size[0]).unsqueeze(1).float()], 1)
            else:
                if len(last_target.shape) == 2:
                    lambd = lambd.unsqueeze(1).float()
                new_target = last_target.float() * lambd + y1.float() * (1-lambd)
            return {'last_input': new_input, 'last_target': new_target}
        
        elif name=='Augmix':
            "Applies mixup to `last_input` and `last_target` if `train`."
            if not train: return
            self.learn.loss_func = Loss_combine()
            num_images = last_input.size(0)
            new_input = last_input.clone()
            for i in range(num_images):
                image = new_input[i,:,:,:].permute(1,2,0).cpu().numpy()
                image = self.aug(image, self.aug_list, alpha=self.aug_alpha, mixture_width = self.mixture_width, mix_depth=self.mix_depth)
                new_input[i,:,:,:] = image.permute(2,0,1)
            return {'last_input': new_input ,'last_target': last_target}
        
        elif name=='GridMask':
            if not train: return
            self.learn.loss_func = Loss_combine()
            tfms = Compose([OneOf([GridMask(num_grid=(10,15), rotate=10, mode=0, fill_value=0),
                                   GridMask(num_grid=(10,15), rotate=10, mode=2, fill_value=0),
                                   GridMask(num_grid=(10,15), rotate=0, mode=0, fill_value=0),
                                   GridMask(num_grid=(10,15), rotate=0, mode=2, fill_value=0)], p=1)])

            num_images = last_input.size(0)
            new_input = last_input.clone()
            for i in range(num_images):
                image = new_input[i,:,:,:].permute(1,2,0).cpu().numpy()
                image = tfms(image = image)['image']
                new_input[i,:,:,:] = torch.from_numpy(image).permute(2,0,1)
            
            return {'last_input': new_input, 'last_target': last_target}
        
        elif name=='Cutout':
            if not train: return
            self.learn.loss_func = Loss_combine()
            new_input = last_input.clone()
            for i in range(last_input.size(0)):
                hole = np.random.choice([4,8])
                new_input[i,:,:,:] = cutout(new_input[i,:,:,:], n_holes=hole, length=10)
            
            return {'last_input': new_input, 'last_target': last_target}
        
    def on_train_end(self, **kwargs):
        if self.stack_y: 
            try:
                self.learn.loss_func = self.learn.loss_func.get_old()
            except:
                self.learn.loss_func = MixUpLoss(self.learn.loss_func)

# Training

I have performed a check of different optimizers and schedules on a [similar task](https://www.kaggle.com/c/Kannada-MNIST/discussion/122430), and [Over9000 optimizer](https://github.com/mgrankin/over9000) cosine annealing **without warm-up** worked the best. Freezing the backbone at the initial stage of training didn't give me any advantage in that test, so here I perform the training straight a way with discriminative learning rate (smaller lr for backbone).

In [None]:
from torch.optim import Adam

def lookahead_adam(params, la_alpha=0.8, la_steps=5, *args, **kwargs):
     adam = Adam(params, *args, **kwargs)
     
     return Lookahead(adam, la_alpha, la_steps)

In [None]:
def choice(learn:Learner, auglist, mixup_alpha:float=0.4, stack_x:bool=False, stack_y:bool=True, prob=0.5, cutmix_alpha:float=1.0,
                aug_alpha:float=1, mix_depth:float=-1, mixture_width:int=3) -> Learner:
    "Add cutmix https://arxiv.org/pdf/1905.04899.pdf to `learn`."
    learn.callback_fns.append(partial(Choice, auglist, mixup_alpha=mixup_alpha, stack_x=stack_x, stack_y=stack_y, prob=prob, cutmix_alpha=cutmix_alpha,
                aug_alpha=aug_alpha, mix_depth=mix_depth, mixture_width=mixture_width))
    return learn

setattr(choice, 'cb_fn', Choice)
Learner.choice = choice

In [None]:
model = Dnet_1ch(pre=False)
learn = Learner(data, model, loss_func=Loss_combine(),opt_func=lookahead_adam,
        metrics=[Metric_grapheme(),Metric_vowel(),Metric_consonant(),Metric_tot()]).choice(aug_list)
logger = CSVLogger(learn,f'log{fold}')
learn.clip_grad = 1.0
learn.split([model.head1])
learn.unfreeze()

In [None]:
model

In [None]:
learn.summary()

In [None]:
# learn.fit_one_cycle(100,max_lr=slice(0.2e-2,1e-2),wd=[1e-3,0.1e-1], pct_start=0.0, 
#                     div_factor=100, callbacks=[logger, SaveModelCallback(learn,monitor='metric_tot',
#     mode='max',name=f'seresnext', every='epoch'), Choice(learn,auglist=aug_list),EpochCallback(learn)])

In [None]:
xb, yb = learn.data.one_batch()

In [None]:
cb = learn.callback_fns[1]
cb_fn = partial(cb.func, **cb.keywords)

In [None]:
[Img(xb[0]).show(ax=ax) for i, ax in enumerate(plt.subplots(4, 4, figsize=(15,15))[1].flatten()) ]
plt.show()

In [None]:
[Img(cb_fn(learn, aug_list).on_batch_begin(xb, yb, True,5)['last_input'][0]).show(ax=ax)
 for i, ax in enumerate(plt.subplots(4, 4, figsize=(15,15))[1].flatten())]
plt.show()

In [None]:
[Img(cb_fn(learn, aug_list).on_batch_begin(xb, yb, True,15)['last_input'][0]).show(ax=ax) for i, ax in enumerate(plt.subplots(4, 4, figsize=(15,15))[1].flatten())]
plt.show()

In [None]:
[Img(cb_fn(learn, aug_list).on_batch_begin(xb, yb, True,25)['last_input'][0]).show(ax=ax) for i, ax in enumerate(plt.subplots(4, 4, figsize=(15,15))[1].flatten())]
plt.show()

In [None]:
[Img(cb_fn(learn, aug_list).on_batch_begin(xb, yb, True,35)['last_input'][0]).show(ax=ax) for i, ax in enumerate(plt.subplots(4, 4, figsize=(15,15))[1].flatten())]
plt.show()

In [None]:
xb, yb = learn.data.one_batch()
[Img(xb[0]).show(ax=ax) for i, ax in enumerate(plt.subplots(4, 4, figsize=(15,15))[1].flatten()) ]
plt.show()

In [None]:
[Img(cb_fn(learn, aug_list).on_batch_begin(xb, yb, True,45)['last_input'][0]).show(ax=ax)
 for i, ax in enumerate(plt.subplots(4, 4, figsize=(15,15))[1].flatten())]
plt.show()

In [None]:
[Img(cb_fn(learn, aug_list).on_batch_begin(xb, yb, True,55)['last_input'][0]).show(ax=ax)
 for i, ax in enumerate(plt.subplots(4, 4, figsize=(15,15))[1].flatten())]
plt.show()

In [None]:
[Img(cb_fn(learn, aug_list).on_batch_begin(xb, yb, True,65)['last_input'][0]).show(ax=ax)
 for i, ax in enumerate(plt.subplots(4, 4, figsize=(15,15))[1].flatten())]
plt.show()

In [None]:
[Img(cb_fn(learn, aug_list).on_batch_begin(xb, yb, True,75)['last_input'][0]).show(ax=ax)
 for i, ax in enumerate(plt.subplots(4, 4, figsize=(15,15))[1].flatten())]
plt.show()

In [None]:
[Img(cb_fn(learn, aug_list).on_batch_begin(xb, yb, True,85)['last_input'][0]).show(ax=ax)
 for i, ax in enumerate(plt.subplots(4, 4, figsize=(15,15))[1].flatten())]
plt.show()

In [None]:
[Img(cb_fn(learn, aug_list).on_batch_begin(xb, yb, True,95)['last_input'][0]).show(ax=ax)
 for i, ax in enumerate(plt.subplots(4, 4, figsize=(15,15))[1].flatten())]
plt.show()