In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

import warnings
warnings.filterwarnings('ignore')

In [None]:
from fastai.vision import *
from fastai.callbacks.hooks import *
from fastai.utils.mem import *

In [None]:
path_data = Path('/tmp/vtdata/')

In [None]:
import glob
import os
import shutil 
from fastai.vision import *
from fastai.metrics import accuracy
from fastai.basic_data import *

from pathlib import Path

from matplotlib import pyplot as plt
from PIL import Image, ImageOps


def resize_one(fn, i, pth, size, path, padding_color=(255, 255, 255)):
    if os.path.getsize(fn) == 0:
        return
    dest = pth/fn.relative_to(path)
    if dest.exists():
        return
    dest.parent.mkdir(parents=True, exist_ok=True)
    target = size
    im =  PIL.Image.open(fn)
    if im is None:
        return
    width, height = im.size
    if width < 10 or height < 10:
        return
    ratio = 1
    if width > height:
        ratio = target/width
        targ_sz = (target, int(height*ratio))
        if int(height*ratio) <= 0 :
            return
        im_resize = im.resize(targ_sz, resample=PIL.Image.BILINEAR)
    else:
        ratio = target/height
        targ_sz = (int(width*ratio), target)
        if int(width*ratio) <= 0 :
            return
        im_resize = im.resize(targ_sz, resample=PIL.Image.BILINEAR)
        
    new_size = im_resize.size
    new_im = Image.new("RGB", (target, target), padding_color)
    new_im.paste(im_resize, ((target-new_size[0])//2,
                    (target-new_size[1])//2))
    new_im.save(dest, quality=100)

In [None]:
# create smaller image sets the first time this nb is run
! rm -rf /tmp/resize_pad
SZ = 1024
path_resize = Path("/tmp/resize_pad")
sets = [(path_resize, SZ)]
il = ImageList.from_folder(path_data)
for p,size in sets:
    os.makedirs(p, exist_ok=True)
    print(f"resizing to {size} into {p}")
    parallel(partial(resize_one, pth=p, size=size, path=path_data), il.items)

In [None]:
import glob 
import os

path_data = Path('/tmp/vtdata/')
path_images = glob.glob(f"{path_data}/**/*.jpg", recursive=True) + glob.glob(f"{path_data}/**/*.png", recursive=True) + glob.glob(f"{path_data}/**/*.JPG", recursive=True) + glob.glob(f"{path_data}/**/*.PNG", recursive=True)
path_labels = glob.glob(f"{path_data}/**/*.json", recursive=True)

dict_label = {}
for img_path in path_images:
    try:
        label_path = img_path.replace('images', 'labels')
        label_path = label_path.replace('png', 'json')
        dict_label[img_path] = label_path
    except Exception as e:
        pass

In [None]:
!  rm -rf '/tmp/mask_org_labels'; mkdir '/tmp/mask_org_labels'

import PIL
import json
import numpy as np
import os
from matplotlib.pyplot import imshow


def create_label_images(dict_label):
    for img_path, lbl_path in dict_label.items():
        try:
            
            img = PIL.Image.open(img_path)
            with open(lbl_path) as json_file:
                label = json.load(json_file)
                
            size = img.size
            
            w, h = size
            
            mask = np.array(PIL.Image.new('L', (w, h)))
            regions = label['attributes']['_via_img_metadata']['regions']
            ratio = 1
            for r in regions:
                shape = r['shape_attributes']
                x, y, w, h = int(shape['x']/ratio), int(shape['y']/ratio), int(shape['width']/ratio), int(shape['height']/ratio)
                mask[y:y+h, x:x+w] = 1
                
            mask_label_path = img_path.replace("/tmp/vtdata", '/tmp/mask_org_labels')
            dir_path = os.path.dirname(mask_label_path)
            if not os.path.exists(dir_path):
                os.makedirs(dir_path)
                
            mask = PIL.Image.fromarray(mask, 'L')
            mask.save(mask_label_path)
        except Exception as e:
            import traceback
            traceback.print_exc()
create_label_images(dict_label)

In [None]:
! rm -rf /tmp/mask_resize_labels
# create smaller image sets the first time this nb is run
SZ = 1024
path_resize = Path("/tmp/mask_resize_labels")
sets = [(path_resize, SZ)]
il = ImageList.from_folder(Path('/tmp/mask_org_labels'))
for p,size in sets:
    os.makedirs(p, exist_ok=True)
    print(f"resizing to {size} into {p}")
    parallel(partial(resize_one, pth=p, size=size, path=Path('/tmp/mask_org_labels'), padding_color=(0,0,0)), il.items)

In [None]:
from matplotlib.pyplot import imshow
%matplotlib inline
import cv2

imcv = cv2.imread('/tmp/mask_resize_labels/C2_train_3/images/4580.png') * 255
imshow(imcv)
print(imcv.shape)

In [None]:
imcv = cv2.imread('/tmp/resize_pad/C2_train_3/images/4580.png') 
imshow(imcv)
print(imcv.shape)

In [None]:
get_y_fn = lambda x:  str(x).replace("/tmp/resize_pad", '/tmp/mask_resize_labels')
src = (SegmentationItemList.from_folder(Path('/tmp/resize_pad'))
       .split_by_rand_pct(0.2)
       .label_from_func(get_y_fn, classes=['bg', 'vt'])
      )

In [None]:
src_size = np.array([512, 512])
print(src_size)

In [None]:
size = src_size
bs = 2
data = (src.transform(get_transforms(max_rotate=3), size=size, resize_method=ResizeMethod.SQUISH, tfm_y=True)
        .databunch(bs=bs, num_workers=0)).normalize(imagenet_stats)

In [None]:
data.show_batch(4, figsize=(10,7))

In [None]:
codes = ['bg', 'vt']
name2id = {v:k for k,v in enumerate(codes)}
void_code = name2id['bg']
print(name2id)
def acc_vt(input, target):
    target = target.squeeze(1)
    mask = target != void_code
    return (input.argmax(dim=1)[mask]==target[mask]).float().mean()

In [None]:
wd=1e-2

In [None]:
learn = unet_learner(data, models.resnet50, metrics=acc_vt, wd=wd)

In [None]:
lr_find(learn)
learn.recorder.plot()

In [None]:
lr=1e-4
learn.fit_one_cycle(5, slice(lr), pct_start=0.9)

In [None]:
learn.save('stage-1-1');

In [None]:
learn.show_results(rows=3, figsize=(9,11))

In [None]:
learn.unfreeze()

In [None]:
lrs = slice(lr/400,lr/4)

In [None]:
learn.fit_one_cycle(5, lrs, pct_start=0.8)

In [None]:
learn.show_results(rows=3, figsize=(9,11))

In [None]:
learn.save('stage-1');

In [None]:
learn.destroy()
learn=None
gc.collect()

In [None]:
size = np.array([960, 960])
bs = 1
data = (src.transform(get_transforms(max_rotate=1, max_warp=0, p_affine=1.0), size=size, resize_method=ResizeMethod.SQUISH, tfm_y=True)
        .databunch(bs=bs, num_workers=0))

In [None]:
data.show_batch(3, figsize=(10,7))

In [None]:
learn = unet_learner(data, models.resnet50, metrics=acc_vt, wd=wd)

In [None]:
learn.load('stage-1')

In [None]:
lr_find(learn)
learn.recorder.plot()

In [None]:
lr=1e-6
learn.fit_one_cycle(5, slice(lr), pct_start=0.8)

In [None]:
learn.save('stage-2-1');

In [None]:
learn.show_results(rows=3, figsize=(9,11))

In [None]:
learn.load('stage-2-1')

In [None]:
learn.unfreeze()

In [None]:
lr=1e-4
lrs = slice(1e-6,lr/10)

In [None]:
learn.fit_one_cycle(5, lrs)

In [None]:
learn.save('stage-2-2');

In [None]:
learn.show_results(rows=3, figsize=(9,11))