In [None]:
!pip install -q catalyst==20.12
!pip install -q pytorch-toolbelt==0.4.2
!pip install -q torch-optimizer==0.1.0
!pip install -q segmentation-models-pytorch==0.1.3
!pip install -q ttach==0.0.3
!pip install -q albumentations==0.5.2
!pip install timm==0.3.2
!pip install opencv-python-headless==4.1.2.30
!pip install wandb

In [None]:
# import segmentation_models_pytorch as smp
import torch
import torch.nn as nn
from torch.optim.optimizer import Optimizer
import torch.nn.functional as F
import math
import warnings

from catalyst.contrib.nn import OneCycleLRWithWarmup
from torch.optim.lr_scheduler import (
    ExponentialLR,
    CyclicLR,
    MultiStepLR,
    CosineAnnealingLR,
    CosineAnnealingWarmRestarts,
    ReduceLROnPlateau
)
from pytorch_toolbelt.losses import *
from pytorch_toolbelt.utils import image_to_tensor
from pytorch_toolbelt.utils.random import set_manual_seed
from torch.nn import KLDivLoss
from catalyst import utils
from catalyst.contrib.nn import OneCycleLRWithWarmup
from catalyst import dl
from catalyst.contrib.utils.cv import image as cata_image
from catalyst.contrib.callbacks.wandb_logger import WandbLogger
from catalyst.dl import (
    SupervisedRunner,
    CriterionCallback,
    EarlyStoppingCallback,
    SchedulerCallback,
    MetricAggregationCallback,
    IouCallback,
    DiceCallback,
    InferCallback, CheckpointCallback
)
from torch.optim.lr_scheduler import CyclicLR, ReduceLROnPlateau
from torch.utils.data import Dataset, DataLoader

import albumentations as A

import os
import cv2
import json
from PIL import Image
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import precision_recall_curve, auc, average_precision_score
from pathlib import Path
from tqdm .auto import tqdm
import plotly.express as px
from sklearn.model_selection import train_test_split
from pathlib import Path
from collections import OrderedDict
from typing import List,  Optional, Dict

%matplotlib inline

In [None]:
MAIN_PATH = Path('/content/drive/MyDrive/hubmap256x256')
IMG_PATHS = MAIN_PATH / 'train'
MASK_PATHS = MAIN_PATH / 'masks'
LABEL = '/content/drive/MyDrive/kidneysegmentation/train.csv'
TEST_IMG_PATHS = MAIN_PATH / 'test'
TEST_MASK = MAIN_PATH / 'test_masks'

In [None]:
class BaseConfig:
    __basedir__ = MAIN_PATH
    train_img_path = IMG_PATHS
    train_mask_path = MASK_PATHS

    #Data config
    augmentation = 'medium' #options: normal, easy, medium, advanced
    scale_size = 256

    #Train config
    num_epochs = 10
    batch_size = 32
    val_batch_size = 32
    learning_rate = 1e-5
    learning_rate_decode = 1e-3
    weight_decay = 2.5e-5
    is_fp16 = True

    #Model config
    model_name = None
    model_params = None

    #Metric config
    metric = "dice"
    mode = "max"

    #Optimize config
    criterion = {"bce": 0.8, 'log_dice': 0.2}
    pos_weights = [200]
    optimizer = "madgrad"
    scheduler = "simple"

    resume_path = None #Resume training

    @classmethod
    def get_all_attributes(cls):
        d = {}
        attributes = dict(cls.__dict__)

        for k, v in attributes.items():
            if not k.startswith('__') and k != 'get_all_attributes':
                d[k] = v

        return d

In [None]:
train_config = BaseConfig.get_all_attributes()

In [None]:
class HUBMAPSegmentation(Dataset):
    def __init__(self, images: List[Path], masks: List[Path] = None, transform=None, preprocessing_fn=None):
        self.images = images
        self.masks = masks
        self.transform = transform
        self.preprocessing_fn = preprocessing_fn

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index: int) -> dict:
        image_path = self.images[index]

        result = OrderedDict()
        image = cata_image.imread(image_path)
        result['image'] = image

        if self.masks is not None:
            mask = Image.open(self.masks[index]).convert('L')
            mask = mask.point(lambda x: 255 if x > 0 else 0, '1')
            mask = np.asarray(mask).astype(np.uint8)
            result['mask'] = mask

        if self.transform is not None:
            transformed = self.transform(**result)
            image = transformed['image']
            if self.masks is not None:
                mask = transformed['mask']
                mask = image_to_tensor(mask, dummy_channels_dim=True).float()
                result['mask'] = mask

        if self.preprocessing_fn is not None:
            image = self.preprocessing_fn(image = image)['image']

        image = image_to_tensor(image).float()
        result['image'] = image
        result['filename'] = image_path.name

        return result

In [None]:
def get_loader(
    images: List[Path],
    random_state: int,
    valid_size: float = 0.2,
    batch_size: int = 4,
    val_batch_size: int = 8,
    num_workers: int = 4,
    train_transforms_fn=None,
    valid_transforms_fn=None,
    preprocessing_fn=None,
    masks: List[Path] = None,
):
    indices = np.arange(len(images))

    train_indices, valid_indices = train_test_split(
        indices, test_size=valid_size, random_state=random_state, shuffle=True)

    np_images = np.array(images)
    train_images = np_images[train_indices].tolist()
    val_images = np_images[valid_indices].tolist()
    np_masks = np.array(masks)
    train_masks = np_masks[train_indices].tolist()
    val_masks = np_masks[valid_indices].tolist()


    train_dataset = HUBMAPSegmentation(
        sorted(train_images),
        masks=sorted(train_masks),
        transform=train_transforms_fn,
        preprocessing_fn=preprocessing_fn,
    )

    valid_dataset = HUBMAPSegmentation(
        sorted(val_images),
        masks=sorted(val_masks),
        transform=valid_transforms_fn,
        preprocessing_fn=preprocessing_fn,
    )


    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        shuffle=True,
        pin_memory=True,
        drop_last=True
    )

    valid_loader = DataLoader(
        valid_dataset,
        batch_size=val_batch_size,
        num_workers=num_workers,
        shuffle=True,
        pin_memory=True,
        drop_last=True
    )




    loaders = OrderedDict()
    loaders['train'] = train_loader
    loaders['valid'] = valid_loader
    loaders['valid_dataset'] = valid_dataset
    # loaders['test'] = test_loader

    return loaders

In [None]:
class BaseTransform(object):
    def __init__(self, image_size: int = 1024, preprocessing_fn=None):
        self.image_size = image_size
        self.preprocessing_fn = preprocessing_fn

    def pre_transform(self):
        raise NotImplementedError()

    def hard_transform(self):
        raise NotImplementedError()

    def resize_transforms(self):
        raise NotImplementedError()

    def _get_compose(self, transform):
        result = A.Compose([
            item for sublist in transform for item in sublist
        ])
        return result

    def train_transform(self):
        return self._get_compose([
            self.resize_transforms(),
            self.hard_transform()
        ])

    def validation_transform(self):
        return self._get_compose([
            self.pre_transform()
        ])

    def test_transform(self):
        return self.validation_transform()

    def get_preprocessing(self):
        return A.Compose([
            A.Lambda(image=self.preprocessing_fn)
        ])

class NormalTransform(BaseTransform):
    def __init__(self, *args, **kwargs):
        super(NormalTransform, self).__init__(*args, **kwargs)

    def hard_transform(self):
        return [
            A.VerticalFlip(p=0.5),
            A.HorizontalFlip(p=0.5),
            A.RandomRotate90(p=0.7),
        ]

    def resize_transforms(self):
        return [
            A.LongestMaxSize(self.image_size),
            A.PadIfNeeded(min_height=self.image_size, min_width=self.image_size,
                          border_mode=cv2.BORDER_CONSTANT, value=0)
        ]

    def pre_transform(self):
        return self.resize_transforms()

class MediumTransform(NormalTransform):
    def __init__(self, *args, **kwargs):
        super(MediumTransform, self).__init__(*args, **kwargs)

    def hard_transform(self):
        return [
            A.VerticalFlip(p=0.5),
            A.HorizontalFlip(p=0.5),
            A.RandomRotate90(p=0.7),
            A.OneOf([
                A.ElasticTransform(alpha=120, sigma=120 * 0.05,
                                   alpha_affine=120 * 0.03, p=0.5),
                A.GridDistortion(p=0.5),
                A.OpticalDistortion(distort_limit=2, shift_limit=0.5, p=0.5)
            ], p=0.5),
            A.CLAHE(p=0.5),
            A.RandomBrightnessContrast(p=0.5),
            A.RandomGamma(p=0.5)
        ]


def get_transform(name):
    if name == 'normal':
        return NormalTransform
    if name == 'easy':
        return EasyTransform
    if name == 'medium':
        return MediumTransform
    if name == 'advanced':
        return AdvancedTransform

In [None]:
normal = get_transform('medium')(train_config['scale_size'])
test_transform = normal.test_transform()
test_images = sorted(TEST_IMG_PATHS.glob('*.*'))
test_masks = sorted(TEST_MASK.glob('*.*'))
indices = np.arange(len(test_images))


In [None]:
test_dataset = HUBMAPSegmentation(
        sorted(test_images),
        transform=test_transform
    )

In [None]:
test_loader = DataLoader(
        test_dataset,
        batch_size=8,
        num_workers=2,
        shuffle=True,
        pin_memory=True,
        drop_last=True
    )

In [None]:
def get_model(params, model_name):

    # Model return logit values
    #if ensemble
    #if model_name == "model_ensemble":

    model = getattr(smp, model_name)(
        **params
    )
    return model

In [None]:
def prepare_everything(exp_name):
    print("===> Get model")
    model = get_model(
        train_config['model_params'],
        train_config['model_name']
    )

    print("===> Get transformation")
    #Define transform (augemntation)
    Transform = get_transform(train_config['augmentation'])
    transforms = Transform(
        train_config['scale_size'],
        preprocessing_fn=preprocessing_fn
    )

    train_transform = transforms.train_transform()
    val_transform = transforms.validation_transform()
    preprocessing = transforms.get_preprocessing()

    print("===> Get data loader")
    loader = get_loader(
        images = sorted(train_config['train_img_path'].glob('*.*')),
        masks = sorted(train_config['train_mask_path'].glob('*.*')),
        random_state = SEED,
        valid_size = 0.2,
        batch_size = train_config['batch_size'],
        val_batch_size = train_config['val_batch_size'],
        num_workers = 2,
        train_transforms_fn=train_transform,
        valid_transforms_fn=val_transform,
        preprocessing_fn=preprocessing
    )

    print("===> Get optimizer")
    param_group = []
    if hasattr(model, 'encoder'):
        encoder_params = filter(lambda p: p.requires_grad, model.encoder.parameters())
        param_group += [{'params': encoder_params, 'lr': train_config['learning_rate']}]
    if hasattr(model, 'decoder'):
        decoder_params = filter(lambda p: p.requires_grad, model.decoder.parameters())
        param_group += [{'params': decoder_params}]
    if hasattr(model, 'segmentation_head'):
        head_params = filter(lambda p: p.requires_grad, model.segmentation_head.parameters())
        param_group += [{'params': head_params}]
    if len(param_group) == 0:
        param_group = list(model.parameters())

    total = int(sum(p.numel() for p in model.parameters()))
    trainable = int(sum(p.numel() for p in model.parameters() if p.requires_grad))
    count_parameters = {"total": total, "trainable": trainable}

    print(
        f'[INFO] total and trainable parameters in the model {count_parameters}'
    )

    #Set optimizer
    optimizer = get_optimizer(
        train_config['optimizer'], param_group, train_config['learning_rate_decode'], train_config['weight_decay'])

    print("===> Get shceduler")
    scheduler = get_scheduler(
        train_config['scheduler'], optimizer, train_config['learning_rate'], train_config['num_epochs'],
        batches_in_epoch=len(loader['train']), mode=train_config['mode']
    )

    print("===> Get loss")
    criterion = {}
    for loss_name in train_config['criterion']:
        if loss_name == 'wbce':
            pos_weights = torch.tensor(train_config['pos_weights'], device=utils.get_device())
            loss_fn = WeightedBCEWithLogits(pos_weights=pos_weights)
        else:
            loss_fn = get_loss(loss_name)
        criterion[loss_name] = loss_fn

    print("===> Get callbacks")
    #Define callbacks
    callbacks = []
    losses = []
    for loss_name, loss_weight in train_config['criterion'].items():
        criterion_callback = CriterionCallback(
            input_key="mask",
            output_key="logits",
            criterion_key=loss_name,
            prefix="loss_"+loss_name,
            multiplier=float(loss_weight)
        )

        callbacks.append(criterion_callback)
        losses.append(criterion_callback.prefix)

    callbacks += [MetricAggregationCallback(
        prefix="loss",
        mode="sum",
        metrics=losses
    )]

    if isinstance(scheduler, (CyclicLR, OneCycleLRWithWarmup)):
        callbacks += [SchedulerCallback(mode="batch")]
    elif isinstance(scheduler, (ReduceLROnPlateau)):
        callbacks += [SchedulerCallback(reduced_metric=train_config['metric'])]

    early_stopping = EarlyStoppingCallback(
        patience=10, metric=train_config['metric'], minimize=False)

    iou_scores = IouCallback(
        input_key="mask",
        activation="Sigmoid",
        threshold=0.5
    )

    dice_scores = DiceCallback(
        input_key="mask",
        activation="Sigmoid",
        threshold=0.5
    )

    prefix = exp_name
    log_dir = os.path.join("/content/drive/MyDrive/training", prefix)
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)


    callbacks += [early_stopping,
                  iou_scores, dice_scores]

    print("===> Saving config setting")
    #Save config as JSON format
    with open(os.path.join(log_dir, 'config.json'), 'w') as f:
        save_config = train_config.copy()
        save_config['train_img_path'] = str(save_config['train_img_path'])
        save_config['train_mask_path'] = str(save_config['train_mask_path'])
        json.dump(save_config, f)

    print("===> Done")

    return {
        'model': model,
        'loader': loader,
        'optimizer': optimizer,
        'criterion': criterion,
        'callbacks': callbacks,
        'scheduler': scheduler,
        'log_dir': log_dir
    }