In [None]:
!pip install -U -q segmentation-models

import numpy as np
import matplotlib.pyplot as plt
import cv2 as cv
import wandb
import tensorflow as tf
from tqdm import tqdm

%env SM_FRAMEWORK=tf.keras
import segmentation_models as sm
from kaggle_secrets import UserSecretsClient

user_secrets = UserSecretsClient()

wandb.login(key=user_secrets.get_secret("wandb-api-key"))

# Load Nerve Segmentation, Tumor Classification and PNI Segmentation models

In [None]:
best_model = wandb.restore('model-best.h5', run_path="xjackor/nerve-segmentation/nkxxxti0", replace=True)
nerve_seg_model = sm.Unet("resnet50", input_shape=(512, 512, 3), classes=1, activation='sigmoid', encoder_weights=None)
nerve_seg_model.load_weights(best_model.name)

best_model = wandb.restore('model-best.h5', run_path="xjackor/tumor-detection/aq3xfs6y", replace=True)
tumor_class_model = tf.keras.applications.ResNet50V2(input_shape=(512, 512, 3), include_top=False, weights=None)
x = tumor_class_model.output
x = tf.keras.layers.Dense(128, activation='relu')(x)
x = tf.keras.layers.GlobalAveragePooling2D()(x)
predictions = tf.keras.layers.Dense(1, activation="sigmoid")(x)
tumor_class_model = tf.keras.models.Model(inputs=tumor_class_model.input, outputs=predictions)
tumor_class_model.load_weights(best_model.name)

best_model = wandb.restore('model-best.h5', run_path="xjackor/pni-segmentation/p6swtzxo", replace=True)
pni_seg_model = sm.Unet("resnet50", input_shape=(512, 512, 4), classes=1, activation='sigmoid', encoder_weights=None)
pni_seg_model.load_weights(best_model.name)

# Overall model

In [None]:
class ConditionalInterpolationLayer(tf.keras.layers.Layer):
    def call(self, inputs):
        segmentation_mask, tumor_probability = inputs
        higher_than_tumor_probability = segmentation_mask > tumor_probability
        average = (segmentation_mask + tumor_probability) / 2.0
        # Replace pixels with higher values than tumor_probability 
        # with average between tumor_probability and particular pixel
        interpolated_mask = tf.where(higher_than_tumor_probability, average, segmentation_mask)
        return interpolated_mask


class Image:
    def __init__(self, image):
        if type(image) is np.ndarray:
            self._raw_image = image
        elif type(image) is str:
            self._raw_image = cv.imread(image)
        self._width = self._raw_image.shape[1]
        self._height = self._raw_image.shape[0]

    @property
    def width(self):
        return self._width

    @property
    def height(self):
        return self._height

    def read_block(
        self,
        rect=None,  # x, y, w, h
        size=None,  # w, h
    ) -> np.ndarray:
        block = self._raw_image.copy()
        if rect is not None:
            block = self._raw_image[
                rect[1] : rect[1] + rect[3], rect[0] : rect[0] + rect[2]
            ]
        if size is not None:
            return cv.resize(block, size, 0, 0, interpolation=cv.INTER_NEAREST)
        return block


class _PredictionMask:
    def __init__(self, width: int, height: int):
        self._width = width
        self._height = height
        self._total = np.zeros(shape=(height, width, 1))
        self._count = np.zeros(shape=(height, width, 1))

    def add_patch(self, mask_patch: np.ndarray, x: int, y: int):
        height, width, _ = mask_patch.shape
        self._total[y:y + height, x:x + width, :] += mask_patch
        self._count[y:y + height, x:x + width, :] += 1

    @property
    def result_mask(self) -> np.ndarray:
        return self._total / self._count


class Preprocessor:
    def __init__(self, patch_size: tuple, output_patch_size: tuple, overlap: float):
        self._patch_size = patch_size
        self._output_patch_size = output_patch_size
        self._overlap = overlap
        self._step_size = int(patch_size[0] * (1 - overlap))

    def extract_patches(self, image: Image):
        total_horizontal_steps = (image.width - self._patch_size[1]) // self._step_size + 1
        total_vertical_steps = (image.height - self._patch_size[0]) // self._step_size + 1
        total_patches = total_horizontal_steps * total_vertical_steps
        print(f"Processing image {image.height}x{image.width}")
        print(f" - Step size: {self._step_size}")
        print(f" - Total patches: {total_patches}")
        return tqdm(self._generate_patches(image), total=total_patches)
        
        
    def _generate_patches(self, image: Image):
        w, h = self._patch_size[1], self._patch_size[0]
        for y in range(0, image.height - h + 1, self._step_size):
            for x in range(0, image.width - w + 1, self._step_size):
                image_block = image.read_block(rect=(x, y, w, h), size=self._output_patch_size)
                yield x, y, image_block / 255.


class Postprocessor():
    def __init__(self, patch_size, original_patch_size):
        self._ratio_height = patch_size[0] / original_patch_size[0]
        self._ratio_width = patch_size[1] / original_patch_size[1] 
        
    def init_postprocessing(self, image: Image):
        self._image = image
        mask_width, mask_height = int(image.width * self._ratio_width), int(image.height * self._ratio_height)
        self._nerve_mask = _PredictionMask(mask_width, mask_height)
        self._pni_mask = _PredictionMask(mask_width, mask_height)
        
    def collect(self, nerve_mask_patch: np.ndarray, pni_mask_patch: np.ndarray, x: int, y: int):
        x, y = int(x * self._ratio_width), int(y * self._ratio_height)
        self._nerve_mask.add_patch(nerve_mask_patch, x, y)
        self._pni_mask.add_patch(pni_mask_patch, x, y)
    
    def display(self):
        image = self._image.read_block()
        nerve_mask = cv.resize(self._nerve_mask.result_mask, (self._image.width, self._image.height), 0, 0)[..., None]
        pni_mask = cv.resize(self._pni_mask.result_mask, (self._image.width, self._image.height), 0, 0)[..., None]
        
        masked_image = np.where(np.round(nerve_mask), np.array([0,255,0], dtype=np.uint8), image)
        masked_image = np.where(np.round(pni_mask), np.array([255,0,0], dtype=np.uint8), masked_image)
        masked_image = cv.addWeighted(image, 0.5, masked_image, 0.5, 0)
        
        plt.figure(figsize=(10, 10))
        plt.imshow(masked_image)
        plt.axis("off")
        plt.show()
    

class Processor:
    _preprocessor_patch_size = (1024, 1024)
    _patch_size = (512, 512)

    def __init__(self, pni_seg_model: tf.keras.Model, nerve_seg_model: tf.keras.Model, tumor_class_model: tf.keras.Model = None, overlap=0.5):
        self._preprocessor = Preprocessor(self._preprocessor_patch_size, self._patch_size, overlap)
        self._nerve_segmenter = nerve_seg_model
        self._tumor_classifier = tumor_class_model
        self._pni_segmenter = pni_seg_model
        self._pni_segmenter_with_tumor = self._make_pni_seg_model_with_tumor_probability(pni_seg_model) if tumor_class_model is not None else None
        self._postprocessor = Postprocessor(self._patch_size, self._preprocessor_patch_size)

    def process(self, image: Image):
        self._postprocessor.init_postprocessing(image)
        
        for (x, y, patch) in self._preprocessor.extract_patches(image):
            nerve_mask = self.segment_nerves(patch)
            tumor_probability= None
            if self._tumor_classifier is not None:
                tumor_probability = self.classify_tumor(patch)
            pni_mask = self.segment_pni(patch, nerve_mask, tumor_probability)
            
            self._postprocessor.collect(nerve_mask_patch=nerve_mask, pni_mask_patch=pni_mask, x=x, y=y)

        self._postprocessor.display()

    def segment_nerves(self, image_patch: np.ndarray) -> np.ndarray:
        input_image = tf.expand_dims(image_patch, axis=0)
        prediction = self._nerve_segmenter.predict(input_image, verbose=0)
        return prediction[0]

    def classify_tumor(self, image_patch: np.ndarray) -> float:
        input_image = tf.expand_dims(image_patch, axis=0)
        prediction = self._tumor_classifier.predict(input_image, verbose=0)
        return prediction[0][0]
     
    def segment_pni(self, image_patch: np.ndarray, nerve_mask_patch: np.ndarray, tumor_probability: float | None) -> np.ndarray:
        input_image = tf.concat([image_patch, nerve_mask_patch], axis=-1)
        input_image = tf.expand_dims(input_image, axis=0)
        if tumor_probability is not None and self._pni_segmenter_with_tumor is not None:
            input_probability = tf.expand_dims(tf.constant(tumor_probability), axis=0)
            prediction = self._pni_segmenter_with_tumor.predict({"image": input_image, "tumor_probability": input_probability}, verbose=0)
        else:    
            prediction = self._pni_segmenter.predict(input_image, verbose=0)
        return prediction[0]
    
    def _make_pni_seg_model_with_tumor_probability(self, pni_seg_model: tf.keras.Model):
        image_input = tf.keras.layers.Input(shape=(512, 512, 4), name="image")
        tumor_prob_input = tf.keras.layers.Input(shape=(1,), name="tumor_probability")
        segmentation_output = pni_seg_model(inputs=image_input)
        interpolationLayer = ConditionalInterpolationLayer(name="conditional_interpolation")
        interpolation = interpolationLayer([segmentation_output, tumor_prob_input])
        return tf.keras.models.Model(inputs=[image_input, tumor_prob_input], outputs=interpolation)

In [None]:
image = Image("/kaggle/input/test-data/test_data/test_image.jpg")
processor = Processor(pni_seg_model, nerve_seg_model, tumor_class_model, overlap=0.5)
prediction = processor.process(image)

# Grad-CAM

In [None]:
def compute_gradcam(img, gradcam_model):
    with tf.GradientTape() as tape:
        tape.watch(img)
        conv_output, pred = gradcam_model(img, training=False)
        class_channel = pred[0]

    grads = tape.gradient(class_channel, conv_output)

    pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))

    conv_output = conv_output[0]

    heatmap = conv_output @ pooled_grads[..., tf.newaxis]
    heatmap = tf.squeeze(heatmap)
    heatmap = tf.maximum(heatmap, 0) / tf.math.reduce_max(heatmap)

    return (heatmap, pred)

def visualize_gradcam(image, model):
    gradcam_model = tf.keras.Model(
        [model.input],
        [model.get_layer(name='stage4_unit3_conv3').input, model.output]
    )
    plt.figure(figsize=(5, 5))
    image_tensor = tf.expand_dims(image, axis=0)
    heat_map, pred = compute_gradcam(image_tensor, gradcam_model)
    heat_map = tf.expand_dims(heat_map, axis=-1)
    heat_map = tf.image.resize(heat_map, (512,512))
    plt.imshow(image[:,:,:3])
    
    plt.imshow(heat_map, cmap='jet', alpha=0.3, vmax=1, vmin=0)
    plt.axis("off")
    plt.show()