In [None]:
!pip install git+https://github.com/qubvel/segmentation_models.pytorch

In [None]:
import gc
import os
import random
import time
import warnings
warnings.simplefilter("ignore")


from albumentations import *
from albumentations.pytorch import ToTensor
import cv2
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd
import segmentation_models_pytorch as smp
from sklearn.model_selection import KFold
import tifffile as tiff
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
from torch.nn import functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader, Dataset, sampler
from tqdm import tqdm_notebook as tqdm

In [None]:
def seed_everything(seed=2**3):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True

seed_everything(121)

In [None]:
fold = 0
nfolds = 5
reduce = 4
sz = 1024

BATCH_SIZE = 16
DEVICE = ('cuda' if torch.cuda.is_available() else 'cpu')
NUM_WORKERS = 4
NUM_EPOCHS = 60
SEED = 2020
TH = 0.39

DATA = '../input/hubmap-kidney-segmentation/test/'
LABELS = '../input/hubmap-kidney-segmentation/train.csv'
MASKS = '../input/hubmap-256x256/masks'
TRAIN = '../input/hubmap-256x256/train'
df_sample = pd.read_csv('../input/hubmap-kidney-segmentation/sample_submission.csv')


In [None]:
def rle_encode_less_memory(img):
    pixels=img.T.flatten()
    pixels[0]=0
    pixels[-1]=0
    runs = np.where(pixels[1:] != pixels[:-1])[0]+2
    runs[1::2]-=runs[::2]
    return ' '.join(str(x) for x in runs)

In [None]:
mean = np.array([0.65459856,0.48386562,0.69428385])
std = np.array([0.15167958,0.23584107,0.13146145])

def img2tensor(img, dtype:np.dtype=np.float32):
    if img.ndim==2: 
        img=np.expand_dims(img, 2)
    img=np.transpose(img, (2, 0, 1))
    return torch.from_numpy(img.astype(dtype, copy=False))

class HuBMAPDataset(Dataset):
    def __init__(self, fold=fold, train=True, tfms=None):
        ids = pd.read_csv(LABELS).id.values
        kf = KFold(n_splits=nfolds, random_state=SEED, shuffle=True)
        ids=set(ids[list(kf.split(ids))[fold][0 if train else 1]])
        self.fnames=[fname for fname in os.listdir(TRAIN) if fname.split('_')[0] in ids]
        self.train = train
        self.tfms = tfms
    
    def __len__(self):
        return len(self.fnames)
    
    def __getitem__(self, idx):
        fname = self.fnames[idx]
        imgs=cv2.cvtColor(cv2.imread(os.path.join(TRAIN, fname)), cv2.COLOR_BGR2RGB)
        masks=cv2.imread(os.path.join(MASKS, fname), cv2.IMREAD_GRAYSCALE)
        if self.tfms is not None:
            augmented=self.tfms(image=imgs, mask=masks)
            imgs, masks=augmented['image'], augmented['mask']
        return img2tensor((imgs/255.0-mean)/std), img2tensor(masks)

In [None]:
def get_augmentation(p=1.0):
    return Compose([
        HorizontalFlip(),
        VerticalFlip(),
        RandomRotate90(),
        ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.2, rotate_limit=15, p=0.9, border_mode=cv2.BORDER_REFLECT),
        OneOf([
            OpticalDistortion(p=0.3),
            GridDistortion(p=.1),
            IAAPiecewiseAffine(p=0.3),
        ], p=0.3),
        OneOf([
            HueSaturationValue(10, 15, 10),
            CLAHE(clip_limit=2),
            RandomBrightnessContrast(),
        ], p=0.3),
    ], p=p)

In [None]:
ds = HuBMAPDataset(tfms=get_augmentation())
dl = DataLoader(ds, batch_size=16, shuffle=True, num_workers=NUM_WORKERS)
imgs, masks = next(iter(dl))
print(imgs.shape)
print(masks.shape)

plt.figure(figsize=(16, 16))
for i, (img, mask) in enumerate(zip(imgs, masks)):
    img = ((img.permute(1, 2, 0)*std + mean) * 255.0).numpy().astype(np.uint8)
    plt.subplot(4, 4, i+1)
    plt.imshow(img, vmin=0, vmax=255)
    plt.imshow(mask.squeeze().numpy(), alpha=0.2)
    plt.axis('off')
    plt.subplots_adjust(wspace=None, hspace=None)
plt.show()

del ds, dl, imgs, masks

In [None]:
class DiceLoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(DiceLoss, self).__init__()
    
    def forward(self, inputs, targets, smooth=1):
        #
        inputs = F.sigmoid(inputs)
        #flatten
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        #element_wise production to get intersection score
        intersection = (inputs*targets).sum()
        dice_score = (2*intersection + smooth) / (inputs.sum() + targets.sum() + smooth)
        return 1 - dice_score

In [None]:
"""
Lovasz-Softmax and Jaccard hinge loss in PyTorch
Maxim Berman 2018 ESAT-PSI KU Leuven (MIT License)
"""

from __future__ import print_function, division

import torch
from torch.autograd import Variable
import torch.nn.functional as F
import numpy as np
try:
    from itertools import  ifilterfalse
except ImportError: # py3k
    from itertools import  filterfalse


def lovasz_grad(gt_sorted):
    """
    Computes gradient of the Lovasz extension w.r.t sorted errors
    See Alg. 1 in paper
    """
    p = len(gt_sorted)
    gts = gt_sorted.sum()
    intersection = gts - gt_sorted.float().cumsum(0)
    union = gts + (1 - gt_sorted).float().cumsum(0)
    jaccard = 1. - intersection / union
    if p > 1: # cover 1-pixel case
        jaccard[1:p] = jaccard[1:p] - jaccard[0:-1]
    return jaccard


def iou_binary(preds, labels, EMPTY=1., ignore=None, per_image=True):
    """
    IoU for foreground class
    binary: 1 foreground, 0 background
    """
    if not per_image:
        preds, labels = (preds,), (labels,)
    ious = []
    for pred, label in zip(preds, labels):
        intersection = ((label == 1) & (pred == 1)).sum()
        union = ((label == 1) | ((pred == 1) & (label != ignore))).sum()
        if not union:
            iou = EMPTY
        else:
            iou = float(intersection) / union
        ious.append(iou)
    iou = cal_mean(ious)    # mean accross images if per_image
    return 100 * iou


def iou(preds, labels, C, EMPTY=1., ignore=None, per_image=False):
    """
    Array of IoU for each (non ignored) class
    """
    if not per_image:
        preds, labels = (preds,), (labels,)
    ious = []
    for pred, label in zip(preds, labels):
        iou = []    
        for i in range(C):
            if i != ignore: # The ignored label is sometimes among predicted classes (ENet - CityScapes)
                intersection = ((label == i) & (pred == i)).sum()
                union = ((label == i) | ((pred == i) & (label != ignore))).sum()
                if not union:
                    iou.append(EMPTY)
                else:
                    iou.append(float(intersection) / union)
        ious.append(iou)
    ious = map(cal_mean, zip(*ious)) # mean accross images if per_image
    return 100 * np.array(ious)


# --------------------------- BINARY LOSSES ---------------------------


def lovasz_hinge(logits, labels, per_image=True, ignore=None):
    """
    Binary Lovasz hinge loss
      logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty)
      labels: [B, H, W] Tensor, binary ground truth masks (0 or 1)
      per_image: compute the loss per image instead of per batch
      ignore: void class id
    """
    if per_image:
        loss = cal_mean(lovasz_hinge_flat(*flatten_binary_scores(log.unsqueeze(0), lab.unsqueeze(0), ignore))
                          for log, lab in zip(logits, labels))
    else:
        loss = lovasz_hinge_flat(*flatten_binary_scores(logits, labels, ignore))
    return loss


def lovasz_hinge_flat(logits, labels):
    """
    Binary Lovasz hinge loss
      logits: [P] Variable, logits at each prediction (between -\infty and +\infty)
      labels: [P] Tensor, binary ground truth labels (0 or 1)
      ignore: label to ignore
    """
    if len(labels) == 0:
        # only void pixels, the gradients should be 0
        return logits.sum() * 0.
    signs = 2. * labels.float() - 1.
    errors = (1. - logits * Variable(signs))
    errors_sorted, perm = torch.sort(errors, dim=0, descending=True)
    perm = perm.data
    gt_sorted = labels[perm]
    grad = lovasz_grad(gt_sorted)
    #loss = torch.dot(F.relu(errors_sorted), Variable(grad))
    loss = torch.dot(F.elu(errors_sorted)+1, Variable(grad))
    return loss


def flatten_binary_scores(scores, labels, ignore=None):
    """
    Flattens predictions in the batch (binary case)
    Remove labels equal to 'ignore'
    """
    scores = scores.view(-1)
    labels = labels.view(-1)
    if ignore is None:
        return scores, labels
    valid = (labels != ignore)
    vscores = scores[valid]
    vlabels = labels[valid]
    return vscores, vlabels


class StableBCELoss(torch.nn.modules.Module):
    def __init__(self):
         super(StableBCELoss, self).__init__()
    def forward(self, input, target):
         neg_abs = - input.abs()
         loss = input.clamp(min=0) - input * target + (1 + neg_abs.exp()).log()
         return loss.mean()


def binary_xloss(logits, labels, ignore=None):
    """
    Binary Cross entropy loss
      logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty)
      labels: [B, H, W] Tensor, binary ground truth masks (0 or 1)
      ignore: void class id
    """
    logits, labels = flatten_binary_scores(logits, labels, ignore)
    loss = StableBCELoss()(logits, Variable(labels.float()))
    return loss


# --------------------------- MULTICLASS LOSSES ---------------------------


def lovasz_softmax(probas, labels, only_present=False, per_image=False, ignore=None):
    """
    Multi-class Lovasz-Softmax loss
      probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1)
      labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1)
      only_present: average only on classes present in ground truth
      per_image: compute the loss per image instead of per batch
      ignore: void class labels
    """
    if per_image:
        loss = mean(lovasz_softmax_flat(*flatten_probas(prob.unsqueeze(0), lab.unsqueeze(0), ignore), only_present=only_present)
                          for prob, lab in zip(probas, labels))
    else:
        loss = lovasz_softmax_flat(*flatten_probas(probas, labels, ignore), only_present=only_present)
    return loss


def lovasz_softmax_flat(probas, labels, only_present=False):
    """
    Multi-class Lovasz-Softmax loss
      probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1)
      labels: [P] Tensor, ground truth labels (between 0 and C - 1)
      only_present: average only on classes present in ground truth
    """
    C = probas.size(1)
    losses = []
    for c in range(C):
        fg = (labels == c).float() # foreground for class c
        if only_present and fg.sum() == 0:
            continue
        errors = (Variable(fg) - probas[:, c]).abs()
        errors_sorted, perm = torch.sort(errors, 0, descending=True)
        perm = perm.data
        fg_sorted = fg[perm]
        losses.append(torch.dot(errors_sorted, Variable(lovasz_grad(fg_sorted))))
    return cal_mean(losses)


def flatten_probas(probas, labels, ignore=None):
    """
    Flattens predictions in the batch
    """
    B, C, H, W = probas.size()
    probas = probas.permute(0, 2, 3, 1).contiguous().view(-1, C)  # B * H * W, C = P, C
    labels = labels.view(-1)
    if ignore is None:
        return probas, labels
    valid = (labels != ignore)
    vprobas = probas[valid.nonzero().squeeze()]
    vlabels = labels[valid]
    return vprobas, vlabels

def xloss(logits, labels, ignore=None):
    """
    Cross entropy loss
    """
    return F.cross_entropy(logits, Variable(labels), ignore_index=255)


# --------------------------- HELPER FUNCTIONS ---------------------------

def cal_mean(l, ignore_nan=False, empty=0):
    """
    nanmean compatible with generators.
    """
    l = iter(l)
    if ignore_nan:
        l = ifilterfalse(np.isnan, l)
    try:
        n = 1
        acc = next(l)
    except StopIteration:
        if empty == 'raise':
            raise ValueError('Empty mean')
        return empty
    for n, v in enumerate(l, 2):
        acc += v
    if n == 1:
        return acc
    return acc / n

In [None]:
def symmetric_lovasz(outputs, targets):
    return 0.5*(lovasz_hinge(outputs, targets) + lovasz_hinge(-outputs ,1.0-targets))

In [None]:
class LovaszLoss(nn.Module):
    def __init__(self):
        super(LovaszLoss, self).__init__()
        
    def forward(self, inputs, targets, smooth=1):
        inputs = F.sigmoid(inputs)
        return symmetric_lovasz(inputs, targets)

In [None]:
def UnetPlusPlus():
    return smp.Unet(
        encoder_name='efficientnet-b7',
        encoder_weights='imagenet',
        in_channels=3,
        classes=1
    )

In [None]:
def UnetResNext():
    return smp.Unet(
        encoder_name='se_resnext50_32x4d',
        encoder_weights='imagenet',
        in_channels=3,
        classes=1
    )

In [None]:
import segmentation_models_pytorch as smp
def UnetDenseNet():
    return smp.Unet(
    encoder_name='densenet201',
    encoder_weights='imagenet',
    in_channels=3,
    classes=1)

In [None]:
def train_one_epoch(fold, model, dataloader_train, dataloader_valid, optimizer, loss_function):
    #training phase
    model.train()
    train_loss = 0
    for i, (imgs, masks) in enumerate(dataloader_train):
        optimizer.zero_grad()
        imgs = imgs.to(DEVICE)
        masks = masks.to(DEVICE)
        #forward pass
        outputs = model(imgs)
        #cal loss and backward
        loss = loss_function(outputs, masks)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    train_loss /= len(dataloader_train)
    
    #validating phase
    model.eval()
    valid_loss = 0
    with torch.no_grad():
        for i, (imgs, masks) in enumerate(dataloader_valid):
            imgs = imgs.to(DEVICE)
            masks = masks.to(DEVICE)
            outputs = model(imgs)
            loss = loss_function(outputs, masks)
            valid_loss += loss.item()
    valid_loss /=len(dataloader_valid)
    print(f'FOLD: {fold + 1}, EPOCH: {epoch + 1} - train loss: {train_loss} -  valid_loss: {valid_loss}')
    return train_loss, valid_loss



In [None]:
best_valid_loss = 0
for fold in range(nfolds):
    ds_t = HuBMAPDataset(fold=fold, train=True, tfms=get_augmentation())
    ds_v = HuBMAPDataset(fold=fold, train=False)
    dataloader_t = torch.utils.data.DataLoader(ds_t, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)
    dataloader_v = torch.utils.data.DataLoader(ds_v, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)
    model = UnetResNext().to(DEVICE)
    diceloss = LovaszLoss()
    optimizer = torch.optim.Adam([
        {'params': model.decoder.parameters(), 'lr': 1e-3},
        {'params': model.encoder.parameters(), 'lr': 1e-3},
    ])

#     scheduler = optim.lr_scheduler.OneCycleLR(optimizer=optimizer, 
#                                               pct_start=0.1, 
#                                               div_factor=1e-3, 
#                                               max_lr=1e-2, 
#                                               epochs=NUM_EPOCHS, 
#                                               steps_per_epoch=len(dataloader_t))
    train_loss = 0
    valid_loss = 0

    for epoch in tqdm(range(NUM_EPOCHS)):
        train_loss, valid_loss = train_one_epoch(fold, model, dataloader_t, dataloader_v, optimizer, diceloss)
    
    torch.save(model.state_dict(), f'model_fold_{fold}.pth')
    if best_valid_loss == 0:
        best_valid_loss = valid_loss
    if best_valid_loss >= valid_loss:
        best_valid_loss = valid_loss
        torch.save(model, 'best_unet_model.pth')
    
    gc.collect()
