In [None]:
import warnings
warnings.filterwarnings('ignore')
import os
import cv2
import json
import time
import random
import numpy as np
import pandas as pd 
import tifffile as tiff
import tensorflow as tf
import tensorflow_addons as tfa
import matplotlib.pyplot as plt
import albumentations as albu
from sklearn.model_selection import train_test_split, KFold
import tensorflow.keras.backend as K
from tensorflow.keras import Model, Sequential
from tensorflow.keras.models import load_model
from tensorflow.keras.utils import Sequence
from tensorflow.keras.losses import binary_crossentropy
from tensorflow.keras.layers import *
from tensorflow.keras.optimizers import Adam, SGD
from tensorflow.keras.callbacks import *
import segmentation_models as sm
from segmentation_models import Unet
print('tensorflow version:', tf.__version__)
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
gpu_devices = tf.config.experimental.list_physical_devices('GPU')
if gpu_devices:
    for gpu_device in gpu_devices:
        print('device available:', gpu_device)
pd.set_option('display.max_columns', None)

In [None]:
VER = 'v5'
DATA_PATH = './data'
IMGS_PATH = './data/tiles'
MSKS_PATH = './data/masks'
MDLS_PATH = f'./models_{VER}'
if not os.path.exists(MDLS_PATH):
    os.mkdir(MDLS_PATH)
PARAMS = {
    'version': VER,
    'folds': 4,
    'img_size': 256,
    'resize': 4,
    'batch_size': 8,
    'epochs': 1000,
    'patience': 16,
    'decay': False,
    'backbone': 'efficientnetb5',
    'bce_weight': .5,
    'seed': 2020
}
with open(f'{MDLS_PATH}/params.json', 'w') as file:
    json.dump(PARAMS, file)
    
def seed_all(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    tf.random.set_seed(seed)

seed_all(PARAMS['seed'])
start_time = time.time()

In [None]:
df_masks = pd.read_csv(f'{DATA_PATH}/train.csv').set_index('id')
df_masks

In [None]:
def enc2mask(encs, shape):
    img = np.zeros(shape[0] * shape[1], dtype=np.uint8)
    for m, enc in enumerate(encs):
        if isinstance(enc, np.float) and np.isnan(enc): continue
        s = enc.split()
        for i in range(len(s) // 2):
            start = int(s[2 * i]) - 1
            length = int(s[2 * i + 1])
            img[start : start + length] = 1 + m
    return img.reshape(shape).T

def show_img_n_mask(df, img_num, resize):
    img = tiff.imread(os.path.join('./data/train', df.index[img_num] + '.tiff'))
    if len(img.shape) == 5: img = np.transpose(img.squeeze(), (1, 2, 0))
    mask = enc2mask(df.iloc[img_num], (img.shape[1], img.shape[0]))
    print(img.shape, mask.shape)
    img = cv2.resize(img,
                     (img.shape[1] // resize, img.shape[0] // resize),
                     interpolation=cv2.INTER_AREA)
    mask = cv2.resize(mask,
                      (mask.shape[1] // resize, mask.shape[0] // resize),
                      interpolation=cv2.INTER_NEAREST)
    plt.figure(figsize=(8, 8))
    plt.axis('off')
    plt.imshow(img)
    plt.imshow(mask, alpha=.4)
    plt.show()

In [None]:
show_img_n_mask(df=df_masks, img_num=4, resize=PARAMS['resize'])

In [None]:
aug = albu.Compose(
    [
        albu.OneOf(
            [
                albu.RandomBrightness(limit=.15), 
                albu.RandomContrast(limit=.15), 
                albu.RandomGamma()
            ], 
            p=.33
        ),
        albu.RandomRotate90(p=.33),
        albu.HorizontalFlip(p=.33),
        albu.VerticalFlip(p=.33),
        albu.ShiftScaleRotate(shift_limit=.1, scale_limit=.1, rotate_limit=20, p=.33)
    ]
)

In [None]:
class DataGenKid(Sequence):
    
    def __init__(self, imgs_path, msks_path, imgs_idxs, img_size,
                 batch_size=32, mode='fit', shuffle=False, 
                 aug=None, resize=None):
        self.imgs_path = imgs_path
        self.msks_path = msks_path
        self.imgs_idxs = imgs_idxs
        self.img_size = img_size
        self.batch_size = batch_size
        self.mode = mode
        self.shuffle = shuffle
        self.aug = aug
        self.resize = resize
        self.on_epoch_end()
        
    def __len__(self):
        return int(np.floor(len(self.imgs_idxs) / self.batch_size))
    
    def on_epoch_end(self):
        self.indexes = np.arange(len(self.imgs_idxs))
        if self.shuffle:
            np.random.shuffle(self.indexes)
            
    def __getitem__(self, index):
        batch_size = min(self.batch_size, len(self.imgs_idxs) - index*self.batch_size)
        X = np.zeros((batch_size, self.img_size, self.img_size, 3), dtype=np.float32)
        imgs_batch = self.imgs_idxs[index * self.batch_size : (index+1) * self.batch_size]
        if self.mode == 'fit':
            y = np.zeros((batch_size, self.img_size, self.img_size), dtype=np.float32)
            for i, img_idx in enumerate(imgs_batch):
                X[i, ], y[i] = self.get_tile(img_idx)
            return X, y
        elif self.mode == 'predict':
            for i, img_idx in enumerate(imgs_batch):
                X[i, ] = self.get_tile(img_idx)
            return X
        else:
            raise AttributeError('mode parameter error')
            
    def get_tile(self, img_idx):
        img_path = f'{self.imgs_path}/{img_idx}.png'
        img = cv2.imread(img_path)
        if img is None:
            print('error load image:', img_path)
        if self.resize:
            img = cv2.resize(img, (int(img.shape[1] / self.resize), int(img.shape[0] / self.resize)))
        img = img.astype(np.float32) / 255
        if self.mode == 'fit':
            msk_path = f'{self.msks_path}/{img_idx}.png'
            msk = cv2.imread(msk_path, cv2.IMREAD_GRAYSCALE)
            if msk is None:
                print('error load mask:', msk_path)
            if self.resize:
                msk = cv2.resize(msk, (int(msk.shape[1] / self.resize), int(msk.shape[0] / self.resize)))
            msk = msk.astype(np.float32)
            if self.aug:
                augmented = self.aug(image=img, mask=msk)
                img = augmented['image']
                msk = augmented['mask']
            return img, msk
        else:
            if self.aug:
                img = self.aug(image=img)['image']
            return img

In [None]:
imgs_idxs = [x.replace('.png', '') for x in os.listdir(IMGS_PATH) if '.png' in x]
train_datagen = DataGenKid(
        imgs_path=IMGS_PATH, 
        msks_path=MSKS_PATH, 
        imgs_idxs=imgs_idxs, 
        img_size=PARAMS['img_size'], 
        batch_size=PARAMS['batch_size'], 
        mode='fit', 
        shuffle=True,           
        aug=aug, 
        resize=None
)
val_datagen = DataGenKid(
    imgs_path=IMGS_PATH, 
    msks_path=MSKS_PATH, 
    imgs_idxs=imgs_idxs, 
    img_size=PARAMS['img_size'], 
    batch_size=PARAMS['batch_size'], 
    mode='fit', 
    shuffle=False,           
    aug=None, 
    resize=None
)

In [None]:
bsize = 8
Xt, yt = train_datagen.__getitem__(3)
print('test X: ', Xt.shape)
print('test y: ', yt.shape)
fig, axes = plt.subplots(figsize=(16, 4), nrows=2, ncols=bsize)
for j in range(bsize):
    axes[0, j].imshow(Xt[j])
    axes[0, j].set_title(j)
    axes[0, j].axis('off')
    axes[1, j].imshow(yt[j])
    axes[1, j].axis('off')
plt.show()

In [None]:
bsize = 8
Xt, yt = val_datagen.__getitem__(5)
print('test X: ', Xt.shape)
print('test y: ', yt.shape)
fig, axes = plt.subplots(figsize=(16, 4), nrows=2, ncols=bsize)
for j in range(bsize):
    axes[0, j].imshow(Xt[j])
    axes[0, j].set_title(j)
    axes[0, j].axis('off')
    axes[1, j].imshow(yt[j])
    axes[1, j].axis('off')
plt.show()

In [None]:
def dice_coef(y_true, y_pred, smooth=1):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2 * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)

def dice_loss(y_true, y_pred, smooth=1):
    return (1 - dice_coef(y_true, y_pred, smooth))

def bce_dice_loss(y_true, y_pred):
    return PARAMS['bce_weight'] * binary_crossentropy(y_true, y_pred) + \
    (1 - PARAMS['bce_weight']) * dice_loss(y_true, y_pred)

def get_model(backbone, input_shape, classes=1, learning_rate=.001):
    model = Unet(
            backbone_name=backbone,
            input_shape=input_shape,
            classes=classes, 
            activation='sigmoid'
    )
    model.compile(
        optimizer=tfa.optimizers.Lookahead(
            tf.keras.optimizers.Adam(learning_rate=learning_rate),
            sync_period=10
        ),
        loss=bce_dice_loss, 
        metrics=[dice_coef]
    )
    return model

In [None]:
def get_lr_callback(batch_size=10, epochs=100, warmup=5, plot=False):
    lr_start = 1e-5
    lr_max = 1e-3
    lr_min = lr_start / 100
    lr_ramp_ep = warmup
    lr_sus_ep = 0
    lr_decay = .95
    
    def lr_scheduler(epoch):
        if epoch < lr_ramp_ep:
            lr = (lr_max - lr_start) / lr_ramp_ep * epoch + lr_start
        elif epoch < lr_ramp_ep + lr_sus_ep:
            lr = lr_max
        else:
            lr = (lr_max - lr_min) * lr_decay ** (epoch - lr_ramp_ep - lr_sus_ep) + lr_min
        return lr
        
    if not plot:
        lr_callback = tf.keras.callbacks.LearningRateScheduler(lr_scheduler, verbose=False)
        return lr_callback 
    else: 
        return lr_scheduler
    
if PARAMS['decay']:
    lr_scheduler_plot = get_lr_callback(
        batch_size=PARAMS['batch_size'], 
        epochs=PARAMS['epochs'], 
        plot=True
    )
    xs = [i for i in range(PARAMS['epochs'])]
    y = [lr_scheduler_plot(x) for x in xs]
    plt.plot(xs, y)
    plt.title(f'lr schedule from {y[0]:.5f} to {max(y):.3f} to {y[-1]:.8f}')
    plt.show()

In [None]:
def train_model(mparams, n_fold, train_datagen, val_datagen):
    model = get_model(
        mparams['backbone'], 
        input_shape=(mparams['img_size'], mparams['img_size'], 3)
    )
    checkpoint_path = f'{MDLS_PATH}/model_{n_fold}.hdf5'
    earlystopper = EarlyStopping(
        monitor='val_dice_coef', 
        patience=mparams['patience'], 
        verbose=0,
        restore_best_weights=True,
        mode='max'
    )
    lrreducer = ReduceLROnPlateau(
        monitor='val_dice_coef', 
        factor=.1, 
        patience=int(mparams['patience'] / 2), 
        verbose=0, 
        min_lr=1e-7,
        mode='max'
    )
    checkpointer = ModelCheckpoint(
        checkpoint_path, 
        monitor='val_dice_coef', 
        verbose=0, 
        save_best_only=True,
        save_weights_only=True, 
        mode='max'
    )
    callbacks = [earlystopper, checkpointer]
    if mparams['decay']:
        callbacks.append(get_lr_callback(mparams['batch_size']))
        print('lr warmup and decay')
    else:
        callbacks.append(lrreducer)
        print('lr reduce on plateau')
    history = model.fit(
        train_datagen,
        validation_data=val_datagen,
        callbacks=callbacks,
        epochs=mparams['epochs'],
        verbose=1
    )
    history_file = f'{MDLS_PATH}/history_{n_fold}.json'
    dict_to_save = {}
    for k, v in history.history.items():
        dict_to_save.update({k: [np.format_float_positional(x) for x in history.history[k]]})
    with open(history_file, 'w') as file:
        json.dump(dict_to_save, file)
    model.load_weights(checkpoint_path)
    return model, history

In [None]:
kfold = KFold(n_splits=PARAMS['folds'],     
              random_state=PARAMS['seed'],
              shuffle=True).split(imgs_idxs)
bavg_epoch = 0
bavg_loss = 0
bavg_dice_coef = 0

for n, (tr, te) in enumerate(kfold):
    print('=' * 10, f'FOLD {n}', '=' * 10)
    X_tr = [imgs_idxs[i] for i in tr]; X_val = [imgs_idxs[i] for i in te]
    print('train:', len(X_tr), '| test:', len(X_val))
    train_datagen = DataGenKid(
        imgs_path=IMGS_PATH, 
        msks_path=MSKS_PATH, 
        imgs_idxs=X_tr, 
        img_size=PARAMS['img_size'], 
        batch_size=PARAMS['batch_size'], 
        mode='fit', 
        shuffle=True,           
        aug=aug, 
        resize=None
    )
    val_datagen = DataGenKid(
        imgs_path=IMGS_PATH, 
        msks_path=MSKS_PATH, 
        imgs_idxs=X_val, 
        img_size=PARAMS['img_size'], 
        batch_size=PARAMS['batch_size'], 
        mode='fit', 
        shuffle=False,           
        aug=None, 
        resize=None
    )
    model, history = train_model(PARAMS, n, train_datagen, val_datagen)
    plt.plot(history.history['loss'], label='loss')
    plt.plot(history.history['val_loss'], label='val_loss')
    plt.legend()
    plt.show()
    plt.plot(history.history['dice_coef'], label='dice_coef')
    plt.plot(history.history['val_dice_coef'], label='val_dice_coef')
    plt.legend()
    plt.show()
    best_epoch = np.argmin(history.history['val_loss'])
    best_loss = history.history['val_loss'][best_epoch]
    best_dice_coef = history.history['val_dice_coef'][best_epoch]
    print('best epoch:', best_epoch, 
          '| best loss:', best_loss,
          '| best dice coef:', best_dice_coef)
    bavg_epoch = bavg_epoch + (best_epoch / PARAMS['folds'])
    bavg_loss = bavg_loss + (best_loss / PARAMS['folds'])
    bavg_dice_coef = bavg_dice_coef + (best_dice_coef / PARAMS['folds'])
    

elapsed_time = time.time() - start_time
print(f'time elapsed: {elapsed_time // 60:.0f} min {elapsed_time % 60:.0f} sec')

In [None]:
result = PARAMS.copy()
result['bavg_epoch'] = bavg_epoch
result['bavg_loss'] = bavg_loss
result['bavg_dice_coef'] = bavg_dice_coef
if not os.path.exists('results.csv'):
    df_save = pd.DataFrame(result, index=[0])
    df_save.to_csv('results.csv', sep='\t')
else:
    df_old = pd.read_csv('results.csv', sep='\t', index_col=0)
    df_save = pd.DataFrame(result, index=[df_old.index.max() + 1])
    df_save = df_old.append(df_save, ignore_index=True)
    df_save.to_csv('results.csv', sep='\t')

In [None]:
pd.read_csv('results.csv', sep='\t', index_col=0)

In [None]:
n_fold = 0
larger = 1
checkpoint_path = f'{MDLS_PATH}/model_{n_fold}.hdf5'
model_lrg = get_model(
    PARAMS['backbone'], 
    input_shape=(PARAMS['img_size'] * larger, PARAMS['img_size'] * larger, 3)
)
model_lrg.load_weights(checkpoint_path) # or .set_weights(model.get_weights()) from smaller model
model_lrg.summary()

In [None]:
bsize = 16
val_datagen = DataGenKid(
    imgs_path=IMGS_PATH, 
    msks_path=MSKS_PATH, 
    imgs_idxs=imgs_idxs, 
    img_size=PARAMS['img_size'], 
    batch_size=bsize, 
    mode='fit', 
    shuffle=False,           
    aug=None, 
    resize=None
)
Xt, yt = val_datagen.__getitem__(0)
y_pred = model_lrg.predict(Xt)

In [None]:
print('test X: ', Xt.shape)
print('test y: ', yt.shape)
print('pred y: ', y_pred.shape)
fig, axes = plt.subplots(figsize=(16, 4), nrows=3, ncols=bsize)
for j in range(bsize):
    axes[0, j].imshow(Xt[j])
    axes[0, j].set_title(j)
    axes[0, j].axis('off')
    axes[1, j].set_title('true')
    axes[1, j].imshow(yt[j])
    axes[1, j].axis('off')
    axes[2, j].set_title('pred')
    axes[2, j].imshow(np.squeeze(y_pred[j]))
    axes[2, j].axis('off')
plt.show()

In [None]:
n_fold = 0
larger = 4
checkpoint_path = f'{MDLS_PATH}/model_{n_fold}.hdf5'
model_lrg = get_model(
    PARAMS['backbone'], 
    input_shape=(PARAMS['img_size'] * larger, PARAMS['img_size'] * larger, 3)
)
model_lrg.load_weights(checkpoint_path) # or .set_weights(model.get_weights()) from smaller model
model_lrg.summary()

In [None]:
img_num = 0
resize = 4
shft = .6
wnd = PARAMS['img_size'] * larger
img = tiff.imread(os.path.join('./data/train', df_masks.index[img_num] + '.tiff'))
if len(img.shape) == 5: img = np.transpose(img.squeeze(), (1, 2, 0))
mask = enc2mask(df_masks.iloc[img_num], (img.shape[1], img.shape[0]))
print(img.shape, mask.shape)
img = cv2.resize(img,
                 (img.shape[1] // resize, img.shape[0] // resize),
                 interpolation=cv2.INTER_AREA)
mask = cv2.resize(mask,
                  (mask.shape[1] // resize, mask.shape[0] // resize),
                  interpolation=cv2.INTER_NEAREST)
img = img[int(img.shape[0]*shft) : int(img.shape[0]*shft)+wnd, 
          int(img.shape[1]*shft) : int(img.shape[1]*shft)+wnd, 
          :]
mask = mask[int(mask.shape[0]*shft) : int(mask.shape[0]*shft)+wnd, 
            int(mask.shape[1]*shft) : int(mask.shape[1]*shft)+wnd]
plt.figure(figsize=(8, 8))
plt.axis('off')
plt.imshow(img)
plt.imshow(mask, alpha=.4)
plt.show()

In [None]:
mask_lrg = model_lrg.predict(img[np.newaxis, ] / 255)
plt.figure(figsize=(8, 8))
plt.axis('off')
plt.imshow(np.squeeze(mask_lrg))
plt.show()

In [None]:
plt.figure(figsize=(14, 4))
plt.hist(mask_lrg.flatten(), bins=100)
plt.show()