In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import cv2
import os
from pathlib import Path
import gc
from tqdm.notebook import tqdm
import rasterio
from rasterio.windows import Window
import torch, torchvision
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import warnings
warnings.filterwarnings("ignore")

In [None]:
import fastai
fastai.__version__

In [None]:
import fastai.layers as fastai_layers

In [None]:
sz = 480  #the size of tiles
TH = 0.438  #threshold for positive predictions
WINDOW = 1024
OVERLAP = 308
EDGE_IGNORE = 150
tta = True
DATA = '/kaggle/input/hubmap-kidney-segmentation/test/'
MODELS = [
         '/kaggle/input/hubmap-win1024-sz480-subds-fold0-8epochs/fullTraining_8epochs',
         '/kaggle/input/hubmap-win1024-sz480-subds-fold3-8epochs/fullTraining_8epochs',
         '/kaggle/input/hubmap-win1024-sz480-subds-fold1-8epochs/fullTraining_8epochs',
         '/kaggle/input/hubmap-sz480win1024-train-fold2-subds/model_checkpoints_fold_2/fullTraining',
         '/kaggle/input/hubmap-win1024-sz480-subds-fold4-8epochs/fullTraining_8epochs',
         ]
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Data

In [None]:
#functions to convert encoding to mask and mask to encoding
def enc2mask(encs, shape):
    # shape: (width, height)
    img = np.zeros(shape[0]*shape[1], dtype=np.uint8)
    for m,enc in enumerate(encs):
        if isinstance(enc,np.float) and np.isnan(enc): continue
        s = enc.split()
        for i in range(len(s)//2):
            start = int(s[2*i]) - 1
            length = int(s[2*i+1])
            img[start:start+length] = 1 + m
    return img.reshape(shape).T

def mask2enc(mask, n=1):
    pixels = mask.T.flatten()
    encs = []
    for i in range(1,n+1):
        p = (pixels == i).astype(np.int8)
        if p.sum() == 0: encs.append(np.nan)
        else:
            p = np.concatenate([[0], p, [0]])
            runs = np.where(p[1:] != p[:-1])[0] + 1
            runs[1::2] -= runs[::2]
            encs.append(' '.join(str(x) for x in runs))
    return encs

def rle_encode(img:np.ndarray):
    img = img.T.flatten()
    idxs = np.nonzero(img[:-1]!=img[1:])[0]
    if img[0] == 1:
        idxs = np.concatenate([[-1],idxs])
    if img[-1] == 1:
        idxs = np.concatenate([idxs,[len(img)-1]])
    # index for 1start
    idxs[::2] = idxs[::2] + 1
    idxs[1::2] = (idxs[1::2] - idxs[::2]) + 1
    idxs[::2] += 1
    return ' '.join(str(x) for x in idxs)

In [None]:
def make_grid(shape, window=256, min_overlap=32):
    """
        Return Array of size (N,4), where N - number of tiles,
        2nd axis represente slices: x1,x2,y1,y2 
    """
    x, y = shape
    nx = x // (window - min_overlap) + 1
    x1 = np.linspace(0, x, num=nx, endpoint=False, dtype=np.int64)
    x1[-1] = x - window
    x2 = (x1 + window).clip(0, x)
    assert np.all(x2-x1 == window), "Row or height not equal to window. All tiles must be window_x_window"
    ny = y // (window - min_overlap) + 1
    y1 = np.linspace(0, y, num=ny, endpoint=False, dtype=np.int64)
    y1[-1] = y - window
    y2 = (y1 + window).clip(0, y)
    assert np.all(y2-y1 == window), "Column or width not equal to window. All tiles must be window_x_window"
    slices = np.zeros((nx,ny, 4), dtype=np.int64)
    
    for i in range(nx):
        for j in range(ny):
            slices[i,j] = x1[i], x2[i], y1[j], y2[j]    
    return slices.reshape(nx*ny,4)

In [None]:
mean = np.array([0.63468326, 0.48969275, 0.67348264])
std = np.array([0.20481194, 0.25495073, 0.1935244])

std_th = 7

def img2tensor(img,dtype:np.dtype=np.float32):
    if img.ndim==2 : img = np.expand_dims(img,2)
    img = np.transpose(img,(2,0,1))
    return torch.from_numpy(img.astype(dtype, copy=False))

class HuBMAPDataset(Dataset):
    def __init__(self, idx):
        self.data = rasterio.open(os.path.join(DATA,idx+'.tiff'))
        if self.data.count != 3:
            self.layers = []
            if len(self.data.subdatasets) > 0:
                for i, subdataset in enumerate(self.data.subdatasets, 0):
                    self.layers.append(rasterio.open(subdataset))
        self.slices = make_grid(self.data.shape, window=WINDOW, min_overlap=OVERLAP) #[num_slices,4]
        
    def __len__(self):
        return self.slices.shape[0]
    
    def __getitem__(self, idx):
        x1,x2,y1,y2 = self.slices[idx]
        if self.data.count == 3:
            img = self.data.read(window=Window.from_slices((x1,x2),(y1,y2))) # shape: [C,H,W]
        else:
            img = np.zeros((3,WINDOW,WINDOW),np.uint8)
            for j, layer in enumerate(self.layers):
                img[j,:,:] = layer.read(window=Window.from_slices((x1,x2),(y1,y2)))[0]
        img = np.moveaxis(img, 0, -1)
        #img = cv2.resize(img, (sz, sz), interpolation=cv2.INTER_AREA)
        
        #check for images with no real content
        if np.all(np.array([np.std(img[:,:,i]) for i in range(3)]) <= std_th):
            #images with -1 will be skipped
            img = cv2.resize(img, (sz, sz), interpolation=cv2.INTER_AREA)
            return img2tensor((img/255.0 - mean)/std), torch.tensor(-1), torch.tensor((x1,x2,y1,y2))
        else:
            img = cv2.resize(img, (sz, sz), interpolation=cv2.INTER_AREA)
            return img2tensor((img/255.0 - mean)/std), torch.tensor(1), torch.tensor((x1,x2,y1,y2))

# Model

In [None]:
class FPN(torch.nn.Module):
    def __init__(self, input_channels:list, output_channels:list):
        super().__init__()
        self.convs = torch.nn.ModuleList(
            [torch.nn.Sequential(torch.nn.Conv2d(in_ch, out_ch*2, kernel_size=3, padding=1),
             torch.nn.ReLU(inplace=True), torch.nn.BatchNorm2d(out_ch*2),
             torch.nn.Conv2d(out_ch*2, out_ch, kernel_size=3, padding=1))
            for in_ch, out_ch in zip(input_channels, output_channels)])
        
    def forward(self, xs:list, last_layer):
        hcs = [F.interpolate(c(x),scale_factor=2**(len(self.convs)-i),mode='bilinear')
               for i,(c,x) in enumerate(zip(self.convs, xs))]
        hcs.append(last_layer)
        return torch.cat(hcs, dim=1)

class PixelShuffle_ICNR(torch.nn.Sequential):
    "Upsample by `scale` from `ni` filters to `nf` (default `ni`), using `nn.PixelShuffle`."
    def __init__(self, ni, nf=None, scale=2, blur=False, norm_type=fastai_layers.NormType.Weight, 
                 act_cls=fastai_layers.defaults.activation):
        super().__init__()
        nf = fastai_layers.ifnone(nf, ni)
        layers = [fastai_layers.ConvLayer(ni, nf*(scale**2), ks=1, norm_type=norm_type, act_cls=act_cls, bias_std=0),
                  torch.nn.PixelShuffle(scale)]
        if norm_type == fastai_layers.NormType.Weight:
            layers[0][0].weight_v.data.copy_(fastai_layers.icnr_init(layers[0][0].weight_v.data))
            layers[0][0].weight_g.data.copy_(((layers[0][0].weight_v.data**2).sum(dim=[1,2,3])**0.5)[:,None,None,None])
        else:
            layers[0][0].weight.data.copy_(fastai_layers.icnr_init(layers[0][0].weight.data))
        
        if blur: layers += [torch.nn.ReplicationPad2d((1,0,1,0)), torch.nn.AvgPool2d(2, stride=1)]
        super().__init__(*layers)
        
class UnetBlock(torch.nn.Module):
    def __init__(self, up_in_c:int, x_in_c:int, nf:int=None, blur:bool=False,
                 self_attention:bool=False, **kwargs):
        super().__init__()
        self.shuf = PixelShuffle_ICNR(up_in_c, up_in_c//2, blur=blur, norm_type=None, **kwargs)
        self.bn = torch.nn.BatchNorm2d(x_in_c)
        ni = up_in_c//2 + x_in_c
        nf = nf if nf is not None else max(up_in_c//2,32)
        self.conv1 = fastai_layers.ConvLayer(ni, nf, norm_type=None, **kwargs)
        self.conv2 = fastai_layers.ConvLayer(nf, nf, norm_type=None, **kwargs)
        self.relu = torch.nn.ReLU(inplace=True)

    def forward(self, up_in:torch.Tensor, left_in:torch.Tensor) -> torch.Tensor:
        s = left_in
        up_out = self.shuf(up_in)
        cat_x = self.relu(torch.cat([up_out, self.bn(s)], dim=1))
        return self.conv2(self.conv1(cat_x))

class _ASPPModule(torch.nn.Module):
    def __init__(self, inplanes, planes, kernel_size, padding, dilation, groups=1):
        super().__init__()
        self.atrous_conv = torch.nn.Conv2d(inplanes, planes, kernel_size=kernel_size,
                stride=1, padding=padding, dilation=dilation, bias=False, groups=groups)
        self.bn = torch.nn.BatchNorm2d(planes)
        self.relu = torch.nn.ReLU()

        self._init_weight()

    def forward(self, x):
        x = self.atrous_conv(x)
        x = self.bn(x)

        return self.relu(x)

    def _init_weight(self):
        for m in self.modules():
            if isinstance(m, torch.nn.Conv2d):
                torch.nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, torch.nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

class ASPP(torch.nn.Module):
    def __init__(self, inplanes=512, mid_c=256, dilations=[6, 12, 18, 24], out_c=None):
        super().__init__()
        self.aspps = [_ASPPModule(inplanes, mid_c, 1, padding=0, dilation=1)] + \
            [_ASPPModule(inplanes, mid_c, 3, padding=d, dilation=d, groups=4) for d in dilations]
        self.aspps = torch.nn.ModuleList(self.aspps)
        self.global_pool = torch.nn.Sequential(torch.nn.AdaptiveMaxPool2d((1, 1)),
                        torch.nn.Conv2d(inplanes, mid_c, 1, stride=1, bias=False),
                        torch.nn.BatchNorm2d(mid_c), torch.nn.ReLU())
        out_c = out_c if out_c is not None else mid_c
        self.out_conv = torch.nn.Sequential(torch.nn.Conv2d(mid_c*(2+len(dilations)), out_c, 1, bias=False),
                                    torch.nn.BatchNorm2d(out_c), torch.nn.ReLU(inplace=True))
        self._init_weight()

    def forward(self, x):
        x0 = self.global_pool(x)
        xs = [aspp(x) for aspp in self.aspps]
        x0 = F.interpolate(x0, size=xs[0].size()[2:], mode='bilinear', align_corners=True)
        x = torch.cat([x0] + xs, dim=1)
        return self.out_conv(x)
    
    def _init_weight(self):
        for m in self.modules():
            if isinstance(m, torch.nn.Conv2d):
                torch.nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, torch.nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

In [None]:
class UneXt50(torch.nn.Module):
    def __init__(self, stride=1):
        super().__init__()
        #encoder
        #m = torch.hub.load('facebookresearch/semi-supervised-ImageNet1K-models',
        #                   'resnext50_32x4d_swsl')
        m = torchvision.models.resnext50_32x4d(pretrained=False)
        self.enc0 = torch.nn.Sequential(m.conv1, m.bn1, torch.nn.ReLU(inplace=True))
        self.enc1 = torch.nn.Sequential(torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1),
                            m.layer1) #256
        self.enc2 = m.layer2 #512
        self.enc3 = m.layer3 #1024
        self.enc4 = m.layer4 #2048
        #aspp with customized dilatations
        self.aspp = ASPP(2048,256,out_c=512,dilations=[stride*1,stride*2,stride*3,stride*4])
        self.drop_aspp = torch.nn.Dropout2d(0.25)
        #decoder
        self.dec4 = UnetBlock(512,1024,256)
        self.dec3 = UnetBlock(256,512,128)
        self.dec2 = UnetBlock(128,256,64)
        self.dec1 = UnetBlock(64,64,32)
        self.fpn = FPN([512,256,128,64],[16]*4)
        self.drop = torch.nn.Dropout2d(0.1)
        self.final_conv = fastai_layers.ConvLayer(32+16*4, 1, ks=1, norm_type=None, act_cls=None)
        
    def forward(self, x): #N,3,H,W
        enc0 = self.enc0(x) #N,64,H/2,W/2
        enc1 = self.enc1(enc0) #N,256,H/(2**2),W/(2**2)
        enc2 = self.enc2(enc1) #N,512,H/(2**3),W/(2**3)
        enc3 = self.enc3(enc2) #N,1024,H/(2**4),W/(2**4)
        enc4 = self.enc4(enc3) #N,2048,H/(2**5),W/(2**5)
        enc5 = self.aspp(enc4) #N,512,H/(2**5),W/(2**5)
        dec3 = self.dec4(self.drop_aspp(enc5),enc3) #N,256,H/(2**4),W/(2**4)
        dec2 = self.dec3(dec3,enc2) #N,128,H/(2**3),W/(2**3)
        dec1 = self.dec2(dec2,enc1) #N,64,H/(2**2),W/(2**2)
        dec0 = self.dec1(dec1,enc0) #N,32,H/(2**1),W/(2**1)
        x = self.fpn([enc5, dec3, dec2, dec1], dec0) #N,96,H/(2**1),W/(2**1)
        x = self.final_conv(self.drop(x)) #N,1,H/(2**1),W/(2**1)
        x = F.interpolate(x,scale_factor=2,mode='bilinear') #N,1,H,W
        return x

In [None]:
models = []
for path in MODELS:
    model = UneXt50(stride=2)
    model.load_state_dict(torch.load(path,map_location=torch.device('cpu')))
    model.to(device)
    model.eval()
    models.append(model)

In [None]:
len(models)

# Prediction

In [None]:
"""
idx = 'c68fe75ea'
num_ex = 50
plt.figure(figsize=(10,20))
ds = HuBMAPDataset(idx)

for i in range(num_ex):
    img, b, _ = ds[100+i]
    img = img.numpy()
    img = np.transpose(img, (1,2,0))
    img = ((img * std + mean) * 255.0).astype(np.uint8)
    plt.subplot(10,5,i+1)
    plt.imshow(img)
    plt.title(f'{b}')
    plt.axis('off')
"""

In [None]:
test_paths = list(Path('/kaggle/input/hubmap-kidney-segmentation/test').glob('*tiff')); test_paths

In [None]:
names,preds_rle = [],[]
bs = 12
with torch.no_grad():
    for test_path in tqdm(test_paths):
        idx = test_path.stem
        ds = HuBMAPDataset(idx)
        dl = DataLoader(ds, batch_size=bs, shuffle=False, drop_last=False, num_workers=0)
        mask = torch.zeros(ds.data.shape, dtype=torch.uint8)
        for imgs, igns, coords in iter(dl):
            imgs = imgs[igns==1]
            if len(imgs) == 0:
                continue
            imgs = imgs.to(device)
            preds = torch.zeros((imgs.shape[0],1,*imgs.shape[-2:]), dtype=torch.float32, device=device)
            for model in models:
                preds += torch.sigmoid(model(imgs))
                if tta:
                    flips = [[-1],[-2],[-2,-1]]
                    for f in flips:
                        _preds_f = torch.sigmoid(model(torch.flip(imgs,f)))
                        preds += torch.flip(_preds_f, f)
            if tta:
                preds = preds / (len(models)*(len(flips)+1))
            else:
                preds = preds / len(models)
            preds = F.interpolate(preds, size=(WINDOW,WINDOW), mode='bilinear')
            
            coords = coords[igns==1]
            for i in range(len(coords)):
                x1,x2,y1,y2 = coords[i].tolist()    
                x1_ign, x2_ign, y1_ign, y2_ign = 0,0,0,0
                if x1 != 0:
                    x1_ign = EDGE_IGNORE
                if x2 != mask.shape[0]:
                    x2_ign = EDGE_IGNORE
                if y1 != 0:
                    y1_ign = EDGE_IGNORE
                if y2 != mask.shape[1]:
                    y2_ign = EDGE_IGNORE
                mask[(x1+x1_ign):(x2-x2_ign), (y1+y1_ign):(y2-y2_ign)] += (preds[i, 0, x1_ign:WINDOW-x2_ign, y1_ign:WINDOW-y2_ign] > TH).to(dtype=torch.uint8, device='cpu')

        mask.clamp_(0,1)
        mask_np = mask.numpy()

        rle = rle_encode(mask_np)
        names.append(idx)
        preds_rle.append(rle)
        
        del mask, mask_np, ds, dl
        gc.collect()

In [None]:
df = pd.DataFrame({'id':names,'predicted':preds_rle})
df.to_csv('submission.csv',index=False)