In [None]:
!pip install -q "/kaggle/input/pycocotools/pycocotools-2.0-cp37-cp37m-linux_x86_64.whl"
!pip install -q "/kaggle/input/hpapytorchzoozip/pytorch_zoo-master"
!pip install -q "/kaggle/input/hpacellsegmentatormaster/HPA-Cell-Segmentation-master"

In [None]:
import sys
sys.path.insert(0, "../input/hpa-script")
sys.path.insert(0, "../input/timm-pytorch-image-models/pytorch-image-models-master")

In [None]:
import base64
import glob
import typing as t
import zlib

import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2
import cv2
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from pycocotools import _mask as coco_mask
import pytorch_lightning as pl

import torch
from torch.utils.data import Dataset, DataLoader

from models import HPAClassifier
from dataset import HPA_RGB_MEAN, HPA_RGB_STD

In [None]:
import hpacellseg.cellsegmentator as cellsegmentator

In [None]:
"""Utility functions for the HPA Cell Segmentation package."""
import scipy.ndimage as ndi
from skimage import filters, measure, segmentation
from skimage.morphology import (binary_erosion, closing, disk,
                                remove_small_holes, remove_small_objects)

HIGH_THRESHOLD = 0.4
LOW_THRESHOLD = HIGH_THRESHOLD - 0.25

def __fill_holes(image):
    """Fill_holes for labelled image, with a unique number."""
    boundaries = segmentation.find_boundaries(image)
    image = np.multiply(image, np.invert(boundaries))
    image = ndi.binary_fill_holes(image > 0)
    image = ndi.label(image)[0]
    return image

def label_cell(nuclei_pred, cell_pred):
    """Label the cells and the nuclei.
    Keyword arguments:
    nuclei_pred -- a 3D numpy array of a prediction from a nuclei image.
    cell_pred -- a 3D numpy array of a prediction from a cell image.
    Returns:
    A tuple containing:
    nuclei-label -- A nuclei mask data array.
    cell-label  -- A cell mask data array.
    0's in the data arrays indicate background while a continous
    strech of a specific number indicates the area for a specific
    cell.
    The same value in cell mask and nuclei mask refers to the identical cell.
    NOTE: The nuclei labeling from this function will be sligthly
    different from the values in :func:`label_nuclei` as this version
    will use information from the cell-predictions to make better
    estimates.
    """
    def __wsh(
        mask_img,
        threshold,
        border_img,
        seeds,
        threshold_adjustment=0.35,
        small_object_size_cutoff=10,
    ):
        img_copy = np.copy(mask_img)
        m = seeds * border_img  # * dt
        img_copy[m <= threshold + threshold_adjustment] = 0
        img_copy[m > threshold + threshold_adjustment] = 1
        img_copy = img_copy.astype(np.bool)
        img_copy = remove_small_objects(img_copy, small_object_size_cutoff).astype(
            np.uint8
        )

        mask_img[mask_img <= threshold] = 0
        mask_img[mask_img > threshold] = 1
        mask_img = mask_img.astype(np.bool)
        mask_img = remove_small_holes(mask_img, 63) # CHECK 1/8 ORIGINAL VALUE: 1000
        mask_img = remove_small_objects(mask_img, 1).astype(np.uint8) # CHECK 2/8 ORIGINAL VALUE: 8
        markers = ndi.label(img_copy, output=np.uint32)[0]
        labeled_array = segmentation.watershed(
            mask_img, markers, mask=mask_img, watershed_line=True
        )
        return labeled_array

    nuclei_label = __wsh(
        nuclei_pred[..., 2] / 255.0,
        0.4,
        1 - (nuclei_pred[..., 1] + cell_pred[..., 1]) / 255.0 > 0.05,
        nuclei_pred[..., 2] / 255,
        threshold_adjustment=-0.25,
        small_object_size_cutoff=32, # CHECK 3/8 ORIGINAL VALUE: 500
    )

    # for hpa_image, to remove the small pseduo nuclei
    nuclei_label = remove_small_objects(nuclei_label, 157) # CHECK 4/8 ORIGINAL VALUE: 2500
    nuclei_label = measure.label(nuclei_label)
    # this is to remove the cell borders' signal from cell mask.
    # could use np.logical_and with some revision, to replace this func.
    # Tuned for segmentation hpa images
    threshold_value = max(0.22, filters.threshold_otsu(cell_pred[..., 2] / 255) * 0.5)
    # exclude the green area first
    cell_region = np.multiply(
        cell_pred[..., 2] / 255 > threshold_value,
        np.invert(np.asarray(cell_pred[..., 1] / 255 > 0.05, dtype=np.int8)),
    )
    sk = np.asarray(cell_region, dtype=np.int8)
    distance = np.clip(cell_pred[..., 2], 255 * threshold_value, cell_pred[..., 2])
    cell_label = segmentation.watershed(-distance, nuclei_label, mask=sk)
    cell_label = remove_small_objects(cell_label, 344).astype(np.uint8) # CHECK 5/8 ORIGINAL VALUE: 5500
    selem = disk(2) # CHECK 6/8 ORIGINAL VALUE: 6
    cell_label = closing(cell_label, selem)
    cell_label = __fill_holes(cell_label)
    # this part is to use green channel, and extend cell label to green channel
    # benefit is to exclude cells clear on border but without nucleus
    sk = np.asarray(
        np.add(
            np.asarray(cell_label > 0, dtype=np.int8),
            np.asarray(cell_pred[..., 1] / 255 > 0.05, dtype=np.int8),
        )
        > 0,
        dtype=np.int8,
    )
    cell_label = segmentation.watershed(-distance, cell_label, mask=sk)
    cell_label = __fill_holes(cell_label)
    cell_label = np.asarray(cell_label > 0, dtype=np.uint8)
    cell_label = measure.label(cell_label)
    cell_label = remove_small_objects(cell_label, 344) # CHECK 7/8 ORIGINAL VALUE: 5500
    cell_label = measure.label(cell_label)
    cell_label = np.asarray(cell_label, dtype=np.uint16)
    nuclei_label = np.multiply(cell_label > 0, nuclei_label) > 0
    nuclei_label = measure.label(nuclei_label)
    nuclei_label = remove_small_objects(nuclei_label, 157) # CHECK 8/8 ORIGINAL VALUE: 2500
    nuclei_label = np.multiply(cell_label, nuclei_label > 0)

    return nuclei_label, cell_label

In [None]:
class HPATestDataset(Dataset):
    def __init__(self, dataset_dir):
        super().__init__()
        self.dataset_dir = dataset_dir
        self.make_file_list()
    
    def __getitem__(self, index):
        red = cv2.imread(self.red[index], cv2.IMREAD_GRAYSCALE)
        green = cv2.imread(self.green[index], cv2.IMREAD_GRAYSCALE)
        blue = cv2.imread(self.blue[index], cv2.IMREAD_GRAYSCALE)
        yellow = cv2.imread(self.yellow[index], cv2.IMREAD_GRAYSCALE)
        
        sample = {
            "red": red,
            "green": green,
            "blue": blue,
            "yellow": yellow,
        }
        return sample
    
    def __len__(self):
        return len(self.red)
    
    def make_file_list(self):
        self.red = glob.glob(self.dataset_dir + "/" + "*_red.png")
        self.green = [f.replace("red", "green") for f in self.red]
        self.blue = [f.replace("red", "blue") for f in self.red]
        self.yellow = [f.replace("red", "yellow") for f in self.red]

In [None]:
class HPACellTestDataset(HPATestDataset):
    def __getitem__(self, index):
        red = cv2.imread(self.red[index], cv2.IMREAD_GRAYSCALE)
        green = cv2.imread(self.green[index], cv2.IMREAD_GRAYSCALE)
        blue = cv2.imread(self.blue[index], cv2.IMREAD_GRAYSCALE)
        yellow = cv2.imread(self.yellow[index], cv2.IMREAD_GRAYSCALE)

        sample = {
            "nuc": [blue],
            "cell": [
                [red], [yellow], [blue]
            ],
            "rgb": np.dstack((red, green, blue)),
            "image_id": self.red[index].split("/")[-1].split("_")[0],
        }
        
        return sample

In [None]:
def make_batch(samples):
    nuc = [sample["nuc"][0] for sample in samples]
    mt = [sample["cell"][0][0] for sample in samples]
    er = [sample["cell"][1][0] for sample in samples]
    nu = [sample["cell"][2][0] for sample in samples]
    cell = [mt, er, nu]
    
    rgb = [sample["rgb"] for sample in samples]
    image_id = [sample["image_id"] for sample in samples]
    return {"nuc":nuc, "cell":cell, "rgb": rgb, "image_id": image_id}

In [None]:
def encode_binary_mask(mask: np.ndarray) -> t.Text:
    """Converts a binary mask into OID challenge encoding ascii text."""

    # check input mask --
    if mask.dtype != np.bool:
        raise ValueError(
            "encode_binary_mask expects a binary mask, received dtype == %s" %
            mask.dtype)

    mask = np.squeeze(mask)
    if len(mask.shape) != 2:
        raise ValueError(
            "encode_binary_mask expects a 2d mask, received shape == %s" %
            mask.shape)

    # convert input mask to expected COCO API input --
    mask_to_encode = mask.reshape(mask.shape[0], mask.shape[1], 1)
    mask_to_encode = mask_to_encode.astype(np.uint8)
    mask_to_encode = np.asfortranarray(mask_to_encode)

    # RLE encode mask --
    encoded_mask = coco_mask.encode(mask_to_encode)[0]["counts"]

    # compress and base64 encoding --
    binary_str = zlib.compress(encoded_mask, zlib.Z_BEST_COMPRESSION)
    base64_str = base64.b64encode(binary_str)
    return base64_str

In [None]:
def get_bbox_from_mask(mask):
    """get bbox from single boolean mask"""
    coords = np.argwhere(mask)
    x_min = coords[:, 1].min()
    x_max = coords[:, 1].max()
    y_min = coords[:, 0].min()
    y_max = coords[:, 0].max()
    
    bbox = [x_min, x_max, y_min, y_max]
    return bbox

In [None]:
def plot_cell_mask_bbox(mask, mask_id):
    fig, ax = plt.subplots()
    mask = mask==mask_id
    bbox = get_bbox_from_mask(mask)
    bbox_mask = np.zeros(mask.shape)
    bbox_mask[bbox[2]:bbox[3], bbox[0]:bbox[1]] = 1
    ax.imshow(mask, alpha=0.5)
    ax.imshow(bbox_mask, alpha=0.3)

In [None]:
def plot_single_cell(single_cell_mask, rgb_image):
    fig, (ax1, ax2) = plt.subplots(2, figsize=(10, 20))
    
    bbox = get_bbox_from_mask(single_cell_mask)
    bbox_mask = np.zeros(single_cell_mask.shape)
    bbox_mask[bbox[2]:bbox[3], bbox[0]:bbox[1]] = 1
    
    pad_rgb_image = rgb_image[bbox[2]:bbox[3], bbox[0]:bbox[1], :]

    ax1.imshow(rgb_image, alpha=1.0)
    ax1.imshow(single_cell_mask, alpha=0.5)
    ax1.imshow(bbox_mask, alpha=0.3)
    
    ax2.imshow(pad_rgb_image)

In [None]:
def inference(image, model, device, height, width):
    transform = A.Compose(
        [
            A.Resize(height=height, width=width),
            A.Normalize(mean=HPA_RGB_MEAN, std=HPA_RGB_STD),
            ToTensorV2(),
        ]
    )
    
    image = transform(image=image)["image"]
    image = torch.unsqueeze(image, 0)

    if DEBUG:
        print(f"clf inference resized image.size(): {image.size()}")

    if isinstance(model, list):
        pred_list = []
        for m in model:
            pred = m(image.to(device))
            pred_list.append(pred)
            
        pred = torch.stack(pred_list, dim=0)
        if DEBUG: print(f"clf inference pred.size(): {pred.size()}")
        pred = torch.mean(pred, 0)
        if DEBUG: print(f"clf inference pred.size(): {pred.size()}")
    else:
        pred = model(image.to(device))
        if DEBUG: print(f"clf inference pred.size(): {pred.size()}")

    pred = torch.squeeze(pred, 0)
    if DEBUG: print(f"clf inference pred.size(): {pred.size()}")
    
    return pred

In [None]:
def inference_tta(image, model, device, height, width):
    transform = A.Compose(
        [
            A.Resize(height=height, width=width),
            A.Normalize(mean=HPA_RGB_MEAN, std=HPA_RGB_STD),
            ToTensorV2(),
        ]
    )

    transform_hflip = A.Compose(
        [
            A.Resize(height=height, width=width),
            A.HorizontalFlip(p=1.0),
            A.Normalize(mean=HPA_RGB_MEAN, std=HPA_RGB_STD),
            ToTensorV2(),
        ]
    )

    transform_vflip = A.Compose(
        [
            A.Resize(height=height, width=width),
            A.VerticalFlip(p=1.0),
            A.Normalize(mean=HPA_RGB_MEAN, std=HPA_RGB_STD),
            ToTensorV2(),
        ]
    )

#     transform_rot90f = A.Compose(
#         [
#             A.Resize(height=height, width=width),
#             A.Rotate((90, 90), border_mode=cv2.BORDER_CONSTANT, value=0, p=0.5),
#             A.Normalize(mean=HPA_RGB_MEAN, std=HPA_RGB_STD),
#             ToTensorV2(),
#         ]
#     )

#     transform_rot90r = A.Compose(
#         [
#             A.Resize(height=height, width=width),
#             A.Rotate((-90, -90), border_mode=cv2.BORDER_CONSTANT, value=0, p=0.5),
#             A.Normalize(mean=HPA_RGB_MEAN, std=HPA_RGB_STD),
#             ToTensorV2(),
#         ]
#     )

    
    transform_small = A.Compose(
        [
            A.Resize(height=int(height*0.8), width=int(width*0.8)),
            A.PadIfNeeded(min_height=height, min_width=width, border_mode=cv2.BORDER_CONSTANT, value=0),
            A.Normalize(mean=HPA_RGB_MEAN, std=HPA_RGB_STD),
            ToTensorV2(),
        ]
    )
    
    transform_large = A.Compose(
        [
            A.Resize(height=int(height*1.2), width=int(width*1.2)),
            A.CenterCrop(height=height, width=width, p=1.0),
            A.Normalize(mean=HPA_RGB_MEAN, std=HPA_RGB_STD),
            ToTensorV2(),
        ]
    )

    image_org = transform(image=image)["image"]
    image_hflip = transform_hflip(image=image)["image"]
    image_vflip = transform_vflip(image=image)["image"]
#     image_rot90f = transform_rot90f(image=image)["image"]
#     image_rot90r = transform_rot90r(image=image)["image"]
    image_small = transform_small(image=image)["image"]
    image_large = transform_large(image=image)["image"]

    image = torch.stack(
        [
            image_org,
            image_hflip,
            image_vflip,
#             image_rot90f,
#             image_rot90r,
            image_small,
            image_large,
        ],
        axis=0
    )

    if DEBUG:
        print(f"clf inference resized image.size(): {image.size()}")
    
    if isinstance(model, list):
        pred_list = []
        for m in model:
            pred = m(image.to(device))
            if DEBUG: print(f"clf inference pred.size(): {pred.size()}")
            pred = torch.mean(pred, 0)
            if DEBUG: print(f"clf inference pred.size(): {pred.size()}")
            pred_list.append(pred)

        pred = torch.stack(pred_list, dim=0)
        if DEBUG: print(f"clf inference pred.size(): {pred.size()}")
    else:
        pred = model(image.to(device))
        if DEBUG: print(f"clf inference pred.size(): {pred.size()}")

    pred = torch.mean(pred, 0)
    if DEBUG: print(f"clf inference pred.size(): {pred.size()}")

    return pred

In [None]:
def test_inference(test):
    pred = [x * 0.01 for x in range(19)]
    return pred

In [None]:
def get_pred_string(cell_mask, image, clf, CHECK_PLOT=False, clf_image_height=1024, clf_image_width=1024):
    num_mask = cell_mask.max()
    pred_string = []
    pred_string_check = []

    for mask_id in range(1, num_mask):
        # single cell mask
        single_cell_mask = cell_mask==mask_id
        bbox = get_bbox_from_mask(single_cell_mask)

        # single cell rgb image
        cell_image = image[bbox[2]:bbox[3], bbox[0]:bbox[1], :]

        # pad cell rgb image
        transform = A.PadIfNeeded(
            min_height=image_height,
            min_width=image_width,
            border_mode=cv2.BORDER_CONSTANT,
            value=0)
        pad_cell_image = transform(image=cell_image)["image"]

        # check single cell
        if CHECK_PLOT:
            plot_single_cell(single_cell_mask, image)
            CHECK_PLOT = False

        # inference
        if TTA:
            pred = inference_tta(pad_cell_image, clf, device, height=clf_image_height, width=clf_image_width)
        else:
            pred = inference(pad_cell_image, clf, device, height=clf_image_height, width=clf_image_width)

        if DEBUG:
            print(f"single_cell_mask.shape: {single_cell_mask.shape}")

        encoded_mask = encode_binary_mask(single_cell_mask)
        encoded_mask = encoded_mask.decode("utf-8")
    
        for label in range(19):
            conf = pred[label]
            if conf > PROB_THR:
                pred_string.append(f"{label} {conf} {encoded_mask}")
                pred_string_check.append(f"{label} {conf:.2f} /")
        pred_string_check.append(f"[{mask_id}]\n")
        
    pred_string = " ".join(pred_string)
    pred_string_check = " ".join(pred_string_check).replace("\n ", "\n")
    return pred_string, pred_string_check

In [None]:
# ----------
# debug mode
# ----------
DEBUG = False

In [None]:
# ----------
# settings
# ----------
pl.seed_everything(0)
seg_batch_size = 16 if not DEBUG else 4
TEST_DIR = "../input/hpa-single-cell-image-classification/test"
PROB_THR = 0.001
CLF_IMAGE_SIZE = 1024

TTA = True

In [None]:
# ----------
# checkpoint
# ----------
checkpoints = []
checkpoints.append("../input/hpa-checkpoint-final/b0-gce-full-base-aug-bce-fold-2_HPA-279_checkpoints_hpa-clf-epoch004-valid_loss0.036291.ckpt")
checkpoints.append("../input/hpa-checkpoint-final/b0-gce-full-base-aug-bce-fold-0_HPA-281_checkpoints_hpa-clf-epoch004-valid_loss0.036503.ckpt")
checkpoints.append("../input/hpa-checkpoint-final/b0-extra-rare-aug-bce-loss_HPA-267_checkpoints_hpa-clf-epoch009-valid_loss0.030838.ckpt")
checkpoints.append("../input/hpa-checkpoint-final/seresnext26d_32x4d-full-base-aug-focal-fold-1_HPA-282_checkpoints_last.ckpt")

model_names = []
model_names.append("tf_efficientnet_b0")
model_names.append("tf_efficientnet_b0")
model_names.append("tf_efficientnet_b0")
model_names.append("seresnext26d_32x4d")

In [None]:
# ----------
# device
# ----------
device = (
    torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
)
print(f"device {device}")

torch.set_grad_enabled(False)

In [None]:
# ----------
# TestDataset
# ----------
ds_cell = HPACellTestDataset(TEST_DIR)

loader = DataLoader(
    ds_cell,
    batch_size=seg_batch_size,
    shuffle=False,
    num_workers=0,
    collate_fn=make_batch,
    pin_memory=False,
)

In [None]:
# ----------
# Model SEG
# ----------
NUC_MODEL = "../input/hpacellsegmentatormodelweights/dpn_unet_nuclei_v1.pth"
CELL_MODEL = "../input/hpacellsegmentatormodelweights/dpn_unet_cell_3ch_v1.pth"
segmentator = cellsegmentator.CellSegmentator(
    NUC_MODEL,
    CELL_MODEL,
    scale_factor=1.0,
    device="cuda",
    padding=True,
    multi_channel_model=True,
)

segmentator.nuclei_model = segmentator.nuclei_model.eval()
segmentator.cell_model = segmentator.cell_model.eval()

In [None]:
# ----------
# Model CLF
# ----------
clf_list = []

for checkpoint, model_name in zip(checkpoints, model_names):
    clf = HPAClassifier.load_from_checkpoint(checkpoint, pretrained=False, model_name=model_name)
    clf.to(device)
    clf.eval();
    clf_list.append(clf)

In [None]:
image_id_list = []
image_width_list = []
image_height_list = []
pred_string_list = []
pred_string_check_list = []

for index, batch in enumerate(loader):
    
    batch_nuc = batch["nuc"]
    batch_mt, batch_er, batch_nu = batch["cell"]

    # get original image size
    batch_image_size = [x.shape for x in batch_nuc]
    
    # resize to 1/4
    batch_nuc = [cv2.resize(x, (int(x.shape[0]/4), int(x.shape[0]/4)), interpolation=cv2.INTER_AREA) for x in batch_nuc]
    batch_mt = [cv2.resize(x, (int(x.shape[0]/4), int(x.shape[0]/4)), interpolation=cv2.INTER_AREA) for x in batch_mt]
    batch_er = [cv2.resize(x, (int(x.shape[0]/4), int(x.shape[0]/4)), interpolation=cv2.INTER_AREA) for x in batch_er]
    batch_nu = [cv2.resize(x, (int(x.shape[0]/4), int(x.shape[0]/4)), interpolation=cv2.INTER_AREA) for x in batch_nu]
    batch_cell = [batch_mt, batch_er, batch_nu]
    
    # run segmentation
    nuc_segmentations = segmentator.pred_nuclei(batch_nuc)
    cell_segmentations = segmentator.pred_cells(batch_cell)
    
    batch_cell_mask = [
        label_cell(nuc_seg, cell_seg)[1]
        for nuc_seg, cell_seg in zip(nuc_segmentations, cell_segmentations)
    ]
    
    if DEBUG:
        for cell_mask in batch_cell_mask:
            print(f"cell_mask.shape: {cell_mask.shape}")
    
    # resize cell_mask to original image size
    batch_cell_mask = [
        cv2.resize(cell_mask, image_size, interpolation=cv2.INTER_NEAREST)
        for (cell_mask, image_size)
        in zip(batch_cell_mask, batch_image_size)
    ]
    if DEBUG:
        for cell_mask in batch_cell_mask:
            print(f"cell_mask.shape: {cell_mask.shape}")
    
    # single sample from batch
    for cell_mask, image, image_id in zip(batch_cell_mask, batch["rgb"], batch["image_id"]):

        image_height = image.shape[0]
        image_width = image.shape[1]
        
        if DEBUG:
            print(f"image.shape: {image.shape}, cell_mask.shape: {cell_mask.shape}")
    
        CHECK_PLOT = False if not DEBUG else True
        pred_string, pred_string_check = get_pred_string(
            cell_mask,
            image,
            clf_list,
            clf_image_height=CLF_IMAGE_SIZE,
            clf_image_width=CLF_IMAGE_SIZE,
            CHECK_PLOT=CHECK_PLOT
        )
    
        image_id_list.append(image_id)
        image_width_list.append(image_width)
        image_height_list.append(image_height)
        pred_string_list.append(pred_string)
        pred_string_check_list.append(pred_string_check)

    if DEBUG:
        break

In [None]:
df = pd.DataFrame(
    data=zip(
        image_id_list,
        image_width_list,
        image_height_list,
        pred_string_list,
        pred_string_check_list),
    columns=[
        "ID",
        "ImageWidth",
        "ImageHeight",
        "PredictionString",
        "PredictionStringCheck"
    ]
)

In [None]:
df.head()

In [None]:
print(df.PredictionStringCheck[0])

In [None]:
print(df.PredictionStringCheck[1])

In [None]:
sub_df = df[["ID", "ImageWidth", "ImageHeight", "PredictionString"]]

In [None]:
sub_df.head()

In [None]:
sub_df.to_csv("submission.csv", index=False)

In [None]:
check_df = df[["ID", "ImageWidth", "ImageHeight", "PredictionStringCheck"]]

In [None]:
check_df.head()

In [None]:
check_df.to_csv("check.csv", index=False)