# **Lung Segmentation using pretrained U-net**

* For this notebook, UNet architecture with pre-trained ResNet34 as an encoder is used, I've used [segmentation_models.pytorch](https://github.com/qubvel/segmentation_models.pytorch) library which has many inbuilt segmentation architectures with different backbones.

* The purpose of this notebook is to identify "Pneumothorax" or a collapsed lung from chest x-rays. Pneumothorax is a condition that is responsible for making people suddenly gasp for air, and feel helplessly breathless for no apparent reason. So ultimately, we want to develop a model to identify and segment pneumothorax from a set of chest radiographic images.

In [None]:
!pip install segmentation-models-pytorch

In [None]:
import os
import glob
import numpy as np
import pandas as pd
from tqdm import tqdm
from collections import defaultdict, Counter

from PIL import Image,ImageFile
import albumentations as A
import matplotlib.pyplot as plt

from sklearn import model_selection
import segmentation_models_pytorch as smp

import torch
from torch import nn,optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

ImageFile.LOAD_TRUNCATED_IMAGES = True
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# **Utility Functions**

* Most Segmentation problems like this should have two images: input and mask. In this case of multiple objects, there will be multiple masks. In this dataset, we are provided with **RLE** instead. **RLE** stands for **Run Length Encoding** and is a way to represent binary mask to save space.

* Here is the utility function which is used to create mask from this RLE data.

In [None]:
def run_length_decode(rle, height=1024, width=1024, fill_value=1):
    component = np.zeros((height, width), np.float32)
    component = component.reshape(-1)
    rle = np.array([int(s) for s in rle.strip().split(' ')])
    rle = rle.reshape(-1, 2)
    start = 0
    for index, length in rle:
        start = start+index
        end = start+length
        component[start: end] = fill_value
        start = end
    component = component.reshape(width, height).T
    return component

# **Create Dataset Class**

* In this cell, the traditional Dataset class has been created. Please note that, this class is created in such way that it can applied to almost any segmentation problem.
* Here training dataset is a CSV file consisting only ImageIds which are also filenames and other column contains RLE data. So, in the **init()** function we fatch the image ids and initialize some other parameters as well which are then used during **getitem()** method.
* In the **getitem()** method Augmentation and preprocessing has been done. This method returns dictionary contains image and mask.

In [None]:
class SIIMDataset(Dataset):
    
    def __init__(self, df, data_dir, transform=None, preprocessing_fun=None, channel_first=True):
        self.data_dir = data_dir
        self.transform = transform                       # for augmentations
        self.preprocessing_fun = preprocessing_fun       # preprocessing_fun to normalize images
        self.channel_first = channel_first               # set channels as first dimension
        self.image_ids = df.ImageId.values
        self.group_by = df.groupby('ImageId')

    def __len__(self):
        return len(self.image_ids)

    def __getitem__(self, idx):
        img_id = self.image_ids[idx]
        df = self.group_by.get_group(img_id)
        annotations = df[' EncodedPixels'].tolist()
        
        img_path = os.path.join(self.data_dir, img_id + ".png")
        img = Image.open(img_path).convert('RGB')
        img = np.array(img)

        mask = np.zeros(shape=(1024,1024))
        if annotations[0] != ' -1':
            for rle in annotations:
                mask += run_length_decode(rle)
        mask = (mask >= 1).astype('float32')
        mask = np.expand_dims(mask, axis=-1)
        
        # apply augmentation
        if self.transform:
            augmented = self.transform(image=img, mask=mask)
            img = augmented['image']
            mask = augmented['mask']

        if self.preprocessing_fun:
            img = self.preprocessing_fun(img)
        
        # convert shape from (width, height, channel) ----> (channel, width, height) 
        if self.channel_first:
            img = np.transpose(img, (2, 0, 1)).astype(np.float32)
            mask = np.transpose(mask, (2, 0, 1)).astype(np.float32)

        return {
            'image': torch.Tensor(img),
            'mask': torch.Tensor(mask)
        }

# **Create Training and Evaluation Function**

* In this section, Training and Evaluation function is created for one epoch. The functions takes model, dataloader, criterion(loss function), optimizer and returns average loss for one epoch.

In [None]:
# Train model for one epoch

def train(data_loader, model, criterion, optimizer):
    model.train()
    train_loss = 0
    for data in tqdm(data_loader):
        inputs = data['image']
        labels = data['mask']

        inputs = inputs.to(device, dtype=torch.float)
        labels = labels.to(device, dtype=torch.float)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        train_loss += loss.item()
        loss.backward()
        optimizer.step()

    return train_loss/len(data_loader)

In [None]:
# Evaluate the model

def evaluate(data_loader, model, criterion):
    model.eval()
    eval_loss = 0
    with torch.no_grad():
        for data in tqdm(data_loader):
            inputs = data['image']
            labels = data['mask']

            inputs = inputs.to(device, dtype=torch.float)
            labels = labels.to(device, dtype=torch.float)

            outputs = model(inputs)
            loss = criterion(outputs, labels)
            eval_loss += loss.item()

    return eval_loss/len(data_loader)

# **Define train-val Dataset and Create DataLoader**

In [None]:
# Intialize some useful variables

DATA_DIR = '../input/siim-png-images/train_png'
data_csv = '../input/siim-acr-pneumothorax-segmentation-data/train-rle.csv'
batch_size = 32

Encoder = 'resnet34'
Weights = 'imagenet'

In [None]:
# Define augmentation and preprocessing-function(according to Encoder)

transform = A.Compose([
    A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.1, rotate_limit=10, p=0.5),
    A.OneOf([A.RandomGamma(gamma_limit=(90,110)),
             A.RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0.1)], p=0.5),
    A.Resize(width=224, height=224)
])

prep_fun = smp.encoders.get_preprocessing_fn(
    Encoder,
    Weights
)

In [None]:
# Split data into training and validation

df = pd.read_csv(data_csv)
df_train, df_val = model_selection.train_test_split(df, test_size=0.15)

In [None]:
# Initialize Dataset
train_dataset = SIIMDataset(df_train,
                            DATA_DIR,
                            transform = transform,
                            preprocessing_fun = prep_fun)

val_dataset = SIIMDataset(df_val,
                          DATA_DIR,
                          transform = transform,
                          preprocessing_fun = prep_fun)

# Create DataLoader
train_loader = DataLoader(train_dataset,
                          batch_size = batch_size,
                          shuffle = True,
                          num_workers = 8)

val_loader = DataLoader(val_dataset,
                        batch_size = batch_size,
                        num_workers = 4)

In [None]:
# Explore DataLoader

print('Training data Info:')
dataiter = iter(train_loader)
data = dataiter.next()
images,labels = data['image'],data['mask']
print("shape of images : {}".format(images.shape))
print("shape of labels : {}".format(labels.shape))

print('\nValidation data Info:')
dataiter = iter(val_loader)
data = dataiter.next()
images,labels = data['image'],data['mask']
print("shape of images : {}".format(images.shape))
print("shape of labels : {}".format(labels.shape))

# **Visualize Images**

In [None]:
def denoramlize(img):
    img = img.permute(1,2,0)            # change shape ---> (width, height, channel)
    mean = torch.FloatTensor([0.485, 0.456, 0.406])
    std = torch.FloatTensor([0.229, 0.224, 0.225])
    img = img*std + mean
    img = np.clip(img,0,1)              # convert the pixel values range(min=0, max=1)
    return img

def imshow(img, mask):
    fig = plt.figure(figsize=(15, 10))
    a = fig.add_subplot(1, 3, 1)
    plt.imshow(denoramlize(img), cmap='bone')
    a.set_title("Original x-ray image")
    plt.grid(False)
    plt.axis("off")

    a = fig.add_subplot(1, 3, 2)
    imgplot = plt.imshow(torch.squeeze(mask, dim=1).permute(1,2,0), cmap='binary')
    a.set_title("The mask")
    plt.grid(False)
    plt.xticks([])
    plt.yticks([])

    a = fig.add_subplot(1, 3, 3)
    plt.imshow(denoramlize(img), cmap='bone')
    plt.imshow(torch.squeeze(mask, dim=1).permute(1,2,0), cmap='binary', alpha=0.3)
    a.set_title("Mask on the X-ray image")

    plt.axis("off")
    plt.grid(False)


def show_batch_image(dataloader, num_images):
    data = next(iter(dataloader))
    image,mask = data['image'],data['mask']
    img_idx = torch.randint(0, dataloader.batch_size, (num_images,))
    for i in img_idx:
        imshow(image[i], mask[i])

show_batch_image(train_loader, 5)

# **Create Loss Class**

* Here, Dice Loss is define as well as Focal Loss is also created to get better results.
* For this notebook, I have used the combination of two loss: diceloss and focalloss

In [None]:
def dice_loss(input, target):
    input = torch.sigmoid(input)
    smooth = 1.0
    iflat = input.view(-1)
    tflat = target.view(-1)
    intersection = (iflat * tflat).sum()
    return ((2.0 * intersection + smooth) / (iflat.sum() + tflat.sum() + smooth))


class FocalLoss(nn.Module):
    def __init__(self, gamma):
        super().__init__()
        self.gamma = gamma

    def forward(self, input, target):
        if not (target.size() == input.size()):
            raise ValueError("Target size ({}) must be the same as input size ({})"
                             .format(target.size(), input.size()))
        max_val = (-input).clamp(min=0)
        loss = input - input * target + max_val + \
            ((-max_val).exp() + (-input - max_val).exp()).log()
        invprobs = F.logsigmoid(-input * (target * 2.0 - 1.0))
        loss = (invprobs * self.gamma).exp() * loss
        return loss.mean()


class MixedLoss(nn.Module):
    def __init__(self, alpha, gamma):
        super().__init__()
        self.alpha = alpha
        self.focal = FocalLoss(gamma)

    def forward(self, input, target):
        loss = self.alpha*self.focal(input, target) - torch.log(dice_loss(input, target))
        return loss.mean()

# **Train The Model**

In [None]:
# Create Model

model = smp.Unet(
    encoder_name = Encoder,
    encoder_weights = Weights,
    in_channels = 3,
    classes = 1,
    activation = None
)

model.to(device)

In [None]:
# Define loss function and Set Optimizer

criterion = MixedLoss(alpha = 10.0,
                      gamma = 2.0)

optimizer = optim.Adam(model.parameters(),
                       lr = 0.0001)

In [None]:
# Loop over all Epochs

epochs = 15

for epoch in range(epochs):

    train_loss = train(train_loader,
                       model,
                       criterion,
                       optimizer)

    val_loss = evaluate(val_loader,
                        model,
                        criterion)

    print(f'Epoch: {epoch+1}')
    print(f'Training Loss: {train_loss}, \t Validation Loss: {val_loss}\n')

# **Finally, UPVOTE the notebook if you found it useful, feel free in comments**