# Image segmentation

In [None]:
!pip install -Uq fastai

In [None]:
!pip install -Uqq fastbook
from fastbook import *

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import numpy as np
import os
import time
import cv2
import pdb
import matplotlib
import fastai
from matplotlib import pyplot as plt
import nibabel as nib
from nibabel.testing import data_path
from PIL import Image
from fastai.vision.all import *
from fastai.callback.hook import *
from fastai.test_utils import *
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
%matplotlib inline
%matplotlib notebook
%reload_ext autoreload
%autoreload 2

In [None]:
%cd drive
%cd MyDrive
%cd train

In [None]:
path_lbl = 'labels/'
path_img = 'ct_images/'
path = Path('')

In [None]:
def n_codes(fnames, is_partial=True):
  "Gather the codes from a list of `fnames`"
  vals = set()
  if is_partial:
    random.shuffle(fnames)
    fnames = fnames[:10]
  for fname in fnames:
    msk = np.array(PILMask.create(fname))
    for val in np.unique(msk):
      if val not in vals:
        vals.add(val)
  vals = list(vals)
  p2c = dict()
  for i,val in enumerate(vals):
    p2c[i] = vals[i]
  return p2c

In [None]:
def get_msk(fn, pix2class):
  "Grab a mask from a `filename` and adjust the pixels based on `pix2class`"
  fn = get_y_fn(fn)
  msk = np.array(PILMask.create(fn))
  mx = np.max(msk)
  for i, val in enumerate(p2c):
    msk[msk==p2c[i]] = val
  return PILMask.create(msk)

In [None]:
# function to generate alterated mask from files
get_y = lambda o: get_msk(o,p2c)

In [None]:
# textfile with the labels sorted in order of codes generated from mask alteration code
codes = np.loadtxt(path/'codes.txt', dtype=str); codes

In [None]:
# retrieve image and label files and import into the code
img_names = get_image_files(path_img)
lbl_names = get_image_files(path_lbl)

In [None]:
# this code needs to run multiple times and the maximum length should be taken
# unsure as to why the input is changing on each run
p2c = n_codes(lbl_names)

In [None]:
# building data block to provide instructions to data loader on how to prepare batches
train_data = DataBlock(blocks=(ImageBlock, MaskBlock(codes)),
                   get_items=get_image_files,
                   splitter=RandomSplitter(),
                   get_y=get_y,
                   batch_tfms=[Normalize.from_stats(*imagenet_stats)])

In [None]:
# data loader creation with path to the CT images of the training data, batch size of 4 chosen
dls = train_data.dataloaders(path/'ct_images', bs=4)
dls.vocab = codes

In [None]:
# building the U-Net learner
learn = unet_learner(dls, resnet34, metrics=DiceMulti, self_attention=True, act_cls=Mish, opt_func=ranger)

In [None]:
# provides summary of the U-Net model
learn.summary()

# Stage 1

In [None]:
# used to find the ideal learning rate for the data
learn.lr_find()

In [None]:
lr = 1e-4
#lr denotes the maximum learning rate

In [None]:
# using a flat learning rate before later using cosine annealing, recommended by the ranger optimiser
# first parameter is the number of epochs that should be trained
# second paramter is the selection of learning rate
learn.fit_flat_cos(5, slice(lr))

In [None]:
# used to save the model trained up to this point
learn.save('stage-1');

In [None]:
# used to load the model -- uncomment only if unloading is required
#learn.load('stage-1');

In [None]:
# shows the results of the current model
learn.show_results(max_n=4, figsize=(20,20))

# Stage 2

In [None]:
# must unfreeze layers in the model for further training
learn.unfreeze()

In [None]:
# rule of thumb is to reduce learning rate as follows for second stage training
lrs = slice(lr/400,lr/4)

In [None]:
# using a flat learning rate before later using cosine annealing, recommended by the ranger optimiser
# first parameter is the number of epochs that should be trained
# second paramter is the selection of learning rate
learn.fit_flat_cos(6, lrs)

In [None]:
# used to save the model trained up to this point
learn.save('stage-2');

In [None]:
# used to load the model -- uncomment only if unloading is required
#learn.load('stage-2');

In [None]:
# show results after stage 2
learn.show_results(max_n=4, figsize=(20,20))

# Inference

In [None]:
%cd ..

# getting image files from the test data folder
fp = get_image_files('test/')

# taking a subset of the dataset and grabbing predictions
dl = learn.dls.test_dl(fp[:3])
preds= learn.get_preds(dl=dl)

# taking the first prediction and viewing
pred_1 = preds[0][1]
pred_arx = pred_1.argmax(dim=0)
plt.imshow(pred_arx)