# Imports, hyperparams

In [0]:
from fastai2.basics           import *
from fastai2.vision.all       import *
from fastai2.callback.tracker import *
from fastai2.callback.all     import *
import os
from sklearn.model_selection import KFold
from sklearn.metrics import recall_score
import warnings
warnings.filterwarnings("ignore")

In [0]:
sz = 128
bs = 128
nfolds = 5 
fold = 0  # change this parameter manually
SEED = 1337
cnt = Path('../input')
md = cnt/'bengaliai-cv19'
DRIVE = cnt/'bengali_weights/xresent50/'
LABELS = cnt/'folded_data/train_with_fold.csv'
TRAIN = cnt/'train_images'
arch = xresnet50
arch_name = 'xresnet50'

def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

seed_everything(SEED)

In [12]:
df = pd.read_csv(LABELS)
nunique = list(df.nunique())[1:-1]
print(nunique)

[168, 11, 7, 1295, 200840]


# Preprocessing

In [0]:
class GridMask(RandTransform):
    order = 101
    def __init__(self, p=0.5, num_grid=3, fill_value=0, rotate=0, mode=0):
        super().__init__(p=p)
        if isinstance(num_grid, int): num_grid = (num_grid, num_grid)
        if isinstance(rotate, int): rotate = (-rotate, rotate)
        self.num_grid = num_grid
        self.fill_value = fill_value
        self.rotate = rotate
        self.mode = mode
        self.masks = None
        self.rand_h_max = []
        self.rand_w_max = []

    def init_masks(self, height, width):
        if self.masks is None:
            self.masks = []
            n_masks = self.num_grid[1] - self.num_grid[0] + 1
            for n, n_g in enumerate(range(self.num_grid[0], self.num_grid[1] + 1, 1)):
                grid_h = height / n_g
                grid_w = width / n_g
                this_mask = torch.ones((int((n_g + 1) * grid_h), int((n_g + 1) * grid_w)), dtype=torch.uint8)
                for i in range(n_g + 1):
                    for j in range(n_g + 1):
                        this_mask[
                             int(i * grid_h) : int(i * grid_h + grid_h / 2),
                             int(j * grid_w) : int(j * grid_w + grid_w / 2)
                        ] = self.fill_value
                        if self.mode == 2:
                            this_mask[
                                 int(i * grid_h + grid_h / 2) : int(i * grid_h + grid_h),
                                 int(j * grid_w + grid_w / 2) : int(j * grid_w + grid_w)
                            ] = self.fill_value
                
                if self.mode == 1:
                    this_mask = 1 - this_mask

                self.masks.append(this_mask)
                self.rand_h_max.append(grid_h)
                self.rand_w_max.append(grid_w)

    def get_params(self, img):
        height, width = img.shape[-2:]
        self.init_masks(height, width)

        mid = np.random.randint(len(self.masks))
        mask = self.masks[mid]
        rand_h = np.random.randint(self.rand_h_max[mid])
        rand_w = np.random.randint(self.rand_w_max[mid])
        angle = np.random.randint(self.rotate[0], self.rotate[1]) if self.rotate[1] > 0 else 0

        return mask, rand_h, rand_w, angle
    
    def encodes(self, image:TensorImage, **params):
        mask, rand_h, rand_w, angle = self.get_params(image)
        h, w = image.shape[-2:]
        mask = afunc.rotate(mask, angle) if self.rotate[1] > 0 else mask
        mask = mask[:,:,np.newaxis] if image.ndim == 3 else mask
        mask = tensor(mask.float()).to(default_device())
        image *= mask[rand_h:rand_h+h, rand_w:rand_w+w]
        return image

In [0]:
get_image = lambda x: TRAIN/f'{x[0]}.png'
def get_labels(x): return tensor(x[1:4].astype('uint8'))

In [0]:
gv = L(o for o in df['grapheme_root'].unique() if o==o).sorted()
vv = L(o for o in df['vowel_diacritic'].unique() if o==o).sorted()
cv = L(o for o in df['consonant_diacritic'].unique() if o==o).sorted()

In [16]:
cs = L(len(gv),len(vv),len(cv))
vocab = L(gv,vv,cv); vocab

(#3) [(#168) [0,1,2,3,4,5,6,7,8,9...],(#11) [0,1,2,3,4,5,6,7,8,9...],(#7) [0,1,2,3,4,5,6]]

In [0]:
class MEMCategorize(Transform):
    loss_func,order=CrossEntropyLossFlat(),1
    def __init__(self, vocab): self.vocab,self.c = vocab,L(len(cls) for cls in vocab)
    def encodes(self, o): return TensorCategory(tensor(o).float())
    def decodes(self, o): return MultiCategory (tensor(o))

In [0]:
split_idx = IndexSplitter(df.loc[df.fold==fold].index)(df)

In [0]:
type_tfms = [[get_image, PILImageBW.create, ToTensor], [get_labels, MEMCategorize(vocab=vocab)]]
item_tfms = [ToTensor]
batch_tfms = [IntToFloatTensor, Brightness(max_lighting=0.3, p=0.5), Contrast(), 
             GridMask(p=0., num_grid=(3,7)), RandomErasing(p=0., sh=0.15),
             Warp(magnitude=0.1), AffineCoordTfm(size=sz)]

In [0]:
dsrc = Datasets(df.values, type_tfms, splits=split_idx)
tdl = TfmdDL(dsrc, bs=bs, after_item=item_tfms, after_batch=batch_tfms, device=default_device())

In [21]:
xb,_ = tdl.one_batch()
mega_batch_stats = xb.mean(), xb.std()
mega_batch_stats

(TensorImageBW(0.0716, device='cuda:0'),
 TensorImageBW(0.2045, device='cuda:0'))

In [0]:
batch_tfms += [Normalize.from_stats(*mega_batch_stats)]

In [0]:
tdl_train = TfmdDL(dsrc.train, bs=bs, after_item=item_tfms, after_batch=batch_tfms, device=default_device())
tdl_valid = TfmdDL(dsrc.valid, bs=bs, after_item=item_tfms, after_batch=batch_tfms, device=default_device())

In [0]:
dbch = DataLoaders(tdl_train, tdl_valid, device=default_device())

In [0]:
# dbch.show_batch(max_n=3)

# Constructing model

In [0]:
class MishFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        ctx.save_for_backward(x)
        return x * torch.tanh(F.softplus(x))   # x * tanh(ln(1 + exp(x)))

    @staticmethod
    def backward(ctx, grad_output):
        x = ctx.saved_variables[0]
        sigmoid = torch.sigmoid(x)
        tanh_sp = torch.tanh(F.softplus(x)) 
        return grad_output * (tanh_sp + x * sigmoid * (1 - tanh_sp * tanh_sp))

class Mish(Module):
    def forward(self, x):
        return MishFunction.apply(x)

def to_Mish(model):
    for child_name, child in model.named_children():
        if isinstance(child, nn.ReLU):
            setattr(model, child_name, Mish())
        else:
            to_Mish(child)

In [0]:
class Head(Module):
    def __init__(self, nc, n, ps=0.5):
        self.fc = nn.Sequential(*[AdaptiveConcatPool2d(), Mish(), Flatten(),
             LinBnDrop(nc*2, 512, True, ps, Mish()),
             LinBnDrop(512, n, True, ps)])
        self._init_weight()
        
    def _init_weight(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                torch.nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1.0)
                m.bias.data.zero_()
        
    def forward(self, x):
        return self.fc(x)

class CascadeModel(Module):
    def __init__(self, arch=arch, n=dbch.c, pre=True):
        m = arch(pre)
        m = nn.Sequential(*children_and_parameters(m)[:-4])
        conv = nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1, bias=False)
        w = (m[0][0].weight.sum(1)).unsqueeze(1)
        conv.weight = nn.Parameter(w)
        m[0][0] = conv
        nc = m(torch.zeros(2, 1, sz, sz)).detach().shape[1]
        self.body = m
        self.heads = nn.ModuleList([Head(nc, c) for c in n])
        
    def forward(self, x):    
        x = self.body(x)
        return [f(x) for f in self.heads]

In [0]:
class CascadeDnet(Module):
    def __init__(self, arch=arch, n=dbch.c, pre=True):
        m = arch(pre)
        conv = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)

        w = (m.features.conv0.weight.sum(1)).unsqueeze(1)
        conv.weight = nn.Parameter(w)

        self.layer0 = nn.Sequential(conv, m.features.norm0, nn.ReLU(inplace=True))
        self.layer1 = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False),
            m.features.denseblock1)
        self.layer2 = nn.Sequential(m.features.transition1, m.features.denseblock2)
        self.layer3 = nn.Sequential(m.features.transition2, m.features.denseblock3)
        self.layer4 = nn.Sequential(m.features.transition3, m.features.denseblock4,
                                    m.features.norm5)
        nc = self.layer4[-1].weight.shape[0]
        self.body = nn.Sequential(self.layer0, self.layer1, self.layer2, self.layer3, self.layer4)
        self.heads = nn.ModuleList([Head(nc, c) for c in n])
        
    def forward(self, x):
        x = self.body(x)
        return [f(x) for f in self.heads]

In [0]:
class OHEM(Module):
    def __init__(self, top_k=0.7, weights=[0.7, 0.1, 0.2]):
        super(OHEM, self).__init__()
        self.loss = F.cross_entropy
        self.top_k = top_k
        self.weights = weights
    
    def forward(self, input, target, cb_reduction='mean', index=None):
        y,loss = target.long(),0
        
        for idx, row in enumerate(input):
            gt = y[:, idx]
            loss += self.weights[idx] * self.loss(row, gt, reduction='none', ignore_index=-1)

        if self.top_k == 1: valid_loss = loss

        self.index = torch.topk(loss, int(self.top_k * loss.size()[0])) if index is None else index
        valid_loss = loss[self.index[1]]

        return valid_loss.mean() if cb_reduction == 'mean' else valid_loss

# Metrics, callbacks, useful scripts 

In [0]:
class RecallPartial(Metric):
    # based on AccumMetric
    "Stores predictions and targets on CPU in accumulate to perform final calculations with `func`."
    def __init__(self, a=0, **kwargs):
        self.func = partial(recall_score, average='macro', zero_division=0)
        self.a = a

    def reset(self): self.targs,self.preds = tensor([]), tensor([])

    def accumulate(self, learn):
        fp,sp,tp = learn.pred
        preds,targs = torch.stack((fp.argmax(-1),sp.argmax(-1),tp.argmax(-1)), dim=-1).float(),learn.y
        preds,targs = to_detach(preds),to_detach(targs)
        self.preds = torch.cat((self.preds, preds))
        self.targs = torch.cat((self.targs, targs))

    @property
    def value(self):
        if len(self.preds) == 0: return
        return self.func(self.targs[:, self.a], self.preds[:, self.a])

    @property
    def name(self): return df.columns[self.a+1]
    
class RecallCombine(Metric):
    def accumulate(self, learn):
        scores = [learn.metrics[i].value for i in range(3)]
        self.combine = np.average(scores, weights=[7,1,2])

    @property
    def value(self):
        return self.combine

In [0]:
class MyTrackCallback(Callback):
    run_after,run_valid = [Normalize],False
    def __init__(self, augs, probs): self.augs,self.probs = augs,probs    
        
    def aug_tracker(self, augs, probs): return augs[int(np.random.choice(len(augs), 1, p=probs))]
    
    def begin_batch(self): self.learn.condition = self.aug_tracker(self.augs, self.probs)

In [0]:
from torch.distributions.beta import Beta
def NoLoss(*o): pass
class CustomMixUp(Callback):
    run_after,run_valid = MyTrackCallback,False
    def __init__(self, alpha=0.4): self.distrib = Beta(tensor(alpha), tensor(alpha))
        
    def begin_fit(self):self.loss_func0 = self.learn.loss_func
        
    def begin_batch(self):
        if self.learn.condition != self.__class__.__name__: return
        self.dls.after_batch.fs[-1].p,self.dls.after_batch.fs[-2].p = 0.,0.
        self.learn.loss_func = NoLoss
        lam = self.distrib.sample((self.y[:, 0].size(0),)).squeeze().to(self.x[0].device)
        lam = torch.stack([lam, 1-lam], 1)
        self.lam = lam.max(1)[0]
        shuffle = torch.randperm(self.y[:, 0].size(0)).to(self.x.device)
        xb1 = tuple(L(self.xb).itemgot(shuffle))
        yb1 = tuple([self.yb[i][shuffle] for i in range(len(self.yb))])
        nx_dims = len(self.x.size())
        self.learn.xb = tuple(L(xb1,self.xb).map_zip(torch.lerp,weight=unsqueeze(self.lam, n=nx_dims-1)))
        self.learn.yb = yb1,self.yb
        
    def after_loss(self):
        if self.learn.condition != self.__class__.__name__: return
        loss0 = self.loss_func0(self.learn.pred, *self.learn.yb[0], cb_reduction='none')
        loss1 = self.loss_func0(self.learn.pred, *self.learn.yb[1], cb_reduction='none', index=self.loss_func0.index)
        # loss1 = self.loss_func0(self.learn.pred, *self.learn.yb[1], cb_reduction='none')
        self.learn.loss = torch.lerp(loss0, loss1, self.lam[self.loss_func0.index[1]]).mean()
        # self.learn.loss = torch.lerp(loss0, loss1, self.lam).mean()
        self.learn.loss_func = self.loss_func0

In [0]:
def rand_bbox(size, lam):
    W = size[2]
    H = size[3]
    cut_rat = np.sqrt(1. - lam)
    cut_w = np.int(W * cut_rat)
    cut_h = np.int(H * cut_rat)

    # uniform
    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)
    return bbx1,bby1,bbx2,bby2

In [0]:
class CutMix(Callback):
    run_after,run_valid = MyTrackCallback,False
    def __init__(self, alpha=1., stack_y=True): self.alpha,self.stack_y = alpha,stack_y

    def begin_fit(self): self.loss_func0 = self.learn.loss_func

    def begin_batch(self):
        if self.learn.condition != self.__class__.__name__: return
        self.dls.after_batch.fs[-1].p,self.dls.after_batch.fs[-2].p = 0.,0.
        self.learn.loss_func = NoLoss
        lam = np.random.beta(self.alpha, self.alpha)
        shuffle = torch.randperm(self.y[:, 0].size(0)).to(self.x.device)
        yb1 = TensorCategory(*[self.yb[i][shuffle] for i in range(len(self.yb))])
        last_input_size = self.x.shape
        bbx1, bby1, bbx2, bby2 = rand_bbox(last_input_size, lam)
        new_input = self.x.clone()
        new_input[:, ..., bby1:bby2, bbx1:bbx2] = self.x[shuffle, ..., bby1:bby2, bbx1:bbx2]
        self.learn.xb = tuple([new_input])
        lam = self.x.new([lam])
        if self.stack_y:
            lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (last_input_size[-1] * last_input_size[-2]))
            lam = self.x.new([lam])
            self.learn.yb = tuple([torch.cat([self.y.float(), yb1.float(), lam.repeat(last_input_size[0]).unsqueeze(1).float()], 1)])
        else:
            if len(learn.y.shape) == 2:
                lam = lam.unsqueeze(1).float()
            self.learn.yb = tuple([self.yb.float() * lam + yb1.float() * (1-lam)])


    def after_loss(self):
        if self.learn.condition != self.__class__.__name__: return
        self.learn.loss = self.loss_func0(self.learn.pred, *self.learn.yb, cb_reduction='mean')
        self.learn.loss_func = self.loss_func0

In [0]:
class EraseCallback(Callback):
    run_after,run_valid = CustomMixUp,False

    def begin_fit(self): self.loss_func0 = self.learn.loss_func

    def begin_batch(self):
      self.mu,self.gm,self.cutout = self.learn.condition == 'CustomMixUp',self.learn.condition == 'GridMask',self.learn.condition == 'Cutout'
      if self.mu: return
      self.learn.loss_func = NoLoss
    
    # def after_batch(self):
      if self.gm: 
        self.dls.after_batch.fs[-2].p=0.
        self.dls.after_batch.fs[-1].p=0.8
      elif self.cutout: 
        self.dls.after_batch.fs[-1].p=0.
        self.dls.after_batch.fs[-2].p=0.8

    def after_loss(self):
      if self.mu: return
      self.learn.loss = self.loss_func0(self.learn.pred, *self.learn.yb, cb_reduction='mean')
      self.learn.loss_func = self.loss_func0

In [0]:
def show(img, i=0): return TensorImage(img[i,...]).show()

In [0]:
al,probs = ['Cutout', 'GridMask', 'CustomMixUp', 'CutMix'],[0.1, 0.1, 0.5, 0.3]

# Training loop

In [38]:
model = CascadeModel()
# model = CascadeDnet()

Downloading: "https://s3.amazonaws.com/fast-ai-modelzoo/xrn50_940.pth" to /root/.cache/torch/checkpoints/xrn50_940.pth


HBox(children=(IntProgress(value=0, max=256198016), HTML(value='')))




In [0]:
learn = Learner(dbch, model, loss_func=OHEM(), 
               cbs=[MyTrackCallback(al, probs), CustomMixUp(), CutMix(), EraseCallback()],
               metrics=[RecallPartial(a=i) for i in range(len(dbch.c))] + [RecallCombine()],
               splitter=lambda m: [list(m.body.parameters()), list(m.heads.parameters())],
                model_dir=DRIVE/'models')

In [40]:
learn.to_fp16()

<fastai2.learner.Learner at 0x7fb1ce6084a8>

In [0]:
learn.load(DRIVE/'models/augs_xresnet50_0')

In [0]:
# learn.recorder.lr_find()

In [0]:
learn.fit_one_cycle(50, slice(1e-3, 1e-2), cbs=[SaveModelCallback(fname=f'augs_xresnet50_{fold}'), ReduceLROnPlateau(patience=2)])

epoch,train_loss,valid_loss,grapheme_root,vowel_diacritic,consonant_diacritic,recall_combine,time


In [0]:
learn.recorder.values

In [0]:
v = learn.recorder.values[-3]
rc = float(v[-1])
val_loss = float(v[1])

In [0]:
model_fn = f'{arch_name}_rc{rc:0.4f}_validloss{val_loss:0.4f}_fold{fold+1}_of_{nfolds}'
model_fn

In [0]:
learn.save(DRIVE/f'models/{model_fn}')

In [0]:
learn.load(DRIVE/'models/xresnet50_rc0.9653_validloss0.3888_fold1_of_5')

In [0]:
learn.fit_one_cycle(14, slice(1e-3), cbs=[SaveModelCallback(monitor='recall_combine', fname=f'model_{fold}'), ReduceLROnPlateau(patience=2)])