# Phân đoạn Pneumothorax (tràn khí màng phổi)

## Tổng quan bài toán

Trong dự án này, sẽ phải phân đoạn được đâu là phần tràn khí màng phổi từ hình ảnh X-quang đã được gán nhãn trước.

Dữ liệu bao gồm có 12955 hình X-quang từ 12047 bệnh nhân và nhãn tương ứng (có những hình ảnh có nhãn, có những hình ảnh không có nhãn tức là không bị chứng tràn khí màng phổi).

## Mục tiêu

* Đào tạo mô hình segment dựa trên Unet với các backbone:
    * Resnet50
    * Efficientnet
    * Xceptionnet
* Thử nghiệm các phương pháp augumentation trên hình ảnh X-Quang
* Đánh giá mô hình:
    * Trên metrics dice score
    * Đánh giá cho từng mô hình với các backbone khác nhau

In [None]:
# import libraries
import numpy as np
import pandas as pd
from glob import glob
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib.image as image
from tqdm.notebook import tqdm
import glob
import pydicom
import sys
import os

print(os.listdir("../input/siim-acr-pneumothorax-segmentation"))
print()
sys.path.insert(0, '../input/siim-acr-pneumothorax-segmentation')

from mask_functions import rle2mask
%matplotlib inline

In [None]:
# install libraries for augumentation image and
# visualize architecture model Unet
!pip install albumentations
!pip install torchsummary

In [None]:
import torch
import torchvision
import torch.nn as nn 
from torch.utils.data import DataLoader, Dataset

from torchsummary import summary
import torchvision.models as models
import torchvision.transforms as T
import albumentations as A

from torch.autograd import Variable

## Load information for dataset

In [None]:
# Load information for dataset
train_df = pd.read_csv("../input/segment-data/train_info_split.csv")
val_df = pd.read_csv("../input/segment-data/val_info_split.csv")

In [None]:
# Load all path for file .dcm
file_path = "../input/siim-acr-pneumothorax-segmentation-data/dicom-images-train/*/*/*.dcm"
file_paths = glob.glob(file_path)

## Visualize image origin and mask with file .dcm

In [None]:
# Visualize image mask for file .dcm
image_id_arr = train_df["ImageId"].unique()

for index, image_id in enumerate(image_id_arr):
    index_ = list(filter(lambda x: image_id in file_paths[x], range(len(file_paths))))
    dataset = pydicom.dcmread(file_paths[index_[0]])
    image_data = dataset.pixel_array
    
    record_arr = train_df[train_df["ImageId"]==image_id]
    # Visualize patient has multi segment
    if len(record_arr) >= 2:
        fig, (ax1, ax2) = plt.subplots(1, 2)
        fig.set_figheight(15)
        fig.set_figwidth(15)
        ax1.imshow(image_data, cmap=plt.cm.bone)
        ax2.imshow(image_data, cmap=plt.cm.bone)
        mask = np.zeros((1024, 1024))
        for _, row in record_arr.iterrows():
            if row["EncodedPixels"] != ' -1':
                mask_ = rle2mask(row["EncodedPixels"], 1024, 1024).T
                mask[mask_==255] = 255
        
        ax2.imshow(mask, alpha=0.3, cmap="Blues")    
        break

## Khởi tạo kiến trúc mô hình

### Các hàm trợ giúp

In [None]:
def toTensor(np_array, axis=(2,0,1)):
    return torch.tensor(np_array).permute(axis)

def toNumpy(tensor, axis=(1,2,0)):
    return tensor.detach().cpu().permute(axis).numpy()

### Tạo Data Loader

In [None]:
# Visualize image mask for file .dcm
image_id_arr = train_df["ImageId"].unique()

for index, image_id in enumerate(image_id_arr):
    index_ = list(filter(lambda x: image_id in file_paths[x], range(len(file_paths))))
    dataset = pydicom.dcmread(file_paths[index_[0]])
    image_data = dataset.pixel_array
    
    record_arr = train_df[train_df["ImageId"]==image_id]
    # Visualize patient has multi segment
    if len(record_arr) >= 2:
        fig, (ax1, ax2) = plt.subplots(1, 2)
        fig.set_figheight(15)
        fig.set_figwidth(15)
        ax1.imshow(image_data, cmap=plt.cm.bone)
        ax2.imshow(image_data, cmap=plt.cm.bone)
        mask = np.zeros((1024, 1024))
        for _, row in record_arr.iterrows():
            if row["EncodedPixels"] != ' -1':
                mask_ = rle2mask(row["EncodedPixels"], 1024, 1024).T
                mask[mask_==255] = 255
        
        ax2.imshow(mask, alpha=0.3, cmap="Blues")    
        break

In [None]:
def get_infor(df):
    infor = []
    image_id_arr = df["ImageId"].unique()
    for index, image_id in tqdm(enumerate(image_id_arr)):
        index_ = list(filter(lambda x: image_id in file_paths[x], range(len(file_paths))))
        full_image_path = file_paths[index_[0]]

        # Get all segment encode
        record_arr = train_df[train_df["ImageId"]==image_id]
        encode_pixels = []
        for _, row in record_arr.iterrows():
            encode_pixels.append(row["EncodedPixels"])

        infor.append({
            "key": image_id,
            "file_path": full_image_path,
            "mask": encode_pixels
        })
    return infor

print("Loading information for training set \n")
train_infor = get_infor(train_df)
print("Loading information for validation set \n")
val_infor = get_infor(val_df)

In [None]:
import cv2

class MaskDataset(Dataset):
    def __init__(self, df, img_info, transforms=None):
        self.df = df
        self.img_info = img_info
        self.transforms = transforms
        
    def __getitem__(self, idx):
        img_path = self.img_info[idx]["file_path"]
        key = self.img_info[idx]["key"]
        
        # load image data
        dataset = pydicom.dcmread(img_path)
        img = dataset.pixel_array
        img = cv2.resize(img, dsize=(512, 512), interpolation=cv2.INTER_CUBIC)
        
        mask_arr = self.img_info[idx]["mask"]
        
        mask = np.zeros((512, 512))
        
        for item in mask_arr:
            if item != " -1":
                mask_ = rle2mask(item, 1024, 1024).T
                mask_ = cv2.resize(mask_, dsize=(512, 512), interpolation=cv2.INTER_CUBIC)
                mask[mask_==255] = 255
        
        if self.transforms:
            sample = {
                "image": img,
                "mask": mask
            }
            sample = self.transforms(**sample)
            img = sample["image"]
            mask = sample["mask"]

        # to Tensor
        mask = np.expand_dims(mask, axis=-1)/255.0
        mask = toTensor(mask).float()
        
        img = np.expand_dims(img, axis=-1)/255.0
        img = toTensor(img).float()
        
        return img, mask
            
    def __len__(self):
        return len(self.img_info)

In [None]:
train_transform = A.Compose([
    A.HorizontalFlip(),
    A.OneOf([
        A.RandomContrast(),
        A.RandomGamma(),
        A.RandomBrightness(),
        ], p=0.3),
    A.OneOf([
        A.ElasticTransform(alpha=120, sigma=120 * 0.05, alpha_affine=120 * 0.03),
        A.GridDistortion(),
        A.OpticalDistortion(distort_limit=2, shift_limit=0.5),
        ], p=0.3),
    A.ShiftScaleRotate(),
])

train_dataset = MaskDataset(train_df, train_infor, train_transform)
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, drop_last=True)

val_dataset = MaskDataset(val_df, val_infor)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=True, drop_last=True)

number_visualize = 1
for img, mask in train_dataset:
    if number_visualize > 5:
        break
    img = toNumpy(img)[:,:,0]
    mask = toNumpy(mask)[:,:,0]

    fig, (ax1, ax2) = plt.subplots(1, 2)
    fig.set_figheight(15)
    fig.set_figwidth(15)
    ax1.imshow(img, cmap=plt.cm.bone)
    ax2.imshow(img, cmap=plt.cm.bone)
    ax2.imshow(mask, alpha=0.3, cmap="Blues")
    number_visualize += 1


### Định nghĩa hàm đo dice score và tính toán dice loss

Đọc trong: [jeremyjordan semantic-segmentation](https://www.jeremyjordan.me/semantic-segmentation/)

In [None]:
!pip install torchsummary
import os
import cv2
import pandas
import numpy as np
from tqdm import tqdm
from glob import glob

import matplotlib
import matplotlib.pyplot as plt
from random import choice, choices, shuffle

import torch
import torchvision
import torch.nn as nn 
from torch.utils.data import DataLoader, Dataset

from torchsummary import summary
import torchvision.models as models
import torchvision.transforms as T
from sklearn.model_selection import train_test_split
from random import randint
import albumentations as A
from PIL import Image


In [None]:
def Deconv(n_input, n_output, k_size=4, stride=2, padding=1):
    Tconv = nn.ConvTranspose2d(
        n_input, n_output,
        kernel_size=k_size,
        stride=stride, padding=padding,
        bias=False)
    block = [
        Tconv,
        nn.BatchNorm2d(n_output),
        nn.LeakyReLU(inplace=True),
    ]
    return nn.Sequential(*block)
        

def Conv(n_input, n_output, k_size=4, stride=2, padding=0, bn=False, dropout=0):
    conv = nn.Conv2d(
        n_input, n_output,
        kernel_size=k_size,
        stride=stride,
        padding=padding, bias=False)
    block = [
        conv,
        nn.BatchNorm2d(n_output),
        nn.LeakyReLU(0.2, inplace=True),
        nn.Dropout(dropout)
    ]
    return nn.Sequential(*block)

def get_layer_efficientnet(efficientnet, start_index, end_index, x):
    for index in range(start_index, end_index):
        x = efficientnet._blocks[index](x)
    
    return x

class Unet(nn.Module):
    def __init__(self, efficientnet):
        super().__init__()
        
        # get some layer from efficientnet
        self.efficientnet = efficientnet
        self._conv_stem = efficientnet._conv_stem
        self._bn0 = efficientnet._bn0
        
        self._conv_head = efficientnet._conv_head
        self._bn1 = efficientnet._bn1
        self._avg_pooling = efficientnet._avg_pooling
        self._dropout = efficientnet._dropout
        
        self.tanh = nn.Tanh()
        self.sigmoid = nn.Sigmoid()
        
        
        # convolution layer, use to reduce the number of channel => reduce weight number
        self.conv_6 = Conv(384, 192, 1, 1, 0)
        self.conv_5 = Conv(160, 80, 1, 1, 0)
        self.conv_4 = Conv(80, 40, 1, 1, 0)
        self.conv_3 = Conv(48, 24, 1, 1, 0)
        self.conv_2 = Conv(32, 16, 1, 1, 0)
        self.conv_1 = Conv(64, 32, 1, 1, 0)
        self.conv_0 = Conv(1, 1, 1, 1, 0)
        
        # deconvolution layer
        self.deconv6 = Deconv(320, 192, 3, 1, 1)
        self.deconv5 = Deconv(192, 80, 3, 2, 0)
        self.deconv4 = Deconv(80, 40, 5, 2, 1)
        self.deconv3 = Deconv(40, 24, 5, 2, 1)
        self.deconv2 = Deconv(24, 16, 5, 2, 1)
        self.deconv1 = Deconv(16, 32, 1, 1, 0)
        self.deconv0 = Deconv(32, 1, 4, 2, 0)
        
        
    def forward(self, x):
        
        # down sample
        x = self._conv_stem(x)
        x = self._bn0(x)
        skip_1 = x
        
        x = get_layer_efficientnet(self.efficientnet, 0, 1, x)
        skip_2 = x
        x = get_layer_efficientnet(self.efficientnet, 1, 3, x)
        skip_3 = x
        x = get_layer_efficientnet(self.efficientnet, 3, 5, x)
        skip_4 = x
        
        x = get_layer_efficientnet(self.efficientnet, 5, 8, x)
        skip_5 = x
        
        x = get_layer_efficientnet(self.efficientnet, 8, 15, x)
        skip_6 = x
        
        x = get_layer_efficientnet(self.efficientnet, 15, 16, x)
        x7 = x
        
        # up sample
        x6 = self.deconv6(x7)
        x6 = torch.cat([x6, skip_6], dim=1)
        x6 = self.conv_6(x6)
        
        x5 = self.deconv5(x6)
        x5 = torch.cat([x5, skip_5], dim=1)
        x5 = self.conv_5(x5)
        
        x4 = self.deconv4(x5)
        x4 = torch.cat([x4, skip_4], dim=1)
        x4 = self.conv_4(x4)
        
        x3 = self.deconv3(x4)
        x3 = torch.cat([x3, skip_3], dim=1)
        x3 = self.conv_3(x3)
        
        x2 = self.deconv2(x3)
        x2 = torch.cat([x2, skip_2], dim=1)
        x2 = self.conv_2(x2)
        
        x1 = self.deconv1(x2)
        x1 = torch.cat([x1, skip_1], dim=1)
        x1 = self.conv_1(x1)
        
        x0 = self.deconv0(x1)
        x0 = self.conv_0(x0)
        x0 = self.sigmoid(x0)
        return x0

In [None]:
!pip install torchsummary
from torchsummary import summary

package_path = '../input/efficientnet/EfficientNet-PyTorch-master/'
sys.path.append(package_path)
from efficientnet_pytorch import EfficientNet

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model_ft = EfficientNet.from_name('efficientnet-b0')
model_ft.load_state_dict(torch.load('../input/efficientnet-pytorch/efficientnet-b0-08094119.pth'))
model_ft._conv_stem = nn.Conv2d(1, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)

model_ft.to(device)
model = Unet(model_ft)
model.to(device)

# for i, child in enumerate(model.children()):
#     if i <= 7:
#         for param in child.parameters():
#             param.requires_grad = False

print(summary(model,input_size=(1,512,512)))

In [None]:
ALPHA = 0.8
GAMMA = 2
import torch.nn.functional as F

def dice_score(inputs, targets, smooth=1):
    #flatten label and prediction tensors
    inputs = inputs.view(-1)
    targets = targets.view(-1)

    intersection = (inputs * targets).sum()                            
    dice_score = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)  
    
    return dice_score

def get_dice_loss(inputs, targets, smooth=1):
    #flatten label and prediction tensors
    inputs = inputs.view(-1)
    targets = targets.view(-1)

    intersection = (inputs * targets).sum()                            
    dice_loss = 1 - (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)  
    
    return dice_loss

def get_focal_loss(inputs, targets, alpha=0.8, gamma=2, smooth=1):       
    #flatten label and prediction tensors
    inputs = inputs.view(-1)
    targets = targets.view(-1)

    #first compute binary cross-entropy 
    BCE = F.binary_cross_entropy(inputs, targets, reduction='mean')
    BCE_EXP = torch.exp(-BCE)
    focal_loss = alpha * (1-BCE_EXP)**gamma * BCE
    
    return focal_loss

def combo_loss(inputs, targets):
    dice_loss = get_dice_loss(inputs, targets)
    BCE = F.binary_cross_entropy(inputs, targets, reduction='mean')
    focal_loss = get_focal_loss(inputs, targets)
    
    return 1*dice_loss + 4*focal_loss + 3*BCE

In [None]:
def dice_score_validation(inputs, targets, smooth=1):
    #flatten label and prediction tensors
    inputs = inputs.view(-1)
    targets = targets.view(-1)
    
    if np.all(inputs.numpy() == 0) and np.all(targets.numpy() == 0):
        dice_score = 1
        
        return dice_score

    intersection = (inputs * targets).sum()                            
    dice_score = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)  
    
    return dice_score

In [None]:
# import csv

# train_params = [param for param in model.parameters() if param.requires_grad]
# optimizer = torch.optim.Adam(train_params, lr=0.001, betas=(0.9, 0.99))

# start_epochs = 281
# end_epochs = 381
# path_model = "../input/segment-data/model_unet_efficienet_epoch_280.0_train_dice_score_0.5216_val_dice_score_0.3641.pth"
# model.load_state_dict(torch.load(path_model))
# model.train()
# saved_dir = "model"
# os.makedirs(saved_dir, exist_ok=True)
# path_csv_history = "./model/history.csv"

# with open(path_csv_history, mode='w') as csv_file:
#     fieldnames = ['epoch', 'train_dice_score', 'val_dice_score']
#     writer = csv.DictWriter(csv_file, fieldnames=fieldnames)

#     writer.writeheader()
#     for epoch in range(start_epochs, end_epochs):
#         number_iter = 1
#         train_dice_score = []
#         with tqdm(train_loader, unit="batch") as tepoch:
#             for imgs, masks in tepoch:
#                 tepoch.set_description(f"Epoch {epoch}, interation {number_iter}")
#                 optimizer.zero_grad()
#                 imgs_gpu = imgs.to(device)
#                 outputs = model(imgs_gpu)
#                 # outputs = torch.sigmoid(outputs)
#                 masks = masks.to(device)

#                 dice_scores = dice_score(outputs, masks)
#                 loss = combo_loss(outputs, masks)
                
#                 loss.backward()
#                 optimizer.step()
#                 tepoch.set_postfix(loss=loss.item(), 
#                                    dice_score=dice_scores.item())
#                 number_iter += 1
                
#         train_dice_score = []
#         train_losses = []
#         with torch.no_grad():
#             for imgs, masks in tqdm(train_loader):
#                 imgs_gpu = imgs.to(device)
#                 outputs = model(imgs_gpu)
#                 # outputs = torch.sigmoid(outputs) 
#                 masks = masks.to(device)
                
#                 dice_scores = dice_score(outputs, masks)
#                 loss = combo_loss(outputs, masks)
                
#                 train_dice_score.extend([dice_scores.item()])
#                 train_losses.extend([loss.item()])
#             train_dice_score = np.mean(np.array(train_dice_score))
#             train_losses = np.mean(np.array(train_losses))
        
#         val_dice_score = []
#         val_losses = []
#         with torch.no_grad():
#             for imgs, masks in tqdm(val_loader):
#                 imgs_gpu = imgs.to(device)
#                 outputs = model(imgs_gpu)
#                 # outputs = torch.sigmoid(outputs) 
#                 masks = masks.to(device)
                
#                 dice_scores = dice_score(outputs, masks)
#                 loss = combo_loss(outputs, masks)
                
#                 val_dice_score.extend([dice_scores.item()])
#                 val_losses.extend([loss.item()])
#             val_dice_score = np.mean(np.array(val_dice_score))
#             val_losses = np.mean(np.array(val_losses))
#             print(f"Epoch {epoch}, Train loss {train_losses:0.4f}, train dice score: {train_dice_score:0.4f}, Validation loss {val_losses:0.4f}, validation dice score: {val_dice_score:0.4f}")
#             path = "./model/model_unet_efficienet_epoch_{:.1f}_train_dice_score_{:0.4f}_val_dice_score_{:0.4f}.pth".format(epoch, train_dice_score, val_dice_score)
#             torch.save(model.state_dict(), path)
        
#             writer.writerow({
#                 "epoch": epoch,
#                 "train_dice_score": train_dice_score,
#                 "val_dice_score": val_dice_score
#             })

In [None]:
# Test model
import csv

train_params = [param for param in model.parameters() if param.requires_grad]
optimizer = torch.optim.Adam(train_params, lr=0.001, betas=(0.9, 0.99))

path_model = "../input/segment-data/model_unet_efficienet_epoch_280.0_train_dice_score_0.5216_val_dice_score_0.3641.pth"
model.load_state_dict(torch.load(path_model))



train_dice_score = []
with torch.no_grad():
    for imgs, masks in tqdm(train_loader):
        imgs_gpu = imgs.to(device)
        outputs = model(imgs_gpu)
        masks = masks.to(device)
        
        dice_scores = dice_score_validation(outputs, masks)
        train_dice_score.extend([dice_scores.item()])

    train_dice_score = np.mean(np.array(train_dice_score))

    val_dice_score = []
    
with torch.no_grad():
    for imgs, masks in tqdm(val_loader):
        imgs_gpu = imgs.to(device)
        outputs = model(imgs_gpu)
        masks = masks.to(device)
        
        dice_scores = dice_score_validation(outputs, masks)
        val_dice_score.extend([dice_scores.item()])

val_dice_score = np.mean(np.array(val_dice_score))
print(f"train dice score: {train_dice_score:0.4f}, validation dice score: {val_dice_score:0.4f}")


In [None]:
# visualize 
number = 0
with torch.no_grad():
    for imgs, masks in val_loader:
        imgs_gpu = imgs.to(device)
        outputs = model(imgs_gpu)
        # outputs = torch.sigmoid(outputs) 
        outputs = torch.round(outputs) * 255
        masks = masks.to(device)
        for index in range(10):
            img_origin = np.reshape(imgs_gpu[index].cpu().numpy(), (512, 512))
            pred_ = np.reshape(outputs[index].cpu().numpy(), (512, 512))
            mask_ = np.reshape(masks[index].cpu().numpy()*255, (512, 512))
            if np.all(mask_==0):
                continue
            fig, (ax1, ax2, ax3) = plt.subplots(1, 3)
            fig.set_figheight(15)
            fig.set_figwidth(15)
            ax1.imshow(img_origin, cmap=plt.cm.bone)
            ax2.imshow(img_origin, cmap=plt.cm.bone)
            ax2.imshow(pred_, alpha=0.3, cmap="Blues")
            ax3.imshow(img_origin, cmap=plt.cm.bone)
            ax3.imshow(mask_, alpha=0.3, cmap="Blues")
            number += 1
        if number == 100:
            break