### reference:  

https://www.kaggle.com/c/hubmap-kidney-segmentation/notebooks   
https://github.com/qubvel/segmentation_models.pytorch/blob/master/examples/cars%20segmentation%20(camvid).ipynb  


#### data preprocessing:  

https://www.kaggle.com/iafoss/256x256-images

In [None]:
!pip install segmentation_models_pytorch

In [None]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import os
import os
import cv2
import numpy as np
import pandas as pd
from torch.utils.data import TensorDataset, DataLoader, Dataset
import albumentations as albu
import segmentation_models_pytorch as smp
from albumentations.pytorch import ToTensor
import torch
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

In [None]:
DEVICE = ('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
!mkdir data
!mkdir data/images
!unzip ../input/256x256-images/train.zip -d data/images

In [None]:
!mkdir data/masks
!unzip ../input/256x256-images/masks.zip -d data/masks

In [None]:
class config:
    images_path = './data/images'
    masks_path = './data/masks'
    backbone = 'resnet34'
    ACTIVATION = 'sigmoid'
    ENCODER_WEIGHTS = 'imagenet'
    lr=1e-3
    epochs=10
    batch_size=8
    T_max=500
    im_size=256
    num_workers=4

In [None]:
train_augmentation = albu.Compose([
                        albu.HorizontalFlip(),
                        albu.OneOf([
                            albu.RandomContrast(),
                            albu.RandomGamma(),
                            albu.RandomBrightness(),
                            ], p=0.3),
                        albu.OneOf([
                            albu.ElasticTransform(alpha=120, sigma=120 * 0.05, alpha_affine=120 * 0.03),
                            albu.GridDistortion(),
                            albu.OpticalDistortion(distort_limit=2, shift_limit=0.5),
                            ], p=0.3),
                        albu.ShiftScaleRotate(),
                        albu.Resize(config.im_size, config.im_size),
                        ToTensor()

                    ])

valid_augmentation = albu.Compose([
                        albu.Resize(config.im_size, config.im_size),
                        ToTensor()
                    ])


class HuBMAPDataset(Dataset):
    def __init__(self, ids, transforms=None):
        self.ids = ids
        self.transforms = transforms
        
    def __getitem__(self, idx):
        name = self.ids[idx]
        img = cv2.imread(f"{config.images_path}/{name}")
        mask = cv2.imread(f"{config.masks_path}/{name}", 0)
        
        if self.transforms:
            augmented = self.transforms(image=img, mask=mask)
            img = augmented['image']
            mask = augmented['mask']

        return img, mask

    def __len__(self):
        return len(self.ids)

In [None]:
data = os.listdir(config.images_path)#[:100]
train_lsit = list(set([row.split("_")[0] for row in data]))
train_idx = [row for row in data if row.split("_")[0] in train_lsit[:-2]]
valid_idx = [row for row in data if row.split("_")[0] not in train_lsit[:-2]]
len(train_idx), len(valid_idx)

In [None]:
train_datasets = HuBMAPDataset(train_idx, transforms=train_augmentation)
valid_datasets = HuBMAPDataset(valid_idx, transforms=valid_augmentation)
train_loader = DataLoader(train_datasets, batch_size=config.batch_size, shuffle=True, num_workers=config.num_workers)
valid_loader = DataLoader(valid_datasets, batch_size=config.batch_size, shuffle=False, num_workers=config.num_workers)

In [None]:
x,y = train_datasets[1]
x.shape,y.shape

In [None]:
# helper function for data visualization
def visualize(**images):
    """PLot images in one row."""
    n = len(images)
    plt.figure(figsize=(16, 5))
    for i, (name, image) in enumerate(images.items()):
        plt.subplot(1, n, i + 1)
        plt.xticks([])
        plt.yticks([])
        plt.title(' '.join(name.split('_')).title())
        plt.imshow(image)
    plt.show()

# same image with different random transforms

image, mask = train_datasets[5]
visualize(image=image.permute(1,2,0), mask=mask.squeeze(0))

In [None]:
model = smp.Unet(
    config.backbone, 
    encoder_weights=config.ENCODER_WEIGHTS, 
    in_channels=3, 
    classes=1, 
    activation=config.ACTIVATION,
    decoder_use_batchnorm=False
)
optimizer = torch.optim.AdamW(model.parameters(),lr=config.lr)

loss_fn = smp.utils.losses.DiceLoss() # smp.utils.losses.BCEWithLogitsLoss()

#metric = [smp.utils.losses.DiceLoss()]
metrics = [
    smp.utils.metrics.IoU(threshold=0.5),
]

In [None]:
train_epoch = smp.utils.train.TrainEpoch(
    model, 
    loss=loss_fn, 
    metrics=metrics, 
    optimizer=optimizer,
    device=DEVICE,
    verbose=True,
)

valid_epoch = smp.utils.train.ValidEpoch(
    model, 
    loss=loss_fn, 
    metrics=metrics, 
    device=DEVICE,
    verbose=True,
)

In [None]:
def savelogs(logs, name):
    with open(f'{name}.txt', 'a') as f:
        for k, v in logs.items():
            f.write(f'{k} {v}')
        f.write('\n')

In [None]:
max_score = 1e5
losses = {}
ious = {}
losses['train'] = []
losses['valid'] = []
ious['train'] = []
ious['valid'] = []

for i in range(0, config.epochs):
    
    print('\nEpoch: {}'.format(i))
    train_logs = train_epoch.run(train_loader)
    valid_logs = valid_epoch.run(valid_loader)
    
    savelogs(train_logs, f'train_logs.txt')
    savelogs(valid_logs, f'valid_logs.txt')
    
    losses['train'].append(train_logs['dice_loss'])
    losses['valid'].append(valid_logs['dice_loss'])
    ious['train'].append(train_logs['iou_score'])
    ious['valid'].append(valid_logs['iou_score'])
    #break
    # do something (save model, change lr, etc.)
    # val loss
    if max_score > valid_logs['dice_loss']:
        max_score = valid_logs['dice_loss']
        torch.save(model, 'best.pth')
        print('Model saved!')
        
    if i == 15:
        optimizer.param_groups[0]['lr'] = 1e-5
        print('Decrease decoder learning rate to 1e-5!')

In [None]:
# PLOT
def plot(scores, name):
    plt.figure(figsize=(15,5))
    plt.plot(range(len(scores["train"])), scores["train"], label=f'train {name}')
    plt.plot(range(len(scores["train"])), scores["valid"], label=f'val {name}')
    plt.title(f'{name} plot'); plt.xlabel('Epoch'); plt.ylabel(f'{name}');
    plt.legend(); 
    plt.show()

plot(losses, "loss")
plot(ious, "iou")