In [None]:
import warnings
warnings.filterwarnings('ignore')
import os
import gc
import cv2
import json
import time
import random
import numpy as np
import pandas as pd 
import tifffile as tiff
import tensorflow as tf
import tensorflow_addons as tfa
import matplotlib.pyplot as plt
import albumentations as albu
from sklearn.model_selection import train_test_split, KFold, GroupKFold, StratifiedKFold
import tensorflow.keras.backend as K
from tensorflow.keras import Model, Sequential
from tensorflow.keras.models import load_model
from tensorflow.keras.utils import Sequence
from tensorflow.keras.losses import binary_crossentropy
from tensorflow.keras.layers import *
from tensorflow.keras.optimizers import Adam, SGD
from tensorflow.keras.callbacks import *
from tqdm import tqdm
import segmentation_models as sm
from segmentation_models import Unet, FPN, Linknet
from segmentation_models.losses import bce_jaccard_loss
print('tensorflow version:', tf.__version__)
gpu_devices = tf.config.experimental.list_physical_devices('GPU')
if gpu_devices:
    for gpu_device in gpu_devices:
        print('device available:', gpu_device)
#policy = tf.keras.mixed_precision.experimental.Policy('mixed_float16')
#tf.keras.mixed_precision.experimental.set_policy(policy)
pd.set_option('display.max_columns', None)

In [None]:
VER = 'v61'
PARAMS = {
    'version': VER,
    'folds': 4,
    'img_size': 256,
    'resize': 4,
    'batch_size': 20,
    'epochs': 1000,
    'patience': 30,
    'decay': False,
    'backbone': 'efficientnetb1', # efficientnetbX, resnet34/50, resnext50, seresnet34, seresnext50
    'bce_weight': 1,
    'loss': 'bce_jaccard_loss', # bce_jaccard_loss bce_dice
    'seed': 42,
    'split': 'strat', # 'kfold', 'group' or 'strat'
    'mirror': False,
    'aughard': True,
    'umodel': 'link', # 'unet', 'fpn', 'link'
    'pseudo': 'v55', # version num 'vXX' or 'False'
    'lr': .0005,
    'shift': True,
    'external': 'ext', # 'None' or 'ext' otherwise
    'downsample': None, # 'None' or '.25' otherwise
    'comments': 'new data'
}
DATA_PATH = './data2'
resize = PARAMS['resize']
size = PARAMS['img_size']
ext = PARAMS['external']
pseudo = PARAMS['pseudo']
if PARAMS['shift']:
    if ext:
        if pseudo: 
            IMGS_PATH = f'{DATA_PATH}/tiles_r{resize}_s{size}_shft_{ext}_{pseudo}/'
            MSKS_PATH = f'{DATA_PATH}/masks_r{resize}_s{size}_shft_{ext}_{pseudo}/'
        else:
            IMGS_PATH = f'{DATA_PATH}/tiles_r{resize}_s{size}_shft_{ext}/'
            MSKS_PATH = f'{DATA_PATH}/masks_r{resize}_s{size}_shft_{ext}/'
    else:
        if pseudo: 
            IMGS_PATH = f'{DATA_PATH}/tiles_r{resize}_s{size}_shft_{pseudo}/'
            MSKS_PATH = f'{DATA_PATH}/masks_r{resize}_s{size}_shft_{pseudo}/'
        else:
            IMGS_PATH = f'{DATA_PATH}/tiles_r{resize}_s{size}_shft/'
            MSKS_PATH = f'{DATA_PATH}/masks_r{resize}_s{size}_shft/'
else:
    if ext:
        if pseudo: 
            IMGS_PATH = f'{DATA_PATH}/tiles_r{resize}_s{size}_{ext}_{pseudo}/'
            MSKS_PATH = f'{DATA_PATH}/masks_r{resize}_s{size}_{ext}_{pseudo}/'
        else:
            IMGS_PATH = f'{DATA_PATH}/tiles_r{resize}_s{size}_{ext}/'
            MSKS_PATH = f'{DATA_PATH}/masks_r{resize}_s{size}_{ext}/'
    else:
        if pseudo: 
            IMGS_PATH = f'{DATA_PATH}/tiles_r{resize}_s{size}_{pseudo}/'
            MSKS_PATH = f'{DATA_PATH}/masks_r{resize}_s{size}_{pseudo}/'
        else:
            IMGS_PATH = f'{DATA_PATH}/tiles_r{resize}_s{size}/'
            MSKS_PATH = f'{DATA_PATH}/masks_r{resize}_s{size}/'
MDLS_PATH = f'./models_{VER}'
if not os.path.exists(MDLS_PATH):
    os.mkdir(MDLS_PATH)
with open(f'{MDLS_PATH}/params.json', 'w') as file:
    json.dump(PARAMS, file)
if not PARAMS['mirror']:
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    STRATEGY = tf.distribute.get_strategy() 
else:
    STRATEGY = tf.distribute.MirroredStrategy()
    
def seed_all(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    tf.random.set_seed(seed)

seed_all(PARAMS['seed'])
start_time = time.time()

In [None]:
df_masks = pd.read_csv(f'{DATA_PATH}/train.csv').set_index('id')
df_masks

# Utils

In [None]:
def enc2mask(encs, shape):
    img = np.zeros(shape[0] * shape[1], dtype=np.uint8)
    for m, enc in enumerate(encs):
        if isinstance(enc, np.float) and np.isnan(enc): continue
        s = enc.split()
        for i in range(len(s) // 2):
            start = int(s[2 * i]) - 1
            length = int(s[2 * i + 1])
            img[start : start + length] = 1 + m
    return img.reshape(shape).T

def show_img_n_mask(df, img_num, resize):
    img = tiff.imread(os.path.join(f'{DATA_PATH}/train', df.index[img_num] + '.tiff'))
    if len(img.shape) == 5: img = np.transpose(img.squeeze(), (1, 2, 0))
    mask = enc2mask(df.iloc[img_num], (img.shape[1], img.shape[0]))
    print(img.shape, mask.shape)
    img = cv2.resize(img,
                     (img.shape[1] // resize, img.shape[0] // resize),
                     interpolation=cv2.INTER_AREA)
    mask = cv2.resize(mask,
                      (mask.shape[1] // resize, mask.shape[0] // resize),
                      interpolation=cv2.INTER_NEAREST)
    plt.figure(figsize=(8, 8))
    plt.axis('off')
    plt.imshow(img)
    plt.imshow(mask, alpha=.4)
    plt.show()

In [None]:
show_img_n_mask(df=df_masks, img_num=4, resize=PARAMS['resize'])

In [None]:
if PARAMS['aughard']:
    aug = albu.Compose([
        albu.OneOf([
            albu.RandomBrightness(limit=.2, p=1), 
            albu.RandomContrast(limit=.2, p=1), 
            albu.RandomGamma(p=1)
        ], p=.5),
        albu.OneOf([
            albu.Blur(blur_limit=3, p=1),
            albu.MedianBlur(blur_limit=3, p=1)
        ], p=.25),
        albu.OneOf([
            albu.GaussNoise(0.002, p=.5),
            albu.IAAAffine(p=.5),
        ], p=.25),
        albu.OneOf([
            albu.ElasticTransform(alpha=120, sigma=120 * .05, alpha_affine=120 * .03, p=.5),
            albu.GridDistortion(p=.5),
            albu.OpticalDistortion(distort_limit=2, shift_limit=.5, p=1)                  
        ], p=.25),
        albu.RandomRotate90(p=.5),
        albu.HorizontalFlip(p=.5),
        albu.VerticalFlip(p=.5),
        albu.Cutout(num_holes=10, 
                    max_h_size=int(.1 * size), max_w_size=int(.1 * size), 
                    p=.25),
        albu.ShiftScaleRotate(p=.5)
    ])
else:
    aug = albu.Compose([
        albu.OneOf([
            albu.RandomBrightness(limit=.2, p=1), 
            albu.RandomContrast(limit=.2, p=1), 
            albu.RandomGamma(p=1)
        ], p=.5),
        albu.RandomRotate90(p=.25),
        albu.HorizontalFlip(p=.25),
        albu.VerticalFlip(p=.25)
    ])

In [None]:
class DataGenKid(Sequence):
    
    def __init__(self, imgs_path, msks_path, imgs_idxs, img_size,
                 batch_size=32, mode='fit', shuffle=False, 
                 aug=None, resize=None):
        self.imgs_path = imgs_path
        self.msks_path = msks_path
        self.imgs_idxs = imgs_idxs
        self.img_size = img_size
        self.batch_size = batch_size
        self.mode = mode
        self.shuffle = shuffle
        self.aug = aug
        self.resize = resize
        self.on_epoch_end()
        
    def __len__(self):
        return int(np.floor(len(self.imgs_idxs) / self.batch_size))
    
    def on_epoch_end(self):
        self.indexes = np.arange(len(self.imgs_idxs))
        if self.shuffle:
            np.random.shuffle(self.indexes)
            
    def __getitem__(self, index):
        batch_size = min(self.batch_size, len(self.imgs_idxs) - index*self.batch_size)
        X = np.zeros((batch_size, self.img_size, self.img_size, 3), dtype=np.float32)
        imgs_batch = self.imgs_idxs[index * self.batch_size : (index+1) * self.batch_size]
        if self.mode == 'fit':
            y = np.zeros((batch_size, self.img_size, self.img_size), dtype=np.float32)
            for i, img_idx in enumerate(imgs_batch):
                X[i, ], y[i] = self.get_tile(img_idx)
            return X, y
        elif self.mode == 'predict':
            for i, img_idx in enumerate(imgs_batch):
                X[i, ] = self.get_tile(img_idx)
            return X
        else:
            raise AttributeError('fit mode parameter error')
            
    def get_tile(self, img_idx):
        img_path = f'{self.imgs_path}/{img_idx}.png'
        img = cv2.imread(img_path)
        if img is None:
            print('error load image:', img_path)
        if self.resize:
            img = cv2.resize(img, (int(img.shape[1] / self.resize), int(img.shape[0] / self.resize)))
        img = img.astype(np.float32) / 255
        if self.mode == 'fit':
            msk_path = f'{self.msks_path}/{img_idx}.png'
            msk = cv2.imread(msk_path, cv2.IMREAD_GRAYSCALE)
            if msk is None:
                print('error load mask:', msk_path)
            if self.resize:
                msk = cv2.resize(msk, (int(msk.shape[1] / self.resize), int(msk.shape[0] / self.resize)))
            msk = msk.astype(np.float32)
            if self.aug:
                augmented = self.aug(image=img, mask=msk)
                img = augmented['image']
                msk = augmented['mask']
            return img, msk
        else:
            if self.aug:
                img = self.aug(image=img)['image']
            return img

In [None]:
imgs_idxs = []
for img_idx in tqdm([x.replace('.png', '') 
                for x in os.listdir(IMGS_PATH) 
                if '.png' in x]):
    msk_path = f'{MSKS_PATH}/{img_idx}.png'
    msk = cv2.imread(msk_path, cv2.IMREAD_GRAYSCALE)
    if np.sum(msk) == 0:
        if PARAMS['downsample']:
            if np.random.uniform(0, 1) < PARAMS['downsample']:
                imgs_idxs.append(img_idx)
        else:
            imgs_idxs.append(img_idx)
    else:
        imgs_idxs.append(img_idx)

train_datagen = DataGenKid(
        imgs_path=IMGS_PATH, 
        msks_path=MSKS_PATH, 
        imgs_idxs=imgs_idxs, 
        img_size=PARAMS['img_size'], 
        batch_size=PARAMS['batch_size'], 
        mode='fit', 
        shuffle=True,           
        aug=aug, 
        resize=None
)
val_datagen = DataGenKid(
    imgs_path=IMGS_PATH, 
    msks_path=MSKS_PATH, 
    imgs_idxs=imgs_idxs, 
    img_size=PARAMS['img_size'], 
    batch_size=PARAMS['batch_size'], 
    mode='fit', 
    shuffle=False,           
    aug=None, 
    resize=None
)

In [None]:
bsize = min(8, PARAMS['batch_size'])
Xt, yt = train_datagen.__getitem__(3)
print('test X: ', Xt.shape)
print('test y: ', yt.shape)
fig, axes = plt.subplots(figsize=(16, 4), nrows=2, ncols=bsize)
for j in range(bsize):
    axes[0, j].imshow(Xt[j])
    axes[0, j].set_title(j)
    axes[0, j].axis('off')
    axes[1, j].imshow(yt[j])
    axes[1, j].axis('off')
plt.show()

In [None]:
bsize = min(8, PARAMS['batch_size'])
Xt, yt = val_datagen.__getitem__(5)
print('test X: ', Xt.shape)
print('test y: ', yt.shape)
fig, axes = plt.subplots(figsize=(16, 4), nrows=2, ncols=bsize)
for j in range(bsize):
    axes[0, j].imshow(Xt[j])
    axes[0, j].set_title(j)
    axes[0, j].axis('off')
    axes[1, j].imshow(yt[j])
    axes[1, j].axis('off')
plt.show()

In [None]:
def dice_coef(y_true, y_pred, smooth=1):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2 * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)

def dice_loss(y_true, y_pred, smooth=1):
    return (1 - dice_coef(y_true, y_pred, smooth))

def bce_dice_loss(y_true, y_pred):
    return PARAMS['bce_weight'] * binary_crossentropy(y_true, y_pred) + \
        (1 - PARAMS['bce_weight']) * dice_loss(y_true, y_pred)

def get_model(backbone, input_shape, loss_type='bce_dice', 
              umodel='unet', classes=1, lr=.001):
    with STRATEGY.scope():
        if loss_type == 'bce_dice': 
            loss = bce_dice_loss
        elif loss_type == 'bce_jaccard_loss':
            loss = bce_jaccard_loss
        else:
            raise AttributeError('loss mode parameter error')
        if umodel == 'unet':
            model = Unet(backbone_name=backbone, encoder_weights='imagenet',
                         input_shape=input_shape,
                         classes=classes, activation='sigmoid')
        elif umodel == 'fpn':
            model = FPN(backbone_name=backbone, encoder_weights='imagenet',
                        input_shape=input_shape,
                        classes=classes, activation='sigmoid')
        elif umodel == 'link':
            model = Linknet(backbone_name=backbone, encoder_weights='imagenet',
                            input_shape=input_shape,
                            classes=classes, activation='sigmoid')
        else:
            raise AttributeError('umodel mode parameter error')
        model.compile(
            optimizer=tfa.optimizers.Lookahead(
                tf.keras.optimizers.Adam(learning_rate=lr),
                sync_period=max(6, int(PARAMS['patience'] / 4))
            ),
            loss=loss, 
            metrics=[dice_coef]
        )
    return model

In [None]:
def get_lr_callback(batch_size=10, epochs=100, warmup=5, plot=False):
    lr_start = 1e-5
    lr_max = 1e-3
    lr_min = lr_start / 100
    lr_ramp_ep = warmup
    lr_sus_ep = 0
    lr_decay = .95
    
    def lr_scheduler(epoch):
        if epoch < lr_ramp_ep:
            lr = (lr_max - lr_start) / lr_ramp_ep * epoch + lr_start
        elif epoch < lr_ramp_ep + lr_sus_ep:
            lr = lr_max
        else:
            lr = (lr_max - lr_min) * lr_decay ** (epoch - lr_ramp_ep - lr_sus_ep) + lr_min
        return lr
        
    if not plot:
        lr_callback = tf.keras.callbacks.LearningRateScheduler(lr_scheduler, verbose=False)
        return lr_callback 
    else: 
        return lr_scheduler
    
if PARAMS['decay']:
    lr_scheduler_plot = get_lr_callback(
        batch_size=PARAMS['batch_size'], 
        epochs=PARAMS['epochs'], 
        plot=True
    )
    xs = [i for i in range(PARAMS['epochs'])]
    y = [lr_scheduler_plot(x) for x in xs]
    plt.plot(xs, y)
    plt.title(f'lr schedule from {y[0]:.5f} to {max(y):.3f} to {y[-1]:.8f}')
    plt.show()

# Train

In [None]:
def train_model(mparams, n_fold, train_datagen, val_datagen):
    model = get_model(
        mparams['backbone'], 
        input_shape=(mparams['img_size'], mparams['img_size'], 3),
        loss_type=mparams['loss'],
        umodel=mparams['umodel'],
        lr=mparams['lr']
    )
    checkpoint_path = f'{MDLS_PATH}/model_{n_fold}.hdf5'
    earlystopper = EarlyStopping(
        monitor='val_dice_coef', 
        patience=mparams['patience'], 
        verbose=0,
        restore_best_weights=True,
        mode='max'
    )
    lrreducer = ReduceLROnPlateau(
        monitor='val_dice_coef', 
        factor=.1, 
        patience=int(mparams['patience'] / 2), 
        verbose=0, 
        min_lr=1e-7,
        mode='max'
    )
    checkpointer = ModelCheckpoint(
        checkpoint_path, 
        monitor='val_dice_coef', 
        verbose=0, 
        save_best_only=True,
        save_weights_only=True, 
        mode='max'
    )
    callbacks = [earlystopper, checkpointer]
    if mparams['decay']:
        callbacks.append(get_lr_callback(mparams['batch_size']))
        print('lr warmup and decay')
    else:
        callbacks.append(lrreducer)
        print('lr reduce on plateau')
    history = model.fit(
        train_datagen,
        validation_data=val_datagen,
        callbacks=callbacks,
        epochs=mparams['epochs'],
        verbose=1
    )
    history_file = f'{MDLS_PATH}/history_{n_fold}.json'
    dict_to_save = {}
    for k, v in history.history.items():
        dict_to_save.update({k: [np.format_float_positional(x) for x in history.history[k]]})
    with open(history_file, 'w') as file:
        json.dump(dict_to_save, file)
    model.load_weights(checkpoint_path)
    return model, history

In [None]:
msks_strat = []
for img_idx in tqdm(imgs_idxs):
    msk_path = f'{MSKS_PATH}/{img_idx}.png'
    msk = cv2.imread(msk_path, cv2.IMREAD_GRAYSCALE)
    msk_max = PARAMS['img_size'] ** 2
    if np.sum(msk) == 0:
        msk_cls = 0
    elif np.sum(msk) < msk_max * .01:
        msk_cls = 1
    elif np.sum(msk) < msk_max * .1:
        msk_cls = 2
    elif np.sum(msk) < msk_max * .2:
        msk_cls = 3
    elif np.sum(msk) < msk_max * .3:
        msk_cls = 4
    else:
        msk_cls = 5
    msks_strat.append(msk_cls)
plt.figure(figsize=(8, 4))
plt.hist(msks_strat, bins=20)
plt.show()

In [None]:
for iname in list(set([x[:9] for x in imgs_idxs])):
    print('img name:', iname, 
          '| imgs number:', len([x for x in imgs_idxs if x[:9] == iname]))
if PARAMS['split'] == 'kfold':
    kfold = KFold(n_splits=PARAMS['folds'],     
                  random_state=PARAMS['seed'],
                  shuffle=True).split(imgs_idxs)
elif PARAMS['split'] == 'group':
    grps = [x[:9] for x in imgs_idxs]
    kfold = GroupKFold(n_splits=PARAMS['folds']).split(imgs_idxs, imgs_idxs, grps)
elif PARAMS['split'] == 'strat':
    kfold = StratifiedKFold(n_splits=PARAMS['folds'],     
                            random_state=PARAMS['seed'],
                            shuffle=True).split(imgs_idxs, msks_strat)
else:
    raise AttributeError('split mode parameter error')

In [None]:
epoch_by_folds = []
loss_by_folds = []
dice_coef_by_folds = []
        
for n, (tr, te) in enumerate(kfold):
    print('=' * 10, f'FOLD {n}', '=' * 10)
    X_tr = [imgs_idxs[i] for i in tr]; X_val = [imgs_idxs[i] for i in te]
    M_tr = [msks_strat[i] > 0 for i in tr]; M_val = [msks_strat[i] > 0 for i in te]
    print('train:', len(X_tr), '| val:', len(X_val))
    print('masks in train:', sum(M_tr) / len(X_tr), 
          '| masks in val:', sum(M_val) / len(X_val))
    print('groups train:', set([x[:9] for x in X_tr]), 
          '\ngroups val:', set([x[:9] for x in X_val]))
    train_datagen = DataGenKid(
        imgs_path=IMGS_PATH, 
        msks_path=MSKS_PATH, 
        imgs_idxs=X_tr, 
        img_size=PARAMS['img_size'], 
        batch_size=PARAMS['batch_size'], 
        mode='fit', 
        shuffle=True,           
        aug=aug, 
        resize=None
    )
    val_datagen = DataGenKid(
        imgs_path=IMGS_PATH, 
        msks_path=MSKS_PATH, 
        imgs_idxs=X_val, 
        img_size=PARAMS['img_size'], 
        batch_size=PARAMS['batch_size'], 
        mode='fit', 
        shuffle=False,           
        aug=None, 
        resize=None
    )
    model, history = train_model(PARAMS, n, train_datagen, val_datagen)
    plt.plot(history.history['loss'], label='loss')
    plt.plot(history.history['val_loss'], label='val_loss')
    plt.legend()
    plt.show()
    plt.plot(history.history['dice_coef'], label='dice_coef')
    plt.plot(history.history['val_dice_coef'], label='val_dice_coef')
    plt.legend()
    plt.show()
    best_epoch = np.argmax(history.history['val_dice_coef'])
    best_loss = history.history['val_loss'][best_epoch]
    best_dice_coef = history.history['val_dice_coef'][best_epoch]
    print('best epoch:', best_epoch, 
          '| best loss:', best_loss,
          '| best dice coef:', best_dice_coef)
    epoch_by_folds.append(best_epoch)
    loss_by_folds.append(best_loss)
    dice_coef_by_folds.append(best_dice_coef)
    del train_datagen, val_datagen, model; gc.collect()
    
elapsed_time = time.time() - start_time
print(f'time elapsed: {elapsed_time // 60:.0f} min {elapsed_time % 60:.0f} sec')

In [None]:
result = PARAMS.copy()
result['bavg_epoch'] = np.mean(epoch_by_folds)
result['bavg_loss'] = np.mean(loss_by_folds)
result['bavg_dice_coef'] = np.mean(dice_coef_by_folds)
result['dice_by_folds'] = ' '.join([f'{x:.4f}' for x in dice_coef_by_folds])
with open(f'{MDLS_PATH}/params.json', 'w') as file:
    json.dump(result, file)
if not os.path.exists('results.csv'):
    df_save = pd.DataFrame(result, index=[0])
    df_save.to_csv('results.csv', sep='\t')
else:
    df_old = pd.read_csv('results.csv', sep='\t', index_col=0)
    df_save = pd.DataFrame(result, index=[df_old.index.max() + 1])
    df_save = df_old.append(df_save, ignore_index=True)
    df_save.to_csv('results.csv', sep='\t')

In [None]:
pd.read_csv('results.csv', sep='\t', index_col=0)

# Predict

In [None]:
larger = 4
test_models = []
for n_fold in list(range(PARAMS['folds'])):
    checkpoint_path = f'{MDLS_PATH}/model_{n_fold}.hdf5'
    print(checkpoint_path)
    model_lrg = get_model(
        PARAMS['backbone'], 
        input_shape=(PARAMS['img_size'] * larger, PARAMS['img_size'] * larger, 3),
        loss_type=PARAMS['loss'],
        umodel=PARAMS['umodel']
    )
    model_lrg.load_weights(checkpoint_path) # or .set_weights(model.get_weights()) from smaller model
    test_models.append(model_lrg)

In [None]:
img_num = 0
resize = PARAMS['resize']
shft = .6
wnd = PARAMS['img_size'] * larger
img = tiff.imread(os.path.join('./data/train', df_masks.index[img_num] + '.tiff'))
if len(img.shape) == 5: img = np.transpose(img.squeeze(), (1, 2, 0))
mask = enc2mask(df_masks.iloc[img_num], (img.shape[1], img.shape[0]))
print(img.shape, mask.shape)
img = cv2.resize(img,
                 (img.shape[1] // resize, img.shape[0] // resize),
                 interpolation=cv2.INTER_AREA)
mask = cv2.resize(mask,
                  (mask.shape[1] // resize, mask.shape[0] // resize),
                  interpolation=cv2.INTER_NEAREST)
img = img[int(img.shape[0]*shft) : int(img.shape[0]*shft)+wnd, 
          int(img.shape[1]*shft) : int(img.shape[1]*shft)+wnd, 
          :]
mask = mask[int(mask.shape[0]*shft) : int(mask.shape[0]*shft)+wnd, 
            int(mask.shape[1]*shft) : int(mask.shape[1]*shft)+wnd]
plt.figure(figsize=(4, 4))
plt.axis('off')
plt.imshow(img)
plt.imshow(mask, alpha=.4)
plt.show()

In [None]:
def dice_np(pred, true, k=1):
    intersection = np.sum(pred[true==k]) * 2
    dice = intersection / (np.sum(pred) + np.sum(true))
    return dice

def get_dice(mask, mask_lrg, th):
    mask_pred = np.squeeze(mask_lrg > th).astype(int)
    return dice_np(mask, mask_pred)

def get_best_th_dice(mask, mask_lrg, n=100, plot=False):
    thresholds = np.linspace(0, 1, n)
    dices = [get_dice(mask, mask_lrg, th) for th in thresholds]
    n_max = np.argmax(dices)
    if plot:
        plt.plot(thresholds, dices)
        plt.title(f'th: {thresholds[n_max]:.2f} dice: {dices[n_max]:.2f}')
        plt.show()
    return thresholds[n_max], dices[n_max]

In [None]:
fig, axes = plt.subplots(figsize=(16, 4), nrows=1, ncols=len(test_models))
for j in range(len(test_models)):
    mask_lrg = test_models[j].predict(img[np.newaxis, ] / 255)
    axes[j].imshow(np.squeeze(mask_lrg))
    axes[j].set_title(f'img {j}: {np.min(mask_lrg):.2f}-{np.max(mask_lrg):.2f}')
    axes[j].axis('off')
    print(get_best_th_dice(mask, mask_lrg))
plt.show()

In [None]:
get_best_th_dice(mask, mask_lrg, n=100, plot=True)

In [None]:
plt.figure(figsize=(14, 4))
plt.hist(mask_lrg.flatten(), bins=100)
plt.show()

In [None]:
plt.figure(figsize=(14, 4))
plt.hist(np.where(mask_lrg < 10e-4, np.nan, mask_lrg).flatten(), bins=100)
plt.show()