# HuBMAP - Hacking the Kidney - Kaggle Competition

The [Kaggle competition page](https://www.kaggle.com/c/hubmap-kidney-segmentation)

Helpful Notebooks:
* [https://www.kaggle.com/markalavin/hubmap-tile-images-w-overlap-and-build-tfrecords](https://www.kaggle.com/markalavin/hubmap-tile-images-w-overlap-and-build-tfrecords)

## ToDO
* Look at impact of different affine matrices
* Look at impact of removing alpha channel on model size and performance
* Add Deepmind's architecture optimizer

## Package Downloads for Offline use

## Setup

In [1]:
TEST = False

In [2]:
#!conda update -n base conda

In [3]:
#! conda config --set always_yes True
#! conda install -c fastai -c pytorch fastai
#! conda install pytorch torchvision torchaudio fastai -c pytorch
#! conda update pytorch torchvision torchaudio cudatoolkit -c pytorch
#! conda install pandas
#! conda install -c conda-forge kaggle
#! conda install -c conda-forge tifffile
#! conda install -c conda-forge tqdm
# !conda install -c conda-forge matplotlib
#! conda install -c conda-forge pytorch-lightning
#! conda install -c conda-forge wandb
#! conda install -c conda-forge arrow
#!conda install -c conda-forge pickle5

In [4]:
#! pip install arrow pickle5

In [5]:
# needed if running in wsl2
#! pip install pytorch-lightning wandb

In [6]:
# I have no idea why the conda-forge version doesn't work

#!python -m pip install opencv-python

# If you are running this notebook on a server (like Linux on WSL2) you need the headless version of opencv
# The regular opencv requires GUI packages that serves dont have, and will raise an error
#!python -m pip install opencv-python-headless

# temporary solution to use tab complete - something wrong with jupyter jedi - need to downgrade
#!pip install jedi==0.17.2

#!pip install torchio --upgrade

#!pip install pytorch-lightning-bolts

In [7]:
#!pip install --upgrade ssl

Ensure the finicky local CUDA is running

In [8]:
# First, import PyTorch
import torch
import torchvision
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F

# Check PyTorch version
torch.__version__
torch.cuda.is_available()

True

In [9]:
import pandas as pd
import numpy as np
import pytorch_lightning as pl
from pytorch_lightning.callbacks import Callback, ModelCheckpoint

# prebuilt models
from pl_bolts.models import UNet

import tensorboard as tb

# Need to put kaggle.json in /%USERS%/.kaggle folder (C:/Users/Craig/.kaggle)
#import kaggle

from pathlib import Path
import random
import os
import shutil
from typing import Union

# Read tiff images
#import tifffile
import cv2
from tqdm import tqdm
import torchio as tio

from PIL import Image
from IPython.display import display

import time
import wandb
import arrow


import matplotlib.pyplot as plt

# Memory management tools
import gc

from fastai.vision.all import *
from fastai.imports import *
from fastai.callback.wandb import *

import pickle5 as pickle

  f' install it with `pip install {pypi_name}`.' + extra_text


In [10]:
path = Path()
#kaggle.api.competition_download_files("hubmap-kidney-segmentation", path=paLearner

Ensure you are about to download the data in the cvorrect directory

In [11]:
#path.ls()

Unzip the data in the correct folder - commented out so as to not repeat the unzipping

In [12]:
import zipfile

#with zipfile.ZipFile(path/"hubmap-kidney-segmentation.zip", 'r') as zipref:
#    zipref.extractall(path)

In [13]:
train_df = pd.read_csv(path/"train.csv").rename(columns={"id": "img_id"})

In [14]:
path = Path()
#path.ls()

## Helper Functions

In [15]:
################
# Main Functions
################


def rle2mask(mask_rle, shape):
    '''
    mask_rle: encoding string value from csv
    shape: (width,height) of array to return
    Returns numpy array, 1 - mask, 0 - background
    '''
    s = mask_rle.split()
    # return a list of starting pixels and a list of lengths
    starts, lengths = [
        np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])
    ]
    # subtract 1 from every starting pixel
    starts -= 1
    ends = starts + lengths
    # calculate a background of 0 (empty) with size defined by image
    img = np.zeros(shape[0] * shape[1], dtype=np.uint8)
    # replace every 0 within each range with 1
    for lo, hi in zip(starts, ends):
        img[lo : hi] = 1
    return img.reshape(shape).T

def mask2rle(x):
    dots = np.where(x.T.flatten() == 1)[0]
    run_lengths = []
    prev = -2
    for b in dots:
        if (b>prev+1): run_lengths.extend((b + 1, 0))
        run_lengths[-1] += 1
        prev = b
    return run_lengths


def get_id_by_index(index, df=train_df):
    return df.iloc[index]['img_id']

def get_single_img(id, folder="train"):
    img = tifffile.imread(path/folder/(id+".tiff"))
    if len(img.shape) == 5:
        img = img.squeeze().transpose(1, 2, 0)
    return img

def show_single_img(id, **kwargs):
    return plt.imshow(get_single_img(id), **kwargs)

def show_img_by_index(index, df=train_df):
    return plt.imshow(tifffile.imread(path/"train"/(train_df.iloc[TEST_IMAGE_INDEX]['id']+".tiff")))

def get_single_encs(id, df=train_df):
    return df[df['img_id'] == id]['encoding'].array[0]

def get_mask(id, df=train_df, folder="train"):
    return rle2mask(
        get_single_encs(id, df=df),
        get_single_img(id, folder=folder).shape[::-1][1:]
    )

def show_single_img_and_mask_by_id(id):
    plt.figure(figsize=(16, 10))
    
    mask = get_mask(id)
    img = get_single_img(id)

    plt.subplot(1, 3, 1)
    plt.imshow(img)
    plt.title(f"Image", fontsize=18)
    
    plt.subplot(1, 3, 2)
    plt.imshow(img)
    plt.imshow(mask, cmap="hot", alpha=0.5)
    plt.title(f"Image + mask", fontsize=18)    
    
    plt.subplot(1, 3, 3)
    plt.imshow(mask, cmap="hot")
    plt.title(f"Mask", fontsize=18)    
    
    return plt.show()

def show_single_img_and_mask(subject: tio.data.subject.Subject, resize_to=50):
    plt.figure(figsize=(120, 100))
    
    if not isinstance(subject, tio.data.subject.Subject):
        raise TypeError(f"The subject is required to be of type torchio.data.subject.Subject but you provided {type(subject)}")
    
    img = subject["img"][tio.DATA].squeeze().permute(1,2,0)
    mask = subject["mask"][tio.DATA].squeeze().unsqueeze(2)
    
    if resize_to:
        img = resizer(img, scale=resize_to)
        mask = resizer(mask, scale=resize_to)

    plt.subplot(1, 3, 1)
    plt.imshow(img)
    plt.title(f"Image", fontsize=18)
    
    plt.subplot(1, 3, 2)
    plt.imshow(img)
    plt.imshow(mask, cmap="hot", alpha=0.5)
    plt.title(f"Image + mask", fontsize=18)    
    
    plt.subplot(1, 3, 3)
    plt.imshow(mask, cmap="hot")
    plt.title(f"Mask", fontsize=18)    
    
    return plt.show()

def to_4d(img, input_chan_first=False, output_chan_first=True):
    if not len(img.shape)==3:
        raise ValueError("Function only converts 3D arrayto 4D array")
    return np.expand_dims(np.transpose(img, 
                   (0,1,2) if input_chan_first else (2,0,1)), 
                   3 if output_chan_first else 0)

def to_3d(img, input_chan_first=True, output_chan_first=False):
    if not len(img.shape)==4:
        raise ValueError("Function only converts 4D arrayto 3D array")
    return np.transpose(img.squeeze(), (0,1,2) if output_chan_first else (1,2,0))

def to_3chan(x, dim=1):
    return torch.cat((x,x,x), dim=dim)

def resizer(img, scale=5, show=False):
    """
    Returns an smaller array of the same dimensions, but converts to 3D to allow for resizing
    """
    scale_percent = scale # percent of original size
    im_dims = (len(img.shape) == 4)
    if im_dims:
        img = to_3d(img)
    width = int(img.shape[1] * scale_percent / 100)
    height = int(img.shape[0] * scale_percent / 100)
    dim = (width, height)
    img_reshaped = cv2.resize(img.numpy(), dim)
    if show:
        return plt.imshow(img_reshaped)
    if im_dims:
        return to_4d(img_reshaped)
    return img_reshaped

def squeeze_and_reshape(img_tensor, remove_alpha=False):
    if not isinstance(img_tensor, torch.Tensor):
        raise TypeError("Image needs to be a tensor")
    if len(img_tensor.shape) == 5:
        img_tensor = img_tensor.squeeze().permute(2, 1, 0)
    img_tensor = img_tensor.unsqueeze(2).permute(3,1,0,2)
    return img_tensor

def to_pil(image):
    # for 
    data = image.numpy().squeeze().T
    data = data.astype(np.uint8)
    image = Image.fromarray(data)
    w, h = image.size
    display(image)
    print() 

In [16]:
def remask(img, mask, tile, threshold=8, show=False):
    
    img_height = img.shape[1]
    img_width = img.shape[0]
    
    number_of_vertical_tiles = (img_height // tile)+1
    number_of_horizontal_tiles = (img_width // tile)+1
    
    #background = np.zeros((tile*number_of_horizontal_tiles, tile*number_of_vertical_tiles,3))[:img.shape[0],:img.shape[1],:img.shape[2]]
    
    tile_coords = []
    for h_idx in range(number_of_horizontal_tiles):
        for v_idx in range(number_of_vertical_tiles):
            tile_coords.append((h_idx+1, v_idx+1)) # +1 to remove 0 indexing

    cropped_images = []
    for h,v in tile_coords:
        cropped_images.append((h, v, img[tile*(h-1):tile*h, tile*(v-1):tile*v, :]))
        
    for horiz,vert,im in cropped_images:
        if not 0 in im.shape:      #required in case tile is 
            hsv = cv2.cvtColor(im, cv2.COLOR_BGR2HSV)
            h, s, v = cv2.split(hsv)
            if s.mean() < threshold:
                all_black = np.full((im.shape[0], im.shape[1]),2)
                mask[tile*(horiz-1):tile*horiz,tile*(vert-1):tile*vert] = all_black
                #im = im*0.
            #background[tile*(horiz-1):tile*horiz,tile*(vert-1):tile*vert,:] = im
    
    if show:
        plt.figure(figsize=(10, 10))

        plt.subplot(1, 3, 1)
        plt.imshow(img.astype('uint8'))
        plt.title(f"Image", fontsize=18)

        plt.subplot(1, 3, 2)
        plt.imshow(img.astype('uint8'))
        plt.imshow(mask.astype('uint8'), cmap="hot", alpha=0.5)
        plt.title(f"Image + mask", fontsize=18)    

        plt.subplot(1, 3, 3)
        plt.imshow(mask.astype('uint8'), cmap="hot")
        plt.title(f"Mask", fontsize=18)    

        plt.show()
    
    return mask


#img_id = get_id_by_index(7)
#img_id = '095bf7a1f'
#with tifffile.TiffFile(path/"train"/(img_id+".tiff")) as tif:
#    imgg = tif.asarray()
#print(imgg.shape)
#mask = get_mask(img_id)
#new_mask = remask(to_3d(squeeze_and_reshape(imgg)), mask, 1000)

In [17]:
def create_mask_df(df, directory):
    mask_list = []
    for idx,_ in tqdm(enumerate(df.iterrows()), total=len(df)):
        img_id = get_id_by_index(idx, df=df)
        with tifffile.TiffFile(path/directory/(img_id+".tiff")) as tif:
            base_im = tif.asarray()
            im_tensor = squeeze_and_reshape(torch.from_numpy(base_im)).numpy()
            mask = remask(to_3d(im_tensor), get_mask(img_id), 1000)
            mask_list.append((img_id, mask))
    return pd.DataFrame(mask_list, columns=["img_id", "mask"])

In [18]:
# Don't recreate the dataset everytime - pull from local directory if available as pickle file
# I understand pickles aren't safe - so only do this in your local env and never open an unfamiliar pickle file
if not (path/"new_masks.pkl").exists():
    new_masks = create_mask_df(train_df, "train")
    new_masks.to_pickle(path/"new_masks.pkl")
if (path/"new_masks.pkl").exists():
    with open(path/"new_masks.pkl", "rb") as fh:
        new_masks = pickle.load(fh)
    #new_masks = pd.read_pickle("new_masks.pkl")

In [19]:
def cut_image(img_id, source_path:Path, destination_path: Path, mask_df=None):
    """
    Cut image (and corresponding mask - in Dataframe - if supplied) into QUARTERS and save them to a directory called smaller
    """
    
    img = tio.Image(source_path/f"{img_id}.tiff").data
    if len(img.shape) != 4:
        raise ValueError("Tensor shape needs to have 4 dimensions")
    if img.shape[0] != 4:
        raise ValueError("First dimension must have 4 channels")
    vertical_tiles = img.shape[2] // 2
    horizontal_tiles = img.shape[1] // 2
    
    if mask_df:
        mask = torch.from_numpy(mask_df[mask_df["img_id"]==img_id]["mask"].values[0]).unsqueeze(0).unsqueeze(3)
        # I have managed to flip the axes somewhere and am too lazy or stubborn to fix the root issue. So need to permute axes
        mask = mask.permute(0,2,1,3)
    
    img1 = img[:,:horizontal_tiles,:vertical_tiles,:]
    img2 = img[:,horizontal_tiles:,:vertical_tiles,:]
    img3 = img[:,:horizontal_tiles,vertical_tiles:,:]
    img4 = img[:,horizontal_tiles:,vertical_tiles:,:]
    
    if mask_df:
        mask1 = mask[:,:horizontal_tiles,:vertical_tiles,:]
        mask2 = mask[:,horizontal_tiles:,:vertical_tiles,:]
        mask3 = mask[:,:horizontal_tiles,vertical_tiles:,:]
        mask4 = mask[:,horizontal_tiles:,vertical_tiles:,:]
    
    tio.Image(tensor=img1).save(destination_path/f"{img_id}_1.tiff")
    tio.Image(tensor=img2).save(destination_path/f"{img_id}_2.tiff")
    tio.Image(tensor=img3).save(destination_path/f"{img_id}_3.tiff")
    tio.Image(tensor=img4).save(destination_path/f"{img_id}_4.tiff")
    
    if mask_df:
        tio.Image(tensor=mask1).save(destination_path/f"{img_id}_1_mask.tiff")
        tio.Image(tensor=mask2).save(destination_path/f"{img_id}_2_mask.tiff")
        tio.Image(tensor=mask3).save(destination_path/f"{img_id}_3_mask.tiff")
        tio.Image(tensor=mask4).save(destination_path/f"{img_id}_4_mask.tiff")


#[cut_image(item, new_masks) for item in new_masks.img_id.tolist()]

In [20]:
def restitch_image(img_id, pred_mask=None):
    for name in (path/"smaller/imgs").glob(f"{img_id}_?.tiff"):
        img_quarter = name.name.split("_")[1].split(".")[0]
        if img_quarter == "1":
            img1 = tio.Image(path/"smaller/imgs"/f"{name.name}").data
        if img_quarter == "2":
            img2 = tio.Image(path/"smaller/imgs"/f"{name.name}").data
        if img_quarter == "3":
            img3 = tio.Image(path/"smaller/imgs"/f"{name.name}").data
        if img_quarter == "4":
            img4 = tio.Image(path/"smaller/imgs"/f"{name.name}").data
    
    # make a 4D tensor with 4 channels and 1 depth channel
    whole_image = torch.zeros(
        img1.shape[0],
        img1.shape[1] + img3.shape[1],
        img1.shape[2] + img2.shape[2],
    ).unsqueeze(3)
    
    whole_image[:,:whole_image.shape[1]//2, :whole_image.shape[2]//2, :] = img1
    whole_image[:,whole_image.shape[1]//2-1:, :whole_image.shape[2]//2, :] = img2
    whole_image[:,:whole_image.shape[1]//2,  whole_image.shape[2]//2:, :] = img3
    whole_image[:,whole_image.shape[1]//2-1:,  whole_image.shape[2]//2:, :] = img4
    
    to_pil(whole_image)

In [21]:
#restitch_image(get_id_by_index(1))

## Custom Model

In [22]:
class CustomUnet(nn.Module):
    """
    The user specifies what the first input channels size will be and the ultimate output size will be
    The downblock and upblock functions also take input and out values - but these are PER CONVOLUTION
    They do not necessarily inherit the values specified by the user
    
    The architecture is 3 down blocks, followed by 3 up blocks
    Output is squeezed if the channel_out is 1 - masks are single channels so this matches the dimensions.
    """
    
    def __init__(self, channel_in, channel_out, stride=1, ks=3):
        super(CustomUnet, self).__init__()
        self.down_conv1 = self._downblock(channel_in, 16, stride=stride, ks=ks)
        self.down_conv2 = self._downblock(16, 32, stride=stride, ks=ks)
        self.down_conv3 = self._downblock(32, 64, stride=stride, ks=ks)
        self.up_conv3 = self._upblock(64, 32, stride=stride, ks=ks)
        self.up_conv2 = self._upblock(32*2, 16, stride=stride, ks=ks) # key to notice the doubling of input size
        self.up_conv1 = self._upblock(16*2, channel_out, stride=stride, ks=ks)
    
    # downward (contracting) block
    def _downblock(self, n_in, n_out, stride, ks):
        down_conv = nn.Sequential(
            nn.Conv2d(n_in, n_out, stride=stride, kernel_size=ks, padding=ks//2), 
            nn.BatchNorm2d(n_out),
            nn.ReLU(),
            nn.Conv2d(n_out, n_out, stride=stride, kernel_size=ks, padding=ks//2), 
            nn.BatchNorm2d(n_out),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=ks, stride=2, padding=ks//2) # 256/2 = 128
        )
        return down_conv
    
    def _upblock(self, n_in, n_out, stride, ks):
        up_conv = nn.Sequential(
            nn.Conv2d(n_in, n_out, stride=stride, kernel_size=ks, padding=ks//2),
            nn.BatchNorm2d(n_out),
            nn.ReLU(),
            nn.Conv2d(n_out, n_out, stride=stride, kernel_size=ks, padding=ks//2), 
            nn.BatchNorm2d(n_out),
            nn.ReLU(),
            nn.ConvTranspose2d(n_out, n_out, stride=2, kernel_size=ks, padding=ks//2, output_padding=ks//2),
        )
        return up_conv
    
    def forward(self, x):
        down_conv1 = self.down_conv1(x)
        down_conv2 = self.down_conv2(down_conv1)
        down_conv3 = self.down_conv3(down_conv2)
        
        up_conv3 = self.up_conv3(down_conv3)
        
        up_conv2 = self.up_conv2(torch.cat([up_conv3, down_conv2], 1))
        up_conv1 = self.up_conv1(torch.cat([up_conv2, down_conv1], 1))
        
        return nn.Sigmoid()(up_conv1)

In [23]:
class DiceLoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(DiceLoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):
        
        #comment out if your model contains a sigmoid or equivalent activation layer
        #inputs = F.sigmoid(inputs)       
        
        #flatten label and prediction tensors
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        
        intersection = (inputs * targets).sum()                            
        dice = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)  
        
        return 1 - dice
    
dice_loss = DiceLoss()    
    
class DiceBCELoss(nn.Module):
    # Formula Given above.
    def __init__(self, weight=None, size_average=True):
        super(DiceBCELoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):
        
        #comment out if your model contains a sigmoid or equivalent activation layer
        #inputs = F.sigmoid(inputs)       
        
        #flatten label and prediction tensors
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        
        intersection = (inputs * targets).sum()                            
        dice_loss = 1 - (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)  
        BCE = F.binary_cross_entropy(inputs, targets, reduction='mean')
        Dice_BCE = BCE + dice_loss
        
        return 1 - Dice_BCE
    
bce_dice_loss = DiceBCELoss()

In [24]:
def acc_metric(predb, yb):
    return (torch.round(predb) == yb).float().mean()

## Pytorch Lightning

In [25]:
#sample_loader = next(iter(train_loader))

In [26]:
class LitUNETDataLoader(pl.LightningDataModule):
    def __init__(self, train_loader, valid_loader):
        super().__init__()
        self.train_loader = train_loader
        self.valid_loader = valid_loader
        
    def train_dataloader(self):
        return self.train_loader
    
    def val_dataloader(self):
        return self.valid_loader
    
    def transfer_batch_to_device(self, batch, device):
        # squeeze the depth dimension as our model doesn't account for this but TorchIO utilises this
        x = batch['img'][tio.DATA].squeeze()
        y = batch['mask'][tio.DATA].squeeze()
        # need to override the labels to ensure we are only guessing presence of glomeruli
        # remember the additional category was introduced just to improve sampling selection area
        y = torch.where(y != torch.tensor(1), torch.tensor(0), torch.tensor(1))
        return x.to(device), y.to(device)

class LitUNET(pl.LightningModule):
    def __init__(self, model, loss_fxn, bs=32, learning_rate=3e-3, num_workers=0, manual_optimization=False):
        super().__init__()
        self.save_hyperparameters()
        self.model = model
        self.loss_fxn = loss_fxn
        self.lr = learning_rate
        self._manual_optimization = manual_optimization
        if self._manual_optimization:
            self.training_step = self.training_step_manual
            
    def forward(self, x):
        out = self.model(x)
        return out

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.model(x)
        loss = self.loss_fxn(y_hat, y)
        self.log('train_loss', loss, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.model(x)
        loss = self.loss_fxn(y_hat, y)
        self.log('valid_loss', loss, logger=True, on_step=True, on_epoch=True)
        return loss

    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.model(x)
        loss = self.loss_fxn(y_hat, y)
        self.log('test_loss', loss)

    def configure_optimizers(self):
        # self.hparams available because we called self.save_hyperparameters()
        return torch.optim.Adam(self.parameters(), lr=self.lr)

In [27]:
class LossSpikeAnalyzer(Callback):
    def __init__(self):
        self.train_batch_losses = []
        self.valid_batch_losses =[]
    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
        callback_stats = trainer.callback_metrics
        if "train_loss_step" in callback_stats.keys():
            self.train_batch_losses.append(callback_stats["train_loss_step"].item())
            #print(f"The shape of the batch is {batch['img'][tio.DATA].shape} - the batch actually returns {batch.keys()} - and {outputs}, and the loss is {np.asarray(self.batch_loss).mean()}")
            #if callback_stats["train_loss_step"].item() > 0.9 or batch_idx % 60 == 0:
            if batch_idx % 600 == 0:
                img_list = []
                with torch.no_grad():
                    pl_module.eval()
                    preds = pl_module(batch['img'][tio.DATA].squeeze().cuda())
                    pl_module.train()
                for i, img in enumerate(batch['img'][tio.DATA].squeeze()):
                    mask = torch.where(batch['mask'][tio.DATA].squeeze()[i,...].unsqueeze(0) != torch.tensor(1), torch.tensor(0), torch.tensor(1))
                    mask = to_3chan(mask, 0)
                    pred = to_3chan(preds[i,...],0)
                    img_list.append(img)
                    img_list.append(mask)
                    img_list.append(pred.cpu())
                grid = torchvision.utils.make_grid(
                    img_list,
                    nrow=3,
                )
                self._save(grid, batch_idx, callback_stats["train_loss_step"].item())
                
    def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
        self.valid_batch_losses.append(outputs.cpu())
       
    def on_validation_epoch_end(self, trainer, pl_module):
        if len(self.train_batch_losses) > 10:
            train_loss_total = np.asarray(self.train_batch_losses).mean()
            train_loss_last50 = np.asarray(self.train_batch_losses[round(len(self.train_batch_losses)*(1-0.5)):]).mean()
            train_loss_last15 = np.asarray(self.train_batch_losses[round(len(self.train_batch_losses)*(1-0.15)):]).mean()
            print(f"The validation loss at VALIDATION END is {trainer.callback_metrics['valid_loss_epoch'].item()}")
            wandb.log({
                "valid_loss_epoch": trainer.callback_metrics["valid_loss_epoch"].item(),
                "valid_loss": trainer.callback_metrics["valid_loss"].item(),
                "train_loss_total_epoch": train_loss_total,
                "train_loss_last50_epoch": train_loss_last50,
                "train_loss_last15_epoch": train_loss_last15,
            })

            self.train_batch_losses = []

    @staticmethod
    def _show(img):
        npimg = img.numpy()
        plt.imshow(np.transpose(npimg, (1,2,0)), interpolation='nearest')
        
    @staticmethod
    def _save(img, batch_idx, loss):
        npimg = img.numpy()
        loss = str(round(loss,2)).replace(".", "_")
        plt.imsave(path/"training_image_logs"/f"{batch_idx}__{loss}.png", np.transpose(npimg, (1,2,0)))

## TorchIO

### Subject Generation

In [28]:
def subject_creator(df, affine = torch.tensor([[-1.,  0.,  0.,  0.], [ 0., -1.,  0.,  0.], [ 0.,  0.,  1.,  0.], [ 0.,  0.,  0.,  1.]])):
    subjects_list = []
    for idx,_ in tqdm(enumerate(df.iterrows()), total=len(df)):
        
        img_id = get_id_by_index(idx, df=df)
        
        pic_list = [item for item in (path/"smaller/imgs").rglob("*") if not item.is_dir() and img_id in item.name]
        
        for pic in pic_list:
            pic_name = pic.name.split(".")[0]
            im = tio.ScalarImage(path/"smaller/imgs"/(pic_name+".tiff"))
            mask = tio.LabelMap(path/"smaller/masks"/(pic_name+"_mask.tiff"), affine=affine)

            subjects_list.append(tio.Subject(
                img = im,
                mask = mask,
                img_id = pic_name
            ))
    return subjects_list

### Custom Transforms

In [29]:
if TEST:
    test_items = subject_creator(new_masks)
    transforms = tio.Compose([custom_reshape, custom_normalization])
    test_dataset = tio.SubjectsDataset(test_items, transform=transforms)
    
    test_img = test_dataset[0]
    
    downsized_img = tio.Resample((4,4,1))(test_img["img"][tio.DATA])
    
    downsized_img.shape
    
    plt.imshow(downsized_img.squeeze().permute(1,2,0))

In [30]:
#show_single_img_and_mask(test_img)

In [31]:
custom_normalization = tio.Lambda(lambda x: (x/255).float(), types_to_apply=[tio.INTENSITY])

In [32]:
custom_reshape = tio.Lambda(lambda x: x[:3,...], types_to_apply=[tio.INTENSITY])

In [33]:
custom_to3d = tio.Lambda(lambda x: to_3chan(x, 0), types_to_apply=[tio.LABEL])

In [34]:
# unnecessary as I should find out why there are different shapes but I want to get to model building
def shuffle_axes(img_tensor):
    return img_tensor.permute(0,2,1,3)
reshuffle = tio.Lambda(shuffle_axes, types_to_apply=[tio.LABEL])

In [35]:
custom_shrink = tio.Lambda(lambda x: torch.tensor(resizer(x, 15)))

## WandB Sweep Setup

In [36]:
# models

pl_unet = UNet(1)
custom_unet = CustomUnet(3, 1)

In [37]:
custom_unet = CustomUnet(3, 1)

In [38]:
# sweep config for the WandB workflow

################
# TODO - include other transforms, hyperparameters (dropout, nadam), different loss, different architecture
################

sweep_config = {
    "name": "custom_unet",
    "method": "grid",
    "metric": {
        "name": "valid_loss_epoch",
        "goal": "minimize"
    },
    "parameters": {
        "sample_ratio": {
            "values": [1, 2, 3]
        },

        "patch_size": {
            "values": ["256", "512"]
        },
        "lr": {
            "values": [3e-3, 3e-2, 3e-4]
        },
        "gradient_clipping": {
            "values": [0.5, 1]
        },
    }
}

#sweep_id = wandb.sweep(sweep_config, entity="stantonius", project="kidneys-cv")

## Training

In [39]:
class MasterTrainer:
    def __init__(self):
        self.sweep_defaults = {
            "sample_ratio": 1,
            "patch_size": "256",
            "lr": 3e-3,
            "gradient_clipping": 0.5,
            "epochs": 10,
        }
        self.sample_ratio_options = {
            1: {0: 8, 1: 3, 2: 2},
            2: {0: 1, 1: 1, 2: 1},
            3: {0: 3, 1: 3, 2: 1}
        }
        self.patch_size_options = {
            "256": (256, 256, 1),
            "512": (512, 512, 1)
        }
        self.run = wandb.init(entity="stantonius", project="kidneys-cv", config=self.sweep_defaults)
        self.config = wandb.config
        
        self.sample_ratio = self.sample_ratio_options.get(self.config.sample_ratio)
        self.patch_size = self.patch_size_options.get(self.config.patch_size)
    
        self.subjects_list = subject_creator(new_masks)
        self.subjects_list_copy = self.subjects_list[:]     # needed because shuffle does in place
        
        random.seed(57)
        random.shuffle(self.subjects_list_copy)
        
        self.train_subjects = self.subjects_list_copy[:round(len(self.subjects_list_copy)*0.8)]
        self.valid_subjects = self.subjects_list_copy[round(len(self.subjects_list_copy)*0.8):]
        #train_subjects = subjects_list_copy[:1]
        #valid_subjects = subjects_list_copy[1:2]
    
        self.train_transforms = tio.Compose([tio.Resample((20,20,1)), custom_reshape, tio.RandomFlip(), tio.RandomAffine(), custom_normalization,])
        self.valid_transforms = tio.Compose([tio.Resample((20,20,1)), custom_reshape, custom_normalization,])
        #self.train_transforms = tio.Compose([custom_reshape, tio.RandomFlip(), tio.RandomAffine(), custom_normalization,])
        #self.valid_transforms = tio.Compose([custom_reshape, custom_normalization,])
    
        self.train_dataset = tio.SubjectsDataset(self.train_subjects, transform=self.train_transforms)
        self.valid_dataset = tio.SubjectsDataset(self.valid_subjects, transform=self.valid_transforms)
    
        self.queue_length = 10
        self.samples_per_volume = 10
    
        self.sampler = tio.data.LabelSampler(self.patch_size, label_probabilities=self.sample_ratio)
    
        self.train_queue = tio.Queue(
            self.train_dataset,
            self.queue_length,
            self.samples_per_volume,
            self.sampler,
            num_workers=0,
        )

        self.valid_queue = tio.Queue(
            self.valid_dataset,
            self.queue_length,
            self.samples_per_volume,
            self.sampler,
            num_workers=0,
        )

        self.train_loader = DataLoader(self.train_queue, batch_size=16, pin_memory=True)
        self.valid_loader = DataLoader(self.valid_queue, batch_size=16, pin_memory=True)
        
    def set_transforms(self, train_transforms, valid_transforms):
        """
        Clearly this should be written so that code isnt duplicated
        Couldnt decide if I should have a setter for every parameter
        """
        self.train_transforms = train_transforms
        self.valid_transforms = valid_transforms
        self.train_dataset = tio.SubjectsDataset(self.train_subjects, transform=self.train_transforms)
        self.valid_dataset = tio.SubjectsDataset(self.valid_subjects, transform=self.valid_transforms)
    
        self.queue_length = 10
        self.samples_per_volume = 10
    
        self.sampler = tio.data.LabelSampler(self.patch_size, label_probabilities=self.sample_ratio)
    
        self.train_queue = tio.Queue(
            self.train_dataset,
            self.queue_length,
            self.samples_per_volume,
            self.sampler,
            num_workers=0,
        )

        self.valid_queue = tio.Queue(
            self.valid_dataset,
            self.queue_length,
            self.samples_per_volume,
            self.sampler,
            num_workers=0,
        )
        
        self.train_loader = torch.utils.data.DataLoader(self.train_queue, batch_size=16, pin_memory=True)
        self.valid_loader = torch.utils.data.DataLoader(self.valid_queue, batch_size=16, pin_memory=True)
        
    def train(self, model, loss, epochs, dryrun = True):
        data = LitUNETDataLoader(self.train_loader, self.valid_loader)
        model = LitUNET(custom_unet, dice_loss, num_workers=16, learning_rate=self.config.lr)
        checkpoint_callback = ModelCheckpoint(monitor='valid_loss_epoch', dirpath=path/"models", save_top_k=3, mode='min', save_weights_only=True, prefix=arrow.now().format("MMM_DD_YY_H_mm"))
        trainer = pl.Trainer(gpus=1, callbacks=[LossSpikeAnalyzer(), checkpoint_callback], gradient_clip_val=self.config.gradient_clipping, max_epochs=epochs)
        trainer.fit(model, data)

In [40]:
patch_size = (256, 256, 1)
sample_ratio = {0: 4, 1: 4, 2: 1}

subjects_list = subject_creator(new_masks)
subjects_list_copy = subjects_list[:]     # needed because shuffle does in place

random.seed(57)
random.shuffle(subjects_list_copy)

train_subjects = subjects_list_copy[:round(len(subjects_list_copy)*0.8)]
valid_subjects = subjects_list_copy[round(len(subjects_list_copy)*0.8):]
#train_subjects = subjects_list_copy[:1]
#valid_subjects = subjects_list_copy[1:2]

#train_transforms = tio.Compose([tio.Resample((20,20,1)), custom_reshape, tio.RandomFlip(), tio.RandomAffine(), custom_normalization,])
#valid_transforms = tio.Compose([tio.Resample((20,20,1)), custom_reshape, custom_normalization,])
train_transforms = tio.Compose([custom_reshape, tio.RandomFlip(), tio.RandomAffine(), custom_normalization,])
valid_transforms = tio.Compose([custom_reshape, custom_normalization,])

train_dataset = tio.SubjectsDataset(train_subjects, transform=train_transforms)
valid_dataset = tio.SubjectsDataset(valid_subjects, transform=valid_transforms)

queue_length = 100
samples_per_volume = 100

sampler = tio.data.LabelSampler(patch_size, label_probabilities=sample_ratio)

train_queue = tio.Queue(
    train_dataset,
    queue_length,
    samples_per_volume,
    sampler,
    num_workers=0,
    shuffle_subjects=True,
    shuffle_patches=True,
)

valid_queue = tio.Queue(
    valid_dataset,
    queue_length,
    samples_per_volume,
    sampler,
    num_workers=0,
    shuffle_subjects=False,
    shuffle_patches=False,
)

def printer(x):
    print(x)

train_loader = DataLoader(train_queue, batch_size=16)
valid_loader = DataLoader(valid_queue, batch_size=16)

100%|███████████████████████████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 275.85it/s]


In [41]:
def batch_creator(subjects_list):
    """
    Takes a list of objects and returns a tuple of same length
    First value in tuple is a list of the x-values, second is a list of y-values
    """
    x = torch.stack([img["img"][tio.DATA] for img in subjects_list], 0).squeeze()
    y = torch.stack([mask["mask"][tio.DATA] for mask in subjects_list], 0).squeeze().unsqueeze(1)
    return (x, y)

dls = DataLoaders(
    TfmdDL(
        train_queue, 
        batch_size=16, 
        pin_memory=True,
        num_workers=0,
        #chunkify=lambda x: print(str(x)),
        # returns generator of indices (provided by sample attribute), length is provided by queue sample length
        #create_batches=lambda x: print(x),
        # passed a list of length batchsize and collates into a batch
        #create_batch=lambda x: print(x[1]["img"][tio.DATA].shape),
        create_batch=batch_creator,
    ),
    TfmdDL(
        valid_queue, 
        batch_size=16, 
        pin_memory=True, 
        num_workers=0, 
        create_batch=batch_creator
    ),
).cuda()

In [42]:
#prebatched=False

#def create_batch(b): return (fa_collate,fa_convert)[prebatched](b)
#create_batch()

In [43]:
# fastai callbacks

class PrinterCallback(Callback):
    """
    Snaps image of x, y, and preds every specified number of batches
    Saves images to path specified
    """
    def __init__(self, path, img_freq=105):
        self.img_freq = img_freq
        self.path = path
    def after_batch(self):
        if self.iter % self.img_freq == 0:
            img_list = []
            with torch.no_grad():
                for i in range(self.pred.shape[0]):
                    x = self.x[i,...]
                    y = to_3chan(self.y[i,...], 0)
                    pred = to_3chan(self.pred[i,...], 0)
                    img_list.append(x)
                    img_list.append(y)
                    img_list.append(pred)
                grid = torchvision.utils.make_grid(
                    img_list,
                    nrow=3,
                )
                self._save(self.path, grid, self.epoch, self.iter, round(self.loss.item(), 3))
                
        #print(f"The learning rate is {self.opt.hypers[0]['lr']}")
        #print({self.dls.valid.subjects_dataset._transform})
        
    @staticmethod
    def _save(img_path, img, epoch, batch, loss):
        npimg = img.cpu().detach().float().numpy()
        plt.imsave(img_path/f"epoch{epoch}batch{batch}__{loss}.png", np.transpose(npimg, (1,2,0)))
        
class ConvertY(Callback):
    """
    Since we used TorchIO to sample the data, we first need to convert the y back to its normal values
    """
    def before_batch(self):
        """
        NOTE: as per the docs, you can only assign to `yb`, not `y`
        `yb` is a tuple (which is immutable) therefore you must override the `self.learn.yb` - note we are assigning to to `learn.yb`
        """
        #self.yb = tuple([torch.where(self.y != torch.tensor(1).cuda(), torch.tensor(0).cuda(), torch.tensor(1).cuda())])
        self.learn.yb = tuple([torch.where(self.y != torch.tensor(1).cuda(), torch.tensor(0).cuda(), torch.tensor(1).cuda())])
        #print(self.yb[0].shape)
        #print(len(self.yb))
        
    #def after_pred(self):
        # To check to see that the overwritten values of y did change
        #print(self.y)
        #print(dir(self))
        
class AddSigmoidActivation(Callback):
    """
    Change the output to add a Sigmoid function 
    Needed since:
        a) Using a pretrained Resnet model that doesn't support adding a final activation layer
        b) unlike `cnn_learner`, a `unet_learner` doesn't have the `custom_head` parameter (which the forums suggest is an option to effectively add a layer to a pretarined model)
    Note: need to check if `learner.model[-1].add_module` would work if you subclassed `nn.Module` and created a `forward()` method that added this activation?
    """
    def after_pred(self):
        """
        As per the documentation, this callback hook is specifically designed for modifying the outputs BEFORE theyre sent to the loss function
        Thus it is a perfect place to add our sigmoid function to the outputs
        """
        self.learn.pred = nn.Sigmoid()(self.pred)
        
class ProgressiveTransformsUpdateCallback(Callback):
    def before_epoch(self):
        if self.epoch < 4:
            self.dls.train.subjects_dataset.set_transform(
                tio.Compose([tio.Resample((4,4,1)), custom_reshape, tio.RandomFlip(), tio.RandomAffine(), custom_normalization,])
            )
            self.dls.valid.subjects_dataset.set_transform(
                tio.Compose([tio.Resample((4,4,1)), custom_reshape, custom_normalization,])
            )
            #for h in self.opt.hypers:
            #    h["lr"] = 0.00001
        if 3 < self.epoch < 8:
            self.dls.train.subjects_dataset.set_transform(
                tio.Compose([tio.Resample((2,2,1)), custom_reshape, tio.RandomFlip(), tio.RandomAffine(), custom_normalization,])
            )
            self.dls.valid.subjects_dataset.set_transform(
                tio.Compose([tio.Resample((2,2,1)), custom_reshape, custom_normalization,])
            )
            for h in self.opt.hypers:
                h["lr"] = 0.00001
        if self.epoch > 7:
            self.dls.train.subjects_dataset.set_transform(
                tio.Compose([custom_reshape, tio.RandomFlip(), tio.RandomAffine(), custom_normalization,])
            )
            self.dls.valid.subjects_dataset.set_transform(
                tio.Compose([custom_reshape, custom_normalization,])
            )
        #print(self.data.dataset.subjects_dataset.dry_iter())

In [44]:
learner = unet_learner(dls, resnet34, n_out=1, loss_func=dice_loss, lr=0.00015)

In [45]:
#learner.lr_find()

In [46]:
learner.show_training_loop()

Start Fit
   - before_fit     : [TrainEvalCallback, Recorder, ProgressCallback]
  Start Epoch Loop
     - before_epoch   : [Recorder, ProgressCallback]
    Start Train
       - before_train   : [TrainEvalCallback, Recorder, ProgressCallback]
      Start Batch Loop
         - before_batch   : []
         - after_pred     : []
         - after_loss     : []
         - before_backward: []
         - before_step    : []
         - after_step     : []
         - after_cancel_batch: []
         - after_batch    : [TrainEvalCallback, Recorder, ProgressCallback]
      End Batch Loop
    End Train
     - after_cancel_train: [Recorder]
     - after_train    : [Recorder, ProgressCallback]
    Start Valid
       - before_validate: [TrainEvalCallback, Recorder, ProgressCallback]
      Start Batch Loop
         - **CBs same as train batch**: []
      End Batch Loop
    End Valid
     - after_cancel_validate: [Recorder]
     - after_validate : [Recorder, ProgressCallback]
  End Epoch Loop
   - after_

In [47]:
#learner.dls.dataset.subjects_dataset.set_transform
new_train_transforms = tio.Compose([custom_reshape, tio.RandomFlip(), tio.RandomAffine(), custom_normalization,])
new_valid_transforms = tio.Compose([custom_reshape, custom_normalization,])
#learner.dls.train.subjects_dataset.set_transform(new_train_transforms)
#learner.dls.valid.subjects_dataset.set_transform(new_valid_transforms)

### Logging

According to the fastai docs, models are saved in `learner.path/learner.model_dir/name.pth` so if this isn't set we need to pick a location.

In [48]:
# can save only when a certain improvement happens but will do this later

save_model_callback = SaveModelCallback(every_epoch=True)

To utilise WandB:

In [53]:
wandb.init()
wandb_callback = WandbCallback(log='all', log_preds=True, log_model=True, log_dataset=False, dataset_name=None, valid_dl=None, n_preds=36, seed=12345, reorder=True)

0,1
epoch,3.78528
train_loss,0.56998
raw_loss,0.63942
wd_0,0.01
sqr_mom_0,0.99
lr_0,1e-05
mom_0,0.86282
eps_0,1e-05
wd_1,0.01
sqr_mom_1,0.99


0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train_loss,█▆▃▂▂▁▁▁▁▁▁▁▂▁▁▁▂▂▂▂▂▂▁▁▂▂▂▂▁▁▁▁▂▂▂▂▂▁▁▁
raw_loss,█▃▄▃▁▄▆▄▄▆▄▇▃▃▅▃▃▆▃▆▃▃▄▅▄▅▇▃▄▄▃█▄▆▇▃▃▄▄▄
wd_0,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
sqr_mom_0,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
lr_0,▁▁▂▃▃▄▆▆▇██▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
mom_0,██▇▇▆▅▄▃▂▁▁███▇▇▆▆▅▅▄▄▃▃▂▂▁▁▁▁▁▁▁▁▁▁▁▂▂▂
eps_0,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
wd_1,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
sqr_mom_1,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁


### Model Run

In [54]:
# if a model has already been created and you want to continue training 
# note you should create the original learner first, and then run the code below
#learner.load("could_do_beter_0.25")

In [None]:
# Train with standard resolution

save_image_path = path/"training_image_logs"
cbs=[PrinterCallback(save_image_path), ConvertY(), AddSigmoidActivation(), save_model_callback, wandb_callback, GradientAccumulation()]
learner.fine_tune(6, cbs=cbs, base_lr=1e-5)

epoch,train_loss,valid_loss,time
0,0.594227,0.676311,16:08


epoch,train_loss,valid_loss,time


In [None]:
# Progressive transformation

save_image_path = path/"training_image_logs"
cbs=[PrinterCallback(save_image_path), ConvertY(), AddSigmoidActivation(), save_model_callback, wandb_callback, ProgressiveTransformsUpdateCallback(), GradientAccumulation()]
learner.freeze_to(-1)
learner.fit_flat_cos(12, lr=1e-5, pct_start=0.4, cbs=cbs)
#learner.fit_flat_cos(5, lr=1e-5, cbs=cbs)

In [73]:
learner.opt_func

<function fastai.optimizer.Adam(params, lr, mom=0.9, sqr_mom=0.99, eps=1e-05, wd=0.01, decouple_wd=True)>

In [None]:
# Train with lower resolution and then upsampling at the very end

save_image_path = path/"training_image_logs"
cbs=[PrinterCallback(save_image_path), ConvertY(), AddSigmoidActivation(), save_model_callback, wandb_callback]
learner.fine_tune(6)

In [None]:
# Train with standard resolution

save_image_path = path/"training_image_logs"
cbs=[PrinterCallback(save_image_path), ConvertY(), AddSigmoidActivation(), save_model_callback, wandb_callback]
learner.freeze_to(-1)
learner.fit_one_cycle(12, lr_max=0.00001, div=100, cbs=cbs)
#learner.fit_flat_cos(5, lr=1e-5, cbs=cbs)

In [None]:
#learner.save("can_still_do_better_0.115")

In [None]:
let

Some notes on training:

* Initial LR was 0.001 - I think this was too high because we quickly (after 300 batches) got to a low loss, then the **loss explosion** happened (quickly improved loss, then loss deteriorates very fast and never comes back down - plateaus in the opposite direction)
    * NOTE: Resampling doesn't incur this loss explosion - loss always trends down and then plateaus at the min
* Running the `learner.lr_find()` suggests my LR was 10x too high
    * However running `learner.fine_tune()` with this LR did not massively improve (loss around 0.79)
        * One theory could be that the `lr_find()` only looks at the LR when we expect loss to improve in the first few batches. So this is the optimal LR in the early stages of training but it may get "stale" later in training (which the `lr_find()` has no insight on



TODO:
- Fine TUNE [DONE]
- Shorten the size of each epoch by reducing `samples_per_volume`
- fp16
- Train with resize and then scale up progressively by adjusting transforms through callbacks. Or just train and then fine tune (inclu unfeeze a few layers) with the larger data
- TTA
- Train with Resize(2) then upscale using a callback - try to round the edges where the hand drawn lines are sharp and innacurate
- Try with bigger patch size
- Try with different sample ratio
- Plateau callback to adjust LR 
- Use an open-source histology pretrained model and fine tune with this

In [None]:
#os.environ['WANDB_MODE'] = 'dryrun'
#master_trainer = MasterTrainer()
#master_trainer.train(custom_unet, dice_loss, 10)

In [None]:
#next(iter(master_trainer.train_loader))

In [None]:
#wandb.agent(wandb.sweep(sweep_config, entity="stantonius", project="kidneys-cv"), master_trainer)

In [None]:
#wandb login --relogin
#master_trainer()

In [None]:
#training_log_path = (path/"lightning_logs/version_126/checkpoints/")
#training_log_path.as_posix()

In [None]:
#!tensorboard --logdir lightning_logs/version_3/checkpoints

### Update model without resized images

In [None]:
load_model = LitUNET.load_from_checkpoint(path/"models"/"21_Feb_18_16_34-epoch=9-step=16249.ckpt")

In [None]:
master_trainer.set_transforms(
    tio.Compose([custom_reshape, tio.RandomFlip(), tio.RandomAffine(), custom_normalization,]),
    tio.Compose([custom_reshape, custom_normalization,])
)
master_trainer.train(load_model, dice_loss, 10)

## Inference

In [None]:
load_model = LitUNET.load_from_checkpoint(path/"models"/"21_Feb_20_8_15-epoch=4-step=8124.ckpt")

In [None]:
checkpoint_model_details = torch.load(path/"models"/"21_Feb_19_16_56-epoch=8-step=14624.ckpt")
#print(checkpoint_model_details['hyper_parameters'])

In [None]:
#[cut_image(item, path/"test", path/"test/smaller") for item in [img.name.split(".")[0] for img in (path/"test").glob("*.tiff")]]

**An aside on inference/test transforms**: for some reason I don't see the ability to perform transforms on inference/test data in the TorchIO library. Therefore I am starting to question whether you are supposed to do this? Anyway, the evaluation doesn't work if you do not do this.

I looked into how to create a *batch* transform but from what I read, I can iterate over each of the items in the batch very quickly because a) the patches are small, b) the transforms occur in C (therefore are already optimised) and c) GPU memory is limited - transfering data to the GPU to perform transforms that are already optimised for the CPU doesn't make much sense, and the transfer itself takes time.


In [None]:
def get_test_images(imgs: Union[Path, list], transforms):
    """
    Takes only image FILE NAMES (in Path of list form) and converts to ScalarImage
    Then applies any transforms provided in tio.COMPOSE object
    Returns a LIST of image tio.SUBJECTSDATASET
    """
    if isinstance(imgs, Path):
        image_paths = [img.name for img in (imgs).rglob("*") if not item.is_dir()]
    if isinstance(imgs, list):
        image_paths = imgs
    
    images =[]
    for image_path in image_paths:
        images.append(
            tio.Subject(
                img = tio.ScalarImage(image_path),
                img_id = image_path.name.split(".")[0]
            )
        )
    images_dataset = tio.SubjectsDataset(images, transforms)
    return images_dataset

In [None]:
def subject_pred(subject, model, patch_size=(256,256,1)):
    """
    Take a tio.SUBJECT and return a TUPLE of the image ID and its predicted output tensor
    """
    grid_sampler = tio.inference.GridSampler(subject, patch_size)
    patch_loader = torch.utils.data.DataLoader(grid_sampler, batch_size=4)
    aggregator = tio.inference.GridAggregator(grid_sampler)
    
    with torch.no_grad():
        for patches_batch in patch_loader:
            img_id = patches_batch['img_id']
            input_tensor = patches_batch['img'][tio.DATA]
            # Need to run the data through the non-random transforms that were applied to the training and validation data
            #input_tensor = torch.stack([test_transforms(item) for item in input_tensor]).squeeze()
            input_tensor.squeeze_()
            if len(input_tensor.shape) == 3:
                input_tensor.unsqueeze_(0)
            locations = patches_batch[tio.LOCATION]
            logits = model(input_tensor)
            labels = logits
            outputs = labels.unsqueeze(4)
            #print(outputs.shape)
            aggregator.add_batch(outputs, locations)
        output_tensor = aggregator.get_output_tensor()
        return (img_id, output_tensor)

In [None]:
test_img_path = [path/"test/smaller"/"26dc41664_1.tiff"]
test_transforms = tio.Compose([tio.Resample((2,2,1)), custom_reshape, custom_normalization,])
model = load_model.eval()

test_preds = [subject_pred(item, model) for item in get_test_images(test_img_path, test_transforms)]

In [None]:
test_preds[0][1].shape

In [None]:
torch.max(test_preds[0][1])

In [None]:
plt.imshow(to_3chan(test_preds[0][1],0).squeeze().permute(1,2,0), cmap="hot")

In [None]:
rle_encoding(output_tensor)

## Cleanup & Controls

In [None]:
#[img.unlink() for img in (path/"training_image_logs").glob("*") if img.name != ".ipynb_checkpoints"]

############################################################ 
# BE CAREFUL WITH BELOW - MAKE SURE THE DIRECTORY IS CORRECT
############################################################

# For lightning logs
#[shutil.rmtree(folder) for folder in (path/"lightning_logs").glob("*")]

# For lightning logs
#[model.unlink() for model in (path/"models").glob("*")]

# For wandb
"""
for folder in (path/"wandb").glob("*"):
    try: 
        shutil.rmtree(folder)
    except: 
        continue
"""

## Submission

* Run test set through above code and include in the model
    * Utilise the inference from TorchIO
* Add the stitch function to take inferences

In [None]:
type(affine_options[0])