In [None]:
!pip install segmentation_models_pytorch

In [None]:
%%time
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import os
import os
import cv2
import numpy as np
import pandas as pd
from torch.utils.data import TensorDataset, DataLoader, Dataset
from albumentations import *
import segmentation_models_pytorch as smp
from albumentations.pytorch import ToTensor
import torch
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

In [None]:
DEVICE = ('cuda' if torch.cuda.is_available() else 'cpu')

# parameters

In [None]:
class config:
    backbone = 'efficientnet-b4'
    ACTIVATION = 'sigmoid'
    ENCODER_WEIGHTS = 'imagenet'
    
    lr=5e-4
    epochs= 50
    batch_size=8
#     T_max=500
    im_size=256
    num_workers=4
    
#     nfolds = 4
#     fold = 0
    
#     seed = 2020
    
    images_path = f'../input/{im_size}{im_size}-pu/train/'
    masks_path = f'../input/{im_size}{im_size}-pu/masks/'
    
    images_path2 = f'../input/{im_size}{im_size}-pu/2train/'
    masks_path2 = f'../input/{im_size}{im_size}-pu/2masks/'
    
    images_path3 = f'../input/{im_size}{im_size}-pu/3train/'
    masks_path3 = f'../input/{im_size}{im_size}-pu/3masks/'

# data (augment)

In [None]:
def get_aug(p=1.0):
    return Compose([
        HorizontalFlip(),
        VerticalFlip(),
        RandomRotate90(),
        ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.2, rotate_limit=15, p=0.9, 
                         border_mode=cv2.BORDER_REFLECT),
        OneOf([
            OpticalDistortion(p=0.3),
            GridDistortion(p=.1),
            IAAPiecewiseAffine(p=0.3),
        ], p=0.3),
        OneOf([
            HueSaturationValue(10,15,10),
            CLAHE(clip_limit=2),
            RandomBrightnessContrast(),            
        ], p=0.4),
    ], p=p)

mean = np.array([0.65459856,0.48386562,0.69428385])
std = np.array([0.15167958,0.23584107,0.13146145])

def img2tensor(img,dtype:np.dtype=np.float32):
    if img.ndim==2 : img = np.expand_dims(img,2)  
    img = np.transpose(img,(2,0,1))
    return torch.from_numpy(img.astype(dtype, copy=False))

class HuBMAPDataset(Dataset):
    def __init__(self, ids, transforms=None):
        self.ids = ids
        
        self.transforms = transforms
        
    def __getitem__(self, idx):
        path = self.ids[idx]
        
        if path[19] == 't':
            img = cv2.cvtColor(cv2.imread(path), cv2.COLOR_BGR2RGB)
            mask = cv2.imread(os.path.join(path[:19]+'masks/'+path[25:]),cv2.IMREAD_GRAYSCALE)
            
        else:
            img = cv2.cvtColor(cv2.imread(path), cv2.COLOR_BGR2RGB)
            mask = cv2.imread(os.path.join(path[:20]+'masks/'+path[26:]),cv2.IMREAD_GRAYSCALE)
    
        if self.transforms:
            augmented = self.transforms(image=img, mask=mask)
            img = augmented['image']
            mask = augmented['mask']
        
#         print(img.shape, mask.shape)
        return img2tensor((img/255.0 - mean)/std),img2tensor(mask)

    def __len__(self):
        return len(self.ids)

In [None]:
data = os.listdir(config.images_path)#[:100]
train_lsit = list(set([row.split("_")[0] for row in data]))
train_idx = [row for row in data if row.split("_")[0] in train_lsit[:-2]]
valid_idx = [row for row in data if row.split("_")[0] not in train_lsit[:-2]]
len(train_idx), len(valid_idx)

In [None]:
data2 = os.listdir(config.images_path2)#[:100]
train_lsit2 = list(set([row.split("_")[0] for row in data2]))
train_idx2 = [row for row in data2 if row.split("_")[0] in train_lsit2[:-2]]
valid_idx2 = [row for row in data2 if row.split("_")[0] not in train_lsit2[:-2]]
len(train_idx2), len(valid_idx2)

In [None]:
data3 = os.listdir(config.images_path3)#[:100]
train_lsit3 = list(set([row.split("_")[0] for row in data3]))
train_idx3 = [row for row in data3 if row.split("_")[0] in train_lsit3[:-2]]
valid_idx3 = [row for row in data3 if row.split("_")[0] not in train_lsit3[:-2]]
len(train_idx3), len(valid_idx3)

In [None]:
# train_idx.extend(train_idx2)
# train_idx.extend(train_idx3)
len(train_idx)

In [None]:
# valid_idx.extend(valid_idx2)
# valid_idx.extend(valid_idx3)
len(valid_idx)

dataset

In [None]:
train_datasets = HuBMAPDataset(train_idx, transforms= get_aug())
valid_datasets = HuBMAPDataset(valid_idx)
train_loader = DataLoader(train_datasets, batch_size=config.batch_size, shuffle=True, num_workers=config.num_workers)
valid_loader = DataLoader(valid_datasets, batch_size=config.batch_size, shuffle=False, num_workers=config.num_workers)

In [None]:
# helper function for data visualization
ds = HuBMAPDataset(train_idx[:40],transforms= None)
dl = DataLoader(ds,batch_size=40,shuffle=False,num_workers=config.num_workers)
imgs,masks = next(iter(dl))

plt.figure(figsize=(16,16))
for i,(img,mask) in enumerate(zip(imgs,masks)):
    img = ((img.permute(1,2,0)*std + mean)*255.0).numpy().astype(np.uint8)
#     img = img.permute(1,2,0).numpy().astype(np.uint8)
    plt.subplot(8,8,i+1)
    plt.imshow(img,vmin=0,vmax=255)
    plt.imshow(mask.squeeze().numpy(), alpha=0.3)
    plt.axis('off')
    plt.subplots_adjust(wspace=None, hspace=None)
    
del ds,dl,imgs,masks

# Model

https://segmentation-modelspytorch.readthedocs.io/en/latest/docs/api.html#unet

In [None]:
model = smp.Unet(
    config.backbone, 
    encoder_weights=config.ENCODER_WEIGHTS, 
    in_channels=3,
    classes=1, 
    activation=config.ACTIVATION,
    decoder_use_batchnorm=False
)
optimizer = torch.optim.AdamW(model.parameters(),lr=config.lr)

loss_fn = smp.utils.losses.DiceLoss() # smp.utils.losses.BCEWithLogitsLoss()

#metric = [smp.utils.losses.DiceLoss()]
metrics = [
    smp.utils.metrics.IoU(threshold=0.5),
]

In [None]:
train_epoch = smp.utils.train.TrainEpoch(
    model, 
    loss=loss_fn, 
    metrics=metrics, 
    optimizer=optimizer,
    device=DEVICE,
    verbose=True,
)

valid_epoch = smp.utils.train.ValidEpoch(
    model, 
    loss=loss_fn, 
    metrics=metrics, 
    device=DEVICE,
    verbose=True,
)

In [None]:
def savelogs(logs, name):
    with open(f'{name}.txt', 'a') as f:
        for k, v in logs.items():
            f.write(f'{k} {v}')
        f.write('\n')

# Train

In [None]:
%%time

patience = 0
max_score = 1e5
losses = {}
ious = {}
losses['train'] = []
losses['valid'] = []
ious['train'] = []
ious['valid'] = []

for i in range(0, config.epochs):

    print('\nEpoch: {}'.format(i))
    train_logs = train_epoch.run(train_loader)
    valid_logs = valid_epoch.run(valid_loader)
    
    savelogs(train_logs, f'train_logs')
    savelogs(valid_logs, f'valid_logs')
    
    losses['train'].append(train_logs['dice_loss'])
    losses['valid'].append(valid_logs['dice_loss'])
    
    ious['train'].append(train_logs['iou_score'])
    ious['valid'].append(valid_logs['iou_score'])
    
    patience += 1
    #break
    # do something (save model, change lr, etc.)
    # val loss
    if max_score > valid_logs['dice_loss']:
        max_score = valid_logs['dice_loss']
        torch.save(model, 'best.pth')
        patience = 0
        print('get the best score: ', 1- max_score)
        print('Model saved!')
        
    if i == 15:
        optimizer.param_groups[0]['lr'] = 1e-4
        print('Decrease decoder learning rate to 1e-5!')

In [None]:
# PLOT
def plot(scores, name):
    plt.figure(figsize=(15,5))
    plt.plot(range(len(scores["train"])), scores["train"], label=f'train {name}')
    plt.plot(range(len(scores["train"])), scores["valid"], label=f'val {name}')
    plt.title(f'{name} plot'); plt.xlabel('Epoch'); plt.ylabel(f'{name}');
    plt.legend(); 
    plt.show()

plot(losses, "loss")
plot(ious, "iou")