In [1]:
# !pip install /kaggle/input/save-smp/segmentation_models_pytorch/{segmentation_models_pytorch-0.3.3-py3-none-any.whl,pretrainedmodels-0.7.4-py3-none-any.whl,efficientnet_pytorch-0.7.1-py3-none-any.whl,timm-0.9.2-py3-none-any.whl,munch-4.0.0-py2.py3-none-any.whl}

In [1]:
import segmentation_models_pytorch as smp
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from glob import glob
import os
import logging
from tqdm import tqdm

import tifffile
import numpy as np
import pandas as pd

import albumentations as A
import albumentations.pytorch as AT
from torchvision import transforms
import gc

ModuleNotFoundError: No module named 'segmentation_models_pytorch'

In [3]:
class UnetUpscale(nn.Module):
    def __init__(
        self,
        encoder_name,
        decoder_use_batchnorm,
        in_channels,
        classes,
        encoder_weights,
        upscale_factor,
    ):
        super().__init__()
        self.upscale_factor = upscale_factor

        self.model = smp.Unet(
            encoder_name=encoder_name,
            decoder_use_batchnorm=decoder_use_batchnorm,
            in_channels=in_channels,
            classes=classes,
            encoder_weights=encoder_weights,
        )

    def forward(self, x):
        x = torch.nn.functional.interpolate(
            x, (x.shape[-2] * self.upscale_factor, x.shape[-1] * self.upscale_factor), mode="bilinear"
        )
        x = self.model(x)
        x = torch.nn.functional.interpolate(
            x, (x.shape[-2] // self.upscale_factor, x.shape[-1] // self.upscale_factor), mode="bilinear"
        )
        return x
    
    

In [4]:
class Dataset2DMultiPlanesTest(Dataset):
    def __init__(
        self,
        full_image,
        crop_size,
        overlap_size,
        planes,
        transform=None,
    ):
        super().__init__()
        step_size = crop_size - overlap_size
        self.crop_size = crop_size
        self.image = full_image
         
        self.depth, self.height, self.width = self.image.shape
        
        #calculate XY coordinates
        xy_coordinates = []
        if 'xy' in planes:
            for z in range(self.depth):
                for y in range(0, self.height - step_size, step_size):
                    for x in range(0, self.width - step_size, step_size):
                        crop_end_y = min(y + crop_size, self.height)
                        crop_end_x = min(x + crop_size, self.width)

                        xy_coordinates.append((z, z+1, y, crop_end_y, x, crop_end_x))

        # calculate XZ coordinates
        xz_coordinates = []
        if 'xz' in planes:
            for z in range(0, self.depth - step_size, step_size):
                for y in range(self.height):
                    for x in range(0, self.width - step_size, step_size):
                        crop_end_z = min(z + crop_size, self.depth)
                        crop_end_x = min(x + crop_size, self.width)

                        xz_coordinates.append((z, crop_end_z, y, y+1, x, crop_end_x))

        # calculate YZ coordinates
        yz_coordinates = []
        if 'yz' in planes:
            for z in range(0, self.depth - step_size, step_size):
                for y in range(0, self.height - step_size, step_size):
                    for x in range(self.width):
                        crop_end_z = min(z + crop_size, self.depth)
                        crop_end_y = min(y + crop_size, self.height)

                        yz_coordinates.append((z, crop_end_z, y, crop_end_y, x, x+1))

        print(f'num xy slices: {len(xy_coordinates)} num xz slices: {len(xz_coordinates)} num yz slices: {len(yz_coordinates)}')
        self.coordinates = xy_coordinates + xz_coordinates + yz_coordinates
        print(f'total num of coordinates across 3 planes: {len(self.coordinates)}')

        self.transform = transform

    def __len__(self):
        return len(self.coordinates)

    def __getitem__(self, idx):
        coordinates = self.coordinates[idx]
        z1, z2, y1, y2, x1, x2 = coordinates
        
        image_crop = self.image[z1:z2, y1:y2, x1:x2].copy().squeeze()
        
        height_pad_before = height_pad_after = width_pad_before = width_pad_after = 0
        if image_crop.shape[0] != self.crop_size:
            height_pad_size = self.crop_size - image_crop.shape[0]
            height_pad_before = height_pad_size // 2
            height_pad_after = height_pad_size - height_pad_before

        if image_crop.shape[1] != self.crop_size:
            width_pad_size = self.crop_size - image_crop.shape[1]
            width_pad_before = width_pad_size // 2
            width_pad_after = width_pad_size - width_pad_before
            
        image_crop = np.pad(image_crop, ((height_pad_before, height_pad_after), (width_pad_before, width_pad_after)), mode="constant", constant_values=0)        
               
        if self.transform:
            sample = self.transform(image=image_crop)
            image_crop = sample['image']

        image_mean = torch.mean(image_crop.float())
        image_std = torch.std(image_crop.float())

        image_crop = (image_crop - image_mean) / (image_std + 1e-4)
        
        return image_crop, torch.tensor([z1, z2, y1, y2, x1, x2]), torch.tensor([height_pad_before, height_pad_after, width_pad_before, width_pad_after])

In [5]:
def create_dataset(dataset_root):
    paths = sorted(glob(f'{dataset_root}/*.tif'))
    height, width = tifffile.memmap(paths[0], mode='r').shape
    
    full_image = np.zeros((len(paths), height, width), dtype=np.uint8)
    
    for path_index, path in enumerate(paths):
        full_image[path_index] = (tifffile.imread(path) / 256).astype(np.uint8)
    
    return full_image

In [6]:
def predict(
    batch_size,
    num_workers,
    dataset_params,
    model_name,
    model_params,
    test_kidney=None,
):  

    if model_name == 'unet':
        model = smp.Unet(
            encoder_name=model_params['encoder_name'],
            decoder_use_batchnorm=model_params['decoder_use_batchnorm'],
            in_channels=1,
            classes=1,
            encoder_weights=None,
        )
    elif model_name == 'unet_upscale':
        model = UnetUpscale(
            encoder_name=model_params['encoder_name'],
            decoder_use_batchnorm=model_params['decoder_use_batchnorm'],
            upscale_factor=model_params['upscale_factor'],
            in_channels=1,
            classes=1,
            encoder_weights=None,
        )
    else:
        raise ValueError('Wrong model_name')


    checkpoint = torch.load(model_params['checkpoint_path'], map_location='cpu')

    model.load_state_dict(checkpoint['model'], strict=True)
    model.cuda().eval();

    transform = A.Compose(
        [
            AT.ToTensorV2(),
        ]
    )

    dataset = Dataset2DMultiPlanesTest(
        full_image=test_kidney,
        crop_size=dataset_params['crop_size'],
        overlap_size=dataset_params['overlap_size'],
        planes=dataset_params['planes'],
        transform=transform,
    )

    loader = DataLoader(
        dataset=dataset,
        shuffle=False,
        drop_last=False,
        batch_size=batch_size,
        pin_memory=True,
        num_workers=num_workers,
    )


    y_pred_shape = (loader.dataset.depth, loader.dataset.height, loader.dataset.width)
    y_pred = torch.zeros(y_pred_shape, dtype=torch.float16)
    y_stats = torch.zeros(y_pred_shape, dtype=torch.uint8)

    for (input, coordinates, paddings) in (tqdm(loader)):
        input = input.cuda()

        with torch.no_grad():
            with torch.cuda.amp.autocast(enabled=True):
                preds = model(input)

                for coordinates_sample, paddings_sample, preds_sample in zip(coordinates, paddings, preds):
                    z1, z2, y1, y2, x1, x2 = coordinates_sample
                    
                    height_pad_before, height_pad_after, width_pad_before, width_pad_after = paddings_sample
                    if height_pad_before:
                        preds_sample = preds_sample[:, height_pad_before:, :]
                    if height_pad_after:
                        preds_sample = preds_sample[:, :-height_pad_after, :]
                    if width_pad_before:
                        preds_sample = preds_sample[:, :, width_pad_before:]
                    if width_pad_after:
                        preds_sample = preds_sample[:, :, :-width_pad_after]

                    slice_shape = y_pred[z1:z2, y1:y2, x1:x2].shape

                    y_pred[z1:z2, y1:y2, x1:x2] += preds_sample.view(slice_shape).cpu()
                    y_stats[z1:z2, y1:y2, x1:x2] += 1


    y_pred /= y_stats

    del model, y_stats
    gc.collect()

    return y_pred


def rle_encode(img):
    '''
    img: numpy array, 1 - mask, 0 - background
    Returns run length as string formated
    '''
    pixels = img.flatten()
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)

In [7]:
th = 0.025

In [8]:
model1 = {
    'batch_size': 2,
    'num_workers': 2,
    'dataset_params' : {
            'crop_size': 512,
            'overlap_size': 256,
            'planes': ['xy', 'xz', 'yz'],
        },
    'model_name': 'unet',
    'model_params':
        {
            'encoder_name': 'tu-maxvit_base_tf_512.in21k_ft_in1k',
            'decoder_use_batchnorm': False,
            'checkpoint_path': './weights/maxvit_base.pt/epoch_33_surface_dice_at_mean_0.8023.pt',
        },
}

In [9]:
model2 = {
    'batch_size': 4,
    'num_workers': 2,
    'dataset_params' : {
            'crop_size': 512,
            'overlap_size': 256,
            'planes': ['xy', 'xz', 'yz'],
        },
    'model_name': 'unet_upscale',
    'model_params':
        {
            'encoder_name': 'tu-tf_efficientnetv2_s.in21k_ft_in1k',
            'decoder_use_batchnorm': False,
            'checkpoint_path': './weights/effnet_v2_m.pt/epoch_23_surface_dice_at_mean_0.8133.pt',
            'upscale_factor': 2,
        },
}

In [10]:
model3 = {
    'batch_size': 2,
    'num_workers': 2,
    'dataset_params' : {
            'crop_size': 512,
            'overlap_size': 256,
            'planes': ['xy', 'xz', 'yz'],
        },
    'model_name': 'unet_upscale',
    'model_params':
        {
            'encoder_name': 'tu-dpn68b',
            'decoder_use_batchnorm': False,
            'checkpoint_path': './weights/dpn_68.pt/epoch_38_surface_dice_at_mean_0.80779.pt',
            'upscale_factor': 2,
        },
}

In [11]:
models = [model1, model2, model3]

In [12]:
ids, rles = [], []

for test_kidney in [6,5]:
    images_paths = sorted(glob(f'/teradata/hra_data/k4_data/competition-data/test/kidney_{test_kidney}/images/*.tif')) 
    test_kidney_image = create_dataset(
        dataset_root=f'/teradata/hra_data/k4_data/competition-data/test/kidney_{test_kidney}/images/'
    )

    if test_kidney == 6:
        private_res = 63.08
        public_res = 50.0

        scale = private_res / public_res

        d_original, h_original, w_original = test_kidney_image.shape
        test_kidney_image = torch.tensor(test_kidney_image).view(1, 1, d_original, h_original, w_original)
        test_kidney_image = test_kidney_image.to(dtype=torch.float32)
        test_kidney_image = torch.nn.functional.interpolate(test_kidney_image, (
            int(d_original*scale),
            int(h_original*scale),
            int(w_original*scale),
        ), mode='trilinear').squeeze().numpy()

    for model_index, model in enumerate(models):
        preds = predict(
            **model,
            test_kidney=test_kidney_image,
        )

        if model_index == 0:
            preds_ensemble = preds
        else:
            preds_ensemble += preds

        del preds
        gc.collect()

    del test_kidney_image
    gc.collect()

    preds_ensemble /= len(models)
    if test_kidney == 6:
        d_preds, h_preds, w_preds = preds_ensemble.shape 
        preds_ensemble = preds_ensemble.view(1, 1, d_preds, h_preds, w_preds)
        preds_ensemble = preds_ensemble.to(dtype=torch.float32)

        preds_ensemble = torch.nn.functional.interpolate(preds_ensemble, (
            d_original,
            h_original,
            w_original,
        ), mode='trilinear').squeeze()

    preds_ensemble_th = torch.sigmoid(preds_ensemble.cuda()).cpu() > th
    for pred_index, pred in enumerate(preds_ensemble_th):
        ids.append(f'kidney_{test_kidney}_{images_paths[pred_index].split("/")[-1].split(".")[0]}')
        rle = rle_encode(pred)
        if rle == '':
            rle = '1 0'
        rles.append(rle)

    del preds_ensemble, preds_ensemble_th
    gc.collect()


submission = pd.DataFrame({
    'id': ids,
    'rle': rles,
})

num xy slices: 18960 num xz slices: 15552 num yz slices: 15850
total num of coordinates across 3 planes: 50362


100%|██████████| 25181/25181 [25:47<00:00, 16.27it/s]


num xy slices: 18960 num xz slices: 15552 num yz slices: 15850
total num of coordinates across 3 planes: 50362


100%|██████████| 12591/12591 [12:58<00:00, 16.17it/s]


num xy slices: 18960 num xz slices: 15552 num yz slices: 15850
total num of coordinates across 3 planes: 50362


100%|██████████| 25181/25181 [14:43<00:00, 28.51it/s]


num xy slices: 30360 num xz slices: 23940 num yz slices: 23970
total num of coordinates across 3 planes: 78270


100%|██████████| 39135/39135 [39:16<00:00, 16.61it/s]


num xy slices: 30360 num xz slices: 23940 num yz slices: 23970
total num of coordinates across 3 planes: 78270


100%|██████████| 19568/19568 [20:02<00:00, 16.28it/s]


num xy slices: 30360 num xz slices: 23940 num yz slices: 23970
total num of coordinates across 3 planes: 78270


100%|██████████| 39135/39135 [22:29<00:00, 29.00it/s]


In [13]:
submission.to_csv('submission-validation.csv', index=False)

In [16]:
# !pip install numba

Collecting numba
  Using cached numba-0.59.0-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (2.7 kB)
Collecting llvmlite<0.43,>=0.42.0dev0 (from numba)
  Using cached llvmlite-0.42.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.8 kB)
Using cached numba-0.59.0-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (3.7 MB)
Using cached llvmlite-0.42.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (43.8 MB)
Installing collected packages: llvmlite, numba
Successfully installed llvmlite-0.42.0 numba-0.59.0


In [17]:
import sys,os 
sys.path.append(f'{os.getcwd()}/sennet-metrics')
sys.path.append(f'{os.getcwd()}/sennet-metrics/src')

from sennet_metrices import *

In [41]:
# Compute competition metric.

submit_df = pd.read_csv('submission-validation.csv')
label_df = pd.read_csv('/teradata/hra_data/k4_data/competition-data/solution.csv')

# Check the id column of the dataframe and separate rows into two dataframes based on if the values contains "kidney_5" or "kidney_6".
kidney_5_submit_df = submit_df[submit_df['id'].str.contains('kidney_5')]
kidney_6_submit_df = submit_df[submit_df['id'].str.contains('kidney_6')]
print(f'kidney_5_submit_df shape: {kidney_5_submit_df.shape}')
print(f'kidney_6_submit_df shape: {kidney_6_submit_df.shape}')

kidney_5_label_df = label_df[label_df['id'].str.contains('kidney_5')]
kidney_6_label_df = label_df[label_df['id'].str.contains('kidney_6')]
print(f'kidney_5_label_df shape: {kidney_5_label_df.shape}')
print(f'kidney_6_label_df shape: {kidney_6_label_df.shape}')

kidney_5_submit_df.reset_index(inplace=True)
kidney_6_submit_df.reset_index(inplace=True)
kidney_5_label_df.reset_index(inplace=True)
kidney_6_label_df.reset_index(inplace=True)

## -------------- Surface Dice --------------
surface_dice_kidney_5 = compute_surface_dice_score(kidney_5_submit_df, kidney_5_label_df)
print(f'Surface dice for public test (kidney_5) set is: {surface_dice_kidney_5}')

surface_dice_kidney_6 = compute_surface_dice_score(kidney_6_submit_df, kidney_6_label_df)
print(f'Surface dice for private test (kidney_6) set is: {surface_dice_kidney_6}')

kidney_5_submit_df shape: (1012, 2)
kidney_6_submit_df shape: (501, 2)
kidney_5_label_df shape: (1012, 7)
kidney_6_label_df shape: (501, 7)
Surface dice for public test (kidney_5) set is: 0.8855258822441101
Surface dice for private test (kidney_6) set is: 0.691794216632843
