In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os


# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
!ls /kaggle/input/hpacellsegmentatorraman/HPA-Cell-Segmentation

In [None]:
!pip install  "/kaggle/input/pycocotools/pycocotools-2.0-cp37-cp37m-linux_x86_64.whl"
# !pip install -q "/kaggle/input/hpapytorchzoozip/pytorch_zoo-master"
#!pip install -q "/kaggle/input/hpacellsegmentatormaster/HPA-Cell-Segmentation-master"
!pip install "/kaggle/input/efficientnetpytorch"

In [None]:
import sys
sys.path.append("/kaggle/input/hpacellsegmentatorraman/HPA-Cell-Segmentation")
sys.path.append("/kaggle/input/hpapytorchzoozip/pytorch_zoo-master")

In [None]:
labeling_model_path = "/kaggle/input/hpalabeliingmodel/3ch-no-mask-multiflag-mixup.pth"
keepres_labeling_model_path = "/kaggle/input/hpalabeliingmodel/3ch-keep-resolution-fine-tune-472.pth"
labeling_model_path_472 =  "/kaggle/input/hpalabeliingmodel/3ch-multiflag-mixup-fine-tune-472.pth"

test_path = "/kaggle/input/hpa-single-cell-image-classification/test/"
test_df_path=  "/kaggle/input/hpa-single-cell-image-classification/sample_submission.csv"
nuclei_model_path = "/kaggle/input/hpa-cell-segmentation-weights/nuclei-model.pth"
cell_model_path = "/kaggle/input/hpa-cell-segmentation-weights/cell-model.pth"
batch_size=8
num_workers=2


In [None]:

SUFFIX2TYPE = {
    "red": "microtubule",
    "yellow": "er",
    "blue": "nuclei",
    "green": "protein",
}

IMAGE_TYPES = ["microtubule", "er", "nuclei", "protein"]


from collections import namedtuple
from dataclasses import dataclass
from typing import Tuple

import numpy as np

CellSample = namedtuple("CellSample", ("id", "microtubule", "er", "nuclei", "protein", "image_size"))
CellInput = namedtuple("CellInput", ("nuc_input", "cell_input"))
CellOutput = namedtuple("CellOutput", ("nuc_output", "cell_output"))


@dataclass
class CellLabelInput:
    image_size: int
    nuc_prediction: np.array
    cell_prediction: np.array


@dataclass
class SegmentedImage:
    image: np.array
    sample_id: str
    segment_id: int
    image_size: Tuple[int, int] = None
    encoded_mask: str = None
    pred_score: np.array = None

    def __post_init__(self):
        if self.image_size is None:
            self.image_size = self.image.shape[-2:]

    def file_name(self, prefix):
        return f"{self.sample_id}_{self.segment_id}.{prefix}"

    @property
    def text_attributes(self):
        return {
            "sample_id": self.sample_id,
            "segment_id": self.segment_id,
            "image_size": self.image_size,
        }

    def format(self):
        return " ".join([f"{i} {score} {self.encoded_mask}" for i, score in enumerate(self.pred_score)])



In [None]:
import numpy as np
from skimage import transform
from skimage.util import img_as_float64

NORMALIZE = {
    "mean": np.array([124 / 255, 117 / 255, 104 / 255]),
    "std": np.array([1 / (0.0167 * 255)] * 3),
}

IMAGE_SIZE = 512


def normalize(image):
    return (image - NORMALIZE["mean"][:, None, None]) / NORMALIZE["std"][:, None, None]


class SamplePreprocessor:
    @staticmethod
    def run(sample: CellSample):
        nuclei_input = NucleiPreprocessor.run(sample)
        cell_input = CellPreprocessor.run(sample)
        return nuclei_input, cell_input


class NucleiPreprocessor:
    @classmethod
    def run(cls, sample: CellSample):
        image = cls._scale_image(img_as_float64(sample.nuclei), sample.image_size)
        image = cls._stack_image(image)
        image = normalize(image)
        return image.astype(np.float64)

    @staticmethod
    def _scale_image(image, image_size):
        return transform.rescale(
            image,
            scale=IMAGE_SIZE / image_size,
            anti_aliasing=True,
        )

    @staticmethod
    def _stack_image(image):
        return np.stack([image, image, image])


class CellPreprocessor:
    @classmethod
    def run(cls, sample: CellSample):
        images = cls._scale_images([img_as_float64(sample.microtubule), img_as_float64(sample.er), img_as_float64(sample.nuclei)], sample.image_size)
        images = np.stack(images)
        images = normalize(images)
        return images.astype(np.float32)

    @staticmethod
    def _scale_images(images, image_size):
        scale = IMAGE_SIZE / image_size
        return [
            transform.rescale(
                image,
                scale=scale,
                anti_aliasing=True,
            )
            for image in images
        ]


In [None]:
import os
from logging import getLogger

import imageio
import numpy as np
from PIL import ImageFile
from torch.utils.data import Dataset

from skimage.util import img_as_ubyte

ImageFile.LOAD_TRUNCATED_IMAGES = True

logger = getLogger(__name__)


class CellDataset(Dataset):
    def __init__(self, sample_df, image_root_path):
        self.sample_df = sample_df
        self.image_root_path = image_root_path

    def __len__(self):
        return len(self.sample_df)

    def __getitem__(self, index):
        sample = self.sample_df.iloc[index]
        # logger.info(f"start loading image: {sample.ID}")
        cell_sample = self._load_sample(sample.ID)
        # logger.info(f"start preprocessing image: {sample.ID}")
        input = SamplePreprocessor.run(cell_sample)
        return cell_sample, CellInput(*input)

    def _load_sample(self, sample_id):
        image_dict = {"id": sample_id}
        for suffix, type in SUFFIX2TYPE.items():
            img_path = os.path.join(self.image_root_path, f"{sample_id}_{suffix}.png")
            image_dict[type] = img_as_ubyte(np.array(imageio.imread(img_path)))
        return CellSample(**image_dict, image_size=image_dict["protein"].shape[0])

    
    
import torch


def my_collate(batch):
    samples = [s[0] for s in batch]
    inputs = [s[1] for s in batch]
    return collate_by_tuple(samples), collate_by_tensor(inputs)


def collate_by_tuple(tuples):
    clazz = type(tuples[0])
    return clazz(*[tuple([getattr(t, field) for t in tuples]) for field in clazz._fields])


def collate_by_tensor(tuples):
    clazz = type(tuples[0])
    return clazz(*[torch.Tensor([getattr(t, field) for t in tuples]) for field in clazz._fields])


In [None]:
import sys

import torch
import torch.nn.functional as F


class CellPredictionModel:
    def __init__(self, nuclei_model_path, cell_model_path, device):
        self.nuclei_model = SegmentationModel(nuclei_model_path, device)
        self.cell_model = SegmentationModel(cell_model_path, device)

    def __call__(self, input: CellInput):
        nuc_preds = self.nuclei_model(torch.FloatTensor(input.nuc_input))
        cell_preds = self.cell_model(torch.FloatTensor(input.cell_input))
        return CellOutput(nuc_preds, cell_preds)


class SegmentationModel:
    def __init__(self, model_path, device):
        self.nuclei_model = torch.load(model_path, map_location=torch.device(device))
        self.device = device

    def __call__(self, batch):
        with torch.no_grad():
            preds = self.nuclei_model(batch.to(self.device))
            preds = F.softmax(preds, dim=1)
            return preds.cpu().numpy()

class CellPredicitonDecomposer:
    @staticmethod
    def run(cell_sample, cell_output):
        inputs = [
            CellLabelInput(*attrs)
            for attrs in zip(
                cell_sample.image_size,
                cell_output.nuc_output,
                cell_output.cell_output,
            )
        ]

        samples = [CellSample(*attrs) for attrs in zip(*[getattr(cell_sample, field) for field in cell_sample._fields])]
        return list(zip(samples, inputs))

In [None]:
from typing import List

import cv2
import numpy as np
from skimage import util

class InputPreprocessor:
    @staticmethod
    def run(input: CellLabelInput):
        input.cell_prediction = OutputPrprocessor.run(input.cell_prediction, input.image_size)
        input.nuc_prediction = OutputPrprocessor.run(input.nuc_prediction, input.image_size)
        return input


class OutputPrprocessor:
    @classmethod
    def run(cls, pred, image_size) -> List[np.array]:
        """
        return 2 channels for each prediction:
        * pred[0]: border
        * pred[1]: detected segments
        """
        pred = cls._remove_first_channel(pred)
        pred = cls._format(pred)
        pred = cls._resize(pred, image_size)
        return pred

    @staticmethod
    def _resize(pred, image_size):
        pred = pred.transpose([1, 2, 0])
        pred = cv2.resize(
            pred,
            (image_size, image_size),
            interpolation=cv2.INTER_AREA,
        )
        pred = pred.transpose([2, 0, 1])
        return pred

    @staticmethod
    def _format(pred):
        """
        quantize values
        """
        return util.img_as_ubyte(pred)

    @staticmethod
    def _remove_first_channel(pred):
        return pred[1:]

In [None]:
class LabelPredictor:
    @staticmethod
    def run(input: CellLabelInput):
        nuc_label = NucleiLabelPredictor.run(input.nuc_prediction, input.cell_prediction)
        cell_label = CellLabelPredictor.run(input.cell_prediction, nuc_label)
        cell_label = SmallLabelRemover.run(cell_label)
        return cell_label

import numpy as np
import scipy.ndimage as ndi
from skimage import segmentation
from skimage.morphology import remove_small_holes, remove_small_objects

HIGH_THRESHOLD = 0.4 * 255
LOW_THRESHOLD = HIGH_THRESHOLD - 0.25 * 255


class NucleiLabelPredictor:
    @classmethod
    def run(cls, nuclei_pred, cell_pred):
        nuc_borders, nuc_segments = nuclei_pred
        cell_borders, _ = cell_pred

        markers = cls._build_markers(nuc_segments, nuc_borders, cell_borders)
        mask = cls._build_mask(nuc_segments)
        labels = segmentation.watershed(mask, markers, mask=mask, watershed_line=True)

        labels = ndi.label(labels)[0]
        labels = remove_small_objects(labels, 2500)

        return labels

    @classmethod
    def _build_markers(cls, nuc_segments, nuc_borders, cell_borders):
        borders = cls._build_border_for_marker(nuc_borders, cell_borders)
        marker_base = nuc_segments * borders

        marker_mask = np.zeros_like(nuc_segments).astype(bool)
        marker_mask[marker_base > LOW_THRESHOLD] = 1
        marker_mask = remove_small_objects(marker_mask, 500).astype(np.uint8)
        markers = ndi.label(marker_mask, output=np.uint32)[0]

        return markers

    @staticmethod
    def _build_border_for_marker(nuc_borders, cell_borders):
        return 1 - (nuc_borders + cell_borders) / 255.0 > 0.05

    @staticmethod
    def _build_mask(nuc_segments):
        mask = nuc_segments > HIGH_THRESHOLD
        mask = mask.astype(np.bool)
        mask = remove_small_holes(mask, 1000)
        mask = remove_small_objects(mask, 8).astype(np.uint8)
        return mask


import numpy as np
import scipy.ndimage as ndi
from skimage import filters, segmentation
from skimage.morphology import closing, disk, remove_small_objects


class CellLabelPredictor:
    @classmethod
    def run(cls, cell_pred, nuclei_label):
        _, cell_segments = cell_pred

        threshold = cls._calc_distance_thereshold(cell_segments)
        elevation_map = cls._calc_cell_elevation(cell_segments)

        cell_label = segmentation.watershed(
            elevation_map,
            markers=nuclei_label,
            mask=cls._calc_cell_mask(cell_segments, threshold),
            watershed_line=True,
        )

        cell_label = remove_small_objects(cell_label, 5500).astype(np.uint8)
        cell_label = ndi.label(cell_label)[0]
        return cell_label

    @staticmethod
    def _calc_border_mask(cell_borders):
        return np.asarray(cell_borders / 255 > 0.05, dtype=np.int8)

    @staticmethod
    def _calc_distance_thereshold(cell_segments):
        return max(0.22 * 255, filters.threshold_otsu(cell_segments) * 0.5)

    @staticmethod
    def _calc_cell_elevation(cell_segments):
        return -cell_segments

    @staticmethod
    def _calc_cell_mask(cell_segments,  threshold):
        cell_mask = cell_segments > threshold
        return cell_mask

    @staticmethod
    def _fix_broken_shapes(cell_label):
        return closing(cell_label, disk(6))

import numpy as np


class SmallLabelRemover:
    @classmethod
    def run(cls, label):
        remove_label_indices = cls._get_remove_label_indices(label)
        remove_mask = np.invert(np.isin(label, remove_label_indices))
        remained_label = label * remove_mask
        return remained_label

    @staticmethod
    def _get_remove_label_indices(label):
        stats = np.bincount(label.flatten())[1:]
        remove_label_indices = np.where(stats < (stats.std() * 0.3))[0] + 1
        return remove_label_indices

    
class CellLabelCalculator:
    @staticmethod
    def run(input: CellLabelInput):
        input = InputPreprocessor.run(input)
        cell_label = LabelPredictor.run(input)
        return cell_label

In [None]:
import base64
import zlib

import numpy as np
from pycocotools import _mask as coco_mask
from skimage import img_as_ubyte


class ImageSegmentator:
    @classmethod
    def run(cls, sample: CellSample, labels: np.array):
        image = cls._stack_image(sample)
        segmented_samples = []
        for segment_id in cls._find_segment_ids(labels):
            segmented_sample = SegmentedImage(
                image=ImageCropper.run(image, labels, segment_id),
                segment_id=segment_id,
                sample_id=sample.id,
                encoded_mask=BinaryMaskEncoder.run(labels == segment_id),
            )
            segmented_samples.append(segmented_sample)
        return segmented_samples

    @staticmethod
    def _find_segment_ids(nuclei_label):
        return set(np.unique(nuclei_label)) - {0}

    @staticmethod
    def _stack_image(sample):
        types = ["microtubule", "nuclei", "protein"] # ignore yellow
        return np.stack([getattr(sample, img_type) for img_type in types])


class ImageCropper:
    @classmethod
    def run(cls, image: np.array, labels: np.array, segment_id: int):
        mask = cls._get_specified_label_mask(labels, segment_id)
        coordnates = cls._get_crop_coordinates(mask)

        cropped_mask = cls._crop_mask(mask, *coordnates)
        cropped_images = cls._crop_images(image, *coordnates)
        cropped_images = cropped_images * cropped_mask
        cropped_images = img_as_ubyte(cropped_images)
        return cropped_images

    @staticmethod
    def _get_specified_label_mask(labels, segment_id):
        return np.where(labels == segment_id, 1, 0).astype(np.bool)

    @staticmethod
    def _get_crop_coordinates(mask):
        true_points = np.argwhere(mask)
        top_left = true_points.min(axis=0)
        bottom_right = true_points.max(axis=0)
        top = top_left[0]
        bottom = bottom_right[0]
        left = top_left[1]
        right = bottom_right[1]
        return top, bottom, left, right

    @staticmethod
    def _crop_images(images, top, bottom, left, right):
        return images[
            :,
            top : bottom + 1,
            left : right + 1,
        ]

    @staticmethod
    def _crop_mask(mask, top, bottom, left, right):
        return mask[
            top : bottom + 1,
            left : right + 1,
        ]


class BinaryMaskEncoder:
    def run(mask: np.ndarray) -> str:
        mask = np.squeeze(mask)
        if len(mask.shape) != 2:
            raise ValueError("encode_binary_mask expects a 2d mask, received shape == %s" % mask.shape)

        # convert input mask to expected COCO API input --
        mask_to_encode = mask.reshape(mask.shape[0], mask.shape[1], 1)
        mask_to_encode = mask_to_encode.astype(np.uint8)
        mask_to_encode = np.asfortranarray(mask_to_encode)

        # RLE encode mask --
        encoded_mask = coco_mask.encode(mask_to_encode)[0]["counts"]

        # compress and base64 encoding --
        binary_str = zlib.compress(encoded_mask, zlib.Z_BEST_COMPRESSION)
        base64_str = base64.b64encode(binary_str)
        return base64_str.decode("utf-8")

In [None]:
import torch
from torch.nn import ZeroPad2d

MEAN = [20.503461399549288, 13.71165825908858, 13.297466514210495]
STD = [33.8227572424431, 36.53488215123118, 22.72992170966404]


class SymmetricPad(torch.nn.Module):
    def __init__(self, pad_size):
        self.pad_size = pad_size

    def __call__(self, img):
        pad_shape = self._get_padding(img)

        if pad_shape is None:
            return img
        else:
            return ZeroPad2d(pad_shape)(img)

    def _get_padding(self, image):
        width, height = image.shape[-2:]
        horizontal_pad_length = max(0, self.pad_size[0] - width)
        vertical_pad_length = max(0, self.pad_size[1] - height)

        if horizontal_pad_length == 0 and vertical_pad_length == 0:
            return None

        l_pad = horizontal_pad_length // 2
        r_pad = horizontal_pad_length // 2 + int(horizontal_pad_length % 2 != 0)
        t_pad = vertical_pad_length // 2
        b_pad = vertical_pad_length // 2 + int(vertical_pad_length % 2 != 0)

        return (t_pad, b_pad, l_pad, r_pad)

class SquarePad(torch.nn.Module):
    def __call__(self, img):
        return ZeroPad2d(self._get_padding(img))(img)

    @staticmethod
    def _get_padding(image):
        width, height = image.shape[-2:]
        max_wh = max([width, height])
        horizontal_padding = (max_wh - width) // 2
        vertical_padding = (max_wh - height) // 2
        l_pad = horizontal_padding
        r_pad = horizontal_padding if horizontal_padding % 1 == 0 else horizontal_padding + 1
        t_pad = vertical_padding
        b_pad = vertical_padding if vertical_padding % 1 == 0 else vertical_padding + 1

        return (t_pad, b_pad, l_pad, r_pad)


In [None]:
import torch
from torch.nn import ZeroPad2d
from torchvision.transforms import (
    CenterCrop,
    Normalize,
    RandomCrop,
    RandomHorizontalFlip,
    RandomResizedCrop,
    RandomVerticalFlip,
    Resize,
)

SEG_IMG_SIZE_KEEP_RES = (472, 472)


class SegmentPreprocessorKeepRes:
    resize = torch.nn.Sequential(
        SymmetricPad(SEG_IMG_SIZE_KEEP_RES),
        CenterCrop(SEG_IMG_SIZE_KEEP_RES),
    )
    trans = Normalize(
        mean=MEAN,
        std=STD,
    )

    @classmethod
    def run(cls, img):
        img = torch.Tensor(img).type(torch.uint8).to(0)
        img = cls.resize(img).type(torch.float32)
        img = cls.trans(img)
        return img

    @classmethod
    def batch_run(cls, imgs):
        imgs = [cls.run(img) for img in imgs]
        return torch.stack(imgs)


class AugmentedSegmentPreprocessorKeepRes:
    resize = torch.nn.Sequential(
        SymmetricPad(SEG_IMG_SIZE_KEEP_RES),
        RandomCrop(SEG_IMG_SIZE_KEEP_RES),
    )
    trans = torch.nn.Sequential(
        RandomHorizontalFlip(),
        RandomVerticalFlip(),
        Normalize(
            mean=MEAN,
            std=STD,
        ),
    )

    @classmethod
    def run(cls, img):
        img = torch.Tensor(img).type(torch.uint8).to(0)
        img = cls.resize(img).type(torch.float32)
        img = cls.trans(img)
        return img

    @classmethod
    def batch_run(cls, imgs):
        imgs = [cls.run(img) for img in imgs]
        return torch.stack(imgs)


In [None]:
import torch
from torch.nn import ZeroPad2d
from torchvision.transforms import (
    CenterCrop,
    Normalize,
    RandomCrop,
    RandomHorizontalFlip,
    RandomResizedCrop,
    RandomVerticalFlip,
    Resize,
)


class SegmentPreprocessor472:
    resize = torch.nn.Sequential(
        SquarePad(),
        Resize(SEG_IMG_SIZE_KEEP_RES),
    )
    trans = Normalize(
        mean=MEAN,
        std=STD,
    )

    @classmethod
    def run(cls, img):
        img = torch.Tensor(img).type(torch.uint8).to(0)
        img = cls.resize(img).type(torch.float32)
        img = cls.trans(img)
        return img

    @classmethod
    def batch_run(cls, imgs):
        imgs = [cls.run(img) for img in imgs]
        return torch.stack(imgs)


class AugmentedSegmentPreprocessor472:
    resize = torch.nn.Sequential(
        Resize(SEG_IMG_SIZE_KEEP_RES),
        RandomCrop(SEG_IMG_SIZE_KEEP_RES),
    )
    trans = torch.nn.Sequential(
        RandomHorizontalFlip(),
        RandomVerticalFlip(),
        Normalize(
            mean=MEAN,
            std=STD,
        ),
    )

    @classmethod
    def run(cls, img):
        img = torch.Tensor(img).type(torch.uint8).to(0)
        img = cls.resize(img).type(torch.float32)
        img = cls.trans(img)
        return img

    @classmethod
    def batch_run(cls, imgs):
        imgs = [cls.run(img) for img in imgs]
        return torch.stack(imgs)


In [None]:
import torch
from torchvision.transforms import Normalize, RandomHorizontalFlip, RandomResizedCrop, RandomVerticalFlip, Resize

SEG_IMG_SIZE = (300, 300)


class SegmentPreprocessor:
    resize = torch.nn.Sequential(
        SquarePad(),
        Resize(SEG_IMG_SIZE),
    )
    trans = Normalize(
        mean=MEAN,
        std=STD,
    )

    @classmethod
    def run(cls, img):
        img = torch.Tensor(img).type(torch.uint8).to(0)
        img = cls.resize(img).type(torch.float32)
        img = cls.trans(img)
        return img

    @classmethod
    def batch_run(cls, imgs):
        imgs = [cls.run(img) for img in imgs]
        return torch.stack(imgs)


class AugmentedSegmentPreprocessor:
    resize = RandomResizedCrop(SEG_IMG_SIZE, scale=(0.5, 1.0))
    trans = torch.nn.Sequential(
        RandomHorizontalFlip(),
        RandomVerticalFlip(),
        Normalize(
            mean=MEAN,
            std=STD,
        ),
    )

    @classmethod
    def run(cls, img):
        img = torch.Tensor(img).type(torch.uint8).to(0)
        img = cls.resize(img).type(torch.float32)
        img = cls.trans(img)
        return img

    @classmethod
    def batch_run(cls, imgs):
        imgs = [cls.run(img) for img in imgs]
        return torch.stack(imgs)


In [None]:
from itertools import chain

import numpy as np
import torch


class SegmentLabelPredictor:
    def __init__(self, model, keepres_model, model_472):
        self.model = model
        self.keepres_model = keepres_model
        self.model_472 = model_472

    def run(self, segments, batch_size, n_augments=4):
        scores = self._predict(segments, batch_size, n_augments)

        for segment, score in zip(segments, scores):
            segment.pred_score = score

        return segments

    def _predict(self, segments, batch_size, n_augments):
        scores = []

        for chunk in np.array_split(segments, len(segments) // batch_size + 1):
            model_score = self._predict_for_each_chunk(chunk, n_augments)
            keepres_score = self._predict_for_each_chunk_keepres(chunk, n_augments)
            score_472 = self._predict_for_each_chunk_472(chunk, n_augments)
            score = model_score * 0.575 + keepres_score * 0.2  + score_472 * 0.225
            scores.append(score)

        scores = list(chain.from_iterable(scores))

        return scores

    def _predict_for_each_chunk_keepres(self, chunk, n_augments):
        imgs = [segment.image for segment in chunk]
        chunk_imgs = [AugmentedSegmentPreprocessorKeepRes.batch_run(imgs) for _ in range(n_augments)]
        chunk_imgs.append(SegmentPreprocessorKeepRes.batch_run(imgs))
        aug_scores = [self._predict_core_keepres(chunk_img) for chunk_img in chunk_imgs]
        scores = np.mean(aug_scores, axis=0)
        return scores

    def _predict_core_keepres(self, imgs):
        with torch.no_grad():
            scores = self.keepres_model(imgs)
            scores = torch.nn.Sigmoid()(scores)

        return scores.cpu().numpy()

    def _predict_for_each_chunk(self, chunk, n_augments):
        imgs = [segment.image for segment in chunk]
        chunk_imgs = [AugmentedSegmentPreprocessor.batch_run(imgs) for _ in range(n_augments)]
        chunk_imgs.append(SegmentPreprocessor.batch_run(imgs))
        aug_scores = [self._predict_core(chunk_img) for chunk_img in chunk_imgs]
        scores = np.mean(aug_scores, axis=0)
        return scores

    def _predict_core(self, imgs):
        with torch.no_grad():
            scores = self.model(imgs)
            scores = torch.nn.Sigmoid()(scores)

        return scores.cpu().numpy()

    
    def _predict_for_each_chunk_472(self, chunk, n_augments):
        imgs = [segment.image for segment in chunk]
        chunk_imgs = [AugmentedSegmentPreprocessor472.batch_run(imgs) for _ in range(n_augments)]
        chunk_imgs.append(SegmentPreprocessor472.batch_run(imgs))
        aug_scores = [self._predict_472(chunk_img) for chunk_img in chunk_imgs]
        scores = np.mean(aug_scores, axis=0)
        return scores

    def _predict_472(self, imgs):
        with torch.no_grad():
            scores = self.model_472(imgs)
            scores = torch.nn.Sigmoid()(scores)

        return scores.cpu().numpy()


In [None]:
import torch
from efficientnet_pytorch.model import MemoryEfficientSwish
from torch.nn import AdaptiveMaxPool2d, BatchNorm1d, Dropout, Linear
from torch.nn.modules import Module


class CustomEfficientNet(Module):
    def __init__(self, model_path, device):
        super().__init__()
        self.device = device
        self.model = torch.load(model_path).to(device).train()

        self._max_pool = AdaptiveMaxPool2d(output_size=1)

        self._output = torch.nn.Sequential(
            BatchNorm1d(num_features=1536 * 2),
            MemoryEfficientSwish(),
            Linear(in_features=1536 * 2, out_features=500, bias=True),
            BatchNorm1d(num_features=500),
            MemoryEfficientSwish(),
            Dropout(p=0.3),
            Linear(in_features=500, out_features=19),
        )

    def forward(self, inp):
        with torch.no_grad():
            x = self._extract_fixed_features(inp)
    
        x = self.model._swish(self.model._bn1(x))
        x1 = self.model._avg_pooling(x).flatten(start_dim=1)
        x2 = self._max_pool(x).flatten(start_dim=1)
        x = torch.cat([x1, x2], axis=1)
        x = self._output(x)

        return x

    def _extract_fixed_features(self, inputs):
        model = self.model
        x = model._swish(model._bn0(model._conv_stem(inputs)))
        # Blocks
        for idx, block in enumerate(model._blocks):
            drop_connect_rate = model._global_params.drop_connect_rate
            if drop_connect_rate:
                drop_connect_rate *= float(idx) / len(model._blocks)  # scale drop connect_rate
            x = block(x, drop_connect_rate=drop_connect_rate)

        x = model._conv_head(x)

        return x

    def finetune_params(self):
        return list(self._output.parameters()) + list(self.model._bn1.parameters())

    def base_params(self):
        return (
            list(self.model._blocks.parameters())
            + list(self.model._conv_stem.parameters())
            + list(self.model._bn0.parameters())
        )


In [None]:
from efficientnet_pytorch import EfficientNet


def load_labeling_model(labeling_model_path, device):
    # effnet = EfficientNet.from_name("efficientnet-b3",in_channels=4,num_classes=19)
    ckpt = torch.load(labeling_model_path).eval()
    # effnet.load_state_dict(ckpt)
    return ckpt.to(device)


In [None]:
def format_sample(cell_sample, segments):
    return ",".join(
        [
            str(cell_sample.id), 
            str(cell_sample.protein.shape[0]), 
            str(cell_sample.protein.shape[1]),
            " ".join([segment.format() for segment in segments])
        ]
    ) + "\n"

In [None]:
segment_model = CellPredictionModel(nuclei_model_path, cell_model_path, device=0)


In [None]:
labeling_model_keepres = load_labeling_model(keepres_labeling_model_path, device=0)
labeling_model = load_labeling_model(labeling_model_path, device=0)
labeling_model_472 = load_labeling_model(labeling_model_path_472, device=0)


In [None]:
individual_label_predictor = SegmentLabelPredictor(labeling_model, labeling_model_keepres, labeling_model_472)

In [None]:
test_df = pd.read_csv(test_df_path)
dataset = CellDataset(test_df, test_path)

from torch.utils.data import DataLoader
data_loader = DataLoader(
            dataset,
            batch_size=batch_size,
            num_workers=num_workers,
            collate_fn=my_collate,
        )

In [None]:
from tqdm import tqdm

label_path = "/kaggle/input/calc-label-for-publiclb/labels"

def inference_with_precomputed_labels():
    with open("/kaggle/working/submission_base.csv", "w") as file:
        file.write(",".join(["ID","ImageWidth", "ImageHeight", "PredictionString"]) + "\n")
        for batch in tqdm(data_loader):
            cell_samples, _ = batch
            samples = [CellSample(*attrs) for attrs in zip(*[getattr(cell_samples, field) for field in cell_samples._fields])]
            for cell_sample in samples:
                labels = np.load(os.path.join(label_path, f"{cell_sample.id}.npy"))
                segments = ImageSegmentator.run(cell_sample, labels)
                segments = individual_label_predictor.run(segments, batch_size=16, n_augments=4)
                file.write(format_sample(cell_sample, segments))

def inference():
    with open("/kaggle/working/submission_base.csv", "w") as file:
        file.write(",".join(["ID","ImageWidth", "ImageHeight", "PredictionString"]) + "\n")
        for batch in tqdm(data_loader):
            cell_samples, cell_inputs = batch
            cell_outputs = segment_model(cell_inputs)
            for cell_sample, label_input in CellPredicitonDecomposer.run(cell_samples, cell_outputs):
                labels = CellLabelCalculator.run(label_input)
                segments = ImageSegmentator.run(cell_sample, labels)
                segments = individual_label_predictor.run(segments, batch_size=16, n_augments=4)
                file.write(format_sample(cell_sample, segments))


In [None]:
import random
import numpy
import torch


def set_random_seed(seed):
    torch.manual_seed(seed)
    random.seed(seed)
    numpy.random.seed(seed)

set_random_seed(1)

In [None]:
if len(test_df) == 559:
    inference_with_precomputed_labels()
else:
    inference()

In [None]:
del segment_model
del dataset
del data_loader
del individual_label_predictor
del labeling_model
del labeling_model_keepres

In [None]:
import gc
gc.collect()
torch.cuda.empty_cache()

## calc additional score

In [None]:
def build_decoder(with_labels=True, target_size=(300, 300), ext='jpg'):
    
    def decode(path):
        file_bytes = tf.io.read_file(path)
        if ext == 'png':
            img = tf.image.decode_png(file_bytes, channels=3)
        elif ext in ['jpg', 'jpeg']:
            img = tf.image.decode_jpeg(file_bytes, channels=3)
        else:
            raise ValueError("Image extension not supported")
        img = tf.cast(img, tf.float32) / 255.0
        img = tf.image.resize(img, target_size)
        return img
    
    def decode_with_labels(path, label):
        return decode(path), label

    return decode_with_labels if with_labels else decode

def build_augmenter(with_labels=True):
    
    def augment(img):
        img = tf.image.random_flip_left_right(img)
        img = tf.image.random_flip_up_down(img)
        return img
    
    def augment_with_labels(img, label):
        return augment(img), label
    
    return augment_with_labels if with_labels else augment


def build_dataset(paths, labels=None, bsize=32, cache=True,
                  decode_fn=None, augment_fn=None,
                  augment=True, repeat=True, shuffle=1024, 
                  cache_dir=""):
    if cache_dir != "" and cache is True:
        os.makedirs(cache_dir, exist_ok=True)
    if decode_fn is None:
        decode_fn = build_decoder(labels is not None)
    if augment_fn is None:
        augment_fn = build_augmenter(labels is not None)

    AUTO = tf.data.experimental.AUTOTUNE
    slices = paths if labels is None else (paths, labels)

    dset = tf.data.Dataset.from_tensor_slices(slices)
    dset = dset.map(decode_fn, num_parallel_calls=AUTO)
    dset = dset.cache(cache_dir) if cache else dset
    dset = dset.map(augment_fn, num_parallel_calls=AUTO) if augment else dset
    dset = dset.repeat() if repeat else dset
    dset = dset.shuffle(shuffle) if shuffle else dset
    dset = dset.batch(bsize).prefetch(AUTO)

    return dset

In [None]:
sys.path.append("/kaggle/input/efficientnet-keras-source-code/")
sys.path.append("/kaggle/input/kerasapplications")

In [None]:
import tensorflow as tf
from efficientnet.tfkeras import EfficientNetB4

whole_image_model = tf.keras.models.load_model('../input/hpa-models/HPA classification efnb7 train 13cc0d 20/model_green.h5')

In [None]:
additional_score_df = pd.read_csv(test_df_path)
additional_score_df  = additional_score_df.drop(additional_score_df.columns[1:],axis=1)
for i in range(19):
    additional_score_df[f'{i}'] = pd.Series(np.zeros(additional_score_df.shape[0]))

In [None]:
load_dir = f"/kaggle/input/hpa-single-cell-image-classification/"

test_paths = load_dir + "/test/" + additional_score_df['ID'] + '_' + 'green' + '.png' # Start making individul label
label_cols = additional_score_df.columns[1:] # Get the multi-labels
test_decoder = build_decoder(with_labels=False, target_size=(720,720))
dtest = build_dataset(
    test_paths, bsize=8, repeat=False, 
    shuffle=False, augment=False, cache=False,
    decode_fn=test_decoder
)

In [None]:
additional_score_df[label_cols] = whole_image_model.predict(dtest, verbose=1)


In [None]:
base_df = pd.read_csv("submission_base.csv")
base_df = pd.merge(base_df, additional_score_df, on = 'ID', how = 'left')

In [None]:
for row in range(base_df.shape[0]):
    pred = base_df.loc[row,'PredictionString']
    pred_split = pred.split()
    for j in range(int(len(pred_split)/3)):        
        for k in range(19):
            if int(pred_split[ 3*j ]) == k:
                p = pred_split[ 3*j + 1 ]               
                pred_split[ 3*j + 1 ] = str( base_df.loc[row, f'{k}']*0.6 + float(p)*0.4 )
    base_df.loc[row,'PredictionString'] = ' '.join(pred_split)

In [None]:
base_df = base_df[['ID','ImageWidth','ImageHeight','PredictionString']]
base_df.to_csv('submission.csv',index = False)