<a href="https://colab.research.google.com/github/scancer-org/ml-pcam-classification/blob/main/notebooks/10_WSI_Tumour_Prediction.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Step Plan

- [x] Setup all the requirements for the code
- [x] Load a Google Drive that has symbolic links to the Camleyon16 dataset
- [x] Load a sample file
- [x] Get a sliding window and padding parameters
- [x] Pad the WSI
- [x] Segment the tissue (and create a mask)
- [x] Load up the PCam model from the stored weights
- [x] Predict tumour regions
- [x] Colourise the regions and save
- [x] Try to size tensors so it doesn't crash running out of memory


In [1]:
!wget -nc -q https://github.com/computationalpathologygroup/ASAP/releases/download/1.9/ASAP-1.9-Linux-Ubuntu1804.deb
!sudo apt-get -qq -y install ./ASAP-1.9-Linux-Ubuntu1804.deb

In [2]:
import argparse
import torch
import numpy as np
import matplotlib.cm as cm
from skimage.filters import gaussian
from xml.etree import ElementTree as ET
from PIL import Image, ImageDraw
from skimage.color import rgb2hsv
from skimage.transform import resize
from google.colab import drive
import gc
import sys
sys.path.append(r'/opt/ASAP/bin')
try:
    import multiresolutionimageinterface as mir
except ImportError:
    print("ASAP package not installed.")

In [3]:
# WSI_FILE = "/content/drive/MyDrive/FSDL Project/CAMELYON16/training/tumor/tumor_029.tif"
WSI_FILE = "/content/drive/MyDrive/FSDL Project/CAMELYON16/testing/images/test_001.tif"
WSI_LEVEL = 2 # Magnification 0=40x, 1=20x, 2=10x, ...
PCAM_MODEL = "/content/drive/MyDrive/FSDL Project/PCam/pcam_cnn_v1.2.pt"
OUTPUT_PATH = "/content/drive/MyDrive/FSDL Project/"
WINDOW = (96, 96)

In [4]:
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [5]:
class TIFReader:

    def __init__(self, file, level):

        self.file = file
        self.reader = mir.MultiResolutionImageReader()
        self.mr_image = self.reader.open(self.file)
        self.level = level

    def get_shape(self):
        # X, Y
        return self.mr_image.getLevelDimensions(self.level)

    def load_patch(self, x, y, width, height):
        ds = self.mr_image.getLevelDownsample(self.level)
        image_patch = self.mr_image.getUCharPatch(int(x * ds), int(y * ds), width, height, self.level)
        return image_patch

    def load_image(self):
        assert self.level >= 2
        shape = self.get_shape()
        # TODO: Remove this hack
        width = shape[0]
        height = shape[1]
        return self.load_patch(0, height//2, width//2, height//2)
        # return self.load_patch(0, 0, shape[0], shape[1])

    @staticmethod
    def segment_tissue(image):
        resized = image[::16, ::16, :].copy()
        hsv = rgb2hsv(resized)
        return resize(hsv[:, :, 1], image.shape[:2], mode='constant', cval=0, anti_aliasing=False)

In [6]:
def sliding_window(image_shape, window_shape, stride=None):

    if stride is None:
        stride = (window_shape[0], window_shape[1])

    # Padding
    padding_x = 0 if image_shape[1] % window_shape[1] == 0 else window_shape[1] - image_shape[1] % window_shape[1]
    padding_y = 0 if image_shape[0] % window_shape[0] == 0 else window_shape[0] - image_shape[0] % window_shape[0]
    padded_shape = (image_shape[0] + padding_y, image_shape[1] + padding_x)

    x = np.arange(0, padded_shape[1], stride[1])
    y = np.arange(0, padded_shape[0], stride[0])

    x1, y1 = np.meshgrid(x, y)

    x2 = x1 + window_shape[1]
    y2 = y1 + window_shape[0]

    return np.stack([x1, y1, x2, y2], axis=2), {'x': padding_x, 'y': padding_y}

In [7]:
def predict_tumor_regions(wsi, tissue_mask, windows):

    model = torch.jit.load(PCAM_MODEL)
    device = torch.device("cuda")
    model.eval()
  
    # Initialize with zeros
    tumor = np.zeros(wsi.shape[:2])

    for i in range(windows.shape[0]):
        for j in range(windows.shape[1]):

            # [x1, y1, x2, y2]
            bbox = windows[i, j, :].reshape(-1)

            # Tissue mask patch
            mask_patch = tissue_mask[bbox[1]:bbox[3], bbox[0]: bbox[2]]

            if mask_patch.mean() > 0.075:

                # Select patch from window
                wsi_patch = np.expand_dims(wsi[bbox[1]:bbox[3], bbox[0]: bbox[2], :].copy(), axis=0)

                # Convert to tensor
                wsi_tensor = torch.from_numpy(wsi_patch).permute(0, 3, 1, 2).float().to(device) / 255.

                # Inference
                tumor[bbox[1]:bbox[3], bbox[0]:bbox[2]] = torch.sigmoid(model(wsi_tensor)).squeeze().item()

    return gaussian(tumor, preserve_range=True)

In [8]:
reader = TIFReader(WSI_FILE, WSI_LEVEL)

In [9]:
wsi = reader.load_image()

In [10]:
windows, padding = sliding_window(wsi.shape, WINDOW)

In [11]:
wsi_padded = np.pad(wsi, ((0, padding['y']), (0, padding['x']), (0, 0)), mode='constant', constant_values=255)

In [12]:
del wsi
gc.collect()

417

In [13]:
tissue_mask = reader.segment_tissue(wsi_padded)

In [14]:
tumor_map = predict_tumor_regions(wsi_padded, tissue_mask, windows)

In [15]:
del reader
del wsi_padded
del tissue_mask
del windows
gc.collect()

50

In [16]:
# np.save(OUTPUT_PATH + '/normal_116_2.npy', tumor_map)

In [17]:
cmapper = cm.get_cmap('plasma')
colorized = Image.fromarray(np.uint8(cmapper(np.clip(tumor_map, 0, 1)) * 255))
colorized.save(OUTPUT_PATH + '/test_001_4.png')

In [18]:
del colorized
gc.collect()

50