In [None]:
import torch
import torchvision
from torch import nn, Tensor
from torch.nn import functional as F
from torch.jit.annotations import List, Tuple, Dict, Optional


import re
import pickle
import math
import numpy as np

from torchvision import utils
from PIL import Image

import os
from os.path import join
import copy

import zlib
import base64
import albumentations as A

import random
import math

from torchvision.models.detection.image_list import ImageList
from torchvision.models.detection.roi_heads import paste_masks_in_image

In [None]:
!pip install "/kaggle/input/pycocotools/pycocotools-2.0-cp37-cp37m-linux_x86_64.whl"

import pycocotools
from pycocotools import mask as cocomask

In [None]:
num_classes = 20
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor

In [None]:
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone
from torchvision import models

def create_backbone(num_classes):
    backbone_wide = resnet_fpn_backbone('resnet50', False)
    conv1_weight = backbone_wide._modules['body']._modules['conv1'].weight.clone()
    backbone_wide._modules['body']._modules['conv1'] = nn.Conv2d(4, 64, kernel_size=7, stride=2, padding=3, bias=False)
    with torch.no_grad():
        backbone_wide._modules['body']._modules['conv1'].weight[:, :3] = conv1_weight
        backbone_wide._modules['body']._modules['conv1'].weight[:, 3] = backbone_wide._modules['body']._modules['conv1'].weight[:, 0]
    return backbone_wide

In [None]:
def get_model_instance_segmentation(num_classes):
#     with open('/kaggle/input/resnet50pretraineddefault52percent-net/resnet50fpn_default_pretrained1.pkl', 'rb') as input:
#         net = pickle.load(input)
    model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=False, 
                                                               progress=False, 
                                                               num_classes=num_classes, 
                                                               pretrained_backbone=False)

    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 256
    model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, hidden_layer, num_classes)
    return model

In [None]:
def _resize_image_and_masks(image, self_min_size, self_max_size, target):
    # type: (Tensor, float, float, Optional[Dict[str, Tensor]]) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]
    im_shape = torch.tensor(image.shape[-2:])
    min_size = float(torch.min(im_shape))
    max_size = float(torch.max(im_shape))
    scale_factor = self_min_size / min_size
    if max_size * scale_factor > self_max_size:
        scale_factor = self_max_size / max_size
        
    image = torch.nn.functional.interpolate(
        image[None], scale_factor=scale_factor, mode='bilinear', recompute_scale_factor=True,
        align_corners=False)[0]

    if target is None:
        return image, target

    if "masks" in target:
        mask = target["masks"]
        
        mask = F.interpolate(mask[:, None].float(), scale_factor=scale_factor)[:, 0].byte()
        target["masks"] = mask
    return image, target

class GeneralizedRCNNTransform_(nn.Module):
    """
    Performs input / target transformation before feeding the data to a GeneralizedRCNN
    model.

    The transformations it perform are:
        - input normalization (mean subtraction and std division)
        - input / target resizing to match min_size / max_size

    It returns a ImageList for the inputs, and a List[Dict[Tensor]] for the targets
    """

    def __init__(self, min_size, max_size, image_mean, image_std):
        super(GeneralizedRCNNTransform_, self).__init__()
        if not isinstance(min_size, (list, tuple)):
            min_size = (min_size,)
        self.min_size = min_size
        self.max_size = max_size
        self.image_mean = image_mean
        self.image_std = image_std

    def forward(self,
                images,       # type: List[Tensor]
                targets=None  # type: Optional[List[Dict[str, Tensor]]]
                ):
        # type: (...) -> Tuple[ImageList, Optional[List[Dict[str, Tensor]]]]
        images = [img for img in images]
        if targets is not None:
            # make a copy of targets to avoid modifying it in-place
            # once torchscript supports dict comprehension
            # this can be simplified as as follows
            # targets = [{k: v for k,v in t.items()} for t in targets]
            targets_copy: List[Dict[str, Tensor]] = []
            for t in targets:
                data: Dict[str, Tensor] = {}
                for k, v in t.items():
                    data[k] = v
                targets_copy.append(data)
            targets = targets_copy
        for i in range(len(images)):
            image = images[i]
            target_index = targets[i] if targets is not None else None
            
#             print("target_index before = ", target_index)
#             print("image.shape before = ", image.shape)

#             if image.dim() != 3:
#                 raise ValueError("images is expected to be a list of 3d tensors "
#                                  "of shape [C, H, W], got {}".format(image.shape))
            image = self.normalize(image)
            image, target_index = self.resize(image, target_index)
            
#             print("target_index after= ", target_index)
#             print("image.shape after = ", image.shape)
            
            images[i] = image
            if targets is not None and target_index is not None:
                targets[i] = target_index

        image_sizes = [img.shape[-2:] for img in images]
        images = self.batch_images(images)
        image_sizes_list = torch.jit.annotate(List[Tuple[int, int]], [])
        for image_size in image_sizes:
            assert len(image_size) == 2
            image_sizes_list.append((image_size[0], image_size[1]))

        image_list = ImageList(images, image_sizes_list)
    
        return image_list, targets

    def normalize(self, image):
        dtype, device = image.dtype, image.device
        mean = torch.as_tensor(self.image_mean, dtype=dtype, device=device)
        std = torch.as_tensor(self.image_std, dtype=dtype, device=device)
        return (image - mean[:, None, None]) / std[:, None, None]

    def torch_choice(self, k):
        # type: (List[int]) -> int
        """
        Implements `random.choice` via torch ops so it can be compiled with
        TorchScript. Remove if https://github.com/pytorch/pytorch/issues/25803
        is fixed.
        """
        index = int(torch.empty(1).uniform_(0., float(len(k))).item())
        return k[index]

    def resize(self, image, target):
        # type: (Tensor, Optional[Dict[str, Tensor]]) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]
        h, w = image.shape[-2:]
        if self.training:
            size = float(self.torch_choice(self.min_size))
        else:
            # FIXME assume for now that testing uses the largest scale
            size = float(self.min_size[-1])
        if torchvision._is_tracing():
            image, target = _resize_image_and_masks_onnx(image, size, float(self.max_size), target)
        else:
            image, target = _resize_image_and_masks(image, size, float(self.max_size), target)

        if target is None:
            return image, target

        bbox = target["boxes"]
        bbox = resize_boxes(bbox, (h, w), image.shape[-2:])
        target["boxes"] = bbox

        if "keypoints" in target:
            keypoints = target["keypoints"]
            keypoints = resize_keypoints(keypoints, (h, w), image.shape[-2:])
            target["keypoints"] = keypoints
        return image, target

    # _onnx_batch_images() is an implementation of
    # batch_images() that is supported by ONNX tracing.
    @torch.jit.unused
    def _onnx_batch_images(self, images, size_divisible=32):
        # type: (List[Tensor], int) -> Tensor
        max_size = []
        for i in range(images[0].dim()):
            max_size_i = torch.max(torch.stack([img.shape[i] for img in images]).to(torch.float32)).to(torch.int64)
            max_size.append(max_size_i)
        stride = size_divisible
        max_size[1] = (torch.ceil((max_size[1].to(torch.float32)) / stride) * stride).to(torch.int64)
        max_size[2] = (torch.ceil((max_size[2].to(torch.float32)) / stride) * stride).to(torch.int64)
        max_size = tuple(max_size)

        # work around for
        # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
        # which is not yet supported in onnx
        padded_imgs = []
        for img in images:
            padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
            padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
            padded_imgs.append(padded_img)

        return torch.stack(padded_imgs)

    def max_by_axis(self, the_list):
        # type: (List[List[int]]) -> List[int]
        maxes = the_list[0]
        for sublist in the_list[1:]:
            for index, item in enumerate(sublist):
                maxes[index] = max(maxes[index], item)
        return maxes

    def batch_images(self, images, size_divisible=32):
        # type: (List[Tensor], int) -> Tensor
        if torchvision._is_tracing():
            # batch_images() does not export well to ONNX
            # call _onnx_batch_images() instead
            return self._onnx_batch_images(images, size_divisible)

        max_size = self.max_by_axis([list(img.shape) for img in images])
        stride = float(size_divisible)
        max_size = list(max_size)
        max_size[1] = int(math.ceil(float(max_size[1]) / stride) * stride)
        max_size[2] = int(math.ceil(float(max_size[2]) / stride) * stride)

        batch_shape = [len(images)] + max_size
        batched_imgs = images[0].new_full(batch_shape, 0)
        for img, pad_img in zip(images, batched_imgs):
            pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)

        return batched_imgs

    def postprocess(self,
                    result,               # type: List[Dict[str, Tensor]]
                    image_shapes,         # type: List[Tuple[int, int]]
                    original_image_sizes  # type: List[Tuple[int, int]]
                    ):
        # type: (...) -> List[Dict[str, Tensor]]
        if self.training:
            return result
        for i, (pred, im_s, o_im_s) in enumerate(zip(result, image_shapes, original_image_sizes)):
            boxes = pred["boxes"]
            boxes = resize_boxes(boxes, im_s, o_im_s)
            result[i]["boxes"] = boxes
            if "masks" in pred:
                masks = pred["masks"]
                masks = paste_masks_in_image(masks, boxes, o_im_s)
                result[i]["masks"] = masks
            if "keypoints" in pred:
                keypoints = pred["keypoints"]
                keypoints = resize_keypoints(keypoints, im_s, o_im_s)
                result[i]["keypoints"] = keypoints
        return result

    def __repr__(self):
        format_string = self.__class__.__name__ + '('
        _indent = '\n    '
        format_string += "{0}Normalize(mean={1}, std={2})".format(_indent, self.image_mean, self.image_std)
        format_string += "{0}Resize(min_size={1}, max_size={2}, mode='bilinear')".format(_indent, self.min_size,
                                                                                         self.max_size)
        format_string += '\n)'
        return format_string
    
    
def resize_boxes(boxes, original_size, new_size):
    # type: (Tensor, List[int], List[int]) -> Tensor
    ratios = [
        torch.tensor(s, dtype=torch.float32, device=boxes.device) /
        torch.tensor(s_orig, dtype=torch.float32, device=boxes.device)
        for s, s_orig in zip(new_size, original_size)
    ]
    ratio_height, ratio_width = ratios
    xmin, ymin, xmax, ymax = boxes.unbind(1)

    xmin = xmin * ratio_width
    xmax = xmax * ratio_width
    ymin = ymin * ratio_height
    ymax = ymax * ratio_height
    return torch.stack((xmin, ymin, xmax, ymax), dim=1)

In [None]:
def collate_fn(batch):
    return tuple(zip(*batch))

In [None]:
# means = {'red': 41.33620795268639, 'green': 25.602606457251067, 'blue': 27.628134909278703, 'yellow': 41.59902456019701}
# stds = {'red': 68.24592496125408, 'green': 43.672423320044, 'blue': 74.81053636746947, 'yellow': 66.31413028835794}
# counts = {'red': 515, 'green': 492, 'blue': 516, 'yellow': 501}

In [None]:
net = get_model_instance_segmentation(num_classes)
net._modules['backbone'] = create_backbone(num_classes)

transform_ = GeneralizedRCNNTransform_((800,), 1333, [0, 0, 0, 0], [1, 1, 1, 1])
# transform_ = GeneralizedRCNNTransform_((800,), 1333, [means['red']/counts['red'], means['green']/counts['green'], means['blue']/counts['blue'], means['yellow']/counts['yellow']], 
#                     [stds['red']/counts['red'], stds['green']/counts['green'], stds['blue']/counts['blue'], stds['yellow']/counts['yellow']])

net._modules['transform'] = transform_

with open('/kaggle/input/maskrcnn-full-dataset/maskrcnn_32percent_fulldataset4ch.pth', 'rb') as input:
    model_weights = pickle.load(input)
    net.load_state_dict(model_weights)

In [None]:
net.to(device);

In [None]:
from os import walk
_, _, filenames = next(walk('/kaggle/input/hpa-single-cell-image-classification/test'))

img_ids = []
for file in filenames:
    im_id = file.split('_')[0]
    img_ids.append(im_id)

unique_img_ids = set(img_ids)

In [None]:
from albumentations.pytorch.transforms import ToTensorV2

transforms_test = A.Compose([A.ToFloat(), ToTensorV2(transpose_mask=True, always_apply=True, p=1.0)])

In [None]:
class HPADatasetTest(object):
    def __init__(self, unique_img_ids, transforms):
        self.transforms = transforms
        self.img_list = list(unique_img_ids)

    def __getitem__(self, idx):
        img_dir = '/kaggle/input/hpa-single-cell-image-classification/test'
        img_dict = {'red' : None, 'green' : None, 'blue' : None, 'yellow' : None}
#         img_dict = {'red' : None, 'green' : None, 'blue' : None}
        for key in img_dict.keys():
            with Image.open(os.path.join(img_dir, self.img_list[idx] + "_" + key + ".png")) as img:
                img_dict[key] = np.array(img).astype(np.float32)
#                 print("before", idx, np.max(img_dict[key]))
                if np.max(img_dict[key]) > 255:
                      img_dict[key] = img_dict[key]*255/65535
                img_dict[key] = np.floor(img_dict[key]).astype(np.ubyte)
#                 print("after", idx, np.max(img_dict[key]))
                        
#                 if np.max(img_dict[key]) > 1:
#                       img_dict[key] = img_dict[key]/256
        
        img = np.stack((img_dict[k] for k in img_dict.keys()), axis = 2)
        augmented = self.transforms(image = img)
        img = augmented['image']
        
        return img, self.img_list[idx]

    def __len__(self):
        return len(self.img_list)

In [None]:
def collate_fn_cust(batch):
    return batch

In [None]:
testset = HPADatasetTest(unique_img_ids, transforms_test)
batch_size = 1
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, 
                                         num_workers=2, pin_memory=False, collate_fn=collate_fn_cust)

In [None]:
min_score = 0.0
f = open("submission.csv", "w")
f.write("ID,ImageWidth,ImageHeight,PredictionString\n")
net.eval()
with torch.no_grad():
    for images in testloader:   
        IDs = list(image[1] for image in images)
        images = list(image[0].to(device) for image in images)
        preds = net(images)
        for pred_id, pred in enumerate(preds):
            image_id = IDs[pred_id]
            image_width = pred['masks'].shape[2]
            image_height = pred['masks'].shape[3]
            
            line = image_id + ','+ str(image_width) + ',' + str(image_height) + ','
            
            for score_id, score in enumerate(pred['scores']):
                if score > min_score:
                    label = pred['labels'][score_id].item()
                    if label == 19:
                        label = 0
                        
                    mask = pred['masks'][score_id, 0, ::].cpu()
                    mask = torch.where(mask>0.5, 1, 0)
                    rle_mask = cocomask.encode(np.asfortranarray(mask.type(torch.uint8)))['counts']
                    rle_mask_zlib_64 = base64.b64encode(zlib.compress(rle_mask))                    
                    line += str(label) + ' ' + str(score.item()) + ' ' + rle_mask_zlib_64.decode('ascii') + ' '
            line += '\n'
            f.write(line)
            f.flush()
f.close()