In [None]:
!pip install "../input/pycocotools/pycocotools-2.0-cp37-cp37m-linux_x86_64.whl"
!pip install "../input/hpacellsegmentatorraman/HPA-Cell-Segmentation"
!pip install "../input/hpapytorchzoozip/pytorch_zoo-master"
!pip install "../input/localhpapackage/hpa-single-cell/"

In [None]:
from copy import deepcopy
import os

import albumentations as A
from albumentations.pytorch import ToTensorV2
from cv2 import resize, INTER_NEAREST
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from PIL import Image
import torch
from torch.nn import Conv2d, Sequential, ReLU, AdaptiveMaxPool2d, Flatten
from tqdm.notebook import tqdm

### HPA Local Code

In [None]:
from hpa.data import N_CLASSES, CHANNEL_MEANS, CHANNEL_STDS
from hpa.data.dataset import NEGATIVE_LABEL, load_channels
from hpa.data.misc import parse_string_label, remove_empty_masks
from hpa.data.transforms import ToCellMasks
from hpa.infer.cells import get_cells
from hpa.infer.label import *
from hpa.model.bestfitting.densenet import DensenetClass
from hpa.model.localizers import *
from hpa.segment import HPACellSegmenter
from hpa.utils.plot import *

### Inference

In [None]:
PROB_CUTOFF = 0.05
MIN_AGREEMENT = 3

In [None]:
IMG_DIM = 1536
DOWNSIZE_SCALE = 16
FEATURE_MAP_DIM = int(IMG_DIM / DOWNSIZE_SCALE)

FEATURE_ROI_METHOD = 'max_and_avg'
POSITION_ENCODING = True
POSITION_ENC_SHAPE = 8
NUM_ENCODERS = 4
EMB_DIM = 1024
NUM_HEADS = 4

if FEATURE_ROI_METHOD == 'max_and_avg':
    cell_feature_dim = 2048
else:
    cell_feature_dim = 1024
if POSITION_ENCODING:
    cell_feature_dim += POSITION_ENC_SHAPE * POSITION_ENC_SHAPE
print(f'Features extracted per cell = {cell_feature_dim}')

ROOT_DIR = '/kaggle/input/hpa-single-cell-image-classification'
IMG_DIR = os.path.join(ROOT_DIR, 'test')

DEVICE = 'cuda'
MODEL_PATHS = [
    '/kaggle/input/hparoimodels/roi15-model9.pth',
    '/kaggle/input/hparoimodels/roi16-model9.pth',
    '/kaggle/input/hparoimodels/roi17-model9.pth'
]

NUCLEI_PATH = '/kaggle/input/hpacellsegmentatormodelweights/dpn_unet_nuclei_v1.pth'
CELL_PATH = '/kaggle/input/hpacellsegmentatormodelweights/dpn_unet_cell_3ch_v1.pth'

In [None]:
sub_df = pd.read_csv(os.path.join(ROOT_DIR, 'sample_submission.csv'))

In [None]:
def load_model(model_path):
    densenet_model = DensenetClass(in_channels=4, dropout=True)
    densenet_encoder = Sequential(densenet_model.conv1,
                                  densenet_model.encoder2,
                                  densenet_model.encoder3,
                                  densenet_model.encoder4,
                                  densenet_model.encoder5,
                                  ReLU())
    
    feature_roi_pool = RoIPool(method=FEATURE_ROI_METHOD, 
                               positions=POSITION_ENCODING, 
                               tgt_shape=POSITION_ENC_SHAPE)
    
    upsample_fn = Upsample(scale_factor=2, mode='nearest')

    model = CellTransformer(backbone=densenet_encoder,
                            feature_roi=feature_roi_pool,
                            num_encoders=NUM_ENCODERS,
                            emb_dim=EMB_DIM,
                            num_heads=NUM_HEADS,
                            upsample=upsample_fn,
                            cell_feature_dim=cell_feature_dim,
                            device=DEVICE)
    
    model_state = torch.load(model_path, map_location=DEVICE)
    model.load_state_dict(model_state)
    model = model.to(DEVICE)
    model = model.eval()
    return model

In [None]:
models = [load_model(model_path) for model_path in MODEL_PATHS]

In [None]:
segmenter = HPACellSegmenter(NUCLEI_PATH, CELL_PATH, device=DEVICE)

In [None]:
sub_df.tail(5)

In [None]:
def assign_cell_labels(cells, cell_probs, prob_cutoff):
    for cell, probs in zip(cells, cell_probs):
        cell_class_idx = np.where(probs > prob_cutoff)[0]
        if len(cell_class_idx) == 0:
            cell.add_prediction(NEGATIVE_LABEL, 0.5)
        else:
            for label_id in cell_class_idx:
                cell.add_prediction(label_id, probs[label_id])
    return cells

In [None]:
def assign_cell_labels_ensemble(cells, ensemble_probs, prob_cutoff, min_agreement=2):
    for cell, cell_probs in zip(cells, ensemble_probs):
        assigned = False
        for label_id, label_probs in enumerate(cell_probs):
            prob_avg = label_probs.mean()
            pred_idx, = np.where(label_probs > PROB_CUTOFF)
            if len(pred_idx) >= min_agreement:
                cell.add_prediction(label_id, prob_avg)
                assigned = True
        if not assigned:
            cell.add_prediction(NEGATIVE_LABEL, 0.5)
    return cells

In [None]:
normalize_fn = A.Normalize(mean=CHANNEL_MEANS, std=CHANNEL_STDS, max_pixel_value=255)
resize_seg_fn = A.Resize(FEATURE_MAP_DIM, FEATURE_MAP_DIM, interpolation=INTER_NEAREST)
cell_mask_fn = ToCellMasks()

In [None]:
predictions = []
num_missed_cells = 0
for img_id, img_dim in tqdm(zip(sub_df['ID'], sub_df['ImageWidth']), total=len(sub_df)):

    # load the image
    channels = load_channels(img_id, IMG_DIR)
    img_full = np.dstack([channels['red'], channels['green'], channels['blue'], channels['yellow']])
    img_reduced = resize(img_full, (IMG_DIM, IMG_DIM))
    img_shape = (img_dim, img_dim)
    
    # segment the cells
    seg = segmenter(img_reduced[..., 0], img_reduced[..., 3], img_reduced[..., 2])
    seg = resize(seg, img_shape, interpolation=INTER_NEAREST)
    cells = get_cells(seg)
    
    # prep the image
    x = normalize_fn(image=img_reduced)['image']
    x = ToTensorV2()(image=x)['image']
    x = x.float().to(DEVICE)
    
    # create the individual cell masks
    subseg = resize_seg_fn(image=seg)['image']
    cell_masks = cell_mask_fn(image=subseg)['image']
    cell_masks = torch.from_numpy(cell_masks)
    cell_masks = cell_masks.to(DEVICE)

    # count the cells
    num_cells = torch.LongTensor([len(cell_masks)])
    num_cells = num_cells.to(DEVICE)
    
    # calculate the image level class probabilities
    with torch.no_grad():
        class_probs_ensemble = []
        cell_probs_ensemble = []
        for model in models:
            logits, cell_logits = model(x.unsqueeze(0), cell_masks, num_cells, return_cells=True)

            class_probs = torch.sigmoid(logits).cpu().numpy().squeeze()
            class_probs_ensemble.append(class_probs)

            cell_probs = torch.sigmoid(cell_logits).cpu().numpy()
            cell_probs_ensemble.append(cell_probs)
            
    ensemble_probs = np.stack(cell_probs_ensemble).transpose((1, 2, 0))
        
    # identify the cells which get squashed from the segmentation resize and remove those cells
    missed_cell_ids = set(np.unique(seg)).difference(set(np.unique(subseg)))
    cells = [cell for cell in cells if cell.cell_id not in missed_cell_ids]
    num_missed_cells += len(missed_cell_ids)
    
    # assign the cell labels
    cells = assign_cell_labels_ensemble(cells, ensemble_probs, PROB_CUTOFF, min_agreement=MIN_AGREEMENT)
    
    # gather the prediction strings
    pred_strings = [cell.get_prediction_string() for cell in cells]
    pred_str = ' '.join(pred_strings)
    predictions.append(pred_str)

In [None]:
print(num_missed_cells)

In [None]:
sub_df['PredictionString'] = predictions
sub_df.to_csv('submission.csv', index=None)
sub_df.head()

In [None]:
get_percent_labeled_cells(cells)

In [None]:
tgt_img = Image.fromarray((img_reduced[..., 1]).astype(np.uint8))
ref_img = Image.fromarray((img_reduced[..., [0, 3, 2]]).astype(np.uint8))
plot_example(ref_img, tgt_img, seg)

In [None]:
for class_probs in class_probs_ensemble:
    tgt_class_idx = []
    for i, p in enumerate(class_probs):
        if p > PROB_CUTOFF:
            tgt_class_idx.append(i)
    ax = plot_predicted_probs(class_probs, tgt_class_idx)
    plt.show()

In [None]:
for cell_id, probs in enumerate(np.stack(cell_probs_ensemble).transpose((1, 2, 0))):
    cell_class_idx = []
    probs = probs.ravel()
    for i, p in enumerate(probs):
        if p > PROB_CUTOFF:
            cell_class_idx.append(i)
    ax = plot_predicted_probs(probs, cell_class_idx)

    xticks = np.arange(1, 3 * 18 + 1, 3)
    xtick_labels = range(18)
    ax.set_xticks(xticks)
    ax.set_xticklabels(xtick_labels)
    
    ax.set_title(f'Cell {cell_id + 1}')
    plt.show()

In [None]:
overlay_cell_assignments(cells, seg)