# Import

In [None]:
import sys
sys.path.append('../input/efficientnet-pytorch/EfficientNet-PyTorch/EfficientNet-PyTorch-master')
sys.path.append('../input/pytorch-image-models/pytorch-image-models-master')
sys.path.append('../input/pytorch-images-seresnet')

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline

from tqdm.notebook import tqdm
import os, gc
import random
import math
from PIL import Image
import tifffile as tiff
import cv2
import zipfile

import timm
from efficientnet_pytorch import EfficientNet
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torch.optim as optim

import torchvision
from torchvision import transforms
import albumentations as A
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from torch.optim.lr_scheduler import CosineAnnealingLR

import warnings
warnings.filterwarnings("ignore")
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

In [None]:
seed = 2020
seed_everything(seed)
print(device)

TEST_ROOT = '../input/hpa-single-cell-image-classification/test/'

sz = 256
bs = 64
TH = 1e-15

#ImageNet
mean = np.array([[[0.485, 0.456, 0.406]]])
std = np.array([[[0.229, 0.224, 0.225]]])

In [None]:
#commit: public
#submit: public + private
test_df = pd.read_csv('../input/hpa-single-cell-image-classification/sample_submission.csv')

In [None]:
#commit&submit: public
public = pd.read_csv('../input/hpasubmission009/sample_submission.csv')

In [None]:
#commit: public + public
#submit: public + private + public
test_df_ = pd.concat([test_df, public]).reset_index(drop=True)

In [None]:
#commit→publicの最後20個のid
if len(test_df) == 559:
    public = test_df_[-20:]

#submit→private_id
else:
    public = test_df_.drop_duplicates(keep=False).reset_index(drop=True)

# Functions

In [None]:
def collate_fn(batch):
    return tuple(zip(*batch))

In [None]:
def load_RGB_image(image_id_path):
    red = cv2.imread(image_id_path+"_red.png", cv2.IMREAD_GRAYSCALE) #HW
    green = cv2.imread(image_id_path+"_green.png", cv2.IMREAD_GRAYSCALE)
    blue = cv2.imread(image_id_path+"_blue.png", cv2.IMREAD_GRAYSCALE)
    
    #CHW
    stacked_image = np.array([red, green, blue])
    return stacked_image

In [None]:
def load_GGG_image(image_id_path):
    green = cv2.imread(image_id_path+"_green.png", cv2.IMREAD_GRAYSCALE)
    
    #CHW
    stacked_image = np.array([green, green, green])
    return stacked_image

# Dataset

In [None]:
class HPADataset(Dataset):
    def __init__(self, path, df, nuc_masks_batch, cell_masks_batch, transform=None):
        self.path = path
        self.df = df
        self.nuc_masks_batch = nuc_masks_batch
        self.cell_masks_batch = cell_masks_batch
        self.transform = transform
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):  #1バッチ内でのidx
        img_path = os.path.join(self.path, self.df.iloc[idx, 0])
        image = load_RGB_image(img_path).astype(np.float32)  #3*H*W
        green = load_GGG_image(img_path).astype(np.float32)  #3*H*W
        img_512 = cv2.resize(np.transpose(green, (1, 2, 0)), (512, 512))  #512*512*3
        img_512 = (img_512/255.0 - mean) / std  #Normalization
        img_512 = np.transpose(img_512, (2, 0, 1))  #3*512*512
        img_512 = torch.from_numpy(img_512)
        
        nuc_mask = nuc_masks_batch[idx]  #H*W
        cell_mask = cell_masks_batch[idx]  #H*W
        img_tiles = []
        img_centers = []
        
        for i in range(1, np.max(cell_mask)+1):
            try: #細胞マスクができる時
                cx1 = np.min(np.where(cell_mask==i)[1])
                cx2 = np.max(np.where(cell_mask==i)[1])
                cy1 = np.min(np.where(cell_mask==i)[0])
                cy2 = np.max(np.where(cell_mask==i)[0])
            
            except: #細胞マスクができない時はラベル予測をしても意味がない
                continue
            
            try: #核マスクができる時
                nx1 = np.min(np.where(nuc_mask==i)[1])
                nx2 = np.max(np.where(nuc_mask==i)[1])
                ny1 = np.min(np.where(nuc_mask==i)[0])
                ny2 = np.max(np.where(nuc_mask==i)[0])
#                 xc = (nx1 + nx2) // 2
#                 yc = (ny1 + ny2) // 2
#                 img_centers.append([xc, yc])  #核BBoxの中心座標
                
            except: #核マスクができない時
                continue
#                 blue = image[2]
#                 cell_mask_ = np.where(cell_mask==i, 0, -255)  #該当細胞マスク領域は0、それ以外は-255の配列
#                 rblue = np.clip(blue+cell_mask_, 0, 255).astype(np.uint8)
#                 nuc_coo = np.argwhere((rblue>=1)&(rblue<=255))  #細胞マスク内でblueが1以上の値を取る座標(核+α?)
#                 nuc_center = np.median(nuc_coo, axis=0).astype(np.int32)  #上記座標の中央値(外れ値に強くするため)
#                 xc = nuc_center[1]
#                 yc = nuc_center[0]
#                 img_centers.append([xc, yc])
#                 del blue, cell_mask_, rblue, nuc_coo
#                 gc.collect()
                
            
            #核BBoxの中心座標
            xc = (nx1 + nx2) // 2
            yc = (ny1 + ny2) // 2
            img_centers.append([xc, yc])           
    
            #切り出すタイルの1辺の長さaは細胞BBoxの短辺とする
            w = cx2 - cx1
            h = cy2 - cy1
            if w <= h: a = w
            else: a = h
            
            #padding追加
            pad0 = a
            pad1 = a
            image_ = np.pad(image, [(0, 0), (pad0//2, pad0-pad0//2), (pad1//2, pad1-pad1//2)], constant_values=0)
            
            #切り出す細胞タイルの左上と右下の座標を求める(+padding補正)
            rx1 = xc - a//2 + a//2
            rx2 = xc + a//2 + a//2
            ry1 = yc - a//2 + a//2
            ry2 = yc + a//2 + a//2
            
            #タイル切り出し
            tile = image_[:, ry1:ry2, rx1:rx2]  #3*a*a
            tile = np.transpose(tile, (1, 2, 0))  #a*a*3
            rtile = cv2.resize(tile, (sz, sz))  #sz*sz*3
            rtile = (rtile/255.0 - mean) / std  #Normalization
            rtile = np.transpose(rtile, (2, 0, 1))  #3*sz*sz

            rtile = torch.from_numpy(rtile)
            img_tiles.append(rtile)
            
            del image_, tile, rtile
            gc.collect()
            
        img_centers = np.array(img_centers)
        
        del image, nuc_mask, cell_mask, green
        gc.collect()
        
        return img_tiles, img_centers, img_512

# Tile Model

In [None]:
#EfficientNet B5
tile_model1 = EfficientNet.from_name('efficientnet-b5')
tile_model1._fc = nn.Linear(in_features=2048, out_features=19)
tile_model1.load_state_dict(torch.load('../input/hpa-vol4tox-upsampling-6enspl-efb5-weight/efficientnetb5_seed_2020_single_fold.pth'))
tile_model1.to(device)

In [None]:
#EfficientNet B6
tile_model2 = EfficientNet.from_name('efficientnet-b6')
tile_model2._fc = nn.Linear(in_features=2304, out_features=19)
tile_model2.load_state_dict(torch.load('../input/hpa-vol4tox-upsampling-6enspl-efb6-weight/efficientnetb6_seed_2020_single_fold.pth'))
tile_model2.to(device)

In [None]:
#EfficientNet B7
tile_model3 = EfficientNet.from_name('efficientnet-b7')
tile_model3._fc = nn.Linear(in_features=2560, out_features=19)
tile_model3.load_state_dict(torch.load('../input/hpa-vol4tox-upsampling-6enspl-efb7-weight/efficientnetb7_seed_2020_single_fold.pth'))
tile_model3.to(device)

In [None]:
#SEResNeXt50
class SRNX50(nn.Module):
    def __init__(self, model_name='seresnext50_32x4d', pretrained=True):
        super().__init__()
        self.model = timm.create_model(model_name, pretrained=pretrained)
        n_features = self.model.fc.in_features
        self.model.fc = nn.Linear(n_features, 19)

    def forward(self, x):
        x = self.model(x)
        return x
    
tile_model4 = SRNX50(pretrained=False)
tile_model4.load_state_dict(torch.load('../input/hpa-vol4tox-upsampling-6enspl-srnx50-weight/seresnext50_seed_2020_single_fold.pth'))
tile_model4.to(device)

In [None]:
#CSPResNeXt50
class CSPNetModel(nn.Module):
    
    def __init__(self, num_classes=19, model_name='cspresnext50', pretrained=True):
        super(CSPNetModel, self).__init__()
        self.model = timm.create_model(model_name, pretrained=pretrained)
        self.model.head.fc = nn.Linear(self.model.head.fc.in_features, 19)
        
    def forward(self, x):
        x = self.model(x)
        return x
    
tile_model5 = CSPNetModel(pretrained=False)
tile_model5.load_state_dict(torch.load('../input/hpa-vol4tox-upsampling-6enspl-csprnx50-weight/cspresnext50_seed_2020_single_fold.pth'))
tile_model5.to(device)

In [None]:
#NFNet F1
class NFNet(nn.Module):
    def __init__(self, output_features=19, model_name='nfnet_f1', pretrained=True):
        super(NFNet, self).__init__()
        self.model = timm.create_model(model_name, pretrained=pretrained)
        self.model.head.fc = nn.Sequential(nn.Linear(self.model.head.fc.in_features, 512),
                                 nn.ReLU(),
                                 nn.Linear(512, output_features))
        
    def forward(self, x):
        x = self.model(x)
        return x

tile_model6 = NFNet(pretrained=False)
tile_model6.load_state_dict(torch.load('../input/hpa-vol4tox-upsampling-6enspl-nfnetf1-weight/nfnetf1_seed_2020_single_fold.pth'))
tile_model6.to(device)

# Image Model

In [None]:
#EfficientNet B7
img_model = EfficientNet.from_name('efficientnet-b7')
img_model._fc = nn.Linear(in_features=2560, out_features=19)
img_model.load_state_dict(torch.load('../input/hpa-ill-ggg-efb7-weight/efficientnetb7_seed_2020_single_fold.pth'))
img_model.to(device)

# Inference function

In [None]:
def inference_per_batch(data_loader, 
                        tile_model1, 
                        tile_model2, 
                        tile_model3, 
                        tile_model4, 
                        tile_model5, 
                        tile_model6, 
                        img_model, 
                        device):
    
    tile_model1.eval()  #EFB5
    tile_model2.eval()  #EFB6
    tile_model3.eval()  #EFB7
    tile_model4.eval()  #SRNX50
    tile_model5.eval()  #CSPRNX50
    tile_model6.eval()  #NFNetF1
    img_model.eval()  #EFB7
    
    for i, (img_tiles, img_centers, img) in enumerate(data_loader):  #1バッチ"だけ"取り出す
        preds = []
        
        for j in range(len(img_tiles)):  #1バッチ内のj番目の画像に対して
            tiles = img_tiles[j]
            preds_per_image = []
            img_j = img[j].to(device, dtype=torch.float)
            img_j = img_j.unsqueeze(0)
            img_pred = img_model(img_j)
            img_pred = nn.Sigmoid()(img_pred)  #1*19
            img_pred[0][11] = 0.0
            img_pred[0][18] = 0.0
            
            for k in range(len(tiles)):  #1バッチ内のj番目の画像のk番目にtileに対して
                tile = tiles[k]
                tile = torch.unsqueeze(tile, 0)  #(batch)の次元を増やす
                tile = tile.to(device, dtype=torch.float)
        
                with torch.no_grad():
                    tile_pred1 = tile_model1(tile)
                    tile_pred2 = tile_model2(tile)
                    tile_pred3 = tile_model3(tile)
                    tile_pred4 = tile_model4(tile)
                    tile_pred5 = tile_model5(tile)
                    tile_pred6 = tile_model6(tile)
                    tile_pred1 = nn.Sigmoid()(tile_pred1)  #1*19
                    tile_pred2 = nn.Sigmoid()(tile_pred2)  #1*19
                    tile_pred3 = nn.Sigmoid()(tile_pred3)  #1*19
                    tile_pred4 = nn.Sigmoid()(tile_pred4)  #1*19
                    tile_pred5 = nn.Sigmoid()(tile_pred5)  #1*19
                    tile_pred6 = nn.Sigmoid()(tile_pred6)  #1*19
                    tile_pred = (tile_pred1 + tile_pred2 + tile_pred3 + tile_pred4 + tile_pred5 + tile_pred6) / 6

                    pred = tile_pred * 0.75 + img_pred * 0.25
                    
                preds_per_image.append(pred.detach().cpu().numpy())
        
                del tile, pred, tile_pred1, tile_pred2, tile_pred3, tile_pred4, tile_pred5, tile_pred6, tile_pred
                gc.collect()
            
            preds_per_image = np.stack(preds_per_image)  #j番目の画像に関してk個のタイル予測結果を結合
            preds.append(preds_per_image)    

            del img_j, img_pred
            gc.collect()
            
    return preds, img_centers

# HPA-cellsegmentator

In [None]:
!pip install "../input/pycocotools/pycocotools-2.0-cp37-cp37m-linux_x86_64.whl"
!pip install "../input/hpapytorchzoozip/pytorch_zoo-master"
!pip install "../input/hpacellsegmentatorraman/HPA-Cell-Segmentation/"

In [None]:
import warnings
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    import numpy as np
    import pandas as pd
    import os
    import gc
    import os.path
    import urllib
    import zipfile
    from hpacellseg.cellsegmentator import *
    from hpacellseg import cellsegmentator, utils
    import cv2
    import scipy.ndimage as ndi
    from skimage import filters, measure, segmentation, transform, util
    from skimage.morphology import (binary_erosion, closing, disk, remove_small_holes, remove_small_objects)
    from PIL import Image
    import matplotlib.pyplot as plt

In [None]:
NUC_MODEL = "../input/hpacellsegmentatormodelweights/dpn_unet_nuclei_v1.pth"
CELL_MODEL = "../input/hpacellsegmentatormodelweights/dpn_unet_cell_3ch_v1.pth"
segmentator_even_faster = cellsegmentator.CellSegmentator(
    NUC_MODEL,
    CELL_MODEL,
    device="cuda",
    multi_channel_model=True,
)

In [None]:
def load_images(df, size, root='../input/hpa-single-cell-image-classification/test/'):
    blue_scaled = []
    rgb_scaled = []
    for id in list(df.ID):
        r = cv2.imread(os.path.join(root, f'{id}_red.png'), cv2.IMREAD_GRAYSCALE)
        y = cv2.imread(os.path.join(root, f'{id}_yellow.png'), cv2.IMREAD_GRAYSCALE)
        b = cv2.imread(os.path.join(root, f'{id}_blue.png'), cv2.IMREAD_GRAYSCALE)
        blue_image = cv2.resize(b, (int(size*0.25), int(size*0.25)))
        rgb_image = cv2.resize(np.stack((r, y, b), axis=2), (int(size*0.25), int(size*0.25)))
        blue_scaled.append(blue_image/255.)
        rgb_scaled.append(rgb_image/255.)
        del r, y, b, blue_image, rgb_image
        gc.collect()
    return blue_scaled, rgb_scaled

In [None]:
import base64
import numpy as np
from pycocotools import _mask as coco_mask
import typing as t
import zlib


def encode_binary_mask(mask: np.ndarray) -> t.Text:
  """Converts a binary mask into OID challenge encoding ascii text."""

  # check input mask --
  if mask.dtype != np.bool:
    raise ValueError(
        "encode_binary_mask expects a binary mask, received dtype == %s" %
        mask.dtype)

  mask = np.squeeze(mask)
  if len(mask.shape) != 2:
    raise ValueError(
        "encode_binary_mask expects a 2d mask, received shape == %s" %
        mask.shape)

  # convert input mask to expected COCO API input --
  mask_to_encode = mask.reshape(mask.shape[0], mask.shape[1], 1)
  mask_to_encode = mask_to_encode.astype(np.uint8)
  mask_to_encode = np.asfortranarray(mask_to_encode)

  # RLE encode mask --
  encoded_mask = coco_mask.encode(mask_to_encode)[0]["counts"]

  # compress and base64 encoding --
  binary_str = zlib.compress(encoded_mask, zlib.Z_BEST_COMPRESSION)
  base64_str = base64.b64encode(binary_str)
  return base64_str.decode('ascii')

# Inference

In [None]:
#Image sizeごとにtest_dfを分割
#それぞれindexを初期化(0スタート)してからsub_dfsに格入
sub_dfs = []
for dim in public.ImageWidth.unique():
    df = public[public['ImageWidth']==dim].copy().reset_index(drop=True)
    sub_dfs.append(df)

for sub in sub_dfs:
    print(f'<<<<<<<<<<Inference for image size: {sub.ImageWidth.loc[0]}>>>>>>>>>>')
    
    for start in tqdm(range(0, len(sub), bs)):
        #1バッチごとにcell segmentation→label inference→mask and label matchingを行う
        #start: 0, bs, 2*bs, 3*bs...
        #img_num: 1バッチに含まれる画像数(id数)
        if len(sub) < bs:
            img_num = len(sub)
        elif len(sub) - start < bs:
            img_num = len(sub) - start
        else:
            img_num = bs        
        
        ############################################################################ 
        
        #fast cell segmentation
        print(f'Image {sub.ImageWidth.loc[0]} Batch {int(start/bs)+1}: Cell Segmentation')
        data_df = sub[start:start+img_num]  #subの1バッチ分のdf
        blue_scaled, rgb_scaled = load_images(df=data_df, size=sub.ImageWidth.loc[0])
        nuc_masks_batch = []
        cell_masks_batch = []
        batch_size = 24
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            for i in range(0, len(data_df), batch_size):
                s = i
                e = min(len(data_df), s+batch_size)
                blue_batch = blue_scaled[s:e]
                rgb_batch = rgb_scaled[s:e]
                nuc_segmentations = segmentator_even_faster.pred_nuclei(blue_batch)
                cell_segmentations = segmentator_even_faster.pred_cells(rgb_batch, precombined=True)
                for data_id, nuc_seg, cell_seg in zip(data_df.ID.to_list(), nuc_segmentations, cell_segmentations):
                    nuc_mask, cell_mask = utils.label_cell(nuc_seg, cell_seg)
                    #マスクに関してはinterpolation大切
                    r_nuc_mask = cv2.resize(nuc_mask.astype(np.uint8), (sub.ImageWidth.loc[0], sub.ImageWidth.loc[0]), interpolation=cv2.INTER_NEAREST)
                    r_cell_mask = cv2.resize(cell_mask.astype(np.uint8), (sub.ImageWidth.loc[0], sub.ImageWidth.loc[0]), interpolation=cv2.INTER_NEAREST)                        
                    nuc_masks_batch.append(r_nuc_mask)
                    cell_masks_batch.append(r_cell_mask)
                    del nuc_mask, cell_mask, r_nuc_mask, r_cell_mask
                    gc.collect()
                del blue_batch, rgb_batch, nuc_segmentations, cell_segmentations
                gc.collect()
        del blue_scaled, rgb_scaled, data_df
        gc.collect()
        
        ############################################################################
        
        #label inference
        print(f'Image {sub.ImageWidth.loc[0]} Batch {int(start/bs)+1}: Label inference')
        test_ds = HPADataset(path=TEST_ROOT,
                             df=sub.iloc[start:start+img_num],
                             nuc_masks_batch=nuc_masks_batch,
                             cell_masks_batch=cell_masks_batch,
                             transform=None)
        test_dl = DataLoader(dataset=test_ds,
                             batch_size=img_num,
                             shuffle=False,
                             collate_fn=collate_fn,
                             num_workers=0)
        preds, centers = inference_per_batch(test_dl, 
                                             tile_model1, 
                                             tile_model2, 
                                             tile_model3, 
                                             tile_model4, 
                                             tile_model5, 
                                             tile_model6,
                                             img_model, 
                                             device)
        del test_ds, test_dl, nuc_masks_batch
        gc.collect()
        
        ############################################################################

        #mask and label matching
        print(f'Image {sub.ImageWidth.loc[0]} Batch {int(start/bs)+1}: Mask and Label matching')
        predstrings = []    
        for i in range(img_num):  #1バッチ内のi番目の画像に対して          
            preds_ = preds[i]
            centers_ = centers[i]
            cell_masks_ = cell_masks_batch[i]
            predstring = ''
            all_masks = np.arange(1, np.max(cell_masks_)+1)
            pocs = []
        
            for t in range(len(preds_)):
                poc = cell_masks_[centers_[t][1], centers_[t][0]]  #poc: 細胞マスクの核座標におけるピクセル値
                if poc == 0: continue  #偶発的に0の場合
                pocs.append(poc)
                lpred = preds_[t]  #lpred: タイルのラベル予測(1*19)
                lpred_arr = np.where(lpred.flatten()>TH)[0]  #lpred_arr: THより大きいラベル一覧を取得
                bmask = (cell_masks_==poc)
                enc = encode_binary_mask(bmask)
                
                if len(lpred_arr) == 0:  #生予測値の全てがTH以下で予測ラベルが存在しない時→ひとまず'18'にする
                    predstring += '18' + f' {lpred[0][18]} ' + enc + ' '
                else:  #1タイルに対してちゃんと1つ以上のラベルが存在する時
                    for l in lpred_arr:
                        predstring += f'{l}' + f' {lpred[0][l]} ' + enc + ' '        
                del bmask, enc
                gc.collect()
        
            #細胞マスキングはできたが核マスキングはできなかった場合
            lack = list(set(all_masks) - set(pocs))
            labels = np.arange(19)
            for ll in lack:
                bmask = (cell_masks_==ll)
                enc = encode_binary_mask(bmask)
                for l in labels:
                    predstring += 'l' + f' {TH} ' + enc + ' '
                del bmask, enc
                gc.collect()
            predstrings.append(predstring)
            del preds_, centers_, cell_masks_, predstring
            gc.collect()
        
        sub['PredictionString'].iloc[start:start+img_num] = predstrings
        del predstrings, cell_masks_batch, preds, centers
        gc.collect()

In [None]:
all_subs = pd.concat(sub_dfs, ignore_index=True, sort=False)

In [None]:
all_subs.head(20)

In [None]:
all_subs.iloc[0, 3]

In [None]:
private_dict = dict(zip(all_subs['ID'], all_subs['PredictionString']))

In [None]:
test_df['PredictionString'] = test_df['ID'].map(private_dict).fillna(test_df['PredictionString'])

In [None]:
test_df.to_csv('submission.csv', index=False)