In [None]:
!pip install -q ../input/pycocotools/pycocotools-2.0-cp37-cp37m-linux_x86_64.whl

In [None]:
import sys
sys.path = [
    '../input/smp20210127/pytorch-image-models-master/pytorch-image-models-master',  # timm
    '../input/hpapytorchzoozip/pytorch_zoo-master/',
    '../input/hpa-seg/HPA-Cell-Segmentation/hpacellseg',
    '../input/hpafinal'
] + sys.path
import warnings
warnings.filterwarnings("ignore")

In [None]:
import os
import random
import numpy as np
import pandas as pd
from tqdm import tqdm

from skimage import transform, util
import glob
import cv2
from matplotlib.patches import Rectangle
import matplotlib.pyplot as plt
import base64
from pycocotools import _mask as coco_mask
import typing as t
import zlib
import os.path
import urllib
import zipfile
import scipy.ndimage as ndi
from skimage import filters, measure, segmentation
from skimage.morphology import (binary_erosion, closing, disk,
                                remove_small_holes, remove_small_objects)

import albumentations 
import torch.nn.functional as F
import torch
from torch import nn
from torch.utils.data import DataLoader,Dataset
from torch.utils.data import TensorDataset, DataLoader, Dataset
import timm
from model import class_densenet121_dropout  # daishu
import PIL
import gc
import torch.cuda.amp as amp

device = torch.device('cuda')

In [None]:
seg_size = 512
seg_bs = 8388608 // seg_size ** 2
seg_TTA = 8
small_th_dict = {
    2048: 500,
    1024: 125,
    512 : 32,
}
small_th = small_th_dict[seg_size]
mask_dir = 'test_mask_npz_fullsize_cell_mask'

model_dirs = [
    '../input/bo-hpa-models',  # bo
    '../input/bo-hpa-models-3d256',  # bo
    '../input/hpa-models',  # gary
    '../input/hpa-models-qishen',
]

TTA = {
    'orig': 2,
    'masked': 3,
    'cells_128': 8,
    'cells_256': 6,
    'center_cells': 3,
}
n_ch = 4
num_classes = 19
image_size = 512

orig_mean = [239.93038613, 246.05603962, 250.16871503, 250.50623682]

data_dir = '../input/hpa-single-cell-image-classification/test/'
df_sub = pd.read_csv('../input/hpa-single-cell-image-classification/sample_submission.csv')
df_sub = df_sub.head(10) if df_sub.shape[0] == 559 else df_sub
df_sub.shape

In [None]:
os.makedirs(mask_dir, exist_ok=True)

In [None]:
NORMALIZE = {"mean": [124 / 255, 117 / 255, 104 / 255], "std": [1 / (0.0167 * 255)] * 3}


def get_trans_seg(img, I, rev=False):
    if I >= 4 and not rev:
        img = img.transpose(2,3)
    if I % 4 == 0:
        pass
    elif I % 4 == 1:
        img = img.flip(2)
    elif I % 4 == 2:
        img = img.flip(3)
    elif I % 4 == 3:
        img = img.flip(2).flip(3)
    if I >= 4 and rev:
        img = img.transpose(2,3)
    return img


class CellSegmentator(object):
    """Uses pretrained DPN-Unet models to segment cells from images."""

    def __init__(
        self,
        nuclei_model="../input/hpa-seg/dpn_unet_nuclei_v1.pth",
        cell_model="../input/hpa-seg/dpn_unet_cell_3ch_v1.pth",
        scale_factor=1.0,
        device="cuda",
#         padding=False,
        multi_channel_model=True,
    ):
        """Class for segmenting nuclei and whole cells from confocal microscopy images.
        It takes lists of images and returns the raw output from the
        specified segmentation model. Models can be automatically
        downloaded if they are not already available on the system.
        When working with images from the Huan Protein Cell atlas, the
        outputs from this class' methods are well combined with the
        label functions in the utils module.
        Note that for cell segmentation, there are two possible models
        available. One that works with 2 channeled images and one that
        takes 3 channels.
        Keyword arguments:
        nuclei_model -- A loaded torch nuclei segmentation model or the
                        path to a file which contains such a model.
                        If the argument is a path that points to a non-existant file,
                        a pretrained nuclei_model is going to get downloaded to the
                        specified path (default: './nuclei_model.pth').
        cell_model -- A loaded torch cell segmentation model or the
                      path to a file which contains such a model.
                      The cell_model argument can be None if only nuclei
                      are to be segmented (default: './cell_model.pth').
        scale_factor -- How much to scale images before they are fed to
                        segmentation models. Segmentations will be scaled back
                        up by 1/scale_factor to match the original image
                        (default: 0.25).
        device -- The device on which to run the models.
                  This should either be 'cpu' or 'cuda' or pointed cuda
                  device like 'cuda:0' (default: 'cuda').
        padding -- Whether to add padding to the images before feeding the
                   images to the network. (default: False).
        multi_channel_model -- Control whether to use the 3-channel cell model or not.
                               If True, use the 3-channel model, otherwise use the
                               2-channel version (default: True).
        """
        if device != "cuda" and device != "cpu" and "cuda" not in device:
            raise ValueError(f"{device} is not a valid device (cuda/cpu)")
        if device != "cpu":
            try:
                assert torch.cuda.is_available()
            except AssertionError:
                print("No GPU found, using CPU.", file=sys.stderr)
                device = "cpu"
        self.device = device

        if isinstance(nuclei_model, str):
            if not os.path.exists(nuclei_model):
                print(
                    f"Could not find {nuclei_model}. Downloading it now",
                    file=sys.stderr,
                )
                raise
            nuclei_model = torch.load(
                nuclei_model, map_location=torch.device(self.device)
            )
        if isinstance(nuclei_model, torch.nn.DataParallel) and device == "cpu":
            nuclei_model = nuclei_model.module

        self.nuclei_model = nuclei_model.to(self.device).eval()

        self.multi_channel_model = multi_channel_model
        if isinstance(cell_model, str):
            if not os.path.exists(cell_model):
                print(
                    f"Could not find {cell_model}. Downloading it now", file=sys.stderr
                )
                raise

            cell_model = torch.load(cell_model, map_location=torch.device(self.device))
        self.cell_model = cell_model.to(self.device).eval()
        self.scale_factor = scale_factor
#         self.padding = padding

    def _image_conversion(self, images):
        """Convert/Format images to RGB image arrays list for cell predictions.
        Intended for internal use only.
        Keyword arguments:
        images -- list of lists of image paths/arrays. It should following the
                 pattern if with er channel input,
                 [
                     [microtubule_path0/image_array0, microtubule_path1/image_array1, ...],
                     [er_path0/image_array0, er_path1/image_array1, ...],
                     [nuclei_path0/image_array0, nuclei_path1/image_array1, ...]
                 ]
                 or if without er input,
                 [
                     [microtubule_path0/image_array0, microtubule_path1/image_array1, ...],
                     None,
                     [nuclei_path0/image_array0, nuclei_path1/image_array1, ...]
                 ]
        """
        microtubule_imgs, er_imgs, nuclei_imgs = images
        if self.multi_channel_model:
            if not isinstance(er_imgs, list):
                raise ValueError("Please speicify the image path(s) for er channels!")
        else:
            if not er_imgs is None:
                raise ValueError(
                    "second channel should be None for two channel model predition!"
                )

        if not isinstance(microtubule_imgs, list):
            raise ValueError("The microtubule images should be a list")
        if not isinstance(nuclei_imgs, list):
            raise ValueError("The microtubule images should be a list")

        if er_imgs:
            if not len(microtubule_imgs) == len(er_imgs) == len(nuclei_imgs):
                raise ValueError("The lists of images needs to be the same length")
        else:
            if not len(microtubule_imgs) == len(nuclei_imgs):
                raise ValueError("The lists of images needs to be the same length")

        if not all(isinstance(item, np.ndarray) for item in microtubule_imgs):
            microtubule_imgs = [
                os.path.expanduser(item) for _, item in enumerate(microtubule_imgs)
            ]
            nuclei_imgs = [
                os.path.expanduser(item) for _, item in enumerate(nuclei_imgs)
            ]

            microtubule_imgs = list(
                map(lambda item: imageio.imread(item), microtubule_imgs)
            )
            nuclei_imgs = list(map(lambda item: imageio.imread(item), nuclei_imgs))
            if er_imgs:
                er_imgs = [os.path.expanduser(item) for _, item in enumerate(er_imgs)]
                er_imgs = list(map(lambda item: imageio.imread(item), er_imgs))

        if not er_imgs:
            er_imgs = [
                np.zeros(item.shape, dtype=item.dtype)
                for _, item in enumerate(microtubule_imgs)
            ]
        cell_imgs = list(
            map(
                lambda item: np.dstack((item[0], item[1], item[2])),
                list(zip(microtubule_imgs, er_imgs, nuclei_imgs)),
            )
        )

        return cell_imgs

    def pred_nuclei(self, images):
        """Predict the nuclei segmentation.
        Keyword arguments:
        images -- A list of image arrays or a list of paths to images.
                  If as a list of image arrays, the images could be 2d images
                  of nuclei data array only, or must have the nuclei data in
                  the blue channel; If as a list of file paths, the images
                  could be RGB image files or gray scale nuclei image file
                  paths.
        Returns:
        predictions -- A list of predictions of nuclei segmentation for each nuclei image.
        """

        def _preprocess(image):
            self.target_shape = image.shape
            if len(image.shape) == 2:
                image = np.dstack((image, image, image))
            image = transform.rescale(image, self.scale_factor, multichannel=True)
            nuc_image = np.dstack((image[..., 2], image[..., 2], image[..., 2]))
            nuc_image = nuc_image.transpose([2, 0, 1])
            return nuc_image

        def _segment_helper(imgs):
            with torch.no_grad():
                mean = torch.as_tensor(NORMALIZE["mean"], device=self.device)
                std = torch.as_tensor(NORMALIZE["std"], device=self.device)
                imgs = torch.tensor(imgs).float()
                imgs = imgs.to(self.device)
                imgs = imgs.sub_(mean[:, None, None]).div_(std[:, None, None])

                imgs = torch.stack([get_trans_seg(self.nuclei_model(get_trans_seg(imgs, I)), I, True).softmax(1) for I in range(1)], 0).mean(0)
#                 imgs = self.nuclei_model(imgs).softmax(1)
#                 imgs = F.softmax(imgs, dim=1)
                return imgs

        preprocessed_imgs = list(map(_preprocess, images))
        bs = 24
        predictions = []
        for i in range(0, len(preprocessed_imgs), bs):
            start = i
            end = min(len(preprocessed_imgs), i+bs)
            x = preprocessed_imgs[start:end]
            pred = _segment_helper(x).cpu().numpy()
            predictions.append(pred)
        predictions = list(np.concatenate(predictions, axis=0))
        predictions = map(util.img_as_ubyte, predictions)
        predictions = list(map(self._restore_scaling_padding, predictions))
        return predictions

    def _restore_scaling_padding(self, n_prediction):
        """Restore an image from scaling and padding.
        This method is intended for internal use.
        It takes the output from the nuclei model as input.
        """
        n_prediction = n_prediction.transpose([1, 2, 0])
        if not self.scale_factor == 1:
            n_prediction[..., 0] = 0
            n_prediction = cv2.resize(
                n_prediction,
                (self.target_shape[0], self.target_shape[1]),
                interpolation=cv2.INTER_AREA,
            )
        return n_prediction

    def pred_cells(self, images, precombined=False):
        """Predict the cell segmentation for a list of images.
        Keyword arguments:
        images -- list of lists of image paths/arrays. It should following the
                  pattern if with er channel input,
                  [
                      [microtubule_path0/image_array0, microtubule_path1/image_array1, ...],
                      [er_path0/image_array0, er_path1/image_array1, ...],
                      [nuclei_path0/image_array0, nuclei_path1/image_array1, ...]
                  ]
                  or if without er input,
                  [
                      [microtubule_path0/image_array0, microtubule_path1/image_array1, ...],
                      None,
                      [nuclei_path0/image_array0, nuclei_path1/image_array1, ...]
                  ]
                  The ER channel is required when multichannel is True
                  and required to be None when multichannel is False.
                  The images needs to be of the same size.
        precombined -- If precombined is True, the list of images is instead supposed to be
                       a list of RGB numpy arrays (default: False).
        Returns:
        predictions -- a list of predictions of cell segmentations.
        """

        def _preprocess(image):
            self.target_shape = image.shape
            if not len(image.shape) == 3:
                raise ValueError("image should has 3 channels")
            cell_image = transform.rescale(image, self.scale_factor, multichannel=True)
            cell_image = cell_image.transpose([2, 0, 1])
            return cell_image

        def _segment_helper(imgs):
            with torch.no_grad():
                mean = torch.as_tensor(NORMALIZE["mean"], device=self.device)
                std = torch.as_tensor(NORMALIZE["std"], device=self.device)
                imgs = torch.tensor(imgs).float()
                imgs = imgs.to(self.device)
                imgs = imgs.sub_(mean[:, None, None]).div_(std[:, None, None])

                imgs = torch.stack([get_trans_seg(self.cell_model(get_trans_seg(imgs, I)), I, True).softmax(1) for I in range(seg_TTA)], 0).mean(0)
#                 imgs = F.softmax(imgs, dim=1)
                return imgs

        if not precombined:
            images = self._image_conversion(images)
        preprocessed_imgs = list(map(_preprocess, images))
        bs = 24
        predictions = []
        for i in range(0, len(preprocessed_imgs), bs):
            start = i
            end = min(len(preprocessed_imgs), i+bs)
            x = preprocessed_imgs[start:end]
            pred = _segment_helper(x).cpu().numpy()
            predictions.append(pred)
        ###
#         return predictions
        ###
        predictions = list(np.concatenate(predictions, axis=0))
        predictions = map(self._restore_scaling_padding, predictions)
        predictions = list(map(util.img_as_ubyte, predictions))

        return predictions


def __fill_holes(image):
    """Fill_holes for labelled image, with a unique number."""
    boundaries = segmentation.find_boundaries(image)
    image = np.multiply(image, np.invert(boundaries))
    image = ndi.binary_fill_holes(image > 0)
    image = ndi.label(image)[0]
    return image


def label_cell(nuclei_pred, cell_pred):
    """Label the cells and the nuclei.
    Keyword arguments:
    nuclei_pred -- a 3D numpy array of a prediction from a nuclei image.
    cell_pred -- a 3D numpy array of a prediction from a cell image.
    Returns:
    A tuple containing:
    nuclei-label -- A nuclei mask data array.
    cell-label  -- A cell mask data array.
    0's in the data arrays indicate background while a continous
    strech of a specific number indicates the area for a specific
    cell.
    The same value in cell mask and nuclei mask refers to the identical cell.
    NOTE: The nuclei labeling from this function will be sligthly
    different from the values in :func:`label_nuclei` as this version
    will use information from the cell-predictions to make better
    estimates.
    """
    def __wsh(
        mask_img,
        threshold,
        border_img,
        seeds,
        threshold_adjustment=0.35,
        small_object_size_cutoff=10,
    ):
        img_copy = np.copy(mask_img)
        m = seeds * border_img  # * dt
        img_copy[m <= threshold + threshold_adjustment] = 0
        img_copy[m > threshold + threshold_adjustment] = 1
        img_copy = img_copy.astype(np.bool)
        img_copy = remove_small_objects(img_copy, small_object_size_cutoff).astype(
            np.uint8
        )

        mask_img[mask_img <= threshold] = 0
        mask_img[mask_img > threshold] = 1
        mask_img = mask_img.astype(np.bool)
        mask_img = remove_small_holes(mask_img, 63)
        mask_img = remove_small_objects(mask_img, 1).astype(np.uint8)
        markers = ndi.label(img_copy, output=np.uint32)[0]
        labeled_array = segmentation.watershed(
            mask_img, markers, mask=mask_img, watershed_line=True
        )
        return labeled_array

    nuclei_label = __wsh(
        nuclei_pred[..., 2] / 255.0,
        0.4,
        1 - (nuclei_pred[..., 1] + cell_pred[..., 1]) / 255.0 > 0.05,
        nuclei_pred[..., 2] / 255,
        threshold_adjustment=-0.25,
        small_object_size_cutoff=small_th,
    )

    # for hpa_image, to remove the small pseduo nuclei
    nuclei_label = remove_small_objects(nuclei_label, 157)
    nuclei_label = measure.label(nuclei_label)
    # this is to remove the cell borders' signal from cell mask.
    # could use np.logical_and with some revision, to replace this func.
    # Tuned for segmentation hpa images
    threshold_value = max(0.22, filters.threshold_otsu(cell_pred[..., 2] / 255) * 0.5)
    # exclude the green area first
    cell_region = np.multiply(
        cell_pred[..., 2] / 255 > threshold_value,
        np.invert(np.asarray(cell_pred[..., 1] / 255 > 0.05, dtype=np.int8)),
    )
    sk = np.asarray(cell_region, dtype=np.int8)
    distance = np.clip(cell_pred[..., 2], 255 * threshold_value, cell_pred[..., 2])
    cell_label = segmentation.watershed(-distance, nuclei_label, mask=sk)
    cell_label = remove_small_objects(cell_label, 344).astype(np.uint8)
    selem = disk(2)
    cell_label = closing(cell_label, selem)
    cell_label = __fill_holes(cell_label)
    # this part is to use green channel, and extend cell label to green channel
    # benefit is to exclude cells clear on border but without nucleus
    sk = np.asarray(
        np.add(
            np.asarray(cell_label > 0, dtype=np.int8),
            np.asarray(cell_pred[..., 1] / 255 > 0.05, dtype=np.int8),
        )
        > 0,
        dtype=np.int8,
    )
    cell_label = segmentation.watershed(-distance, cell_label, mask=sk)
    cell_label = __fill_holes(cell_label)
    cell_label = np.asarray(cell_label > 0, dtype=np.uint8)
    cell_label = measure.label(cell_label)
    cell_label = remove_small_objects(cell_label, 344)
    cell_label = measure.label(cell_label)
    cell_label = np.asarray(cell_label, dtype=np.uint16)
    nuclei_label = np.multiply(cell_label > 0, nuclei_label) > 0
    nuclei_label = measure.label(nuclei_label)
    nuclei_label = remove_small_objects(nuclei_label, 157)
    nuclei_label = np.multiply(cell_label, nuclei_label > 0)

    return nuclei_label, cell_label

In [None]:
## the cell seg model
cellsegmentor = CellSegmentator()

In [None]:
class HPADatasetSeg(Dataset):
    def __init__(self, df, root='../input/hpa-single-cell-image-classification/test/'):
        self.df = df.reset_index(drop=True)
        self.root = root

    def __len__(self):
        return len(self.df)
        
    def __getitem__(self, index):

        row = self.df.loc[index]
        r = os.path.join(self.root, f'{row.ID}_red.png')
        y = os.path.join(self.root, f'{row.ID}_yellow.png')
        b = os.path.join(self.root, f'{row.ID}_blue.png')
        r = cv2.imread(r, 0)
        y = cv2.imread(y, 0)
        b = cv2.imread(b, 0)
        target_shape = r.shape
        gray_image = cv2.resize(b, (seg_size, seg_size))
        rgb_image = cv2.resize(np.stack((r, y, b), axis=2), (seg_size, seg_size))

        return gray_image, rgb_image, target_shape, row.ID
    
    
def collate_fn(batch):
    gray = []
    rgb_image = []
    target_shape = []
    IDs = []
    for data_point in batch:
        gray.append(data_point[0])
        rgb_image.append(data_point[1])
        target_shape.append(data_point[2])
        IDs.append(data_point[3])
    return gray, rgb_image, target_shape, IDs


dataset_seg = HPADatasetSeg(df_sub)
loader_seg = DataLoader(dataset_seg, batch_size=seg_bs, num_workers=2, collate_fn=collate_fn)

In [None]:
# cell_segmentations = cellsegmentor.pred_cells(rgb, precombined=True)
# plt.imshow(cell_segmentations[0][0].transpose(1,2,0))

In [None]:
for gray, rgb, target_shapes, IDs in tqdm(loader_seg):
    nuc_segmentations = cellsegmentor.pred_nuclei(gray)
    cell_segmentations = cellsegmentor.pred_cells(rgb, precombined=True)
    for data_id, target_shape, nuc_seg, cell_seg in zip(IDs, target_shapes, nuc_segmentations, cell_segmentations):
        nuc, cell = label_cell(nuc_seg, cell_seg)
        np.savez_compressed(f'./{mask_dir}/{data_id}', cell.astype(np.uint8))
print('---- finish mask write ----')

In [None]:
del cellsegmentor
gc.collect()
torch.cuda.empty_cache()

In [None]:
!nvidia-smi

In [None]:
def encode_binary_mask(mask: np.ndarray) -> t.Text:
    """Converts a binary mask into OID challenge encoding ascii text."""

    # check input mask --
    if mask.dtype != np.bool:
        raise ValueError(
            "encode_binary_mask expects a binary mask, received dtype == %s" %
            mask.dtype)

    mask = np.squeeze(mask)
    if len(mask.shape) != 2:
        raise ValueError(
            "encode_binary_mask expects a 2d mask, received shape == %s" %
            mask.shape)

    # convert input mask to expected COCO API input --
    mask_to_encode = mask.reshape(mask.shape[0], mask.shape[1], 1)
    mask_to_encode = mask_to_encode.astype(np.uint8)
    mask_to_encode = np.asfortranarray(mask_to_encode)

    # RLE encode mask --
    encoded_mask = coco_mask.encode(mask_to_encode)[0]["counts"]

    # compress and base64 encoding --
    binary_str = zlib.compress(encoded_mask, zlib.Z_BEST_COMPRESSION)
    base64_str = base64.b64encode(binary_str)
    return base64_str.decode('ascii')

def binary_mask_to_ascii(mask, mask_val=1):
    """Converts a binary mask into OID challenge encoding ascii text."""
    mask = np.where(mask==mask_val, 1, 0).astype(np.bool)
    
    # check input mask --
    if mask.dtype != np.bool:
        raise ValueError(f"encode_binary_mask expects a binary mask, received dtype == {mask.dtype}")

    mask = np.squeeze(mask)
    if len(mask.shape) != 2:
        raise ValueError(f"encode_binary_mask expects a 2d mask, received shape == {mask.shape}")

    # convert input mask to expected COCO API input --
    mask_to_encode = mask.reshape(mask.shape[0], mask.shape[1], 1)
    mask_to_encode = mask_to_encode.astype(np.uint8)
    mask_to_encode = np.asfortranarray(mask_to_encode)

    # RLE encode mask --
    encoded_mask = coco_mask.encode(mask_to_encode)[0]["counts"]

    # compress and base64 encoding --
    binary_str = zlib.compress(encoded_mask, zlib.Z_BEST_COMPRESSION)
    base64_str = base64.b64encode(binary_str)
    return base64_str.decode()

In [None]:
def read_img(image_id, color, train_or_test='test', image_size=None):
    filename = f'../input/hpa-single-cell-image-classification/{train_or_test}/{image_id}_{color}.png'
    img = cv2.imread(filename, 0)
    return img

class HPADatasetTest(Dataset):
    def __init__(self, image_ids, mode='test'):
        self.image_ids = image_ids
        self.mode = mode
        
    def __len__(self):
        return len(self.image_ids)
        
    def __getitem__(self, index):
        
        try:
            image_id = self.image_ids[index]
            red = read_img(image_id, "red", self.mode, 0)
            green = read_img(image_id, "green", self.mode, 0)
            blue = read_img(image_id, "blue", self.mode, 0)
            yellow = read_img(image_id, "yellow", self.mode, 0)
            image = np.stack([blue, green, red, yellow], axis=-1)

            image_512 = cv2.resize(image, (512, 512)).transpose(2,0,1).astype(np.float32)
            image_768 = cv2.resize(image, (768, 768)).transpose(2,0,1).astype(np.float32)
            cell_mask = np.load(f'{mask_dir}/{image_id}.npz')['arr_0']
            ### for debug
#             cell_mask = np.zeros(cell_mask.shape).astype(int) if random.random() < 0.5 else cell_mask
            ###
            cell_mask = cv2.resize(cell_mask, (image.shape[0], image.shape[1]), interpolation=cv2.INTER_NEAREST)

            encs = ''
            masked_images = []
            cells_128 = []
            cells_256 = []
            center_cells = []
            for cell_id in range(1, np.max(cell_mask)+1):
                bmask = (cell_mask == cell_id).astype(np.uint8)
                enc = encode_binary_mask(bmask==1)
                x, y, w, h = cv2.boundingRect(bmask)

                max_l = max(w, h)
                cx = x + w // 2
                cy = y + h // 2
                x1 = max(0, cx - max_l // 2)
                x1 = min(x1, image.shape[1] - max_l)
                y1 = max(0, cy - max_l // 2)
                y1 = min(y1, image.shape[0] - max_l)
                tmp = image.copy()
                tmp[bmask==0] = 0

                cropped_cell_orig = tmp[y1:y1+max_l, x1:x1+max_l]
                cropped_cell_128 = cv2.resize(cropped_cell_orig, (128, 128))
                cells_128.append(cropped_cell_128)
                cropped_cell_256 = cv2.resize(cropped_cell_orig, (256, 256))
                cells_256.append(cropped_cell_256)
                masked = cv2.resize(tmp, (image_size, image_size))
                masked_images.append(masked)

                ### daishu part
                cropped_cell = cv2.resize(tmp[y:y+h, x:x+w], 
                                            (int(w / image.shape[1] * 768),
                                             int(h / image.shape[0] * 768))
                                         )
                final_size = 512
                new_h, new_w, _ = cropped_cell.shape
                new_h = final_size if cropped_cell.shape[0] > final_size else new_h
                new_w = final_size if cropped_cell.shape[1] > final_size else new_w
                cropped_cell = cv2.resize(cropped_cell, (new_w, new_h))

                center_cell = np.zeros((final_size, final_size, 4))
                center = final_size // 2
                h_start = max(0,center-cropped_cell.shape[0]//2)
                h_end = min(final_size,h_start+cropped_cell.shape[0])
                w_start = max(0,center-cropped_cell.shape[1]//2)
                w_end = min(final_size,w_start+cropped_cell.shape[1])

                center_cell[h_start:h_end,w_start:w_end,:] = cropped_cell
                center_cells.append(center_cell)
                ###

                if encs == '':
                    encs += enc
                else:
                    encs = encs + ' ' + enc

            if len(masked_images) > 0:
                masked_images = np.stack(masked_images).transpose(0, 3, 1, 2).astype(np.float32)
                cells_128 = np.stack(cells_128).transpose(0, 3, 1, 2).astype(np.float32)
                cells_256 = np.stack(cells_256).transpose(0, 3, 1, 2).astype(np.float32)
                center_cells = np.stack(center_cells).transpose(0, 3, 1, 2).astype(np.float32)
            else:
                masked_images = np.zeros((4, 4, image_size, image_size))
                cells_128 = np.zeros((4, 4, 128, 128))
                cells_256 = np.zeros((4, 4, 256, 256))

            for ch in range(4):
                image_512[ch] /= orig_mean[ch]
                image_768[ch] /= orig_mean[ch]
                masked_images[:, ch] /= orig_mean[ch]
                cells_128[:, ch] /= orig_mean[ch]
                cells_256[:, ch] /= orig_mean[ch]
                center_cells[:, ch] /= orig_mean[ch]

        except:
            image_id = ''
            encs = ''
            image_512 = np.zeros((4, 512, 512))
            image_768 = np.zeros((4, 768, 768))
            masked_images = np.zeros((5, 4, image_size, image_size))
            cells_128 = np.zeros((5, 4, 128, 128))
            cells_256 = np.zeros((5, 4, 256, 256))
            center_cells = np.zeros((5, 4, 512, 512))

        return image_id, encs, {
            '512': torch.tensor(image_512),
            '768': torch.tensor(image_768),
            'masked': torch.tensor(masked_images),
            'cells_128': torch.tensor(cells_128),
            'cells_256': torch.tensor(cells_256),
            'center_cells': torch.tensor(center_cells)
        }


In [None]:
dataset = HPADatasetTest(df_sub.ID.values, mode='test')
dataloader = DataLoader(dataset, batch_size=1, num_workers=2)

In [None]:
if df_sub.shape[0] <= 559:
    from pylab import rcParams
    rcParams['figure.figsize'] = 20,15

    f, axarr = plt.subplots(4,5)
    ID, enc, imgs = dataset[0]
    print(imgs['masked'].shape)
    for p in range(5):
        axarr[0, p].imshow(imgs['512'][:3].transpose(0, 1).transpose(1,2))
        axarr[1, p].imshow(imgs['masked'][p, :3].transpose(0, 1).transpose(1,2))
        axarr[2, p].imshow(imgs['cells_256'][p, :3].transpose(0, 1).transpose(1,2))
        axarr[3, p].imshow(imgs['center_cells'][p, [2,1,0]].transpose(0, 1).transpose(1,2))

# Model

In [None]:
class enetv2(nn.Module):
    def __init__(self, enet_type, out_dim=num_classes):
        super(enetv2, self).__init__()
        self.enet = timm.create_model(enet_type, False)
        if ('efficientnet' in enet_type) or ('mixnet' in enet_type):
            self.enet.conv_stem.weight = nn.Parameter(self.enet.conv_stem.weight.repeat(1,n_ch//3+1,1,1)[:, :n_ch])
            self.myfc = nn.Linear(self.enet.classifier.in_features, out_dim)
            self.enet.classifier = nn.Identity()
        elif ('resnet' in enet_type or 'resnest' in enet_type) and 'vit' not in enet_type:
            self.enet.conv1[0].weight = nn.Parameter(self.enet.conv1[0].weight.repeat(1,n_ch//3+1,1,1)[:, :n_ch])
            self.myfc = nn.Linear(self.enet.fc.in_features, out_dim)
            self.enet.fc = nn.Identity()
        elif 'rexnet' in enet_type or 'regnety' in enet_type or 'nf_regnet' in enet_type:
            self.enet.stem.conv.weight = nn.Parameter(self.enet.stem.conv.weight.repeat(1,n_ch//3+1,1,1)[:, :n_ch])
            self.myfc = nn.Linear(self.enet.head.fc.in_features, out_dim)
            self.enet.head.fc = nn.Identity()
        elif 'resnext' in enet_type:
            self.enet.conv1.weight = nn.Parameter(self.enet.conv1.weight.repeat(1,n_ch//3+1,1,1)[:, :n_ch])
            self.myfc = nn.Linear(self.enet.fc.in_features, out_dim)
            self.enet.fc = nn.Identity()
        elif 'hrnet_w32' in enet_type:
            self.enet.conv1.weight = nn.Parameter(self.enet.conv1.weight.repeat(1,n_ch//3+1,1,1)[:, :n_ch])
            self.myfc = nn.Linear(self.enet.classifier.in_features, out_dim)
            self.enet.classifier = nn.Identity()
        elif 'densenet' in enet_type:
            self.enet.features.conv0.weight = nn.Parameter(self.enet.features.conv0.weight.repeat(1,n_ch//3+1,1,1)[:, :n_ch])
            self.myfc = nn.Linear(self.enet.classifier.in_features, out_dim)
            self.enet.classifier = nn.Identity()
        elif 'ese_vovnet39b' in enet_type or 'xception41' in enet_type:
            self.enet.stem[0].conv.weight = nn.Parameter(self.enet.stem[0].conv.weight.repeat(1,n_ch//3+1,1,1)[:, :n_ch])
            self.myfc = nn.Linear(self.enet.head.fc.in_features, out_dim)
            self.enet.head.fc = nn.Identity()
        elif 'dpn' in enet_type:
            self.enet.features.conv1_1.conv.weight = nn.Parameter(self.enet.features.conv1_1.conv.weight.repeat(1,n_ch//3+1,1,1)[:, :n_ch])
            self.myfc = nn.Linear(self.enet.classifier.in_channels, out_dim)
            self.enet.classifier = nn.Identity()
        elif 'inception' in enet_type:
            self.enet.features[0].conv.weight = nn.Parameter(self.enet.features[0].conv.weight.repeat(1,n_ch//3+1,1,1)[:, :n_ch])
            self.myfc = nn.Linear(self.enet.last_linear.in_features, out_dim)
            self.enet.last_linear = nn.Identity()
        elif 'vit_base_resnet50' in enet_type:
            self.enet.patch_embed.backbone.stem.conv.weight = nn.Parameter(self.enet.patch_embed.backbone.stem.conv.weight.repeat(1,n_ch//3+1,1,1)[:, :n_ch])
            self.myfc = nn.Linear(self.enet.head.in_features, out_dim)
            self.enet.head = nn.Identity()
        else:
            raise
    
    def forward(self, x):
        x = self.enet(x)
        h = self.myfc(x)
        return h
    

In [None]:
!ls ../input/bo-hpa-models-3d256

In [None]:
kernel_types = {
    'resnet50d_512_multilabel_8flips_ss22rot45_co2_lr1e4_bs32_focal_ext_15epo': {
        'model_class': 'enetv2',
        'folds': [1],
        'enet_type': 'resnet50d',
        'input_type': ['512', 'masked'],
    },
    'rex150_512_multilabel_8flips_ss22rot45_co7_lr3e4_bs32_ext_cellpseudo2full_15epo': {
        'model_class': 'enetv2',
        'folds': [0],
        'enet_type': 'rexnet_150',
        'input_type': ['512', 'masked'],
    },
    'densenet121_512_multilabel_8flips_ss22rot45_co2_lr1e4_bs32_focal_ext_15epo': {
        'model_class': 'enetv2',
        'folds': [2],
        'enet_type': 'densenet121',
        'input_type': ['512', 'masked'],
    },
    'b0_512_multilabel_8flips_ss22rot45_co7_lr1e4_bs32_focal_ext_15epo': {
        'model_class': 'enetv2',
        'folds': [3],
        'enet_type': 'tf_efficientnet_b0_ns',
        'input_type': ['512', 'masked'],
    },
    'resnet101d_512_multilabel_8flips_ss22rot45_co7_lr1e4_bs32_focal_ext_15epo': {
        'model_class': 'enetv2',
        'folds': [4],
        'enet_type': 'resnet101d',
        'input_type': ['512', 'masked'],
    },
    'dpn68b_512_multilabel_8flips_ss22rot45_co7_lr1e4_bs32_focal_ext_15epo': {
        'model_class': 'enetv2',
        'folds': [0],
        'enet_type': 'dpn68b',
        'input_type': ['512', 'masked'],
    },
    'densenet169_512_multilabel_8flips_ss22rot45_co2_lr1e4_bs32_focal_ext_15epo': {
        'model_class': 'enetv2',
        'folds': [1],
        'enet_type': 'densenet169',
        'input_type': ['512', 'masked'],
    },
    ### 2.5d
    'b0_3d128_multilabel_lw41_8flips_ss22rot45_lr1e4_bs32cell16_ext_2019_15epo': {
        'model_class': 'enetv2',
        'folds': [2],
        'enet_type': 'tf_efficientnet_b0_ns',
        'input_type': ['cells_128'],
    },
    'resnet50d_3d128_multilabel_lw41_8flips_ss22rot45_lr1e4_bs32cell16_ext_15epo': {
        'model_class': 'enetv2',
        'folds': [0],
        'enet_type': 'resnet50d',
        'input_type': ['cells_128'],
    },
    'mixnet_m_3d128_multilabel_lw41_8flips_ss22rot45_lr1e4_bs32cell16_ext_15epo': {
        'model_class': 'enetv2',
        'folds': [0],
        'enet_type': 'mixnet_m',
        'input_type': ['cells_128'],
    },
    'densenet121_3d128_multilabel_lw41_8flips_ss22rot45_lr1e4_bs32cell16_ext_2019_15epo': {
        'model_class': 'enetv2',
        'folds': [3],
        'enet_type': 'densenet121',
        'input_type': ['cells_128'],
    },
    ### 2.5d 256
    'b0_3d256_multilabel_lw41_8flips_ss22rot45_lr1e4_bs32cell16_ext_2019_15epo': {
        'model_class': 'enetv2',
        'folds': [0],
        'enet_type': 'tf_efficientnet_b0_ns',
        'input_type': ['cells_256'],
    },
    'b1_3d256_multilabel_lw41_8flips_ss22rot45_lr1e4_bs32cell16_ext_2019_15epo': {
        'model_class': 'enetv2',
        'folds': [3],
        'enet_type': 'tf_efficientnet_b1_ns',
        'input_type': ['cells_256'],
    },
    'densenet121_3d256_multilabel_lw41_8flips_ss22rot45_lr1e4_bs32cell16_ext_2019_15epo': {
        'model_class': 'enetv2',
        'folds': [2],
        'enet_type': 'densenet121',
        'input_type': ['cells_256'],
    },
    'dpn68b_3d256_multilabel_lw41_8flips_ss22rot45_lr1e4_bs32cell16_ext_2019_15epo': {
        'model_class': 'enetv2',
        'folds': [4],
        'enet_type': 'dpn68b',
        'input_type': ['cells_256'],
    },
    'mixnet_m_3d256_multilabel_lw41_8flips_ss22rot45_lr1e4_bs32cell16_ext_2019_15epo': {
        'model_class': 'enetv2',
        'folds': [2],
        'enet_type': 'mixnet_m',
        'input_type': ['cells_256'],
    },
    'resnet50d_3d256_multilabel_lw41_8flips_ss22rot45_lr1e4_bs32cell16_ext_2019_15epo': {
        'model_class': 'enetv2',
        'folds': [1],
        'enet_type': 'resnet50d',
        'input_type': ['cells_256'],
    },
}

In [None]:
def load_state_dict(model, model_file):
    for folder in model_dirs:
        model_path = os.path.join(folder, model_file)
        if os.path.exists(model_path):
            state_dict = torch.load(model_path)
            state_dict = {k[7:] if k.startswith('module.') else k: state_dict[k] for k in state_dict.keys()}
            model.load_state_dict(state_dict, strict=True)
            model.eval()
            return model
    raise

models = []
input_types = []
for key in kernel_types.keys():
    for fold in kernel_types[key]['folds']:
        model = eval(kernel_types[key]['model_class'])(
            kernel_types[key]['enet_type'],
        )
        model = model.to(device)
        model_file = f'{key}_final_fold{fold}.pth'
        print(f'loading {model_file} ...')
        model = load_state_dict(model, model_file)
        models.append(model)

        input_types.append(kernel_types[key]['input_type'])

n_models = len(models)
print('done!')
print('model count:', n_models)

In [None]:
def load_model(model_name,path):
    if model_name == 'densenet121':
        state_dict = torch.load(path, torch.device('cuda') )
        model = class_densenet121_dropout(num_classes=19,in_channels=4,pretrained_file=None)
        model.cuda()
        model.load_state_dict(state_dict)
#         model = amp.initialize(model, opt_level="O1")
        model.eval()
    return model

In [None]:
folds = [0,1,2,3,4]
model_dic = {'densenet121':'../input/hpafinal/output/run_nn_20210504_000509/'}

In [None]:
rgby_models = []
for model_name in model_dic:
    path = model_dic[model_name]
    for fold in folds:
        if os.path.exists(path+'fold%s.ckpt'%fold):
            model = load_model(model_name,path+'fold%s.ckpt'%fold)
            rgby_models.append(model)
print('daishu model count:', len(rgby_models))

In [None]:
def get_trans(img, I, mode='bgry'):
    if mode == 'rgby':
        img = img[:, [2,1,0,3]]
    if I >= 4:
        img = img.transpose(2,3)
    if I % 4 == 0:
        return img
    elif I % 4 == 1:
        return img.flip(2)
    elif I % 4 == 2:
        return img.flip(3)
    elif I % 4 == 3:
        return img.flip(2).flip(3)
    

# def get_trans_daishu(img, I, mode='bgry'):
#     if mode == 'rgby':
#         img = img[:, [2,1,0,3]]

#     if I == 0:
#         return img[:, :, :512, :512]
#     if I == 1:
#         return img[:, :, :512, 256:]
#     if I == 2:
#         return img[:, :, 256:, :512]
#     if I == 3:
#         return img[:, :, 256:, 256:]
#     if I == 4:
#         return img[:, :, 128:640, 128:640]
#     raise
def get_trans_daishu(img, I, mode='bgry'):
    if mode == 'rgby':
        img = img[:, [2,1,0,3]]
    
    if I == 0:
        img = img[:, :, 64:704, 64:704]
    if I == 1:
        img = img[:, :, :640, :640].flip(2)
    if I == 2:
        img = img[:, :, :640, 128:].flip(3)
    if I == 3:
        img = img[:, :, 128:, 128:].flip(2).flip(3)
    if I == 4:
        img = img[:, :, 128:, :640].transpose(2,3)
    if I == 5:
        img = img[:, :, 32:672, 96:736].transpose(2,3).flip(2)
    if I == 6:
        img = img[:, :, 96:736, 32:672].transpose(2,3).flip(3)
    img = F.interpolate(img, size=[512, 512], mode="bilinear")
    return img

In [None]:
print('TTA', TTA)

In [None]:
IDs = []
encs = []
PRED_FINAL = []
little_bs = 16
with torch.no_grad():
    for ID, enc, images in tqdm(dataloader):
        try:
            if len(enc[0]) > 0:
                with amp.autocast():
                    for k in images.keys():
                        images[k] = images[k].cuda()
                        if images[k].ndim == 5:
                            images[k] = images[k].squeeze(0)

                    preds = {
                        'orig': [],
                        'cells': [],
                    }
                    # orig 全图
#                     for m, inp_types in zip(models, input_types):
#                         for t in inp_types:
#                             if t in ['512', '768']:
#                                 for I in np.random.choice(range(8), TTA['orig'], replace=False):
#                                     preds['orig'].append(m(get_trans(images[t], I)).sigmoid())

                    # 遮盖 & 单cell
                    for m, inp_types in zip(models, input_types):
                        for t in inp_types:
                            if t in ['masked', 'cells_128', 'cells_256']:
                                for I in np.random.choice(range(8), TTA[t], replace=False):
                                    this_pred = torch.cat([
                                        m(get_trans(images[t][b:b+little_bs], I)).sigmoid() \
                                            for b in range(0, images[t].shape[0], little_bs)
                                    ])
                                    preds['cells'].append(this_pred)

                    # daishu
                    for m in rgby_models:
                        # 全图
#                         for I in np.random.choice(range(8), TTA['orig']-1, replace=False):
#                             preds['orig'].append(m(get_trans_daishu(images['768'], I, 'rgby'))[1].sigmoid())
#                         preds['orig'].append(m(images['512'][:, [2,1,0,3]])[1].sigmoid())
                        # 单cell
                        for I in np.random.choice(range(8), TTA['center_cells'], replace=False):
                            this_pred = torch.cat([
                                m(get_trans(images['center_cells'][b:b+little_bs], I, 'rgby'))[1].sigmoid() \
                                    for b in range(0, images['center_cells'].shape[0], little_bs)
                            ])
                            preds['cells'].append(this_pred)

                    for k in preds.keys():
                        if len(preds[k]) > 0:
                            preds[k] = torch.stack(preds[k], 0).mean(0)
                        else:
                            preds[k] = 0

                    pred_final = preds['cells']

                    PRED_FINAL.append(pred_final.cpu())
                    IDs += [ID[0]] * images['cells_128'].shape[0]
                    encs += enc[0].split(' ')

        except:
            print('error')
            pass

PRED_FINAL = torch.cat(PRED_FINAL).float()

In [None]:
print(PRED_FINAL.shape, PRED_FINAL.max(), PRED_FINAL.min(), PRED_FINAL.mean())

In [None]:
PredictionString = []
for i in tqdm(range(PRED_FINAL.shape[0])):
    enc = encs[i]
    prob = PRED_FINAL[i]
    sub_string = []
    for cid, p in enumerate(prob):
        sub_string.append(' '.join([str(cid), f'{p:.5f}', enc]))
    sub_string = ' '.join(sub_string)
    PredictionString.append(sub_string)

In [None]:
df_pred = pd.DataFrame({
    'ID': IDs,
    'PredictionString': PredictionString
})
df_pred = df_pred.groupby(['ID'])['PredictionString'].apply(lambda x: ' '.join(x)).reset_index()

In [None]:
df_sub = df_sub[['ID', 'ImageWidth', 'ImageHeight']].merge(df_pred, on='ID', how="left")
df_sub.fillna('', inplace=True)
df_sub.to_csv('submission.csv', index=False)

In [None]:
df_sub.shape

In [None]:
!rm -rf {mask_dir}