# STEP1：Import Libraries 

In [None]:
!pip install --user torch==1.9.0 torchvision==0.10.0 torchaudio==0.9.0 torchtext==0.10.0

In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline
!pip install segmentation-models-pytorch

#use fastai v2
from fastai.vision.all import *  
import segmentation_models_pytorch as smp
from tqdm import tqdm
import cv2
import  os
import  zipfile

# STEP2：Utility

In [None]:
def rle_decode(rle, height, width , fill=255):
    s = rle.split()
    start, length = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    start -= 1
    mask = np.zeros(height*width, dtype=np.uint8)
    for i, l in zip(start, length):
        mask[i:i+l] = fill
    mask = mask.reshape(width,height).T
    mask = np.ascontiguousarray(mask)
    return mask

def rle2mask(rles, class_names, height, width, class_dict):
    img = np.zeros(height*width, dtype=np.uint16)
    for rle, class_name in zip(rles, class_names):
        s = rle.split(' ')
        starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
        starts -= 1
        ends = starts + lengths
        for lo, hi in zip(starts, ends):
            img[lo:hi] = class_dict[class_name]
        
    mask = img.reshape((width, height))
    return mask

In [None]:
#The purpose of creating debug states is to save time
debug = False

# STEP3：Data preparation

In [None]:
ROOT_DIR = '../input/uw-madison-gi-tract-image-segmentation/'
home_path = Path(ROOT_DIR)

In [None]:
from zipfile import ZipFile
with ZipFile('../input/uwmgit-datapreparation-pytorch-fastai/train.zip', 'r') as zip_ref:
  zip_ref.extractall('')

In [None]:
image_path = home_path / 'train'
file_paths = glob.glob(ROOT_DIR+'train/*/*/*/*')
file_paths[:3]

In [None]:
train_csv = pd.read_csv(home_path / 'train.csv')
train_csv = train_csv[train_csv['segmentation'].notnull()] #Remove empty line
train_csv.head(3)

In [None]:
#Create table from train image (38496 images need 2 min 13 second)
file_csv = pd.DataFrame(columns=['id','filename','filepath'])
for idx, filepath in tqdm(enumerate(file_paths)):
    case_day_str = filepath.split('/')[5]
    filename = filepath.split('/')[-1]
    slice_id = filename.split('_')[1]
    slice_str = f'slice_{slice_id}'
    idstr = case_day_str+'_'+slice_str
    file_csv.loc[idx] = [idstr, filename,filepath]

file_csv.head(3)

In [None]:
# merge file_csv into train_csv
train_csv = pd.merge(train_csv, file_csv, on=['id'])
train_csv.head(3)

In [None]:
#Fill table with other parameters
def get_img_height(row):
    return int(row.filename[:-4].split('_')[2])
def get_img_width(row):
    return int(row.filename[:-4].split('_')[3])
train_csv['img_height'] = train_csv.apply(lambda row: get_img_height(row), axis=1)
train_csv['img_width'] = train_csv.apply(lambda row: get_img_width(row), axis=1)
train_csv.head(3)

In [None]:
#save csv  
#train_csv.to_csv('train_csv.csv')
#load csv to save time
#train_csv = pd.read_csv('../input/df-train-csv/df_train.csv')

if debug:
    train_csv = train_csv[:100]#Test small batches of data to save time

In [None]:
#Get the split identity
#0: background
#1: stomach
#2: large_bowel
#3: small_bowel
class2id = {class_name: idx+1 for idx, class_name in enumerate(train_csv['class'].unique())}
id2class = {v:k for k, v in class2id.items()}
id2class

# STEP4：TRAIN

In [None]:
#View the file in pair
train_path = Path('./')
fnames = get_image_files(train_path /'images')
lbl_names = get_image_files(train_path /'labels')

In [None]:
#View file details for images and lables and create mask methods
print (fnames[0],lbl_names[0])
get_mask = lambda o:'./labels/'+str(o.stem)+'_mask.png'

In [None]:
#Check out a pair of images and lables
img_fn = fnames[random.randint(0,len(fnames))]
im = PILImage.create(img_fn)
im.show(figsize=(5,5))
print(len(fnames))

In [None]:
mask_fn = get_mask(img_fn)
msk = PILMask.create(mask_fn)
msk.show(figsize=(5,5), alpha=1)
print(im.shape,msk.shape)

In [None]:
#make DataBlock
binary = DataBlock(blocks=(ImageBlock, MaskBlock( ['Background', 'stomach', 'large_bowel','small_bowel'])),    
                   get_items=get_image_files,   
                   splitter=RandomSplitter(),    
                   get_y=get_mask,               
                   item_tfms=Resize(224),       # Modify "128" may change results OR OOM ,ResizeMethod.Squish
                   batch_tfms=[Normalize.from_stats(*imagenet_stats)])  

In [None]:
#Read the picture and display the sample
dls = binary.dataloaders(train_path /'images',bs=4)  # Modify "bs=3" may change train time OR OOM
dls.show_batch( vmin=0, vmax=3)

In [None]:
def build_model(encoder_name):
    model = smp.UnetPlusPlus(
        encoder_name=encoder_name,      # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
        encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
        in_channels=3,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
        classes=4,        # model output channels (number of classes in your dataset)
        activation=None,
    )
    model.to('cuda')
    return model

In [None]:
import segmentation_models_pytorch as smp
smpNet = build_model('resnet34')
learn_smp = Learner(dls, smpNet, metrics=DiceMulti).to_fp16()

In [None]:
#You can try other options
#pay attention to downloading the model first

#model   : resnet34
#metrics : DICEMulti

# p = Path("/root/.cache/torch/hub/checkpoints")
# p.mkdir(parents=True)
# !cp ../input/resnet34/resnet34-b627a593.pth /root/.cache/torch/hub/checkpoints/resnet34-b627a593.pth
#learn = unet_learner(dls,resnet34,metrics=DiceMulti)
#learn = unet_learner(dls, resnet34, metrics=DiceMulti, self_attention=True, act_cls=Mish, opt_func=ranger)

In [None]:
#learn.lr_find()

In [None]:
#learn_smp.lr_find()

In [None]:
lr = 1e-3

In [None]:
#learn_smp.fit(1)

In [None]:
lrs = slice(lr/400, lr/4)

In [None]:
learn_smp.unfreeze()
learn_smp.fit_flat_cos(20, lrs)

In [None]:
learn_smp.save('model_learn_smp')

In [None]:
import dill
learn_smp.export('learn_smp_resnet_20epoch.pkl', pickle_module=dill)

In [None]:
!rm images -rf
!rm labels -rf

learn.fit_flat_cos(12)
11	0.020275	0.029574	0.885125	04:57
