In [None]:
import time
import os
import skimage.io
import numpy as np
import pandas as pd
import cv2
import rasterio
from rasterio.windows import Window
import torch
import random
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchsampler import ImbalancedDatasetSampler
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.sampler import SubsetRandomSampler, RandomSampler, SequentialSampler
from warmup_scheduler import GradualWarmupScheduler
#from efficientnet_pytorch import model as enet
#from efficientnet_pytorch import EfficientNet
import segmentation_models_pytorch as smp
import albumentations as albu
from albumentations.pytorch import ToTensorV2
from sklearn.model_selection import StratifiedKFold
import matplotlib.pyplot as plt
from sklearn.metrics import cohen_kappa_score
#from tqdm import tqdm_notebook as tqdm
from tqdm.notebook import tqdm
#from lookahead import Lookahead
#from radam import *
from losses import *
from utils import rle_decode, make_grid, seed_everything
from prefetch_generator import BackgroundGenerator

In [None]:
smp.__version__,torch.__version__,albu.__version__

In [None]:
torch.backends.cudnn.benchmark = True

In [None]:
seed_everything(42)

# Config

In [None]:
data_dir = 'F:/hubmap-kidney-segmentation/'

logdir = 'F:/HuBMAP/exp007/'
encoder = 'timm-efficientnet-b4'  # 'efficientnet-b3'  timm-efficientnet-b4
ENCODER_WEIGHTS = 'noisy-student'  # noisy-student
mix_up = False
use_amp = True
image_size = 512
batch_size = 8
num_workers = 0
init_lr = 3e-4  # 1e-4
warmup_factor = 10
warmup_epo = 1
n_epochs = 50
n_epochs_stop = 5
epochs_no_improve1 = 0
epochs_no_improve2 = 0
early_stop = True
scaler = torch.cuda.amp.GradScaler(enabled=use_amp)
device = torch.device('cuda')

In [None]:
df = pd.read_csv('disk_folds_1024-128.csv');df.shape

In [None]:
df.head()

# Model

In [None]:
model = smp.Unet(
    encoder_name=encoder,        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
    encoder_weights=ENCODER_WEIGHTS,     # use `imagenet` pretrained weights for encoder initialization
    in_channels=3,                  # model input channels (1 for grayscale images, 3 for RGB, etc.)
    classes=1,                      # model output channels (number of classes in your dataset)
)

# Dataset

In [None]:
class HubMapDataset(Dataset):
    def __init__(self, df, train=True, transform=None):
        super().__init__()
        self.df = df.reset_index(drop=True)
        self.train = train
        self.transform = transform

    def __len__(self):
        return len(self.df)

    # get data operation    
    def __gen_data__(self, index):
        item = self.df.iloc[index]
        filename = item.image_id
        image_pth = item.path
        mask_pth = image_pth.replace('images', 'masks').replace('.png', '.jpg')
        image = cv2.imread(image_pth)
        mask = cv2.imread(mask_pth, 0) / 255
        
        return image, mask
    
    def __getitem__(self, idx):
        image, mask = self.__gen_data__(idx)
            
        if self.transform is not None:
            augments = self.transform(image=image, mask=mask)
            image = augments['image']
            mask = augments['mask'].unsqueeze(0).float()
        return image, mask#, mask.sum()

# Augmentations

In [None]:
transforms_train = albu.Compose([
    #albu.RandomResizedCrop(832, 832, p=0.4),
    albu.Resize(image_size, image_size),
    
    albu.HorizontalFlip(p=0.5),
    albu.VerticalFlip(p=0.5),
    albu.RandomRotate90(p=0.5),
    albu.Transpose(p=0.5),
    albu.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.2, rotate_limit=15, p=0.5, border_mode=cv2.BORDER_REFLECT),
    
    albu.RandomBrightnessContrast(p=0.5),
    albu.HueSaturationValue(hue_shift_limit=0.2, sat_shift_limit=0.2, val_shift_limit=0.2, p=0.5),
    albu.CLAHE(p=0.5),
    
    albu.OneOf([
        albu.OpticalDistortion(p=0.3),
        albu.GridDistortion(p=.1),
        albu.IAAPiecewiseAffine(p=0.3),
    ], p=0.3),
    
    # https://www.kaggle.com/c/hubmap-kidney-segmentation/discussion/202375, improve CV, but hurt LB
    albu.CoarseDropout(max_holes=8, max_height=64, max_width=64, fill_value=0, mask_fill_value=0, p=0.2),
    
    albu.Normalize(),
    ToTensorV2()
])

transforms_val = albu.Compose([
    albu.Resize(image_size, image_size),
    albu.Normalize(),
    ToTensorV2()
])

In [None]:
 if 1:
    dataset_show = HubMapDataset(df=df, train=True, transform=transforms_train)
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    # from pylab import rcParams
    # rcParams['figure.figsize'] = 20,10
    # for i in range(2):
    #     f, axarr = plt.subplots(1,6)
    #     for p in range(6):
    #         idx = np.random.randint(0, len(dataset_show))
    #         img, mask = dataset_show[idx]
    #         if p == 0 or p == 2 or p == 4:
    #             axarr[p].imshow(img.transpose(0, 1).transpose(1,2).squeeze())
    #             axarr[p+1].imshow(mask[0,:,:])

            #axarr[p].set_title(str(label))

    dl = DataLoader(dataset_show, batch_size=32, shuffle=False)
    imgs, masks = next(iter(dl))

    plt.figure(figsize=(16, 16))
    for i, (img, mask) in enumerate(zip(imgs, masks)):
        #print(s)
        img = ((img.permute(1,2,0)*std + mean)*255.0).numpy().astype(np.uint8)
        plt.subplot(8, 8, i+1)
        #plt.imshow(img.permute(1,2,0), vmin=0, vmax=255)
        plt.imshow(img)
        plt.imshow(mask.squeeze().numpy(), alpha=0.2)
        plt.axis('off')
        plt.subplots_adjust(wspace=None, hspace=None)
    plt.savefig('./viz.jpg')
    plt.show()

    #del dataset_show, dl, imgs, masks

# Utils

In [None]:
#Dice系数
def dice_coeff(pred, target):
    smooth = 1.
    num = pred.size(0)
    m1 = pred.view(num, -1)  # Flatten
    m2 = target.view(num, -1)  # Flatten
    intersection = (m1 * m2).sum()
 
    return (2. * intersection + smooth) / (m1.sum() + m2.sum() + smooth)

In [None]:
# https://github.com/facebookresearch/mixup-cifar10/blob/master/train.py
def mixup_data(x, y, alpha=1.0, use_cuda=True):
    '''Returns mixed inputs, pairs of targets, and lambda'''
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = x.size()[0]
    if use_cuda:
        index = torch.randperm(batch_size).cuda()
    else:
        index = torch.randperm(batch_size)

    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam


def mixup_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

In [None]:
class Meter:
    def __init__(self, threshold=0.5):
        self.dice = []
        self.threshold = threshold

    def update(self, targets, outputs):
        probs = torch.sigmoid(outputs.float())
        probs = (probs > self.threshold)
        dice = dice_coeff(probs, targets)
        self.dice.append(dice)

    def get_metrics(self):
        dice = np.mean(self.dice)
        return dice

def save_log(fold_id, phase, epoch, epoch_loss, acc):
    with open(os.path.join(logdir, f'result_fold{fold_id}.txt'), 'a') as f:
        f.write(f'epoch:{epoch} phase:{phase} loss:{epoch_loss} acc:{acc} \n')

def epoch_log(phase, epoch, epoch_loss, meter, start):
    '''logging the metrics at the end of an epoch'''
    dice = meter.get_metrics()
    print("{} loss: {:0.4f} | {} dice: {:0.4f}".format(phase, epoch_loss, phase, dice))
    return dice

def plot(scores, name, idx=None):
    plt.figure(figsize=(15,5))
    x1, y1 = range(len(scores["train"])), scores["train"]
    x2, y2 = range(len(scores["val"])), scores["val"]
    plt.plot(x1, y1, label=f'train {name}')
    plt.plot(x2, y2, label=f'val {name}')
    plt.title(f'{name} plot'); plt.xlabel('Epoch'); plt.ylabel(f'{name}');
    
    # show min/max point
    if name.startswith('loss'):
        indx1=np.argmin(y1)
        indx2=np.argmin(y2)
    else:
        indx1=np.argmax(y1)
        indx2=np.argmax(y2)
    
    plt.plot(indx1, y1[indx1], 'ks')
    plt.plot(indx2, y2[indx2], 'ks')
    show_max1 = '['+str(indx1)+', {:.4f}'.format(y1[indx1])+']'
    show_max2 = '['+str(indx2)+', {:.4f}'.format(y2[indx2])+']'
    plt.annotate(show_max1, xytext=(indx1, y1[indx1]), xy=(indx1, y1[indx1]))
    plt.annotate(show_max2, xytext=(indx2, y2[indx2]), xy=(indx2, y2[indx2]))
    plt.plot(indx1, y1[indx1],'gs')
    plt.plot(indx2, y2[indx2],'gs')
    
    if idx is not None:
        plt.plot(idx, y2[idx], 'ks')
        show = '['+str(idx)+', {:.4f}'.format(y2[idx])+']'
        plt.annotate(show, xytext=(idx, y2[idx]), xy=(idx, y2[idx]))
        plt.plot(idx, y2[idx],'gs')
    
    plt.legend();
    plt.savefig(os.path.join(logdir, f'{name}.jpg'))
    plt.show()
    
    return indx2
    
def utils(file):
    with open(os.path.join(logdir, file), 'r') as f:
        data = f.read().splitlines()
        losses, accs = {}, {}
        losses['train'], losses['val'] = [], []
        accs['train'], accs['val'] = [], []
        for line in data:
            phase = line.split(' ')[1].split(':')[1]
            loss = line.split(' ')[2].split(':')[1]
            acc = line.split(' ')[3].split(':')[1]
            losses[phase].append(float(loss))
            accs[phase].append(float(acc))
            
    return losses, accs

In [None]:
# label smooth
class LS(nn.Module):
    def __init__(self, smooth=0.2):
        super().__init__()
        self.label_smoothing = smooth
    def forward(self, inputs, targets):
        #comment out if your model contains a sigmoid or equivalent activation layer 
        #inputs = torch.sigmoid(inputs)
        
        targets = targets.float() * (1 - self.label_smoothing) + 0.5 * self.label_smoothing
        loss  = F.binary_cross_entropy_with_logits(inputs, targets, reduction='mean')
        return loss

# Trainer

In [None]:
class Trainer(object):
    def __init__(self, fold_id, model, criterion, optimizer, epochs, scheduler, train_loader, valid_loader, use_amp, mix_up=False):
        self.best_loss = float("inf")
        self.best_dice = -float("inf")
        self.phases = ["train", "val"]
        self.num_epochs = epochs
        self.fold_id = fold_id
        self.logdir = logdir
        self.device = device
        self.use_amp = use_amp
        self.net = model
        self.criterion = criterion
        self.optimizer = optimizer
        self.scheduler = scheduler
        
        self.dataloaders = {
            self.phases[0]: train_loader,
            self.phases[1]: valid_loader
        }
        self.mix_up = mix_up
        self.net.to(self.device)
        
        self.losses = {phase: [] for phase in self.phases}
        self.dice = {phase: [] for phase in self.phases}
    
    def loss_fn(self, pred, gt):
        loss = self.criterion(pred, gt.to(self.device))
        return loss
    
    def forward(self, x, y=None):
        x = x.to(self.device)
        outputs = self.net(x)
        return outputs

    def iterate(self, epoch, phase):
        meter = Meter()
        start = time.strftime("%H:%M:%S")
        print(f"Epoch: {epoch} | phase: {phase} | Time: {start}")
        self.net.train(phase == "train")
        dataloader = self.dataloaders[phase]
        running_loss = 0.0
        total_batches = len(dataloader)
        tk0 = tqdm(BackgroundGenerator(dataloader), total=total_batches)
        self.optimizer.zero_grad(set_to_none=True)
        for itr, batch in enumerate(tk0):
            images, targets = batch
                
            if phase == "train":
                with torch.cuda.amp.autocast(enabled=self.use_amp):
                    if self.mix_up and np.random.rand() > 0.5:
                        images, targets_a, targets_b, lam = mixup_data(images, targets)
                        outputs = self.forward(images)
                        loss = mixup_criterion(self.criterion, outputs, targets_a.to(self.device), targets_b.to(self.device), lam)
                    else:
                        outputs = self.forward(images)
                        loss = self.loss_fn(outputs, targets)
                scaler.scale(loss).backward()
                scaler.step(self.optimizer)
                scaler.update()
                self.optimizer.zero_grad(set_to_none=True)
            else:
                outputs = self.forward(images)
                loss = self.loss_fn(outputs, targets)
                
            running_loss += loss.item()
            outputs = outputs.detach().cpu()
            meter.update(targets, outputs)
            tk0.set_postfix(loss=(running_loss / (itr + 1)))
        epoch_loss = running_loss / total_batches
        epoch_dice = epoch_log(phase, epoch, epoch_loss, meter, start)
        save_log(self.fold_id, phase, epoch, epoch_loss, epoch_dice)
        self.losses[phase].append(epoch_loss)
        self.dice[phase].append(epoch_dice)
        torch.cuda.empty_cache()
        return epoch_loss, epoch_dice

    def fit(self):
        for epoch in range(self.num_epochs):
            self.iterate(epoch, "train")
            
            with torch.no_grad():
                val_loss, val_dice = self.iterate(epoch, "val")
                self.scheduler.step(val_loss)
            
            # monitor val loss
            if val_loss < self.best_loss:
                print(f"****** New optimal loss found @ {epoch}, saving state ******")
                epochs_no_improve1 = 0
                self.best_loss = val_loss
                torch.save(self.net.state_dict(), f"{self.logdir}/best_loss_fold{self.fold_id}.pth")
            else:
                epochs_no_improve1 += 1
            
            # monitor val metric
            if val_dice > self.best_dice:
                print(f"****** New optimal dice found @ {epoch}, saving state ******")
                epochs_no_improve2 = 0
                self.best_dice = val_dice
                torch.save(self.net.state_dict(), f"{self.logdir}/best_metric_fold{self.fold_id}.pth")
            else:
                epochs_no_improve2 += 1
            
            if early_stop and epochs_no_improve1 >= n_epochs_stop and epochs_no_improve2 >= n_epochs_stop:
                print('Early stopping!' )
                break
            print()
            
        print(f'train finished. best loss: {self.best_loss}, best dice: {self.best_dice}')

# Train

In [None]:
def train(fold_id):
    print(f'###################### training fold: {fold_id} ######################')
    ###################### data ###########################
    dataset_train = HubMapDataset(df=df[df.fold != fold_id], train=True, transform=transforms_train)
    dataset_valid = HubMapDataset(df=df[df.fold == fold_id], train=False, transform=transforms_val)

    train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=batch_size, shuffle=True,  num_workers=num_workers, pin_memory=True)
    valid_loader = torch.utils.data.DataLoader(dataset_valid, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
    
    ###################### model ###########################
    model = smp.Unet(
    encoder_name=encoder, 
    encoder_weights=ENCODER_WEIGHTS, 
    in_channels=3, 
    classes=1, 
    activation=None,
    decoder_use_batchnorm=True
    )
    criterion = LS() # BCEDiceLoss()  # FocalTverskyLoss() #nn.BCEWithLogitsLoss() #smp.utils.losses.DiceLoss() #Jaccardloss  # nn.BCEWithLogitsLoss() , nn.BCELoss()
    optimizer = optim.Adam(model.parameters(), lr=init_lr/warmup_factor)
    #optimizer = Lookahead(RAdam(filter(lambda p: p.requires_grad, model.parameters()),lr=init_lr), alpha=0.5, k=5)
    #scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=1.0 / 3, mode="min", patience=3, verbose=True)
    scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, n_epochs-warmup_epo)
    scheduler = GradualWarmupScheduler(optimizer, multiplier=warmup_factor, total_epoch=warmup_epo, after_scheduler=scheduler_cosine)
    
    ###################### trainer ###########################
    trainer = Trainer(
        fold_id=fold_id,
        model=model,
        criterion=criterion,
        optimizer=optimizer, 
        scheduler=scheduler, 
        epochs=n_epochs,
        train_loader=train_loader, 
        valid_loader=valid_loader,
        use_amp=use_amp,
        mix_up=mix_up
    )
    trainer.fit()

# Run Training

In [None]:
train(0)

In [None]:
losses, dice = utils('result_fold0.txt')
best_val_idx = plot(losses, 'loss_0')
plot(dice, 'dice_0', idx=best_val_idx)

In [None]:
train(1)

In [None]:
losses, dice = utils('result_fold1.txt')
best_val_idx = plot(losses, 'loss_1')
plot(dice, 'dice_1', idx=best_val_idx)

In [None]:
train(2)

In [None]:
losses, dice = utils('result_fold2.txt')
best_val_idx = plot(losses, 'loss_2')
plot(dice, 'dice_2', idx=best_val_idx)

In [None]:
train(3)

In [None]:
losses, dice = utils('result_fold3.txt')
best_val_idx = plot(losses, 'loss_3')
plot(dice, 'dice_3', idx=best_val_idx)

In [None]:
train(4)

In [None]:
losses, dice = utils('result_fold4.txt')
best_val_idx = plot(losses, 'loss_4')
plot(dice, 'dice_4', idx=best_val_idx)