# Intro
In this notebook, I'd create a simple baseline. I'll build a classifier on top of image-level labels (multi-label classification), then use an explainability technique Guided-GRADCAM to extract regions responsible for particular class prediction, and then assign the segmented cells to particular classes based on the overlap with Grad-CAM outputs.

**Credits**:
To segment cells offline, I'll use [this notebook by RDizzl3](https://www.kaggle.com/rdizzl3/hpa-segmentation-masks-no-internet), the corresponding datasets. I also checked out the batched version from [this notebook by Darek Kłeczek](https://www.kaggle.com/thedrcat/hpa-baseline-cell-segmentation).


# Plan
1. [Libraries](#Libraries)
2. [Data Generators](#Data-Generators)
  * [One-hot encoding classes](#One-hot-encoding-classes)
  * [Stratified split into train/val](#Stratified-split-into-train/val)
  * [Generator class](#Generator-class)
3. [PR-AUC-based Callback](#PR-AUC-based-Callback)
4. [Classifier](#Classifier)
  * [Defining a model](#Defining-a-model)
  * [Initial tuning of the added fully-connected layer](#Initial-tuning-of-the-added-fully-connected-layer)
  * [Training the whole model](#Training-the-whole-model)
  * [Visualizing train and val PR AUC](#Visualizing-train-and-val-PR-AUC)
5. [Extracting Integrated gradients](#Extracting-Integrated-gradients)
6. [Cell segmentation](#Cell-segmentation)
7. [Cell level predictions](#Cell-level-predictions)
8. [Submission](#Submission)

# Libraries

In [None]:
!pip install "../input/keras-application/Keras_Applications-1.0.8-py3-none-any.whl"
!pip install "../input/efficientnet111/efficientnet-1.1.1-py3-none-any.whl"
!pip install "../input/pycocotools/pycocotools-2.0-cp37-cp37m-linux_x86_64.whl"
!pip install "../input/hpapytorchzoozip/pytorch_zoo-master"
!pip install "../input/hpacellsegmentatormaster/HPA-Cell-Segmentation-master"
!pip install "../input/tfexplainforoffline/tf_explain-0.2.1-py3-none-any.whl"

In [None]:
import os, glob
import tensorflow as tf
gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)
# tf.compat.v1.disable_eager_execution()
import random
from sklearn.model_selection import train_test_split
import cv2
import numpy as np
import pandas as pd
import multiprocessing
from copy import deepcopy
from sklearn.metrics import precision_recall_curve, auc
import keras
import keras.backend as K
from keras.optimizers import Adam
from keras.callbacks import Callback
# please note, that locally I've trained a keras.efficientnet model, but using tensorflow.keras.applications.EfficientNetB0 should lead to the same results
from efficientnet.keras import EfficientNetB0
from keras.layers import Dense, Flatten
from keras.models import Model, load_model
from keras.utils import Sequence
from albumentations import Compose, VerticalFlip, HorizontalFlip, Rotate, GridDistortion
import matplotlib.pyplot as plt
from IPython.display import Image, display
from numpy.random import seed
seed(10)
from tensorflow.python.framework import ops
import gc
from numba import cuda 
import hpacellseg.cellsegmentator as cellsegmentator
from hpacellseg.utils import label_cell, label_nuclei
from tqdm.auto import tqdm
import base64
import numpy as np
from pycocotools import _mask as coco_mask
import typing as t
import zlib
import warnings
from tf_explain.core.integrated_gradients import IntegratedGradients
warnings.filterwarnings('ignore')

tf.random.set_seed(10)
%matplotlib inline

In [None]:
TEST_IMGS_FOLDER = '../input/hpa-single-cell-image-classification/test/'
TRAIN_IMGS_FOLDER = '../input/hpa-single-cell-image-classification/train/'
IMG_HEIGHT = IMG_WIDTH = 512
BATCH_SIZE = 8

LOAD_PRETRAINED_MODEL = True
TRAIN_MODEL = False
FAST_PUBLIC_RUN = True

CHECKPOINT_NAME = 'classifier_effnetb0_512.h5'

num_cores = multiprocessing.cpu_count()

# Data Generators

## One-hot encoding classes

In [None]:
train_df = pd.read_csv('../input/hpa-single-cell-image-classification/train.csv')
train_df.head()

In [None]:
# from https://www.kaggle.com/c/hpa-single-cell-image-classification/data

specified_class_names = """0. Nucleoplasm
1. Nuclear membrane
2. Nucleoli
3. Nucleoli fibrillar center
4. Nuclear speckles
5. Nuclear bodies
6. Endoplasmic reticulum
7. Golgi apparatus
8. Intermediate filaments
9. Actin filaments 
10. Microtubules
11. Mitotic spindle
12. Centrosome
13. Plasma membrane
14. Mitochondria
15. Aggresome
16. Cytosol
17. Vesicles and punctate cytosolic patterns
18. Negative"""

class_names = [class_name.split('. ')[1] for class_name in specified_class_names.split('\n')]
class_names

In [None]:
train_df['Label'] = train_df['Label'].map(lambda x: map(int, x.split('|'))).map(set)
for class_i, class_name in enumerate(class_names):
    train_df[class_name] = train_df['Label'].map(lambda x: 1 if class_i in x else 0)
train_df.head()

In [None]:
# dictionary for fast access to ohe vectors
id_2_ohe_vector = {img:vec for img, vec in zip(train_df['ID'], train_df.iloc[:, 2:-1].values)}

## Stratified split into train/val
Let's stratify based on combination of labels. The unique combinations will be put into train.

In [None]:
label_combinations = train_df['Label'].map(lambda x: str(sorted(list(x))))
f'There are {sum(label_combinations.value_counts() == 1)} images with unique label combinations out of {len(label_combinations)}.'

In [None]:
label_combinations_counts = label_combinations.value_counts()
unique_label_combs = label_combinations_counts.index[(label_combinations_counts == 1).values]

In [None]:
train_ids_unique_combs = train_df['ID'].loc[label_combinations.map(lambda x: x in unique_label_combs)]

In [None]:
non_unique_combo_bool_idx = label_combinations.map(lambda x: x not in unique_label_combs)
train_ids, val_ids = train_test_split(train_df['ID'].loc[non_unique_combo_bool_idx].values, 
                                        test_size=0.2, 
                                        stratify=label_combinations.loc[non_unique_combo_bool_idx], # sorting present classes in lexicographical order, just to be sure
                                        random_state=42)

In [None]:
train_ids = np.concatenate((train_ids, train_ids_unique_combs))

In [None]:
sub_df = pd.read_csv('../input/hpa-single-cell-image-classification/sample_submission.csv')
test_ids = sub_df['ID'].values

## Generator class

Using green filter, as "the green filter should be used to predict the label, and the other filters are used as references." ([from the data page](https://www.kaggle.com/c/hpa-single-cell-image-classification/data))

In [None]:
class DataGenenerator(Sequence):
    def __init__(self, id_list, id_2_ohe_vector=None, folder_imgs=TRAIN_IMGS_FOLDER, 
                 batch_size=BATCH_SIZE, shuffle=True, augmentation=None,
                 resized_height=IMG_HEIGHT, resized_width=IMG_WIDTH, num_channels=3):
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.augmentation = augmentation
        self.id_list = deepcopy(id_list)
        self.folder_imgs = folder_imgs
        self.len = len(self.id_list) // self.batch_size
        self.resized_height = resized_height
        self.resized_width = resized_width
        self.num_channels = num_channels
        self.id_2_ohe_vector = id_2_ohe_vector
        self.is_test = not 'train' in folder_imgs
        if not self.is_test:       
            self.num_classes = len(next(iter(id_2_ohe_vector.values())))
        if not shuffle and not self.is_test:
            self.labels = [id_2_ohe_vector[img] for img in self.id_list[:self.len*self.batch_size]]

    def __len__(self):
        return self.len
    
    def on_epoch_start(self):
        if self.shuffle:
            random.shuffle(self.id_list)

    def __getitem__(self, idx):
        current_batch = self.id_list[idx * self.batch_size: (idx + 1) * self.batch_size]
        X = np.empty((self.batch_size, self.resized_height, self.resized_width, self.num_channels))

        if not self.is_test:
            y = np.empty((self.batch_size, self.num_classes))

        for i, image_id in enumerate(current_batch):
            path = os.path.join(self.folder_imgs, f'{image_id}_green.png')
            img = cv2.resize(cv2.imread(path), (self.resized_height, self.resized_width)).astype(np.float32)
            if not self.augmentation is None:
                augmented = self.augmentation(image=img)
                img = augmented['image']
            X[i, :, :, :] = img/255.0
            if not self.is_test:
                y[i, :] = self.id_2_ohe_vector[image_id]
        if not self.is_test:
            return X, y
        return X

    def get_labels(self):
        if self.shuffle:
            images_current = self.id_list[:self.len*self.batch_size]
            labels = [self.id_2_ohe_vector[img] for img in images_current]
        else:
            labels = self.labels
        return np.array(labels)

In [None]:
albumentations_train = Compose([
    VerticalFlip(), HorizontalFlip(), Rotate(limit=50), GridDistortion()
], p=1)

Generator instances

In [None]:
is_public_test_run = len(sub_df)==559 and FAST_PUBLIC_RUN
if is_public_test_run:
    test_ids = test_ids[:10]
# to speed up private set submission, if no training is expected
if not TRAIN_MODEL:
    train_ids = train_ids[:2000]
    val_ids = val_ids[:1000]
data_generator_train = DataGenenerator(train_ids, id_2_ohe_vector, augmentation=albumentations_train)
data_generator_train_eval = DataGenenerator(train_ids, id_2_ohe_vector, shuffle=False)
data_generator_val = DataGenenerator(val_ids, id_2_ohe_vector, shuffle=False)
data_generator_test = DataGenenerator(test_ids, folder_imgs=TEST_IMGS_FOLDER, shuffle=False)

# PR-AUC-based Callback

The callback would be used:
1. to estimate AUC under precision recall curve for each class,
2. to early stop after 5 epochs of no improvement in mean PR AUC,
3. save a model with the best PR AUC in validation,
4. to reduce learning rate on PR AUC plateau.

In [None]:
class PrAucCallback(Callback):
    def __init__(self, data_generator, class_names, num_workers=num_cores, 
                 early_stopping_patience=5, 
                 plateau_patience=3, reduction_rate=0.5,
                 stage='train', checkpoints_path='checkpoints/', model_name='effnetb0'):
        super(Callback, self).__init__()
        self.data_generator = data_generator
        self.num_workers = num_workers
        self.class_names = class_names
        self.history = [[] for _ in range(len(self.class_names) + 1)] # to store per each class and also mean PR AUC
        self.early_stopping_patience = early_stopping_patience
        self.plateau_patience = plateau_patience
        self.reduction_rate = reduction_rate
        self.stage = stage
        self.best_pr_auc = -float('inf')
        if not os.path.exists(checkpoints_path):
            os.makedirs(checkpoints_path)
        self.checkpoints_path = checkpoints_path
        self.model_name = model_name
        
    def compute_pr_auc(self, y_true, y_pred):
        pr_auc_mean = 0
        print(f"\n{'#'*30}\n")
        for class_i in range(len(self.class_names)):
            precision, recall, _ = precision_recall_curve(y_true[:, class_i], y_pred[:, class_i])
            pr_auc = auc(recall, precision)
            pr_auc_mean += pr_auc/len(self.class_names)
            print(f"PR AUC {self.class_names[class_i]}, {self.stage}: {pr_auc:.3f}\n")
            self.history[class_i].append(pr_auc)        
        print(f"\n{'#'*20}\n PR AUC mean, {self.stage}: {pr_auc_mean:.3f}\n{'#'*20}\n")
        self.history[-1].append(pr_auc_mean)
        return pr_auc_mean
              
    def is_patience_lost(self, patience):
        if len(self.history[-1]) > patience:
            best_performance = max(self.history[-1][-(patience + 1):-1])
            return best_performance == self.history[-1][-(patience + 1)] and best_performance >= self.history[-1][-1]    
              
    def early_stopping_check(self, pr_auc_mean):
        if self.is_patience_lost(self.early_stopping_patience):
            self.model.stop_training = True    
              
    def model_checkpoint(self, pr_auc_mean, epoch):
        if pr_auc_mean > self.best_pr_auc:
            # remove previous checkpoints to save space
            for checkpoint in glob.glob(os.path.join(self.checkpoints_path, f'classifier_{self.model_name}_epoch_*')):
                os.remove(checkpoint)
            self.best_pr_auc = pr_auc_mean
            self.model.save(os.path.join(self.checkpoints_path, f'classifier_{self.model_name}_epoch_{epoch}_val_pr_auc_{pr_auc_mean}.h5'))              
            print(f"\n{'#'*20}\nSaved new checkpoint\n{'#'*20}\n")
              
    def reduce_lr_on_plateau(self):
        if self.is_patience_lost(self.plateau_patience):
            new_lr = float(keras.backend.get_value(self.model.optimizer.lr)) * self.reduction_rate
            keras.backend.set_value(self.model.optimizer.lr, new_lr)
            print(f"\n{'#'*20}\nReduced learning rate to {new_lr}.\n{'#'*20}\n")
        
    def on_epoch_end(self, epoch, logs={}):
        y_pred = self.model.predict(self.data_generator, workers=self.num_workers)
        y_true = self.data_generator.get_labels()
        # estimate AUC under precision recall curve for each class
        pr_auc_mean = self.compute_pr_auc(y_true, y_pred)
              
        if self.stage == 'val':
            # early stop after early_stopping_patience=4 epochs of no improvement in mean PR AUC
            self.early_stopping_check(pr_auc_mean)

            # save a model with the best PR AUC in validation
            self.model_checkpoint(pr_auc_mean, epoch)

            # reduce learning rate on PR AUC plateau
            self.reduce_lr_on_plateau()            
        
    def get_pr_auc_history(self):
        return self.history

Callback instances

In [None]:
train_metric_callback = PrAucCallback(data_generator_train_eval, class_names[:-1])
val_callback = PrAucCallback(data_generator_val, class_names[:-1], stage='val')

# Classifier

## Defining a model

In [None]:
def get_model(class_names):
    K.clear_session()
    base_model =  EfficientNetB0(weights='imagenet', include_top=False, pooling='avg', 
                                 input_shape=(IMG_HEIGHT, IMG_WIDTH, 3))
    x = base_model.output
    y_pred = Dense(len(class_names) - 1, activation='sigmoid')(x)
    return Model(inputs=base_model.input, outputs=y_pred)

if not LOAD_PRETRAINED_MODEL:
    model = get_model(class_names)

## Initial tuning of the added fully-connected layer

In [None]:
if LOAD_PRETRAINED_MODEL:
    model = load_model(f'../input/cell-models/{CHECKPOINT_NAME}')

In [None]:
for base_layer in model.layers[:-1]:
    base_layer.trainable = False
    
model.compile(optimizer=Adam(lr=1e-2),  loss='categorical_crossentropy')
history_0 = model.fit(data_generator_train,
                      validation_data=data_generator_val,
                      epochs=1,
                      callbacks=[train_metric_callback, val_callback],
                      workers=num_cores,
                      verbose=1)

## Training the whole model

In [None]:
if TRAIN_MODEL:
    for base_layer in model.layers:
        base_layer.trainable = True

    model.compile(optimizer=Adam(lr=3e-4),  loss='categorical_crossentropy')
    history_1 = model.fit(data_generator_train,
                          validation_data=data_generator_val,
                          epochs=30,
                          callbacks=[train_metric_callback, val_callback],
                          workers=num_cores,
                          verbose=1,
                          initial_epoch=1)

## Visualizing train and val PR AUC

In [None]:
def plot_with_dots(ax, np_array):
    ax.scatter(list(range(1, len(np_array) + 1)), np_array, s=50)
    ax.plot(list(range(1, len(np_array) + 1)), np_array)

In [None]:
if TRAIN_MODEL:
    pr_auc_history_train = train_metric_callback.get_pr_auc_history()
    pr_auc_history_val = val_callback.get_pr_auc_history()

    plt.figure(figsize=(10, 7))
    plot_with_dots(plt, pr_auc_history_train[-1])
    plot_with_dots(plt, pr_auc_history_val[-1])

    plt.xlabel('Epoch', fontsize=17)
    plt.ylabel('Mean PR AUC', fontsize=17)
    plt.legend(['Train', 'Val'])
    plt.title('Training and Validation PR AUC', fontsize=22)
    plt.savefig('pr_auc_hist.png')

I left the model to train longer on my local GPU. I then upload the best model and plots from the model training.

In [None]:
if LOAD_PRETRAINED_MODEL and not TRAIN_MODEL:
    model = load_model(f'../input/cell-models/{CHECKPOINT_NAME}')

In [None]:
#  (for the 260x260 effnetb0)
if LOAD_PRETRAINED_MODEL:
    display(Image("../input/cell-models/pr_auc_hist.png"))    

# Extracting Integrated gradients

Using the awesome tf-explain library.

In [None]:
explainer = IntegratedGradients()

In [None]:
image_id = test_ids[0]
path = os.path.join(TEST_IMGS_FOLDER, f'{image_id}_green.png')
img = cv2.resize(cv2.imread(path), (IMG_HEIGHT, IMG_WIDTH)).astype(np.float32)/255.0
data = ([img], None)

In [None]:
grid = explainer.explain(data, model, 1, n_steps=15)
plt.imshow(grid)
plt.title(grid.max())

# Cell segmentation

In [None]:
def build_image_names(image_id: str, test: bool = False, height: int = 2048, width: int = 2048) -> list:
    
    def read_img(path: str, height: int = height, width: int = width):
        return cv2.resize(cv2.cvtColor(cv2.imread(path), cv2.COLOR_BGR2GRAY), (height, width))
    
    # mt is the mitchondria
    mt = f'../input/hpa-single-cell-image-classification/{"test" if test else "train"}/{image_id}_red.png'
    # er is the endoplasmic reticulum
    er = f'../input/hpa-single-cell-image-classification/{"test" if test else "train"}/{image_id}_yellow.png'
    # nu is the nuclei
    nu = f'../input/hpa-single-cell-image-classification/{"test" if test else "train"}/{image_id}_blue.png'
    
    mt_img = read_img(mt)
    er_img = read_img(er)
    nu_img = read_img(nu)
    return mt_img, er_img, nu_img

NUC_MODEL = '../input/hpacellsegmentatormodelweights/dpn_unet_nuclei_v1.pth'
CELL_MODEL = '../input/hpacellsegmentatormodelweights/dpn_unet_cell_3ch_v1.pth'

segmentator = cellsegmentator.CellSegmentator(
    NUC_MODEL,
    CELL_MODEL,
    scale_factor=0.25,
    device='cuda',
    padding=False,
    multi_channel_model=True
)

In [None]:
def get_masks(img_ids, test=True):
    mt_imgs = []
    er_imgs = []
    nu_imgs = []
    for img_id in img_ids:
        mt_img, er_img, nu_img = build_image_names(image_id=img_id, test=test)
        mt_imgs.append(mt_img)
        er_imgs.append(er_img)
        nu_imgs.append(nu_img)
    images = [mt_imgs, er_imgs, nu_imgs]
    
    try:
        nuc_segmentations = segmentator.pred_nuclei(images[2])
        cell_segmentations = segmentator.pred_cells(images)
        cell_masks = []
        for i in tqdm(range(len(cell_segmentations)), desc='Labeling cells..'):
            _, cell_mask = label_cell(nuc_segmentations[i], cell_segmentations[i])
            cell_masks.append(cell_mask)
        return cell_masks
    except:
        return None

# Cell-level predictions

Let's use the following heuristics:

we'll use the 95th percentile of Integrated Gradients output to estimate value roughly corresponding to the image-level confidence. The image conf. level is the predicted class confidence from our classifier. 

Next, we'll focus on Integrated Gradients output on top of the cell region. As not the whole true-positive cell must be lightened up by Integrated Gradients, we'll use the 95th percentile of the cell's to compare it with the value corresponding to the global image conf. level. Using these two levels we'll estimate the cell-level confidence: the higher the Integrated Gradients output, the higher the confidence. 

Finally, we'll output only cells with confidence values reaching at least 60% of the image-level confidence.

In [None]:
max_cell_level_conf_2_image_level_conf=0.4
quantile_level=0.95


def vis_masks_test(img_idx, conf_threshold=0.05, mask_height=2048, mask_width=2048, 
                   max_cell_level_conf_2_image_level_conf=max_cell_level_conf_2_image_level_conf, 
                   test_ids=test_ids, figsize=7, quantile_level=quantile_level):
    image_id = test_ids[img_idx]
    mask = get_masks([image_id])[0]
    n_cells = mask.max()
    cell_2_max_conf = dict()
    
    path = os.path.join(TEST_IMGS_FOLDER, f'{image_id}_green.png')
    img = cv2.resize(cv2.imread(path), (IMG_HEIGHT, IMG_WIDTH)).astype(np.float32)/255.0
    predictions_test = model.predict(np.expand_dims(img, 0))
    
    for class_i, class_name in enumerate(class_names[:-1]):
        class_conf_score = predictions_test[0][class_i]
        if class_conf_score > conf_threshold:
            try:
                explanation = explainer.explain(([img], None), model, class_i, n_steps=15)
                explanation_img = cv2.resize(explanation, (mask_height, mask_width))
                explanation_total_level = np.quantile(explanation_img.flatten(), quantile_level)
            except:
                continue

            plt.figure(figsize=(figsize, figsize))
            plt.imshow(mask)
            plt.imshow(explanation_img, alpha=0.7)
            plt.xticks([])
            plt.yticks([])
            plt.title(f'{test_ids[img_idx]}\n{class_name} ({class_conf_score:.2f}): raw Grad-CAMs', fontsize=22)
            plt.show()
            
            masks_all = np.zeros((mask_height, mask_width))
            coord_2_conf = dict()
            for cell_i in range(1, n_cells + 1):
                cell_mask_bool = mask == cell_i
                cell_explanation_perc = np.quantile(explanation_img[cell_mask_bool], quantile_level)
                cell_conf = np.clip(cell_explanation_perc*class_conf_score/explanation_total_level, 0, class_conf_score)
                if cell_conf/class_conf_score >= max_cell_level_conf_2_image_level_conf:
                    masks_all[cell_mask_bool] = 1
                    mask_pixels_x, mask_pixels_y = np.where(cell_mask_bool)
                    coord_2_conf[(int(mask_pixels_y.mean()), int(mask_pixels_x.mean()))] = cell_conf
                    if not cell_i in cell_2_max_conf:
                        cell_2_max_conf[cell_i] = cell_conf
                    else:
                        cell_2_max_conf[cell_i] = max(cell_conf, cell_2_max_conf[cell_i])

            plt.figure(figsize=(figsize, figsize)) 
            plt.imshow(masks_all)
            for coords, conf in coord_2_conf.items():
                conf_rounded = np.round(conf*100)/100
                plt.scatter(*coords, s=700, color='red', marker=r"$ {} $".format(conf_rounded))
            plt.xticks([])
            plt.yticks([])
            plt.title(f'{test_ids[img_idx]}\n{class_name}: cell-level predictions', fontsize=22)
            plt.show()
            
    # nothing interesting there

    masks_all = np.zeros((mask_height, mask_width))
    coord_2_conf = dict()
    for cell_i in range(1, n_cells + 1):
        if not cell_i in cell_2_max_conf:
            cell_conf = 0.99
        else:
            cell_conf = 1 - cell_2_max_conf[cell_i]
        if cell_conf >= conf_threshold:
            cell_mask_bool = mask == cell_i
            masks_all[cell_mask_bool] = 1
            mask_pixels_x, mask_pixels_y = np.where(cell_mask_bool)
            coord_2_conf[(int(mask_pixels_y.mean()), int(mask_pixels_x.mean()))] = cell_conf

    plt.figure(figsize=(figsize, figsize)) 
    plt.imshow(masks_all)
    for coords, conf in coord_2_conf.items():
        conf_rounded = np.round(conf*100)/100
        plt.scatter(*coords, s=700, color='red', marker=r"$ {} $".format(conf_rounded))
    plt.xticks([])
    plt.yticks([])
    plt.title(f'{test_ids[img_idx]}\n{class_names[-1]}: cell-level predictions', fontsize=22)
    plt.show()

In [None]:
for test_img_id in range(10):
    vis_masks_test(test_img_id)

# Submission

In [None]:
def encode_binary_mask(mask: np.ndarray) -> t.Text:
    """Converts a binary mask into OID challenge encoding ascii text."""

    # check input mask --
    if mask.dtype != np.bool:
        raise ValueError(
            "encode_binary_mask expects a binary mask, received dtype == %s" %
            mask.dtype)

    mask = np.squeeze(mask)
    if len(mask.shape) != 2:
        raise ValueError(
            "encode_binary_mask expects a 2d mask, received shape == %s" %
            mask.shape)

    # convert input mask to expected COCO API input --
    mask_to_encode = mask.reshape(mask.shape[0], mask.shape[1], 1)
    mask_to_encode = mask_to_encode.astype(np.uint8)
    mask_to_encode = np.asfortranarray(mask_to_encode)

    # RLE encode mask --
    encoded_mask = coco_mask.encode(mask_to_encode)[0]["counts"]

    # compress and base64 encoding --
    binary_str = zlib.compress(encoded_mask, zlib.Z_BEST_COMPRESSION)
    base64_str = base64.b64encode(binary_str)
    return base64_str.decode()

In [None]:
test_id_2_order_idx = {test_id: idx for idx, test_id in enumerate(test_ids)}

In [None]:
dummy_prediction = sub_df['PredictionString'].values[0]

In [None]:
def get_predictions_string(img_ids, mask_heights, mask_widths, conf_threshold=0.05, 
                           max_cell_level_conf_2_image_level_conf=max_cell_level_conf_2_image_level_conf, 
                           quantile_level=quantile_level, batch_size=8): 
    results_list = []
    img_idx = 0
    data_gen = DataGenenerator(test_ids, folder_imgs=TEST_IMGS_FOLDER, shuffle=False, batch_size=batch_size)
    # for the case, when there's just a single incomplete batch
    batch_i = -1
    for batch_i in range(len(img_ids)//batch_size + 1):
        assert batch_i*batch_size == len(results_list), f'Prev batch_i: {batch_i - 1}'
        img_batch_i = 0

        images_batch = data_gen.__getitem__(batch_i)[:len(img_ids) - batch_i*batch_size, :, :]

        predictions_test = model.predict(images_batch)

        img_batch_ids = img_ids[batch_i*batch_size:(batch_i + 1)*batch_size]
        masks_batch = get_masks(img_batch_ids)
        if masks_batch is None:
            current_batch_size = images_batch.shape[0]
            results_list.append([dummy_prediction for _ in range(current_batch_size)])
            print('batch masks were None')
            continue
        for mask in masks_batch:
            n_cells = mask.max()
            cell_2_max_conf = dict()
            results_list_img = []
            mask_height, mask_width = mask_heights[img_idx], mask_widths[img_idx]
            mask = cv2.resize(mask, (mask_height, mask_width))
            for class_i, class_name in enumerate(class_names[:-1]):
                class_conf_score = predictions_test[img_batch_i][class_i]
                if class_conf_score > conf_threshold:
                    try:
                        explanation = explainer.explain(([images_batch[img_batch_i]], None), model, class_i, n_steps=15)
                        explanation_img = cv2.resize(explanation, (mask_height, mask_width))
                        explanation_total_level = np.quantile(explanation_img.flatten(), quantile_level)
                    except:
                        print('explanation issue')
                        continue
                    coord_2_conf = dict()
                    for cell_i in range(1, n_cells + 1):
                        cell_mask_bool = mask == cell_i
                        cell_explanation_perc = np.quantile(explanation_img[cell_mask_bool], quantile_level)
                        cell_conf = np.clip(cell_explanation_perc*class_conf_score/explanation_total_level, 0, class_conf_score)
                        if cell_conf/class_conf_score >= max_cell_level_conf_2_image_level_conf:
                            mask_rle = encode_binary_mask(cell_mask_bool)
                            results_list_img.extend([class_i, cell_conf, mask_rle])
                            if not cell_i in cell_2_max_conf:
                                cell_2_max_conf[cell_i] = cell_conf
                            else:
                                cell_2_max_conf[cell_i] = max(cell_conf, cell_2_max_conf[cell_i])
                                
            # nothing interesting there
            for cell_i in range(1, n_cells + 1):
                if not cell_i in cell_2_max_conf:
                    cell_conf = 0.99
                else:
                    cell_conf = 1 - cell_2_max_conf[cell_i]
                if cell_conf >= conf_threshold:
                    cell_mask_bool = mask == cell_i
                    mask_rle = encode_binary_mask(cell_mask_bool)
                    results_list_img.extend([len(class_names) - 1, cell_conf, mask_rle])

            results_list.append(' '.join([str(item) for item in results_list_img]))
            img_idx += 1
            img_batch_i += 1

    return results_list

In [None]:
# to save the time for the public test set run
if is_public_test_run:
    sub_df.to_csv('submission.csv', index=None)
else:    
    sub_df['PredictionString'] = get_predictions_string(sub_df['ID'].values, 
                                                        sub_df['ImageHeight'].values,
                                                        sub_df['ImageWidth'].values)
    sub_df.to_csv('submission.csv', index=None)