In [None]:
!pip install "/kaggle/input/hpamisc/pytorch_zoo-master"
!pip install "/kaggle/input/hpamisc/pycocotools-2.0-cp37-cp37m-linux_x86_64.whl"
!pip install "/kaggle/input/hpamisc/faiss_gpu-1.7.0-cp37-cp37m-manylinux2014_x86_64.whl"

In [None]:
import sys
import importlib

MODULE_PATH = "/kaggle/input/timmilya/pytorch-image-models-master/timm/__init__.py"
MODULE_NAME = "timm"

spec = importlib.util.spec_from_file_location(MODULE_NAME, MODULE_PATH)
module = importlib.util.module_from_spec(spec)
sys.modules[spec.name] = module
spec.loader.exec_module(module)

In [None]:
import os
import cv2
import math
import timm
import zlib
import faiss
import torch
import base64
import pickle
import torch.nn
import numpy as np
import pandas as pd
from tqdm import tqdm
import albumentations as A
import scipy.ndimage as ndi
import matplotlib.pyplot as plt
import torch.nn.functional as F
from scipy.sparse import csr_matrix
from pycocotools import _mask as coco_mask
from skimage import filters, measure, segmentation, transform, util
from skimage.morphology import closing, disk, remove_small_holes, remove_small_objects
from fastai.vision.all import cnn_learner, DataBlock, aug_transforms, Resize, Normalize, ImageBlock, MultiCategoryBlock, RandomSplitter, partial, PILImage, create_body, create_head, num_features_model

In [None]:
directory = '/kaggle/input/hpa-single-cell-image-classification'

PUBLIC_ONLY = True

SEGMENTATION_SCALE = 0.25
ARCFACE_THRESH = 0.9

num_classes = 19
crop_input_size = 168
full_input_size = 512

In [None]:
test_df = pd.read_csv(os.path.join(directory, 'sample_submission.csv'))

if len(test_df) == 559:
    test_df = test_df[:8][::2]
    
if PUBLIC_ONLY:
    with open("/kaggle/input/hpamisc/public_image_ids.pickle", 'rb') as f:
        public_image_ids = pickle.load(f)

In [None]:
crop_weights = ['dm_nfnet_f0_168_2021_04_23__11_13_04/epoch_0',
                'dm_nfnet_f1_168_2021_04_26__14_52_16/epoch_0',
                'dm_nfnet_f2_168_2021_04_27__15_13_50/epoch_0',
                'dm_nfnet_f3_168_2021_04_30__18_30_44/epoch_0',
                'ecaresnet50d_168_2021_04_24__08_48_09/epoch_0',
                'ecaresnet50t_168_2021_05_01__20_25_39/epoch_0',
                'seresnet152d_168_2021_04_23__17_30_06/epoch_0',
                'efficientnet_v2s_168_2021_04_23__22_05_47/epoch_0',
                'seresnext50_32x4d_168_2021_04_24__08_50_52/epoch_0',
                'tf_efficientnet_b5_ns_168_2021_05_01__20_24_56/epoch_0']

crop_mdls = []
for w in crop_weights:
    parts = w.split('_2021')[0].split('_')
    crop_mdl = '_'.join(parts[:-1])
    crop_mdls.append(crop_mdl)
    print(crop_mdl)

In [None]:
full_weights = ['eca_nfnet_l0_512_2021_04_29__16_56_08/epoch_9',
                'eca_nfnet_l0_512_2021_04_29__22_06_03/epoch_9',
                'eca_nfnet_l1_512_2021_04_30__00_35_08/epoch_9',
                'eca_nfnet_l1_512_2021_04_30__00_36_12/epoch_9',
                'ecaresnet50t_512_2021_04_30__15_18_41/epoch_9',
                'ecaresnet50t_512_2021_04_30__15_18_52/epoch_9',
                'efficientnet_v2s_512_2021_04_30__12_21_17/epoch_9',
                'efficientnet_v2s_512_2021_04_30__15_14_15/epoch_9',
                'tf_efficientnet_b5_ns_512_2021_05_01__14_58_14/epoch_9',
                'tf_efficientnet_b5_ns_512_2021_05_01__14_58_27/epoch_9']

full_mdls = []
for w in full_weights:
    parts = w.split('_2021')[0].split('_')
    full_mdl = '_'.join(parts[:-1])
    full_mdls.append(full_mdl)
    print(full_mdl)

In [None]:
arcface_weights = 'eca_nfnet_l0_512_2021_04_25__23_37_16/epoch_14'

parts = arcface_weights.split('_2021')[0].split('_')

arcface_mdl = '_'.join(parts[:-1])

print(arcface_mdl)

In [None]:
with open(f"/kaggle/input/hpaembeddings/embeddings_{arcface_weights.replace('/', '_')}/embeddings.pickle", 'rb') as f:
    embeddings = pickle.load(f)

embeddings_labels = np.zeros((len(embeddings), num_classes), dtype=np.float32)
embeddings_features = np.zeros((len(embeddings), 512), dtype=np.float32)

for i, (k,v) in enumerate(embeddings.items()):
    embeddings_labels[i] = v['labels']
    embeddings_features[i] = v['embeddings']

embeddings_features /= np.linalg.norm(embeddings_features, axis=1, keepdims=True)

In [None]:
gpu_index = faiss.IndexFlatIP(embeddings_features.shape[1])
gpu_index = faiss.index_cpu_to_all_gpus(gpu_index)
gpu_index.add(embeddings_features)

In [None]:
def read_img(image_id_path):
    img = cv2.imread(image_id_path, 0)
    return img

def load_RGBY_images(image_id_path):
    
    red_image = read_img(image_id_path+"_red.png")
    green_image = read_img(image_id_path+"_green.png")
    blue_image = read_img(image_id_path+"_blue.png")
    yellow_image = read_img(image_id_path+"_yellow.png")
    
    return red_image, green_image, blue_image, yellow_image

def encode_binary_mask(mask):

    if mask.dtype != np.bool:
        raise ValueError("encode_binary_mask expects a binary mask, received dtype == %s" % mask.dtype)

    mask = np.squeeze(mask)
    if len(mask.shape) != 2:
        raise ValueError("encode_binary_mask expects a 2d mask, received shape == %s" % mask.shape)
    
    mask = mask.astype(np.uint8)
    
    mask_to_encode = mask.reshape(mask.shape[0], mask.shape[1], 1)
    mask_to_encode = np.asfortranarray(mask_to_encode)
    
    encoded_mask = coco_mask.encode(mask_to_encode)[0]["counts"]
    
    binary_str = zlib.compress(encoded_mask, zlib.Z_BEST_COMPRESSION)
    base64_str = base64.b64encode(binary_str)
    
    return base64_str.decode()

def compute_M(data):
    cols = np.arange(data.size)
    return csr_matrix((cols, (data.ravel(), cols)), shape=(data.max() + 1, data.size))

def get_indices_sparse(data):
    M = compute_M(data)
    return [np.unravel_index(row.data, data.shape) for row in M]

class CellSegmentator(object):
    """Uses pretrained DPN-Unet models to segment cells from images."""

    NORMALIZE = {"mean": [124 / 255, 117 / 255, 104 / 255], "std": [1 / (0.0167 * 255)] * 3}
    
    def __init__(
            self,
            nuclei_model="./nuclei_model.pth",
            cell_model="./cell_model.pth",
            model_width_height=None,
            device="cuda",
            multi_channel_model=True,
            return_without_scale_restore=False,
            scale_factor=0.25,
            padding=False
    ):

        if device != "cuda" and device != "cpu" and "cuda" not in device:
            raise ValueError(f"{device} is not a valid device (cuda/cpu)")
        if device != "cpu":
            try:
                assert torch.cuda.is_available()
            except AssertionError:
                print("No GPU found, using CPU.", file=sys.stderr)
                device = "cpu"
                
        self.device = device

        if isinstance(nuclei_model, str):
            if not os.path.exists(nuclei_model):
                print(
                    f"Could not find {nuclei_model}. Downloading it now",
                    file=sys.stderr,
                )
                download_with_url(NUCLEI_MODEL_URL, nuclei_model)
            nuclei_model = torch.load(
                nuclei_model, map_location=torch.device(self.device)
            )
        if isinstance(nuclei_model, torch.nn.DataParallel) and device == "cpu":
            nuclei_model = nuclei_model.module

        self.nuclei_model = nuclei_model.to(self.device)

        self.multi_channel_model = multi_channel_model
        if isinstance(cell_model, str):
            if not os.path.exists(cell_model):
                print(
                    f"Could not find {cell_model}. Downloading it now", file=sys.stderr
                )
                if self.multi_channel_model:
                    download_with_url(MULTI_CHANNEL_CELL_MODEL_URL, cell_model)
                else:
                    download_with_url(TWO_CHANNEL_CELL_MODEL_URL, cell_model)
            cell_model = torch.load(cell_model, map_location=torch.device(self.device))
        self.cell_model = cell_model.to(self.device)
        self.model_width_height = model_width_height
        self.return_without_scale_restore = return_without_scale_restore
        self.scale_factor = scale_factor
        self.padding = padding

    def _image_conversion(self, images):

        microtubule_imgs, er_imgs, nuclei_imgs = images
        if self.multi_channel_model:
            if not isinstance(er_imgs, list):
                raise ValueError("Please speicify the image path(s) for er channels!")
        else:
            if not er_imgs is None:
                raise ValueError(
                    "second channel should be None for two channel model predition!"
                )

        if not isinstance(microtubule_imgs, list):
            raise ValueError("The microtubule images should be a list")
        if not isinstance(nuclei_imgs, list):
            raise ValueError("The microtubule images should be a list")

        if er_imgs:
            if not len(microtubule_imgs) == len(er_imgs) == len(nuclei_imgs):
                raise ValueError("The lists of images needs to be the same length")
        else:
            if not len(microtubule_imgs) == len(nuclei_imgs):
                raise ValueError("The lists of images needs to be the same length")

        if not all(isinstance(item, np.ndarray) for item in microtubule_imgs):
            microtubule_imgs = [
                os.path.expanduser(item) for _, item in enumerate(microtubule_imgs)
            ]
            nuclei_imgs = [
                os.path.expanduser(item) for _, item in enumerate(nuclei_imgs)
            ]

            microtubule_imgs = list(
                map(lambda item: imageio.imread(item), microtubule_imgs)
            )
            nuclei_imgs = list(map(lambda item: imageio.imread(item), nuclei_imgs))
            if er_imgs:
                er_imgs = [os.path.expanduser(item) for _, item in enumerate(er_imgs)]
                er_imgs = list(map(lambda item: imageio.imread(item), er_imgs))

        if not er_imgs:
            er_imgs = [
                np.zeros(item.shape, dtype=item.dtype)
                for _, item in enumerate(microtubule_imgs)
            ]
        cell_imgs = list(
            map(
                lambda item: np.dstack((item[0], item[1], item[2])),
                list(zip(microtubule_imgs, er_imgs, nuclei_imgs)),
            )
        )

        return cell_imgs

    def _pad(self, image):
        
        rows, cols = image.shape[:2]
        self.scaled_shape = rows, cols
        img_pad= cv2.copyMakeBorder(
                    image,
                    32,
                    (32 - rows % 32),
                    32,
                    (32 - cols % 32),
                    cv2.BORDER_REFLECT,
                )
        
        return img_pad

    def pred_nuclei(self, images):

        def _preprocess(images):
            if isinstance(images[0], str):
                raise NotImplementedError('Currently the model requires images as numpy arrays, not paths.')
                # images = [imageio.imread(image_path) for image_path in images]
            self.target_shapes = [image.shape for image in images]
            #print(images.shape)
            #resize like in original implementation with https://scikit-image.org/docs/dev/api/skimage.transform.html#skimage.transform.resize
            if self.model_width_height:
                images = np.array([transform.resize(image, (self.model_width_height,self.model_width_height)) 
                                  for image in images])
            else:
                images = [transform.rescale(image, self.scale_factor) for image in images]

            if self.padding:
                images = [self._pad(image) for image in images]

            nuc_images = np.array([np.dstack((image[..., 2], image[..., 2], image[..., 2])) if len(image.shape) >= 3
                                   else np.dstack((image, image, image)) for image in images])
            
            nuc_images = nuc_images.transpose([0, 3, 1, 2])
            #print("nuc", nuc_images.shape)

            return nuc_images

        def _segment_helper(imgs):
            with torch.no_grad():
                mean = torch.as_tensor(self.NORMALIZE["mean"], device=self.device)
                std = torch.as_tensor(self.NORMALIZE["std"], device=self.device)
                imgs = torch.tensor(imgs).float()
                imgs = imgs.to(self.device)
                imgs = imgs.sub_(mean[:, None, None]).div_(std[:, None, None])

                imgs = self.nuclei_model(imgs)
                imgs = F.softmax(imgs, dim=1)
                return imgs

        preprocessed_imgs = _preprocess(images)
        predictions = _segment_helper(preprocessed_imgs)
        predictions = predictions.to("cpu").numpy()
        #dont restore scaling, just save and scale later ...
        predictions = [self._restore_scaling(util.img_as_ubyte(pred), target_shape)
                       for pred, target_shape in zip(predictions, self.target_shapes)]
        return predictions

    def _restore_scaling(self, n_prediction, target_shape):
        """Restore an image from scaling and padding.
        This method is intended for internal use.
        It takes the output from the nuclei model as input.
        """
        n_prediction = n_prediction.transpose([1, 2, 0])
        if self.padding:
            n_prediction = n_prediction[
                32 : 32 + self.scaled_shape[0], 32 : 32 + self.scaled_shape[1], ...
            ]
        n_prediction[..., 0] = 0
        if not self.return_without_scale_restore:
            n_prediction = cv2.resize(
                n_prediction,
                (target_shape[0], target_shape[1]),
                #try INTER_NEAREST_EXACT
                interpolation=cv2.INTER_AREA,
            )
        return n_prediction

    def pred_cells(self, images, precombined=False):

        def _preprocess(images):
            self.target_shapes = [image.shape for image in images]
            for image in images:
                if not len(image.shape) == 3:
                    raise ValueError("image should has 3 channels")
            #resize like in original implementation with https://scikit-image.org/docs/dev/api/skimage.transform.html#skimage.transform.resize
            if self.model_width_height:
                images = np.array([transform.resize(image, (self.model_width_height,self.model_width_height)) 
                                  for image in images])
            else:
                images = np.array([transform.rescale(image, self.scale_factor, multichannel=True) for image in images])

            if self.padding:
                images = np.array([self._pad(image) for image in images])

            cell_images = images.transpose([0, 3, 1, 2])

            return cell_images

        def _segment_helper(imgs):
            with torch.no_grad():
                mean = torch.as_tensor(self.NORMALIZE["mean"], device=self.device)
                std = torch.as_tensor(self.NORMALIZE["std"], device=self.device)
                imgs = torch.tensor(imgs).float()
                imgs = imgs.to(self.device)
                imgs = imgs.sub_(mean[:, None, None]).div_(std[:, None, None])
                imgs = self.cell_model(imgs)
                imgs = F.softmax(imgs, dim=1)
                return imgs

        if not precombined:
            images = self._image_conversion(images)
        preprocessed_imgs = _preprocess(images)
        predictions = _segment_helper(preprocessed_imgs)
        predictions = predictions.to("cpu").numpy()
        predictions = [self._restore_scaling(util.img_as_ubyte(pred), target_shape)
                       for pred, target_shape in zip(predictions, self.target_shapes)]
        
        return predictions

def __fill_holes(image):
    """Fill_holes for labelled image, with a unique number."""
    boundaries = segmentation.find_boundaries(image)
    image = np.multiply(image, np.invert(boundaries))
    image = ndi.binary_fill_holes(image > 0)
    image = ndi.label(image)[0]
    return image

def label_cell_vlad(nuclei_pred, cell_pred, img_size=512, return_nuclei_label=True):
    """Label the cells and the nuclei.

    Keyword arguments:
    nuclei_pred -- a 3D numpy array of a prediction from a nuclei image.
    cell_pred -- a 3D numpy array of a prediction from a cell image.

    Returns:
    A tuple containing:
    nuclei-label -- A nuclei mask data array.
    cell-label  -- A cell mask data array.

    0's in the data arrays indicate background while a continous
    strech of a specific number indicates the area for a specific
    cell.
    The same value in cell mask and nuclei mask refers to the identical cell.

    NOTE: The nuclei labeling from this function will be sligthly
    different from the values in :func:`label_nuclei` as this version
    will use information from the cell-predictions to make better
    estimates.
    """
    def __wsh(
        mask_img,
        threshold,
        border_img,
        seeds,
        threshold_adjustment=0.35,
        small_object_size_cutoff=10,
    ):
        img_copy = np.zeros_like(mask_img)
        m = seeds * border_img  # * dt
        img_copy[m > threshold + threshold_adjustment] = 1
        img_copy = img_copy.astype(np.bool)
        img_copy = remove_small_objects(img_copy, small_object_size_cutoff).astype(
            np.uint8
        )

        mask_img = np.where(mask_img <= threshold, 0, 1)
        mask_img = mask_img.astype(np.bool)
        
        ### New segmentation ###
        mask_img = remove_small_holes(mask_img, int(63 * (img_size / 512)**2))
        ########################
        
        mask_img = remove_small_objects(mask_img, 8).astype(np.uint8)
        markers = ndi.label(img_copy, output=np.uint32)[0]
        labeled_array = segmentation.watershed(
            mask_img, markers, mask=mask_img, watershed_line=True
        )
        return labeled_array

    nuclei_label = __wsh(
        nuclei_pred[..., 2] / 255.0,
        0.4,
        1 - (nuclei_pred[..., 1] + cell_pred[..., 1]) / 255.0 > 0.05,
        nuclei_pred[..., 2] / 255,
        threshold_adjustment=-0.25,
        small_object_size_cutoff=32,
    )

    # for hpa_image, to remove the small pseduo nuclei
    
    ### New segmentation ###
    nuclei_label = remove_small_objects(nuclei_label, int(157 * (img_size / 512)**2))
    ########################
    
    nuclei_label = measure.label(nuclei_label)
    # this is to remove the cell borders' signal from cell mask.
    # could use np.logical_and with some revision, to replace this func.
    # Tuned for segmentation hpa images
    threshold_value = max(0.22, filters.threshold_otsu(cell_pred[..., 2] / 255) * 0.5)
    # exclude the green area first
    cell_region = np.multiply(
        cell_pred[..., 2] / 255 > threshold_value,
        np.invert(np.asarray(cell_pred[..., 1] / 255 > 0.05, dtype=np.int8)),
    )
    sk = np.asarray(cell_region, dtype=np.int8)
    
    ################### CHANGES HERE ###################################
 
    ####################### V4 ####################
    distance_old = np.clip(cell_pred[..., 2], 255 * threshold_value, cell_pred[..., 2])
    cell_label_old = segmentation.watershed(-distance_old, nuclei_label, mask=sk)
    
    ### New segmentation ###
    cell_label_old = remove_small_objects(cell_label_old, int(344 * (img_size / 512)**2)).astype(np.uint8)
    ########################
    
    distance = distance_old.copy()
    distance[distance<225] = 0
    
    cell_label = segmentation.watershed(-distance, nuclei_label, mask=distance)
    
    ### New segmentation ###
    cell_label = remove_small_objects(cell_label, int(344 * (img_size / 512)**2)).astype(np.uint8)
    ########################
    
    unqs = np.unique(cell_label)
    if 0 in unqs:
        unqs = unqs[1:]
        
    lst = [cv2.dilate((cell_label==unq).astype(np.uint8), kernel=np.ones((10, 10), np.uint8), iterations=1) for unq in unqs]
    cell_label = np.zeros_like(cell_label)
    
    for i, l in enumerate(lst):
        cell_label[l==True] = unqs[i]
                
    for unq in unqs:
        smth = cell_label_old==unq
        if True not in smth[0,:] and True not in smth[-1,:] and True not in smth[:,0] and True not in smth[:,-1]:
            cell_label[cell_label==unq] = 0
            cell_label[cell_label_old==unq] = unq
            
    ################### CHANGES HERE ###################################
    
#     selem = disk(max(1, int(6 * 2048 / img_size)))
    selem = disk(6)
    cell_label = closing(cell_label, selem)
    cell_label = __fill_holes(cell_label)
    # this part is to use green channel, and extend cell label to green channel
    # benefit is to exclude cells clear on border but without nucleus
    sk = np.asarray(
        np.add(
            np.asarray(cell_label > 0, dtype=np.int8),
            np.asarray(cell_pred[..., 1] / 255 > 0.05, dtype=np.int8),
        )
        > 0,
        dtype=np.int8,
    )
    cell_label = segmentation.watershed(-distance, cell_label, mask=sk)
    cell_label = __fill_holes(cell_label)
    cell_label = np.asarray(cell_label > 0, dtype=np.uint8)
    cell_label = measure.label(cell_label)
    
    ### New segmentation ###
    cell_label = remove_small_objects(cell_label, int(344 * (img_size / 512)**2))
    ########################
    
    cell_label = measure.label(cell_label)
    cell_label = np.asarray(cell_label, dtype=np.uint16)
    if not return_nuclei_label:
        return cell_label
    nuclei_label = np.multiply(cell_label > 0, nuclei_label) > 0
    nuclei_label = measure.label(nuclei_label)
    
    ### New segmentation ###
    nuclei_label = remove_small_objects(nuclei_label, int(157 * (img_size / 512)**2))
    ########################
    
    nuclei_label = np.multiply(cell_label, nuclei_label > 0)

    return nuclei_label, cell_label

def crop_net(crop_mdl, pretrained=False, **kwargs):
    model = timm.create_model(crop_mdl, pretrained=pretrained, num_classes=num_classes)
    return model

def full_net(full_mdl, pretrained=False, **kwargs):
    model = timm.create_model(full_mdl, pretrained=pretrained, num_classes=num_classes)
    return model

def arcface_net(pretrained=False, **kwargs):
    model = timm.create_model('eca_nfnet_l0', pretrained=pretrained, num_classes=num_classes)
    return model

class ArcMarginProduct(torch.nn.Module):

    def __init__(self, in_features=512, out_features=11582):
        super().__init__()
        self.weight = torch.nn.Parameter(torch.FloatTensor(out_features, in_features))
        self.reset_parameters()
        
    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        
    def forward(self, x):
        cosine = torch.functional.F.linear(torch.functional.F.normalize(x), torch.functional.F.normalize(self.weight.cuda()))
        return cosine  
    
class Customhead(torch.nn.Module):

    def __init__(self, in_features=512, out_features=11582):
        
        super(Customhead, self).__init__()
        
        body = create_body(arcface_net, n_in=3, pretrained=False)
        nf = num_features_model(torch.nn.Sequential(*body.children()))
        
        self.head = create_head(nf, n_out=in_features, concat_pool=True)
        self.arc_margin = ArcMarginProduct(in_features, out_features)

    def forward(self, features):
        x = self.head(features)
        return x

In [None]:
NUC_MODEL = '/kaggle/input/hpamisc/HPA-Cell-Segmentation-weights/dpn_unet_nuclei_v1.pth'
CELL_MODEL = '/kaggle/input/hpamisc/HPA-Cell-Segmentation-weights/dpn_unet_cell_3ch_v1.pth'

segmentator = CellSegmentator(
    NUC_MODEL,
    CELL_MODEL,
    device="cuda",
    multi_channel_model=True,
    scale_factor=SEGMENTATION_SCALE,
    padding=True,
    return_without_scale_restore=True
)

In [None]:
df = pd.read_csv('/kaggle/input/hpa-single-cell-image-classification/train.csv')[:1]
df['ID'] = df['ID'].map(lambda x: f'../input/hpa-single-cell-image-classification/train/{x}_red.png')

In [None]:
crop_batch_tfms = [Normalize.from_stats(mean=[0.135, 0.085, 0.100], std=[0.152, 0.099, 0.175])]
crop_resizer = Resize(crop_input_size, method='squish')

crop_dblock = DataBlock(blocks=(ImageBlock, MultiCategoryBlock(vocab=[str(i) for i in range(19)])),
                        splitter=RandomSplitter(),
                        get_x=lambda x: x[0],
                        get_y=lambda x: x['Label'].split('|'),
                        item_tfms=[],
                        batch_tfms=crop_batch_tfms)

crop_dls = crop_dblock.dataloaders(df, bs=64, val_bs=64)

t = PILImage.create(np.zeros((crop_input_size,crop_input_size,3), dtype=np.uint8)) # Needed for the first iteration to initialize model on GPU

crop_models = []
for crop_mdl, w in zip(crop_mdls, crop_weights):
    crop_model = cnn_learner(crop_dls, partial(crop_net, crop_mdl), pretrained=False).load(f"/kaggle/input/hpaweights/{w}")
    crop_model.get_preds(dl=crop_dls.test_dl([t]))[0] # Needed for the first iteration to initialize model on GPU
    crop_models.append(crop_model)

In [None]:
full_batch_tfms = [Normalize.from_stats(mean=[0.110, 0.063, 0.066], std=[0.157, 0.098, 0.160])]
full_resizer = Resize(full_input_size, method='squish')

full_dblock = DataBlock(blocks=(ImageBlock, MultiCategoryBlock(vocab=[str(i) for i in range(19)])),
                        splitter=RandomSplitter(),
                        get_x=lambda x: x[0],
                        get_y=lambda x: x['Label'].split('|'),
                        item_tfms=[],
                        batch_tfms=full_batch_tfms)

full_dls = full_dblock.dataloaders(df, bs=16, val_bs=16)

t = PILImage.create(np.zeros((full_input_size,full_input_size,3), dtype=np.uint8)) # Needed for the first iteration to initialize model on GPU

full_models = []
for full_mdl, w in zip(full_mdls, full_weights):
    full_model = cnn_learner(full_dls, partial(full_net, full_mdl), pretrained=False).load(f"/kaggle/input/hpaweights/{w}")
    full_model.get_preds(dl=full_dls.test_dl([t]))[0] # Needed for the first iteration to initialize model on GPU
    full_models.append(full_model)

In [None]:
arcface_model = cnn_learner(full_dls, arcface_net, custom_head=Customhead(512, 11582), pretrained=False).load(f"/kaggle/input/hpaweights/{arcface_weights}")
arcface_model.get_preds(dl=full_dls.test_dl([t]))[0] # Needed for the first iteration to initialize model on GPU
print()

In [None]:
def remove_small_nuclei_on_border(nuclei_mask, nuclei_area_thresh=0.5):
    
    h,w = nuclei_mask.shape[:2]
    
    touching_border = []
    
    nuclei_uniques = np.unique(nuclei_mask)
    if 0 in nuclei_uniques:
        nuclei_uniques = nuclei_uniques[1:]
        
    for unq in nuclei_uniques:
        idxs = np.where(nuclei_mask==unq)
        y_min, y_max = min(idxs[0]), max(idxs[0])
        x_min, x_max = min(idxs[1]), max(idxs[1])
        if x_min == 0 or y_min == 0 or x_max == (w - 1) or y_max == (h - 1):
            touching_border.append(unq)
    
    nuclei_areas = np.array([np.count_nonzero(nuclei_mask==unq) for unq in nuclei_uniques])
    not_touching_border_idxs = [i for i, unq in enumerate(nuclei_uniques) if unq not in touching_border]

    ignore_nuclei_idxs = nuclei_uniques[(nuclei_areas < np.median(nuclei_areas[not_touching_border_idxs])*nuclei_area_thresh) & np.isin(nuclei_uniques, touching_border)]
    
    return ignore_nuclei_idxs

def do_segmentation(ryb_image, blue_image):
    
    img_size = blue_image.shape[0] * SEGMENTATION_SCALE

    nuc_segmentation = segmentator.pred_nuclei([blue_image])
    cell_segmentation = segmentator.pred_cells([ryb_image], precombined=True)
    
    nuclei_mask, cell_mask = label_cell_vlad(nuc_segmentation[0], cell_segmentation[0], img_size, return_nuclei_label=True)
    
    cell_mask = cell_mask.astype(np.uint8)

    # Remove border cells with nuclei_area < median/2
    small_border_nuclei_idxs = remove_small_nuclei_on_border(nuclei_mask)

    # Remove cells without nuclei
    cell_uniques = np.unique(cell_mask)
    if 0 in cell_uniques:
        cell_uniques = cell_uniques[1:]
        
    nuclei_uniques = np.unique(nuclei_mask)
    if 0 in nuclei_uniques:
        nuclei_uniques = nuclei_uniques[1:]
            
    cells_without_nuclei_idxs = np.setdiff1d(np.union1d(cell_uniques, nuclei_uniques), np.intersect1d(cell_uniques, nuclei_uniques))
    
    if len(small_border_nuclei_idxs):
        small_border_cell_mask = np.array([cell_mask != i for i in small_border_nuclei_idxs]).prod(axis=0).astype(np.uint8)
    else:
        small_border_cell_mask = np.ones(cell_mask.shape, dtype=np.uint8)
        
    for ig in np.union1d(small_border_nuclei_idxs, cells_without_nuclei_idxs):
        cell_mask[cell_mask == ig] = 0
        
    return cell_mask, small_border_cell_mask

def get_bboxes(cell_mask, scale_factor):
    
    cell_bboxes = {}

    unqs = np.unique(cell_mask)
    if 0 in unqs:
        unqs = unqs[1:]
    
    bboxes = get_indices_sparse(cell_mask)
    
    for c in unqs:
        w, h = bboxes[c]
        x_0, x_1, y_0, y_1 = w.min(), w.max(), h.min(), h.max()
        
        bbox = [int(scale_factor*x_0), int((x_1+1)*scale_factor), int(y_0*scale_factor), int((y_1+1)*scale_factor)]
        cell_bboxes[c] = bbox
    
    return cell_bboxes

def do_preprocessing(rgb_image, cell_bboxes, cell_mask, small_border_cell_mask):
    
    crops_batch = []
    full_crops_batch = []
    
    arcface_image = full_resizer(PILImage.create(rgb_image))
    full_image = arcface_image * cv2.resize(small_border_cell_mask, (full_input_size, full_input_size), interpolation=cv2.INTER_NEAREST)[..., None]

    resized_mask = cv2.resize(cell_mask, (full_input_size, full_input_size), interpolation=cv2.INTER_NEAREST)
    
    for cell_idx, cell_bbox in cell_bboxes.items():
        
        temp_mask = (cell_mask == cell_idx).astype(np.uint8)
        
        # CROPS
        temp_crop = (rgb_image * temp_mask[..., None])[cell_bbox[0]:cell_bbox[1], cell_bbox[2]:cell_bbox[3]]
        crops_batch.append(crop_resizer(PILImage.create(temp_crop)))
        
        # FULL CROPS
        masked_img_resized = full_image * (resized_mask == cell_idx).astype(np.uint8)[..., None] 
        full_crops_batch.append(masked_img_resized)
    
    return crops_batch, full_crops_batch, full_image, arcface_image

In [None]:
idx = 0

image_id, ImageWidth, ImageHeight, PredictionString = test_df.iloc[idx]

# Load images
red_image, green_image, blue_image, yellow_image = load_RGBY_images(f"{directory}/test/{image_id}")

image_size = red_image.shape[0]
scale_factor = 1/SEGMENTATION_SCALE

# Segmentation
ryb_image = np.transpose(np.array([red_image, yellow_image, blue_image]), (1,2,0))
cell_mask, small_border_cell_mask = do_segmentation(ryb_image, blue_image)

# Get bboxes
cell_bboxes = get_bboxes(cell_mask, scale_factor)

# RLE
cell_mask = cv2.resize(cell_mask, (image_size, image_size), interpolation=cv2.INTER_NEAREST)
small_border_cell_mask = cv2.resize(small_border_cell_mask, (image_size, image_size), interpolation=cv2.INTER_NEAREST)
rles = [encode_binary_mask(cell_mask==cell_id) for cell_id in cell_bboxes.keys()]

# Preprocessing
rgb_image = np.transpose(np.array([red_image, green_image, blue_image]), (1,2,0))
crops_batch, full_crops_batch, full_image, arcface_image = do_preprocessing(rgb_image, cell_bboxes, cell_mask, small_border_cell_mask)

with torch.no_grad():

    # Crops Classification
    crop_dl = crop_dls.test_dl(crops_batch)

    y_crops = []
    for crop_model in crop_models:
        preds = []
        for batch in crop_dl:
            preds.extend(crop_model.model(batch[0]))
        preds = torch.sigmoid(torch.stack(preds))
        y_crops.append(preds)

    y_crops = torch.mean(torch.stack(y_crops), axis=0)

    # Full Crops Classification
    full_dl = full_dls.test_dl(full_crops_batch)

    y_full_crops = []
    for full_model in full_models:
        preds = []
        for batch in full_dl:
            preds.extend(full_model.model(batch[0]))
        preds = torch.sigmoid(torch.stack(preds))
        y_full_crops.append(preds)

    y_full_crops = torch.mean(torch.stack(y_full_crops), axis=0)

    # Full Image Classification
    full_image_dl = full_dls.test_dl([full_image])

    y_full_image = []
    for full_model in full_models:
        preds = []
        for batch in full_image_dl:
            preds.extend(full_model.model(batch[0]))
        preds = torch.sigmoid(torch.stack(preds))
        y_full_image.append(preds)

    y_full_image = torch.mean(torch.stack(y_full_image), axis=0).flatten()

    # Arcface Prediction
    arcface_image_dl = full_dls.test_dl([arcface_image])

    preds = []
    for batch in arcface_image_dl:
        preds.extend(arcface_model.model(batch[0]))
    y_arcface = torch.stack(preds).cpu().numpy()

    y_arcface /= np.linalg.norm(y_arcface, axis=1, keepdims=True)

all_dists, all_topk_idxs = gpu_index.search(x=y_arcface, k=1)

if all_dists.item() >= ARCFACE_THRESH:
    y_full_image = embeddings_labels[all_topk_idxs.item()]

for cell_id in range(len(cell_bboxes)):
    for class_id in range(num_classes):
        conf = y_crops[cell_id, class_id]/3 + y_full_crops[cell_id, class_id]/3 + y_full_image[class_id]/3

In [None]:
plt.imshow(cell_mask)

In [None]:
fig, axs = plt.subplots(figsize=(15, 13), sharex=True, sharey=True, ncols=5, nrows=len(crops_batch) // 5 + 1 if len(crops_batch) % 5 else len(crops_batch) // 5)

for i, (cr, k) in enumerate(zip(crops_batch, cell_bboxes.keys())):
    axs.reshape(-1)[i].imshow(cr)
    axs.reshape(-1)[i].set_title(f'Cell {k}')
    
if len(crops_batch) % 5:
    for i in range(len(crops_batch), len(crops_batch) + 5 - len(crops_batch) % 5):
        axs.reshape(-1)[i].set_axis_off()
        
plt.tight_layout()
plt.show()

In [None]:
fig, axs = plt.subplots(figsize=(15, 13), sharex=True, sharey=True, ncols=5, nrows=len(full_crops_batch) // 5 + 1 if len(full_crops_batch) % 5 else len(full_crops_batch) // 5)

for i, (cr, k) in enumerate(zip(full_crops_batch, cell_bboxes.keys())):
    axs.reshape(-1)[i].imshow(cr)
    axs.reshape(-1)[i].set_title(f'Cell {k}')
    
if len(full_crops_batch) % 5:
    for i in range(len(full_crops_batch), len(full_crops_batch) + 5 - len(full_crops_batch) % 5):
        axs.reshape(-1)[i].set_axis_off()
        
plt.tight_layout()
plt.show()

In [None]:
plt.imshow(full_image)

In [None]:
plt.imshow(arcface_image)

In [None]:
with open('submission.csv', 'w') as outf:
        
    print('ID,ImageWidth,ImageHeight,PredictionString', file=outf)
    
    for idx in tqdm(range(len(test_df)), total=len(test_df)):
        
        image_id, ImageWidth, ImageHeight, PredictionString = test_df.iloc[idx]
                
        if (PUBLIC_ONLY and (image_id in public_image_ids)) or not PUBLIC_ONLY:
            
            # Load images
            red_image, green_image, blue_image, yellow_image = load_RGBY_images(f"{directory}/test/{image_id}")
            
            image_size = red_image.shape[0]
            scale_factor = 1/SEGMENTATION_SCALE
            
            # Segmentation
            ryb_image = np.transpose(np.array([red_image, yellow_image, blue_image]), (1,2,0))
            cell_mask, small_border_cell_mask = do_segmentation(ryb_image, blue_image)

            # Get bboxes
            cell_bboxes = get_bboxes(cell_mask, scale_factor)
            
            # RLE
            cell_mask = cv2.resize(cell_mask, (image_size, image_size), interpolation=cv2.INTER_NEAREST)
            small_border_cell_mask = cv2.resize(small_border_cell_mask, (image_size, image_size), interpolation=cv2.INTER_NEAREST)
            rles = [encode_binary_mask(cell_mask==cell_id) for cell_id in cell_bboxes.keys()]
            
            # Preprocessing
            rgb_image = np.transpose(np.array([red_image, green_image, blue_image]), (1,2,0))
            crops_batch, full_crops_batch, full_image, arcface_image = do_preprocessing(rgb_image, cell_bboxes, cell_mask, small_border_cell_mask)

            with torch.no_grad():
                
                # Crops Classification
                crop_dl = crop_dls.test_dl(crops_batch)

                y_crops = []
                for crop_model in crop_models:
                    preds = []
                    for batch in crop_dl:
                        preds.extend(crop_model.model(batch[0]))
                    preds = torch.sigmoid(torch.stack(preds))
                    y_crops.append(preds)

                y_crops = torch.mean(torch.stack(y_crops), axis=0)

                # Full Crops Classification
                full_dl = full_dls.test_dl(full_crops_batch)

                y_full_crops = []
                for full_model in full_models:
                    preds = []
                    for batch in full_dl:
                        preds.extend(full_model.model(batch[0]))
                    preds = torch.sigmoid(torch.stack(preds))
                    y_full_crops.append(preds)

                y_full_crops = torch.mean(torch.stack(y_full_crops), axis=0)

                # Arcface Prediction
                arcface_image_dl = full_dls.test_dl([arcface_image])

                preds = []
                for batch in arcface_image_dl:
                    preds.extend(arcface_model.model(batch[0]))
                y_arcface = torch.stack(preds).cpu().numpy()

                y_arcface /= np.linalg.norm(y_arcface, axis=1, keepdims=True)

                all_dists, all_topk_idxs = gpu_index.search(x=y_arcface, k=1)

                if all_dists.item() >= ARCFACE_THRESH:
                    y_full_image = embeddings_labels[all_topk_idxs.item()]
                else:
                    # Full Image Classification
                    full_image_dl = full_dls.test_dl([full_image])

                    y_full_image = []
                    for full_model in full_models:
                        preds = []
                        for batch in full_image_dl:
                            preds.extend(full_model.model(batch[0]))
                        preds = torch.sigmoid(torch.stack(preds))
                        y_full_image.append(preds)

                    y_full_image = torch.mean(torch.stack(y_full_image), axis=0).flatten()
                            
            # Submission
            pred_strs = []
        
            for cell_id in range(len(cell_bboxes)):
                
                rle = rles[cell_id]
                
                for class_id in range(num_classes):
                    
                    conf = y_crops[cell_id, class_id]/3 + y_full_crops[cell_id, class_id]/3 + y_full_image[class_id]/3
                    pred_strs.append(f"{class_id} {conf} {rle}")
                    
            PredictionString = ' '.join(pred_strs)
                            
        print(f"{image_id},{ImageWidth},{ImageHeight},{PredictionString}", file=outf)