# Description
This kernel provides a starter Pytorch code for inference that performs dividing the images into tiles([based on this kernel](https://www.kaggle.com/iafoss/256x256-images)), selection of tiles with tissue, evaluation of the predictions of multiple models with TTA, combining the tile masks back into image level masks, and conversion into RLE. The inference is performed based on models trained in the [fast.ai starter kernel](https://www.kaggle.com/iafoss/hubmap-fast-ai-starter), provided by me. I hope it will help you to get started with this competition.

In [None]:
!mkdir -p /tmp/pip/cache/
!cp ../input/segmentationmodelspytorch/segmentation_models/efficientnet_pytorch-0.6.3.xyz /tmp/pip/cache/efficientnet_pytorch-0.6.3.tar.gz
!cp ../input/segmentationmodelspytorch/segmentation_models/pretrainedmodels-0.7.4.xyz /tmp/pip/cache/pretrainedmodels-0.7.4.tar.gz
!cp ../input/segmentationmodelspytorch/segmentation_models/segmentation-models-pytorch-0.1.2.xyz /tmp/pip/cache/segmentation_models_pytorch-0.1.2.tar.gz
!cp ../input/segmentationmodelspytorch/segmentation_models/timm-0.1.20-py3-none-any.whl /tmp/pip/cache/
!cp ../input/segmentationmodelspytorch/segmentation_models/timm-0.2.1-py3-none-any.whl /tmp/pip/cache/
!pip install --no-index --find-links /tmp/pip/cache/ efficientnet-pytorch
!pip install --no-index --find-links /tmp/pip/cache/ segmentation-models-pytorch
!mkdir -p /root/.cache/torch/hub/checkpoints/
!cp ../input/pytorch-pretrained-models/se_resnext50_32x4d-a260b3a4.pth /root/.cache/torch/hub/checkpoints/
!cp ../input/pytorch-pretrained-models/resnet34-333f7ec4.pth /root/.cache/torch/hub/checkpoints/

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
import tifffile as tiff
import cv2
import os
import gc
from tqdm.notebook import tqdm

from fastai.vision.all import *
from torch.utils.data import Dataset, DataLoader

import warnings
warnings.filterwarnings("ignore")
import segmentation_models_pytorch as smp

In [None]:
sz = 256   #the size of tiles
reduce = 4 #reduce the original images by 4 times
# TH = 0.39  #threshold for positive predictions
TH = 0.59  #threshold for positive predictions

DATA = '../input/hubmap-kidney-segmentation/test/'
MODELS = [f'../input/hubmap-fast-ai-starter/model_{i}.pth' for i in range(4)]
df_sample = pd.read_csv('../input/hubmap-kidney-segmentation/sample_submission.csv')
bs = 64
NUM_WORKERS = 2
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
device

# Data

In [None]:
#functions to convert encoding to mask and mask to encoding
def enc2mask(encs, shape):
    img = np.zeros(shape[0]*shape[1], dtype=np.uint8)
    for m,enc in enumerate(encs):
        if isinstance(enc,np.float) and np.isnan(enc): continue
        s = enc.split()
        for i in range(len(s)//2):
            start = int(s[2*i]) - 1
            length = int(s[2*i+1])
            img[start:start+length] = 1 + m
    return img.reshape(shape).T

def mask2enc(mask, n=1):
    pixels = mask.T.flatten()
    encs = []
    for i in range(1,n+1):
        p = (pixels == i).astype(np.int8)
        if p.sum() == 0: encs.append(np.nan)
        else:
            p = np.concatenate([[0], p, [0]])
            runs = np.where(p[1:] != p[:-1])[0] + 1
            runs[1::2] -= runs[::2]
            encs.append(' '.join(str(x) for x in runs))
    return encs

#https://www.kaggle.com/bguberfain/memory-aware-rle-encoding
#with bug fix
def rle_encode_less_memory(img):
    #watch out for the bug
    pixels = img.T.flatten()
    
    # This simplified method requires first and last pixel to be zero
    pixels[0] = 0
    pixels[-1] = 0
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 2
    runs[1::2] -= runs[::2]
    
    return ' '.join(str(x) for x in runs)

In [None]:
# https://www.kaggle.com/iafoss/256x256-images
mean = np.array([0.65459856,0.48386562,0.69428385])
std = np.array([0.15167958,0.23584107,0.13146145])

def img2tensor(img,dtype:np.dtype=np.float32):
    if img.ndim==2 : img = np.expand_dims(img,2)
    img = np.transpose(img,(2,0,1))
    return torch.from_numpy(img.astype(dtype, copy=False))

class HuBMAPTestDataset(Dataset):
    def __init__(self, imgs, idxs):
        self.imgs = imgs
        self.fnames = idxs
        
    def __len__(self):
        return len(self.fnames)
    
    def __getitem__(self, idx):
        return img2tensor((self.imgs[idx]/255.0 - mean)/std)

# Model

In [None]:
import numpy as np 
import pandas as pd
import matplotlib.pyplot as plt
import cv2
import glob

import torch
from torch.nn import Sigmoid
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import Dataset, DataLoader

from IPython.display import clear_output

from sklearn.model_selection import KFold

import matplotlib.pyplot as plt

import albumentations as A

In [None]:
def Unet(encoder_backbone='resnet18', encoder_weights='imagenet'):
    return smp.Unet(encoder_name=encoder_backbone,
                    encoder_weights=encoder_weights,
                    in_channels=3,
                    classes=1)

In [None]:
model_1 = Unet('se_resnext50_32x4d')
model_path_1 = '../input/resnext50-2/best_model (4).pth'
model_3 = Unet('resnet34')
model_path_3 = '../input/resnet34-unet-lovazs-loss/best_model_resnet34_lovazsloss.pth'

model_1.load_state_dict(torch.load(model_path_1))
model_3.load_state_dict(torch.load(model_path_3))

model_1.to('cuda')
model_3.to('cuda')

models = []
models.append(model_1)
models.append(model_3)


# Prediction

In [None]:
#iterator like wrapper that returns predicted masks
class Model_pred:
    def __init__(self, models, dl, tta:bool=True, half:bool=False):
        self.models = models
        self.dl = dl
        self.tta = tta
        self.half = half
        
    def __iter__(self):
        count=0
        with torch.no_grad():
            for x in iter(self.dl):
                x = x.to(device)
                if self.half: x = x.half()
                py = None
                # predict without tta for all models, then add them up
                for model in self.models:
                    p = model(x)
                    p = torch.sigmoid(p).detach()
                    if py is None: py = p
                    else: py += p
                if self.tta:
                    #x,y,xy flips as TTA
                    flips = [[-1],[-2],[-2,-1]]
                    for f in flips:
                        xf = torch.flip(x,f)
                        for model in self.models:
                            p = model(xf)
                            p = torch.flip(p,f)
                            py += torch.sigmoid(p).detach()
                    py /= (1+len(flips))        
                py /= len(self.models)
                    
                py = F.upsample(py, scale_factor=reduce, mode="bilinear")
                py = py.permute(0,2,3,1).float().cpu()
                batch_size = len(py)
                for i in range(batch_size):
                    yield py[i]
                    count += 1
                    
    def __len__(self):
        return len(self.dl.dataset)

In [None]:
#Somehow I cannot resolve the submission error with consideration of the
#private LB data, and the submission error doesn't give an informative
#output. So, for now I share the notbook that makes a submission only
#to the public LB, and later I'll try to resolve the issue.
#IMPORTANT: This notebook doesn't perform predictions for the private LB.
names,preds = [],[]
samples = ['b9a3865fc','b2dc8411c','26dc41664','c68fe75ea','afa5e8098']
samples_n = [id for id in df_sample.id if id not in samples]

names += samples_n
preds += [np.NaN]*len(samples_n)
df_sample = df_sample.loc[df_sample.id.isin(samples)]

In [None]:
s_th = 40  #saturation blancking threshold
p_th = 200*sz//256 #threshold for the minimum number of pixels
#names,preds = [],[]
for idx,row in tqdm(df_sample.iterrows(),total=len(df_sample)):
    idx = row['id']
    #read image
    img = tiff.imread(os.path.join(DATA,idx+'.tiff'))
    if len(img.shape) == 5: img = np.transpose(img.squeeze(), (1,2,0))
    
    #add padding to make the image dividable into tiles
    img_shape = img.shape
    pad0 = (reduce*sz - img_shape[0]%(reduce*sz))%(reduce*sz)
    pad1 = (reduce*sz - img_shape[1]%(reduce*sz))%(reduce*sz)
    img = np.pad(img,[[pad0//2,pad0-pad0//2],[pad1//2,pad1-pad1//2],[0,0]],
                 constant_values=0)

    #split image into tiles using the reshape+transpose trick
    if reduce != 1:
        img = cv2.resize(img,(img.shape[1]//reduce,img.shape[0]//reduce),
                     interpolation = cv2.INTER_AREA)
    img_shape_p = img.shape
    img = img.reshape(img.shape[0]//sz,sz,img.shape[1]//sz,sz,3)
    img = img.transpose(0,2,1,3,4).reshape(-1,sz,sz,3)

    #select tiles for running the model
    imgs,idxs = [],[]
    for i,im in enumerate(img):
        #remove black or gray images based on saturation check
        hsv = cv2.cvtColor(im, cv2.COLOR_BGR2HSV)
        h, s, v = cv2.split(hsv)
        if (s>s_th).sum() <= p_th or im.sum() <= p_th: continue
        imgs.append(im)
        idxs.append(i)
    #tile dataset
    ds = HuBMAPTestDataset(imgs,idxs)
    dl = DataLoader(ds,bs,num_workers=NUM_WORKERS,shuffle=False,pin_memory=True)
    mp = Model_pred(models,dl)
    
    #generate masks
    mask = torch.zeros(img.shape[0],sz*reduce,sz*reduce,dtype=torch.int8)
    for i,p in zip(idxs,iter(mp)): mask[i] = p.squeeze(-1) > TH
    
    #reshape tiled masks into a single mask and crop padding
    mask = mask.view(img_shape_p[0]//sz,img_shape_p[1]//sz,sz*reduce,sz*reduce).\
        permute(0,2,1,3).reshape(img_shape_p[0]*reduce,img_shape_p[1]*reduce)
    mask = mask[pad0//2:-(pad0-pad0//2) if pad0 > 0 else img_shape_p[0]*reduce,
        pad1//2:-(pad1-pad1//2) if pad1 > 0 else img_shape_p[1]*reduce]
    
    #convert to rle
    #https://www.kaggle.com/bguberfain/memory-aware-rle-encoding
    rle = rle_encode_less_memory(mask.numpy())
    names.append(idx)
    preds.append(rle)
    gc.collect()

In [None]:
df = pd.DataFrame({'id':names,'predicted':preds})
df.to_csv('submission.csv',index=False)