## Setup

In [None]:
from fastai.vision.all import *

## Loading the DataFrame

In [None]:
df = pd.read_csv('../input/sartorius-cell-instance-segmentation/train.csv')
df.tail(1)

### Splitting in 5 folds based on image and cell type

In [None]:
from sklearn.model_selection import StratifiedKFold
img_df = df[['id', 'cell_type']].drop_duplicates().reset_index(drop = True)
img_df['fold'] = -1
skf = StratifiedKFold(n_splits = 5, random_state = 42, shuffle = True)
for fold, (train_index, test_index) in enumerate(skf.split(img_df['id'], img_df['cell_type'])):
    img_df.loc[test_index, 'fold'] = fold
img_df.to_csv('train_fold.csv', index = False)
img_df.tail()

In [None]:
img_df.groupby('fold')['cell_type'].value_counts().to_frame().T

### Building final DataFrame

In [None]:
df = pd.read_csv('./train_fold.csv')
df = pd.concat([df, pd.get_dummies(df['fold'], prefix = 'fold', dtype = bool)], axis = 1).drop('fold', axis = 1)
df.tail(1)

## Dataloaders

In [None]:
## Global variables
BS = 32
WORKERS = 4
BASE_DIR = '../input/sartorius-cell-instance-segmentation/train/'
FILE_EXT = '.png'
FOLD = 'fold_0'

In [None]:
dblock = DataBlock(
    blocks = (ImageBlock, CategoryBlock),
    get_x = ColReader('id', pref = BASE_DIR, suff = FILE_EXT),
    get_y = ColReader('cell_type'),
    splitter = ColSplitter(FOLD)
)
dls = dblock.dataloaders(df, bs = BS, num_workers = WORKERS)
dls.show_batch(figsize = (30, 22))

## Basic CNN Learner with resnet18

In [None]:
learn = cnn_learner(dls, resnet18, metrics = accuracy)

## Finding optimal LR and training for 5 epochs

In [None]:
learn.lr_find()

In [None]:
learn.fine_tune(5, 1e-3)

## Visualizing the results

In [None]:
learn.show_results(figsize = (30, 22))

In [None]:
interp = ClassificationInterpretation.from_learner(learn)

In [None]:
interp.plot_confusion_matrix()

In [None]:
## Temporary fix for the broken version of plot_top_losses (untill they update the kaggle repo)
def plot_top_losses_fix(interp, k, largest=True, **kwargs):
        losses,idx = interp.top_losses(k, largest)
        if not isinstance(interp.inputs, tuple): interp.inputs = (interp.inputs,)
        if isinstance(interp.inputs[0], Tensor): inps = tuple(o[idx] for o in interp.inputs)
        else: inps = interp.dl.create_batch(interp.dl.before_batch([tuple(o[i] for o in interp.inputs) for i in idx]))
        b = inps + tuple(o[idx] for o in (interp.targs if is_listy(interp.targs) else (interp.targs,)))
        x,y,its = interp.dl._pre_show_batch(b, max_n=k)
        b_out = inps + tuple(o[idx] for o in (interp.decoded if is_listy(interp.decoded) else (interp.decoded,)))
        x1,y1,outs = interp.dl._pre_show_batch(b_out, max_n=k)
        if its is not None:
            plot_top_losses(x, y, its, outs.itemgot(slice(len(inps), None)), interp.preds[idx], losses,  **kwargs)

In [None]:
plot_top_losses_fix(interp, 3, figsize = (30, 8))