**Solution overview:** https://www.kaggle.com/c/hpa-single-cell-image-classification/discussion/241637

This notebook contains 4 main stages:
1. Cell Segmentation
2. Image Level Prediction
3. Cell Level Prediction
4. Ensemble & Final Prediction

In [None]:
!pip install /kaggle/input/efficientnet-keras-source-code
!pip install /kaggle/input/pycocotools202/pycocotools-2.0.2-cp37-cp37m-linux_x86_64.whl
!pip install /kaggle/input/hpapytorchzoozip/pytorch_zoo-master
!pip install /kaggle/input/kerasapplications
!pip install /kaggle/input/efficientnet-keras-source-code/ -q --no-deps
!pip install /kaggle/input/hpacellsegmentatormaster/HPA-Cell-Segmentation-master/
!pip install /kaggle/input/gputil/

In [None]:
import sys
sys.path.append('../input/efficientnet-pytorch/EfficientNet-PyTorch/EfficientNet-PyTorch-master')
import os
import gc
import shutil
from tqdm.notebook import tqdm
import pickle
import math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import cv2
import tensorflow as tf
import tensorflow.keras.backend as K
import tensorflow as tfa
import tensorflow_addons as tfa
from tensorflow.keras.applications import densenet
from tensorflow.keras.applications import ResNet50V2
from tensorflow.keras.applications import InceptionV3
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.applications import Xception
import efficientnet.tfkeras as efn
import torch
from fastai.vision.all import *
from efficientnet_pytorch import EfficientNet
from pycocotools import mask as coco_mask
import base64
import typing as t
import zlib
from GPUtil import showUtilization as gpu_usage

import os.path
import urllib
import zipfile

import scipy.ndimage as ndi
from skimage import filters, measure, segmentation
from skimage.morphology import (binary_erosion, closing, disk,
                                remove_small_holes, remove_small_objects)
from hpacellseg.cellsegmentator import *
from glob import glob

In [None]:
os.makedirs('/root/.cache/torch/hub/checkpoints/', exist_ok=True)

!cp '../input/resnet50/resnet50.pth' '/root/.cache/torch/hub/checkpoints/resnet50-19c8e357.pth'

!cp '../input/efficientnet-pytorch-pretrained/adv-efficientnet-b1-0f3ce85a.pth' \
'/root/.cache/torch/hub/checkpoints/'

!cp '../input/efficientnet-pytorch-pretrained/adv-efficientnet-b5-86493f6b.pth' \
'/root/.cache/torch/hub/checkpoints/'

!cp '../input/efficientnet-pytorch/efficientnet-b1-dbc7070a.pth' '/root/.cache/torch/hub/checkpoints/'

!cp '../input/efficientnet-pytorch/efficientnet-b5-586e6cc6.pth' '/root/.cache/torch/hub/checkpoints/'

# Global config

In [None]:
%%time
class Config:
    seed = 42
    segmentor_bs = 8
    cell_level_bs = 128
    image_level_bs = 64
    fast_segment = True
    save_preds = True
    hidden_only = False
    showed_df = pd.read_csv('../input/hpa-public-test-submission/my_empty_submission.csv')
    sub_df = pd.read_csv('../input/hpa-single-cell-image-classification/sample_submission.csv')

    if hidden_only:
        sub_df = sub_df[~sub_df['ID'].isin(showed_df['ID'])].reset_index(drop=True)

In [None]:
Config.sub_df.tail()

In [None]:
def seed_everything(seed=Config.seed):
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

seed_everything()

In [None]:
gc.collect()

# 1. Cell segmentation

In [None]:
def encode_binary_mask(mask: np.ndarray) -> t.Text:
    """Converts a binary mask into OID challenge encoding ascii text."""

    # check input mask --
    if mask.dtype != np.bool:
        raise ValueError("encode_binary_mask expects a binary mask, received dtype == %s" %mask.dtype)

    mask = np.squeeze(mask)
    if len(mask.shape) != 2:
        raise ValueError("encode_binary_mask expects a 2d mask, received shape == %s" %mask.shape)

    # convert input mask to expected COCO API input --
    mask_to_encode = np.asfortranarray(mask.reshape(mask.shape[0], mask.shape[1], 1).astype(np.uint8))

    # RLE encode mask --
    encoded_mask = coco_mask.encode(mask_to_encode)[0]["counts"]

    # compress and base64 encoding --
    binary_str = zlib.compress(encoded_mask, zlib.Z_BEST_COMPRESSION)
    base64_str = base64.b64encode(binary_str)
    return base64_str.decode('ascii')

In [None]:
def load_images(df, root='../input/hpa-single-cell-image-classification/test/'):
    blue = []
    ryb = []
    image_ids = []

    for i, row in df.iterrows():
        r = os.path.join(root, f'{row.ID}_red.png')
        y = os.path.join(root, f'{row.ID}_yellow.png')
        b = os.path.join(root, f'{row.ID}_blue.png')
        r = cv2.imread(r, cv2.IMREAD_GRAYSCALE)
        y = cv2.imread(y, cv2.IMREAD_GRAYSCALE)
        b = cv2.imread(b, cv2.IMREAD_GRAYSCALE)
        blue_image = cv2.resize(b, (512, 512))
        ryb_image = cv2.resize(np.stack((r, y, b), axis=2), (512, 512))
        blue.append(blue_image)
        ryb.append(ryb_image)
        image_ids.append(row.ID)

    return blue, ryb, image_ids

In [None]:
def get_cropped_cell(img, msk):
    bmask = msk.astype(int)[...,None]
    masked_img = img * bmask
    true_points = np.argwhere(bmask)
    top_left = true_points.min(axis=0)
    bottom_right = true_points.max(axis=0)
    cropped_arr = masked_img[top_left[0]:bottom_right[0]+1,top_left[1]:bottom_right[1]+1]
    return cropped_arr

In [None]:
def read_img(image_id, color, root='../input/hpa-single-cell-image-classification/test', image_size=None):
    filename = f'{root}/{image_id}_{color}.png'
    assert os.path.exists(filename), f'not found {filename}'
    img = cv2.imread(filename, cv2.IMREAD_UNCHANGED)
    if image_size is not None:
        img = cv2.resize(img, (image_size, image_size))
    if img.max() > 255:
        img_max = img.max()
        img = (img/255).astype('uint8')
    return img

In [None]:
"""
This code is from host's public HPA Cell Segmentation github respository: https://github.com/CellProfiling/HPA-Cell-Segmentation
with modification from: https://www.kaggle.com/linshokaku/faster-hpa-cell-segmentation
"""

class CellSegmentator(object):
    """Uses pretrained DPN-Unet models to segment cells from images."""

    def __init__(
        self,
        nuclei_model="../input/hpacellsegmentatormodelweights/dpn_unet_nuclei_v1.pth",
        cell_model="../input/hpacellsegmentatormodelweights/dpn_unet_cell_3ch_v1.pth",
        scale_factor=1.0,
        device="cuda",
        padding=False,
        multi_channel_model=True,
    ):
        """Class for segmenting nuclei and whole cells from confocal microscopy images.
        It takes lists of images and returns the raw output from the
        specified segmentation model. Models can be automatically
        downloaded if they are not already available on the system.
        When working with images from the Huan Protein Cell atlas, the
        outputs from this class' methods are well combined with the
        label functions in the utils module.
        Note that for cell segmentation, there are two possible models
        available. One that works with 2 channeled images and one that
        takes 3 channels.
        Keyword arguments:
        nuclei_model -- A loaded torch nuclei segmentation model or the
                        path to a file which contains such a model.
                        If the argument is a path that points to a non-existant file,
                        a pretrained nuclei_model is going to get downloaded to the
                        specified path (default: './nuclei_model.pth').
        cell_model -- A loaded torch cell segmentation model or the
                      path to a file which contains such a model.
                      The cell_model argument can be None if only nuclei
                      are to be segmented (default: './cell_model.pth').
        scale_factor -- How much to scale images before they are fed to
                        segmentation models. Segmentations will be scaled back
                        up by 1/scale_factor to match the original image
                        (default: 0.25).
        device -- The device on which to run the models.
                  This should either be 'cpu' or 'cuda' or pointed cuda
                  device like 'cuda:0' (default: 'cuda').
        padding -- Whether to add padding to the images before feeding the
                   images to the network. (default: False).
        multi_channel_model -- Control whether to use the 3-channel cell model or not.
                               If True, use the 3-channel model, otherwise use the
                               2-channel version (default: True).
        """
        if device != "cuda" and device != "cpu" and "cuda" not in device:
            raise ValueError(f"{device} is not a valid device (cuda/cpu)")
        if device != "cpu":
            try:
                assert torch.cuda.is_available()
            except AssertionError:
                print("No GPU found, using CPU.", file=sys.stderr)
                device = "cpu"
        self.device = device

        if isinstance(nuclei_model, str):
            if not os.path.exists(nuclei_model):
                print(
                    f"Could not find {nuclei_model}. Downloading it now",
                    file=sys.stderr,
                )
                download_with_url(NUCLEI_MODEL_URL, nuclei_model)
            nuclei_model = torch.load(
                nuclei_model, map_location=torch.device(self.device)
            )
        if isinstance(nuclei_model, torch.nn.DataParallel) and device == "cpu":
            nuclei_model = nuclei_model.module

        self.nuclei_model = nuclei_model.to(self.device).eval()

        self.multi_channel_model = multi_channel_model
        if isinstance(cell_model, str):
            if not os.path.exists(cell_model):
                print(
                    f"Could not find {cell_model}. Downloading it now", file=sys.stderr
                )
                if self.multi_channel_model:
                    download_with_url(MULTI_CHANNEL_CELL_MODEL_URL, cell_model)
                else:
                    download_with_url(TWO_CHANNEL_CELL_MODEL_URL, cell_model)
            cell_model = torch.load(cell_model, map_location=torch.device(self.device))
        self.cell_model = cell_model.to(self.device).eval()
        self.scale_factor = scale_factor
        self.padding = padding

    def _image_conversion(self, images):
        """Convert/Format images to RGB image arrays list for cell predictions.
        Intended for internal use only.
        Keyword arguments:
        images -- list of lists of image paths/arrays. It should following the
                 pattern if with er channel input,
                 [
                     [microtubule_path0/image_array0, microtubule_path1/image_array1, ...],
                     [er_path0/image_array0, er_path1/image_array1, ...],
                     [nuclei_path0/image_array0, nuclei_path1/image_array1, ...]
                 ]
                 or if without er input,
                 [
                     [microtubule_path0/image_array0, microtubule_path1/image_array1, ...],
                     None,
                     [nuclei_path0/image_array0, nuclei_path1/image_array1, ...]
                 ]
        """
        microtubule_imgs, er_imgs, nuclei_imgs = images
        if self.multi_channel_model:
            if not isinstance(er_imgs, list):
                raise ValueError("Please speicify the image path(s) for er channels!")
        else:
            if not er_imgs is None:
                raise ValueError(
                    "second channel should be None for two channel model predition!"
                )

        if not isinstance(microtubule_imgs, list):
            raise ValueError("The microtubule images should be a list")
        if not isinstance(nuclei_imgs, list):
            raise ValueError("The microtubule images should be a list")

        if er_imgs:
            if not len(microtubule_imgs) == len(er_imgs) == len(nuclei_imgs):
                raise ValueError("The lists of images needs to be the same length")
        else:
            if not len(microtubule_imgs) == len(nuclei_imgs):
                raise ValueError("The lists of images needs to be the same length")

        if not all(isinstance(item, np.ndarray) for item in microtubule_imgs):
            microtubule_imgs = [
                os.path.expanduser(item) for _, item in enumerate(microtubule_imgs)
            ]
            nuclei_imgs = [
                os.path.expanduser(item) for _, item in enumerate(nuclei_imgs)
            ]

            microtubule_imgs = list(
                map(lambda item: imageio.imread(item), microtubule_imgs)
            )
            nuclei_imgs = list(map(lambda item: imageio.imread(item), nuclei_imgs))
            if er_imgs:
                er_imgs = [os.path.expanduser(item) for _, item in enumerate(er_imgs)]
                er_imgs = list(map(lambda item: imageio.imread(item), er_imgs))

        if not er_imgs:
            er_imgs = [
                np.zeros(item.shape, dtype=item.dtype)
                for _, item in enumerate(microtubule_imgs)
            ]
        cell_imgs = list(
            map(
                lambda item: np.dstack((item[0], item[1], item[2])),
                list(zip(microtubule_imgs, er_imgs, nuclei_imgs)),
            )
        )

        return cell_imgs

    def pred_nuclei(self, images):
        """Predict the nuclei segmentation.
        Keyword arguments:
        images -- A list of image arrays or a list of paths to images.
                  If as a list of image arrays, the images could be 2d images
                  of nuclei data array only, or must have the nuclei data in
                  the blue channel; If as a list of file paths, the images
                  could be RGB image files or gray scale nuclei image file
                  paths.
        Returns:
        predictions -- A list of predictions of nuclei segmentation for each nuclei image.
        """

        def _preprocess(image):
            if isinstance(image, str):
                image = imageio.imread(image)
            self.target_shape = image.shape
            if len(image.shape) == 2:
                image = np.dstack((image, image, image))
            image = transform.rescale(image, self.scale_factor, multichannel=True)
            nuc_image = np.dstack((image[..., 2], image[..., 2], image[..., 2]))
            if self.padding:
                rows, cols = nuc_image.shape[:2]
                self.scaled_shape = rows, cols
                nuc_image = cv2.copyMakeBorder(
                    nuc_image,
                    32,
                    (32 - rows % 32),
                    32,
                    (32 - cols % 32),
                    cv2.BORDER_REFLECT,
                )
            nuc_image = nuc_image.transpose([2, 0, 1])
            return nuc_image

        def _segment_helper(imgs):
            with torch.no_grad():
                mean = torch.as_tensor(NORMALIZE["mean"], device=self.device)
                std = torch.as_tensor(NORMALIZE["std"], device=self.device)
                imgs = torch.tensor(imgs).float()
                imgs = imgs.to(self.device)
                imgs = imgs.sub_(mean[:, None, None]).div_(std[:, None, None])

                imgs = self.nuclei_model(imgs)
                imgs = F.softmax(imgs, dim=1)
                return imgs

        preprocessed_imgs = list(map(_preprocess, images))
        bs = 24
        predictions = []
        for i in range(0, len(preprocessed_imgs), bs):
            start = i
            end = min(len(preprocessed_imgs), i+bs)
            x = preprocessed_imgs[start:end]
            pred = _segment_helper(x).cpu().numpy()
            predictions.append(pred)
        predictions = list(np.concatenate(predictions, axis=0))
        predictions = map(util.img_as_ubyte, predictions)
        predictions = list(map(self._restore_scaling_padding, predictions))
        return predictions

    def _restore_scaling_padding(self, n_prediction):
        """Restore an image from scaling and padding.
        This method is intended for internal use.
        It takes the output from the nuclei model as input.
        """
        n_prediction = n_prediction.transpose([1, 2, 0])
        if self.padding:
            n_prediction = n_prediction[
                32 : 32 + self.scaled_shape[0], 32 : 32 + self.scaled_shape[1], ...
            ]
        if not self.scale_factor == 1:
            n_prediction[..., 0] = 0
            n_prediction = cv2.resize(
                n_prediction,
                (self.target_shape[0], self.target_shape[1]),
                interpolation=cv2.INTER_AREA,
            )
        return n_prediction

    def pred_cells(self, images, precombined=False):
        """Predict the cell segmentation for a list of images.
        Keyword arguments:
        images -- list of lists of image paths/arrays. It should following the
                  pattern if with er channel input,
                  [
                      [microtubule_path0/image_array0, microtubule_path1/image_array1, ...],
                      [er_path0/image_array0, er_path1/image_array1, ...],
                      [nuclei_path0/image_array0, nuclei_path1/image_array1, ...]
                  ]
                  or if without er input,
                  [
                      [microtubule_path0/image_array0, microtubule_path1/image_array1, ...],
                      None,
                      [nuclei_path0/image_array0, nuclei_path1/image_array1, ...]
                  ]
                  The ER channel is required when multichannel is True
                  and required to be None when multichannel is False.
                  The images needs to be of the same size.
        precombined -- If precombined is True, the list of images is instead supposed to be
                       a list of RGB numpy arrays (default: False).
        Returns:
        predictions -- a list of predictions of cell segmentations.
        """

        def _preprocess(image):
            self.target_shape = image.shape
            if not len(image.shape) == 3:
                raise ValueError("image should has 3 channels")
            cell_image = transform.rescale(image, self.scale_factor, multichannel=True)
            if self.padding:
                rows, cols = cell_image.shape[:2]
                self.scaled_shape = rows, cols
                cell_image = cv2.copyMakeBorder(
                    cell_image,
                    32,
                    (32 - rows % 32),
                    32,
                    (32 - cols % 32),
                    cv2.BORDER_REFLECT,
                )
            cell_image = cell_image.transpose([2, 0, 1])
            return cell_image

        def _segment_helper(imgs):
            with torch.no_grad():
                mean = torch.as_tensor(NORMALIZE["mean"], device=self.device)
                std = torch.as_tensor(NORMALIZE["std"], device=self.device)
                imgs = torch.tensor(imgs).float()
                imgs = imgs.to(self.device)
                imgs = imgs.sub_(mean[:, None, None]).div_(std[:, None, None])

                imgs = self.cell_model(imgs)
                imgs = F.softmax(imgs, dim=1)
                return imgs

        if not precombined:
            images = self._image_conversion(images)
        preprocessed_imgs = list(map(_preprocess, images))
        bs = 24
        predictions = []
        for i in range(0, len(preprocessed_imgs), bs):
            start = i
            end = min(len(preprocessed_imgs), i+bs)
            x = preprocessed_imgs[start:end]
            pred = _segment_helper(x).cpu().numpy()
            predictions.append(pred)
        predictions = list(np.concatenate(predictions, axis=0))
        predictions = map(self._restore_scaling_padding, predictions)
        predictions = list(map(util.img_as_ubyte, predictions))

        return predictions


HIGH_THRESHOLD = 0.4
LOW_THRESHOLD = HIGH_THRESHOLD - 0.25


def download_with_url(url_string, file_path, unzip=False):
    """Download file with a link."""
    with urllib.request.urlopen(url_string) as response, open(
        file_path, "wb"
    ) as out_file:
        data = response.read()  # a `bytes` object
        out_file.write(data)

    if unzip:
        with zipfile.ZipFile(file_path, "r") as zip_ref:
            zip_ref.extractall(os.path.dirname(file_path))


def __fill_holes(image):
    """Fill_holes for labelled image, with a unique number."""
    boundaries = segmentation.find_boundaries(image)
    image = np.multiply(image, np.invert(boundaries))
    image = ndi.binary_fill_holes(image > 0)
    image = ndi.label(image)[0]
    return image


def label_cell(nuclei_pred, cell_pred):
    """Label the cells and the nuclei.
    Keyword arguments:
    nuclei_pred -- a 3D numpy array of a prediction from a nuclei image.
    cell_pred -- a 3D numpy array of a prediction from a cell image.
    Returns:
    A tuple containing:
    nuclei-label -- A nuclei mask data array.
    cell-label  -- A cell mask data array.
    0's in the data arrays indicate background while a continous
    strech of a specific number indicates the area for a specific
    cell.
    The same value in cell mask and nuclei mask refers to the identical cell.
    NOTE: The nuclei labeling from this function will be sligthly
    different from the values in :func:`label_nuclei` as this version
    will use information from the cell-predictions to make better
    estimates.
    """
    def __wsh(
        mask_img,
        threshold,
        border_img,
        seeds,
        threshold_adjustment=0.35,
        small_object_size_cutoff=10,
    ):
        img_copy = np.copy(mask_img)
        m = seeds * border_img  # * dt
        img_copy[m <= threshold + threshold_adjustment] = 0
        img_copy[m > threshold + threshold_adjustment] = 1
        img_copy = img_copy.astype(np.bool)
        img_copy = remove_small_objects(img_copy, small_object_size_cutoff).astype(
            np.uint8
        )

        mask_img[mask_img <= threshold] = 0
        mask_img[mask_img > threshold] = 1
        mask_img = mask_img.astype(np.bool)
        mask_img = remove_small_holes(mask_img, 63)
        mask_img = remove_small_objects(mask_img, 1).astype(np.uint8)
        markers = ndi.label(img_copy, output=np.uint32)[0]
        labeled_array = segmentation.watershed(
            mask_img, markers, mask=mask_img, watershed_line=True
        )
        return labeled_array

    nuclei_label = __wsh(
        nuclei_pred[..., 2] / 255.0,
        0.4,
        1 - (nuclei_pred[..., 1] + cell_pred[..., 1]) / 255.0 > 0.05,
        nuclei_pred[..., 2] / 255,
        threshold_adjustment=-0.25,
        small_object_size_cutoff=32,
    )

    # for hpa_image, to remove the small pseduo nuclei
    nuclei_label = remove_small_objects(nuclei_label, 157)
    nuclei_label = measure.label(nuclei_label)
    # this is to remove the cell borders' signal from cell mask.
    # could use np.logical_and with some revision, to replace this func.
    # Tuned for segmentation hpa images
    threshold_value = max(0.22, filters.threshold_otsu(cell_pred[..., 2] / 255) * 0.5)
    # exclude the green area first
    cell_region = np.multiply(
        cell_pred[..., 2] / 255 > threshold_value,
        np.invert(np.asarray(cell_pred[..., 1] / 255 > 0.05, dtype=np.int8)),
    )
    sk = np.asarray(cell_region, dtype=np.int8)
    distance = np.clip(cell_pred[..., 2], 255 * threshold_value, cell_pred[..., 2])
    cell_label = segmentation.watershed(-distance, nuclei_label, mask=sk)
    cell_label = remove_small_objects(cell_label, 344).astype(np.uint8)
    selem = disk(2)
    cell_label = closing(cell_label, selem)
    cell_label = __fill_holes(cell_label)
    # this part is to use green channel, and extend cell label to green channel
    # benefit is to exclude cells clear on border but without nucleus
    sk = np.asarray(
        np.add(
            np.asarray(cell_label > 0, dtype=np.int8),
            np.asarray(cell_pred[..., 1] / 255 > 0.05, dtype=np.int8),
        )
        > 0,
        dtype=np.int8,
    )
    cell_label = segmentation.watershed(-distance, cell_label, mask=sk)
    cell_label = __fill_holes(cell_label)
    cell_label = np.asarray(cell_label > 0, dtype=np.uint8)
    cell_label = measure.label(cell_label)
    cell_label = remove_small_objects(cell_label, 344)
    cell_label = measure.label(cell_label)
    cell_label = np.asarray(cell_label, dtype=np.uint16)
    nuclei_label = np.multiply(cell_label > 0, nuclei_label) > 0
    nuclei_label = measure.label(nuclei_label)
    nuclei_label = remove_small_objects(nuclei_label, 157)
    nuclei_label = np.multiply(cell_label, nuclei_label > 0)

    return nuclei_label, cell_label

In [None]:
cell_mask_dir = 'hpa_cell_mask'
nucl_mask_dir = 'hpa_nuclei_mask'

os.makedirs(cell_mask_dir, exist_ok=True)
os.makedirs(nucl_mask_dir, exist_ok=True)

data_df = Config.sub_df.copy()

sizes = np.unique(data_df[['ImageWidth', 'ImageHeight']].values, axis=0)
sizes = [tuple(size) for size in sizes]

if Config.fast_segment:
    cellsegmentor = CellSegmentator()
else:    
    NUC_MODEL = "../input/hpacellsegmentatormodelweights/dpn_unet_nuclei_v1.pth"
    CELL_MODEL = "../input/hpacellsegmentatormodelweights/dpn_unet_cell_3ch_v1.pth"
    import hpacellseg.cellsegmentator as cellsegmentator
    cellsegmentor = cellsegmentator.CellSegmentator(
            NUC_MODEL,
            CELL_MODEL,
            scale_factor=0.25,
            device="cuda",
            padding=True,
            multi_channel_model=True,
        )

In [None]:
%%time
batch_size = Config.segmentor_bs

for size in tqdm(sizes):
    data_df_gr = data_df[data_df[['ImageWidth', 'ImageHeight']].values==size]
    data_size = len(data_df_gr)
    for i in tqdm(range(0, data_size, batch_size), position=0, leave=True):

        start = i
        end = min(len(data_df_gr), start + batch_size)
        blue_batch, ryb_batch, image_ids = load_images(data_df_gr[start:end])
        nuc_segmentations = cellsegmentor.pred_nuclei(blue_batch)
        cell_segmentations = cellsegmentor.pred_cells(ryb_batch, precombined=True)

        for i, image_id in enumerate(image_ids):
            nucl_mask, cell_mask = label_cell(nuc_segmentations[i], cell_segmentations[i])
            nucl_mask = cv2.resize(nucl_mask, size, interpolation=cv2.INTER_NEAREST)
            cell_mask = cv2.resize(cell_mask, size, interpolation=cv2.INTER_NEAREST)

            if len(cell_mask) == 0:
                print('warning: no mask found!')
            np.savez_compressed(f'{cell_mask_dir}/{image_id}', cell_mask)
            np.savez_compressed(f'{nucl_mask_dir}/{image_id}', nucl_mask)

In [None]:
%%time
single_cells_save_dir = 'single_cells'
os.makedirs(single_cells_save_dir, exist_ok=True)

lbls = []
num_files = len(data_df)
all_cells = []

for idx in tqdm(range(num_files)):
    image_id = data_df.iloc[idx].ID
    cell_mask = np.load(f'{cell_mask_dir}/{image_id}.npz')['arr_0']
    red = read_img(image_id, "red")
    green = read_img(image_id, "green")
    blue = read_img(image_id, "blue")
    stacked_image = np.transpose(np.array([blue, green, red]), (1,2,0))

    for j in range(1, np.max(cell_mask) + 1):
        bmask = (cell_mask == j)
        enc = encode_binary_mask(bmask)
        cropped_cell = get_cropped_cell(stacked_image, bmask)
        fname = f'{image_id}_{j}.jpg'
        cv2.imwrite(os.path.join(single_cells_save_dir, fname), cropped_cell)
        all_cells.append({
            'image_id': image_id,
            'fname': fname,
            'cell_id': j,
            'size1': cropped_cell.shape[0],
            'size2': cropped_cell.shape[1],
            'enc': enc,
        })
cell_df = pd.DataFrame(all_cells)
cell_df.to_csv('cell_df.csv', index=False)

In [None]:
if len(cell_df) == 0:
    cell_df = pd.DataFrame({
            'image_id': [],
            'fname': [],
            'cell_id': [],
            'size1': [],
            'size2': [],
            'enc': [],
        }, dtype='object')
cell_df.tail()

In [None]:
print('before empty cache')
print(gpu_usage())
torch.cuda.empty_cache()
del cellsegmentor
print('\nafter empty cache')
print(gpu_usage())
gc.collect()

# 2. Cell Level Prediction

In [None]:
def get_learner(model_name, lr=1e-3):
    assert model_name in [f'b{i}' for i in range(8)] + ['resnet50', 'resnet101']
    
    opt_func = partial(Adam, lr=lr, wd=0.01, eps=1e-8)

    if model_name == 'resnet50':
        return cnn_learner(dls, resnet50, metrics=[accuracy_multi, PrecisionMulti()]).to_fp16()

    elif model_name == 'resnet101':
        return cnn_learner(dls, resnet101, metrics=[accuracy_multi, PrecisionMulti()]).to_fp16()

    elif model_name == 'b0':
        model = EfficientNet.from_pretrained("efficientnet-b0", advprop=True)
        model._fc = nn.Linear(1280, dls.c)

    elif model_name == 'b1':
        model = EfficientNet.from_pretrained("efficientnet-b1", advprop=True)  
        model._fc = nn.Linear(1280, dls.c)
        
    elif model_name == 'b2':
        model = EfficientNet.from_pretrained("efficientnet-b2", advprop=True)  
        model._fc = nn.Linear(1280, dls.c)

    elif model_name == 'b3':
        model = EfficientNet.from_pretrained("efficientnet-b3", advprop=True)  
        model._fc = nn.Linear(1536, dls.c)

    elif model_name == 'b4':
        model = EfficientNet.from_pretrained("efficientnet-b4", advprop=True)  
        model._fc = nn.Linear(1792, dls.c)

    elif model_name == 'b5':
        model = EfficientNet.from_pretrained("efficientnet-b5", advprop=True)
        model._fc = nn.Linear(2048, dls.c)

    elif model_name == 'b6':
        model = EfficientNet.from_pretrained("efficientnet-b6", advprop=True)
        model._fc = nn.Linear(2304, dls.c)
        
    elif model_name == 'b7':
        model = EfficientNet.from_pretrained("efficientnet-b7", advprop=True)
        model._fc = nn.Linear(2560, dls.c)

    learn = Learner(
        dls, model, opt_func=opt_func,
        metrics=[accuracy_multi, PrecisionMulti()]
        ).to_fp16()

    return learn

In [None]:
cell_df['path'] = single_cells_save_dir + '/' + cell_df['fname']
cell_df['image_labels'] = '0'
cell_df.tail()

In [None]:
model_dict = {'GREEN': {'resnet50': [('../input/single-cell-models/resnet50_green_model_3.pth', 2)],

                       'b1': [('../input/single-cell-models/b1_green_model_1.pth', 3)],

                      'b5': [('../input/single-cell-models/green_model_3.pth', 2)]
                      },

             'RGB': {'b5': [('../input/single-cell-models/fold0_b5_rgb_balanced_2.pth', 1),
                           ('../input/single-cell-models/fold4_traindataset_b5_size128_bs128_3.pth', 2)]
                     }
             }

sample_stats = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
item_tfms = RandomResizedCrop(224, min_scale=0.75, ratio=(1.,1.))
batch_tfms = [*aug_transforms(flip_vert=True, size=128, max_warp=0), Normalize.from_stats(*sample_stats)]
bs = Config.cell_level_bs

def get_y(r): 
    return r['image_labels'].split('|')

In [None]:
%%time
if len(cell_df) == 0:
    preds = []
else:
    ttas = []
    sum_weights = 0

    for color, v1 in model_dict.items():

        assert color in ['GREEN', 'RGB']

        if color=='GREEN':
            def get_x(r): 
                return cv2.imread(r['path'])[:,:,1]
        else:
            def get_x(r): 
                return r['path']

        dblock = DataBlock(blocks=(ImageBlock, MultiCategoryBlock(vocab=[str(i) for i in range(19)])),        
                get_x=get_x,
                get_y=get_y,
                item_tfms=item_tfms,
                batch_tfms=batch_tfms
                )

        dls = dblock.dataloaders(cell_df, bs=bs)

        test_dl = dls.test_dl(cell_df)

        for model_name, v3 in v1.items():

            learn = get_learner(model_name)

            for model_path, weight in v3:

                save_name = os.path.split(model_path)[-1].split('.')[0]
                print(f'{save_name}\n')

                learn.model.load_state_dict(load_learner(model_path, cpu=False))
                tta, _ = learn.tta(dl=test_dl)

                if Config.save_preds:
                    with open(f'{save_name}.pickle', 'wb') as handle:
                        pickle.dump(tta, handle)

                ttas.append(tta * weight)
                sum_weights += weight

            del learn

        del dblock
        del dls
        del test_dl
        torch.cuda.empty_cache()

    preds = torch.sum(torch.stack(ttas, axis=0), axis=0) / sum_weights
    preds[:,-1] = torch.prod(1 - preds[:,:-1], axis=1)

In [None]:
if type(preds) != torch.Tensor:
    preds = torch.Tensor(preds)

cell_df['cls'] = ''
threshold = 0.0

for i in range(preds.shape[0]):
    p = torch.nonzero(preds[i] > threshold).squeeze().numpy().tolist()
    if type(p) != list: 
        p = [p]

    if len(p) == 0: 
        cls = [(preds[i].argmax().item(), preds[i].max().item())]
    else:
        cls = [(x, preds[i][x].item()) for x in p]
    cell_df['cls'].loc[i] = cls

In [None]:
def combine(r):
    cls = r[0]
    enc = r[1]
    classes = [str(c[0]) + ' ' + str(c[1]) + ' ' + enc for c in cls]
    return ' '.join(classes)

In [None]:
if len(cell_df) == 0:
    cell_df['pred'] = ''
else:
    cell_df['pred'] = cell_df[['cls', 'enc']].apply(combine, axis=1)
cell_df.tail()

In [None]:
subm = cell_df.groupby(['image_id'])['pred'].apply(lambda x: ' '.join(x)).reset_index()
subm.head()

In [None]:
sample_submission = Config.sub_df.copy()
sample_submission = sample_submission.drop(sample_submission.columns[-1:], axis=1)

ss_df = pd.merge(
    sample_submission,
    subm,
    how="left",
    left_on='ID',
    right_on='image_id',
)
ss_df['PredictionString'] = ss_df['pred']
ss_df = ss_df[['ID', 'ImageWidth', 'ImageHeight', 'PredictionString']]
ss_df.tail()

# 3. Image Level Prediction

In [None]:
print('before empty cache')
print(gpu_usage())
torch.cuda.empty_cache()
print('\nafter empty cache')
print(gpu_usage())
gc.collect()

In [None]:
"""
Code inherited from: https://www.kaggle.com/cdeotte/rotation-augmentation-gpu-tpu-0-96
"""

def get_mat(rotation, shear, height_zoom, width_zoom, height_shift, width_shift):
    # returns 3x3 transformmatrix which transforms indicies

    # CONVERT DEGREES TO RADIANS
    rotation = math.pi * rotation / 180.
    shear = math.pi * shear / 180.

    # ROTATION MATRIX
    c1 = tf.math.cos(rotation)
    s1 = tf.math.sin(rotation)
    one = tf.constant([1],dtype='float32')
    zero = tf.constant([0],dtype='float32')
    rotation_matrix = tf.reshape(tf.concat([c1,s1,zero, -s1,c1,zero, zero,zero,one],axis=0), [3,3])

    # SHEAR MATRIX
    c2 = tf.math.cos(shear)
    s2 = tf.math.sin(shear)
    shear_matrix = tf.reshape(tf.concat([one,s2,zero, zero,c2,zero, zero,zero,one],axis=0), [3,3])

    # ZOOM MATRIX
    zoom_matrix = tf.reshape(tf.concat([one/height_zoom,zero,zero, zero,one/width_zoom,zero, zero,zero,one],axis=0), [3,3])

    # SHIFT MATRIX
    shift_matrix = tf.reshape( tf.concat([one,zero,height_shift, zero,one,width_shift, zero,zero,one],axis=0), [3,3])

    return K.dot(K.dot(rotation_matrix, shear_matrix), K.dot(zoom_matrix, shift_matrix))

In [None]:
def auto_select_accelerator():
    try:
        tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
        tf.config.experimental_connect_to_cluster(tpu)
        tf.tpu.experimental.initialize_tpu_system(tpu)
        strategy = tf.distribute.experimental.TPUStrategy(tpu)
        print("Running on TPU:", tpu.master())
    except ValueError:
        strategy = tf.distribute.get_strategy()
    print(f"Running on {strategy.num_replicas_in_sync} replicas")
    
    return strategy


def build_decoder(color, with_labels=True, target_size=(256, 256), res16=False, ext='png'):
    assert color in ['GREEN', 'RGB']

    if color=='RGB':
      
        def decode(img_id):
          
            paths = [img_id + f'_{colour}.' + ext for colour in ['red', 'green', 'blue']]
          
            file_bytes_all = [tf.io.read_file(path) for path in paths]

            if ext == 'png':
                if res16:
                    img = tf.concat([tf.image.decode_png(file_bytes, channels=1, dtype=tf.dtypes.uint16) for \
                              file_bytes in file_bytes_all], axis=-1)
                else:
                    img = tf.concat([tf.image.decode_png(file_bytes, channels=1) for \
                              file_bytes in file_bytes_all], axis=-1)

            elif ext in ['jpg', 'jpeg']:
                img = tf.concat([tf.image.decode_jpeg(file_bytes, channels=1) for \
                              file_bytes in file_bytes_all], axis=-1)
            else:
                raise ValueError("Image extension not supported")

            if ext == 'png' and res16:
                img = tf.cast(img, tf.float32) / float(2**16 - 1)
            else:
                img = tf.cast(img, tf.float32) / 255.0

            img = tf.image.resize(img, target_size)

            mean = tf.convert_to_tensor([0.485, 0.456, 0.406])
            std = tf.convert_to_tensor([0.229, 0.224, 0.225])

            img = (img-mean)/std

            return img

        def decode_with_labels(path, label):
            return decode(path), label

        return decode_with_labels if with_labels else decode

    else:
        def decode(path):
            file_bytes = tf.io.read_file(path)

            if ext == 'png':
                if res16:
                    img = tf.image.decode_png(file_bytes, channels=3, dtype=tf.dtypes.uint16)
                else:
                    img = tf.image.decode_png(file_bytes, channels=3)
            elif ext in ['jpg', 'jpeg']:
                img = tf.image.decode_jpeg(file_bytes, channels=3)
            else:
                raise ValueError("Image extension not supported")

            if ext == 'png' and res16:
                img = tf.cast(img, tf.float32) / float(2**16 - 1)
            else:
                img = tf.cast(img, tf.float32) / 255.0
            img = tf.image.resize(img, target_size)

            return img

        def decode_with_labels(path, label):
            return decode(path), label
      
        return decode_with_labels if with_labels else decode

def build_augmenter(dim=600, extra_aug=True, with_labels=True):

    def transform(image):
        """
        Code inherited from: https://www.kaggle.com/cdeotte/rotation-augmentation-gpu-tpu-0-96
        """
        # input image - is one image of size [dim,dim,3] not a batch of [b,dim,dim,3]
        # output - image randomly rotated, sheared, zoomed, and shifted
        DIM = dim
        XDIM = DIM%2 #fix for size 331

        rot = 90. * tf.random.normal([1], dtype='float32')
        shr = 5. * tf.random.normal([1], dtype='float32') 
        h_zoom = 1.0 + tf.random.normal([1], dtype='float32')/10.
        w_zoom = 1.0 + tf.random.normal([1], dtype='float32')/10.
        h_shift = 0.05 * DIM * tf.random.normal([1], dtype='float32') 
        w_shift = 0.05 * DIM * tf.random.normal([1], dtype='float32')

        # GET TRANSFORMATION MATRIX
        m = get_mat(rot, shr, h_zoom, w_zoom, h_shift, w_shift) 

        # LIST DESTINATION PIXEL INDICES
        x = tf.repeat(tf.range(DIM//2,-DIM//2,-1), DIM)
        y = tf.tile(tf.range(-DIM//2,DIM//2), [DIM])
        z = tf.ones([DIM*DIM], dtype='int32')
        idx = tf.stack([x,y,z])

        # ROTATE DESTINATION PIXELS ONTO ORIGIN PIXELS
        idx2 = K.dot(m, tf.cast(idx, dtype='float32'))
        idx2 = K.cast(idx2, dtype='int32')
        idx2 = K.clip(idx2, -DIM//2+XDIM+1, DIM//2)

        # FIND ORIGIN PIXEL VALUES           
        idx3 = tf.stack([DIM//2-idx2[0,], DIM//2-1+idx2[1,]])
        d = tf.gather_nd(image, tf.transpose(idx3))

        return tf.reshape(d, [DIM, DIM, 3])

    if extra_aug:
        def augment(img):
            img = tf.image.random_flip_left_right(img, seed=Config.seed)
            img = tf.image.random_flip_up_down(img, seed=Config.seed)
            img = tf.image.random_brightness(img, max_delta=0.1, seed=Config.seed)
            img = transform(img)

            return img
    else:
        def augment(img):
            img = tf.image.random_flip_left_right(img, seed=Config.seed)
            img = tf.image.random_flip_up_down(img, seed=Config.seed)
            img = tf.image.random_brightness(img, max_delta=0.1, seed=Config.seed)
            
            return img

    def augment_with_labels(img, label):
        return augment(img), label

    return augment_with_labels if with_labels else augment

def build_dataset(paths, labels=None, bsize=128, cache=True,
                  decode_fn=None, augment_fn=None,
                  augment=True, repeat=True, 
                  shuffle=1024, cache_dir=""):
    if cache_dir != "" and cache is True:
        os.makedirs(cache_dir, exist_ok=True)

    if decode_fn is None:
        decode_fn = build_decoder(labels is not None)

    if augment_fn is None:
        augment_fn = build_augmenter(labels is not None)  

    AUTO = tf.data.experimental.AUTOTUNE
    slices = paths if labels is None else (paths, labels)

    dset = tf.data.Dataset.from_tensor_slices(slices)
    dset = dset.map(decode_fn, num_parallel_calls=AUTO)
    dset = dset.cache(cache_dir) if cache else dset
    dset = dset.map(augment_fn, num_parallel_calls=AUTO) if augment else dset
    dset = dset.repeat() if repeat else dset
    dset = dset.shuffle(shuffle) if shuffle else dset
    dset = dset.batch(bsize).prefetch(AUTO)

    return dset

In [None]:
strategy = auto_select_accelerator()
BATCH_SIZE = Config.image_level_bs
COMPETITION_NAME = '../input/hpa-single-cell-image-classification'
sub_df = Config.sub_df.copy()
sub_df = sub_df.drop(sub_df.columns[1:], axis=1)
label_cols = [str(i) for i in range(19)]
for i in label_cols:
    sub_df[i] = pd.Series(np.zeros(sub_df.shape[0]))

In [None]:
model_dict = {'GREEN':
              {True: {True: {600: [#('../input/image-level-models/image_level_models/green/b0/res8/fold0_b0_GREEN_bs256_40epochs_unnormalized_augmented.h5', 1),
                                     #('../input/image-level-models/image_level_models/green/b0/res8/fold1_b0_GREEN_bs256_40epochs_unnormalized_augmented.h5', ),
                                     ('../input/image-level-models/image_level_models/green/b0/res8/fold2_b0_GREEN_bs256_50epochs_unnormalized_augmented.h5', 2),
                                     #('../input/image-level-models/image_level_models/green/b0/res8/fold3_b0_GREEN_bs256_50epochs_unnormalized_augmented.h5', ),

                                     ('../input/image-level-models/image_level_models/green/b1/fold0_b1_GREEN_bs256_unnormalized_augmented.h5', 2), # augment?
                                     #('../input/image-level-models/image_level_models/green/b1/fold1_b1_GREEN_bs256_50epochs_unnormalized_augmented.h5', ),
                                     #('../input/image-level-models/image_level_models/green/b1/fold4_b1_GREEN_bs256_50epochs_unnormalized_augmented.h5', ),

                                     #('../input/image-level-models/image_level_models/green/b2/fold1_b2_GREEN_bs256_50epochs_unnormalized_augmented.h5', ),

                                     #('../input/image-level-models/image_level_models/green/b3/40e/fold0_b3_GREEN_bs256_40epochs_unnormalized_augmented.h5', ),
                                     #('../input/image-level-models/image_level_models/green/b3/40e/fold1_b3_GREEN_bs256_40epochs_unnormalized_augmented.h5', ),
                                     #('../input/image-level-models/image_level_models/green/b3/50e/fold2_b3_GREEN_bs256_50epochs_unnormalized_augmented.h5', ),
                                     ('../input/image-level-models/image_level_models/green/b3/50e/fold3_b3_GREEN_bs256_50epochs_unnormalized_augmented.h5', 1.5),
                                     #('../input/image-level-models/image_level_models/green/b3/50e/fold4_b3_GREEN_bs256_50epochs_unnormalized_augmented.h5', ),

                                    #('../input/image-level-models/image_level_models/green/resnet/fold0_resnet_GREEN_bs256_50epochs_unnormalized_augmented.h5'),
                                    ('../input/image-level-models/image_level_models/green/resnet/fold1_resnet_GREEN_bs256_50epochs_unnormalized_augmented.h5', 1.5),
                                    ],

                               700: [('../input/image-level-models/image_level_models/green/densenet/fold2_densenet_GREEN_bs256_40epochs_size700_unnormalized_augmented.h5', 1.5)]
                      },
                       
                       False: {600: [
                                   ('../input/image-level-models/image_level_models/green/b2/fold0_b2_green_bs256_unnormalized.h5', 2),

                                   ('../input/image-level-models/image_level_models/green/b5/fold0_b5_green_public_unnormalized_bs16.h5', 2),

                                    #('../input/image-level-models/image_level_models/green/b7/model_green_fold0_b7.h5', ),
                                    #('../input/image-level-models/image_level_models/green/b7/model_green_fold1_b7.h5', ),
                                    #('../input/image-level-models/image_level_models/green/b7/model_green_fold2_b7.h5', ),
                                    #('../input/image-level-models/image_level_models/green/b7/model_green_fold3_b7.h5', ),
                                    ('../input/image-level-models/image_level_models/green/b7/model_green_fold4_b7.h5', 2)
                       ]
                       }},

              False: {True: {720: [('../input/image-level-models/image_level_models/green/b0/res16/fold2_b0_GREEN_bs256_50epochs_size720_unnormalized_augmented_res16.h5', 4)],
                            },
                     False: {}}},

             'RGB':
              {False: {False: {}, 
                       True: {}},

              True: {True: {700: [('../input/image-level-models/image_level_models/rgb/res16/fold3_b0_RGB_bs256_50epochs_size700_unnormalized_augmented_res16.h5', 4)]}, 
                     False: {}}}}

In [None]:
%%time

ttas = []

if len(sub_df) > 0:

    sum_weights = 0
    for color, v1 in model_dict.items():
        test_paths = COMPETITION_NAME + "/test/" + sub_df['ID']
        if color=='GREEN':
            test_paths += '_green.png'

        for res16, v2 in v1.items():

            for extra_aug, v3 in v2.items():

                for size, v4 in v3.items():

                    test_decoder = build_decoder(color=color, with_labels=False, target_size=(size, size), res16=res16)
                    test_augmenter = build_augmenter(dim=size, extra_aug=extra_aug, with_labels=False)

                    dtest_no_tta = build_dataset(
                            test_paths, bsize=BATCH_SIZE, repeat=False,
                            shuffle=False, augment=False, cache=False,
                            decode_fn=test_decoder)

                    dtest_tta = build_dataset(
                                test_paths, bsize=BATCH_SIZE, repeat=False,
                                shuffle=False, augment=True, cache=False,
                                decode_fn=test_decoder, augment_fn=test_augmenter)

                    for model_path, weight in v4:

                        print(f'color = {color}')
                        print(f'16bit = {res16}')
                        print(f'extra_aug = {extra_aug}')
                        print(f'image_size = {size}\n')

                        model_name = os.path.split(model_path)[-1].split('.')[0]

                        print(model_name)

                        with strategy.scope():
                            model = tf.keras.models.load_model(model_path)

                        tta = [model.predict(dtest_no_tta, verbose=1)]

                        num_steps = 4

                        for step in range(num_steps):
                            tta.append(model.predict(dtest_tta, verbose=1))

                        num_steps += 1

                        sum_weights += weight

                        tta = np.mean(np.stack(tta, axis=0), axis=0)

                        if Config.save_preds:
                            np.save(model_name, tta, allow_pickle=True)

                        ttas.append(tta * weight)

                        print('\n')

    ttas = np.sum(np.stack(ttas, axis=0), axis=0) / sum_weights

# 4. Ensemble & Final Prediction

In [None]:
sub_df[label_cols] = ttas

ss_df = pd.merge(ss_df, sub_df, on='ID', how ='left')

for i in range(ss_df.shape[0]):
    a = ss_df.loc[i,'PredictionString']
    b = a.split()
    for j in range(int(len(b)/3)):
        for k in range(19):
            if int(b[0 + 3 * j]) == k:

                w = 0.5
                c = b[1 + 3 * j]
                b[1 + 3 * j] = str(0.5*((ss_df.loc[i,f'{k}'] * w + float(c) * (1 - w)) + (ss_df.loc[i,f'{k}']**w) * (float(c))**(1 - w)))

    ss_df.loc[i,'PredictionString'] = ' '.join(b)

In [None]:
if Config.hidden_only:
    ss_df = pd.concat([Config.showed_df, ss_df], ignore_index=True)

ss_df = ss_df[['ID','ImageWidth','ImageHeight','PredictionString']]
ss_df.to_csv('submission.csv', index=False)
ss_df.tail()

In [None]:
!ls

In [None]:
for tree in (single_cells_save_dir, cell_mask_dir, nucl_mask_dir):
    shutil.rmtree(tree)

In [None]:
!ls