In [1]:
!pip install git+https://github.com/qubvel/segmentation_models.pytorch

Collecting git+https://github.com/qubvel/segmentation_models.pytorch
  Cloning https://github.com/qubvel/segmentation_models.pytorch to /tmp/pip-req-build-6u1p44jm
Collecting efficientnet-pytorch==0.6.3
  Downloading efficientnet_pytorch-0.6.3.tar.gz (16 kB)
Collecting pretrainedmodels==0.7.4
  Downloading pretrainedmodels-0.7.4.tar.gz (58 kB)
[K     |████████████████████████████████| 58 kB 850 kB/s 
Collecting timm==0.2.1
  Downloading timm-0.2.1-py3-none-any.whl (225 kB)
[K     |████████████████████████████████| 225 kB 1.7 MB/s 
Building wheels for collected packages: segmentation-models-pytorch, efficientnet-pytorch, pretrainedmodels
  Building wheel for segmentation-models-pytorch (setup.py) ... [?25l- \ | done
[?25h  Created wheel for segmentation-models-pytorch: filename=segmentation_models_pytorch-0.1.2-py3-none-any.whl size=64350 sha256=e8c01c1b4525e510cfb4102cfea70abb74bfe6ecfc65e0014bb34eb3147bf0a7
  Stored in directory: /tmp/pip-ephem-wheel-cache-7arte

In [2]:
import gc
import os
import random
import time
import warnings
warnings.simplefilter("ignore")


from albumentations import *
from albumentations.pytorch import ToTensor
import cv2
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd
import segmentation_models_pytorch as smp
from sklearn.model_selection import KFold
import tifffile as tiff
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
from torch.nn import functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader, Dataset, sampler
from tqdm import tqdm_notebook as tqdm

In [3]:
def seed_everything(seed=2**3):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True

seed_everything(121)

In [4]:
fold = 0
nfolds = 5
reduce = 4
sz = 256

BATCH_SIZE = 16
DEVICE = ('cuda' if torch.cuda.is_available() else 'cpu')
NUM_WORKERS = 4
NUM_EPOCHS = 45
SEED = 2020
TH = 0.39

DATA = '../input/hubmap-kidney-segmentation/test/'
LABELS = '../input/hubmap-kidney-segmentation/train.csv'
MASKS = '../input/hubmap-256x256/masks'
TRAIN = '../input/hubmap-256x256/train'
df_sample = pd.read_csv('../input/hubmap-kidney-segmentation/sample_submission.csv')


In [5]:
def rle_encode_less_memory(img):
    pixels=img.T.flatten()
    pixels[0]=0
    pixels[-1]=0
    runs = np.where(pixels[1:] != pixels[:-1])[0]+2
    runs[1::2]-=runs[::2]
    return ' '.join(str(x) for x in runs)

In [6]:
mean = np.array([0.65459856,0.48386562,0.69428385])
std = np.array([0.15167958,0.23584107,0.13146145])

def img2tensor(img, dtype:np.dtype=np.float32):
    if img.ndim==2: 
        img=np.expand_dims(img, 2)
    img=np.transpose(img, (2, 0, 1))
    return torch.from_numpy(img.astype(dtype, copy=False))

class HuBMAPDataset(Dataset):
    def __init__(self, fold=fold, train=True, tfms=None):
        ids = pd.read_csv(LABELS).id.values
        kf = KFold(n_splits=nfolds, random_state=SEED, shuffle=True)
        ids=set(ids[list(kf.split(ids))[fold][0 if train else 1]])
        self.fnames=[fname for fname in os.listdir(TRAIN) if fname.split('_')[0] in ids]
        self.train = train
        self.tfms = tfms
    
    def __len__(self):
        return len(self.fnames)
    
    def __getitem__(self, idx):
        fname = self.fnames[idx]
        imgs=cv2.cvtColor(cv2.imread(os.path.join(TRAIN, fname)), cv2.COLOR_BGR2RGB)
        masks=cv2.imread(os.path.join(MASKS, fname), cv2.IMREAD_GRAYSCALE)
        if self.tfms is not None:
            augmented=self.tfms(image=imgs, mask=masks)
            imgs, masks=augmented['image'], augmented['mask']
        return img2tensor((imgs/255.0-mean)/std), img2tensor(masks)

In [7]:
def get_augmentation(p=1.0):
    return Compose([
        HorizontalFlip(),
        VerticalFlip(),
        RandomRotate90(),
        ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.2, rotate_limit=15, p=0.9, border_mode=cv2.BORDER_REFLECT),
        OneOf([
            OpticalDistortion(p=0.3),
            GridDistortion(p=.1),
            IAAPiecewiseAffine(p=0.3),
        ], p=0.3),
        OneOf([
            HueSaturationValue(10, 15, 10),
            CLAHE(clip_limit=2),
            RandomBrightnessContrast(),
        ], p=0.3),
    ], p=p)

In [8]:
# ds = HuBMAPDataset(tfms=get_augmentation())
# dl = DataLoader(ds, batch_size=16, shuffle=True, num_workers=NUM_WORKERS)
# imgs, masks = next(iter(dl))
# print(imgs.shape)
# print(masks.shape)

# plt.figure(figsize=(16, 16))
# for i, (img, mask) in enumerate(zip(imgs, masks)):
#     img = ((img.permute(1, 2, 0)*std + mean) * 255.0).numpy().astype(np.uint8)
#     plt.subplot(4, 4, i+1)
#     plt.imshow(img, vmin=0, vmax=255)
#     plt.imshow(mask.squeeze().numpy(), alpha=0.2)
#     plt.axis('off')
#     plt.subplots_adjust(wspace=None, hspace=None)
# plt.show()

# del ds, dl, imgs, masks

In [9]:
class DiceLoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(DiceLoss, self).__init__()
    
    def forward(self, inputs, targets, smooth=1):
        #
        inputs = F.sigmoid(inputs)
        #flatten
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        #element_wise production to get intersection score
        intersection = (inputs*targets).sum()
        dice_score = (2*intersection + smooth) / (inputs.sum() + targets.sum() + smooth)
        return 1 - dice_score

In [10]:
def UnetPlusPlus():
    return smp.Unet(
        encoder_name='efficientnet-b7',
        encoder_weights='imagenet',
        in_channels=3,
        classes=1
    )

In [11]:
def UnetResNext():
    return smp.Unet(
        encoder_name='se_resnext50_32x4d',
        encoder_weights='imagenet',
        in_channels=3,
        classes=1
    )

In [12]:
import segmentation_models_pytorch as smp
def UnetDenseNet():
    return smp.Unet(
    encoder_name='densenet201',
    encoder_weights='imagenet',
    in_channels=3,
    classes=1)

In [13]:
def train_one_epoch(fold, model, dataloader_train, dataloader_valid, optimizer, loss_function):
    #training phase
    model.train()
    train_loss = 0
    for i, (imgs, masks) in enumerate(dataloader_train):
        optimizer.zero_grad()
        imgs = imgs.to(DEVICE)
        masks = masks.to(DEVICE)
        #forward pass
        outputs = model(imgs)
        #cal loss and backward
        loss = loss_function(outputs, masks)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    train_loss /= len(dataloader_train)
    
    #validating phase
    model.eval()
    valid_loss = 0
    with torch.no_grad():
        for i, (imgs, masks) in enumerate(dataloader_valid):
            imgs = imgs.to(DEVICE)
            masks = masks.to(DEVICE)
            outputs = model(imgs)
            loss = loss_function(outputs, masks)
            valid_loss += loss.item()
    valid_loss /=len(dataloader_valid)
    print(f'FOLD: {fold + 1}, EPOCH: {epoch + 1} - train loss: {train_loss} -  valid_loss: {valid_loss}')
    return train_loss, valid_loss



In [14]:
best_valid_loss = 0
for fold in range(nfolds):
    ds_t = HuBMAPDataset(fold=fold, train=True, tfms=get_augmentation())
    ds_v = HuBMAPDataset(fold=fold, train=False)
    dataloader_t = torch.utils.data.DataLoader(ds_t, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)
    dataloader_v = torch.utils.data.DataLoader(ds_v, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)
    model = UnetDenseNet().to(DEVICE)
    diceloss = DiceLoss()
    optimizer = torch.optim.Adam([
        {'params': model.decoder.parameters(), 'lr': 1e-3},
        {'params': model.encoder.parameters(), 'lr': 1e-3},
    ])

#     scheduler = optim.lr_scheduler.OneCycleLR(optimizer=optimizer, 
#                                               pct_start=0.1, 
#                                               div_factor=1e-3, 
#                                               max_lr=1e-2, 
#                                               epochs=NUM_EPOCHS, 
#                                               steps_per_epoch=len(dataloader_t))
    train_loss = 0
    valid_loss = 0

    for epoch in tqdm(range(NUM_EPOCHS)):
        train_loss, valid_loss = train_one_epoch(fold, model, dataloader_t, dataloader_v, optimizer, diceloss)
    
    torch.save(model.state_dict(), f'model_fold_{fold}.pth')
    if best_valid_loss == 0:
        best_valid_loss = valid_loss
    if best_valid_loss >= valid_loss:
        best_valid_loss = valid_loss
        torch.save(model, 'best_unet_model.pth')
    
    gc.collect()


Downloading: "http://data.lip6.fr/cadene/pretrainedmodels/densenet201-5750cbb1e.pth" to /root/.cache/torch/hub/checkpoints/densenet201-5750cbb1e.pth


HBox(children=(FloatProgress(value=0.0, max=81139790.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=45.0), HTML(value='')))

FOLD: 1, EPOCH: 1 - train loss: 0.801529505036094 -  valid_loss: 0.7638275905659324
FOLD: 1, EPOCH: 2 - train loss: 0.7233243464649498 -  valid_loss: 0.6785366841052708
FOLD: 1, EPOCH: 3 - train loss: 0.6499556415266805 -  valid_loss: 0.6083566217046035
FOLD: 1, EPOCH: 4 - train loss: 0.5813209847970442 -  valid_loss: 0.5494786736212278
FOLD: 1, EPOCH: 5 - train loss: 0.5279059557171611 -  valid_loss: 0.45464250445365906
FOLD: 1, EPOCH: 6 - train loss: 0.43516403630182343 -  valid_loss: 0.4076589345932007
FOLD: 1, EPOCH: 7 - train loss: 0.4212330068860735 -  valid_loss: 0.4447571980325799
FOLD: 1, EPOCH: 8 - train loss: 0.3591645878630799 -  valid_loss: 0.27268237189242717
FOLD: 1, EPOCH: 9 - train loss: 0.3053201309272221 -  valid_loss: 0.2538790969472182
FOLD: 1, EPOCH: 10 - train loss: 0.2942481745373119 -  valid_loss: 0.34248795164258855
FOLD: 1, EPOCH: 11 - train loss: 0.28512586285541586 -  valid_loss: 0.2764375664685902
FOLD: 1, EPOCH: 12 - train loss: 0.2912157115998206 -  vali

HBox(children=(FloatProgress(value=0.0, max=45.0), HTML(value='')))

FOLD: 2, EPOCH: 1 - train loss: 0.8090035701865581 -  valid_loss: 0.7139610807100932
FOLD: 2, EPOCH: 2 - train loss: 0.7210320450773287 -  valid_loss: 0.575989955663681
FOLD: 2, EPOCH: 3 - train loss: 0.626916592690482 -  valid_loss: 0.6264069358507792
FOLD: 2, EPOCH: 4 - train loss: 0.5301034978373134 -  valid_loss: 0.4448293884595235
FOLD: 2, EPOCH: 5 - train loss: 0.4449212737937472 -  valid_loss: 0.3565768500169118
FOLD: 2, EPOCH: 6 - train loss: 0.36184448152039184 -  valid_loss: 0.28912336428960167
FOLD: 2, EPOCH: 7 - train loss: 0.32526676275243804 -  valid_loss: 0.34185622731844584
FOLD: 2, EPOCH: 8 - train loss: 0.28804431210702924 -  valid_loss: 0.2863815148671468
FOLD: 2, EPOCH: 9 - train loss: 0.2608046229205914 -  valid_loss: 0.21476194461186726
FOLD: 2, EPOCH: 10 - train loss: 0.2466979590221424 -  valid_loss: 0.20671125451723735
FOLD: 2, EPOCH: 11 - train loss: 0.24062115962232522 -  valid_loss: 0.39866352677345274
FOLD: 2, EPOCH: 12 - train loss: 0.2308127509420784 -  v

HBox(children=(FloatProgress(value=0.0, max=45.0), HTML(value='')))

FOLD: 3, EPOCH: 1 - train loss: 0.7570651715121618 -  valid_loss: 0.6252088173111873
FOLD: 3, EPOCH: 2 - train loss: 0.6269225296450824 -  valid_loss: 0.5453306952519203
FOLD: 3, EPOCH: 3 - train loss: 0.5241804101118227 -  valid_loss: 0.40451020358213735
FOLD: 3, EPOCH: 4 - train loss: 0.43093981757396604 -  valid_loss: 0.3647034559676896
FOLD: 3, EPOCH: 5 - train loss: 0.3882526191996365 -  valid_loss: 0.36444467039250616
FOLD: 3, EPOCH: 6 - train loss: 0.32368603021633335 -  valid_loss: 0.3948003080353808
FOLD: 3, EPOCH: 7 - train loss: 0.3038493620186317 -  valid_loss: 0.27663007956832203
FOLD: 3, EPOCH: 8 - train loss: 0.2631298921457151 -  valid_loss: 0.22246183922041707
FOLD: 3, EPOCH: 9 - train loss: 0.22984612969363608 -  valid_loss: 0.19446372452066907
FOLD: 3, EPOCH: 10 - train loss: 0.23204570899649365 -  valid_loss: 0.19623830247281202
FOLD: 3, EPOCH: 11 - train loss: 0.22723124158091662 -  valid_loss: 0.18150424334540297
FOLD: 3, EPOCH: 12 - train loss: 0.2154379944975783

HBox(children=(FloatProgress(value=0.0, max=45.0), HTML(value='')))

FOLD: 4, EPOCH: 1 - train loss: 0.7943519109889016 -  valid_loss: 0.8663271163639269
FOLD: 4, EPOCH: 2 - train loss: 0.7071024553145769 -  valid_loss: 0.8193971451960111
FOLD: 4, EPOCH: 3 - train loss: 0.6090835150659393 -  valid_loss: 0.6240015531841078
FOLD: 4, EPOCH: 4 - train loss: 0.5361303662389053 -  valid_loss: 0.6224151928173868
FOLD: 4, EPOCH: 5 - train loss: 0.44780650040028624 -  valid_loss: 0.588286994319213
FOLD: 4, EPOCH: 6 - train loss: 0.36916315030557506 -  valid_loss: 0.3585556259280757
FOLD: 4, EPOCH: 7 - train loss: 0.31267868700422774 -  valid_loss: 0.397341472537894
FOLD: 4, EPOCH: 8 - train loss: 0.2968382356698031 -  valid_loss: 0.3842690994865016
FOLD: 4, EPOCH: 9 - train loss: 0.26115226930904883 -  valid_loss: 0.2930273319545545
FOLD: 4, EPOCH: 10 - train loss: 0.2602710189596977 -  valid_loss: 0.2858901337573403
FOLD: 4, EPOCH: 11 - train loss: 0.2426767979261171 -  valid_loss: 0.29005362799293116
FOLD: 4, EPOCH: 12 - train loss: 0.23072127449697782 -  vali

HBox(children=(FloatProgress(value=0.0, max=45.0), HTML(value='')))

FOLD: 5, EPOCH: 1 - train loss: 0.7311443842592693 -  valid_loss: 0.708677609761556
FOLD: 5, EPOCH: 2 - train loss: 0.5775858183701833 -  valid_loss: 0.5073141370500837
FOLD: 5, EPOCH: 3 - train loss: 0.452739835920788 -  valid_loss: 0.40883550473621916
FOLD: 5, EPOCH: 4 - train loss: 0.3737366381145659 -  valid_loss: 0.3634787287030901
FOLD: 5, EPOCH: 5 - train loss: 0.31211312611897785 -  valid_loss: 0.2465197756176903
FOLD: 5, EPOCH: 6 - train loss: 0.27989201630864824 -  valid_loss: 0.20829100268227713
FOLD: 5, EPOCH: 7 - train loss: 0.2895317611240205 -  valid_loss: 0.20692110061645508
FOLD: 5, EPOCH: 8 - train loss: 0.22512446329707192 -  valid_loss: 0.16427123546600342
FOLD: 5, EPOCH: 9 - train loss: 0.22277944201514835 -  valid_loss: 0.16206391936256773
FOLD: 5, EPOCH: 10 - train loss: 0.23229515211922783 -  valid_loss: 0.18010302100862777
FOLD: 5, EPOCH: 11 - train loss: 0.20732415460404896 -  valid_loss: 0.14904131775810606
FOLD: 5, EPOCH: 12 - train loss: 0.1985232983316694 