In [None]:
import warnings
warnings.filterwarnings('ignore')
import os
import gc
import cv2
import sys
import json
import time
import pickle
import shutil
import numpy as np
import pandas as pd 
import tifffile as tiff
import tensorflow as tf
import matplotlib.pyplot as plt
import tensorflow.keras.backend as K
from tensorflow.keras import Model, Sequential
from tensorflow.keras.models import load_model
from tensorflow.keras.utils import Sequence
from tensorflow.keras.losses import binary_crossentropy
from tensorflow.keras.layers import *
from tensorflow.keras.optimizers import Adam, SGD
from tensorflow.keras.callbacks import *
import segmentation_models as sm
from segmentation_models import Unet
from tqdm import tqdm
print('tensorflow version:', tf.__version__)
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
gpu_devices = tf.config.experimental.list_physical_devices('GPU')
if gpu_devices:
    for gpu_device in gpu_devices:
        print('device available:', gpu_device)
pd.set_option('display.max_columns', None)

In [None]:
TEST = False
KAGGLE = False
VER = 'v2'
if KAGGLE:
    DATA_PATH = '../input/hubmap-kidney-segmentation'
    MDLS_PATH = f'../input/kidney-models-{VER}'
else:
    DATA_PATH = './data'
    MDLS_PATH = f'./models_{VER}'
THRESHOLD = .5
PRED_BATCH_SIZE = 256
TTAS = [1.25, 1.5, 2]
SUB_PATH = f'{DATA_PATH}/test' if TEST else f'{DATA_PATH}/train'
CACHE_PATH = './cache_kidney'
if not os.path.exists(CACHE_PATH):
    os.mkdir(CACHE_PATH)

start_time = time.time()

In [None]:
def enc2mask(encs, shape):
    img = np.zeros(shape[0] * shape[1], dtype=np.uint8)
    for m, enc in enumerate(encs):
        if isinstance(enc, np.float) and np.isnan(enc): continue
        s = enc.split()
        for i in range(len(s) // 2):
            start = int(s[2 * i]) - 1
            length = int(s[2 * i + 1])
            img[start : start + length] = 1 + m
    return img.reshape(shape).T

def rle_encode_less_memory(img):
    pixels = img.T.flatten()
    pixels[0] = 0
    pixels[-1] = 0
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 2
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)

In [None]:
def dice_coef(y_true, y_pred, smooth=1):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2 * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)

def dice_loss(y_true, y_pred, smooth=1):
    return (1 - dice_coef(y_true, y_pred, smooth))

def bce_dice_loss(y_true, y_pred):
    return binary_crossentropy(y_true, y_pred) + dice_loss(y_true, y_pred)

def get_model(backbone, path, input_shape, classes=1, learning_rate=.001):
    if backbone == 'efficientnetb0':
        weights = f'{path}/efficientnet-b0_weights_tf_dim_ordering_tf_kernels_autoaugment_notop.h5'
    elif backbone == 'efficientnetb1':
        weights = f'{path}/efficientnet-b1_weights_tf_dim_ordering_tf_kernels_autoaugment_notop.h5'
    elif backbone == 'efficientnetb2':
        weights = f'{path}/efficientnet-b2_weights_tf_dim_ordering_tf_kernels_autoaugment_notop.h5'
    else:
        raise AttributeError('mode parameter error')
    optimizer = Adam(lr=learning_rate)
    model = Unet(backbone_name=backbone,
                 input_shape=input_shape,
                 classes=classes, 
                 activation='sigmoid',
                 encoder_weights=weights)
    model.compile(optimizer=optimizer, 
                  loss=bce_dice_loss, 
                  metrics=[dice_coef])
    return model

In [None]:
def get_tiles(img_name, path, params, mode=1):
    img = tiff.imread(os.path.join(path, img_name + '.tiff'))
    if len(img.shape) == 5: img = np.transpose(img.squeeze(), (1, 2, 0))
    print(img_name, 'read:', img.shape)
    shape = img.shape
    tile_size = int(params['img_size'] * mode)
    pad0 = (
        params['resize'] * tile_size
        - shape[0] % (params['resize'] * tile_size)
    ) % (
        params['resize'] * tile_size
    )
    pad1 = (
        params['resize'] * tile_size
        - shape[1] % (params['resize'] * tile_size)
    ) % (
        params['resize'] * tile_size
    )
    img = np.pad(img, 
                 [[pad0 // 2, pad0 - pad0 // 2],
                  [pad1 // 2, pad1 - pad1 // 2],
                  [0, 0]],
                 constant_values=0)
    print(img_name, 'padded:', img.shape)
    img = cv2.resize(img,
                     (img.shape[1] // params['resize'], img.shape[0] // params['resize']),
                     interpolation=cv2.INTER_AREA)
    img = img.reshape(img.shape[0] // tile_size, 
                      tile_size, 
                      img.shape[1] // tile_size, 
                      tile_size, 
                      3)
    print(img_name, 'resized and reshaped:', img.shape)
    img = img.transpose(0, 2, 1, 3, 4).reshape(-1, tile_size, tile_size, 3) 
    print(img_name, 'finally reshaped:', img.shape)
    return img, shape, pad0, pad1

def restore_mask(msk, shape, pad0, pad1, params, mode=1):
    tile_size = int(params['img_size'] * mode)
    msk = np.squeeze(msk)
    print('mask init:', msk.shape)
    msk = msk.reshape(
        int((shape[0] + pad0) / tile_size / params['resize']), 
        int((shape[1] + pad1) / tile_size / params['resize']), 
        tile_size, 
        tile_size
    )
    msk = msk.transpose(0, 2, 1, 3)
    print('mask reshaped and transposed:', msk.shape)
    msk = msk.reshape(
        int((shape[0] + pad0) / params['resize']), 
        int((shape[1] + pad1) / params['resize'])
    )
    print('mask reshaped:', msk.shape)
    msk = (msk > THRESHOLD).astype(np.int8)
    msk = cv2.resize(
        msk,
        (msk.shape[1] * params['resize'], msk.shape[0] * params['resize']),
        interpolation=cv2.INTER_NEAREST
    )
    print('mask resized:', msk.shape)
    msk = msk[
        pad0 // 2 : -(pad0 - pad0 // 2) if pad0 > 0 else (shape[0] + pad0),
        pad1 // 2 : -(pad1 - pad1 // 2) if pad1 > 0 else (shape[1] + pad1)
    ]
    print('mask finally un-padded:', msk.shape)
    return msk

In [None]:
with open(f'{MDLS_PATH}/params.json') as file:
    params = json.load(file)
print('loaded params:', params)
    
imgs_idxs = [x.replace('.tiff', '') for x in os.listdir(SUB_PATH) if '.tiff' in x]
print('images idxs:', imgs_idxs)

for i_mode, mode in enumerate(TTAS):
    for img_idx in imgs_idxs:
        print('-' * 20, img_idx, '-' * 20)
        cache_dict = {}
        img, shape, pad0, pad1 = get_tiles(img_idx, SUB_PATH, params, mode) 
        cache_dict['img'], cache_dict['shape'] = img, shape
        cache_dict['pad0'], cache_dict['pad1'] = pad0, pad1
        file_name = f'{CACHE_PATH}/{img_idx}_m{i_mode}.pickle'
        with open(file_name, 'wb') as handle:
            pickle.dump(cache_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
        del img, shape, pad0, pad1; gc.collect()
        print(img_idx, 'done | saved to', file_name)
    
elapsed_time = time.time() - start_time
print(f'time elapsed: {elapsed_time // 60:.0f} min {elapsed_time % 60:.0f} sec')

In [None]:
sub_list = []
for i_mode, mode in enumerate(TTAS):
    print('=' * 20, 'tta:', i_mode, '=' * 20)
    img_indexes = []
    img_rles = []
    img_shapes = []
    for img_idx in imgs_idxs:
        print('-' * 20, img_idx, '-' * 20)
        file_name = f'{CACHE_PATH}/{img_idx}_m{i_mode}.pickle'
        with open(file_name, 'rb') as handle:
            cache_dict = pickle.load(handle)
        img, shape = cache_dict['img'], cache_dict['shape']
        pad0, pad1 = cache_dict['pad0'], cache_dict['pad1']
        del cache_dict
        msk_pred = np.zeros((len(img), img.shape[1], img.shape[2], 1))
        folds = list(range(params['folds']))
        #folds = [0, 1]
        for n_fold in folds:
            checkpoint_path = f'{MDLS_PATH}/model_{n_fold}.hdf5'
            model = get_model(
                params['backbone'], 
                MDLS_PATH,
                input_shape=(int(params['img_size'] * mode), 
                             int(params['img_size'] * mode), 3)
            )
            model.load_weights(checkpoint_path)
            print('loaded:', checkpoint_path, end=' ')
            batch = int(PRED_BATCH_SIZE / mode)
            msk_pred_fold = np.zeros((len(img), img.shape[1], img.shape[2], 1))
            for i in range(len(img) // batch + 1):
                if i * batch < len(img):
                    imgs_batch = np.array([
                        cv2.cvtColor(img[j, ], cv2.COLOR_BGR2RGB) / 255
                        for j in range(i * batch, min((i+1) * batch, len(img)))
                    ])
                    msk_pred_fold[i * batch : min((i+1) * batch, len(img)), ] = model.predict(imgs_batch)
            print('done')
            msk_pred = msk_pred + msk_pred_fold / len(folds)
            del model, msk_pred_fold, imgs_batch; gc.collect()
        del img; gc.collect()
        msk_pred = restore_mask(msk_pred, shape, pad0, pad1, params, mode)
        rle = rle_encode_less_memory(msk_pred)
        img_indexes.append(img_idx)
        img_rles.append(rle)
        img_shapes.append(shape)
        del msk_pred, rle; gc.collect()
        print(img_idx, 'tta', i_mode, 'done')
    sub_list.append(img_indexes)
    sub_list.append(img_rles)
    sub_list.append(img_shapes)
shutil.rmtree(CACHE_PATH)
    
elapsed_time = time.time() - start_time
print(f'time elapsed: {elapsed_time // 60:.0f} min {elapsed_time % 60:.0f} sec')

In [None]:
tta_masks = []
for i, row in pd.DataFrame(sub_list).T.iterrows():
    print('processing', row[0], end=' ')
    shape = [int(row[2][0]), int(row[2][1])]
    mask = np.zeros(shape, dtype=np.int8)
    for i in range(len(TTAS)):
        mask = mask + enc2mask([row[i * 3 + 1]], (shape[1], shape[0]))
    tta_masks.append(rle_encode_less_memory((mask >= len(TTAS) - 1).astype(np.int8)))
    del mask; gc.collect()
    print('done')
    
elapsed_time = time.time() - start_time
print(f'time elapsed: {elapsed_time // 60:.0f} min {elapsed_time % 60:.0f} sec')

In [None]:
sub_list.append(tta_masks)
df_sub = pd.DataFrame(sub_list).T.iloc[:, [0, -1]]
df_sub.columns = ['id', 'predicted']
df_sub

In [None]:
if not TEST:
    df_masks = pd.read_csv(f'{DATA_PATH}/train.csv').set_index('id')
    idx = df_sub.iloc[0].id
    img = tiff.imread(os.path.join(SUB_PATH, idx + '.tiff'))
    if len(img.shape) == 5: img = np.transpose(img.squeeze(), (1, 2, 0))
    msk_p = enc2mask([df_sub.iloc[0].predicted], (img.shape[1], img.shape[0]))
    msk = enc2mask([df_masks.loc[idx, 'encoding']], (img.shape[1], img.shape[0]))
    print(img.shape)
    print(msk_p.shape)
    print(msk.shape)
    
    plt.figure(figsize=(16, 16))
    plt.imshow(img)
    plt.imshow(msk_p, alpha=.4)
    plt.imshow(msk, alpha=.2)
    plt.title(idx)
    plt.show()

In [None]:
df_sub.to_csv('submission.csv', index=False)