# Intro
In this notebook, I'd like to share how to leverage pre-trained 3-channel Keras models to initialize a 4-channel model.

In the discussion forums the competition hosts have stressed the potential importance of all 4 colors, e.g. "All images have all the four channels, and signals from the markers (blue, yellow, red) are present in all cells in the image, independent of the green channel that you are classifying, in order to help you identify where the cells are, as well as where certain structures and regions within the cells are. This can, in turn, help you to segment the cells and to classify each cell to one or more label(s) according to the signal in the green channel." [link to the post](https://www.kaggle.com/c/hpa-single-cell-image-classification/discussion/215736#1184158).

Considering the size of training data, learning a deep 4-channel model with weights initialized at random might be problematic. But all ImageNet-pre-trained models have 3-channels.

This notebook demonstrates how to initialize a 4-channel EfficientNet with weights reused from a pre-trained 3-channel model.

### Notes:
* For PyTorch models one can check out [the notebook by Iafoss](https://www.kaggle.com/iafoss/pretrained-resnet34-with-rgby-0-460-public-lb) from the previous competition.
* Based on [this summary](https://www.kaggle.com/c/hpa-single-cell-image-classification/discussion/215986) kindly shared by Darek Kłeczek, it might be safe to drop the yellow image: "In the previous HPA competition, the majority of the top competitors dropped the yellow “channel” and used RGB instead of RGBY, without affecting their scores. Theory: microtubules and ER are typically in the same location, so it didn’t add much information and was safe to skip."

### Credits:
* [PyTorch RGBY model](https://www.kaggle.com/iafoss/pretrained-resnet34-with-rgby-0-460-public-lb)
* [Analogous task for Tensorflow VGG model](https://stackoverflow.com/questions/53251827/pretrained-tensorflow-model-rgb-rgby-channel-extension)
* To segment cells offline, I'll use [this notebook by RDizzl3](https://www.kaggle.com/rdizzl3/hpa-segmentation-masks-no-internet), the corresponding datasets. I also checked out the batched version from [this notebook by Darek Kłeczek](https://www.kaggle.com/thedrcat/hpa-baseline-cell-segmentation).
* [A notebook by Darien Schettler](https://www.kaggle.com/dschettler8845/hpa-cellwise-classification-inference) suggested that it might be possible to predict test set cell-by-cell under the time limit.

# Plan
1. [Libraries](#Libraries)
2. [4-channel classifier init](#4-channel-classifier-init)
3. [Check using explainability](#Check-using-explainability)
3. [Cell-level predictions](#Cell-level-predictions)

# Libraries

In [None]:
!pip install "../input/keras-application/Keras_Applications-1.0.8-py3-none-any.whl"
!pip install "../input/efficientnet111/efficientnet-1.1.1-py3-none-any.whl"
!pip install "../input/pycocotools/pycocotools-2.0-cp37-cp37m-linux_x86_64.whl"
!pip install "../input/hpapytorchzoozip/pytorch_zoo-master"
!pip install "../input/hpacellsegmentatormaster/HPA-Cell-Segmentation-master"
!pip install "../input/tfexplainforoffline/tf_explain-0.2.1-py3-none-any.whl"

In [None]:
import os, glob
import tensorflow as tf
gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)
# tf.compat.v1.disable_eager_execution()
import random
from sklearn.model_selection import train_test_split
import cv2
import numpy as np
import pandas as pd
import multiprocessing
from copy import deepcopy
import keras
import keras.backend as K
from keras.optimizers import Adam
from keras.callbacks import Callback
# please note, that locally I've trained a keras.efficientnet model, but using tensorflow.keras.applications.EfficientNetB0 should lead to the same results
from efficientnet.keras import EfficientNetB0
from keras.layers import Dense, Flatten
from keras.models import Model, load_model
from keras.utils import Sequence
from albumentations import Compose, VerticalFlip, HorizontalFlip, Rotate, GridDistortion
import matplotlib.pyplot as plt
from IPython.display import Image, display
from numpy.random import seed
seed(10)
from tensorflow.python.framework import ops
import gc
from numba import cuda 
import hpacellseg.cellsegmentator as cellsegmentator
from hpacellseg.utils import label_cell, label_nuclei
from tqdm.auto import tqdm
import base64
import numpy as np
from pycocotools import _mask as coco_mask
import typing as t
import zlib
import warnings
from tf_explain.core.integrated_gradients import IntegratedGradients
warnings.filterwarnings('ignore')

tf.random.set_seed(10)
%matplotlib inline

# 4-channel classifier init

In [None]:
TEST_IMGS_FOLDER = '../input/hpa-single-cell-image-classification/test/'
TRAIN_IMGS_FOLDER = '../input/hpa-single-cell-image-classification/train/'
IMG_HEIGHT = IMG_WIDTH = 512
BATCH_SIZE = 16
FAST_PUBLIC_RUN = True

# internet must be enables
DOWNLOAD_PRETRAINED_WEIGHTS = False

CHECKPOINT_NAME = 'classifier_effnetb0_rgby_512.h5'

num_cores = multiprocessing.cpu_count()

In [None]:
# from https://www.kaggle.com/c/hpa-single-cell-image-classification/data

specified_class_names = """0. Nucleoplasm
1. Nuclear membrane
2. Nucleoli
3. Nucleoli fibrillar center
4. Nuclear speckles
5. Nuclear bodies
6. Endoplasmic reticulum
7. Golgi apparatus
8. Intermediate filaments
9. Actin filaments 
10. Microtubules
11. Mitotic spindle
12. Centrosome
13. Plasma membrane
14. Mitochondria
15. Aggresome
16. Cytosol
17. Vesicles and punctate cytosolic patterns
18. Negative"""

class_names = [class_name.split('. ')[1] for class_name in specified_class_names.split('\n')]

## A model with weights pre-trained on ImageNet

In [None]:
# you'll need an internet connection to download ImageNet weights,
# for illustration I'm using a randomly generated RGB model
weights_init = 'imagenet' if DOWNLOAD_PRETRAINED_WEIGHTS else None

imagenet_model = EfficientNetB0(weights=weights_init, include_top=False, pooling='avg',
                               input_shape=(IMG_HEIGHT, IMG_WIDTH, 3))
rgb_model_output = Dense(len(class_names) - 1, activation='sigmoid')(imagenet_model.output)
model_rgb = Model(inputs=imagenet_model.input, outputs=rgb_model_output)

## A RGBY model

In [None]:
four_channel_effnet = EfficientNetB0(weights=None, include_top=False, pooling='avg', 
                                     input_shape=(IMG_HEIGHT, IMG_WIDTH, 4))
model_rgby_output = Dense(len(class_names) - 1, activation='sigmoid')(four_channel_effnet.output)
model_rgby = Model(inputs=four_channel_effnet.input, outputs=model_rgby_output)

## Copying ImageNet weights
The Stem layer of EffNet requires special care: we'll copy the blud-channel weights to the newly introduced yellow-channel.

In [None]:
for layer in tqdm(model_rgby.layers, desc='Copying the pre-trained net weights..'):
    if 'input' in layer.name or 'dense' in layer.name:
        continue
    elif layer.name == 'stem_conv':
#         with graph_green.as_default():
        kernels = model_rgb.get_layer('stem_conv').get_weights()[0]
        kernels_extra_channel = np.concatenate((kernels, kernels[:,:,-1:,:]), axis=-2)
        layer.set_weights([kernels_extra_channel])
    else:
#         with graph_green.as_default():
        weights_green = model_rgb.get_layer(layer.name).get_weights()
        layer.set_weights(weights_green)

Loading the fine-tuned RGBY model.

In [None]:
if FAST_PUBLIC_RUN:
    model_rgby = load_model(f'../input/cell-models/{CHECKPOINT_NAME}')

# Check using explainability

In [None]:
sub_df = pd.read_csv('../input/hpa-single-cell-image-classification/sample_submission.csv')
test_ids = sub_df['ID'].values

In [None]:
class DataGenenerator(Sequence):
    def __init__(self, id_list, id_2_ohe_vector=None, folder_imgs=TRAIN_IMGS_FOLDER, 
                 batch_size=BATCH_SIZE, shuffle=True, augmentation=None, resize=False,
                 resized_height=IMG_HEIGHT, resized_width=IMG_WIDTH, num_channels=4):
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.augmentation = augmentation
        self.id_list = deepcopy(id_list)
        self.folder_imgs = folder_imgs
        self.len = len(self.id_list) // self.batch_size
        self.resized_height = resized_height
        self.resized_width = resized_width
        self.num_channels = num_channels
        self.id_2_ohe_vector = id_2_ohe_vector
        self.is_test = not 'train' in folder_imgs
        if not self.is_test:       
            self.num_classes = len(next(iter(id_2_ohe_vector.values())))
        if not shuffle and not self.is_test:
            self.labels = [id_2_ohe_vector[img] for img in self.id_list[:self.len*self.batch_size]]
        self.resize = resize

    def __len__(self):
        return self.len
    
    def on_epoch_start(self):
        if self.shuffle:
            random.shuffle(self.id_list)
            
    # open_rgby adapted from https://www.kaggle.com/iafoss/pretrained-resnet34-with-rgby-0-460-public-lb
    def open_rgby(self, image_id): #a function that reads RGBY image
        colors = ['red','green','blue','yellow']
        img = [cv2.imread(os.path.join(self.folder_imgs, f'{image_id}_{color}.png'), cv2.IMREAD_GRAYSCALE)
               for color in colors]
        img = np.stack(img, axis=-1)
        if img.shape[0] == self.resized_height and img.shape[1] == self.resized_width:
            return img
        img_resized = cv2.resize(img, (self.resized_height, self.resized_width))
        return img_resized

    def __getitem__(self, idx):
        current_batch = self.id_list[idx * self.batch_size: (idx + 1) * self.batch_size]
        X = np.empty((self.batch_size, self.resized_height, self.resized_width, self.num_channels))

        if not self.is_test:
            y = np.empty((self.batch_size, self.num_classes))

        for i, image_id in enumerate(current_batch):
            img = self.open_rgby(image_id)
            if not self.augmentation is None:
                augmented = self.augmentation(image=img)
                img = augmented['image']
            X[i, :, :, :] = img.astype(np.float32)/255.0
            if not self.is_test:
                y[i, :] = self.id_2_ohe_vector[image_id]
        if not self.is_test:
            return X, y
        return X

    def get_labels(self):
        if self.shuffle:
            images_current = self.id_list[:self.len*self.batch_size]
            labels = [self.id_2_ohe_vector[img] for img in images_current]
        else:
            labels = self.labels
        return np.array(labels)

In [None]:
is_public_test_run = len(sub_df)==559 and FAST_PUBLIC_RUN
if is_public_test_run:
    test_ids = test_ids[:10]

In [None]:
explainer = IntegratedGradients()

## Cell segmentation

In [None]:
NUC_MODEL = '../input/hpacellsegmentatormodelweights/dpn_unet_nuclei_v1.pth'
CELL_MODEL = '../input/hpacellsegmentatormodelweights/dpn_unet_cell_3ch_v1.pth'

segmentator = cellsegmentator.CellSegmentator(
    NUC_MODEL,
    CELL_MODEL,
    scale_factor=0.25,
    device='cuda',
    padding=False,
    multi_channel_model=True
)

In [None]:
def get_masks(imgs, test=True):
    try:
        images = [[img[:, :, 0] for img in imgs], 
                  [img[:, :, 3] for img in imgs], 
                  [img[:, :, 2] for img in imgs]]
    
        nuc_segmentations = segmentator.pred_nuclei(images[2])
        cell_segmentations = segmentator.pred_cells(images)
        cell_masks = []
        for i in tqdm(range(len(cell_segmentations)), desc='Labeling cells..'):
            _, cell_mask = label_cell(nuc_segmentations[i], cell_segmentations[i])
            cell_masks.append(cell_mask)
        return cell_masks
    except:
        raise ValueError('Segmentation failed')

## Cell-level predictions using explainability

In [None]:
def vis_integrated_gradients_masks_test(img_idx, conf_threshold=0.01, mask_height=2048, mask_width=2048, 
                                        max_cell_level_conf_2_image_level_conf=0.005, test_ids=test_ids,
                                        model=model_rgby, quantile_level=0.9, figsize=7):
    image_id = test_ids[img_idx]
    img = [cv2.resize(cv2.imread(os.path.join(TEST_IMGS_FOLDER, f'{image_id}_{color}.png'), cv2.IMREAD_GRAYSCALE),
                      (mask_height, mask_width))
           for color in ['red','green','blue','yellow']]
    img = np.stack(img, axis=-1)
    mask = get_masks([img])[0]
    n_cells = mask.max()
    cell_2_max_conf = dict()   
    
    img = cv2.resize(img, (IMG_HEIGHT, IMG_WIDTH)).astype(np.float32)/255.
    predictions_test = model.predict(np.expand_dims(img, 0))
    
    for class_i, class_name in enumerate(class_names[:-1]):
        class_conf_score = predictions_test[0][class_i]
        if class_conf_score > conf_threshold:
            try:
                explanation = explainer.explain(([img], None), model, class_i, n_steps=15)
                explanation_img = cv2.resize(explanation, (mask_height, mask_width))
                explanation_total_level = np.quantile(explanation_img.flatten(), quantile_level)
            except:
                continue

            plt.figure(figsize=(figsize, figsize))
            plt.imshow(mask)
            plt.imshow(explanation_img, alpha=0.7)
            plt.xticks([])
            plt.yticks([])
            plt.title(f'{test_ids[img_idx]}\n{class_name} ({class_conf_score:.2f}): raw Grad-CAMs', fontsize=22)
            plt.show()

            masks_all = np.zeros((mask_height, mask_width))
            coord_2_conf = dict()
            for cell_i in range(1, n_cells + 1):
                cell_mask_bool = mask == cell_i
                cell_explanation_perc = np.quantile(explanation_img[cell_mask_bool], quantile_level)
                cell_conf = np.clip(cell_explanation_perc*class_conf_score/explanation_total_level, 0, class_conf_score)
                if cell_conf/class_conf_score >= max_cell_level_conf_2_image_level_conf and cell_conf > 1e-3:
                    masks_all[cell_mask_bool] = 1
                    mask_pixels_x, mask_pixels_y = np.where(cell_mask_bool)
                    coord_2_conf[(int(mask_pixels_y.mean()), int(mask_pixels_x.mean()))] = cell_conf
                    if not cell_i in cell_2_max_conf:
                        cell_2_max_conf[cell_i] = cell_conf
                    else:
                        cell_2_max_conf[cell_i] = max(cell_conf, cell_2_max_conf[cell_i])

            plt.figure(figsize=(figsize, figsize)) 
            plt.imshow(masks_all)
            for coords, conf in coord_2_conf.items():
                conf_rounded = np.round(conf*100)/100
                plt.scatter(*coords, s=700, color='red', marker=r"$ {} $".format(conf_rounded))
            plt.xticks([])
            plt.yticks([])
            plt.title(f'{test_ids[img_idx]}\n{class_name}: cell-level predictions', fontsize=22)
            plt.show()

    masks_all = np.zeros((mask_height, mask_width))
    coord_2_conf = dict()
    for cell_i in range(1, n_cells + 1):
        if not cell_i in cell_2_max_conf:
            cell_conf = 0.99
        else:
            cell_conf = 1 - cell_2_max_conf[cell_i]
        if cell_conf >= conf_threshold:
            cell_mask_bool = mask == cell_i
            masks_all[cell_mask_bool] = 1
            mask_pixels_x, mask_pixels_y = np.where(cell_mask_bool)
            coord_2_conf[(int(mask_pixels_y.mean()), int(mask_pixels_x.mean()))] = cell_conf

    plt.figure(figsize=(9, 9)) 
    plt.imshow(masks_all)
    for coords, conf in coord_2_conf.items():
        conf_rounded = np.round(conf*100)/100
        plt.scatter(*coords, s=700, color='red', marker=r"$ {} $".format(conf_rounded))
    plt.xticks([])
    plt.yticks([])
    plt.title(f'{test_ids[img_idx]}\n{class_names[-1]}: cell-level predictions', fontsize=22)
    plt.show()

In [None]:
if is_public_test_run:
    for test_img_id in range(10):
        vis_integrated_gradients_masks_test(test_img_id)

## Submission routines

In [None]:
def encode_binary_mask(mask: np.ndarray) -> t.Text:
    """Converts a binary mask into OID challenge encoding ascii text."""

    # check input mask --
    if mask.dtype != np.bool:
        raise ValueError(
            "encode_binary_mask expects a binary mask, received dtype == %s" %
            mask.dtype)

    mask = np.squeeze(mask)
    if len(mask.shape) != 2:
        raise ValueError(
            "encode_binary_mask expects a 2d mask, received shape == %s" %
            mask.shape)

    # convert input mask to expected COCO API input --
    mask_to_encode = mask.reshape(mask.shape[0], mask.shape[1], 1)
    mask_to_encode = mask_to_encode.astype(np.uint8)
    mask_to_encode = np.asfortranarray(mask_to_encode)

    # RLE encode mask --
    encoded_mask = coco_mask.encode(mask_to_encode)[0]["counts"]

    # compress and base64 encoding --
    binary_str = zlib.compress(encoded_mask, zlib.Z_BEST_COMPRESSION)
    base64_str = base64.b64encode(binary_str)
    return base64_str.decode()

In [None]:
test_id_2_order_idx = {test_id: idx for idx, test_id in enumerate(test_ids)}

# Cell-level predictions

In [None]:
def get_predictions_string_classification(img_ids, mask_heights, mask_widths,
                                          classifier_img_height=IMG_HEIGHT, classifier_img_width=IMG_WIDTH,
                                          classifier=model_rgby, conf_threshold=0.1, 
                                          batch_size=BATCH_SIZE, vis=False, class_names=class_names): 
    results_list = []
    img_idx = 0
    data_gen = DataGenenerator(img_ids, folder_imgs=TEST_IMGS_FOLDER, shuffle=False, batch_size=batch_size,
                               resized_height=2048, resized_width=2048, resize=True)
    
    def get_cell_only(cell_bool_mask, img, background_val=0, vis_cell=False):
        cell_img = img.copy()
        cell_img[np.logical_not(cell_bool_mask)] = background_val
        if vis_cell:
            plt.figure(figsize=(9, 9))
            plt.imshow(cell_img[:, :, :3])
            plt.xticks([])
            plt.yticks([])
            plt.title(f'Cell only', fontsize=22)
            plt.show()
        return cell_img
    

    for batch_i in range(len(img_ids)//batch_size + (1 if len(img_ids)%batch_size != 0 else 0)):
        img_batch_i = 0

        images_batch = data_gen.__getitem__(batch_i)[:len(img_ids) - batch_i*batch_size, :, :]
        img_batch_ids = img_ids[batch_i*batch_size:(batch_i + 1)*batch_size]
        try:
            masks_batch = get_masks(images_batch)
            images_batch = np.stack([cv2.resize(img, (IMG_HEIGHT, IMG_WIDTH)) for img in images_batch])
            predictions_batch = classifier.predict(images_batch)
        except ValueError:
            current_batch_size = images_batch.shape[0]
            results_list.extend(['' for _ in range(current_batch_size)])
            continue
        
        for mask_i, mask_init in enumerate(masks_batch):
            try:
                cell_2_max_conf = dict()
                results_list_img = []
                mask_height, mask_width = mask_heights[img_idx], mask_widths[img_idx]
                mask = cv2.resize(mask_init, (mask_height, mask_width))
                mask_classification = cv2.resize(mask_init, (classifier_img_height, classifier_img_width))
                n_cells = mask_classification.max()
                if n_cells == 0:
                    results_list.append('')
                img_current = images_batch[mask_i]
                img_background_mean = img_current[mask_classification == 0].mean()

                cell_2_predictions_list = []
                classifier_batch_next = []
                for cell_i in range(1, n_cells + 1):
                    cell_mask_bool = mask_classification == cell_i
                    cell_masked_img = get_cell_only(cell_mask_bool, img_current, 
                                                    background_val=img_background_mean, vis_cell=vis and mask_i==0)
                    classifier_batch_next.append(cell_masked_img)
                    if len(classifier_batch_next) == batch_size:
                        try:
                            cell_predictions_batch = classifier.predict(np.stack(classifier_batch_next))
                        except:
                            cell_predictions_batch = np.zeros((batch_size, len(class_names) - 1))
                        classifier_batch_next = []
                        cell_2_predictions_list.append(cell_predictions_batch)
                # last incomplete batch
                if len(classifier_batch_next) > 0:
                    if len(classifier_batch_next) > 1:
                        cell_imgs_last = np.stack(classifier_batch_next)
                    else:
                        cell_imgs_last = np.expand_dims(classifier_batch_next[0], 0)

                    try:
                        cell_predictions_batch = classifier.predict(cell_imgs_last)
                    except:
                        cell_predictions_batch = np.zeros((cell_imgs_last.shape[0], len(class_names) - 1))
                    cell_2_predictions_list.append(cell_predictions_batch)
                cell_2_predictions_np = np.concatenate(cell_2_predictions_list) if len(cell_2_predictions_list) > 1 else cell_2_predictions_list[0]
                cell_2_rle = dict()
                for class_i, class_name in enumerate(class_names[:-1]):
                    class_conf_score = predictions_batch[img_batch_i][class_i]
                    if class_conf_score > conf_threshold:
                        for cell_i in range(n_cells):
                            cell_conf = cell_2_predictions_np[cell_i, class_i]
                            cell_conf = np.clip(cell_conf, 0, class_conf_score)
                            if cell_conf > conf_threshold:
                                if cell_i in cell_2_rle:
                                    mask_rle = cell_2_rle[cell_i]
                                else:
                                    cell_mask_bool = mask == cell_i + 1
                                    mask_rle = encode_binary_mask(cell_mask_bool)
                                    cell_2_rle[cell_i] = mask_rle
                                results_list_img.extend([str(class_i), f'{cell_conf:.4f}', mask_rle])
                                if not cell_i in cell_2_max_conf:
                                    cell_2_max_conf[cell_i] = cell_conf
                                else:
                                    cell_2_max_conf[cell_i] = max(cell_conf, cell_2_max_conf[cell_i])

                # nothing interesting there
                for cell_i in range(n_cells):
                    if not cell_i in cell_2_max_conf:
                        cell_conf = 0.99
                    else:
                        cell_conf = 1 - cell_2_max_conf[cell_i]
                    if cell_conf > conf_threshold:
                        if cell_i in cell_2_rle:
                            mask_rle = cell_2_rle[cell_i]
                        else:
                            cell_mask_bool = mask == cell_i + 1
                            mask_rle = encode_binary_mask(cell_mask_bool)
                        results_list_img.extend([str(len(class_names) - 1), f'{cell_conf:.4f}', mask_rle])


                results_list.append(' '.join(results_list_img))
                img_idx += 1
                img_batch_i += 1
            except:
                results_list.append('')
                img_idx += 1
                img_batch_i += 1

    return results_list


# sanity check
sub_df_head = sub_df.head(2)
# classifier_preds
inference_step = 1
for next_start_block_i in range(0, sub_df_head.shape[0], inference_step):
    sub_df_head.iloc[next_start_block_i: next_start_block_i+inference_step,
                     sub_df_head.columns.get_loc('PredictionString')] = get_predictions_string_classification(sub_df_head['ID'].values[next_start_block_i: next_start_block_i+inference_step], 
                                                                                                              sub_df_head['ImageHeight'].values[next_start_block_i: next_start_block_i+inference_step],
                                                                                                              sub_df_head['ImageWidth'].values[next_start_block_i: next_start_block_i+inference_step], vis=True)

In [None]:
def get_predictions_string_integrated_grads(img_ids, mask_heights, mask_widths,
                                            classifier_img_height=IMG_HEIGHT, classifier_img_width=IMG_WIDTH,
                                            max_cell_level_conf_2_image_level_conf=0.01, 
                                            model=model_rgby, quantile_level=0.9, conf_threshold=0.005, 
                                            batch_size=BATCH_SIZE, class_names=class_names): 
    results_list = []
    img_idx = 0
    data_gen = DataGenenerator(img_ids, folder_imgs=TEST_IMGS_FOLDER, shuffle=False, batch_size=batch_size,
                               resized_height=2048, resized_width=2048, resize=True)
    
    def get_cell_only(cell_bool_mask, img, background_val=0, vis_cell=False):
        cell_img = img.copy()
        cell_img[np.logical_not(cell_bool_mask)] = background_val
        if vis_cell:
            plt.figure(figsize=(9, 9))
            plt.imshow(cell_img[:, :, :3])
            plt.xticks([])
            plt.yticks([])
            plt.title(f'Cell only', fontsize=22)
            plt.show()
        return cell_img
    

    for batch_i in range(len(img_ids)//batch_size + (1 if len(img_ids)%batch_size != 0 else 0)):
        img_batch_i = 0

        images_batch = data_gen.__getitem__(batch_i)[:len(img_ids) - batch_i*batch_size, :, :]
        img_batch_ids = img_ids[batch_i*batch_size:(batch_i + 1)*batch_size]
        try:
            masks_batch = get_masks(images_batch)
            images_batch = np.stack([cv2.resize(img, (IMG_HEIGHT, IMG_WIDTH)) for img in images_batch])
            predictions_batch = model.predict(images_batch)
        except ValueError:
            current_batch_size = images_batch.shape[0]
            results_list.extend(['' for _ in range(current_batch_size)])
            continue
        
        for mask_i, mask_init in enumerate(masks_batch):
            cell_2_max_conf = dict()
            results_list_img = []
            mask_height, mask_width = mask_heights[img_idx], mask_widths[img_idx]
            mask = cv2.resize(mask_init, (mask_height, mask_width))
            n_cells = mask.max()
            if n_cells == 0:
                results_list.append('')
            img_current = images_batch[mask_i]
            
            cell_2_rle = dict()
            cell_2_mask = dict()
            for class_i, class_name in enumerate(class_names[:-1]):
                class_conf_score = predictions_batch[img_batch_i][class_i]
                if class_conf_score > conf_threshold:                    
                    try:
                        explanation = explainer.explain(([img_current], None), model, class_i, n_steps=15)
                        explanation_img = cv2.resize(explanation, (mask_height, mask_width))
                        explanation_total_level = np.quantile(explanation_img.flatten(), quantile_level)
                    except:
                        continue
                    for cell_i in range(n_cells):
                        if cell_i in cell_2_mask:
                            cell_mask_bool = cell_2_mask[cell_i]
                        else:
                            cell_mask_bool = mask == cell_i
                            cell_2_mask[cell_i] = cell_mask_bool
                        cell_explanation_perc = np.quantile(explanation_img[cell_mask_bool], quantile_level)
                        cell_conf = np.clip(cell_explanation_perc*class_conf_score/explanation_total_level, 0, class_conf_score)
                        if cell_conf/class_conf_score >= max_cell_level_conf_2_image_level_conf and cell_conf > 1e-3:  
                            if cell_i in cell_2_rle:
                                mask_rle = cell_2_rle[cell_i]
                            else:
                                mask_rle = encode_binary_mask(cell_mask_bool)
                                cell_2_rle[cell_i] = mask_rle
                            results_list_img.extend([str(class_i), f'{cell_conf:.4f}', mask_rle])
                            if not cell_i in cell_2_max_conf:
                                cell_2_max_conf[cell_i] = cell_conf
                            else:
                                cell_2_max_conf[cell_i] = max(cell_conf, cell_2_max_conf[cell_i])

            # nothing interesting there
            for cell_i in range(n_cells):
                if not cell_i in cell_2_max_conf:
                    cell_conf = 0.99
                else:
                    cell_conf = 1 - cell_2_max_conf[cell_i]
                if cell_conf > conf_threshold:
                    if cell_i in cell_2_rle:
                        mask_rle = cell_2_rle[cell_i]
                    else:
                        cell_mask_bool = mask == cell_i + 1
                        mask_rle = encode_binary_mask(cell_mask_bool)
                    results_list_img.extend([str(len(class_names) - 1), f'{cell_conf:.4f}', mask_rle])


            results_list.append(' '.join(results_list_img))
            img_idx += 1
            img_batch_i += 1
#             except:
#                 results_list.append('')
#                 img_idx += 1
#                 img_batch_i += 1

    return results_list

In [None]:
# sanity check
sub_df_head = sub_df.head(2)
inference_step = 1
for next_start_block_i in range(0, sub_df_head.shape[0], inference_step):
    sub_df_head.iloc[next_start_block_i: next_start_block_i+inference_step,
                     sub_df_head.columns.get_loc('PredictionString')] = get_predictions_string_integrated_grads(sub_df_head['ID'].values[next_start_block_i: next_start_block_i+inference_step], 
                                                                                         sub_df_head['ImageHeight'].values[next_start_block_i: next_start_block_i+inference_step],
                                                                                         sub_df_head['ImageWidth'].values[next_start_block_i: next_start_block_i+inference_step])

In [None]:
del sub_df_head
del model_rgb
gc.collect()

In [None]:
# to save the time for the public test set run
if is_public_test_run:
    sub_df.to_csv('submission.csv', index=None)
else:   
    sub_df['PredictionString'] = ''
    gc.collect()
    
    inference_step = 16
    for next_start_block_i in range(0, sub_df.shape[0], inference_step):
        sub_df.iloc[next_start_block_i: next_start_block_i+inference_step,
                    sub_df.columns.get_loc('PredictionString')] = get_predictions_string_integrated_grads(sub_df['ID'].values[next_start_block_i: next_start_block_i+inference_step], 
                                                                                   sub_df['ImageHeight'].values[next_start_block_i: next_start_block_i+inference_step],
                                                                                   sub_df['ImageWidth'].values[next_start_block_i: next_start_block_i+inference_step])
    sub_df.to_csv('submission.csv', index=None)