In [None]:
!pip install albumentations==0.4.6
!pip install git+https://github.com/qubvel/segmentation_models.pytorch

In [None]:
import numpy as np 
import pandas as pd 

import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.model_selection import KFold

import cv2
import tifffile

import os
import time
from IPython.display import clear_output
import torch

In [None]:
class Config:
    root_path = '../input/hubmap-kidney-segmentation'
    seed = 55
    custom_colors = ['#35FCFF', '#FF355A', '#96C503', '#C5035B', '#28B463', '#35FFAF', '#8000FF', '#F400FF']
    images_path = '../input/hubmap-256x256/train'
    masks_path = '../input/hubmap-256x256/masks'
    img_size = 256
    pretrained_model_path = None
    train_logs_path = None
    test_csv_path = ''
    
def seed_everything(seed: int):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
       
config = Config()
seed_everything(config.seed) 

In [None]:
train_df = pd.read_csv(os.path.join(config.root_path, 'train.csv'))
train_df

In [None]:
def get_mask(rle_mask, shape=(1600, 256)):

    mask = np.zeros((shape[0]*shape[1]), dtype=np.uint8)
    
    rle_mask = rle_mask.split()
    positions = map(int, rle_mask[::2])
    lengths = map(int, rle_mask[1::2])
    for pos, le in zip(positions, lengths):
        mask[pos-1:pos+le-1] = 1
    mask = mask.reshape((shape[0], shape[1]))
    
    return mask.T

img = tifffile.imread(os.path.join(config.root_path, "train/aaa6a05cc.tiff"))
mask = get_mask(train_df.loc[train_df['id'] == 'aaa6a05cc', 'encoding'].values[0],
                (img.shape[1], img.shape[0]))
print("img shape ->", img.shape)

plt.figure(figsize=(15, 15))
plt.imshow(img)
plt.imshow(np.ma.masked_where(mask == False, mask), alpha=0.7, cmap='winter')
plt.axis("off");

In [None]:
meta_df = pd.read_csv(os.path.join(config.root_path,'HuBMAP-20-dataset_information.csv'))
meta_df.sample(5)

In [None]:
image_dimensions = meta_df.drop_duplicates(subset=['image_file']).groupby(['width_pixels', 'height_pixels']).count()['image_file'].to_dict()
print("Unique Image dimensions:")
image_dimensions

In [None]:
race_dist = meta_df['race'].value_counts().to_dict()

fig, ax = plt.subplots(figsize=(10, 6))
ax.bar(np.arange(len(race_dist)), race_dist.values(), color=config.custom_colors)
ax.set_xticks(np.arange(len(race_dist)))
ax.set_xticklabels(race_dist, fontsize=12)
ax.set_ylabel("count", fontsize=15)
ax.set_xlabel("race", fontsize=18)
ax.set_title("Race Distribution", fontsize=15);

for percentage, p in zip(race_dist.values(), ax.patches):
    percentage = f'{percentage} people'
    x = p.get_x() + p.get_width() / 2 - 0.15
    y = p.get_y() + p.get_height()
    ax.annotate(str(percentage), (x, y), fontsize=12, fontweight='bold')

In [None]:
fig, ax = plt.subplots(figsize=(10, 6))
sns.countplot(meta_df['sex'], palette=config.custom_colors, ax=ax)

ax.set_xlabel(ax.get_xlabel(), fontsize=15)
ax.set_ylabel(ax.get_ylabel(), fontsize=15)

ax.set_xticklabels(['Female', 'Male'], fontsize=12, rotation=0)
ax.set_xlabel('sex', fontsize=18)
ax.set_ylabel(ax.get_ylabel(), fontsize=15)
ax.set_title("Sex Distribution", fontsize=15);

for count, p in zip([6, 7], ax.patches):
    x = p.get_x() + p.get_width() / 2 - 0.1
    y = p.get_y() + p.get_height()
    ax.annotate(str(count)+ " people", (x, y), fontsize=12, fontweight='bold')

In [None]:
fig, ax = plt.subplots(figsize=(10, 6))
sns.countplot(meta_df['age'], palette=config.custom_colors, ax=ax)

ax.set_xlabel(ax.get_xlabel(), fontsize=15)
ax.set_ylabel(ax.get_ylabel(), fontsize=15)
ax.set_title("Age Distribution", fontsize=15);

In [None]:
sns.catplot(x='race',y='age', hue='sex', data=meta_df, kind="bar", height=6, aspect=2, palette="cool", capsize=.05)
plt.ylabel('Age', fontsize=15)
plt.xlabel('race vs sex', fontsize=15);
plt.title("Distribution by Gender, Race and Age.", fontsize=20);

In [None]:
sns.catplot(x='sex',y='age', hue='weight_kilograms', data=meta_df, kind="bar", height=6, aspect=2, palette="cool", capsize=.05)
plt.ylabel('Age', fontsize=15)
plt.xlabel('sex vs weight_kilograms', fontsize=15);
plt.title("Distribution by Gender, weight kilograms and Age.", fontsize=20);

create dataframe with additional metadata

In [None]:
def get_all_json_file(root: str) -> list:
    """Extraction all unique ids from file names."""
    ids = []
    for dirname, _, filenames in os.walk(root):
        for filename in filenames:
            path = os.path.join(dirname, filename)
            if path.endswith(".json"):
                ids.append(path) 
    ids = list(set(filter(None, ids)))
    print(f"Extracted {len(ids)} json files.")
    return ids


test_json_files = get_all_json_file('../input/hubmap-kidney-segmentation/test')
train_json_files = get_all_json_file('../input/hubmap-kidney-segmentation/train')

all_json_files = train_json_files + test_json_files

train_ids = [x.split("/")[-1].replace('.', '-').split("-")[0] for x in train_json_files]
test_ids = [x.split("/")[4].split("-")[0] for x in test_json_files]
all_ids = train_ids + test_ids

In [None]:
df = pd.read_json(all_json_files[0])
df['ID'] = all_ids[0]

for i, js in enumerate(all_json_files[1:], 1):
    df_ = pd.read_json(js)
    df_['ID'] = all_ids[i]
    df = pd.concat([df, df_], ignore_index=True, sort=False)
    
df = df.drop('geometry', 1).assign(**df.geometry.dropna().apply(pd.Series))
df = df.drop('properties', 1).assign(**df.properties.dropna().apply(pd.Series))
df = df.drop('classification', 1).assign(**df.classification.dropna().apply(pd.Series))

df['is_train'] = 0
df.loc[df['ID'].isin(train_ids), 'is_train'] = 1

df.to_csv("meta_data2.csv", index=False)

In [None]:
meta_df2 = pd.read_csv('meta_data2.csv')
meta_df2.sample(5)

In [None]:
percentages = [c / meta_df2.shape[0] * 100 for c in meta_df2['name'].value_counts()]

fig, ax = plt.subplots(figsize=(10, 6))
sns.countplot(meta_df2['name'], order=meta_df2['name'].value_counts().index,
              ax=ax, palette=config.custom_colors)

ax.set_xlabel(ax.get_xlabel(), fontsize=15)
ax.set_ylabel(ax.get_ylabel(), fontsize=15)

ax.set_xticklabels(ax.get_xticklabels(), fontsize=12, rotation=0)
ax.set_xlabel('name', fontsize=18)
ax.set_ylabel(ax.get_ylabel(), fontsize=15)
ax.set_title("Distribution by classification in metadata", fontsize=15);

for percentage, count, p in zip(percentages,
                               meta_df2['name'].value_counts(sort=True).values,
                               ax.patches):
    percentage = f'{np.round(percentage, 2)}%'
    x = p.get_x() + p.get_width() / 2 - 0.4
    y = p.get_y() + p.get_height()
    ax.annotate(str(percentage)+" / "+str(count), (x, y), fontsize=12, fontweight='bold')

# Training Process

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau

import albumentations as A

from segmentation_models_pytorch.unet import Unet

### Dataset and Dataloader

Dataset - https://www.kaggle.com/iafoss/256x256-images

In [None]:
class HuBMAPDataset(Dataset):
    def __init__(self, ids, phase):
        self.ids = ids
        self.augmentations = get_augmentations(phase)
        
    def __getitem__(self, idx):
        name = self.ids[idx]
        img = cv2.imread(f"{config.images_path}/{name}").astype("float32")
        img /= 255.
        mask = cv2.imread(f"{config.masks_path}/{name}")[:,:,0:1]

        augmented = self.augmentations(image=img, mask=mask)
        img = augmented['image']
        mask = augmented['mask']
        img = img.transpose(2,0,1).astype('float32')
        mask = mask.transpose(2,0,1).astype('float32')
        return img, mask

    def __len__(self):
        return len(self.ids)

    
def get_augmentations(phase: str='train'):
    if phase == 'train':
        list_transform = [
            A.HorizontalFlip(),
            A.OneOf([
                A.RandomContrast(),
                A.RandomGamma(),
                A.RandomBrightness(),
                ], p=0.3),
            A.OneOf([
                A.ElasticTransform(alpha=120, sigma=120 * 0.05, alpha_affine=120 * 0.03),
                A.GridDistortion(),
                A.OpticalDistortion(distort_limit=2, shift_limit=0.5),
                ], p=0.3),
            A.ShiftScaleRotate(p=0.2),
            A.Resize(config.img_size,config.img_size,always_apply=True),
        ]
    else:
        list_transform = [
            A.Resize(config.img_size,config.img_size)
        ]
    return A.Compose(list_transform)


def get_dataloader(
    phase: str ='train',
    batch_size: int = 8,
    num_workers: int = 6,
    val_size: float = 0.2,
    fold: int = 0, 
):
    '''Returns: dataloader for the model training'''
    
    ids = os.listdir(config.images_path)
    train_data = pd.DataFrame(ids, columns=['ids'])
    skf = KFold(
        n_splits=7, random_state=config.seed, shuffle=True
    )

    for i, (train_index, val_index) in enumerate(
            skf.split(train_data, train_data)
            ):
            train_data.loc[val_index, "fold"] = i
    
    train_df = train_data.loc[train_data['fold'] != fold].reset_index(drop=True)
    val_df = train_data.loc[train_data['fold'] == fold].reset_index(drop=True)

    ids = train_df['ids'].tolist() if phase == "train" else val_df['ids'].tolist()
    image_dataset = HuBMAPDataset(ids, phase)
    dataloader = DataLoader(
        image_dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=True,
        shuffle=True,   
    )

    return dataloader

In [None]:
val_dataloader = get_dataloader(phase='train')
imgs, masks = next(iter(val_dataloader))
print("images ->", imgs.shape, imgs.dtype)
print("masks ->", masks.shape, masks.dtype)

#fig, ax = plt.subplots(1,2, figsize=(15, 7))
#ax[0].imshow(imgs[0].permute(1,2,0))
#ax[1].imshow(masks[0][0,:,:])

Exploring several augmentations 

In [None]:
def plot_with_augmentation(image, mask, augment):
    """
    Wrapper for `visualize` function.
    """
    augmented = augment(image=image, mask=mask)
    aug_image = augmented['image']
    aug_mask = augmented['mask']
    
    f, ax = plt.subplots(2, 2, figsize=(15, 15))
    fontsize = 20

    ax[0, 0].imshow(image)
    ax[0, 0].set_title('Original image', fontsize=fontsize)
    ax[0, 1].imshow(mask)
    ax[0, 1].set_title('Original mask', fontsize=fontsize)
    
    ax[1, 0].imshow(aug_image)
    ax[1, 0].set_title('Transformed image', fontsize=fontsize)
    ax[1, 1].imshow(aug_mask)
    ax[1, 1].set_title('Transformed mask', fontsize=fontsize)
    
    plt.tight_layout()
    plt.show()

img = cv2.imread(os.path.join(config.images_path, '0486052bb_150.png'))
mask = cv2.imread(os.path.join(config.masks_path, '0486052bb_150.png')).astype("float32")


In [None]:
plot_with_augmentation(img, mask,  A.HorizontalFlip(p=1))

In [None]:
plot_with_augmentation(img, mask,  A.ElasticTransform(p=1, alpha=120, sigma=120 * 0.05, alpha_affine=120 * 0.03))

In [None]:
plot_with_augmentation(img, mask,  A.OpticalDistortion(p=1, distort_limit=2, shift_limit=0.5))

In [None]:
plot_with_augmentation(img, mask,  A.RandomBrightness(p=1))

### Metric and Loss

In [None]:
def dice_coef_metric(probabilities: torch.Tensor,
                     truth: torch.Tensor,
                     treshold: float = 0.5,
                     eps: float = 1e-9) -> np.ndarray:
    """
    Calculate Dice score for data batch.
    Params:
        probobilities: model outputs after activation function.
        truth: truth values.
        threshold: threshold for probabilities.
        eps: additive to refine the estimate.
        Returns: dice score aka f1.
    """
    scores = []
    num = probabilities.shape[0]
    predictions = (probabilities >= treshold).float()
    assert(predictions.shape == truth.shape)
    for i in range(num):
        prediction = predictions[i]
        truth_ = truth[i]
        intersection = 2.0 * (truth_ * prediction).sum()
        union = truth_.sum() + prediction.sum()
        if truth_.sum() == 0 and prediction.sum() == 0:
            scores.append(1.0)
        else:
            scores.append((intersection + eps) / union)
    return np.mean(scores)


def jaccard_coef_metric(probabilities: torch.Tensor,
               truth: torch.Tensor,
               treshold: float = 0.5,
               eps: float = 1e-9) -> np.ndarray:
    """
    Calculate Jaccard index for data batch.
    Params:
        probobilities: model outputs after activation function.
        truth: truth values.
        threshold: threshold for probabilities.
        eps: additive to refine the estimate.
        Returns: jaccard score aka iou."
    """
    scores = []
    num = probabilities.shape[0]
    predictions = (probabilities >= treshold).float()
    assert(predictions.shape == truth.shape)

    for i in range(num):
        prediction = predictions[i]
        truth_ = truth[i]
        intersection = (prediction * truth_).sum()
        union = (prediction.sum() + truth_.sum()) - intersection + eps
        if truth_.sum() == 0 and prediction.sum() == 0:
            scores.append(1.0)
        else:
            scores.append((intersection + eps) / union)
    return np.mean(scores)


class Meter:
    '''factory for storing and updating iou and dice scores.'''
    def __init__(self, treshold: float = 0.5):
        self.threshold: float = treshold
        self.dice_scores: list = []
        self.iou_scores: list = []
    
    def update(self, logits: torch.Tensor, targets: torch.Tensor):
        """
        Takes: logits from output model and targets,
        calculates dice and iou scores, and stores them in lists.
        """
        probs = torch.sigmoid(logits)
        dice = dice_coef_metric(probs, targets, self.threshold)
        iou = jaccard_coef_metric(probs, targets, self.threshold)
        
        self.dice_scores.append(dice)
        self.iou_scores.append(iou)
    
    def get_metrics(self) -> np.ndarray:
        """
        Returns: the average of the accumulated dice and iou scores.
        """
        dice = np.mean(self.dice_scores)
        iou = np.mean(self.iou_scores)
        return dice, iou
    

class DiceLoss(nn.Module):
    """Calculate dice loss."""
    def __init__(self, eps: float = 1e-9):
        super(DiceLoss, self).__init__()
        self.eps = eps
        
    def forward(self,
                logits: torch.Tensor,
                targets: torch.Tensor) -> torch.Tensor:
        
        num = targets.size(0)
        probability = torch.sigmoid(logits)
        probability = probability.view(num, -1)
        targets = targets.view(num, -1)
        assert(probability.shape == targets.shape)
        
        intersection = 2.0 * (probability * targets).sum()
        union = probability.sum() + targets.sum()
        dice_score = (intersection + self.eps) / union
        #print("intersection", intersection, union, dice_score)
        return 1.0 - dice_score
        
        
class BCEDiceLoss(nn.Module):
    """Compute objective loss: BCE loss + DICE loss."""
    def __init__(self):
        super(BCEDiceLoss, self).__init__()
        self.bce = nn.BCEWithLogitsLoss()
        self.dice = DiceLoss()
        
    def forward(self, 
                logits: torch.Tensor,
                targets: torch.Tensor) -> torch.Tensor:
        assert(logits.shape == targets.shape)
        dice_loss = self.dice(logits, targets)
        bce_loss = self.bce(logits, targets)
        
        return bce_loss + dice_loss

### Model

In [None]:
model = Unet('efficientnet-b2', encoder_weights="imagenet", classes=1, activation=None)

In [None]:
class Trainer:
    """
    Factory for training proccess.
    Args:
        display_plot: if True - plot train history after each epoch.
        net: neural network for mask prediction.
        criterion: factory for calculating objective loss.
        optimizer: optimizer for weights updating.
        phases: list with train and validation phases.
        dataloaders: dict with data loaders for train and val phases.
        imgs_dir: path to folder with images.
        masks_dir: path to folder with imasks.
        path_to_csv: path to csv file.
        meter: factory for storing and updating metrics.
        batch_size: data batch size for one step weights updating.
        num_epochs: num weights updation for all data.
        accumulation_steps: the number of steps after which the optimization step can be taken
                    (https://www.kaggle.com/c/understanding_cloud_organization/discussion/105614).
        lr: learning rate for optimizer.
        scheduler: scheduler for control learning rate.
        losses: dict for storing lists with losses for each phase.
        jaccard_scores: dict for storing lists with jaccard scores for each phase.
        dice_scores: dict for storing lists with dice scores for each phase.
    """
    def __init__(self,
                 net: nn.Module,
                 criterion: nn.Module,
                 lr: float,
                 accumulation_steps: int,
                 batch_size: int,
                 num_epochs: int,
                 display_plot: bool = True,
                 fold: int = 0,
                ):

        """Initialization."""
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        print("device:", self.device)
        self.display_plot = display_plot
        self.net = net
        self.net = self.net.to(self.device)
        self.criterion = criterion
        self.optimizer = Adam(self.net.parameters(), lr=lr)
        self.scheduler = ReduceLROnPlateau(self.optimizer, mode="min",
                                           patience=3, verbose=True)
        self.accumulation_steps = accumulation_steps // batch_size
        self.phases = ["train", "val"]
        self.num_epochs = num_epochs

        self.dataloaders = {
            phase: get_dataloader(
                phase = phase,
                batch_size = 8,
                num_workers = 4
            )
            for phase in self.phases
        }
        self.best_loss = float("inf")
        self.losses = {phase: [] for phase in self.phases}
        self.dice_scores = {phase: [] for phase in self.phases}
        self.jaccard_scores = {phase: [] for phase in self.phases}
         
    def _compute_loss_and_outputs(self,
                                  images: torch.Tensor,
                                  targets: torch.Tensor):
        images = images.to(self.device)
        targets = targets.to(self.device)
        logits = self.net(images)
        loss = self.criterion(logits, targets)
        return loss, logits
        
    def _do_epoch(self, epoch: int, phase: str):
        print(f"{phase} epoch: {epoch} | time: {time.strftime('%H:%M:%S')}")

        self.net.train() if phase == "train" else self.net.eval()
        meter = Meter()
        dataloader = self.dataloaders[phase]
        total_batches = len(dataloader)
        running_loss = 0.0
        self.optimizer.zero_grad()
        for itr, (images, targets) in enumerate(dataloader):
            loss, logits = self._compute_loss_and_outputs(images, targets)
            loss = loss / self.accumulation_steps
            if phase == "train":
                loss.backward()
                if (itr + 1) % self.accumulation_steps == 0:
                    self.optimizer.step()
                    self.optimizer.zero_grad()
            running_loss += loss.item()
            meter.update(logits.detach().cpu(),
                         targets.detach().cpu()
                        )
            
        epoch_loss = (running_loss * self.accumulation_steps) / total_batches
        epoch_dice, epoch_iou = meter.get_metrics()
        
        self.losses[phase].append(epoch_loss)
        self.dice_scores[phase].append(epoch_dice)
        self.jaccard_scores[phase].append(epoch_iou)

        return epoch_loss
        
    def run(self):
        for epoch in range(self.num_epochs):
            self._do_epoch(epoch, "train")
            with torch.no_grad():
                val_loss = self._do_epoch(epoch, "val")
                self.scheduler.step(val_loss)
            if self.display_plot:
                self._plot_train_history()
                
            if val_loss < self.best_loss:
                print(f"\n{'#'*20}\nSaved new checkpoint\n{'#'*20}\n")
                self.best_loss = val_loss
                torch.save(self.net.state_dict(), "best_model.pth")
            print()
        self._save_train_history()
            
    def _plot_train_history(self):
        data = [self.losses, self.dice_scores, self.jaccard_scores]
        colors = ['deepskyblue', "crimson"]
        labels = [
            f"""
            train loss {self.losses['train'][-1]}
            val loss {self.losses['val'][-1]}
            """,
            
            f"""
            train dice score {self.dice_scores['train'][-1]}
            val dice score {self.dice_scores['val'][-1]} 
            """, 
                  
            f"""
            train jaccard score {self.jaccard_scores['train'][-1]}
            val jaccard score {self.jaccard_scores['val'][-1]}
            """,
        ]
        
        clear_output(True)
        with plt.style.context("seaborn-dark-palette"):
            fig, axes = plt.subplots(3, 1, figsize=(8, 10))
            for i, ax in enumerate(axes):
                ax.plot(data[i]['val'], c=colors[0], label="val")
                ax.plot(data[i]['train'], c=colors[-1], label="train")
                ax.set_title(labels[i])
                ax.legend(loc="upper right")
                
            plt.tight_layout()
            plt.show()
            
    def load_predtrain_model(self,
                             state_path: str):
        self.net.load_state_dict(torch.load(state_path))
        print("Predtrain model loaded")
        
    def _save_train_history(self):
        """writing model weights and training logs to files."""
        torch.save(self.net.state_dict(),
                   f"last_epoch_model.pth")

        logs_ = [self.losses, self.dice_scores, self.jaccard_scores]
        log_names_ = ["_loss", "_dice", "_jaccard"]
        logs = [logs_[i][key] for i in list(range(len(logs_)))
                         for key in logs_[i]]
        log_names = [key+log_names_[i] 
                     for i in list(range(len(logs_))) 
                     for key in logs_[i]
                    ]
        pd.DataFrame(
            dict(zip(log_names, logs))
        ).to_csv("train_log.csv", index=False)

In [None]:
trainer = Trainer(net=model,
                  criterion=BCEDiceLoss(),
                  lr=5e-4,
                  accumulation_steps=32,
                  batch_size=8,
                  num_epochs=30,
                  fold=0,
                 )

if config.pretrained_model_path is not None:
    trainer.load_predtrain_model(config.pretrained_model_path)
    
    # if need - load the logs.      
    train_logs = pd.read_csv(config.train_logs_path)
    trainer.losses["train"] =  train_logs.loc[:, "train_loss"].to_list()
    trainer.losses["val"] =  train_logs.loc[:, "val_loss"].to_list()
    trainer.dice_scores["train"] = train_logs.loc[:, "train_dice"].to_list()
    trainer.dice_scores["val"] = train_logs.loc[:, "val_dice"].to_list()
    trainer.jaccard_scores["train"] = train_logs.loc[:, "train_jaccard"].to_list()
    trainer.jaccard_scores["val"] = train_logs.loc[:, "val_jaccard"].to_list()

In [None]:
%%time
trainer.run()