In [None]:
! pip install "../input/keras-application/Keras_Applications-1.0.8-py3-none-any.whl"
! pip install "../input/efficientnet111/efficientnet-1.1.1-py3-none-any.whl"
! pip install "../input/efficientnetpytorch"
! pip install "../input/pycocotools/pycocotools-2.0-cp37-cp37m-linux_x86_64.whl"
! pip install "../input/hpapytorchzoozip/pytorch_zoo-master"
! pip install "../input/hpacellsegmentatormaster/HPA-Cell-Segmentation-master"
! pip install "../input/tfexplainforoffline/tf_explain-0.2.1-py3-none-any.whl"

In [None]:
IS_FINAL = True

In [None]:
import numpy as np 
import pandas as pd
import os, gc, cv2

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as nnf
import sklearn
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import gc

from torch.autograd import Variable
from tqdm.notebook import tqdm

from efficientnet_pytorch import EfficientNet

import zlib
import base64
from pycocotools import _mask as coco_mask
import random
import tensorflow as tf


    
import os.path
import urllib
import zipfile

import scipy.ndimage as ndi
from skimage import filters, measure, segmentation
from skimage.morphology import (binary_erosion, closing, disk,
                                remove_small_holes, remove_small_objects)

In [None]:
NUM_CL = 19

BATCH = 16
EPOCHS = 15

LR = 0.0001
IM_SIZE = 256

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
PATH = '/kaggle/input/'
TRAIN_DIR = PATH + 'segmented-train/'
# TEST_DIR = PATH + 'test/'


In [None]:
import numpy as np
import pandas as pd
import os
from tqdm import tqdm

from hpacellseg.cellsegmentator import *


class CellSegmentator(object):
    """Uses pretrained DPN-Unet models to segment cells from images."""

    def __init__(
        self,
        nuclei_model="../input/hpacellsegmentatormodelweights/dpn_unet_nuclei_v1.pth",
        cell_model="../input/hpacellsegmentatormodelweights/dpn_unet_cell_3ch_v1.pth",
        scale_factor=1.0,
        device="cuda",
        padding=False,
        multi_channel_model=True,
    ):
        """Class for segmenting nuclei and whole cells from confocal microscopy images.
        It takes lists of images and returns the raw output from the
        specified segmentation model. Models can be automatically
        downloaded if they are not already available on the system.
        When working with images from the Huan Protein Cell atlas, the
        outputs from this class' methods are well combined with the
        label functions in the utils module.
        Note that for cell segmentation, there are two possible models
        available. One that works with 2 channeled images and one that
        takes 3 channels.
        Keyword arguments:
        nuclei_model -- A loaded torch nuclei segmentation model or the
                        path to a file which contains such a model.
                        If the argument is a path that points to a non-existant file,
                        a pretrained nuclei_model is going to get downloaded to the
                        specified path (default: './nuclei_model.pth').
        cell_model -- A loaded torch cell segmentation model or the
                      path to a file which contains such a model.
                      The cell_model argument can be None if only nuclei
                      are to be segmented (default: './cell_model.pth').
        scale_factor -- How much to scale images before they are fed to
                        segmentation models. Segmentations will be scaled back
                        up by 1/scale_factor to match the original image
                        (default: 0.25).
        device -- The device on which to run the models.
                  This should either be 'cpu' or 'cuda' or pointed cuda
                  device like 'cuda:0' (default: 'cuda').
        padding -- Whether to add padding to the images before feeding the
                   images to the network. (default: False).
        multi_channel_model -- Control whether to use the 3-channel cell model or not.
                               If True, use the 3-channel model, otherwise use the
                               2-channel version (default: True).
        """
        if device != "cuda" and device != "cpu" and "cuda" not in device:
            raise ValueError(f"{device} is not a valid device (cuda/cpu)")
        if device != "cpu":
            try:
                assert torch.cuda.is_available()
            except AssertionError:
                print("No GPU found, using CPU.", file=sys.stderr)
                device = "cpu"
        self.device = device

        if isinstance(nuclei_model, str):
            if not os.path.exists(nuclei_model):
                print(
                    f"Could not find {nuclei_model}. Downloading it now",
                    file=sys.stderr,
                )
                download_with_url(NUCLEI_MODEL_URL, nuclei_model)
            nuclei_model = torch.load(
                nuclei_model, map_location=torch.device(self.device)
            )
        if isinstance(nuclei_model, torch.nn.DataParallel) and device == "cpu":
            nuclei_model = nuclei_model.module

        self.nuclei_model = nuclei_model.to(self.device).eval()

        self.multi_channel_model = multi_channel_model
        if isinstance(cell_model, str):
            if not os.path.exists(cell_model):
                print(
                    f"Could not find {cell_model}. Downloading it now", file=sys.stderr
                )
                if self.multi_channel_model:
                    download_with_url(MULTI_CHANNEL_CELL_MODEL_URL, cell_model)
                else:
                    download_with_url(TWO_CHANNEL_CELL_MODEL_URL, cell_model)
            cell_model = torch.load(cell_model, map_location=torch.device(self.device))
        self.cell_model = cell_model.to(self.device).eval()
        self.scale_factor = scale_factor
        self.padding = padding

    def _image_conversion(self, images):
        """Convert/Format images to RGB image arrays list for cell predictions.
        Intended for internal use only.
        Keyword arguments:
        images -- list of lists of image paths/arrays. It should following the
                 pattern if with er channel input,
                 [
                     [microtubule_path0/image_array0, microtubule_path1/image_array1, ...],
                     [er_path0/image_array0, er_path1/image_array1, ...],
                     [nuclei_path0/image_array0, nuclei_path1/image_array1, ...]
                 ]
                 or if without er input,
                 [
                     [microtubule_path0/image_array0, microtubule_path1/image_array1, ...],
                     None,
                     [nuclei_path0/image_array0, nuclei_path1/image_array1, ...]
                 ]
        """
        microtubule_imgs, er_imgs, nuclei_imgs = images
        if self.multi_channel_model:
            if not isinstance(er_imgs, list):
                raise ValueError("Please speicify the image path(s) for er channels!")
        else:
            if not er_imgs is None:
                raise ValueError(
                    "second channel should be None for two channel model predition!"
                )

        if not isinstance(microtubule_imgs, list):
            raise ValueError("The microtubule images should be a list")
        if not isinstance(nuclei_imgs, list):
            raise ValueError("The microtubule images should be a list")

        if er_imgs:
            if not len(microtubule_imgs) == len(er_imgs) == len(nuclei_imgs):
                raise ValueError("The lists of images needs to be the same length")
        else:
            if not len(microtubule_imgs) == len(nuclei_imgs):
                raise ValueError("The lists of images needs to be the same length")

        if not all(isinstance(item, np.ndarray) for item in microtubule_imgs):
            microtubule_imgs = [
                os.path.expanduser(item) for _, item in enumerate(microtubule_imgs)
            ]
            nuclei_imgs = [
                os.path.expanduser(item) for _, item in enumerate(nuclei_imgs)
            ]

            microtubule_imgs = list(
                map(lambda item: imageio.imread(item), microtubule_imgs)
            )
            nuclei_imgs = list(map(lambda item: imageio.imread(item), nuclei_imgs))
            if er_imgs:
                er_imgs = [os.path.expanduser(item) for _, item in enumerate(er_imgs)]
                er_imgs = list(map(lambda item: imageio.imread(item), er_imgs))

        if not er_imgs:
            er_imgs = [
                np.zeros(item.shape, dtype=item.dtype)
                for _, item in enumerate(microtubule_imgs)
            ]
        cell_imgs = list(
            map(
                lambda item: np.dstack((item[0], item[1], item[2])),
                list(zip(microtubule_imgs, er_imgs, nuclei_imgs)),
            )
        )

        return cell_imgs

    def pred_nuclei(self, images):
        """Predict the nuclei segmentation.
        Keyword arguments:
        images -- A list of image arrays or a list of paths to images.
                  If as a list of image arrays, the images could be 2d images
                  of nuclei data array only, or must have the nuclei data in
                  the blue channel; If as a list of file paths, the images
                  could be RGB image files or gray scale nuclei image file
                  paths.
        Returns:
        predictions -- A list of predictions of nuclei segmentation for each nuclei image.
        """

        def _preprocess(image):
            if isinstance(image, str):
                image = imageio.imread(image)
            self.target_shape = image.shape
            if len(image.shape) == 2:
                image = np.dstack((image, image, image))
            image = transform.rescale(image, self.scale_factor, multichannel=True)
            nuc_image = np.dstack((image[..., 2], image[..., 2], image[..., 2]))
            if self.padding:
                rows, cols = nuc_image.shape[:2]
                self.scaled_shape = rows, cols
                nuc_image = cv2.copyMakeBorder(
                    nuc_image,
                    32,
                    (32 - rows % 32),
                    32,
                    (32 - cols % 32),
                    cv2.BORDER_REFLECT,
                )
            nuc_image = nuc_image.transpose([2, 0, 1])
            return nuc_image

        def _segment_helper(imgs):
            with torch.no_grad():
                mean = torch.as_tensor(NORMALIZE["mean"], device=self.device)
                std = torch.as_tensor(NORMALIZE["std"], device=self.device)
                imgs = torch.tensor(imgs).float()
                imgs = imgs.to(self.device)
                imgs = imgs.sub_(mean[:, None, None]).div_(std[:, None, None])

                imgs = self.nuclei_model(imgs)
                imgs = F.softmax(imgs, dim=1)
                return imgs

        preprocessed_imgs = list(map(_preprocess, images))
        bs = 24
        predictions = []
        for i in range(0, len(preprocessed_imgs), bs):
            start = i
            end = min(len(preprocessed_imgs), i+bs)
            x = preprocessed_imgs[start:end]
            pred = _segment_helper(x).cpu().numpy()
            predictions.append(pred)
        predictions = list(np.concatenate(predictions, axis=0))
        predictions = map(util.img_as_ubyte, predictions)
        predictions = list(map(self._restore_scaling_padding, predictions))
        return predictions

    def _restore_scaling_padding(self, n_prediction):
        """Restore an image from scaling and padding.
        This method is intended for internal use.
        It takes the output from the nuclei model as input.
        """
        n_prediction = n_prediction.transpose([1, 2, 0])
        if self.padding:
            n_prediction = n_prediction[
                32 : 32 + self.scaled_shape[0], 32 : 32 + self.scaled_shape[1], ...
            ]
        if not self.scale_factor == 1:
            n_prediction[..., 0] = 0
            n_prediction = cv2.resize(
                n_prediction,
                (self.target_shape[0], self.target_shape[1]),
                interpolation=cv2.INTER_AREA,
            )
        return n_prediction

    def pred_cells(self, images, precombined=False):
        """Predict the cell segmentation for a list of images.
        Keyword arguments:
        images -- list of lists of image paths/arrays. It should following the
                  pattern if with er channel input,
                  [
                      [microtubule_path0/image_array0, microtubule_path1/image_array1, ...],
                      [er_path0/image_array0, er_path1/image_array1, ...],
                      [nuclei_path0/image_array0, nuclei_path1/image_array1, ...]
                  ]
                  or if without er input,
                  [
                      [microtubule_path0/image_array0, microtubule_path1/image_array1, ...],
                      None,
                      [nuclei_path0/image_array0, nuclei_path1/image_array1, ...]
                  ]
                  The ER channel is required when multichannel is True
                  and required to be None when multichannel is False.
                  The images needs to be of the same size.
        precombined -- If precombined is True, the list of images is instead supposed to be
                       a list of RGB numpy arrays (default: False).
        Returns:
        predictions -- a list of predictions of cell segmentations.
        """

        def _preprocess(image):
            self.target_shape = image.shape
            if not len(image.shape) == 3:
                raise ValueError("image should has 3 channels")
            cell_image = transform.rescale(image, self.scale_factor, multichannel=True)
            if self.padding:
                rows, cols = cell_image.shape[:2]
                self.scaled_shape = rows, cols
                cell_image = cv2.copyMakeBorder(
                    cell_image,
                    32,
                    (32 - rows % 32),
                    32,
                    (32 - cols % 32),
                    cv2.BORDER_REFLECT,
                )
            cell_image = cell_image.transpose([2, 0, 1])
            return cell_image

        def _segment_helper(imgs):
            with torch.no_grad():
                mean = torch.as_tensor(NORMALIZE["mean"], device=self.device)
                std = torch.as_tensor(NORMALIZE["std"], device=self.device)
                imgs = torch.tensor(imgs).float()
                imgs = imgs.to(self.device)
                imgs = imgs.sub_(mean[:, None, None]).div_(std[:, None, None])

                imgs = self.cell_model(imgs)
                imgs = F.softmax(imgs, dim=1)
                return imgs

        if not precombined:
            images = self._image_conversion(images)
        preprocessed_imgs = list(map(_preprocess, images))
        bs = 24
        predictions = []
        for i in range(0, len(preprocessed_imgs), bs):
            start = i
            end = min(len(preprocessed_imgs), i+bs)
            x = preprocessed_imgs[start:end]
            pred = _segment_helper(x).cpu().numpy()
            predictions.append(pred)
        predictions = list(np.concatenate(predictions, axis=0))
        predictions = map(self._restore_scaling_padding, predictions)
        predictions = list(map(util.img_as_ubyte, predictions))

        return predictions
    


HIGH_THRESHOLD = 0.4
LOW_THRESHOLD = HIGH_THRESHOLD - 0.25


def download_with_url(url_string, file_path, unzip=False):
    """Download file with a link."""
    with urllib.request.urlopen(url_string) as response, open(
        file_path, "wb"
    ) as out_file:
        data = response.read()  # a `bytes` object
        out_file.write(data)

    if unzip:
        with zipfile.ZipFile(file_path, "r") as zip_ref:
            zip_ref.extractall(os.path.dirname(file_path))


def __fill_holes(image):
    """Fill_holes for labelled image, with a unique number."""
    boundaries = segmentation.find_boundaries(image)
    image = np.multiply(image, np.invert(boundaries))
    image = ndi.binary_fill_holes(image > 0)
    image = ndi.label(image)[0]
    return image





def label_cell(nuclei_pred, cell_pred):
    """Label the cells and the nuclei.
    Keyword arguments:
    nuclei_pred -- a 3D numpy array of a prediction from a nuclei image.
    cell_pred -- a 3D numpy array of a prediction from a cell image.
    Returns:
    A tuple containing:
    nuclei-label -- A nuclei mask data array.
    cell-label  -- A cell mask data array.
    0's in the data arrays indicate background while a continous
    strech of a specific number indicates the area for a specific
    cell.
    The same value in cell mask and nuclei mask refers to the identical cell.
    NOTE: The nuclei labeling from this function will be sligthly
    different from the values in :func:`label_nuclei` as this version
    will use information from the cell-predictions to make better
    estimates.
    """
    def __wsh(
        mask_img,
        threshold,
        border_img,
        seeds,
        threshold_adjustment=0.35,
        small_object_size_cutoff=10,
    ):
        img_copy = np.copy(mask_img)
        m = seeds * border_img  # * dt
        img_copy[m <= threshold + threshold_adjustment] = 0
        img_copy[m > threshold + threshold_adjustment] = 1
        img_copy = img_copy.astype(np.bool)
        img_copy = remove_small_objects(img_copy, small_object_size_cutoff).astype(
            np.uint8
        )

        mask_img[mask_img <= threshold] = 0
        mask_img[mask_img > threshold] = 1
        mask_img = mask_img.astype(np.bool)
        mask_img = remove_small_holes(mask_img, 63)
        mask_img = remove_small_objects(mask_img, 1).astype(np.uint8)
        markers = ndi.label(img_copy, output=np.uint32)[0]
        labeled_array = segmentation.watershed(
            mask_img, markers, mask=mask_img, watershed_line=True
        )
        return labeled_array

    nuclei_label = __wsh(
        nuclei_pred[..., 2] / 255.0,
        0.4,
        1 - (nuclei_pred[..., 1] + cell_pred[..., 1]) / 255.0 > 0.05,
        nuclei_pred[..., 2] / 255,
        threshold_adjustment=-0.25,
        small_object_size_cutoff=32,
    )

    # for hpa_image, to remove the small pseduo nuclei
    nuclei_label = remove_small_objects(nuclei_label, 157)
    nuclei_label = measure.label(nuclei_label)
    # this is to remove the cell borders' signal from cell mask.
    # could use np.logical_and with some revision, to replace this func.
    # Tuned for segmentation hpa images
    threshold_value = max(0.22, filters.threshold_otsu(cell_pred[..., 2] / 255) * 0.5)
    # exclude the green area first
    cell_region = np.multiply(
        cell_pred[..., 2] / 255 > threshold_value,
        np.invert(np.asarray(cell_pred[..., 1] / 255 > 0.05, dtype=np.int8)),
    )
    sk = np.asarray(cell_region, dtype=np.int8)
    distance = np.clip(cell_pred[..., 2], 255 * threshold_value, cell_pred[..., 2])
    cell_label = segmentation.watershed(-distance, nuclei_label, mask=sk)
    cell_label = remove_small_objects(cell_label, 344).astype(np.uint8)
    selem = disk(2)
    cell_label = closing(cell_label, selem)
    cell_label = __fill_holes(cell_label)
    # this part is to use green channel, and extend cell label to green channel
    # benefit is to exclude cells clear on border but without nucleus
    sk = np.asarray(
        np.add(
            np.asarray(cell_label > 0, dtype=np.int8),
            np.asarray(cell_pred[..., 1] / 255 > 0.05, dtype=np.int8),
        )
        > 0,
        dtype=np.int8,
    )
    cell_label = segmentation.watershed(-distance, cell_label, mask=sk)
    cell_label = __fill_holes(cell_label)
    cell_label = np.asarray(cell_label > 0, dtype=np.uint8)
    cell_label = measure.label(cell_label)
    cell_label = remove_small_objects(cell_label, 344)
    cell_label = measure.label(cell_label)
    cell_label = np.asarray(cell_label, dtype=np.uint16)
    nuclei_label = np.multiply(cell_label > 0, nuclei_label) > 0
    nuclei_label = measure.label(nuclei_label)
    nuclei_label = remove_small_objects(nuclei_label, 157)
    nuclei_label = np.multiply(cell_label, nuclei_label > 0)

    return nuclei_label, cell_label

In [None]:
# read and visualize sample image
def read_sample_image(filename):
    
    '''
    read individual images
    of different filters (R, G, B, Y)
    and stack them.
    ---------------------------------
    Arguments:
    filename -- sample image path
    
    Returns:
    stacked_images -- stacked (RGBY) image
    '''
    
    red = cv2.imread(TRAIN_DIR + filename + "_red.png", cv2.IMREAD_UNCHANGED)
    green = cv2.imread(TRAIN_DIR + filename + "_green.png", cv2.IMREAD_UNCHANGED)
    blue = cv2.imread(TRAIN_DIR + filename + "_blue.png", cv2.IMREAD_UNCHANGED)
    yellow = cv2.imread(TRAIN_DIR + filename + "_yellow.png", cv2.IMREAD_UNCHANGED)

    stacked_images = np.transpose(np.array([red, green, blue, yellow]), (1,2,0))
    return stacked_images

def read_test_image(filename):
    red = cv2.imread(TEST_DIR + filename + "_red.png", cv2.IMREAD_UNCHANGED)
    green = cv2.imread(TEST_DIR + filename + "_green.png", cv2.IMREAD_UNCHANGED)
    blue = cv2.imread(TEST_DIR + filename + "_blue.png", cv2.IMREAD_UNCHANGED)
    yellow = cv2.imread(TEST_DIR + filename + "_yellow.png", cv2.IMREAD_UNCHANGED)

    stacked_images = np.transpose(np.array([red, green, blue, yellow]), (1,2,0))
#     plt.imshow(stacked_images)
    return stacked_images


def plot_all(im, label):
    
    '''
    plot all RGBY image,
    Red, Green, Blue, Yellow, 
    filters images.
    --------------------------
    Argument:
    im - image
    '''
    
    plt.figure(figsize=(15, 15))
    plt.subplot(1, 5, 1)
    plt.imshow(im[:,:,:3])
    plt.title('RGBY Image')
    plt.axis('off')
    plt.subplot(1, 5, 2)
    plt.imshow(im[:,:,0], cmap='Reds')
    plt.title('Microtubule channels')
    plt.axis('off')
    plt.subplot(1, 5, 3)
    plt.imshow(im[:,:,1], cmap='Greens')
    plt.title('Protein of Interest')
    plt.axis('off')
    plt.subplot(1, 5, 4)
    plt.imshow(im[:,:,2], cmap='Blues')
    plt.title('Nucleus')
    plt.axis('off')
    plt.subplot(1, 5, 5)
    plt.imshow(im[:,:,3], cmap='Oranges')
    plt.title('Endoplasmic Reticulum')
    plt.axis('off')
    plt.show()
    
# read and visualize sample image
def read_sample_image_seg(filename):
    
    '''
    read individual images
    of different filters (R, B, Y)
    and stack them for segmentation.
    ---------------------------------
    Arguments:
    filename -- sample image file path
    
    Returns:
    stacked_images -- stacked (RBY) image path in lists.
    '''
    
    red = TRAIN_DIR + filename + "_red.png"
    blue = TRAIN_DIR+ filename + "_blue.png"
    yellow = TRAIN_DIR + filename + "_yellow.png"
    green = TRAIN_DIR + filename + "_green.png"

    stacked_images = [[red], [yellow], [blue]]
#    plt.imshow(stacked_images[0])
#    plt.imshow(stacked_images[1])
#    plt.imshow(stacked_images[2])
    return stacked_images, red, blue, yellow, green

def read_test_image_seg(filename):
    
    '''
    read individual images
    of different filters (R, B, Y)
    and stack them for segmentation.
    ---------------------------------
    Arguments:
    filename -- sample image file path
    
    Returns:
    stacked_images -- stacked (RBY) image path in lists.
    '''
    
    red = TEST_DIR + filename + "_red.png"
    blue = TEST_DIR+ filename + "_blue.png"
    yellow = TEST_DIR + filename + "_yellow.png"
    green = TEST_DIR + filename + "_green.png"

    stacked_images = [[red], [yellow], [blue]]
    return stacked_images, red, blue, yellow, green


# segment cell 
def segmentCell(image, segmentator):
    
    '''
    segment cell and nuclei from
    microtubules, endoplasmic reticulum,
    and nuclei (R, B, Y) filters.
    ------------------------------------
    Argument:
    image -- (R, B, Y) li
    st of image arrays
    segmentator -- CellSegmentator class object
    
    Returns:
    cell_mask -- segmented cell mask
    '''
    '''
    print('image is', image)
    print("image2 is", image[2])
    '''
    nuc_segmentations = segmentator.pred_nuclei(image[2])
    cell_segmentations = segmentator.pred_cells(image)
    nuclei_mask, cell_mask = label_cell(nuc_segmentations[0], cell_segmentations[0])
    
    gc.collect(); del nuc_segmentations; del cell_segmentations; del nuclei_mask
    
    return cell_mask

def faster_segmentCell(blue, image, segmentator):
    
    '''
    segment cell and nuclei from
    microtubules, endoplasmic reticulum,
    and nuclei (R, B, Y) filters.
    ------------------------------------
    Argument:
    image -- (R, B, Y) li
    st of image arrays
    segmentator -- CellSegmentator class object
    
    Returns:
    cell_mask -- segmented cell mask
    '''
    '''
    print('image is', image)
    print("image2 is", image[2])
    '''
    # nuc_segmentations = segmentator.pred_nuclei(image[2])
    nuc_segmentations = segmentator.pred_nuclei([blue])
    cell_segmentations = segmentator.pred_cells([image], precombined=True)
    nuclei_mask, cell_mask = label_cell(nuc_segmentations[0], cell_segmentations[0])
    
    gc.collect(); del nuc_segmentations; del cell_segmentations
    
    
    return cell_mask, nuclei_mask


# plot segmented cells mask, image

def plot_cell_segments(mask, r, b, y):
    
    '''
    plot segmented cells
    and images
    ---------------------
    Arguments:
    mask -- cell mask
    red -- red filter image path
    blue -- blue filter image path
    yellow -- yellow filter image path
    '''
    microtubule = plt.imread(r)    
    endoplasmicrec = plt.imread(b)    
    nuclei = plt.imread(y)
    img = np.dstack((microtubule, endoplasmicrec, nuclei))
    
    plt.figure(figsize=(15, 15))
    plt.subplot(1, 3, 1)
    plt.imshow(img)
    plt.title('Image')
    plt.axis('off')
    
    plt.subplot(1, 3, 2)
    plt.imshow(mask)
    plt.title('Mask')
    plt.axis('off')
    
    plt.subplot(1, 3, 3)
    plt.imshow(img)
    plt.imshow(mask, alpha=0.6)
    plt.title('Image + Mask')
    plt.axis('off')
    plt.show()

# plot single segmented cells mask, image
def plot_single_cell(mask, red, blue, yellow):
    
    '''
    plot single cell mask
    and image
    ---------------------
    Arguments:
    mask -- cell mask
    red -- red filter image path
    blue -- blue filter image path
    yellow -- yellow filter image path
    '''
    microtubule = plt.imread(r)    
    endoplasmicrec = plt.imread(b)    
    nuclei = plt.imread(y)
    img = np.dstack((microtubule, endoplasmicrec, nuclei))
    
    contours= cv2.findContours(mask.astype('uint8'),
                               cv2.RETR_TREE, 
                               cv2.CHAIN_APPROX_SIMPLE)

    areas = [cv2.contourArea(c) for c in contours[0]]
    x = np.argsort(areas)
    cnt = contours[0][x[-1]]
    x,yc,w,h = cv2.boundingRect(cnt)
    
    plt.figure(figsize=(15, 15))
    plt.subplot(1, 3, 1)
    plt.imshow(img[yc:yc+h, x:x+w])
    plt.title('Cell Image')
    plt.axis('off')
    
    plt.subplot(1, 3, 2)
    plt.imshow(mask[yc:yc+h, x:x+w])
    plt.title('Cell Mask')
    plt.axis('off')
    
    plt.subplot(1, 3, 3)
    plt.imshow(img[yc:yc+h, x:x+w])
    plt.imshow(mask[yc:yc+h, x:x+w], alpha=0.6)
    plt.title('Cell Image + Mask')
    plt.axis('off')
    plt.show()
    
# return all segmented cell images in an original image
def images_single_cell(mask, red, blue, yellow, green):
    
    '''
    create single cell mask
    and image
    ---------------------
    Arguments:
    mask -- cell mask
    red -- red filter image path
    blue -- blue filter image path
    yellow -- yellow filter image path
    '''
    microtubule = plt.imread(red)    
    endoplasmicrec = plt.imread(blue)    
    nuclei = plt.imread(yellow)
    protein = plt.imread(green)
    img = np.dstack((microtubule, endoplasmicrec, nuclei, protein))
    
    contours= cv2.findContours(mask.astype('uint8'),
                               cv2.RETR_TREE, 
                               cv2.CHAIN_APPROX_SIMPLE)
    
    images = list()
    for cnt in contours[0]:
        x,yc,w,h = cv2.boundingRect(cnt)
        images.append(img[yc:yc+h, x:x+w])
        
    del contours, microtubule, endoplasmicrec, nuclei, protein, img
    gc.collect()

    return images

# return all segmented cell images in an original image
def faster_images_single_cell(mask, red, blue, yellow, green):
    
    '''
    create single cell mask
    and image
    ---------------------
    Arguments:
    mask -- cell mask
    red -- red filter image path
    blue -- blue filter image path
    yellow -- yellow filter image path
    '''
    microtubule = red  
    endoplasmicrec = blue  
    nuclei = yellow
    protein = green
    img = np.dstack((microtubule, endoplasmicrec, nuclei, protein))
    
    contours= cv2.findContours(mask.astype('uint8'),
                               cv2.RETR_TREE, 
                               cv2.CHAIN_APPROX_SIMPLE)
    
    images = list()
    for cnt in contours[0]:
        x,yc,w,h = cv2.boundingRect(cnt)
        images.append(img[yc:yc+h, x:x+w])
        
    del contours, microtubule, endoplasmicrec, nuclei, protein, img
    gc.collect()

    return images


# Extend the size of an image by padding
def add_margin(image, size):
    '''
    Extend the size of an image by padding 
    Height * Width * Channel -> size * size * Channel
    '''
    H, W, C = image.shape
    pad_H1 = (size - H)//2
    pad_H2 = pad_H1 + (size - H)%2
    pad_W1 = (size - W)//2
    pad_W2 = pad_W1 + (size - W)%2
    
    return np.pad(image,[(pad_H1, pad_H2),(pad_W1, pad_W2),(0,0)], 'constant')

# Make the image square by padding
def resize_to_square(image):
    '''
    Extend the size of an image by padding 
    Height * Width * Channel -> max(H, W) * max(H, W) * Channel
    '''
    H, W, C = image.shape
    size = max(H, W)    
    return add_margin(image, size)


def binary_mask(rgby_images):
    
    '''
    generate masks from 
    rgby images.
    --------------------
    Arguments:
    rgby_images -- RGBY cell images
    
    Return:
    mask -- binary mask.
    '''
    pass



In [None]:
# Dataset class for cell-level classification
# For training dataset
class GetData_single_cell(Dataset):
    def __init__(self, path, list_IDs, df_labels, img_size, Transform='None'):
        self.path = path
        self.list_IDs = list_IDs
        self.labels = df_labels
        self.img_size = img_size        
        self.transform = Transform
        
    def __len__(self):
        return len(self.list_IDs)    
    
    def __getitem__(self, index):
        ID = self.list_IDs[index]   
                        
        red = np.load(self.path + "red/" + str(ID).zfill(6) + '_red.npy')
        blue = np.load(self.path + "blue/" + str(ID).zfill(6) + '_blue.npy')
        yellow = np.load(self.path + "yellow/" + str(ID).zfill(6) + '_yellow.npy')
        green = np.load(self.path + "green/" + str(ID).zfill(6) + '_green.npy')
        
        img = np.dstack((red, blue, yellow, green))

        img = resize_to_square(img)
        img = cv2.resize(img, (self.img_size, self.img_size)) 
        X = img/255.
        X = np.transpose(X, (2, 0, 1))

        y = self.labels.loc[ID]
        return X, torch.tensor(y, dtype=torch.float)


In [None]:

def encode_binary_mask(mask, mask_val):
  """Converts a binary mask into OID challenge encoding ascii text."""
  mask = np.where(mask==mask_val, True, False)
  
  # check input mask --
  if mask.dtype != np.bool:
    raise ValueError(
        "encode_binary_mask expects a binary mask, received dtype == %s" %
        mask.dtype)

  mask = np.squeeze(mask)
  if len(mask.shape) != 2:
    raise ValueError(
        "encode_binary_mask expects a 2d mask, received shape == %s" %
        mask.shape)

  # convert input mask to expected COCO API input --
  mask_to_encode = mask.reshape(mask.shape[0], mask.shape[1], 1)
  mask_to_encode = mask_to_encode.astype(np.uint8)
  mask_to_encode = np.asfortranarray(mask_to_encode)

  # RLE encode mask --
  encoded_mask = coco_mask.encode(mask_to_encode)[0]["counts"]

  # compress and base64 encoding --
  binary_str = zlib.compress(encoded_mask, zlib.Z_BEST_COMPRESSION)
  base64_str = base64.b64encode(binary_str)
  return base64_str.decode()


def is_border_nuclei(contour_points):
    unique_points = np.unique(contour_points)
    
    if 0 in unique_points:
        return True
    return False

def clean_nuclei_mask_vals(nuclei_mask):
    nuclei = np.unique(nuclei_mask)
    
    nuclei_list = []
    
    for nucleus in nuclei:
        # get inidivual nucleus mask
        nucleus_mask = np.where(nuclei_mask==nucleus, 1,0).astype('uint8')
        
        # get contour for cell and nucleus
        nucleus_cnts, _ = cv2.findContours(nucleus_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        
        if not is_border_nuclei(nucleus_cnts[0]): # If not touching the boundary
            nuclei_list.append(nucleus)
        
    return nuclei_list

def decode_img(img, img_size=(224,224), testing=False):
    """TBD"""
    
    # convert the compressed string to a 3D uint8 tensor
    if not testing:
        # resize the image to the desired size
        img = tf.image.decode_png(img, channels=1)
        return tf.cast(tf.image.resize(img, img_size), tf.uint8)
    else:
        return tf.image.decode_png(img, channels=1)
    

def load_image(img_id, img_dir, testing=False):
    """ Load An Image Using ID and Directory Path - Composes 4 Individual Images """
    return_axis = 0
    clr_list = ["red", "green", "blue", "yellow"]
    
    if not testing:
        rgby = [
            np.asarray(Image.open(os.path.join(img_dir, img_id+f"_{c}.png")), np.uint8) \
            for c in ["red", "green", "blue", "yellow"]
        ]
        return np.stack(rgby, axis=-1)
    else:
        # This is for cellsegmentator
        return np.stack(
            [np.asarray(decode_img(tf.io.read_file(os.path.join(img_dir, img_id+f"_{c}.png")), testing=True), np.uint8)[..., 0] \
             for c in clr_list], axis=return_axis,
        )

def get_contour_bbox_from_rle(rle, width, height, return_mask=True,):
    """ Get bbox of contour as `xmin ymin xmax ymax`
    
    Args:
        rle (rle_string): Run length encoding containing 
            segmentation mask information
        height (int): Height of the original image the map comes from
        width (int): Width of the original image the map comes from
    
    Returns:
        Numpy array for a cell bounding box coordinates
    """
    mask = rle_to_mask(rle, height, width).copy()
    cnts = grab_contours(
        cv2.findContours(
            mask, 
            cv2.RETR_EXTERNAL, 
            cv2.CHAIN_APPROX_SIMPLE
        ))
    x,y,w,h = cv2.boundingRect(cnts[0])
    
    if return_mask:
        return (x,y,x+w,y+h), mask
    else:
        return (x,y,x+w,y+h)

    
def rgby_reshape(fourchannel):
    r = im[0]
    b = im[1]
    y = im[2]
    g = im[3]
    
    return np.stack((r,b,y), axis=2)

def create_pred_col(row):
    """ Simple function to return the correct prediction string
    
    We will want the original public test dataframe submission when it is 
    available. However, we will use the swapped inn submission dataframe
    when it is not.
    
    Args:
        row (pd.Series): A row in the dataframe
    
    Returns:
        The prediction string
    """
    if pd.isnull(row.PredictionString_y):
        return row.PredictionString_x
    else:
        return row.PredictionString_y

def pad_to_square(a):
    """ Pad an array `a` evenly until it is a square """
    if a.shape[1]>a.shape[0]: # pad height
        n_to_add = a.shape[1]-a.shape[0]
        top_pad = n_to_add//2
        bottom_pad = n_to_add-top_pad
        a = np.pad(a, [(top_pad, bottom_pad), (0, 0), (0, 0)], mode='constant')

    elif a.shape[0]>a.shape[1]: # pad width
        n_to_add = a.shape[0]-a.shape[1]
        left_pad = n_to_add//2
        right_pad = n_to_add-left_pad
        a = np.pad(a, [(0, 0), (left_pad, right_pad), (0, 0)], mode='constant')
    else:
        pass
    return a

def flatten_list_of_lists(l_o_l, to_string=False):
    if not to_string:
        return [item for sublist in l_o_l for item in sublist]
    else:
        return [str(item) for sublist in l_o_l for item in sublist]

def get_contour_bbox_from_raw(raw_mask):
    """ Get bbox of contour as `xmin ymin xmax ymax`
    
    Args:
        raw_mask (nparray): Numpy array containing segmentation mask information
    
    Returns:
        Numpy array for a cell bounding box coordinates
    """
    cnts = grab_contours(
        cv2.findContours(
            raw_mask, 
            cv2.RETR_EXTERNAL, 
            cv2.CHAIN_APPROX_SIMPLE
        ))
    xywhs = [cv2.boundingRect(cnt) for cnt in cnts]
    xys = [(xywh[0], xywh[1], xywh[0]+xywh[2], xywh[1]+xywh[3]) for xywh in xywhs]
    return sorted(xys, key=lambda x: (x[1], x[0]))

def grab_contours(cnts):
    # if the length the contours tuple returned by cv2.findContours
    # is '2' then we are using either OpenCV v2.4, v4-beta, or
    # v4-official
    if len(cnts) == 2:
        cnts = cnts[0]

    # if the length of the contours tuple is '3' then we are using
    # either OpenCV v3, v4-pre, or v4-alpha
    elif len(cnts) == 3:
        cnts = cnts[1]

    # otherwise OpenCV has changed their cv2.findContours return
    # signature yet again and I have no idea WTH is going on
    else:
        raise Exception(("Contours tuple must have length 2 or 3, "
            "otherwise OpenCV changed their cv2.findContours return "
            "signature yet again. Refer to OpenCV's documentation "
            "in that case"))

    # return the actual contours array
    return cnts

def get_torchd(tile, DEVICE):
    tile = np.array(tile).swapaxes(1,3)
    tile = torch.from_numpy(tile)
    tile = tile.to(DEVICE)
    tile = tile.float()
    return tile

def rle_to_mask(rle_string, height, width):
    """ Convert RLE sttring into a binary mask 
    
    Args:
        rle_string (rle_string): Run length encoding containing 
            segmentation mask information
        height (int): Height of the original image the map comes from
        width (int): Width of the original image the map comes from
    
    Returns:
        Numpy array of the binary segmentation mask for a given cell
    """
    rows,cols = height,width
    rle_numbers = [int(num_string) for num_string in rle_string.split(' ')]
    rle_pairs = np.array(rle_numbers).reshape(-1,2)
    img = np.zeros(rows*cols,dtype=np.uint8)
    for index,length in rle_pairs:
        index -= 1
        img[index:index+length] = 255
    img = img.reshape(cols,rows)
    img = img.T
    return img



def for_predictions(fourchannel):
    im = fourchannel.transpose(2,0,1)
    r = im[0]
    b = im[2]
    y = im[1]
    g = im[3]
    
    return np.stack((g,y,b), axis=2)

In [None]:
import seaborn as sns

import plotly.express as px

LBL_NAMES = ["Nucleoplasm", "Nuclear Membrane", "Nucleoli", "Nucleoli Fibrillar Center", "Nuclear Speckles", "Nuclear Bodies", 
             "Endoplasmic Reticulum", "Golgi Apparatus", "Intermediate Filaments", "Actin Filaments", "Microtubules", "Mitotic Spindle", 
             "Centrosome", "Plasma Membrane", "Mitochondria", "Aggresome", "Cytosol", "Vesicles", "Negative"]

INT_2_STR = {x:LBL_NAMES[x] for x in np.arange(19)}
INT_2_STR_LOWER = {k:v.lower().replace(" ", "_") for k,v in INT_2_STR.items()}
STR_2_INT_LOWER = {v:k for k,v in INT_2_STR_LOWER.items()}
STR_2_INT = {v:k for k,v in INT_2_STR.items()}
FIG_FONT = dict(family="Helvetica, Arial", size=14, color="#000000")
LABEL_COLORS = [px.colors.label_rgb(px.colors.convert_to_RGB_255(x)) for x in sns.color_palette("Spectral", len(LBL_NAMES))]
LABEL_COL_MAP = {str(i):x for i,x in enumerate(LABEL_COLORS)}

def plot_predictions(img, masks, preds, confs=None, fill_alpha=0.3, lbl_as_str=True):
    # Initialize
    FONT = cv2.FONT_HERSHEY_SIMPLEX; FONT_SCALE = 1.5 ; FONT_THICKNESS = 2; FONT_LINE_TYPE = cv2.LINE_AA;
    COLORS = [[round(y*255) for y in x] for x in sns.color_palette("Spectral", len(LBL_NAMES))]
    to_plot = img.copy()
    cntr_img = img.copy()
    if confs==None:
        confs = [None,]*len(masks)

    cnts = grab_contours(
        cv2.findContours(
            masks, 
            cv2.RETR_EXTERNAL, 
            cv2.CHAIN_APPROX_SIMPLE
        ))
    cnts = sorted(cnts, key=lambda x: (cv2.boundingRect(x)[1], cv2.boundingRect(x)[0]))
        
    for c, pred, conf in zip(cnts, preds, confs):
        # We can only display one color so we pick the first
        color = COLORS[pred[0]]
        if not lbl_as_str:
            classes = "CLS=["+",".join([str(p) for p in pred])+"]"
        else:
            classes = ", ".join([INT_2_STR[p] for p in pred])
        M = cv2.moments(c)
        cx = int(M['m10']/M['m00'])
        cy = int(M['m01']/M['m00'])
        
        text_width, text_height = cv2.getTextSize(classes, FONT, FONT_SCALE, FONT_THICKNESS)[0]
        
        # Border and fill
        cv2.drawContours(to_plot, [c], contourIdx=-1, color=[max(0, x-40) for x in color], thickness=10)
        cv2.drawContours(cntr_img, [c], contourIdx=-1, color=(color), thickness=-1)
        
        # Text
        cv2.putText(to_plot, classes, (cx-text_width//2,cy-text_height//2),
                    FONT, FONT_SCALE, [min(255, x+40) for x in color], FONT_THICKNESS, FONT_LINE_TYPE)
    
    cv2.addWeighted(cntr_img, fill_alpha, to_plot, 1-fill_alpha, 0, to_plot)
    plt.figure(figsize=(16,16))
    plt.imshow(to_plot)
    plt.axis(False)
    plt.show()

In [None]:
model = EfficientNet.from_name('efficientnet-b0', num_classes = 19, in_channels = 4)
model.load_state_dict(torch.load('../input/pretrained-net/state_dict.pth'))
model = model.to(DEVICE)
model.eval()

In [None]:
DATA_DIR = os.path.join(PATH,'hpa-single-cell-image-classification')
TEST_DIR = os.path.join(DATA_DIR, 'test')

TEST_IMG_PATHS = sorted([os.path.join(TEST_DIR, f_name) for f_name in os.listdir(TEST_DIR)])
print(f"... The number of testing images is {len(TEST_IMG_PATHS)}" \
      f"\n\t--> i.e. {len(TEST_IMG_PATHS)//4} 4-channel images ...")

SUB_PATH = os.path.join(DATA_DIR, 'sample_submission.csv')
submission_df = pd.read_csv(SUB_PATH)

SWAP_PATH = os.path.join(PATH, 'efficientnet-inference/submission.csv')
ss_df = pd.read_csv(SWAP_PATH)


print("\n\nSAMPLE SUBMISSION DATAFRAME\n\n")
display(ss_df)

In [None]:
import gc
IMAGE_SIZES = [1728, 2048, 3072, 4096]
BATCH_SIZE = 1
CONF_THRESH = 0.0
TILE_SIZE = (256,256)
SAMPLES = 4 
segmentator = CellSegmentator()

predict_df_1728 = ss_df[ss_df.ImageWidth == IMAGE_SIZES[0]]
predict_df_2048 = ss_df[ss_df.ImageWidth == IMAGE_SIZES[1]]
predict_df_3072 = ss_df[ss_df.ImageWidth == IMAGE_SIZES[2]]
predict_df_4096 = ss_df[ss_df.ImageWidth == IMAGE_SIZES[3]]

predict_ids_1728 = predict_df_1728.ID.tolist() 
predict_ids_2048 = predict_df_2048.ID.tolist() 
predict_ids_3072 = predict_df_3072.ID.tolist() 
predict_ids_4096 = predict_df_4096.ID.tolist() 

In [None]:
BATCH_SIZE = 1
final_df = pd.DataFrame(columns = ['ID', 'ImageWidth', 'ImageHeight'], data = ss_df)
datas = [predict_ids_1728, predict_ids_2048, predict_ids_3072, predict_ids_4096]
flatten = lambda t: [item for sublist in t for item in sublist]

if IS_FINAL:
    for i, data in enumerate(datas):
        #loop through each set of prediction ids
        img_size = IMAGE_SIZES[i]
        if len(data) == 0:
            #skip sets that don't have any images
            print('skipping')
        else:
            predictions = []
            sub_df = pd.DataFrame(columns=["ID"], data=data)
    
            with torch.no_grad():
                for size_idx, submission_ids in enumerate([data]):
                    #loop through each id within each dataset
                    size = IMAGE_SIZES[size_idx]
                    
                    print(f"...WORKING... \n")
                    
                    #perform segmentation and prediction for each image
                    for i in tqdm(range(0, len(submission_ids), BATCH_SIZE), total=int(np.ceil(len(submission_ids)/BATCH_SIZE))):
                        #get numpy arrays for each image, stacked by channel
                        batch_rgby_images = [
                            load_image(submission_ids[i], TEST_DIR, testing = True)
                        ]
                        
                        #perform segmentation
                        cell_segmentations = segmentator.pred_cells([[rgby_image[j] for rgby_image in batch_rgby_images] for j in [0,2,3]])
                        nuc_segmentations  = segmentator.pred_nuclei([rgby_image[2] for rgby_image in batch_rgby_images])
                        
                        #generate masks
                        batch_masks = [label_cell(nuc_seg, cell_seg)[1].astype(np.uint8) for nuc_seg, cell_seg in zip(nuc_segmentations, cell_segmentations)]
                        
                        #delete these to free up memory?
                        del cell_segmentations, nuc_segmentations
                        gc.collect()
                        
                        batch_rgb_images = [rgby_image.transpose(1,2,0) for rgby_image in batch_rgby_images]
                        
                        #get bounding box for each segmented mask
                        batch_cell_bboxes = [get_contour_bbox_from_raw(mask) for mask in batch_masks]
                        
                        #get encoded mask for single cell prediction
                        submission_rles = [[encode_binary_mask(mask, mask_val=cell_id) for cell_id in range(1, mask.max()+1)] for mask in batch_masks]
                        
                        #cut out and resize tiles
                        batch_cell_tiles = [[
                            cv2.resize(
                                pad_to_square(
                                    rgb_image[bbox[1]:bbox[3], bbox[0]:bbox[2], ...]), 
                                TILE_SIZE, interpolation=cv2.INTER_CUBIC) for bbox in bboxes] 
                            for bboxes, rgb_image in zip(batch_cell_bboxes, batch_rgb_images)
                        ]
                        
                        #make sure tiles are in pytorch approved format
                        batch_cell_tiles_torchd = [get_torchd(tile, DEVICE) for tile in batch_cell_tiles]
                        
                        #generate predictions from model
                        batch_o_preds = [nnf.softmax(model(tile), dim=1) for tile in batch_cell_tiles_torchd]
                        batch_o_preds = [[pred.cpu() for pred in cell_preds] for cell_preds in batch_o_preds]
                        batch_o_preds = [[pred.detach().numpy() for pred in cell_preds] for cell_preds in batch_o_preds]
                        
                        #identify top labels
                        batch_confs = [[pred[np.where(pred>CONF_THRESH)] for pred in o_preds] for o_preds in batch_o_preds]
                        batch_preds = [[np.where(pred>CONF_THRESH)[0] for pred in o_preds] for o_preds in batch_o_preds]
                        
                        for j, preds in enumerate(batch_preds):
                            for k in range(len(preds)):
                                if preds[k].size==0:
                                    batch_preds[j][k]=np.array([18,])
                                    batch_confs[j][k]=np.array([1-np.max(batch_o_preds[j][k]),])
                        
                       
                        #make prediction strings
                        submission_strings = [flatten_list_of_lists([[m,]*len(p) for m, p in zip(masks, preds)]) for masks, preds in zip(submission_rles, batch_preds)]
                        batch_preds = [flatten_list_of_lists(preds, to_string=True) for preds in batch_preds]
                        batch_confs = [[f"{conf:.4f}" for cell_confs in confs for conf in cell_confs] for confs in batch_confs]
                        prediction = [" ".join(flatten_list_of_lists(zip(*[preds,confs,masks]))) for preds, confs, masks in zip(batch_preds, batch_confs, submission_rles)]
                        predictions.append(prediction)
                    predictions = flatten(predictions)
                    sub_df['PredictionString' + str(img_size)] = predictions
            final_df = final_df.merge(sub_df, how = 'left', on = 'ID')
            display(final_df)
    final_df['PredictionString'] = np.where(final_df['PredictionString1728'].notna(), final_df['PredictionString1728'], 0)
    final_df['PredictionString'] = np.where(final_df['PredictionString2048'].notna(), final_df['PredictionString2048'], final_df['PredictionString'])
    final_df['PredictionString'] = np.where(final_df['PredictionString3072'].notna(), final_df['PredictionString3072'], final_df['PredictionString'])

    final_df = final_df.drop(columns = ['PredictionString1728', 'PredictionString2048', 'PredictionString3072'], axis = 1)
    display(final_df.head(10))
else:
    final_df = ss_df
    display(final_df.head())


In [None]:
ss_df = submission_df.merge(final_df, how="left", on="ID")
ss_df["PredictionString"] = ss_df.apply(create_pred_col, axis=1)
ss_df = ss_df.drop(columns=["PredictionString_x", "PredictionString_y", "ImageWidth_y", "ImageHeight_y"])
ss_df = ss_df.rename(columns = {'ImageWidth_x': 'ImageWidth', 'ImageHeight_x': 'ImageHeight'})
ss_df

In [None]:
ss_df.to_csv("/kaggle/working/submission.csv", index=None)