# STEP1：Import Libraries 

In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

#use fastai v2
from fastai.vision.all import *  
from tqdm import tqdm
import cv2
import  os
import  zipfile

# STEP2：Utility

In [None]:
def rle_decode(rle, height, width , fill=255):
    s = rle.split()
    start, length = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    start -= 1
    mask = np.zeros(height*width, dtype=np.uint8)
    for i, l in zip(start, length):
        mask[i:i+l] = fill
    mask = mask.reshape(width,height).T
    mask = np.ascontiguousarray(mask)
    return mask

def rle2mask(rles, class_names, height, width, class_dict):
    img = np.zeros(height*width, dtype=np.uint16)
    for rle, class_name in zip(rles, class_names):
        s = rle.split(' ')
        starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
        starts -= 1
        ends = starts + lengths
        for lo, hi in zip(starts, ends):
            img[lo:hi] = class_dict[class_name]
        
    mask = img.reshape((width, height))
    return mask

In [None]:
#The purpose of creating debug states is to save time
debug = False

# STEP3：Data preparation

In [None]:
ROOT_DIR = '../input/uw-madison-gi-tract-image-segmentation/'
home_path = Path(ROOT_DIR)

In [None]:
file_paths = glob.glob(ROOT_DIR+'train/*/*/*/*')
file_paths[:3]

In [None]:
train_csv = pd.read_csv(home_path / 'train.csv')
train_csv = train_csv[train_csv['segmentation'].notnull()] #Remove empty line
train_csv.head(3)

In [None]:
#Create table from train image (38496 images need 2 min 13 second)
file_csv = pd.DataFrame(columns=['id','filename','filepath'])
for idx, filepath in tqdm(enumerate(file_paths)):
    case_day_str = filepath.split('/')[5]
    filename = filepath.split('/')[-1]
    slice_id = filename.split('_')[1]
    slice_str = f'slice_{slice_id}'
    idstr = case_day_str+'_'+slice_str
    file_csv.loc[idx] = [idstr, filename,filepath]

file_csv.head(3)

In [None]:
# merge file_csv into train_csv
train_csv = pd.merge(train_csv, file_csv, on=['id'])
train_csv.head(3)

In [None]:
#Fill table with other parameters
def get_img_height(row):
    return int(row.filename[:-4].split('_')[2])
def get_img_width(row):
    return int(row.filename[:-4].split('_')[3])
train_csv['img_height'] = train_csv.apply(lambda row: get_img_height(row), axis=1)
train_csv['img_width'] = train_csv.apply(lambda row: get_img_width(row), axis=1)
train_csv.head(3)

In [None]:
#save csv  
train_csv.to_csv('train_csv.csv')
#load csv to save time
#train_csv = pd.read_csv('../input/df-train-csv/df_train.csv')



In [None]:
#Create directories
os.mkdir('train')
os.mkdir('train/images')
os.mkdir('train/labels')

In [None]:
#Get the split identity
#0: background
#1: stomach
#2: large_bowel
#3: small_bowel
class2id = {class_name: idx+1 for idx, class_name in enumerate(train_csv['class'].unique())}
id2class = {v:k for k, v in class2id.items()}
id2class

In [None]:
# make copies of “images” and “labels”
grouped = train_csv.groupby('id')
for name,group in grouped:
    df_select = train_csv.groupby('id').get_group(name)
    filepath = df_select.filepath.values[0]
    image = cv2.imread(filepath, cv2.IMREAD_UNCHANGED)
    mask = rle2mask(df_select.segmentation.values,
                df_select['class'].values,
                df_select.img_height.values[0],
                df_select.img_width.values[0],
                class2id)
    #cv2.imwrite("./images/"+str(df_select.id.values),image)
    imgPath = "./train/images/"+df_select.id.values[0]+".png"
    mskPath = "./train/labels/"+df_select.id.values[0]+"_mask.png"
    #print(strpath)
    cv2.imwrite(imgPath,image)
    cv2.imwrite(mskPath,mask)

In [None]:
#If under debug ,zip train images and labels
startdir = "./train"  #要压缩的文件夹路径，这里选择将input中的所有文件压缩
file_news = './' +'train.zip' # 压缩后文件夹的名字，这里压缩到kaggle之中的output文件之中，名称为result.zip
z = zipfile.ZipFile(file_news,'w',zipfile.ZIP_DEFLATED) #参数一：文件夹名
for dirpath, dirnames, filenames in os.walk(startdir):
    fpath = dirpath.replace(startdir,'') #这一句很重要，不replace的话，就从根目录开始复制
    fpath = fpath and fpath + os.sep or ''#实现当前文件夹以及包含的所有文件的压缩
    for filename in filenames:
        z.write(os.path.join(dirpath, filename),fpath+filename)
z.close()
print ('压缩成功')

# STEP4：TRAIN

In [None]:
#View the file in pair
# train_path = Path('./train/')
# fnames = get_image_files(train_path /'images')
# lbl_names = get_image_files(train_path /'labels')

In [None]:
#View file details for images and lables and create mask methods
# print (fnames[0],lbl_names[0])
# get_mask = lambda o:'./train/labels/'+str(o.stem)+'_mask.png'

In [None]:
#Check out a pair of images and lables
# img_fn = fnames[random.randint(0,len(fnames))]
# im = PILImage.create(img_fn)
# im.show(figsize=(5,5))
# print(len(fnames))

In [None]:
# mask_fn = get_mask(img_fn)
# msk = PILMask.create(mask_fn)
# msk.show(figsize=(5,5), alpha=1)
# print(im.shape,msk.shape)

In [None]:
#make DataBlock
# binary = DataBlock(blocks=(ImageBlock, MaskBlock( ['Background', 'stomach', 'large_bowel','small_bowel'])),    
#                    get_items=get_image_files,   
#                    splitter=RandomSplitter(),    
#                    get_y=get_mask,               
#                    item_tfms=Resize(128,ResizeMethod.Squish),       # Modify "128" may change results OR OOM
#                    batch_tfms=[Normalize.from_stats(*imagenet_stats)])  

In [None]:
#Read the picture and display the sample
# dls = binary.dataloaders(train_path /'images',bs=3)  # Modify "bs=3" may change train time OR OOM
# dls.show_batch( vmin=0, vmax=3)

In [None]:
#You can try other options
#pay attention to downloading the model first

#model   : resnet34
#metrics : DICEMulti

#p = Path("/root/.cache/torch/hub/checkpoints")
#p.mkdir(parents=True)
#!cp ../input/resnet34/resnet34-b627a593.pth /root/.cache/torch/hub/checkpoints/resnet34-b627a593.pth

# learn = unet_learner(dls,models.resnet34,metrics=DiceMulti)

In [None]:
# learn.lr_find()

In [None]:
#Perform training, where better results can be obtained by adjusting lr
# if debug:
#     learn.fit_flat_cos(1)
# else:
#     learn.fit_flat_cos(12)

# learn.recorder.plot_loss()


In [None]:
#show result
# learn.show_results(max_n=4, figsize=(12,6))