# Human protein atlas -single cell classification

The Human Protein Atlas is an initiative based in Sweden that is aimed at mapping proteins in all human cells, tissues, and organs. The data in the Human Protein Atlas database is freely accessible to scientists all around the world that allows them to explore the cellular makeup of the human body. Solving the single-cell image classification challenge will help us characterize single-cell heterogeneity in our large collection of images by generating more accurate annotations of the subcellular localizations for thousands of human proteins in individual cells. Thanks to you, we will be able to more accurately model the spatial organization of the human cell and provide new open-access cellular data to the scientific community, which may accelerate our growing understanding of how human cells functions and how diseases develop.

In [None]:
!pip install /kaggle/input/kerasapplications -q
!pip install /kaggle/input/efficientnet-keras-source-code/ -q --no-deps

In [None]:
# Cell Segmentator Tool
print("\n... INSTALLING AND IMPORTING CELL-PROFILER TOOL (HPACELLSEG) ...\n")
try:
    import hpacellseg.cellsegmentator as cellsegmentator
    from hpacellseg.utils import label_cell
except:
    !pip install -q "/kaggle/input/pycocotools/pycocotools-2.0-cp37-cp37m-linux_x86_64.whl"
    !pip install -q "/kaggle/input/hpapytorchzoozip/pytorch_zoo-master"
    !pip install -q "/kaggle/input/hpacellsegmentatormaster/HPA-Cell-Segmentation-master"
    import hpacellseg.cellsegmentator as cellsegmentator
    from hpacellseg.utils import label_cell

print("\n... OTHER IMPORTS STARTING ...\n")
print("\n\tVERSION INFORMATION")

# Machine Learning and Data Science Imports
import tensorflow as tf; print(f"\t\t– TENSORFLOW VERSION: {tf.__version__}");
import pandas as pd; pd.options.mode.chained_assignment = None;
import numpy as np; print(f"\t\t– NUMPY VERSION: {np.__version__}");
import torch

import pandas as pd
import os

import efficientnet.tfkeras as efn
import numpy as np
import pandas as pd
import tensorflow as tf

# Built In Imports
from collections import Counter
from datetime import datetime
import multiprocessing
from glob import glob
import warnings
import requests
import imageio
import IPython
import urllib
import zipfile
import pickle
import random
import shutil
import string
import math
import tqdm
import time
import gzip
import sys
import ast
import csv; csv.field_size_limit(sys.maxsize)
import io
import os
import gc
import re

# Visualization Imports
from matplotlib.colors import ListedColormap
import matplotlib.patches as patches
import plotly.graph_objects as go
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import plotly.express as px
import seaborn as sns
from PIL import Image
import matplotlib; print(f"\t\t– MATPLOTLIB VERSION: {matplotlib.__version__}");
import plotly
import PIL
import cv2

# Submission Imports
from pycocotools import _mask as coco_mask
import typing as t
import base64
import zlib

# PRESETS
LBL_NAMES = ["Nucleoplasm", "Nuclear Membrane", "Nucleoli", "Nucleoli Fibrillar Center", "Nuclear Speckles", "Nuclear Bodies", "Endoplasmic Reticulum", "Golgi Apparatus", "Intermediate Filaments", "Actin Filaments", "Microtubules", "Mitotic Spindle", "Centrosome", "Plasma Membrane", "Mitochondria", "Aggresome", "Cytosol", "Vesicles", "Negative"]
INT_2_STR = {x:LBL_NAMES[x] for x in np.arange(19)}
INT_2_STR_LOWER = {k:v.lower().replace(" ", "_") for k,v in INT_2_STR.items()}
STR_2_INT_LOWER = {v:k for k,v in INT_2_STR_LOWER.items()}
STR_2_INT = {v:k for k,v in INT_2_STR.items()}
FIG_FONT = dict(family="Helvetica, Arial", size=14, color="#7f7f7f")
LABEL_COLORS = [px.colors.label_rgb(px.colors.convert_to_RGB_255(x)) for x in sns.color_palette("Spectral", len(LBL_NAMES))]
LABEL_COL_MAP = {str(i):x for i,x in enumerate(LABEL_COLORS)}

print("\n\n... IMPORTS COMPLETE ...\n")

##### THIS IS FOR PROTOTYPING AND PUBLIC LB PROBING #####
ONLY_PUBLIC = True
##### THIS IS FOR PROTOTYPING AND PUBLIC LB PROBING#####

if ONLY_PUBLIC:
    print("\n... ONLY INFERRING ON PUBLIC TEST DATA (USING PRE-PROCESSED DF) ...\n")
else:
    # Stop Tensorflow From Eating All The Memory
    gpus = tf.config.experimental.list_physical_devices('GPU')
    if gpus:
        try:
            # Currently, memory growth needs to be the same across GPUs
            for gpu in gpus:
                tf.config.experimental.set_memory_growth(gpu, True)
            logical_gpus = tf.config.experimental.list_logical_devices('GPU')
            print(len(gpus), "... Physical GPUs,", len(logical_gpus), "Logical GPUs ...\n")
        except RuntimeError as e:
            # Memory growth must be set before GPUs have been initialized
            print(e)

# Notebook setup

In [None]:
# Define paths to nucleus and cell models for the cellsegmentator class
NUC_MODEL = '/kaggle/input/hpacellsegmentatormodelweights/dpn_unet_nuclei_v1.pth'
CELL_MODEL = '/kaggle/input/hpacellsegmentatormodelweights/dpn_unet_cell_3ch_v1.pth'

#B2_CELL_CLSFR_DIR = "/kaggle/input/hpa-cellwise-classification-training/ebnet_b2_wdensehead/ckpt-0007-0.0901.ckpt"

B2_CELL_CLSFR_DIR = "/kaggle/input/hpacellb4-ckp/ebnet_b2_wdensehead/ckpt-0006-0.0563.ckpt"
#actually b4

# Define the path to the competition data directory
DATA_DIR = "/kaggle/input/hpa-single-cell-image-classification"

# Define the paths to the training and testing tfrecord and 
# image folders respectively for the competition data
TEST_IMG_DIR = os.path.join(DATA_DIR, "test")

# Capture all the relevant full image paths for the competition dataset
TEST_IMG_PATHS = sorted([os.path.join(TEST_IMG_DIR, f_name) for f_name in os.listdir(TEST_IMG_DIR)])
print(f"... The number of testing images is {len(TEST_IMG_PATHS)}" \
      f"\n\t--> i.e. {len(TEST_IMG_PATHS)//4} 4-channel images ...")

# Define paths to the relevant csv files
PUB_SS_CSV = "/kaggle/input/hpa-sample-submission-with-extra-metadata/updated_sample_submission.csv"
SWAP_SS_CSV = os.path.join(DATA_DIR, "sample_submission.csv")

# Create the relevant dataframe objects
ss_df = pd.read_csv(SWAP_SS_CSV)

# Test Time Augmentation Information
DO_TTA = True
TTA_REPEATS = 8

# helps us control whether this is the full submission or just the initial pass
IS_DEMO = len(ss_df)==559

if IS_DEMO:
    ss_df_1 = ss_df.drop_duplicates("ImageWidth", keep="first")
    ss_df_2 = ss_df.drop_duplicates("ImageWidth", keep="last")
    ss_df = pd.concat([ss_df_1, ss_df_2])
    del ss_df_1; del ss_df_2; gc.collect();
    print("\n\nSAMPLE SUBMISSION DATAFRAME\n\n")
    display(ss_df)
else:
    print("\n\nSAMPLE SUBMISSION DATAFRAME\n\n")
    display(ss_df)
    
# If demo-submission/display we only do a subset of the data
if ONLY_PUBLIC:
    pub_ss_df = pd.read_csv(PUB_SS_CSV)
    
    if IS_DEMO:
        pub_ss_df_1 = pub_ss_df.drop_duplicates("ImageWidth", keep="first")
        pub_ss_df_2 = pub_ss_df.drop_duplicates("ImageWidth", keep="last")
        pub_ss_df = pd.concat([pub_ss_df_1, pub_ss_df_2])
        
    pub_ss_df.mask_rles = pub_ss_df.mask_rles.apply(lambda x: ast.literal_eval(x))
    pub_ss_df.mask_bboxes = pub_ss_df.mask_bboxes.apply(lambda x: ast.literal_eval(x))
    pub_ss_df.mask_sub_rles = pub_ss_df.mask_sub_rles.apply(lambda x: ast.literal_eval(x))
    
    print("\n\nTEST DATAFRAME W/ MASKS\n\n")
    display(pub_ss_df)    

# Helper funcitons

In [None]:
def binary_mask_to_ascii(mask, mask_val=1):
    """Converts a binary mask into OID challenge encoding ascii text."""
    mask = np.where(mask==mask_val, 1, 0).astype(np.bool)
    
    # check input mask --
    if mask.dtype != np.bool:
        raise ValueError(f"encode_binary_mask expects a binary mask, received dtype == {mask.dtype}")

    mask = np.squeeze(mask)
    if len(mask.shape) != 2:
        raise ValueError(f"encode_binary_mask expects a 2d mask, received shape == {mask.shape}")

    # convert input mask to expected COCO API input --
    mask_to_encode = mask.reshape(mask.shape[0], mask.shape[1], 1)
    mask_to_encode = mask_to_encode.astype(np.uint8)
    mask_to_encode = np.asfortranarray(mask_to_encode)

    # RLE encode mask --
    encoded_mask = coco_mask.encode(mask_to_encode)[0]["counts"]

    # compress and base64 encoding --
    binary_str = zlib.compress(encoded_mask, zlib.Z_BEST_COMPRESSION)
    base64_str = base64.b64encode(binary_str)
    return base64_str.decode()


def rle_encoding(img, mask_val=1):
    """
    Turns our masks into RLE encoding to easily store them
    and feed them into models later on
    https://en.wikipedia.org/wiki/Run-length_encoding
    
    Args:
        img (np.array): Segmentation array
        mask_val (int): Which value to use to create the RLE
        
    Returns:
        RLE string
    
    """
    dots = np.where(img.T.flatten() == mask_val)[0]
    run_lengths = []
    prev = -2
    for b in dots:
        if (b>prev+1): run_lengths.extend((b + 1, 0))
        run_lengths[-1] += 1
        prev = b
        
    return ' '.join([str(x) for x in run_lengths])


def rle_to_mask(rle_string, height, width):
    """ Convert RLE sttring into a binary mask 
    
    Args:
        rle_string (rle_string): Run length encoding containing 
            segmentation mask information
        height (int): Height of the original image the map comes from
        width (int): Width of the original image the map comes from
    
    Returns:
        Numpy array of the binary segmentation mask for a given cell
    """
    rows,cols = height,width
    rle_numbers = [int(num_string) for num_string in rle_string.split(' ')]
    rle_pairs = np.array(rle_numbers).reshape(-1,2)
    img = np.zeros(rows*cols,dtype=np.uint8)
    for index,length in rle_pairs:
        index -= 1
        img[index:index+length] = 255
    img = img.reshape(cols,rows)
    img = img.T
    return img


def decode_img(img, img_size=(224,224), testing=False):
    """TBD"""
    
    # convert the compressed string to a 3D uint8 tensor
    if not testing:
        # resize the image to the desired size
        img = tf.image.decode_png(img, channels=1)
        return tf.cast(tf.image.resize(img, img_size), tf.uint8)
    else:
        return tf.image.decode_png(img, channels=1)
        

    
def preprocess_path_ds(rp, gp, bp, yp, lbl, n_classes=19, img_size=(224,224), combine=True, drop_yellow=True):
    """ TBD """
    
    ri = decode_img(tf.io.read_file(rp), img_size)
    gi = decode_img(tf.io.read_file(gp), img_size)
    bi = decode_img(tf.io.read_file(bp), img_size)
    yi = decode_img(tf.io.read_file(yp), img_size)

    if combine and drop_yellow:
        return tf.stack([ri[..., 0], gi[..., 0], bi[..., 0]], axis=-1), tf.one_hot(lbl, n_classes, dtype=tf.uint8)
    elif combine:
        return tf.stack([ri[..., 0], gi[..., 0], bi[..., 0], yi[..., 0]], axis=-1), tf.one_hot(lbl, n_classes, dtype=tf.uint8)
    elif drop_yellow:
        return ri, gi, bi, tf.one_hot(lbl, n_classes, dtype=tf.uint8)
    else:
        return ri, gi, bi, yi, tf.one_hot(lbl, n_classes, dtype=tf.uint8)        
    
    
def create_pred_col(row):
    """ Simple function to return the correct prediction string
    
    We will want the original public test dataframe submission when it is 
    available. However, we will use the swapped inn submission dataframe
    when it is not.
    
    Args:
        row (pd.Series): A row in the dataframe
    
    Returns:
        The prediction string
    """
    if pd.isnull(row.PredictionString_y):
        return row.PredictionString_x
    else:
        return row.PredictionString_y
    
    
def load_image(img_id, img_dir, testing=False, only_public=False):
    """ Load An Image Using ID and Directory Path - Composes 4 Individual Images """
    if only_public:
        return_axis = -1
        clr_list = ["red", "green", "blue"]
    else:
        return_axis = 0
        clr_list = ["red", "green", "blue", "yellow"]
    
    if not testing:
        rgby = [
            np.asarray(Image.open(os.path.join(img_dir, img_id+f"_{c}.png")), np.uint8) \
            for c in ["red", "green", "blue", "yellow"]
        ]
        return np.stack(rgby, axis=-1)
    else:
        # This is for cellsegmentator
        return np.stack(
            [np.asarray(decode_img(tf.io.read_file(os.path.join(img_dir, img_id+f"_{c}.png")), testing=True), np.uint8)[..., 0] \
             for c in clr_list], axis=return_axis,
        )
        


def plot_rgb(arr, figsize=(12,12)):
    """ Plot 3 Channel Microscopy Image """
    plt.figure(figsize=figsize)
    plt.title(f"RGB Composite Image", fontweight="bold")
    plt.imshow(arr)
    plt.axis(False)
    plt.show()
    
    
def convert_rgby_to_rgb(arr):
    """ Convert a 4 channel (RGBY) image to a 3 channel RGB image.
    
    Advice From Competition Host/User: lnhtrang

    For annotation (by experts) and for the model, I guess we agree that individual 
    channels with full range px values are better. 
    In annotation, we toggled the channels. 
    For visualization purpose only, you can try blending the channels. 
    For example, 
        - red = red + yellow
        - green = green + yellow/2
        - blue=blue.
        
    Args:
        arr (numpy array): The RGBY, 4 channel numpy array for a given image
    
    Returns:
        RGB Image
    """
    
    rgb_arr = np.zeros_like(arr[..., :-1])
    rgb_arr[..., 0] = arr[..., 0]
    rgb_arr[..., 1] = arr[..., 1]+arr[..., 3]/2
    rgb_arr[..., 2] = arr[..., 2]
    
    return rgb_arr
    
    
def plot_ex(arr, figsize=(20,6), title=None, plot_merged=True, rgb_only=False):
    """ Plot 4 Channels Side by Side """
    if plot_merged and not rgb_only:
        n_images=5 
    elif plot_merged and rgb_only:
        n_images=4
    elif not plot_merged and rgb_only:
        n_images=4
    else:
        n_images=3
    plt.figure(figsize=figsize)
    if type(title) == str:
        plt.suptitle(title, fontsize=20, fontweight="bold")

    for i, c in enumerate(["Red Channel – Microtubles", "Green Channel – Protein of Interest", "Blue - Nucleus", "Yellow – Endoplasmic Reticulum"]):
        if not rgb_only:
            ch_arr = np.zeros_like(arr[..., :-1])        
        else:
            ch_arr = np.zeros_like(arr)
        if c in ["Red Channel – Microtubles", "Green Channel – Protein of Interest", "Blue - Nucleus"]:
            ch_arr[..., i] = arr[..., i]
        else:
            if rgb_only:
                continue
            ch_arr[..., 0] = arr[..., i]
            ch_arr[..., 1] = arr[..., i]
        plt.subplot(1,n_images,i+1)
        plt.title(f"{c.title()}", fontweight="bold")
        plt.imshow(ch_arr)
        plt.axis(False)
        
    if plot_merged:
        plt.subplot(1,n_images,n_images)
        
        if rgb_only:
            plt.title(f"Merged RGB", fontweight="bold")
            plt.imshow(arr)
        else:
            plt.title(f"Merged RGBY into RGB", fontweight="bold")
            plt.imshow(convert_rgby_to_rgb(arr))
        plt.axis(False)
        
    plt.tight_layout(rect=[0, 0.2, 1, 0.97])
    plt.show()
    
    
def flatten_list_of_lists(l_o_l, to_string=False):
    if not to_string:
        return [item for sublist in l_o_l for item in sublist]
    else:
        return [str(item) for sublist in l_o_l for item in sublist]


def create_segmentation_maps(list_of_image_lists, segmentator, batch_size=8):
    """ Function to generate segmentation maps using CellSegmentator tool 
    
    Args:
        list_of_image_lists (list of lists):
            - [[micro-tubules(red)], [endoplasmic-reticulum(yellow)], [nucleus(blue)]]
        batch_size (int): Batch size to use in generating the segmentation masks
        
    Returns:
        List of lists containing RLEs for all the cells in all images
    """
    
    all_mask_rles = {}
    for i in tqdm(range(0, len(list_of_image_lists[0]), batch_size), total=len(list_of_image_lists[0])//batch_size):
        
        # Get batch of images
        sub_images = [img_channel_list[i:i+batch_size] for img_channel_list in list_of_image_lists] # 0.000001 seconds

        # Do segmentation
        cell_segmentations = segmentator.pred_cells(sub_images)
        nuc_segmentations = segmentator.pred_nuclei(sub_images[2])

        # post-processing
        for j, path in enumerate(sub_images[0]):
            img_id = path.replace("_red.png", "").rsplit("/", 1)[1]
            nuc_mask, cell_mask = label_cell(nuc_segmentations[j], cell_segmentations[j])
            new_name = os.path.basename(path).replace('red','mask')
            all_mask_rles[img_id] = [rle_encoding(cell_mask, mask_val=k) for k in range(1, np.max(cell_mask)+1)]
    return all_mask_rles


def get_img_list(img_dir, return_ids=False, sub_n=None):
    """ Get image list in the format expected by the CellSegmentator tool """
    if sub_n is None:
        sub_n=len(glob(img_dir + '/' + f'*_red.png'))
    if return_ids:
        images = [sorted(glob(img_dir + '/' + f'*_{c}.png'))[:sub_n] for c in ["red", "yellow", "blue"]]
        return [x.replace("_red.png", "").rsplit("/", 1)[1] for x in images[0]], images
    else:
        return [sorted(glob(img_dir + '/' + f'*_{c}.png'))[:sub_n] for c in ["red", "yellow", "blue"]]
    
    
def get_contour_bbox_from_rle(rle, width, height, return_mask=True,):
    """ Get bbox of contour as `xmin ymin xmax ymax`
    
    Args:
        rle (rle_string): Run length encoding containing 
            segmentation mask information
        height (int): Height of the original image the map comes from
        width (int): Width of the original image the map comes from
    
    Returns:
        Numpy array for a cell bounding box coordinates
    """
    mask = rle_to_mask(rle, height, width).copy()
    cnts = grab_contours(
        cv2.findContours(
            mask, 
            cv2.RETR_EXTERNAL, 
            cv2.CHAIN_APPROX_SIMPLE
        ))
    x,y,w,h = cv2.boundingRect(cnts[0])
    
    if return_mask:
        return (x,y,x+w,y+h), mask
    else:
        return (x,y,x+w,y+h)
    

def get_contour_bbox_from_raw(raw_mask):
    """ Get bbox of contour as `xmin ymin xmax ymax`
    
    Args:
        raw_mask (nparray): Numpy array containing segmentation mask information
    
    Returns:
        Numpy array for a cell bounding box coordinates
    """
    cnts = grab_contours(
        cv2.findContours(
            raw_mask, 
            cv2.RETR_EXTERNAL, 
            cv2.CHAIN_APPROX_SIMPLE
        ))
    xywhs = [cv2.boundingRect(cnt) for cnt in cnts]
    xys = [(xywh[0], xywh[1], xywh[0]+xywh[2], xywh[1]+xywh[3]) for xywh in xywhs]
    return sorted(xys, key=lambda x: (x[1], x[0]))


def pad_to_square(a):
    """ Pad an array `a` evenly until it is a square """
    if a.shape[1]>a.shape[0]: # pad height
        n_to_add = a.shape[1]-a.shape[0]
        top_pad = n_to_add//2
        bottom_pad = n_to_add-top_pad
        a = np.pad(a, [(top_pad, bottom_pad), (0, 0), (0, 0)], mode='constant')

    elif a.shape[0]>a.shape[1]: # pad width
        n_to_add = a.shape[0]-a.shape[1]
        left_pad = n_to_add//2
        right_pad = n_to_add-left_pad
        a = np.pad(a, [(0, 0), (left_pad, right_pad), (0, 0)], mode='constant')
    else:
        pass
    return a


def cut_out_cells(rgby, rles, resize_to=(256,256), square_off=True, return_masks=False, from_raw=True):
    """ Cut out the cells as padded square images 
    
    Args:
        rgby (np.array): 4 Channel image to be cut into tiles
        rles (list of RLE strings): List of run length encoding containing 
            segmentation mask information
        resize_to (tuple of ints, optional): The square dimension to resize the image to
        square_off (bool, optional): Whether to pad the image to a square or not
        
    Returns:
        list of square arrays representing squared off cell images
    """
    w,h = rgby.shape[:2]
    contour_bboxes = [get_contour_bbox(rle, w, h, return_mask=return_masks) for rle in rles]
    if return_masks:
        masks = [x[-1] for x in contour_bboxes]
        contour_bboxes = [x[:-1] for x in contour_bboxes]
    
    arrs = [rgby[bbox[1]:bbox[3], bbox[0]:bbox[2], ...] for bbox in contour_bboxes]
    if square_off:
        arrs = [pad_to_square(arr) for arr in arrs]
        
    if resize_to is not None:
        arrs = [
            cv2.resize(pad_to_square(arr).astype(np.float32), 
                       resize_to, 
                       interpolation=cv2.INTER_CUBIC) \
            for arr in arrs
        ]
    if return_masks:
        return arrs, masks
    else:
        return arrs


def grab_contours(cnts):
    # if the length the contours tuple returned by cv2.findContours
    # is '2' then we are using either OpenCV v2.4, v4-beta, or
    # v4-official
    if len(cnts) == 2:
        cnts = cnts[0]

    # if the length of the contours tuple is '3' then we are using
    # either OpenCV v3, v4-pre, or v4-alpha
    elif len(cnts) == 3:
        cnts = cnts[1]

    # otherwise OpenCV has changed their cv2.findContours return
    # signature yet again and I have no idea WTH is going on
    else:
        raise Exception(("Contours tuple must have length 2 or 3, "
            "otherwise OpenCV changed their cv2.findContours return "
            "signature yet again. Refer to OpenCV's documentation "
            "in that case"))

    # return the actual contours array
    return cnts


def preprocess_row(img_id, img_w, img_h, combine=True, drop_yellow=True):
    """ TBD """

    rp = os.path.join(TEST_IMG_DIR, img_id+"_red.png")
    gp = os.path.join(TEST_IMG_DIR, img_id+"_green.png")
    bp = os.path.join(TEST_IMG_DIR, img_id+"_blue.png")
    yp = os.path.join(TEST_IMG_DIR, img_id+"_yellow.png")
    
    ri = decode_img(tf.io.read_file(rp), (img_w, img_h), testing=True)
    gi = decode_img(tf.io.read_file(gp), (img_w, img_h), testing=True)
    bi = decode_img(tf.io.read_file(bp), (img_w, img_h), testing=True)

    if not drop_yellow:
        yi = decode_img(tf.io.read_file(yp), (img_w, img_h), testing=True)

    if combine and drop_yellow:
        return tf.stack([ri[..., 0], gi[..., 0], bi[..., 0]], axis=-1)
    elif combine:
        return tf.stack([ri[..., 0], gi[..., 0], bi[..., 0], yi[..., 0]], axis=-1)
    elif drop_yellow:
        return ri, gi, bi
    else:
        return ri, gi, bi, yi

    
def plot_predictions(img, masks, preds, confs=None, fill_alpha=0.3, lbl_as_str=True):
    # Initialize
    FONT = cv2.FONT_HERSHEY_SIMPLEX; FONT_SCALE = 0.7; FONT_THICKNESS = 2; FONT_LINE_TYPE = cv2.LINE_AA;
    COLORS = [[round(y*255) for y in x] for x in sns.color_palette("Spectral", len(LBL_NAMES))]
    to_plot = img.copy()
    cntr_img = img.copy()
    if confs==None:
        confs = [None,]*len(masks)

    cnts = grab_contours(
        cv2.findContours(
            masks, 
            cv2.RETR_EXTERNAL, 
            cv2.CHAIN_APPROX_SIMPLE
        ))
    cnts = sorted(cnts, key=lambda x: (cv2.boundingRect(x)[1], cv2.boundingRect(x)[0]))
        
    for c, pred, conf in zip(cnts, preds, confs):
        # We can only display one color so we pick the first
        color = COLORS[pred[0]]
        if not lbl_as_str:
            classes = "CLS=["+",".join([str(p) for p in pred])+"]"
        else:
            classes = ", ".join([INT_2_STR[p] for p in pred])
        M = cv2.moments(c)
        cx = int(M['m10']/M['m00'])
        cy = int(M['m01']/M['m00'])
        
        text_width, text_height = cv2.getTextSize(classes, FONT, FONT_SCALE, FONT_THICKNESS)[0]
        
        # Border and fill
        cv2.drawContours(to_plot, [c], contourIdx=-1, color=[max(0, x-40) for x in color], thickness=10)
        cv2.drawContours(cntr_img, [c], contourIdx=-1, color=(color), thickness=-1)
        
        # Text
        cv2.putText(to_plot, classes, (cx-text_width//2,cy-text_height//2),
                    FONT, FONT_SCALE, [min(255, x+40) for x in color], FONT_THICKNESS, FONT_LINE_TYPE)
    
    cv2.addWeighted(cntr_img, fill_alpha, to_plot, 1-fill_alpha, 0, to_plot)
    plt.figure(figsize=(16,16))
    plt.imshow(to_plot)
    plt.axis(False)
    plt.show()
    
def tta(original_img_batch, repeats=4):
    """ Perform test time augmentation """
    tta_img_batches = [original_img_batch,]

    for i in range(repeats):
        # create new image batch (tf automatically deep copies)
        img_batch = original_img_batch
        
        SEED = tf.random.uniform((2,), minval=0, maxval=100, dtype=tf.dtypes.int32)
        K = tf.random.uniform((1,), minval=0, maxval=4, dtype=tf.dtypes.int32)[0]

        img_batch = tf.image.stateless_random_flip_left_right(img_batch, SEED)
        img_batch = tf.image.stateless_random_flip_up_down(img_batch, SEED)
        img_batch = tf.image.rot90(img_batch, K)

        img_batch = tf.image.stateless_random_saturation(img_batch, 0.9, 1.1, SEED)
        img_batch = tf.image.stateless_random_brightness(img_batch, 0.075, SEED)
        img_batch = tf.image.stateless_random_contrast(img_batch, 0.9, 1.1, SEED)    
        tta_img_batches.append(img_batch)
    
    return tta_img_batches

# Load model and predict

In [None]:
# Load inference model
inference_model = tf.keras.models.load_model(B2_CELL_CLSFR_DIR)

# Parameters
IMAGE_SIZES = [1728, 2048, 3072, 4096]
BATCH_SIZE = 8
CONF_THRESH = 0.0
TILE_SIZE = (224,224)


# Switch what we will be actually infering on
if ONLY_PUBLIC:
    # Make subset dataframes
    predict_df_1728 = pub_ss_df[pub_ss_df.ImageWidth==IMAGE_SIZES[0]]
    predict_df_2048 = pub_ss_df[pub_ss_df.ImageWidth==IMAGE_SIZES[1]]
    predict_df_3072 = pub_ss_df[pub_ss_df.ImageWidth==IMAGE_SIZES[2]]
    predict_df_4096 = pub_ss_df[pub_ss_df.ImageWidth==IMAGE_SIZES[3]]
else:
    # Load Segmentator
    segmentator = cellsegmentator.CellSegmentator(NUC_MODEL, CELL_MODEL, scale_factor=0.25, padding=True)
    
    # Make subset dataframes
    predict_df_1728 = ss_df[ss_df.ImageWidth==IMAGE_SIZES[0]]
    predict_df_2048 = ss_df[ss_df.ImageWidth==IMAGE_SIZES[1]]
    predict_df_3072 = ss_df[ss_df.ImageWidth==IMAGE_SIZES[2]]
    predict_df_4096 = ss_df[ss_df.ImageWidth==IMAGE_SIZES[3]]


predict_ids_1728 = predict_df_1728.ID.to_list()
predict_ids_2048 = predict_df_2048.ID.to_list()
predict_ids_3072 = predict_df_3072.ID.to_list()
predict_ids_4096 = predict_df_4096.ID.to_list()

# Inferance

In [None]:
predictions = []
sub_df = pd.DataFrame(columns=["ID"], data=predict_ids_1728+predict_ids_2048+predict_ids_3072+predict_ids_4096)

# #### STEP TIMING FOR 1728x1728 IMAGES FOR EFFNETB0 ON 128x128 CROPS ####
#  0:	 1.03042 seconds
#  1:	 8.14935 seconds
#  2:	 0.00002 seconds
#  3:	 29.9057 seconds
#  4:	 1.30675 seconds
#  5:	 0.01442 seconds
#  6:	 0.26723 seconds
#  7:	 4.10871 seconds
#  8:	 0.00108 seconds
#  9:	 0.00066 seconds
# 10:	 0.00015 seconds
for size_idx, submission_ids in enumerate([predict_ids_1728, predict_ids_2048, predict_ids_3072, predict_ids_4096]):
    size = IMAGE_SIZES[size_idx]
    if submission_ids==[]:
        print(f"\n...SKIPPING SIZE {size} AS THERE ARE NO IMAGE IDS ...\n")
        continue
    else:
        print(f"\n...WORKING ON IMAGE IDS FOR SIZE {size} ...\n")
    for i in tqdm(range(0, len(submission_ids), BATCH_SIZE), total=int(np.ceil(len(submission_ids)/BATCH_SIZE))):
        
        # Step 0: Get batch of images as numpy arrays
        batch_rgby_images = [
            load_image(ID, TEST_IMG_DIR, testing=True, only_public=ONLY_PUBLIC) \
            for ID in submission_ids[i:(i+BATCH_SIZE)]
        ]
        
        if ONLY_PUBLIC:
            # Step 1: Get Bounding Boxes
            batch_cell_bboxes = pub_ss_df[pub_ss_df.ID.isin(submission_ids[i:(i+BATCH_SIZE)])].mask_bboxes.values
            
            # Step 2: Get RGB Images (which are actually just labelled as RGBY)
            batch_rgb_images = batch_rgby_images
            
            # Step 3: Get Submission RLEs
            submission_rles = pub_ss_df[pub_ss_df.ID.isin(submission_ids[i:(i+BATCH_SIZE)])].mask_sub_rles.values
            
            # Optional Step: Get the Masks
            if IS_DEMO:
                batch_masks = [
                    sum([rle_to_mask(mask, size, size) for mask in batch]) \
                    for batch in pub_ss_df[pub_ss_df.ID.isin(submission_ids[i:(i+BATCH_SIZE)])].mask_rles.values
                ]
                
            
        else:
            # Step 1: Do Prediction On Batch
            cell_segmentations = segmentator.pred_cells([[rgby_image[j] for rgby_image in batch_rgby_images] for j in [0, 3, 2]])
            nuc_segmentations = segmentator.pred_nuclei([rgby_image[2] for rgby_image in batch_rgby_images])

            # Step 2: Perform Cell Labelling on Batch
            batch_masks = [label_cell(nuc_seg, cell_seg)[1].astype(np.uint8) for nuc_seg, cell_seg in zip(nuc_segmentations, cell_segmentations)]

            # Step 3: Reshape the RGBY Images so They Are Channels Last Across the Batch
            batch_rgb_images = [rgby_image.transpose(1,2,0)[..., :-1] for rgby_image in batch_rgby_images]

            # Step 4: Get Bounding Boxes For All Cells in All Images in Batch
            batch_cell_bboxes = [get_contour_bbox_from_raw(mask) for mask in batch_masks]
            
            # Step 5: Generate Submission RLEs For the Batch
            submission_rles = [[binary_mask_to_ascii(mask, mask_val=cell_id) for cell_id in range(1, mask.max()+1)] for mask in batch_masks]
    
        # Step 6: Cut Out, Pad to Square, and Resize to 224x224
        batch_cell_tiles = [[
            cv2.resize(
                pad_to_square(
                    rgb_image[bbox[1]:bbox[3], bbox[0]:bbox[2], ...]), 
                TILE_SIZE, interpolation=cv2.INTER_CUBIC) for bbox in bboxes] 
            for bboxes, rgb_image in zip(batch_cell_bboxes, batch_rgb_images)
        ]

        # Step 7: (OPTIONAL) Test Time Augmentation
        if DO_TTA:
            tta_batch_cell_tiles = [tta(tf.cast(ct, dtype=tf.float32), repeats=TTA_REPEATS) for ct in batch_cell_tiles]
        else:
            batch_cell_tiles = [tf.cast(ct, dtype=tf.float32) for ct in batch_cell_tiles]
        
        # Step 8: Perform Inference 
        if DO_TTA:
            tta_batch_o_preds = [[inference_model.predict(ct) for ct in bct] for bct in tta_batch_cell_tiles]
            batch_o_preds = [tf.keras.layers.Average()(tta_o_preds).numpy() for tta_o_preds in tta_batch_o_preds]
        else:
            batch_o_preds = [inference_model.predict(cell_tiles) for cell_tiles in batch_cell_tiles]
            
        # Step 9: Post-Process
        batch_confs = [[pred[np.where(pred>CONF_THRESH)] for pred in o_preds] for o_preds in batch_o_preds]
        batch_preds = [[np.where(pred>CONF_THRESH)[0] for pred in o_preds] for o_preds in batch_o_preds]

        for j, preds in enumerate(batch_preds):
            for k in range(len(preds)):
                if preds[k].size==0:
                    batch_preds[j][k]=np.array([18,])
                    batch_confs[j][k]=np.array([1-np.max(batch_o_preds[j][k]),])
        
        # Optional Viz Step
        if IS_DEMO:
            print("\n... DEMO IMAGES ...\n")
            for rgb_images, masks, preds, confs in zip(batch_rgb_images, batch_masks, batch_preds, batch_confs):
                plot_predictions(rgb_images, masks, preds, confs=confs, fill_alpha=0.2, lbl_as_str=True)

        
        # Step 10: Format Predictions To Create Prediction String Easily
        submission_rles = [flatten_list_of_lists([[m,]*len(p) for m, p in zip(masks, preds)]) for masks, preds in zip(submission_rles, batch_preds)]
        batch_preds = [flatten_list_of_lists(preds, to_string=True) for preds in batch_preds]
        batch_confs = [[f"{conf:.4f}" for cell_confs in confs for conf in cell_confs] for confs in batch_confs]
        
        # Step 11: Save Predictions to Be Added to Dataframe At The End
        predictions.extend([" ".join(flatten_list_of_lists(zip(*[preds,confs,masks]))) for preds, confs, masks in zip(batch_preds, batch_confs, submission_rles)])
sub_df["PredictionString"] = predictions

print("\n... TEST DATAFRAME ...\n")
display(sub_df.head(3))

# Submission

In [None]:
ss_df = ss_df.merge(sub_df, how="left", on="ID")
ss_df["PredictionString"] = ss_df.apply(create_pred_col, axis=1)
ss_df = ss_df.drop(columns=["PredictionString_x", "PredictionString_y"])
#ss_df.to_csv("/kaggle/working/submission.csv", index=False)
display(ss_df)

torch.cuda.empty_cache()

# Final submission

In [None]:
sub_df = pd.read_csv('../input/hpa-single-cell-image-classification/sample_submission.csv')

if sub_df.shape[0] != 559:
    def auto_select_accelerator():
        try:
            tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
            tf.config.experimental_connect_to_cluster(tpu)
            tf.tpu.experimental.initialize_tpu_system(tpu)
            strategy = tf.distribute.experimental.TPUStrategy(tpu)
            print("Running on TPU:", tpu.master())
        except ValueError:
            strategy = tf.distribute.get_strategy()
        print(f"Running on {strategy.num_replicas_in_sync} replicas")

        return strategy


    def build_decoder(with_labels=True, target_size=(300, 300), ext='jpg'):
        def decode(path):
            file_bytes = tf.io.read_file(path)
            if ext == 'png':
                img = tf.image.decode_png(file_bytes, channels=3)
            elif ext in ['jpg', 'jpeg']:
                img = tf.image.decode_jpeg(file_bytes, channels=3)
            else:
                raise ValueError("Image extension not supported")

            img = tf.cast(img, tf.float32) / 255.0
            img = tf.image.resize(img, target_size)

            return img

        def decode_with_labels(path, label):
            return decode(path), label

        return decode_with_labels if with_labels else decode


    def build_augmenter(with_labels=True):
        def augment(img):
            img = tf.image.random_flip_left_right(img)
            img = tf.image.random_flip_up_down(img)
            return img

        def augment_with_labels(img, label):
            return augment(img), label

        return augment_with_labels if with_labels else augment


    def build_dataset(paths, labels=None, bsize=32, cache=True,
                      decode_fn=None, augment_fn=None,
                      augment=True, repeat=True, shuffle=1024, 
                      cache_dir=""):
        if cache_dir != "" and cache is True:
            os.makedirs(cache_dir, exist_ok=True)

        if decode_fn is None:
            decode_fn = build_decoder(labels is not None)

        if augment_fn is None:
            augment_fn = build_augmenter(labels is not None)

        AUTO = tf.data.experimental.AUTOTUNE
        slices = paths if labels is None else (paths, labels)

        dset = tf.data.Dataset.from_tensor_slices(slices)
        dset = dset.map(decode_fn, num_parallel_calls=AUTO)
        dset = dset.cache(cache_dir) if cache else dset
        dset = dset.map(augment_fn, num_parallel_calls=AUTO) if augment else dset
        dset = dset.repeat() if repeat else dset
        dset = dset.shuffle(shuffle) if shuffle else dset
        dset = dset.batch(bsize).prefetch(AUTO)

        return dset

    COMPETITION_NAME = "hpa-single-cell-image-classification"
    strategy = auto_select_accelerator()
    BATCH_SIZE = strategy.num_replicas_in_sync * 16

    IMSIZE = (224, 240, 260, 300, 380, 456, 528, 600)

    load_dir = f"/kaggle/input/{COMPETITION_NAME}/"
    sub_df = pd.read_csv('../input/hpa-single-cell-image-classification/sample_submission.csv')
    #sub_df = ss_df.copy()

    sub_df = sub_df.drop(sub_df.columns[1:],axis=1)

    for i in range(19):
        sub_df[f'{i}'] = pd.Series(np.zeros(sub_df.shape[0]))


    test_paths = load_dir + "/test/" + sub_df['ID'] + '_green.png'
    # Get the multi-labels
    label_cols = sub_df.columns[1:]

    test_decoder = build_decoder(with_labels=False, target_size=(IMSIZE[7], IMSIZE[7]))
    dtest = build_dataset(
        test_paths, bsize=BATCH_SIZE, repeat=False, 
        shuffle=False, augment=False, cache=False,
        decode_fn=test_decoder
    )

    with strategy.scope():
        model = tf.keras.models.load_model(
            '../input/hpa-classification-efnb7-train/model_green.h5'
        )

    model.summary()
    sub_df[label_cols] = model.predict(dtest, verbose=1)

    sub_df.head()

    ss_df = pd.merge(ss_df, sub_df, on = 'ID', how = 'left')

    for i in range(ss_df.shape[0]):
        if ss_df.loc[i,'PredictionString'] == '0 1 eNoLCAgIMAEABJkBdQ==':
            continue
        a = ss_df.loc[i,'PredictionString']
        b = a.split()
        for j in range(int(len(a.split())/3)):
            for k in range(19):
                if int(b[0 + 3 * j]) == k:

                    c = b[0 + 3 * j + 1]               
                    b[0 + 3 * j + 1] = str(ss_df.loc[i,f'{k}'] * 0.6 + float(c) * 0.4)# * 0.9 + float(c) * 0.1

        ss_df.loc[i,'PredictionString'] = ' '.join(b)

    ss_df = ss_df[['ID','ImageWidth','ImageHeight','PredictionString']]
    ss_df.to_csv('submission.csv',index = False)
else:
    def auto_select_accelerator():
        try:
            tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
            tf.config.experimental_connect_to_cluster(tpu)
            tf.tpu.experimental.initialize_tpu_system(tpu)
            strategy = tf.distribute.experimental.TPUStrategy(tpu)
            print("Running on TPU:", tpu.master())
        except ValueError:
            strategy = tf.distribute.get_strategy()
        print(f"Running on {strategy.num_replicas_in_sync} replicas")

        return strategy


    def build_decoder(with_labels=True, target_size=(300, 300), ext='jpg'):
        def decode(path):
            file_bytes = tf.io.read_file(path)
            if ext == 'png':
                img = tf.image.decode_png(file_bytes, channels=3)
            elif ext in ['jpg', 'jpeg']:
                img = tf.image.decode_jpeg(file_bytes, channels=3)
            else:
                raise ValueError("Image extension not supported")

            img = tf.cast(img, tf.float32) / 255.0
            img = tf.image.resize(img, target_size)

            return img

        def decode_with_labels(path, label):
            return decode(path), label

        return decode_with_labels if with_labels else decode


    def build_augmenter(with_labels=True):
        def augment(img):
            img = tf.image.random_flip_left_right(img)
            img = tf.image.random_flip_up_down(img)
            return img

        def augment_with_labels(img, label):
            return augment(img), label

        return augment_with_labels if with_labels else augment


    def build_dataset(paths, labels=None, bsize=32, cache=True,
                      decode_fn=None, augment_fn=None,
                      augment=True, repeat=True, shuffle=1024, 
                      cache_dir=""):
        if cache_dir != "" and cache is True:
            os.makedirs(cache_dir, exist_ok=True)

        if decode_fn is None:
            decode_fn = build_decoder(labels is not None)

        if augment_fn is None:
            augment_fn = build_augmenter(labels is not None)

        AUTO = tf.data.experimental.AUTOTUNE
        slices = paths if labels is None else (paths, labels)

        dset = tf.data.Dataset.from_tensor_slices(slices)
        dset = dset.map(decode_fn, num_parallel_calls=AUTO)
        dset = dset.cache(cache_dir) if cache else dset
        dset = dset.map(augment_fn, num_parallel_calls=AUTO) if augment else dset
        dset = dset.repeat() if repeat else dset
        dset = dset.shuffle(shuffle) if shuffle else dset
        dset = dset.batch(bsize).prefetch(AUTO)

        return dset

    COMPETITION_NAME = "hpa-single-cell-image-classification"
    strategy = auto_select_accelerator()
    BATCH_SIZE = strategy.num_replicas_in_sync * 16

    IMSIZE = (224, 240, 260, 300, 380, 456, 528, 600)

    load_dir = f"/kaggle/input/{COMPETITION_NAME}/"
    sub_df = pd.read_csv('../input/hpa-single-cell-image-classification/sample_submission.csv')
    sub_df = ss_df.copy()

    sub_df = sub_df.drop(sub_df.columns[1:],axis=1)

    for i in range(19):
        sub_df[f'{i}'] = pd.Series(np.zeros(sub_df.shape[0]))


    test_paths = load_dir + "/test/" + sub_df['ID'] + '_green.png'
    # Get the multi-labels
    label_cols = sub_df.columns[1:]

    test_decoder = build_decoder(with_labels=False, target_size=(IMSIZE[7], IMSIZE[7]))
    dtest = build_dataset(
        test_paths, bsize=BATCH_SIZE, repeat=False, 
        shuffle=False, augment=False, cache=False,
        decode_fn=test_decoder
    )

    with strategy.scope():
        model = tf.keras.models.load_model(
            '../input/hpa-classification-efnb7-train/model_green.h5'
        )

    model.summary()
    sub_df[label_cols] = model.predict(dtest, verbose=1)

    sub_df.head()

    ss_df = pd.merge(ss_df, sub_df, on = 'ID', how = 'left')

    for i in range(ss_df.shape[0]):
        if ss_df.loc[i,'PredictionString'] == '0 1 eNoLCAgIMAEABJkBdQ==':
            continue
        a = ss_df.loc[i,'PredictionString']
        b = a.split()
        for j in range(int(len(a.split())/3)):
            for k in range(19):
                if int(b[0 + 3 * j]) == k:

                    c = b[0 + 3 * j + 1]               
                    b[0 + 3 * j + 1] = str(ss_df.loc[i,f'{k}'] * 0.6 + float(c) * 0.4)# * 0.9 + float(c) * 0.1

        ss_df.loc[i,'PredictionString'] = ' '.join(b)

    ss_df = ss_df[['ID','ImageWidth','ImageHeight','PredictionString']]
    ss_df.to_csv('submission.csv',index = False)

In [None]:
ss_df

# upvoted please