# びまん性肺疾患(4クラス, (train, validate, test))

## 前準備
### 主要パッケージを読み込む
loggerの設定も行う

In [None]:
import pathlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from IPython.display import display
from logging import basicConfig, getLogger, INFO
basicConfig(level=INFO, format='%(asctime)s %(levelname)s :%(message)s')
logger = getLogger(__name__)

### データディレクトリの指定

In [None]:
DATA_ROOT = pathlib.Path('Data/Images/LIDC_DLD')
CLASS_LABELS = ('normal', 'GGO', 'emphysema', 'honeycomb')
IMAGE_EXT = '.png'

### 画像ファイルを基にpd.DataFrameを作成する

In [None]:
import tut_utils
df_dataset = tut_utils.create_dataset_df(DATA_ROOT, CLASS_LABELS, IMAGE_EXT)
assert set(CLASS_LABELS) == set(df_dataset['class_label'].unique(
)), 'Discrepancy between CLASS_LABELS and df_dataset'
display(df_dataset)

### クラス毎の画像数を確認する。

In [None]:
df_dataset['class_label'].value_counts()

### 各クラスの画像を表示してみる

In [None]:
import tut_utils
tut_utils.show_images_each_class(df_dataset, n_rows=1)

### データ読み込み用の関数を作成

In [None]:
from PIL import Image
IMG_SHAPE = (1, 32, 32)


def load_img(filepath):
    img = Image.open(filepath)
    img = img.resize(IMG_SHAPE[1:])
    return np.atleast_3d(img)

## Data augmentation
いくつかの画像に対して実際にaugmentationを適用し表示する

In [None]:
import torch
import albumentations as A
from tut_utils import AugmentedDataset, load_dataset

album_transform = A.Compose([
    A.RandomBrightnessContrast(brightness_limit=.1, contrast_limit=.1, p=.5),
    A.Flip(p=.5),
    A.ShiftScaleRotate(shift_limit=0.05,
                       scale_limit=.1,
                       rotate_limit=180,
                       p=.8)
])

N_TEST = 3


def random_transform(x, y):
    x = (x / 255).astype(np.float32)
    tfed = album_transform(image=x)
    x = tfed['image']
    x = x.transpose(2, 0, 1)  # to channels first
    return x.astype(np.float32), y


def base_transform(x, y):
    x = x / 255
    x = x.transpose(2, 0, 1)  # to channels first
    return x.astype(np.float32), y


def test_augmentation(df_dataset):
    df_train = df_dataset.iloc[:N_TEST]
    train_data, train_labels = load_dataset(df_train, load_img)
    dataset = AugmentedDataset(train_data, train_labels, random_transform)
    loader = torch.utils.data.DataLoader(dataset,
                                         batch_size=1,
                                         shuffle=False,
                                         num_workers=0)
    for i, data in enumerate(loader):
        plt.figure(figsize=(4, 1.5))
        plt.subplot(1, 2, 1)
        plt.imshow(train_data[i], cmap='gray')
        plt.title('pre-augmentation')
        plt.axis('off')
        plt.subplot(1, 2, 2)
        plt.imshow(data[0].squeeze().numpy(), cmap='gray')
        plt.title('post-augmentation')
        plt.axis('off')
        plt.show()


test_augmentation(df_dataset)

## ネットワーク作成
画像が小さいのでこれまでより小さいネットワークを作成する。

In [None]:
import torch
import torch.nn as nn


class ConvBlock(nn.Module):
    def __init__(self, in_chs, out_chs, kernel_size=3):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_chs, out_chs, kernel_size=kernel_size),
            nn.BatchNorm2d(out_chs), nn.ReLU(inplace=True),
            nn.Conv2d(out_chs, out_chs, kernel_size=kernel_size),
            nn.BatchNorm2d(out_chs), nn.ReLU(inplace=True))

    def forward(self, x):
        return self.block(x)


class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.network = nn.Sequential(ConvBlock(1, 8), nn.MaxPool2d(2),
                                     ConvBlock(8, 16), nn.MaxPool2d(2),
                                     nn.Dropout(.25), nn.Flatten(start_dim=1),
                                     nn.Linear(400, 32), nn.ReLU(inplace=True),
                                     nn.Linear(32, len(CLASS_LABELS)))

    def forward(self, x):
        return self.network(x)


from torchsummary import summary
summary(CNN(), IMG_SHAPE, verbose=False, device='cpu')

In [None]:
import torch.optim as optim
import torch.nn.functional as F
import pytorch_lightning as pl
import os


class LitNet(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = CNN()

    def forward(self, x):
        return self.model(x)

    def shared_step(self, batch):
        x, y = batch
        y_hat = self.model(x)
        loss = F.cross_entropy(y_hat, y)
        return loss

    def training_step(self, batch, batch_idx):
        loss = self.shared_step(batch)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        loss = self.shared_step(batch)
        self.log('val_loss', loss)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

## K-Fold 交差検証(train, validate, test)
各fold中でデータセットの$\frac{2}{4}$を学習用、$\frac{1}{4}$をvalidation(EarlyStopping)用、$\frac{1}{4}$を評価用に使用する

### DataFrameに交差検証用の列を追加する

In [None]:
import itertools
from sklearn.model_selection import StratifiedKFold

K_FOLD = 4
kfold = StratifiedKFold(n_splits=K_FOLD, shuffle=True)

test_indices = [
    t[1] for t in kfold.split(df_dataset['filepath'], df_dataset['class'])
]
index2fold = dict(
    list(
        itertools.chain(*[[(idx, i) for idx in indices]
                          for i, indices in enumerate(test_indices)])))

df_dataset['set'] = df_dataset.index.map(index2fold)
df_dataset

### train, validate, testを用いた交差検証を行う

In [None]:
import copy
import tut_utils
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.loggers import CSVLogger
from torch.utils.data import DataLoader

BATCH_SIZE = 64
EPOCHS = 32
PATIENCE = 4  # early stopping
NUM_WORKERS = 0 if os.name == 'nt' else 2
gpus = 1 if torch.cuda.is_available() else 0

results = []
for i_iter, test_fold in enumerate(range(K_FOLD)):
    logger.info('{i}th iteration of {k}-fold CV'.format(i=i_iter + 1,
                                                        k=K_FOLD))
    val_fold = (test_fold + 1) % K_FOLD
    train_folds = set(range(K_FOLD)) - set([test_fold]) - set([val_fold])
    df_train = df_dataset[df_dataset['set'].map(lambda e: e in train_folds)]
    df_val = df_dataset[df_dataset['set'] == val_fold]
    df_test = df_dataset[df_dataset['set'] == test_fold]
    (train_data,
     train_labels), (val_data, val_labels), (test_data, test_labels) = [
         load_dataset(df, load_img) for df in (df_train, df_val, df_test)
     ]
    trainloader = DataLoader(AugmentedDataset(train_data, train_labels,
                                              random_transform),
                             batch_size=BATCH_SIZE,
                             shuffle=True,
                             num_workers=NUM_WORKERS)
    val_loader = DataLoader(AugmentedDataset(val_data, val_labels,
                                             base_transform),
                            batch_size=BATCH_SIZE,
                            shuffle=False,
                            num_workers=NUM_WORKERS)

    model = LitNet()
    early_stop_callback = EarlyStopping(monitor='val_loss',
                                        patience=PATIENCE,
                                        verbose=False,
                                        mode='min')
    csv_logger = CSVLogger('train_logs', name='dld')
    trainer = pl.Trainer(gpus=gpus,
                         max_epochs=EPOCHS,
                         logger=csv_logger,
                         log_every_n_steps=len(trainloader),
                         callbacks=[early_stop_callback])

    trainer.fit(model, trainloader, val_loader)
    logger.info('Finish training')
    df_logs = pd.read_csv(csv_logger.experiment.metrics_file_path)
    df_logs = pd.DataFrame(
        (df_logs['train_loss'].dropna().reset_index(drop=True),
         df_logs['val_loss'].dropna().reset_index(drop=True))).T
    df_logs.plot(y=['train_loss', 'val_loss'])
    plt.show()
    testloader = DataLoader(AugmentedDataset(test_data, test_labels,
                                             base_transform),
                            batch_size=BATCH_SIZE,
                            shuffle=False,
                            num_workers=NUM_WORKERS)

    df_result = tut_utils.predict_multiclass(model, testloader, df_test.index)
    results.append(df_result)

In [None]:
df_result = pd.concat(results, axis=0)
df_result = df_dataset.join(df_result)
display(df_result)

## 評価
### 混同行列

In [None]:
df_cm = tut_utils.confusion_matrix(df_result)
print('Accuracy = {n} / {d} = {a:.03g}%'.format(n=df_cm.values.trace(),
                                                d=df_cm.values.sum(),
                                                a=100 * df_cm.values.trace() /
                                                df_cm.values.sum()))

display(df_cm)

### ROC

In [None]:
tut_utils.plot_roc_curves(df_result, figsize=(4, 3))
plt.show()

In [None]:
from sklearn import metrics
report = metrics.classification_report(df_result['class'],
                                       df_result['pred_class'],
                                       target_names=CLASS_LABELS,
                                       output_dict=True)
df_report = pd.DataFrame(report)
display(df_report.T)

### クラスごとに間違えている例を表示
#### 画像ごとにlossを計算

In [None]:
df_result['loss'] = F.cross_entropy(torch.FloatTensor(
    df_result['pred_logits']),
                                    torch.tensor(df_result['class']),
                                    reduction='none').numpy()
display(df_result)

#### lossの値が大きい画像を表示

In [None]:
N_SAMPLES = 2
for class_label, group in df_result.groupby('class_label'):
    print(class_label)
    worst = group.sort_values('loss', ascending=False).head(N_SAMPLES)
    worst_data, worst_labels = load_dataset(worst, load_img)
    for img, pred_proba in zip(worst_data, worst['pred_proba']):
        plt.figure(figsize=(4, 1.5))
        plt.subplot(1, 2, 1)
        plt.imshow(img.squeeze(), cmap='gray')
        plt.axis('off')
        ax = plt.subplot(1, 2, 2)
        pd.DataFrame(pred_proba, index=CLASS_LABELS).plot(ax=ax,
                                                          kind='barh',
                                                          legend=False)
        plt.tight_layout()
        plt.show()