# fastai training loop with the data-block API
fastai is a great tool to create a strong baseline quickly. I use pretty much out of the box approach for multilabel classification, with resnet50 backbone, one cycle training, lr finder etc. The data block API is a great way to prepare the data, and comes with a default set of augmentations that I use as well.

this is based on 
Solution overview: https://www.kaggle.com/c/hpa-single-cell-image-classification/discussion/221550

and forked to illustrate an issue with cutmix

In [None]:
!pip install /kaggle/input/iterative-stratification/iterative-stratification-master/

In [None]:
import pandas as pd
import numpy as np
from fastai.vision.all import *
import pickle
import os

In [None]:
# Making pretrained weights work without needing to find the default filename
if not os.path.exists('/root/.cache/torch/hub/checkpoints/'):
        os.makedirs('/root/.cache/torch/hub/checkpoints/')
!cp '../input/resnet50/resnet50.pth' '/root/.cache/torch/hub/checkpoints/resnet50-19c8e357.pth'
# !cp '../input/resnet34/resnet34.pth' '/root/.cache/torch/hub/checkpoints/resnet34-333f7ec4.pth'

In [None]:
path = Path('../input/hpa-cell-tiles-sample-balanced-dataset')

In [None]:
df = pd.read_csv(path/'cell_df.csv')

In [None]:
df.head()

In [None]:
len(df)

In [None]:
labels = [str(i) for i in range(19)]
for x in labels: df[x] = df['image_labels'].apply(lambda r: int(x in r.split('|')))

In [None]:
dfs = df.sample(frac=1, random_state=42)
dfs = dfs.reset_index(drop=True)
len(dfs)

In [None]:
unique_counts = {}
for lbl in labels:
    unique_counts[lbl] = len(dfs[dfs.image_labels == lbl])

full_counts = {}
for lbl in labels:
    count = 0
    for row_label in dfs['image_labels']:
        if lbl in row_label.split('|'): count += 1
    full_counts[lbl] = count
    
counts = list(zip(full_counts.keys(), full_counts.values(), unique_counts.values()))
counts = np.array(sorted(counts, key=lambda x:-x[1]))
counts = pd.DataFrame(counts, columns=['label', 'full_count', 'unique_count'])
counts.set_index('label').T


In [None]:
len(dfs)

In [None]:
nfold = 5
seed = 42

y = dfs[labels].values
X = dfs[['image_id', 'cell_id']].values

dfs['fold'] = np.nan

from iterstrat.ml_stratifiers import MultilabelStratifiedKFold
mskf = MultilabelStratifiedKFold(n_splits=nfold, random_state=seed)
for i, (_, test_index) in enumerate(mskf.split(X, y)):
    dfs.iloc[test_index, -1] = i
    
dfs['fold'] = dfs['fold'].astype('int')

In [None]:
dfs['is_valid'] = False
dfs['is_valid'][dfs['fold'] == 0] = True

In [None]:
dfs.is_valid.value_counts()

In [None]:
def get_x(r): return path/'cells'/(r['image_id']+'_'+str(r['cell_id'])+'.jpg')
img = get_x(dfs.loc[12])
img = PILImage.create(img)
img.show();

In [None]:
def get_y(r): return r['image_labels'].split('|')
get_y(dfs.loc[12])

In [None]:
sample_stats = ([0.07237246, 0.04476176, 0.07661699], [0.17179589, 0.10284516, 0.14199627])

In [None]:
item_tfms = RandomResizedCrop(224, min_scale=0.75, ratio=(1.,1.))
batch_tfms = [*aug_transforms(flip_vert=True, size=128,max_warp=0.2,max_lighting = 0.5,max_rotate =60), Normalize.from_stats(*sample_stats)]
bs=256

In [None]:
def get_y_bce(r): 
    
    categories = r['image_labels'].split('|')
    n_categories = len(categories)
    arr = np.zeros(len(labels))
    for l in categories:
        
        arr[int(l)]= 1 #+ np.log(1/n_categories)/5
    
    
    return arr

get_y_bce(dfs.loc[12])

In [None]:
dblock = DataBlock(blocks=(ImageBlock,RegressionBlock(n_out=len(labels))),
                splitter=ColSplitter(col='is_valid'),
                get_x=get_x,
                get_y=get_y_bce,
                item_tfms=item_tfms,
                batch_tfms=batch_tfms
                )
dls = dblock.dataloaders(dfs, bs=bs)

In [None]:
# dblock.summary(dfs)

In [None]:
dls.show_batch(nrows=9, ncols=1)

In [None]:
cutmix = CutMix(0.1)

In [None]:
learn = cnn_learner(dls, resnet50, metrics=[SpearmanCorrCoef()]).to_fp16()

 learn.lr_find()
# SuggestedLRs(lr_min=0.03630780577659607, lr_steep=0.02754228748381138)

In [None]:
lr=3e-2

# calback error

In [None]:
learn.fine_tune(1,base_lr=lr,cbs=cutmix)

In [None]:
learn.recorder.plot_loss()

# Workaround

In [None]:
class CutMix(MixHandler):
    "Implementation of `https://arxiv.org/abs/1905.04899`"
    def __init__(self, alpha=1.): super().__init__(alpha)
    def before_batch(self):
        bs, _, H, W = self.x.size()
        self.lam = self.distrib.sample((1,)).to(self.x.device)
        shuffle = torch.randperm(bs).to(self.x.device)
        xb1,self.yb1 = self.x[shuffle], tuple((self.y[shuffle],))
        x1, y1, x2, y2 = self.rand_bbox(W, H, self.lam)
        self.learn.xb[0][..., y1:y2, x1:x2] = xb1[..., y1:y2, x1:x2]
        self.lam = (1 - ((x2-x1)*(y2-y1))/float(W*H))
        if not self.stack_y:
            ny_dims = len(self.y.size())
            self.learn.yb = tuple(L(self.yb1,self.yb).map_zip(torch.lerp,weight=unsqueeze(self.lam, n=ny_dims-1)))

    def rand_bbox(self, W, H, lam):
        cut_rat = torch.sqrt(1. - lam).to(self.x.device)
        cut_w = torch.round(W * cut_rat).type(torch.long).to(self.x.device)
        cut_h = torch.round(H * cut_rat).type(torch.long).to(self.x.device)
        # uniform
        cx = torch.randint(0, W, (1,)).to(self.x.device)
        cy = torch.randint(0, H, (1,)).to(self.x.device)
        x1 = torch.clamp(cx - cut_w // 2, 0, W)
        y1 = torch.clamp(cy - cut_h // 2, 0, H)
        x2 = torch.clamp(cx + cut_w // 2, 0, W)
        y2 = torch.clamp(cy + cut_h // 2, 0, H)
        return x1, y1, x2, y2

In [None]:
cutmix = CutMix(0.5)

In [None]:
learn = cnn_learner(dls, resnet50,loss_func=torch.nn.BCEWithLogitsLoss(), metrics=[SpearmanCorrCoef()]).to_fp16()

In [None]:
learn.fine_tune(1,base_lr=lr,cbs=cutmix)

In [None]:
learn.recorder.plot_loss()