# Config

In [None]:
WB_CONFIG = {
    "epochs": 40,
    "batch_size": 16,
    "split_ratio": (0.8, 0.1, 0.1),
    "image_size": (512, 512),
    "loss": "binary_focal_dice_loss",
    "architecture": "resnet50",
    "nerves_channel": True,
    "learning_rate": 0.0002,
}
INTERNAL_CONFIG = WB_CONFIG.copy()
BUFFER_SIZE = 512

# Dataset initialization

## Prepare nerve masks for 4th input channel

In [None]:
!pip install -U -q segmentation-models

import os
import tensorflow as tf
import cv2 as cv
from kaggle_secrets import UserSecretsClient
import wandb

%env SM_FRAMEWORK=tf.keras
import segmentation_models as sm

user_secrets = UserSecretsClient()

wandb.login(key=user_secrets.get_secret("wandb-api-key"))

def load_image(image_path, image_name):
    image = tf.io.read_file(image_path)
    image = tf.io.decode_jpeg(image, channels=3)
    image = tf.image.convert_image_dtype(image, tf.float32)
    image = tf.image.resize(image, INTERNAL_CONFIG["image_size"])
    return image, image_name

if INTERNAL_CONFIG["nerves_channel"]:
    # Restore and load the model
    best_model = wandb.restore('model-best.h5', run_path="xjackor/nerve-segmentation/cfb3kp0w", replace=True)
    model = sm.Unet("resnet50", input_shape=(512, 512, 3), classes=1, activation='sigmoid', encoder_weights=None)
    model.load_weights(best_model.name)

    for folder in ["nerves", "pni", "tumors"]:
        image_names = [image_name for image_name in os.listdir(f"/kaggle/input/masters-thesis-extended/data/pancreas/{folder}/images/")]
        image_paths = [f"/kaggle/input/masters-thesis-extended/data/pancreas/{folder}/images/{image_name}" for image_name in image_names]
        dataset = tf.data.Dataset.from_tensor_slices((tf.convert_to_tensor(image_paths, dtype=tf.string), tf.convert_to_tensor(image_names, dtype=tf.string)))
        dataset = dataset.map(load_image)
        print(f"Processing {folder} folder with {dataset.cardinality().numpy()} images...")

        # Create and store predicted nerve masks
        os.makedirs(f"./nerve_masks/{folder}/", exist_ok=True)
        for images, image_names in dataset.batch(64):
            masks = model.predict(images, verbose=0)
            for mask, image_name in zip(masks, image_names):
                cv.imwrite(f"./nerve_masks/{folder}/{os.path.splitext(image_name.numpy().decode('utf-8'))[0]}.png", mask * 255)

## Helpers

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import random

def make_config(ds_size: int, split_ratio: tuple, batch_size: int) -> dict:
    assert sum(split_ratio) == 1
    train_size = int(ds_size * split_ratio[0])
    val_size = round(ds_size * split_ratio[1])
    test_size = round(ds_size * split_ratio[2])
    return {
        "train_size": train_size,
        "val_size": val_size,
        "test_size": test_size,
        "steps_per_epoch": train_size // batch_size,
        "val_steps": val_size // batch_size
    }

# Splits the dataset into 2 parts based on the train_ratio and returns these 2 parts
def split(dataset, train_size):
    dataset = dataset.shuffle(buffer_size=BUFFER_SIZE)
    return dataset.take(train_size), dataset.skip(train_size)

def dilate_mask(image, mask):
    kernel = tf.ones((5, 5, 1), dtype=tf.float32)
    dilated_mask = tf.nn.dilation2d(input=tf.expand_dims(mask, axis=0),
                                    filters=kernel,
                                    strides=[1, 1, 1, 1],
                                    padding='SAME',
                                    data_format='NHWC',
                                    dilations=[1, 1, 1, 1])
    return image, tf.squeeze(tf.math.round(dilated_mask), axis=0) - 1

# Loads all paths to images and their corresponding masks and returns them as lists
def load_pni_image_and_mask_paths(data_path: str,  nerve_channel: bool):
    image_folder_path = data_path + "images/"
    mask_folder_path = data_path + "masks/"
    
    ids = [os.path.splitext(image)[0] for image in os.listdir(image_folder_path)]

    image_paths = [image_folder_path + id + ".jpg" for id in ids]
    mask_paths = [mask_folder_path + id + ".png" for id in ids]

    image_paths = tf.convert_to_tensor(image_paths, dtype=tf.string)
    mask_paths = tf.convert_to_tensor(mask_paths, dtype=tf.string)
    
    if nerve_channel:
        nerve_mask_folder_path = "./nerve_masks/pni/"
        nerve_mask_paths = [nerve_mask_folder_path + id + ".png" for id in ids]
        nerve_mask_paths = tf.convert_to_tensor(nerve_mask_paths, dtype=tf.string)
        return image_paths, mask_paths, nerve_mask_paths

    return image_paths, mask_paths

# Loads all paths to images and and empty mask and returns them as lists
def load_non_pni_image_and_mask_paths(data_path: str,  empty_mask_path: str, max_size: int, nerve_channel: bool):
    image_folder_path = data_path + "images/"
    
    ids = [os.path.splitext(image)[0] for image in os.listdir(image_folder_path)]
    ids = random.sample(ids, max_size)

    image_paths = [image_folder_path + id + ".jpg" for id in ids]
    empty_mask_paths = [empty_mask_path for _ in range(max_size)]

    image_paths = tf.convert_to_tensor(image_paths, dtype=tf.string)
    empty_mask_paths = tf.convert_to_tensor(empty_mask_paths, dtype=tf.string)
    
    if nerve_channel:
        nerve_mask_folder_path = f"./nerve_masks/{data_path[-7:-1]}/"
        nerve_mask_paths = [nerve_mask_folder_path + id + ".png" for id in ids]
        nerve_mask_paths = tf.convert_to_tensor(nerve_mask_paths, dtype=tf.string)
        return image_paths, empty_mask_paths, nerve_mask_paths

    return image_paths, empty_mask_paths

# Returns the function that loads a single image and its corresponding mask from the provided paths
def load_images_and_masks(image_size: tuple = (512, 512)):
    def load_images_func(image_path, mask_path, nerve_mask_path=None):
        image = tf.io.read_file(image_path)
        image = tf.io.decode_jpeg(image, channels=3)
        image = tf.image.convert_image_dtype(image, tf.float32)
        image = tf.image.resize(image, image_size)

        mask = tf.io.read_file(mask_path)
        mask = tf.io.decode_png(mask, channels=1) // 255
        mask = tf.cast(mask, tf.float32)
        mask = tf.image.resize(mask, image_size, method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
        
        if nerve_mask_path is not None:
            # Load predicted nerves mask and add it as a 4th channel
            nerve_mask = tf.io.read_file(nerve_mask_path)
            nerve_mask = tf.io.decode_png(nerve_mask, channels=1)
            nerve_mask = tf.cast(nerve_mask, tf.float32) / 255.
            image = tf.concat([image, nerve_mask], axis=-1)

        return image, mask

    return load_images_func

def augment(seed=37, nerve_channel=False):
    augment_image_model = tf.keras.Sequential([
        tf.keras.layers.RandomRotation(0.15, seed=seed),
        tf.keras.layers.RandomFlip(mode="horizontal_and_vertical", seed=seed),
    ])
    augment_mask_model = tf.keras.Sequential([
        tf.keras.layers.RandomRotation(0.15, seed=seed),
        tf.keras.layers.RandomFlip(mode="horizontal_and_vertical", seed=seed),
    ])
    rng = tf.random.Generator.from_seed(seed, alg='philox')

    def augment_func(image, mask):
        nerves_mask = None
        if nerve_channel:
            image, nerves_mask = image[:,:,:3], image[:,:,3:]
        rng_seed = rng.make_seeds(1)[:, 0]
        image = tf.image.stateless_random_brightness(image, max_delta=0.1, seed=rng_seed)
        image = tf.image.stateless_random_contrast(image, 0.5, 1.5, seed=rng_seed)
        image = tf.image.stateless_random_hue(image, max_delta=0.1, seed=rng_seed)
        image = tf.image.stateless_random_saturation(image, 0.8, 1.2, seed=rng_seed)
        image = tf.clip_by_value(image, 0.0, 1.0)
        if nerve_channel:
            image = tf.concat([image, nerves_mask], axis=-1)
        return augment_image_model((image)), augment_mask_model((mask))
    
    return augment_func

def draw_mask(image, pni_mask, nerves_mask=None):
    masked_image = image.copy()
    if nerves_mask is not None:
        nerves_mask_3_ch = np.repeat(np.round(nerves_mask), 3, axis=2)
        masked_image = np.where(nerves_mask_3_ch, np.array([0,255,0], dtype=np.uint8), masked_image)
    pni_mask_3_ch = np.repeat(np.round(pni_mask), 3, axis=2)
    masked_image = np.where(pni_mask_3_ch, np.array([255,0,0], dtype=np.uint8), masked_image)
    return cv.addWeighted(image, 0.3, masked_image, 0.7, 0)

def show(image, pni_mask, prediction=None, nerve_channel=False):
    nerves_mask = None
    if nerve_channel:
        image, nerves_mask = image[:,:,:3], image[:,:,3:]
    image = (image * 255).astype('uint8')
    masked_image = draw_mask(image, pni_mask)
    
    titles = ["Input Image", "True Mask", "Predicted Mask"]
    display_list = [image, masked_image] if prediction is None else [image, masked_image, draw_mask(image, prediction, nerves_mask)]
    
    plt.figure(figsize=(15, 15))
    
    for i in range(len(display_list)):
        plt.subplot(1, len(display_list), i+1)
        plt.title(titles[i])
        plt.imshow(display_list[i])
        plt.axis("off")
    
    plt.show()


class DisplayCallback(tf.keras.callbacks.Callback):
    def __init__(self, dataset, nerve_channel):
        super().__init__()
        self._dataset = dataset
        self.nerve_channel = nerve_channel
        
    def _make_wandb_mask(self, image, mask, prediction):
        if self.nerve_channel:
            image = image[:,:,:3]
        return wandb.Image(image, masks={
            "prediction": {"mask_data": prediction.numpy()[:,:,0], "class_labels": {0: "tissue", 1:"nerve"}},
            "ground truth": {"mask_data": mask.numpy()[:,:,0], "class_labels": {0: "tissue", 1:"nerve"}}
        })
        
    def on_epoch_end(self, epoch, logs=None):
        imagelist = []
        for image, mask in self._dataset:
            imagelist.append(self._make_wandb_mask(image, mask, tf.math.round(model.predict(tf.expand_dims(image, axis=0))[0])))
        wandb.log({"segmentation_list": imagelist})

## Load dataset

In [None]:
# Load PNI paths
PNI_DATA_PATH = "/kaggle/input/masters-thesis-extended/data/pancreas/pni/"
pni_paths = load_pni_image_and_mask_paths(PNI_DATA_PATH, INTERNAL_CONFIG["nerves_channel"])

# Load some nerve and tumor/non-tumor paths with empty masks
NERVES_DATA_PATH = "/kaggle/input/masters-thesis-extended/data/pancreas/nerves/"
TUMORS_DATA_PATH = "/kaggle/input/masters-thesis-extended/data/pancreas/tumors/"
EMPTY_PNI_MASK_PATH = "/kaggle/working/empty_mask.png"

empty_mask = np.zeros((512, 512, 1), dtype=np.uint8)
cv.imwrite(EMPTY_PNI_MASK_PATH, empty_mask)

nerve_paths = load_non_pni_image_and_mask_paths(NERVES_DATA_PATH, EMPTY_PNI_MASK_PATH, 431, INTERNAL_CONFIG["nerves_channel"])
tumor_paths = load_non_pni_image_and_mask_paths(TUMORS_DATA_PATH, EMPTY_PNI_MASK_PATH, 431, INTERNAL_CONFIG["nerves_channel"])

paths = []
for p in zip(pni_paths, nerve_paths, tumor_paths):
    paths.append(tf.concat(p, axis=0))

# Load the dataset
dataset = tf.data.Dataset.from_tensor_slices(tuple(paths))

dataset = dataset.map(load_images_and_masks(INTERNAL_CONFIG["image_size"]))
dataset = dataset.map(dilate_mask)
ds_size = dataset.cardinality().numpy()

INTERNAL_CONFIG.update(make_config(ds_size, INTERNAL_CONFIG["split_ratio"], INTERNAL_CONFIG["batch_size"]))

train_ds, val_ds = split(dataset, INTERNAL_CONFIG["train_size"])
val_ds, test_ds = split(val_ds, INTERNAL_CONFIG["val_size"])

print(f"TRAIN: {train_ds.cardinality()}\nVAL: {val_ds.cardinality()}\nTEST: {test_ds.cardinality()}")


train_batches = (train_ds.shuffle(BUFFER_SIZE)
                         .map(augment(nerve_channel=INTERNAL_CONFIG["nerves_channel"]))
                         .batch(INTERNAL_CONFIG["batch_size"])
                         .repeat()
                         .prefetch(buffer_size=tf.data.AUTOTUNE))
val_batches = val_ds.batch(INTERNAL_CONFIG["batch_size"])
test_batches = test_ds.batch(INTERNAL_CONFIG["batch_size"])

# PNI segmentation model

## Prepare model

In [None]:
LOSS_CONFIG_TO_LOSS = {
    "dice_loss": sm.losses.dice_loss,
    "bce_dice_loss": sm.losses.bce_dice_loss,
    "binary_focal_dice_loss": sm.losses.binary_focal_dice_loss,
}

channels = 4 if INTERNAL_CONFIG["nerves_channel"] else 3

model = sm.Unet(INTERNAL_CONFIG["architecture"], encoder_weights=None, input_shape=(*INTERNAL_CONFIG["image_size"], channels), classes=1, activation='sigmoid')

model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=INTERNAL_CONFIG["learning_rate"]),              
              loss=LOSS_CONFIG_TO_LOSS[INTERNAL_CONFIG["loss"]],
              metrics=[sm.metrics.iou_score, sm.metrics.f1_score, sm.metrics.precision, tf.keras.metrics.Recall()])

# Training

In [None]:
wandb.init(project="pni-segmentation", entity="xjackor", config=WB_CONFIG)

display_patches = dataset.skip(INTERNAL_CONFIG["train_size"]).take(8)

model.fit(train_batches,
          epochs=INTERNAL_CONFIG["epochs"],
          steps_per_epoch=INTERNAL_CONFIG["steps_per_epoch"],
          validation_steps=INTERNAL_CONFIG["val_steps"],
          validation_data=val_batches,
          callbacks=[DisplayCallback(display_patches, INTERNAL_CONFIG["nerves_channel"]), wandb.keras.WandbCallback(mode="min")])

wandb.run.finish()

# Evaluate

In [None]:
model.evaluate(test_batches)

In [None]:
for images, masks in test_batches.take(1):
    for image, mask in zip(images, masks):
        show(image.numpy(), mask.numpy(), model.predict(tf.expand_dims(image, axis=0))[0], INTERNAL_CONFIG["nerves_channel"])