# 肺野セグメンテーション
参考 [Image segmentation  |  TensorFlow Core](https://www.tensorflow.org/tutorials/images/segmentation)

## 前準備
### 主要パッケージのインポート

In [None]:
import pathlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from PIL import Image
from IPython.display import display
from logging import basicConfig, getLogger, INFO
basicConfig(level=INFO, format='%(asctime)s %(levelname)s :%(message)s')
logger = getLogger(__name__)

### データディレクトリの指定

In [None]:
DATA_ROOT = pathlib.Path('Data/Images/chest_xray')
IMAGE_DIR = 'regular'
LABEL_DIR = 'lung'
CLASS_LABELS = ('lung')
IMAGE_EXT = '.png'

### 画像ファイルを基にpd.DataFrameを作成する

In [None]:
def create_dataset_df(data_root, image_dir, label_dir, image_ext):
    dfs = []
    root = pathlib.Path(data_root)
    image_filenames = (root / pathlib.Path(image_dir)).glob('*' + image_ext)
    df = pd.DataFrame(image_filenames, columns=['image_path'])
    df['label_path'] = df['image_path'].map(
        lambda p: root / pathlib.Path(label_dir) / p.name)
    return df


df_dataset = create_dataset_df(DATA_ROOT, IMAGE_DIR, LABEL_DIR, IMAGE_EXT)
display(df_dataset)

### 画像を表示
入力画像を背景にセグメンテーションを重畳表示する

In [None]:
OVERLAY_ALPHA = 0.5
cmap = np.array([[0, 0, 0, 0], [255, 0, 0,
                                255 * OVERLAY_ALPHA]]).astype(np.uint8)

N_SAMPLES = 5
plt.figure(figsize=(10, 5))
for i, sample in enumerate(df_dataset.sample(n=N_SAMPLES).itertuples()):
    image = Image.open(sample.image_path).convert('RGBA')
    label = np.array(Image.open(sample.label_path).convert('L'))
    label = (label > 0).astype(np.uint8)
    label = Image.fromarray(cmap[label])
    plt.subplot(1, N_SAMPLES, i + 1)
    plt.imshow(Image.alpha_composite(image, label))
    plt.axis('off')
plt.tight_layout()
plt.show()

### ホールドアウト
学習に時間がかかるため、今回は交差検証は行わない。

In [None]:
from sklearn.model_selection import KFold
N_SPLITS = 5
kfold = KFold(n_splits=N_SPLITS, shuffle=True)
train_index, test_index = next(kfold.split(df_dataset['image_path']))

df_train = df_dataset.iloc[train_index]
df_test = df_dataset.iloc[test_index]
print('training:', len(df_train), 'test:', len(df_test))

## モデル作成
U-Net

In [None]:
IN_CHS = 1
OUT_CHS = 1
UNET_DEPTH = 4

import torch
import torch.nn as nn

class ConvBlock(nn.Module):
    def __init__(self, in_chs, out_chs, kernel_size, padding):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_chs,
                      out_chs,
                      kernel_size=kernel_size,
                      padding=padding), nn.BatchNorm2d(out_chs),
            nn.ReLU(inplace=True))

    def forward(self, x):
        return self.block(x)


class Encoder(nn.Module):
    def __init__(self, in_chs: int, out_chs: int, kernel_size: int,
                 padding: int):
        super().__init__()
        mid_chs = in_chs * 2
        self.block = nn.Sequential(
            ConvBlock(in_chs, mid_chs, kernel_size, padding),
            ConvBlock(mid_chs, out_chs, kernel_size, padding),
        )

    def forward(self, x):
        return self.block(x)


class Decoder(nn.Module):
    def __init__(self, in_chs, out_chs, kernel_size, padding,
                 scale_factor: int):
        super().__init__()
        mid_chs = (in_chs + out_chs) // 2
        self.up = nn.Upsample(scale_factor=scale_factor,
                              mode='bilinear',
                              align_corners=True)
        self.block = nn.Sequential(
            ConvBlock(in_chs, mid_chs, kernel_size, padding),
            ConvBlock(mid_chs, out_chs, kernel_size, padding),
        )

    def forward(self, x1, x2):
        cat = torch.cat([self.up(x1), x2], dim=1)
        return self.block(cat)


class UNet(nn.Module):
    def __init__(self,
                 in_chs,
                 out_chs,
                 depth,
                 kernel_size=3,
                 padding=1,
                 scale_factor=2):
        '''
        Args:
            depth (int): UNets depth i.e # of downsampling layers
        '''
        super().__init__()
        self.depth = depth
        self.encs = nn.ModuleList()
        self.pools = nn.ModuleList()
        ini_chs = 16
        for i in range(depth):
            enc = Encoder(ini_chs * (2**i) if i != 0 else in_chs,
                          ini_chs * (2**(i + 1)), kernel_size, padding)
            self.encs.append(enc)
            if i < (depth - 1):
                self.pools.append(nn.MaxPool2d(scale_factor))

        self.decs = nn.ModuleList()
        for i in range(depth - 1):
            chs = ini_chs * (2**(depth - i))
            dec = Decoder(int(chs * 1.5), ini_chs * (2**(depth - i - 1)),
                          kernel_size, padding, scale_factor)
            self.decs.append(dec)

        self.output_layer = nn.Conv2d(ini_chs * 2,
                                      out_chs,
                                      kernel_size=1,
                                      padding=0)

    def forward(self, x):
        skips = []
        for i in range(self.depth):
            x = self.encs[i](x)
            if i < (self.depth - 1):
                skips.append(x)
                x = self.pools[i](x)

        for i in range(self.depth - 1):
            x = self.decs[i](x, skips[-(i + 1)])

        return self.output_layer(x)

from torchsummary import summary
summary(UNet(IN_CHS, OUT_CHS, UNET_DEPTH), (1, 512, 512), verbose=0, device='cpu')

### ネットワーク構造の可視化

In [None]:
import torchviz
dummy_x = torch.zeros((1, 1, 512, 512), dtype=torch.float, requires_grad=False)
dummy_y = UNet(IN_CHS, OUT_CHS, UNET_DEPTH)(dummy_x)
dot = torchviz.make_dot(dummy_y)
dot.format = 'svg'
dot

### データの読み込み

In [None]:
import tqdm
IMG_SHAPE = (3, 512, 512)


def load_img(filepath):
    img = Image.open(filepath)
    img = img.resize(IMG_SHAPE[1:])
    return np.atleast_3d(img)

def load_segmentation_dataset(df, load_img):
    data = np.stack(
        [load_img(filepath) for filepath in tqdm.tqdm(df['image_path'])])
    labels = np.stack([
        load_img(filepath)[..., :1]
        for filepath in tqdm.tqdm(df['label_path'])
    ])
    return (data/255).astype(np.float32), labels/255

train_data, train_labels = load_segmentation_dataset(df_train, load_img)

### Data augmentation
回転、左右反転等をランダムに適用する。

In [None]:
import albumentations as A

album_transform = A.Compose([
    A.HorizontalFlip(p=.25),
    A.ShiftScaleRotate(shift_limit=0, scale_limit=.1, rotate_limit=5, p=.8)
])

class AugmentedSegmentationDataset(torch.utils.data.Dataset):
    def __init__(self, x, y, transform=None):
        self.transform = transform

        self.xs = x
        self.ys = y

    def __len__(self):
        return len(self.xs)

    def __getitem__(self, idx):
        x, y = self.xs[idx], self.ys[idx]
        if self.transform:
            tfed = self.transform(image=x, mask=y)
            x, y = tfed['image'], tfed['mask']
        x = x.transpose(2, 0, 1) # to channel first
        y = y.transpose(2, 0, 1) # to channel first

        return x.astype(np.float32), y.astype(np.float32)

N_TEST = 3


def test_augmentation(df_dataset):
    dataset = AugmentedSegmentationDataset(train_data[:N_TEST], train_labels[:N_TEST], album_transform)
    loader = torch.utils.data.DataLoader(dataset,
                                         batch_size=1,
                                         shuffle=False,
                                         num_workers=0)
    for i, data in enumerate(loader):
        plt.figure(figsize=(4, 1.5))
        plt.subplot(1, 2, 1)
        plt.imshow(data[0].squeeze().numpy(), cmap='gray')
        plt.title('input image')
        plt.axis('off')
        plt.subplot(1, 2, 2)
        plt.imshow(data[1].squeeze().numpy(), cmap='gray')
        plt.title('label image')
        plt.axis('off')
        plt.show()


test_augmentation(df_dataset)

## 学習

### pytorch-lightining

In [None]:
import torch.optim as optim
import torch.nn.functional as F
import os

NUM_WORKERS = 0 if os.name == 'nt' else 2
BATCH_SIZE = 8

import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping


class LitUNet(pl.LightningModule):
    def __init__(self, in_chs, out_chs, depth):
        super().__init__()
        self.model = UNet(in_chs, out_chs, depth)

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.model(x)
        loss = F.binary_cross_entropy_with_logits(y_hat, y)
        self.log('train_loss', loss)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer


In [None]:
EPOCHS = 64
PATIENCE = 4
BATCH_SIZE = 4

model = LitUNet(IN_CHS, OUT_CHS, UNET_DEPTH)
early_stop_callback = EarlyStopping(monitor='train_loss',
                                    patience=PATIENCE,
                                    verbose=False,
                                    mode='min')
trainer = pl.Trainer(gpus=1 if torch.cuda.is_available() else 0,
                     max_epochs=EPOCHS,
                     callbacks=[early_stop_callback])

dataset = AugmentedSegmentationDataset(train_data, train_labels, album_transform)
trainloader = torch.utils.data.DataLoader(dataset,
                                          batch_size=BATCH_SIZE,
                                          shuffle=True,
                                          num_workers=NUM_WORKERS)
trainer.fit(model, trainloader)
logger.info('Finish training')

## 評価
Dice similarity coefficient(F1 score)とJaccard Index(IoU)を評価指標とする。

In [None]:
import math
from sklearn import metrics

label_cmap = np.array([[0, 0, 0], [255, 0, 0], [0, 255, 0], [255, 255, 0]])

test_data, test_labels = load_segmentation_dataset(df_test, load_img)

model.eval()
model.freeze()


scores = []
for i, (data, label) in enumerate(zip(test_data, test_labels)):
    data = data.transpose((2,0,1))[np.newaxis]
    label = label.astype(np.uint8).squeeze()
    with torch.no_grad():
        pred = torch.sigmoid(model(torch.FloatTensor(data)))
        pred = pred.cpu().numpy().squeeze()
    pred_bin = (pred > .5).astype(np.uint8)
    scores.append((metrics.f1_score(label.ravel(), pred_bin.ravel()),
                   metrics.jaccard_score(label.ravel(), pred_bin.ravel())))
    if i < N_SAMPLES:
        plt.subplot(1, 5, 1)
        plt.imshow(data.squeeze(), cmap='gray')
        plt.title('input')
        plt.axis('off')
        plt.subplot(1, 5, 2)
        plt.imshow(pred, cmap='jet')
        plt.title('result')
        plt.axis('off')
        plt.subplot(1, 5, 3)
        plt.imshow(label_cmap[[0, 1]][pred_bin.astype(np.uint8).squeeze()])
        plt.title('result label')
        plt.axis('off')
        plt.subplot(1, 5, 4)
        plt.imshow(label_cmap[[0, 2]][label.astype(np.uint8)])
        plt.title('true label')
        plt.axis('off')
        plt.subplot(1, 5, 5)
        plt.imshow(label_cmap[pred_bin + label * 2])
        plt.title('comparison')
        plt.axis('off')
        plt.tight_layout()
        plt.show()

In [None]:
df_score = pd.DataFrame(scores,
                        columns=['dice coefficient', 'jaccard index'],
                        index=test_index)
display(df_score.head())
display(
    pd.DataFrame({
        'median': df_score.median(),
        'mean': df_score.mean(),
        'std': df_score.std(),
        'min': df_score.min(),
        'max': df_score.max(),
    }))

### 分布の確認
Dice similarity coefficientの分布を表示する。

#### ヒストグラム

In [None]:
import seaborn as sns
sns.histplot(x=df_score['dice coefficient'])
plt.show()

#### Boxplot

In [None]:
sns.boxplot(x='dice coefficient', data=df_score)
sns.swarmplot(x='dice coefficient', data=df_score, color='.2')
plt.show()

#### Letter value plot

In [None]:
sns.boxenplot(x='dice coefficient', data=df_score)
sns.swarmplot(x='dice coefficient',
              data=df_score,
              color='white',
              edgecolor='black',
              linewidth=1)
plt.show()

#### Violinplot

In [None]:
sns.violinplot(x='dice coefficient', data=df_score, inner=None)
sns.swarmplot(x='dice coefficient',
              data=df_score,
              color='white',
              edgecolor='gray')
plt.show()