# Segmentation models
This notebook shows how to use the awesome segmentation models package https://github.com/qubvel/segmentation_models.pytorch to create a U-net segmentation model that will try to predict the catheter position given the annotations that accompany this dataset. This can be used as an auxiliary training task or used as downstream input to a different model. This notebook will generate masks for a single validation fold, but can easily be extended to make predictions for full out of fold training set. 

There are several different ways this can be posed, but I have chosen to pose this as a 4 class segmentation so that the model is just trying to identify the 4 different catheter types and not generating masks trying to predict if they are correctly placed or not. It should be trivial to extend this to 11 classes or reduced to 1. 

This model does not train the classification branch but the stub is still there if you would like to change the dataloader and some other parts to train it. 

In [None]:
!pip install segmentation-models-pytorch

In [None]:
import os
image_size = 512
seed = 42
use_amp = True
debug = False
model_name = 'resnet18'
phase = "only_seg"
os.mkdir("./models")
if os.path.exists("./models/" + model_name):
    print("Warning model directory already exists \n"*5)
else:
    os.mkdir("./models/" + model_name)
data_dir = '../input/ranzcr-clip-catheter-line-classification/train/'

In [None]:
import pandas as pd
import numpy as np
import sys
sys.path.append('../input/timm-pytorch-image-models/pytorch-image-models-master')
import timm
import segmentation_models_pytorch as smp
import time
import cv2
import PIL.Image
import random
from sklearn.metrics import accuracy_score
import torch
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR 
import albumentations as a
from albumentations import *
from albumentations.pytorch import ToTensorV2
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import gc
from sklearn.metrics import roc_auc_score
import seaborn as sns
from pylab import rcParams
import timm
import ast
from warnings import filterwarnings
filterwarnings("ignore")

device = torch.device('cuda')

## Model

In [None]:
class SegNet(nn.Module):
    def __init__(self, model_name='resnet200d', out_dim=11, pretrained=False):
        super().__init__()
        aux_params=dict(
                        pooling='max',
                        dropout=0.1,
                        classes=out_dim)
        self.model = smp.Unet(model_name, encoder_weights="imagenet", classes = 4, aux_params = aux_params)
    def forward(self, x):
        mask_logits, logits = self.model(x)
        mask_logits = mask_logits.permute(0, 2, 3, 1)
        return mask_logits, logits
        
        

## Utils

In [None]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True # for faster training, but not deterministic
    
seed_everything(seed)

## Transforms

In [None]:
transforms_train = a.Compose([
   a.RandomResizedCrop(image_size, image_size, scale=(0.9, 1), p=1), 
   a.HorizontalFlip(p=0.5),
   a.ShiftScaleRotate(p=0.5),
   a.HueSaturationValue(hue_shift_limit=10, sat_shift_limit=10, val_shift_limit=10, p=0.7),
   a.RandomBrightnessContrast(brightness_limit=(-0.2,0.2), contrast_limit=(-0.2, 0.2), p=0.7),
   a.CLAHE(clip_limit=(1,4), p=0.5),
   a.OneOf([
       a.OpticalDistortion(distort_limit=1.0),
       a.GridDistortion(num_steps=5, distort_limit=1.),
       a.ElasticTransform(alpha=3),
   ], p=0.2),
   a.OneOf([
       a.GaussNoise(var_limit=[10, 50]),
       a.GaussianBlur(),
       a.MotionBlur(),
       a.MedianBlur(),
   ], p=0.2),
  a.Resize(image_size, image_size),
  a.OneOf([
      JpegCompression(),
      Downscale(scale_min=0.1, scale_max=0.15),
  ], p=0.2),
  IAAPiecewiseAffine(p=0.2),
  IAASharpen(p=0.2),
  a.Cutout(max_h_size=int(image_size * 0.1), max_w_size=int(image_size * 0.1), num_holes=5, p=0.5),
  a.Normalize(),
  ToTensorV2()
])

transforms_valid = a.Compose([
    a.Resize(image_size, image_size),
    a.Normalize(),
    ToTensorV2()
])

## Dataset

This interpolation function is used to extend the annotations so we have the smooth curve connecting the points given in the annotation. In the annotation labels we are given a series of points that trace the path of catheter/line, but it would be hard for a model to predict these points because the spacing in between them is a bit arbitrary. To remedy this a straight line is drawn between each consecutive point. It is not perfect, but yields reasonably good looking outputs. This is what the model will be trying to predict. 

In [None]:
from scipy import interpolate
def interpolate_mask(data):
    f = interpolate.interp1d(data[:, 0], data[:, 1])
    xnew = np.arange(data[:, 0].min(), data[:, 0].max(), 1)
    fnew = f(xnew)
    return np.concatenate([xnew[:, None], fnew[:, None]], axis = -1).astype(int)

In [None]:
# data = train_annotations["data"][0]
# data = np.array(ast.literal_eval(data))

In [None]:
# interpolate_mask(data).shape

Classes are reduced down to ETT, NGT, CVC, Swan Ganz. With some slight modification this model could do each in their own class and then try to use it to directly classify the catheter/lines, but for this usage it is likely reasonable to just group them together to try to get the most accurate catheter tracing first. 

In [None]:
COLOR_MAP = {'ETT - Abnormal': 0,
             'ETT - Borderline': 0,
             'ETT - Normal': 0,
             'NGT - Abnormal': 1,
             'NGT - Borderline': 1,
             'NGT - Incompletely Imaged': 1,
             'NGT - Normal': 1,
             'CVC - Abnormal': 2,
             'CVC - Borderline': 2,
             'CVC - Normal': 2,
             'Swan Ganz Catheter Present': 3,
            }


class SegDataset(Dataset):
    def __init__(self, df, df_annotations, annot_size=10, transform=None, mode = 'train'):
        self.df = df
        self.df_annotations = df_annotations
        self.annot_size = annot_size
        self.file_names = df['file_path'].values
        self.patient_id = df['StudyInstanceUID'].values
        self.labels = df[target_cols].values
        self.transform = transform
        self.mode = mode

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        file_name = self.file_names[idx]
        patient_id = self.patient_id[idx]
        no_anno = 1
        image = cv2.imread(file_name)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        seg_mask = np.zeros((image.shape[0], image.shape[1], 4), dtype = np.float32)
        query_string = f"StudyInstanceUID == '{patient_id}'"
        df = self.df_annotations.query(query_string)
        if len(df) == 0:
            no_anno = 0
        for i, row in df.iterrows():
            label = row["label"]
            data = np.array(ast.literal_eval(row["data"]))
            for data_point in range(len(data)):
                point_pairs = data[data_point: data_point + 2]
                if len(point_pairs) < 2:
                    continue
                for d in interpolate_mask(point_pairs):
                    seg_mask[d[1]-self.annot_size//2:d[1]+self.annot_size//2,
                          d[0]-self.annot_size//2:d[0]+self.annot_size//2,
                          COLOR_MAP[label]] = 1
        if self.transform:
            augmented = self.transform(image=image, mask = seg_mask)
            image = augmented['image']
            mask = augmented['mask']
        if self.mode == 'test':
            return torch.tensor(image).float()
        else:
            label = torch.tensor(self.labels[idx]).float()
            no_anno = torch.tensor(no_anno)
            return file_name, torch.tensor(image).float(), torch.tensor(mask).float(), label, no_anno.float()

In [None]:
train_annotations = pd.read_csv('../input/ranzcr-clip-catheter-line-classification/train_annotations.csv')
df_train = pd.read_csv('../input/how-to-properly-split-folds/train_folds.csv')
df_train['file_path'] = df_train.StudyInstanceUID.apply(lambda x: os.path.join(data_dir, f'{x}.jpg'))
if debug:
    df_train = df_train.sample(frac=0.1)
target_cols = df_train.iloc[:, 1:12].columns.tolist()


## Utils

Training is applied only to the samples that we have annotations for, but validation is done on all validation fold samples. Only samples with annotations are scored against but all are predicted for and then written to .npy files for later inspection and reuse. 

In [None]:
def macro_multilabel_auc(label, pred):
    aucs = []
    for i in range(len(target_cols)):
        aucs.append(roc_auc_score(label[:, i], pred[:, i]))
    print(np.round(aucs, 4))
    return np.mean(aucs)


def train_func(train_loader):
    model.train()
    bar = tqdm(train_loader)
    if use_amp:
        scaler = torch.cuda.amp.GradScaler()
    losses = []
    seg_losses = []
    for batch_idx, (_, images, masks, targets, no_annos) in enumerate(bar):
        images, masks, targets, no_annos = images.to(device), masks.to(device), targets.to(device), no_annos.to(device)
        with torch.cuda.amp.autocast():
            mask_logits, logits = model(images)
            loss = criterion(logits, targets).mean()
            seg_loss = criterion(mask_logits, masks)
            seg_loss = seg_loss.mean(axis = -1).mean(axis = -1).mean(axis = -1)
            seg_loss = (seg_loss * no_annos)
            seg_loss = seg_loss.mean()
            total_loss = seg_loss #+ loss
        scaler.scale(total_loss).backward()
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()
        
        losses.append(loss.item())
        smooth_loss = np.mean(losses)
        
        seg_losses.append(seg_loss.item())
        smooth_seg_loss = np.mean(seg_losses)
        bar.set_description(f'loss: {loss.item():.5f}, smth: {smooth_loss:.5f}, seg_smth: {smooth_seg_loss:.5f}')

    loss_train = np.mean(losses)
    return loss_train


def valid_func(valid_loader):
    model.eval()
    bar = tqdm(valid_loader)

    PROB = []
    TARGETS = []
    losses = []
    seg_losses = []
    PREDS = []
    os.makedirs(f'./models/{model_name}/{fold_id}/seg_masks/', exist_ok = True)
    with torch.no_grad():
        for batch_idx, (filenames, images, masks, targets, no_annos) in enumerate(bar):
            images, masks, targets, no_annos = images.to(device), masks.to(device), targets.to(device), no_annos.to(device)
            mask_logits, logits = model(images)
            for idx, filename in enumerate(filenames):
                filename = "".join(filename.split("/")[-1:])
                filename = ".".join(filename.split(".")[:-1])
                mask_pred = mask_logits[idx].sigmoid().detach().cpu().numpy()*255
                mask_pred = mask_pred.astype(np.uint8)
                np.save(f'./models/{model_name}/{fold_id}/seg_masks/{filename}.npy',
                        mask_pred)
            PREDS += [logits.sigmoid()]
            TARGETS += [targets.detach().cpu()]
            loss = criterion(logits, targets).mean()
            losses.append(loss.item())
            smooth_loss = np.mean(losses)
            
            seg_loss = criterion(mask_logits, masks).mean(axis = -1).mean(axis = -1).mean(axis = -1)
            if no_annos.sum() != 0:
                seg_loss = (seg_loss * no_annos).sum()/no_annos.sum()
                seg_losses.append(seg_loss.item())
                seg_smooth_loss = np.mean(seg_losses)
                bar.set_description(f'loss: {loss.item():.5f}, smth: {smooth_loss:.5f}, seg_smth: {seg_smooth_loss:.5f}')
            
    PREDS = torch.cat(PREDS).cpu().numpy()
    TARGETS = torch.cat(TARGETS).cpu().numpy()
    #roc_auc = roc_auc_score(TARGETS.reshape(-1), PREDS.reshape(-1))
    roc_auc = macro_multilabel_auc(TARGETS, PREDS)
    loss_valid = np.mean(losses)
    seg_loss = seg_smooth_loss
    return loss_valid, roc_auc, seg_loss

## Training

In [None]:
def make_model(model_name):
    model = SegNet(out_dim=len(target_cols),model_name = model_name, pretrained=True)
    model = torch.nn.DataParallel(model)
    model = model.to(device)
    return model

In [None]:
init_lr = 1e-3
batch_size = 8
valid_batch_size = 8
n_epochs = 1
num_workers = 4
early_stop = 5

In [None]:
def save_state(fp, model, scheduler, optimizer, epoch):
    model.eval()
    torch.save({
                "model":model.state_dict(),
                "scheduler":scheduler.state_dict(),
                "optimizer":optimizer.state_dict(),
                "epoch": epoch
               },
                fp
              )
def load_state(fp, model, scheduler, optimizer):
    state = torch.load(fp)
    model.load_state_dict(state["model"])
    scheduler.load_state_dict(state["scheduler"])
    optimizer.load_state_dict(state["optimizer"])
    epoch = state["epoch"]
    return model, scheduler, optimizer, epoch

In [None]:
# for item in dataset_train:
#     break

In [None]:
# plt.imshow(item[1][:, :, 10].detach().cpu())

In [None]:
for fold_id in range(5)[:1]:
    log = {}
    roc_auc_max = 0.
    seg_loss_min = 100
    loss_min = 99999
    not_improving = 0
    df_train_this = df_train[df_train['fold'] != fold_id]
    df_valid_this = df_train[df_train['fold'] == fold_id]

    df_train_this = df_train_this[df_train_this['StudyInstanceUID'].isin(train_annotations['StudyInstanceUID'].unique())].reset_index(drop=True)

    dataset_train = SegDataset(df_train_this, train_annotations, annot_size = 20, transform=transforms_train, mode = 'train')
    dataset_valid = SegDataset(df_valid_this, train_annotations, annot_size = 20, transform=transforms_valid, mode = 'train')

    train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=batch_size, shuffle=True,  num_workers=num_workers, pin_memory=True)
    valid_loader = torch.utils.data.DataLoader(dataset_valid, batch_size=valid_batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
    
    model = make_model(model_name)
    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=init_lr)
    
    scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, n_epochs, eta_min=1e-7)
    model_path = f'./models/{model_name}/{fold_id}/{model_name}_fold{fold_id}_{phase}_best_loss.bin'
    if os.path.exists(model_path):
        print("model is resuming training")
        model, scheduler_cosine, optimizer, epoch = load_state(model_path, model, scheduler_cosine, optimizer)
    else:
        print("model being trained from scratch")
        epoch = 1
    for epoch in range(epoch, n_epochs+1):
        scheduler_cosine.step(epoch-1)
        loss_train = train_func(train_loader)
        loss_valid, roc_auc, seg_loss = valid_func(valid_loader)

        log['loss_train'] = log.get('loss_train', []) + [loss_train]
        log['loss_valid'] = log.get('loss_valid', []) + [loss_valid]
        log['lr'] = log.get('lr', []) + [optimizer.param_groups[0]["lr"]]
        log['roc_auc'] = log.get('roc_auc', []) + [roc_auc]
        log['seg_loss'] = log.get('seg_loss', []) + [seg_loss]

        content = time.ctime() + ' ' + f'Fold {fold_id}, Epoch {epoch}, lr: {optimizer.param_groups[0]["lr"]:.7f}, loss_train: {loss_train:.5f}, loss_valid: {loss_valid:.5f}, roc_auc: {roc_auc:.6f}.'
        print(content)
        not_improving += 1
        os.makedirs(f'./models/{model_name}/{fold_id}/', exist_ok = True)
        if seg_loss < seg_loss_min:
            print(f'seg_loss_min ({seg_loss_min:.6f} --> {seg_loss:.6f}). Saving model ...')
            save_state(f'./models/{model_name}/{fold_id}/{model_name}_fold{fold_id}_{phase}_best_AUC.bin',
                       model,
                       scheduler_cosine,
                       optimizer,
                       epoch)
            seg_loss_min = seg_loss
            not_improving = 0

        if loss_valid < loss_min:
            loss_min = loss_valid
            save_state(f'./models/{model_name}/{fold_id}/{model_name}_fold{fold_id}_{phase}_best_loss.bin',
                       model,
                       scheduler_cosine,
                       optimizer,
                       epoch)
        if not_improving == early_stop:
            print('Early Stopping...')
            break
    log_df = pd.DataFrame(log)
    log_df.to_csv(f'./models/{model_name}/{fold_id}/logs_{fold_id}_{phase}.csv')
    df_valid_this.to_csv(f'./models/{model_name}/{fold_id}/val_df_{fold_id}_{phase}.csv')
    
    save_state(f'./models/{model_name}/{fold_id}/{model_name}_fold{fold_id}_{phase}_final.bin',
           model,
           scheduler_cosine,
           optimizer,
           epoch)

In [None]:
img = np.load("./models/resnet18/0/seg_masks/1.2.826.0.1.3680043.8.498.11073617724281949099281046870716891732.npy")

In [None]:
img.shape

Visualization of the models predictions for the 4 different classes. Results are likely not great since not fully trained. You can try training for more epochs and with larger models than resnet18 to get better looking masks. 

In [None]:
for i in range(4):
    plt.imshow(img[:, :, i])
    plt.show()
    