In [None]:
# !pip install git+https://github.com/qubvel/segmentation_models.pytorch

In [None]:
import os
import copy
import cv2
import random
import pydicom
import torch
import time
import math
import shutil
import rasterio

import pandas as pd
import numpy as np
import PIL as pil
import matplotlib.pyplot as plt

import torch
from torch import nn
import torch.nn.functional as F
import torchvision.models as models
import pytorch_lightning as pl
import albumentations as A
from pytorch_lightning import loggers as pl_loggers
from scipy.ndimage.interpolation import zoom
from albumentations.pytorch import ToTensorV2
from albumentations import (
    HorizontalFlip, VerticalFlip, IAAPerspective, ShiftScaleRotate, CLAHE, RandomRotate90,
    Transpose, ShiftScaleRotate, Blur, OpticalDistortion, GridDistortion, HueSaturationValue,
    IAAAdditiveGaussianNoise, GaussNoise, MotionBlur, MedianBlur, IAAPiecewiseAffine,
    IAASharpen, IAAEmboss, RandomBrightnessContrast, Flip, OneOf, Compose, Normalize, Cutout, CoarseDropout, ShiftScaleRotate 
)

In [None]:
# def seed_everything(seed):
#     random.seed(seed)
#     os.environ['PYTHONHASHSEED'] = str(seed)
#     np.random.seed(seed)
#     torch.manual_seed(seed)
#     torch.cuda.manual_seed(seed)
#     torch.backends.cudnn.deterministic = True
#     torch.backends.cudnn.benchmark = True

# seed_everything(cfg['seed'])

In [None]:
input_dir = '../input/hubmap-kidney-segmentation'
errors_dir = '../input/errors'
errors_csv = f'{errors_dir}/errors.csv'

img_size = 1024

In [None]:
df = pd.read_csv(errors_csv)
df.head()

In [None]:
eps = 1e-7
df['filename'] = df.apply( lambda row: f'{row.id}-{row.i}-{row.j}.jpeg', axis=1)
df['precision'] = (df.tp + eps) / (df.tp + df.fp + eps)
df['recall'] = (df.tp + eps) / (df.tp + df.fn + eps)
df['dice'] = (2 * df.tp + eps) / (2 * df.tp + df.fp + df.fn + eps)
df.sample(5)

In [None]:
def plot_imgs(imgs_df, columns = 4):
    imgs = imgs_df.filename.tolist()
    
    rows = len(imgs) // columns
    
    fig, axs = plt.subplots(rows, columns, figsize=(20,rows*5))
    
    for ax, img, im,  in zip(axs.flatten(), imgs, imgs_df.id.tolist()):
        ax.imshow(pil.Image.open(f'{errors_dir}/imgs/{img}'))
        ax.text(500,500,img, color='blue')

    plt.show()

# Plot worst predictions

In [None]:
count = 32
worst = [df.nsmallest(count, typ) for typ in ['precision', 'recall', 'dice']]

**1. Precision**

In [None]:
plot_imgs(worst[0])

**2.Recall**

In [None]:
plot_imgs(worst[1], columns=3)

**3.Dice**

In [None]:
plot_imgs(worst[2])

**4. FN**

In [None]:
worst_fn = df.nlargest(32, 'fn')
worst_fn.head()

In [None]:
plot_imgs(worst_fn)

# Plot tiffs

In [None]:
df.nlargest(32, 'fp')

In [None]:
print('worst recall; worst fn')
worst[1].id.value_counts(), worst_fn.id.value_counts()

In [None]:
df.fp.sum(), df.fn.sum(), 

In [None]:
df.groupby(['id']).fn.sum().sort_values(ascending=False), df.groupby(['id']).fp.sum().sort_values(ascending=False),

In [None]:
thresh = 0.4
df[df.recall < thresh].shape, df[df.precision < thresh].shape

In [None]:
identity = rasterio.Affine(1, 0, 0, 0, 1, 0)

def reconstruct_img(tiff_id, tiffs, scale=10, typ='imgs'):
    tif = rasterio.open(os.path.join(input_dir, 'train', f'{tiff_id}.tiff'), transform=identity)
    shape = tif.shape

    shape = (shape[0] // scale, shape[1] // scale, 3)

    gray = 122
    patched = np.ones(shape, dtype=np.uint8)*gray
    print(tiff_id, shape, patched.shape)

    for f in tiffs:
        file = f.split('.')[0]

        row = int(file.split('-')[1])
        column = int(file.split('-')[2])
        file = file.split('-')[0]

        sz = img_size//scale

        xstart = row*sz
        xstop = xstart+sz

        ystart = column*sz
        ystop = ystart+sz

        pt = os.path.join(errors_dir, typ, f)

        patch = np.array(pil.Image.open(pt))
        patch = cv2.resize(patch,(sz,sz))

        patched[xstart:xstop, ystart:ystop] = patch
        
    return patched

def reconstruct_imgs(df, scale=10, typ='imgs'):
    imgs = {}
    for tiff_id in df.id.unique().tolist():
        tiffs = df[df.id == tiff_id].filename.tolist()
        
        patch = reconstruct_img(tiff_id, tiffs, scale=scale, typ=typ)        
        imgs[tiff_id] = patch

    return imgs

In [None]:
imgs = reconstruct_imgs(df,scale=5)

In [None]:
def save_result(imgs, dr='results'):
    os.mkdir(dr)
    
    for k, v in imgs.items():
        pl = pil.Image.fromarray(v)

        pl.save(f'{dr}/{k}.jpeg')

    shutil.make_archive(dr, 'zip', dr)
    shutil.rmtree(dr)

In [None]:
save_result(imgs)

In [None]:
for k, v in imgs.items():
    plt.figure(figsize = (20,20))
    plt.imshow(v)
    plt.title(k)
    
plt.show()