# L03 Kata - Segmentation

In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
from fastai.vision import *
from fastai.callbacks.hooks import *
from fastai.utils.mem import *

### Exploring the data

In [None]:
path = untar_data(URLs.CAMVID_TINY)
path.ls()

In [None]:
path_lbl = path/'labels'
path_img = path/'images'

In [None]:
fnames = get_image_files(path_img)
fnames[:3]

In [None]:
lbl_names = get_image_files(path_lbl)
lbl_names[:3]

In [None]:
img_f = fnames[0]
img = open_image(img_f)
img.show(figsize=(5,5))

In [None]:
get_y_fn = lambda x: path_lbl/f'{x.stem}_P{x.suffix}'

In [None]:
mask = open_mask(get_y_fn(img_f))
mask.show(figsize=(5,5), alpha=1)

In [None]:
src_size = np.array(mask.shape[1:])
src_size,mask.data

In [None]:
codes = np.loadtxt(path/'codes.txt', dtype=str); codes

### Loading the data

In [None]:
size = src_size//2
bs=4

In [None]:
np.random.seed(42)
src = (SegmentationItemList.from_folder(path_img)
       .split_by_rand_pct(0.2)
       .label_from_func(get_y_fn, classes=codes))

In [None]:
data = (src.transform(get_transforms(), size=size, tfm_y=True)
        .databunch(bs=bs)
        .normalize(imagenet_stats))

In [None]:
data.show_batch(2, figsize=(10,7))

In [None]:
data.show_batch(2, figsize=(10,7), ds_type=DatasetType.Valid)

### Training

In [None]:
name2id = {v:k for k,v in enumerate(codes)}
void_code = name2id['Void']

def acc_camvid(input, target):
    target = target.squeeze(1)
    mask = target != void_code
    return (input.argmax(dim=1)[mask]==target[mask]).float().mean()

In [None]:
metrics=acc_camvid
wd=1e-2

In [None]:
learn = unet_learner(data, models.resnet34, metrics=metrics, wd=wd)

In [None]:
lr_find(learn)
learn.recorder.plot()

In [None]:
lr=3e-3

In [None]:
learn.fit_one_cycle(10, slice(lr), pct_start=0.9)

In [None]:
learn.save('stage-1')

In [None]:
learn.show_results(rows=3, figsize=(10,10))

### Fine tune model

In [None]:
learn.unfreeze()

In [None]:
lrs = slice(lr/400,lr/4)

In [None]:
learn.fit_one_cycle(12, lrs, pct_start=0.8)

In [None]:
learn.save('stage-2');

In [None]:
learn.show_results(rows=3, figsize=(10,10))

### Accuracy analysis

### Predict on CPU