In [None]:
!pip install segmentation_models_pytorch

In [None]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import os
import os
import cv2
import numpy as np
import pandas as pd
from torch.utils.data import TensorDataset, DataLoader, Dataset
import albumentations as albu
import segmentation_models_pytorch as smp
import torch
from tqdm.auto import tqdm

# data preprocessing - https://www.kaggle.com/iafoss/256x256-images

In [None]:
!mkdir data
!mkdir data/images
!unzip ../input/512x512-images/train.zip -d data/images

In [None]:
!mkdir data/masks
!unzip ../input/512x512-images/masks.zip -d data/masks

In [None]:
class config:
    images_path = './data/images'
    masks_path = './data/masks'
    backbone = 'resnet34'
    lr=1e-3
    epochs = 10
    batch_size=8
    T_max=500
    im_size=512
    num_workers=4
    DEBUG = True #change to False to run all

In [None]:
def to_tensor(x, **kwargs):
    return x.transpose(2, 0, 1).astype('float32')

def get_train_augmentation(size=1024):
    return albu.Compose([
        albu.HorizontalFlip(),
        albu.OneOf([
            albu.RandomContrast(),
            albu.RandomGamma(),
            albu.RandomBrightness(),
            ], p=0.3),
        albu.OneOf([
            albu.ElasticTransform(alpha=120, sigma=120 * 0.05, alpha_affine=120 * 0.03),
            albu.GridDistortion(),
            albu.OpticalDistortion(distort_limit=2, shift_limit=0.5),
            ], p=0.3),
        albu.ShiftScaleRotate(),
        albu.Resize(size,size,always_apply=True),
    ])

def get_valid_augmentation(size=1024):
    return albu.Compose([
        albu.Resize(size,size,always_apply=True),
    ])


def get_preprocessing(preprocessing_fn):
    _transform = [
        albu.Lambda(image=preprocessing_fn),
        albu.Lambda(image=to_tensor, mask=to_tensor),
    ]
    return albu.Compose(_transform)



class HuBMAPDataset(Dataset):
    def __init__(self, ids, transforms=None, preprocessing=None):
        self.ids = ids
        self.transforms = transforms
        self.preprocessing = preprocessing
    def __getitem__(self, idx):
        name = self.ids[idx]
        img = cv2.imread(f"{config.images_path}/{name}")
        mask = cv2.imread(f"{config.masks_path}/{name}")[:,:,0:1]
        if self.transforms:
            augmented = self.transforms(image=img, mask=mask)
            img = augmented['image']
            mask = augmented['mask']
        if self.preprocessing:
            preprocessed = self.preprocessing(image=img, mask=mask)
            img = preprocessed['image']
            mask = preprocessed['mask']
        return img, mask

    def __len__(self):
        return len(self.ids)

In [None]:
data = os.listdir(config.images_path)
train_lsit = list(set([row.split("_")[0] for row in data]))
train_idx = [row for row in data if row.split("_")[0] in train_lsit[:-2]]
valid_idx = [row for row in data if row.split("_")[0] not in train_lsit[:-2]]
len(train_idx),len(valid_idx)

In [None]:
preprocessing = get_preprocessing(smp.encoders.get_preprocessing_fn("resnet34","imagenet"))
train_aug = get_train_augmentation(config.im_size)
val_aug = get_valid_augmentation(config.im_size)
train_datasets = HuBMAPDataset(train_idx,transforms=train_aug, preprocessing=preprocessing)
valid_datasets = HuBMAPDataset(valid_idx,transforms=val_aug, preprocessing=preprocessing)
train_loader = DataLoader(train_datasets, batch_size=config.batch_size, shuffle=True, num_workers=config.num_workers,pin_memory=True)
valid_loader = DataLoader(valid_datasets, batch_size=config.batch_size, shuffle=False, num_workers=config.num_workers, pin_memory=True)

In [None]:
x,y = train_datasets[1]
x.shape,y.shape

In [None]:
model = smp.Unet(config.backbone,in_channels = 3,classes = 1,decoder_use_batchnorm = False)
optim = torch.optim.AdamW(model.parameters(),lr=config.lr)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim,T_max=config.T_max)
loss_fn = torch.nn.BCEWithLogitsLoss()
metric = smp.utils.losses.DiceLoss()

In [None]:
class AverageMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

class trainer:
    def __init__(self, model,optim,scheduler,loss_fn,metric):
        self.model = model.cuda()
        self.opt = optim
        self.scheduler = scheduler
        self.loss_fn = loss_fn
        self.metric = metric

    def train(self, train_loader,e,epochs):
        self.model.train()
        tqdm_loder = tqdm(train_loader)
        current_loss_mean = 0
        current_dice_mean = AverageMeter()
        self.opt.zero_grad()
        for batch_idx, (x, y) in enumerate(tqdm_loder):
            x = x.cuda().float()
            y = y.cuda().float()
            predicted = self.model(x.cuda().float())
            loss = self.loss_fn(predicted.cuda().float(), y.cuda().float())
            dice = self.metric(predicted.cuda().float(), y.cuda().float())
            current_dice_mean.update(dice)
            predicted_sigmoid = torch.sigmoid(predicted)
            loss.backward()
            self.opt.step()
            self.opt.zero_grad()
            self.scheduler.step()
            current_loss_mean = (current_loss_mean * batch_idx + loss) / (batch_idx + 1)
            lr = self.opt.param_groups[0]['lr']
            tqdm_loder.set_description(f"Epoch {e}/{epochs}, train loss: {current_loss_mean:.4} dice {current_dice_mean.avg:.4},lr: {lr:.4}")
            if config.DEBUG and batch_idx>10:
                break

    def valid(self, val_loader,epoch):
        self.model.eval()
        tqdm_loder = tqdm(val_loader)
        current_loss_mean = 0
        current_dice_mean = AverageMeter()
        for batch_idx, (x, y) in enumerate(tqdm_loder):
            with torch.no_grad():
                x = x.cuda().float()
                y = y.cuda().float()
                predicted = self.model(x)
                predicted_sigmoid = torch.sigmoid(predicted)
                loss = self.loss_fn(predicted.float(), y.float())
                dice = self.metric(predicted.cuda().float(), y.cuda().float())
                current_dice_mean.update(dice)
            current_loss_mean = (current_loss_mean * batch_idx + loss) / (batch_idx + 1)
            tqdm_loder.set_description(f"val loss: {current_loss_mean:.4}, dice {current_dice_mean.avg:.4}")
            if config.DEBUG and batch_idx>10:
                break
        return current_loss_mean
    def run(self, train_loader, val_loader,epochs):
        best = 10000
        for e in range(epochs):
            self.train(train_loader,e,epochs)
            score = self.valid(val_loader,e)
            if score < best:
                best=score
                torch.save(model.state_dict(),"best.pth")
                print("save best model")
            if config.DEBUG:
                break


In [None]:
epochs = config.epochs
if config.DEBUG: print("DEBUG mode")
T = trainer(model,optim,scheduler,loss_fn,metric)
T.run(train_loader, valid_loader,epochs)

In [None]:
!rm -rf *