# HuBMAP - Hacking the Kidney - Kaggle Competition

The [Kaggle competition page](https://www.kaggle.com/c/hubmap-kidney-segmentation)

Helpful Notebooks:
* [https://www.kaggle.com/markalavin/hubmap-tile-images-w-overlap-and-build-tfrecords](https://www.kaggle.com/markalavin/hubmap-tile-images-w-overlap-and-build-tfrecords)

## ToDO
* Look at impact of different affine matrices
* Look at impact of removing alpha channel on model size and performance
* Add Deepmind's architecture optimizer

## Package Downloads for Offline use

## Setup

In [None]:
TEST = False

In [None]:
#!conda update -n base conda

In [None]:
# ! conda config --set always_yes True
# ! conda install -c fastai -c pytorch fastai
# #! conda install pytorch torchvision torchaudio fastai -c pytorch
# #! conda update pytorch torchvision torchaudio cudatoolkit -c pytorch
# ! conda install pandas
# ! conda install -c conda-forge kaggle
# ! conda install -c conda-forge tifffile
# ! conda install -c conda-forge tqdm
# ! conda install -c conda-forge matplotlib
# ! conda install -c conda-forge pytorch-lightning
# ! conda install -c conda-forge wandb
# ! conda install -c conda-forge arrow
# !conda install -c conda-forge pickle5
# !conda install -n base -c conda-forge jupyterlab_widgets
# !conda install -c conda-forge ipywidgets

In [None]:
#! pip install arrow pickle5

In [None]:
# !pip install timm

In [None]:
# needed if running in wsl2
#! pip install pytorch-lightning wandb

In [None]:
# I have no idea why the conda-forge version doesn't work

#!python -m pip install opencv-python

# If you are running this notebook on a server (like Linux on WSL2) you need the headless version of opencv
# The regular opencv requires GUI packages that serves dont have, and will raise an error
#!python -m pip install opencv-python-headless

# temporary solution to use tab complete - something wrong with jupyter jedi - need to downgrade
#!pip install jedi==0.17.2

# !pip install torchio --upgrade

#!pip install pytorch-lightning-bolts

In [None]:
#!pip install --upgrade ssl

Ensure the finicky local CUDA is running

In [None]:
# First, import PyTorch
import torch
import torchvision
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F

# Check PyTorch version
torch.__version__
torch.cuda.is_available()

In [None]:
import pandas as pd
import numpy as np
import pytorch_lightning as pl
from pytorch_lightning.callbacks import Callback, ModelCheckpoint

# prebuilt models
# from pl_bolts.models import UNet

import tensorboard as tb

# Need to put kaggle.json in /%USERS%/.kaggle folder (C:/Users/Craig/.kaggle)
try:
    import kaggle
except:
    !echo '{"username":"canadarmy","key": KAGGLEKEY}' > ~/.kaggle/kaggle.json
    import kaggle

from pathlib import Path
import random
import os
import shutil
from typing import Union

# Read tiff images
import tifffile
import cv2
from tqdm import tqdm
import torchio as tio

from PIL import Image
from IPython.display import display

import time
import wandb
import arrow

from tqdm import tqdm


import matplotlib.pyplot as plt

# Memory management tools
import gc

import timm

from fastprogress.fastprogress import master_bar, progress_bar

from fastai.vision.all import *
from fastai.imports import *
from fastai.callback.wandb import *
from fastai.metrics import Dice, Jaccard, JaccardCoeff

import pickle5 as pickle

In [None]:
data_path = Path("./data")
# kaggle.api.dataset_download_files("iafoss/hubmap-1024x1024", path=data_path)
# kaggle.api.dataset_download_files("baesiann/glomeruli-hubmap-external-1024x1024", path=data_path)
# kaggle.api.dataset_download_files("iafoss/hubmap-256x256", path=data_path)

Ensure you are about to download the data in the cvorrect directory

Unzip the data in the correct folder - commented out so as to not repeat the unzipping

In [None]:
# import zipfile

# with zipfile.ZipFile(data_path/"glomeruli-hubmap-external-1024x1024.zip", 'r') as zipref:
#     zipref.extractall(data_path)
    
# (data_path/"masks_1024").rename(data_path/"masks_ext_1024")
# (data_path/"images_1024").rename(data_path/"images_ext_1024")
    
# with zipfile.ZipFile(data_path/"hubmap-1024x1024.zip", 'r') as zipref:
#     zipref.extractall(data_path)
    
# (data_path/"masks").rename(data_path/"masks_1024")
# (data_path/"train").rename(data_path/"images_1024")
    
# with zipfile.ZipFile(data_path/"hubmap-256x256.zip", 'r') as zipref:
#     zipref.extractall(data_path)

# (data_path/"masks").rename(data_path/"masks_256")
# (data_path/"train").rename(data_path/"images_256")

In [None]:
len((data_path/"images_256").ls())

In [None]:
len((data_path/"images_1024").ls())

Because these two datasets are the same - 1 is just higher resolution than the other, we need to rename one set because most of the file names are the same

I have also confirmed that in rach folder, the masks and the images are the same name. Se we will rename the 256 pixel images with a `-256` suffix

In [None]:
# for file in (data_path/"images_256").ls():
#     file.rename(data_path/"images_256"/(file.name.split(".")[0]+"-256"+".png"))
    
# for file in (data_path/"masks_256").ls():
#     file.rename(data_path/"masks_256"/(file.name.split(".")[0]+"-256"+".png"))

## Helper Functions

In [None]:
################
# Main Functions
################


def rle2mask(mask_rle, shape):
    '''
    mask_rle: encoding string value from csv
    shape: (width,height) of array to return
    Returns numpy array, 1 - mask, 0 - background
    '''
    s = mask_rle.split()
    # return a list of starting pixels and a list of lengths
    starts, lengths = [
        np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])
    ]
    # subtract 1 from every starting pixel
    starts -= 1
    ends = starts + lengths
    # calculate a background of 0 (empty) with size defined by image
    img = np.zeros(shape[0] * shape[1], dtype=np.uint8)
    # replace every 0 within each range with 1
    for lo, hi in zip(starts, ends):
        img[lo : hi] = 1
    return img.reshape(shape).T

def mask2rle(x):
    dots = np.where(x.T.flatten() == 1)[0]
    run_lengths = []
    prev = -2
    for b in dots:
        if (b>prev+1): run_lengths.extend((b + 1, 0))
        run_lengths[-1] += 1
        prev = b
    return run_lengths


# def get_id_by_index(index, df=train_df):
#     return df.iloc[index]['img_id']

def get_single_img(id, folder="train"):
    img = tifffile.imread(path/folder/(id+".tiff"))
    if len(img.shape) == 5:
        img = img.squeeze().transpose(1, 2, 0)
    return img

def show_single_img(id, **kwargs):
    return plt.imshow(get_single_img(id), **kwargs)

# def show_img_by_index(index, df=train_df):
#     return plt.imshow(tifffile.imread(path/"train"/(train_df.iloc[TEST_IMAGE_INDEX]['id']+".tiff")))

# def get_single_encs(id, df=train_df):
#     return df[df['img_id'] == id]['encoding'].array[0]

# def get_mask(id, df=train_df, folder="train"):
#     return rle2mask(
#         get_single_encs(id, df=df),
#         get_single_img(id, folder=folder).shape[::-1][1:]
#     )

def show_single_img_and_mask_by_id(id):
    plt.figure(figsize=(16, 10))
    
    mask = get_mask(id)
    img = get_single_img(id)

    plt.subplot(1, 3, 1)
    plt.imshow(img)
    plt.title(f"Image", fontsize=18)
    
    plt.subplot(1, 3, 2)
    plt.imshow(img)
    plt.imshow(mask, cmap="hot", alpha=0.5)
    plt.title(f"Image + mask", fontsize=18)    
    
    plt.subplot(1, 3, 3)
    plt.imshow(mask, cmap="hot")
    plt.title(f"Mask", fontsize=18)    
    
    return plt.show()

def show_single_img_and_mask(subject: tio.data.subject.Subject, resize_to=50):
    plt.figure(figsize=(120, 100))
    
    if not isinstance(subject, tio.data.subject.Subject):
        raise TypeError(f"The subject is required to be of type torchio.data.subject.Subject but you provided {type(subject)}")
    
    img = subject["img"][tio.DATA].squeeze().permute(1,2,0)
    mask = subject["mask"][tio.DATA].squeeze().unsqueeze(2)
    
    if resize_to:
        img = resizer(img, scale=resize_to)
        mask = resizer(mask, scale=resize_to)

    plt.subplot(1, 3, 1)
    plt.imshow(img)
    plt.title(f"Image", fontsize=18)
    
    plt.subplot(1, 3, 2)
    plt.imshow(img)
    plt.imshow(mask, cmap="hot", alpha=0.5)
    plt.title(f"Image + mask", fontsize=18)    
    
    plt.subplot(1, 3, 3)
    plt.imshow(mask, cmap="hot")
    plt.title(f"Mask", fontsize=18)    
    
    return plt.show()

def to_4d(img, input_chan_first=False, output_chan_first=True):
    if not len(img.shape)==3:
        raise ValueError("Function only converts 3D arrayto 4D array")
    return np.expand_dims(np.transpose(img, 
                   (0,1,2) if input_chan_first else (2,0,1)), 
                   3 if output_chan_first else 0)

def to_3d(img, input_chan_first=True, output_chan_first=False):
    if not len(img.shape)==4:
        raise ValueError("Function only converts 4D arrayto 3D array")
    return np.transpose(img.squeeze(), (0,1,2) if output_chan_first else (1,2,0))

def to_3chan(x, dim=1):
    return torch.cat((x,x,x), dim=dim)

def resizer(img, scale=5, show=False):
    """
    Returns an smaller array of the same dimensions, but converts to 3D to allow for resizing
    """
    scale_percent = scale # percent of original size
    im_dims = (len(img.shape) == 4)
    if im_dims:
        img = to_3d(img)
    width = int(img.shape[1] * scale_percent / 100)
    height = int(img.shape[0] * scale_percent / 100)
    dim = (width, height)
    img_reshaped = cv2.resize(img.numpy(), dim)
    if show:
        return plt.imshow(img_reshaped)
    if im_dims:
        return to_4d(img_reshaped)
    return img_reshaped

def squeeze_and_reshape(img_tensor, remove_alpha=False):
    if not isinstance(img_tensor, torch.Tensor):
        raise TypeError("Image needs to be a tensor")
    if len(img_tensor.shape) == 5:
        img_tensor = img_tensor.squeeze().permute(2, 1, 0)
    if img_tensor.shape[0] == 3:
        img_tensor = img_tensor.permute(2, 1, 0)
    img_tensor = img_tensor.unsqueeze(2).permute(3,1,0,2)
    return img_tensor

def to_pil(image):
    # for 
    data = image.numpy().squeeze().T
    data = data.astype(np.uint8)
    image = Image.fromarray(data)
    w, h = image.size
    display(image)
    print() 
    
def normalize_array(x):
    """
    Normalize a list of sample image data in the range of 0 to 1
    : x: List of image data.  The image shape is (32, 32, 3)
    : return: Numpy array of normalized data
    """
    return np.array((x - np.min(x)) / (np.max(x) - np.min(x)))

In [None]:
def remask(img, mask, tile, threshold=8, show=False):
    
    img_height = img.shape[1]
    img_width = img.shape[0]
    
    number_of_vertical_tiles = (img_height // tile)+1
    number_of_horizontal_tiles = (img_width // tile)+1
    
    #background = np.zeros((tile*number_of_horizontal_tiles, tile*number_of_vertical_tiles,3))[:img.shape[0],:img.shape[1],:img.shape[2]]
    
    tile_coords = []
    for h_idx in range(number_of_horizontal_tiles):
        for v_idx in range(number_of_vertical_tiles):
            tile_coords.append((h_idx+1, v_idx+1)) # +1 to remove 0 indexing

    cropped_images = []
    for h,v in tile_coords:
        cropped_images.append((h, v, img[tile*(h-1):tile*h, tile*(v-1):tile*v, :]))
        
    for horiz,vert,im in cropped_images:
        if not 0 in im.shape:      #required in case tile is 
            
            hsv = cv2.cvtColor(im, cv2.COLOR_BGR2HSV)
            h, s, v = cv2.split(hsv)
            if s.mean() < threshold:
                all_black = np.full((im.shape[0], im.shape[1]),2)
                mask_dims = mask[tile*(horiz-1):tile*horiz,tile*(vert-1):tile*vert]
                all_black = np.full((mask_dims.shape[0], mask_dims.shape[1]),2)
                mask[tile*(horiz-1):tile*horiz,tile*(vert-1):tile*vert] = all_black
                #im = im*0.
            #background[tile*(horiz-1):tile*horiz,tile*(vert-1):tile*vert,:] = im
    
    if show:
        plt.figure(figsize=(10, 10))

        plt.subplot(1, 3, 1)
        plt.imshow(img.astype('uint8'))
        plt.title(f"Image", fontsize=18)

        plt.subplot(1, 3, 2)
        plt.imshow(img.astype('uint8'))
        plt.imshow(mask.astype('uint8'), cmap="hot", alpha=0.5)
        plt.title(f"Image + mask", fontsize=18)    

        plt.subplot(1, 3, 3)
        plt.imshow(mask.astype('uint8'), cmap="hot")
        plt.title(f"Mask", fontsize=18)    

        plt.show()
    
    return mask


#img_id = get_id_by_index(7)
#img_id = '095bf7a1f'
#with tifffile.TiffFile(path/"train"/(img_id+".tiff")) as tif:
#    imgg = tif.asarray()
#print(imgg.shape)
#mask = get_mask(img_id)
#new_mask = remask(to_3d(squeeze_and_reshape(imgg)), mask, 1000)

In [None]:
def create_mask_df(df, directory):
    mask_list = []
    for idx,_ in tqdm(enumerate(df.iterrows()), total=len(df)):
        img_id = get_id_by_index(idx, df=df)
        with tifffile.TiffFile(path/directory/(img_id+".tiff")) as tif:
            base_im = tif.asarray()
            print(base_im.shape)
            im_tensor = squeeze_and_reshape(torch.from_numpy(base_im)).numpy()
            print(im_tensor.shape)
            mask = remask(to_3d(im_tensor), get_mask(img_id), 1000)
            print(f"Mask shape is {mask.shape}")
            
            cut_image(im_tensor, img_id, mask, write_path/"smaller")
            
            del base_im, im_tensor, mask
            gc.collect()

In [None]:
def cut_image(img_id, source_path:Path, destination_path: Path, mask_df=None):
    """
    Cut image (and corresponding mask - in Dataframe - if supplied) into QUARTERS and save them to a directory called smaller
    """
    
    img = tio.Image(source_path/f"{img_id}.tiff").data
    if len(img.shape) != 4:
        raise ValueError("Tensor shape needs to have 4 dimensions")
    if img.shape[0] != 4:
        raise ValueError("First dimension must have 4 channels")
    vertical_tiles = img.shape[2] // 2
    horizontal_tiles = img.shape[1] // 2
    

    
    img1 = tio.Pad(((512,512,0)))(img[:,:horizontal_tiles,:vertical_tiles,:])
    tio.Image(tensor=img1).save(destination_path/"imgs"/f"{img_id}_1.tiff")
    del img1
    gc.collect()

    img2 = tio.Pad((512,512,0))(img[:,horizontal_tiles:,:vertical_tiles,:])
    tio.Image(tensor=img2).save(destination_path/"imgs"/f"{img_id}_2.tiff")
    del img2
    gc.collect()
    
    img3 = tio.Pad((512,512,0))(img[:,:horizontal_tiles,vertical_tiles:,:])
    tio.Image(tensor=img3).save(destination_path/"imgs"/f"{img_id}_3.tiff")
    del img3
    gc.collect()
    
    img4 = tio.Pad((512,512,0))(img[:,horizontal_tiles:,vertical_tiles:,:])
    tio.Image(tensor=img4).save(destination_path/"imgs"/f"{img_id}_4.tiff")
    del img4
    gc.collect()
    
    del img
    gc.collect()




    if not isinstance(mask_df, NoneType):
        mask = torch.from_numpy(mask_df[mask_df["img_id"]==img_id]["mask"].values[0]).unsqueeze(0).unsqueeze(3)
        # I have managed to flip the axes somewhere and am too lazy or stubborn to fix the root issue. So need to permute axes
        mask = mask.permute(0,2,1,3)
        
        mask1 = tio.Pad((512,512,0))(mask[:,:horizontal_tiles,:vertical_tiles,:])
        tio.Image(tensor=mask1).save(destination_path/"masks"/f"{img_id}_1_mask.tiff")
        del mask1
        gc.collect()
        
        mask2 = tio.Pad((512,512,0))(mask[:,horizontal_tiles:,:vertical_tiles,:])
        tio.Image(tensor=mask2).save(destination_path/"masks"/f"{img_id}_2_mask.tiff")
        del mask2
        gc.collect()
        
        mask3 = tio.Pad((512,512,0))(mask[:,:horizontal_tiles,vertical_tiles:,:])
        tio.Image(tensor=mask3).save(destination_path/"masks"/f"{img_id}_3_mask.tiff")
        del mask3
        gc.collect()
        
        mask4 = tio.Pad((512,512,0))(mask[:,horizontal_tiles:,vertical_tiles:,:])
        tio.Image(tensor=mask4).save(destination_path/"masks"/f"{img_id}_4_mask.tiff")
        
        del mask4
        gc.collect()

        del mask
        gc.collect()
        
# Uncomment if these files dont exist
# [cut_image(item, path/"train", path/"smaller", mask_df=new_masks) for item in new_masks.img_id.tolist()]

In [None]:
def restitch_image(img_id, pred_mask=None):
    for name in (path/"smaller/imgs").glob(f"{img_id}_?.tiff"):
        img_quarter = name.name.split("_")[1].split(".")[0]
        if img_quarter == "1":
            img1 = tio.Image(path/"smaller/imgs"/f"{name.name}").data
        if img_quarter == "2":
            img2 = tio.Image(path/"smaller/imgs"/f"{name.name}").data
        if img_quarter == "3":
            img3 = tio.Image(path/"smaller/imgs"/f"{name.name}").data
        if img_quarter == "4":
            img4 = tio.Image(path/"smaller/imgs"/f"{name.name}").data
    
    # make a 4D tensor with 4 channels and 1 depth channel
    whole_image = torch.zeros(
        img1.shape[0],
        img1.shape[1] + img3.shape[1],
        img1.shape[2] + img2.shape[2],
    ).unsqueeze(3)
    
    whole_image[:,:whole_image.shape[1]//2, :whole_image.shape[2]//2, :] = img1
    whole_image[:,whole_image.shape[1]//2-1:, :whole_image.shape[2]//2, :] = img2
    whole_image[:,:whole_image.shape[1]//2,  whole_image.shape[2]//2:, :] = img3
    whole_image[:,whole_image.shape[1]//2-1:,  whole_image.shape[2]//2:, :] = img4
    
    to_pil(whole_image)

In [None]:
#restitch_image(get_id_by_index(1))

## Refinement

We are using a baseline model that looked to have good accuracy but now we want to refine it. We will do the following:

1. Update the colour normalisation
2. Add blur and artifacts periodically - also update the frequency with which these transformations take place
3. Update the loss function
4. Patch overlap for TorchIO [ONLY DONE IN INFERENCE]
5. Try something like TransUNET
6. Try training on smaller patch sizes

## TorchIO

### Subject Generation

In [None]:
# get all file names in folders of interest
# we know thaqt file names are the same regardless of wther the file is in masks or images directory

img_dirs = [folder.name for folder in data_path.ls() if "images" in folder.name]

img_names = []
for folder in img_dirs:
    for im in (data_path/folder).ls():
        img_names.append((im.parent.name, im.name))
        

In [None]:
# label all masks with glom

# holder = []
# for folder in master_bar(img_dirs):
#     for im in progress_bar((data_path/folder).ls()):
#         mask = tio.LabelMap(data_path/(im.parent.name).replace("images", "masks")/im.name)
#         holder.append((im.parent.name, im.name, 1 in mask.data))
        
# has_glom = pd.DataFrame(holder, columns=["dir", "name", "has_glom"])

has_glom = pd.read_csv("./has_glom.csv")

In [None]:
has_glom["has_glom"].value_counts()

In [None]:
has_glom[has_glom["name"] == '0486052bb_0080-256.png']["has_glom"].values[0]

In [None]:
random_folder = random.choice(img_dirs)
random_image = random.choice((data_path/random_folder).ls()).name
rand_im = tio.ScalarImage(data_path/random_folder/random_image)
rand_mask = tio.LabelMap(data_path/(random_folder).replace("images", "masks")/random_image)

In [None]:
rand_im.data.permute(3,1,2,0).shape

In [None]:
img_names[0][1]

In [None]:
def subject_creator(affine = torch.tensor([[-1.,  0.,  0.,  0.], [ 0., -1.,  0.,  0.], [ 0.,  0.,  1.,  0.], [ 0.,  0.,  0.,  1.]])):
    subjects_list = []
    for img in img_names:
        pic_name = img[1]
        im = tio.ScalarImage(data_path/img[0]/img[1])
        mask = tio.LabelMap(data_path/img[0].replace("images", "masks")/img[1], affine=affine)

        if has_glom[has_glom["name"] == pic_name]["has_glom"].values[0] == False:
            if random.choices([True, False], weights=(2,5))[0]:
                subjects_list.append(tio.Subject(
                    img = im,
                    mask = mask,
                    img_id = pic_name,
                ))
            else:
                pass
        else:
            subjects_list.append(tio.Subject(
                    img = im,
                    mask = mask,
                    img_id = pic_name,
                ))
    return subjects_list

In [None]:
# def subject_creator(df, affine = torch.tensor([[-1.,  0.,  0.,  0.], [ 0., -1.,  0.,  0.], [ 0.,  0.,  1.,  0.], [ 0.,  0.,  0.,  1.]])):
#     subjects_list = []
#     for idx,_ in tqdm(enumerate(df.iterrows()), total=len(df)):
        
#         img_id = get_id_by_index(idx, df=df)
        
#         pic_list = [item for item in (path/"smaller/imgs").rglob("*") if not item.is_dir() and img_id in item.name]
        
#         for pic in pic_list:
#             pic_name = pic.name.split(".")[0]
#             im = tio.ScalarImage(path/"smaller/imgs"/(pic_name+".tiff"))
#             mask = tio.LabelMap(path/"smaller/masks"/(pic_name+"_mask.tiff"), affine=affine)
            
#             ########CRITICAL TO NOTE###########
#             # below is a check to see if any positives actually exist within the image quarters
#             # if we do not run this step, we get an error when we train only on patches with positive values
            
#             if 1 in mask.data:
#                 subjects_list.append(tio.Subject(
#                     img = im,
#                     mask = mask,
#                     img_id = pic_name
#                 ))
#             else:
#                 continue
#     return subjects_list

### Custom Transforms

In [None]:
if TEST:
    test_items = subject_creator(new_masks)
    transforms = tio.Compose([custom_reshape, custom_normalization])
    test_dataset = tio.SubjectsDataset(test_items, transform=transforms)
    
    test_img = test_dataset[0]
    
    downsized_img = tio.Resample((4,4,1))(test_img["img"][tio.DATA])
    
    downsized_img.shape
    
    plt.imshow(downsized_img.squeeze().permute(1,2,0))

In [None]:
#show_single_img_and_mask(test_img)

In [None]:
custom_normalization = tio.Lambda(lambda x: (x/255).float(), types_to_apply=[tio.INTENSITY])

In [None]:
def reshuffle(x):
    return x.permute(3,1,2,0)

In [None]:
# custom_reshape = tio.Lambda(lambda x: x[:3,...], types_to_apply=[tio.INTENSITY])
custom_reshape = tio.Lambda(lambda x: x.squeeze().unsqueeze(0), types_to_apply=[tio.INTENSITY])

In [None]:
custom_to3d = tio.Lambda(lambda x: to_3chan(x, 0), types_to_apply=[tio.LABEL])

In [None]:
custom_squeeze = tio.Lambda(lambda x: x.squeeze().unsqueeze(0), types_to_apply=[tio.LABEL])

In [None]:
# unnecessary as I should find out why there are different shapes but I want to get to model building
def shuffle_axes(img_tensor):
    return img_tensor.permute(0,2,1,3)
reshuffle = tio.Lambda(shuffle_axes, types_to_apply=[tio.LABEL])

In [None]:
#custom_shrink = tio.Lambda(lambda x: torch.tensor(resizer(x, 15)))

resample_2x = tio.Lambda(lambda x: tio.Resample((2,2,1))(x), types_to_apply=[tio.INTENSITY])

## Training

### Setup

In [None]:
# patch_size = (512, 512, 1)
# patch_size = (224, 224, 1)
patch_size = (256, 256, 1)
sample_ratio = {0: 1, 1: 10, 2: 1}

subjects_list = subject_creator()
subjects_list_copy = subjects_list[:]     # needed because shuffle does in place

random.seed(22222)
#random.seed(57)
random.shuffle(subjects_list_copy)

train_subjects = subjects_list_copy[:round(len(subjects_list_copy)*0.8)]
valid_subjects = subjects_list_copy[round(len(subjects_list_copy)*0.8):]
#train_subjects = subjects_list_copy[:1]
#valid_subjects = subjects_list_copy[1:2]

#train_transforms = tio.Compose([tio.Resample((20,20,1)), custom_reshape, tio.RandomFlip(), tio.RandomAffine(), custom_normalization,])
#valid_transforms = tio.Compose([tio.Resample((20,20,1)), custom_reshape, custom_normalization,])
train_transforms = tio.Compose([
#     custom_reshape, 
#     custom_to3d, 
#     tio.RescaleIntensity(percentiles=(0.5, 99.5)), 
    tio.OneOf([
        tio.RandomFlip(),
        tio.RandomAffine(),
        tio.RandomNoise(
            mean=(0.5,0.5),
        ),
        tio.RandomBlur(),
        tio.RandomBiasField(),
        tio.RandomSwap(
            patch_size=(15,15,1)
        )
    ]),

    custom_normalization, 
])
valid_transforms = tio.Compose([
#     tio.RescaleIntensity(percentiles=(0.5, 99.5)),
#     custom_reshape, 
#     custom_to3d, 
    custom_normalization
])

train_dataset = tio.SubjectsDataset(train_subjects, transform=train_transforms)
valid_dataset = tio.SubjectsDataset(valid_subjects, transform=valid_transforms)

queue_length = 40
samples_per_volume = 4

sampler = tio.data.LabelSampler(patch_size, label_probabilities=sample_ratio)

train_queue = tio.Queue(
    train_dataset,
    queue_length,
    samples_per_volume,
    sampler,
    num_workers=0,
    shuffle_subjects=True,
    shuffle_patches=True,
)

valid_queue = tio.Queue(
    valid_dataset,
    queue_length,
    samples_per_volume,
    sampler,
    num_workers=0,
    shuffle_subjects=False,
    shuffle_patches=False,
)

# train_loader = DataLoader(train_queue, batch_size=16)
# valid_loader = DataLoader(valid_queue, batch_size=16)

In [None]:
#train_img_ids = [i["img_id"] for i in train_subjects]
valid_img_ids = [i["img_id"] for i in valid_subjects]

In [None]:
#plt.imshow(train_dataset[0]["mask"][tio.DATA].squeeze().unsqueeze(2), aspect='auto')
# test_mask = tifffile.imread(path/"smaller/masks/cb2d976f4_2_mask.tiff")
# plt.imshow(test_mask, aspect='auto')

In [None]:
# 1 in valid_dataset[-1]["mask"][tio.DATA].squeeze().unsqueeze(2)

In [None]:
def batch_creator(subjects_list):
    """
    Takes a list of objects and returns a tuple of same length
    First value in tuple is a list of the x-values, second is a list of y-values
    """
    x = torch.stack([img["img"][tio.DATA].permute(3,0,1,2) for img in subjects_list], 0).squeeze()
    y = torch.stack([mask["mask"][tio.DATA].permute(3,0,1,2) for mask in subjects_list], 0).squeeze().unsqueeze(1)
    return (x, y)

dls = DataLoaders(
    TfmdDL(
        train_queue, 
        batch_size=8, 
        num_workers=12,
        #chunkify=lambda x: print(str(x)),
        # returns generator of indices (provided by sample attribute), length is provided by queue sample length
        #create_batches=lambda x: print(x),
        # passed a list of length batchsize and collates into a batch
        #create_batch=lambda x: print(x[1]["img"][tio.DATA].shape),
        create_batch=batch_creator,
        after_batch=[
            Normalize.from_stats(*imagenet_stats),
            Hue(p=0.1),
            Saturation(p=0.1),
            Brightness(p=0.1)
        ],
    ),
    TfmdDL(
        valid_queue, 
        batch_size=8, 
        num_workers=12, 
        create_batch=batch_creator,
        after_batch=[Normalize.from_stats(*imagenet_stats)],
    ),
).cuda()

In [None]:
#plt.imshow(to_3d(dls.train_ds[0]["img"][tio.DATA]))
dls.train_ds[50000]["img"][tio.DATA].shape

In [None]:
dls.one_batch()[1].shape

In [None]:
#prebatched=False

#def create_batch(b): return (fa_collate,fa_convert)[prebatched](b)
#create_batch()

Custom function to enable gradients on `torch.where`

In [None]:
class ZeroOrOneFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        res = torch.where(input > 0.5, torch.tensor(1.0, requires_grad=True).cuda(), torch.tensor(0.0, requires_grad=True).cuda())
        ctx.save_for_backward(res)
        return res
    
    @staticmethod
    def backward(ctx, grad_output):
        input = ctx.saved_tensors
        return input
    
zero_or_one = ZeroOrOneFunction.apply

In [None]:
# fastai callbacks

class PrinterCallback(Callback):
    """
    Snaps image of x, y, and preds every specified number of batches
    Saves images to path specified
    """
    def __init__(self, path, img_freq=105):
        self.img_freq = img_freq
        self.path = path
    def after_batch(self):
        if self.iter % self.img_freq == 0:
            img_list = []
            with torch.no_grad():
                for i in range(self.pred.shape[0]):
                    x = self.x[i,...]
                    y = to_3chan(self.y[i,...], 0)
                    pred = to_3chan(self.pred[i,...], 0)
                    img_list.append(x)
                    img_list.append(y)
                    img_list.append(pred)

                grid = torchvision.utils.make_grid(
                    img_list,
                    nrow=3,
                )
                self._save(self.path, grid, self.epoch, self.iter, round(self.loss.item(), 3))
                
        #print(f"The learning rate is {self.opt.hypers[0]['lr']}")
        #print({self.dls.valid.subjects_dataset._transform})
        
    @staticmethod
    def _save(img_path, img, epoch, batch, loss):
        npimg = normalize_array(img.cpu().detach().float().numpy())
        plt.imsave(img_path/f"epoch{epoch}batch{batch}__{loss}.png", np.transpose(npimg, (1,2,0)))
        
class ConvertY(Callback):
    """
    Since we used TorchIO to sample the data, we first need to convert the y back to its normal values
    """
    def before_batch(self):
        """
        NOTE: as per the docs, you can only assign to `yb`, not `y`
        `yb` is a tuple (which is immutable) therefore you must override the `self.learn.yb` - note we are assigning to to `learn.yb`
        """
        #self.yb = tuple([torch.where(self.y != torch.tensor(1).cuda(), torch.tensor(0).cuda(), torch.tensor(1).cuda())])
        self.learn.yb = tuple([torch.where(self.y != torch.tensor(1).cuda(), torch.tensor(0).cuda(), torch.tensor(1).cuda())])
        
        #print(self.yb[0].shape)
        #print(len(self.yb))
        
    #def after_pred(self):
        # To check to see that the overwritten values of y did change
        #print(self.y)
        #print(dir(self))
        

        
## NOTE: this may not be needed anymore now that our loss function cobines this step        
class AddSigmoidActivation(Callback):
    """
    Change the output to add a Sigmoid function 
    Needed since:
        a) Using a pretrained Resnet model that doesn't support adding a final activation layer
        b) unlike `cnn_learner`, a `unet_learner` doesn't have the `custom_head` parameter (which the forums suggest is an option to effectively add a layer to a pretarined model)
    Note: need to check if `learner.model[-1].add_module` would work if you subclassed `nn.Module` and created a `forward()` method that added this activation?
    """
    def after_pred(self):
        """
        As per the documentation, this callback hook is specifically designed for modifying the outputs BEFORE theyre sent to the loss function
        Thus it is a perfect place to add our sigmoid function to the outputs
        """
        self.learn.pred = nn.Sigmoid()(self.pred)
#         self.learn.pred = zero_or_one(nn.Sigmoid()(self.pred))
        
class ProgressiveTransformsUpdateCallback(Callback):
    def before_epoch(self):
        if self.epoch < 4:
            self.dls.train.subjects_dataset.set_transform(
                tio.Compose([tio.Resample((4,4,1)), custom_reshape, tio.RandomFlip(), tio.RandomAffine(), custom_normalization,])
            )
            self.dls.valid.subjects_dataset.set_transform(
                tio.Compose([tio.Resample((4,4,1)), custom_reshape, custom_normalization,])
            )
            #for h in self.opt.hypers:
            #    h["lr"] = 0.00001
        if 3 < self.epoch < 8:
            self.dls.train.subjects_dataset.set_transform(
                tio.Compose([tio.Resample((2,2,1)), custom_reshape, tio.RandomFlip(), tio.RandomAffine(), custom_normalization,])
            )
            self.dls.valid.subjects_dataset.set_transform(
                tio.Compose([tio.Resample((2,2,1)), custom_reshape, custom_normalization,])
            )
            for h in self.opt.hypers:
                h["lr"] = 0.00001
        if self.epoch > 7:
            self.dls.train.subjects_dataset.set_transform(
                tio.Compose([custom_reshape, tio.RandomFlip(), tio.RandomAffine(), custom_normalization,])
            )
            self.dls.valid.subjects_dataset.set_transform(
                tio.Compose([custom_reshape, custom_normalization,])
            )
        #print(self.data.dataset.subjects_dataset.dry_iter())
        
class ModifyTransformsCallback(Callback):
    """
    Train and Valid transforms must each be a list of transforms wrapped in tio.Compose
    """
    def __init__(self, train_callbacks, valid_callbacks):
        self.train_callbacks = train_callbacks
        self.valid_callbacks = valid_callbacks
        
    def before_epoch(self):
        self.dls.train.subjects_dataset.set_transform(self.train_callbacks)
        self.dls.valid.subjects_dataset.set_transform(self.valid_callbacks)
        
class UpsamplePredCallback(Callback):
    """
    If we downsampled the x-values, the preds will be the same resolution. Therefore we need to upsample to the size the y-value (masks) expect
    """
    def __init__(self, upscale):
        self.upscale = upscale
        
    def after_pred(self):
        pass
    
    
class SwitchLossCallback(Callback):
    def after_pred(self):
        if self.iter > round(self.n_iter / 2):
            self.learn.loss_func = hh_dtloss
            
class UpdateContourLoss(Callback):
    def after_pred(self):
        if self.iter % 1000 == 0 and self.iter != 0:
            prior_lambda_coeff = self.learn.loss_func.lambda_coeff
            self.learn.loss_func = AdaptiveLoss(prior_lambda_coeff+0.01)
            print(f"Lambda coefficient updated to {prior_lambda_coeff + 0.01}")
        

In [None]:
#learner = unet_learner(dls, resnet34, n_out=1, loss_func=dice_loss, lr=0.00001)

In [None]:
#learner.lr_find()

Note that none of our callbacks appear here as we have not set them at the `learner` level but rather at the `fit` level. This is because we (for the most part) want callbacks for training only. However it may be necessary to add callbacks here later on.

In [None]:
#learner.show_training_loop()

In [None]:
#learner.dls.dataset.subjects_dataset.set_transform
#new_train_transforms = tio.Compose([custom_reshape, tio.RandomFlip(), tio.RandomAffine(), custom_normalization,])
#new_valid_transforms = tio.Compose([custom_reshape, custom_normalization,])
#learner.dls.train.subjects_dataset.set_transform(new_train_transforms)
#learner.dls.valid.subjects_dataset.set_transform(new_valid_transforms)

## Experiments

### Baseline

In [None]:
tags = [
    "resnet",
    "transforms",
    "custom_loss",
    "reduceLR",
    "150_samples"
]

group = "resnet"

notes = "Perimeter loss"

name = "perimeter_loss"

In [None]:
config = {
    "epochs": 20,
    "transforms": "all",
    "loss": "perimeter_loss",
    "lr": 1e-5,
    "model": "resnet34",
    "train_type": "fit"
}

In [None]:
# wandb.init(project='HuBMAP_model_experiments', entity='stantonius', name=name, tags=tags, group=group, notes=notes, config=config)

In [None]:
# callbacks 

#os.environ['WANDB_MODE'] = 'dryrun'
# wandb_callback = WandbCallback(log='all')
path = Path()

model_name = arrow.utcnow().format("DDMMMYY") + "_" + name
save_model_callback = SaveModelCallback(fname=model_name, every_epoch=True)
save_image_path = path/"training_image_logs"



updated_train_transforms = tio.Compose([tio.Resample((4,4,1)), custom_reshape, tio.RandomFlip(), tio.RandomAffine(), custom_normalization,])
updated_valid_transforms = tio.Compose([tio.Resample((4,4,1)), custom_reshape, custom_normalization,])

updated_transforms = ModifyTransformsCallback(updated_train_transforms, updated_valid_transforms)
reduce_lr = ReduceLROnPlateau(min_delta=0.2, patience=5)

cbs=[PrinterCallback(save_image_path), ConvertY(), AddSigmoidActivation(), save_model_callback, UpdateContourLoss(), reduce_lr]


#### Loss Functions

In [None]:
class DiceLoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(DiceLoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):
        
        #comment out if your model contains a sigmoid or equivalent activation layer
        #inputs = F.sigmoid(inputs) 

        #flatten label and prediction tensors
        inputs = inputs.view(-1)
        
        targets = targets.view(-1)
        
        
        intersection = (inputs * targets).sum()                            
        dice = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)  
        
        return 1 - dice
    
dice_loss = DiceLoss()    
    
class DiceBCELoss(nn.Module):
    # Formula Given above.
    def __init__(self, weight=None, size_average=True):
        super(DiceBCELoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):
        
        #comment out if your model contains a sigmoid or equivalent activation layer
        #inputs = F.sigmoid(inputs)  

        
        #flatten label and prediction tensors
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        
        intersection = (inputs * targets).sum()                            
        dice_loss = 1 - (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)
        
        # Note changed before from binary_cross_entropy to binary_cross_entropy_with_logits
        # got an error
        # However this step requires us to combine our sigmoid layer
        BCE = F.binary_cross_entropy(inputs.float(), targets.float(), reduction='mean')
        Dice_BCE = BCE + dice_loss
        
        return Dice_BCE
    
bce_dice_loss = DiceBCELoss()




In [None]:

from scipy.ndimage.morphology import distance_transform_edt as edt
from scipy.ndimage import convolve

"""
Hausdorff loss implementation based on paper:
https://arxiv.org/pdf/1904.10030.pdf
"""


class HausdorffDTLoss(nn.Module):
    """Binary Hausdorff loss based on distance transform"""

    def __init__(self, alpha=2.0, **kwargs):
        super(HausdorffDTLoss, self).__init__()
        self.alpha = alpha

    @torch.no_grad()
    def distance_field(self, img: np.ndarray) -> np.ndarray:
        field = np.zeros_like(img)

        for batch in range(len(img)):
            fg_mask = img[batch] > 0.5

            if fg_mask.any():
                bg_mask = ~fg_mask

                fg_dist = edt(fg_mask)
                bg_dist = edt(bg_mask)

                field[batch] = fg_dist + bg_dist

        return field

    def forward(
        self, pred: torch.Tensor, target: torch.Tensor, debug=False
    ) -> torch.Tensor:
        """
        Uses one binary channel: 1 - fg, 0 - bg
        pred: (b, 1, x, y, z) or (b, 1, x, y)
        target: (b, 1, x, y, z) or (b, 1, x, y)
        """
        assert pred.dim() == 4 or pred.dim() == 5, "Only 2D and 3D supported"
        assert (
            pred.dim() == target.dim()
        ), "Prediction and target need to be of same dimension"

        # pred = torch.sigmoid(pred)

        pred_dt = torch.from_numpy(self.distance_field(pred.detach().cpu().numpy())).float()
        target_dt = torch.from_numpy(self.distance_field(target.detach().cpu().numpy())).float()

        pred_error = (pred - target) ** 2
        distance = pred_dt ** self.alpha + target_dt ** self.alpha

        dt_field = pred_error.cpu() * distance
        loss = dt_field.mean()

        if debug:
            return (
                loss.cpu().numpy(),
                (
                    dt_field.cpu().numpy()[0, 0],
                    pred_error.cpu().numpy()[0, 0],
                    distance.cpu().numpy()[0, 0],
                    pred_dt.cpu().numpy()[0, 0],
                    target_dt.cpu().numpy()[0, 0],
                ),
            )

        else:
            return loss


class HausdorffERLoss(nn.Module):
    """Binary Hausdorff loss based on morphological erosion"""

    def __init__(self, alpha=2.0, erosions=10, **kwargs):
        super(HausdorffERLoss, self).__init__()
        self.alpha = alpha
        self.erosions = erosions
        self.prepare_kernels()

    def prepare_kernels(self):
        cross = np.array([cv2.getStructuringElement(cv2.MORPH_CROSS, (3, 3))])
        bound = np.array([[[0, 0, 0], [0, 1, 0], [0, 0, 0]]])

        self.kernel2D = cross * 0.2
        self.kernel3D = np.array([bound, cross, bound]) * (1 / 7)

    @torch.no_grad()
    def perform_erosion(
        self, pred: np.ndarray, target: np.ndarray, debug
    ) -> np.ndarray:
        bound = (pred - target) ** 2

        if bound.ndim == 5:
            kernel = self.kernel3D
        elif bound.ndim == 4:
            kernel = self.kernel2D
        else:
            raise ValueError(f"Dimension {bound.ndim} is nor supported.")

        eroted = np.zeros_like(bound)
        erosions = []

        for batch in range(len(bound)):

            # debug
            erosions.append(np.copy(bound[batch][0]))

            for k in range(self.erosions):

                # compute convolution with kernel
                dilation = convolve(bound[batch], kernel, mode="constant", cval=0.0)

                # apply soft thresholding at 0.5 and normalize
                erosion = dilation - 0.5
                erosion[erosion < 0] = 0

                if erosion.ptp() != 0:
                    erosion = (erosion - erosion.min()) / erosion.ptp()

                # save erosion and add to loss
                bound[batch] = erosion
                eroted[batch] += erosion * (k + 1) ** self.alpha

                if debug:
                    erosions.append(np.copy(erosion[0]))

        # image visualization in debug mode
        if debug:
            return eroted, erosions
        else:
            return eroted

    def forward(
        self, pred: torch.Tensor, target: torch.Tensor, debug=False
    ) -> torch.Tensor:
        """
        Uses one binary channel: 1 - fg, 0 - bg
        pred: (b, 1, x, y, z) or (b, 1, x, y)
        target: (b, 1, x, y, z) or (b, 1, x, y)
        """
        assert pred.dim() == 4 or pred.dim() == 5, "Only 2D and 3D supported"
        assert (
            pred.dim() == target.dim()
        ), "Prediction and target need to be of same dimension"

        # pred = torch.sigmoid(pred)

        if debug:
            eroted, erosions = self.perform_erosion(
                pred.cpu().numpy(), target.cpu().numpy(), debug
            )
            return eroted.mean(), erosions

        else:
            eroted = torch.from_numpy(
                self.perform_erosion(pred, target, debug)
            ).float()

            loss = eroted.mean()

            return loss
        
hh_erloss = HausdorffERLoss()
hh_dtloss = HausdorffDTLoss()

In [None]:
# from torch.autograd import Variable
# from torch.autograd.function import Function

# torch.backends.cudnn.deterministic = True

# def odd_flip(H):
#     '''
#     generate frequency map
#     when height or width of image is odd number,
#     creat a array concol [0,1,...,int(H/2)+1,int(H/2),...,0]
#     len(concol) = H
#     '''
#     m = int(H/2)
#     col = np.arange(0,m+1)
#     flipcol = col[m-1::-1]
#     concol = np.concatenate((col,flipcol),0)
#     return concol

# def even_flip(H):
#     '''
#     generate frequency map
#     when height or width of image is even number,
#     creat a array concol [0,1,...,int(H/2),int(H/2),...,0]
#     len(concol) = H
#     '''
#     m = int(H/2)
#     col = np.arange(0,m)
#     flipcol = col[m::-1]
#     concol = np.concatenate((col,flipcol),0)
#     return concol

# def dist(target):
#     '''
#     sqrt(m^2 + n^2) in eq(8)
#     '''

#     _,_,H,W = target.shape

#     if H%2 ==1:
#         concol = odd_flip(H)
#     else:
#         concol = even_flip(H)
        
#     if W%2 == 1:
#         conrow = odd_flip(W)
#     else:
#         conrow = even_flip(W)
        
#     m_col = concol[:,np.newaxis] 
#     m_row = conrow[np.newaxis,:]
#     dist = np.sqrt(m_col*m_col + m_row*m_row) # sqrt(m^2+n^2)
  
#     use_cuda = torch.cuda.is_available()
#     if use_cuda:
#         dist_ = torch.from_numpy(dist).float().cuda()
#     else:
#         dist_ = torch.from_numpy(dist).float()
#     return dist_

# class EnergyLoss(nn.Module):
#     def __init__(self,cuda,alpha,sigma):
#         super(EnergyLoss, self).__init__()
#         self.energylossfunc = EnergylossFunc.apply
#         self.alpha = alpha
#         self.cuda = cuda
#         self.sigma = sigma

#     def forward(self,feat,label):
#         return self.energylossfunc(self.cuda,feat, label,self.alpha,self.sigma)
    
# class EnergylossFunc(Function):
#     '''
#     target: ground truth 
#     feat: Z -0.5. Z：prob of your target class(here is vessel) with shape[B,H,W]. 
#     Z from softmax output of unet with shape [B,C,H,W]. C: number of classes
#     alpha: default 0.35
#     sigma: default 0.25
#     '''
#     @staticmethod
#     def forward(ctx,cuda,feat_levelset,target,alpha,sigma,Gaussian = False):
#         hardtanh = nn.Hardtanh(min_val=0, max_val=1, inplace=False)
#         target = target.float()
#         index_ = dist(target)
#         dim_ = target.shape[1]
#         target = torch.squeeze(target,1)
#         I1 = target + alpha*hardtanh(feat_levelset/sigma) # G_t + alpha*H(phi) in eq(5)
#         dmn = torch.rfft(I1,2,normalized = True, onesided = False)
#         dmn_r = dmn[:,:,:,0] # dmn's real part
#         dmn_i = dmn[:,:,:,1] # dmm's imagine part
#         dmn2 = dmn_r * dmn_r + dmn_i * dmn_i # dmn^2

#         ctx.save_for_backward(feat_levelset,target,dmn,index_)
            
#         F_energy = torch.sum(index_*dmn2)/feat_levelset.shape[0]/feat_levelset.shape[1]/feat_levelset.shape[2] # eq(8)
        
#         return F_energy

#     @staticmethod
#     def backward(ctx,grad_output):
#         feature,label,dmn,index_ = ctx.saved_tensors
#         index_ = torch.unsqueeze(index_,0)
#         index_ = torch.unsqueeze(index_,3)
#         F_diff = -0.5*index_*dmn # eq(9) 
#         diff = torch.irfft(F_diff,2,normalized = True, onesided = False)/feature.shape[0] # eq
#         return None,Variable(-grad_output*diff),None,None,None
    
    
# score1 = y_out[:,0,:,:] # prob for class target
# score2 = (score1-0.5) # for energyloss

# training_loss = self.loss(score2, y_batch)

In [None]:
# adapted from https://github.com/rosanajurdi/Perimeter_loss/blob/master/losses.py
# and https://github.com/jocpae/clDice/blob/master/cldice_loss/pytorch/soft_skeleton.py

def contour(x):
    minn1 = -F.max_pool2d(-x, (3,1), (1,1), (1,0) )
    minn2 = -F.max_pool2d(-x, (1,3),(1,1), (0,1))
    minn = torch.min(minn1, minn2)
    maxx = F.max_pool2d(x, (3,3),(1,1), (1,1))
    res = maxx-minn
    return F.relu_(res)

class ContourLoss(nn.Module):
    def forward(self, pred, targ):
        pred = zero_or_one(pred).float()
        targ = targ.float()
        
#         b, _, w, h = pred.shape
        b, w, h = pred.shape
        
#         cl_pred = contour(pred).sum(axis=(2,3))
        cl_pred = contour(pred).sum(axis=(1,2))
#         target_skeleton = contour(targ).sum(axis=(2,3))
        target_skeleton = contour(targ).sum(axis=(1,2))
        big_pen: Tensor = (cl_pred - target_skeleton) ** 2
        contour_loss = big_pen / (w * h)
    
        return contour_loss.mean(axis=0)
    
contour_loss = ContourLoss()

In [None]:
class AdaptiveLoss(nn.Module):
    def __init__(self, lambda_coeff=0.01):
        super().__init__()
        self.lambda_coeff = lambda_coeff
    
    def forward(self, inputs, targets, smooth=1):
        loss = torch.zeros(inputs.shape[0])
        for i, targ in enumerate(targets):
            if 1 in targ:
                targ1 = targ.view(-1)
                inp1 = inputs[i].view(-1)
                loss1 = dice_loss(targ1, inp1)
                loss2 = contour_loss(inputs[i], targ)
                loss_tot = (1-self.lambda_coeff)*loss1 + self.lambda_coeff*loss2
                loss[i] = loss_tot
            else:
                loss[i] = F.binary_cross_entropy(inputs[i].float(), targ.float(), reduction='mean')

        return loss.mean()

adaptive_loss = AdaptiveLoss()

In [None]:
class CombinedLoss(nn.Module):
    def forward(self, pred, targ):
        loss1 = hh_dtloss(pred, targ).cuda()
        loss2 = adaptive_loss(pred, targ).cuda()
        loss3 = perim_loss(pred, targ)
#         return torch.mean(torch.stack((loss1*0.0001, loss2, loss3*1e-7)))
        return loss1*0.0001 + loss2 + loss3*1e-7
    
combined_loss = CombinedLoss()

In [None]:
# annoyingly we cannot use mix precision fp16 with BCE loss. Otherwise we would slap .to_fp16() on the end of the learner

learner = unet_learner(dls, resnet34, n_out=1, loss_func=adaptive_loss)
# learner = unet_learner(dls, resnet34, n_out=1, loss_func=adaptive_loss)
# learner = Learner(dls, test_model, loss_func=dice_loss)

In [None]:
# learner.load("BASELINE_0.06-0.128")

In [None]:
# learner.model

In [None]:
learner.fit(config["epochs"], lr=config["lr"], cbs=cbs)

In [None]:
torch.save(learner.model, "./models/test_export.pkl")

In [None]:
learner = unet_learner(dls, resnet34, n_out=1, loss_func=combined_loss)

In [None]:
learner.load("./BASELINE_0.06-0.128")

In [None]:
torch.save(learner.model, "./to_upload/baseline_model.pkl")

In [None]:
sett .

## CPU

In [None]:
dls.cpu()
test_data = dls.one_batch()

In [None]:
test_x = test_data[0]
test_y = test_data[1]

In [None]:
model = learner.model.cpu()

In [None]:
test_preds = nn.Sigmoid()(model(test_x))

In [None]:
class ZeroOrOneFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        res = torch.where(input > 0.5, torch.tensor(1.0, requires_grad=True), torch.tensor(0.0, requires_grad=True))
        ctx.save_for_backward(res)
        return res
    
    @staticmethod
    def backward(ctx, grad_output):
        input = ctx.saved_tensors
        return input
    
zero_or_one = ZeroOrOneFunction.apply

In [None]:
# def contour(x):
#     '''
#     Differenciable aproximation of contour extraction
    
#     '''   
#     min_pool_x = torch.nn.functional.max_pool2d(x*-1, (3, 3), 1, 1)*-1
#     max_min_pool_x = torch.nn.functional.max_pool2d(min_pool_x, (3, 3), 1, 1)
#     contour = torch.nn.functional.relu(max_min_pool_x - min_pool_x)
#     return contour


# def soft_skeletonize(x, thresh_width=10):
#     '''
#     Differenciable aproximation of morphological skelitonization operaton
#     thresh_width - maximal expected width of vessel
#     '''
#     for i in range(thresh_width):
#         min_pool_x = torch.nn.functional.max_pool2d(x*-1, (3, 3), 1, 1)*-1
#         max_min_pool_x = torch.nn.functional.max_pool2d(min_pool_x, (3, 3), 1, 1)
#         contour = torch.nn.functional.relu(max_min_pool_x - min_pool_x)
#         x = torch.nn.functional.relu(x - contour)
#     return x

def contour(x):
    minn1 = -F.max_pool2d(-x, (3,1), (1,1), (1,0) )
    minn2 = -F.max_pool2d(-x, (1,3),(1,1), (0,1))
    minn = torch.min(minn1, minn2)
    maxx = F.max_pool2d(x, (3,3),(1,1), (1,1))
    res = maxx-minn
    return F.relu_(res)


class contour_loss():
    '''
    inputs shape  (batch, channel, height, width).
    calculate clDice loss
    Because pred and target at moment of loss calculation will be a torch tensors
    it is preferable to calculate target_skeleton on the step of batch forming,
    when it will be in numpy array format by means of opencv
    '''
        
    def __call__(self, probs: Tensor, target: Tensor, ) -> Tensor:
#         pc = probs[:, self.idc, ...].type(torch.float32)
#         tc = target[:, self.idc, ...].type(torch.float32)

        pc = zero_or_one(probs).float()
        tc = target.float()

        b, _, w, h = pc.shape
        cl_pred = contour(pc).sum(axis=(2,3))
        target_skeleton = contour(tc).sum(axis=(2,3))
        big_pen: Tensor = (cl_pred - target_skeleton) ** 2
        contour_loss = big_pen / (w * h)
    
        return contour_loss.mean(axis=0)

In [None]:
test_preds_formatted = zero_or_one(test_preds)

In [None]:
test_preds_formatted.shape

In [None]:
contour_loss()(test_preds_formatted, test_y)

Find a patch with a glom in it

In [None]:
contour(test_preds_formatted)

In [None]:
1 in test_y[3]

In [None]:
plt.imshow(contour(test_preds_formatted[3]).detach().permute(1,2,0))

In [None]:
plt.imshow(contour(test_y[3]).permute(1,2,0))

In [None]:
F.mse_loss(contour(test_preds[1]).sum(), contour(test_y[1]).sum())

In [None]:
(contour(test_preds[1]).sum() - contour(test_y[1]).sum())**2

In [None]:
contour(test_y[1]).sum()