Credits:
* https://www.kaggle.com/rdizzl3/hpa-segmentation-masks-no-internet
* https://www.kaggle.com/frlemarchand/generate-masks-from-weak-image-level-labels/


# Installation

In [None]:
!pip install -q "../input/pycocotools/pycocotools-2.0-cp37-cp37m-linux_x86_64.whl"
!pip install -q "../input/hpapytorchzoozip/pytorch_zoo-master"
!pip install -q "../input/hpacellsegmentatormaster/HPA-Cell-Segmentation-master"

In [None]:
import os
# Making pretrained weights work without needing to find the default filename
if not os.path.exists('/root/.cache/torch/hub/checkpoints/'):
        os.makedirs('/root/.cache/torch/hub/checkpoints/')
# !cp '../input/resnet50/resnet50.pth' '/root/.cache/torch/hub/checkpoints/resnet50-19c8e357.pth'
!cp '../input/resnet34/resnet34.pth' '/root/.cache/torch/hub/checkpoints/resnet34-333f7ec4.pth'

from fastai.vision.all import *
import pandas as pd
import numpy as np
from tqdm.autonotebook import tqdm
import imageio
from matplotlib import pyplot as plt

import sys

import os
from PIL import Image
import tensorflow as tf
import cv2

import hpacellseg.cellsegmentator as cellsegmentator
from hpacellseg.utils import label_cell, label_nuclei

# Helper functions

In [None]:
def build_image_names(image_id: str) -> list:
    mitchondria = f'../input/hpa-single-cell-image-classification/test/{image_id}_red.png'
    nuclei = f'../input/hpa-single-cell-image-classification/test/{image_id}_blue.png'
    # er is the endoplasmic reticulum
    er = f'../input/hpa-single-cell-image-classification/test/{image_id}_yellow.png'    
    cell = f'../input/hpa-single-cell-image-classification/test/{image_id}_green.png' 
    
    return [[mitchondria], [nuclei], [er], [cell]]

In [None]:
import base64
import numpy as np
from pycocotools import _mask as coco_mask
import typing as t
import zlib


def encode_binary_mask(mask: np.ndarray) -> t.Text:
  """Converts a binary mask into OID challenge encoding ascii text."""

  # check input mask --
  if mask.dtype != np.bool:
    raise ValueError(
        "encode_binary_mask expects a binary mask, received dtype == %s" %
        mask.dtype)

  mask = np.squeeze(mask)
  if len(mask.shape) != 2:
    raise ValueError(
        "encode_binary_mask expects a 2d mask, received shape == %s" %
        mask.shape)

  # convert input mask to expected COCO API input --
  mask_to_encode = mask.reshape(mask.shape[0], mask.shape[1], 1)
  mask_to_encode = mask_to_encode.astype(np.uint8)
  mask_to_encode = np.asfortranarray(mask_to_encode)

  # RLE encode mask --
  encoded_mask = coco_mask.encode(mask_to_encode)[0]["counts"]

  # compress and base64 encoding --
  binary_str = zlib.compress(encoded_mask, zlib.Z_BEST_COMPRESSION)
  base64_str = base64.b64encode(binary_str)
  return base64_str.decode('ascii')

In [None]:
# Input: list of image filters as png
# Output: list of image filters as np.arrays
def image_name_to_numpy(path):
    
    image_arrays = list()
    for image in path:
        array = np.asarray(Image.open(image[0]))
        image_arrays.append(array)
        
    return image_arrays


# Get single image that blends all RGBY into RGB
# Introduce the images as arrays. Can use the function above.

def get_blended_image(images): 
    # get rgby images for sample

    # blend rgby images into single array
    blended_array = np.stack(images[:-1], 2)

    # Create PIL Image
    blended_image = Image.fromarray( np.uint8(blended_array) )
    return blended_image


def get_contour_from_mask(raw_mask):
    contours, _ = cv2.findContours(raw_mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
    contours = (contours[0]).reshape(contours[0].shape[0], 2)
    #print(contours)
    x_min, y_min = list(np.amin(contours, axis=0))
    x_max, y_max = list(np.amax(contours, axis=0))
    return x_min, y_min, x_max, y_max

# Inference

In [None]:
tpath = Path('../input/hpa-single-cell-image-classification')
sub = pd.read_csv(tpath/'sample_submission.csv')

# sub = sub.sample(frac=0.03)
sub.ImageWidth.value_counts()

sub_dfs = []
for dim in sub.ImageWidth.unique():
    df = sub[sub['ImageWidth'] == dim].copy().reset_index(drop=True)
    sub_dfs.append(df)
    
NUC_MODEL = '../input/hpacellsegmentatormodelweights/dpn_unet_nuclei_v1.pth'
CELL_MODEL = '../input/hpacellsegmentatormodelweights/dpn_unet_cell_3ch_v1.pth'

segmentator = cellsegmentator.CellSegmentator(
    NUC_MODEL,
    CELL_MODEL,
    scale_factor=0.25,
    device="cuda",
    padding=False,
    multi_channel_model=True,
)

In [None]:
bs = 8
for sub in sub_dfs:
    print(f'Starting prediction for image size: {sub.ImageWidth.loc[0]}')
    for start in range(0, len(sub), bs):
        if start + bs > len(sub): end = len(sub)
        else: end = start + bs

        images = []
        images_green = []
        for row in range(start, end):
            image_id = sub['ID'].loc[row]
            img = build_image_names(image_id=image_id)
            img, img_green = img[:-1], img[-1]
            images.append(img)
            images_green.append(img_green)

        cell_images = []
        for img in images_green:
            arrays = image_name_to_numpy([img])
            cell_images.append(arrays[0])


        images = np.stack(images).squeeze()
        images = np.transpose(images).tolist()


        try: 
            nuc_segmentations = segmentator.pred_nuclei(images[1])
            cell_segmentations = segmentator.pred_cells(images)

#                 plt.imshow(cell_segmentations[0])
#                 plt.show()

            predstrings = []
            for i in tqdm(range(len(cell_segmentations))):
                _, cell_mask = label_cell(nuc_segmentations[i], cell_segmentations[i])

                # Unique vector of cell_mask numbers
                numbers = set(np.ravel(cell_mask))
                numbers.remove(0)

                predstring = ''
                for number in numbers:
                    isolated_cell = np.where(cell_mask==number, cell_mask, 0)
                    x = (cell_mask==number).astype(np.uint8)
                    x_min, y_min, x_max, y_max = get_contour_from_mask(x)
                    isolated_cell = cell_images[i][y_min:y_max, x_min:x_max]
                    
#                     plt.imshow(isolated_cell)
#                     plt.show()
                    
                    label = 0
                    confidence = 1

                    bmask = (cell_mask == number)
                    enc = encode_binary_mask(bmask)
                    predstring += str(label) + ' ' + str(confidence) + ' ' + enc + ' '
                predstrings.append(predstring)

                sys.exit()

            assert len(predstrings) == len(sub.loc[start:end-1])
            sub['PredictionString'].loc[start:end-1] = predstrings

        except: continue

In [None]:
all_subs = pd.concat(sub_dfs, ignore_index=True, sort=False)
all_subs.to_csv('submission.csv', index=False)
# all_subs.head()
# all_subs.tail()

# Visualisations

In [None]:
# image_id = sub_dfs[0]['ID'].loc[2]
# image = build_image_names(image_id=image_id)
# arrays = image_name_to_numpy(image)
# nuclei = arrays[1]
# cell = arrays[:-1]

# # Nuclei segmentation
# nuc_segmentations = segmentator.pred_nuclei(image[1])

# f, ax = plt.subplots(1, 2, figsize=(16,16))
# ax[0].imshow(arrays[1])
# ax[0].set_title('Original Nucleis', size=20)
# ax[1].imshow(nuc_segmentations[0])
# ax[1].set_title('Segmented Nucleis', size=20)
# plt.show()

# # Cell segmentation
# inter_step = [i for i in image[:-1]]
# print(inter_step)
# cell_segmentations = segmentator.pred_cells(inter_step)

# f, ax = plt.subplots(1, 2, figsize=(16,16))
# ax[0].imshow(get_blended_image(arrays))
# ax[0].set_title('Original Cells', size=20)
# ax[1].imshow(cell_segmentations[0])
# ax[1].set_title('Segmented Cells', size=20)
# plt.show()

# # Nuclei mask
# nuclei_mask = label_nuclei(nuc_segmentations[0])
# # Cell masks
# cell_nuclei_mask, cell_mask = label_cell(nuc_segmentations[0], cell_segmentations[0])

# # Unique vector of cell_mask numbers
# numbers = set(np.ravel(cell_mask))
# numbers.remove(0)

# fig = plt.figure(figsize=(25,6*len(numbers)/4))
# index = 1

# ax = fig.add_subplot(len(numbers)//4+1, 4, index)
# ax.set_title("Complete Cell Mask", size=20)
# plt.imshow(cell_mask)

# index += 1
# for number in numbers:
#     isolated_cell = np.where(cell_mask==number, cell_mask, 0)
#     ax = fig.add_subplot(len(numbers)//4+1, 4, index)
#     ax.set_title("Segment {number}", size=20)
#     plt.imshow(isolated_cell)
#     index += 1
    
# for number in numbers:
#     isolated_cell = np.where(cell_mask==number, cell_mask, 0)
#     x = (cell_mask==number).astype(np.uint8)
#     x_min, y_min, x_max, y_max = get_contour_from_mask(x)
#     isolated_cell = isolated_cell[y_min:y_max, x_min:x_max]
#     plt.imshow(isolated_cell)
#     plt.show()