<a href="https://colab.research.google.com/github/ykitaguchi77/Article_implementation/blob/main/3D-UNET_brain_segmentation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#**Brain_segmentation_pytorch 3D-UNET**

CTの複数画像を立体的にsegmentationする

GitHub: hhttps://github.com/mateuszbuda/brain-segmentation-pytorch

WebPage: https://towardsdatascience.com/creating-and-training-a-u-net-model-with-pytorch-for-2d-3d-semantic-segmentation-model-building-6ab09d6a0862


In [1]:
!git clone https://github.com/mateuszbuda/brain-segmentation-pytorch.git

#作業フォルダを移動
%cd brain-segmentation-pytorch

Cloning into 'brain-segmentation-pytorch'...
remote: Enumerating objects: 97, done.[K
remote: Counting objects: 100% (12/12), done.[K
remote: Compressing objects: 100% (11/11), done.[K
remote: Total 97 (delta 6), reused 2 (delta 1), pack-reused 85[K
Unpacking objects: 100% (97/97), done.
/content/brain-segmentation-pytorch


#Kaggle_3M datasetのダウンロード

まずKaggleに登録してAPIの使用許可を申請する必要あり。詳細は下記を参考に。

https://www.currypurin.com/entry/2018/kaggle-api

In [2]:
# mount Gdrive
from google.colab import drive
drive.mount('/content/drive')

# kaggle ライブラリのインストール
!pip install kaggle

# 一時フォルダに .kaggleフォルダを作成
!mkdir ~/.kaggle

# MyDrive の kaggle.json　(permissionファイル) を一時フォルダ内の .kaggleフォルダにコピー
!cp /content/drive/MyDrive/Kaggle/kaggle.json ~/.kaggle/

# アクセス権限の設定
!chmod 600 ~/.kaggle/kaggle.json

!mkdir ~/.kaggle

# zipファイルのダウンロード
!kaggle datasets download -d mateuszbuda/lgg-mri-segmentation
#!kaggle competitions download -c rsna-2022-cervical-spine-fracture-detection -p /content/drive/MyDrive/Kaggle
# 解凍
!unzip ./lgg-mri-segmentation.zip -d ./

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  inflating: ./lgg-mri-segmentation/kaggle_3m/TCGA_DU_7294_19890104/TCGA_DU_7294_19890104_9_mask.tif  
  inflating: ./lgg-mri-segmentation/kaggle_3m/TCGA_DU_7298_19910324/TCGA_DU_7298_19910324_1.tif  
  inflating: ./lgg-mri-segmentation/kaggle_3m/TCGA_DU_7298_19910324/TCGA_DU_7298_19910324_10.tif  
  inflating: ./lgg-mri-segmentation/kaggle_3m/TCGA_DU_7298_19910324/TCGA_DU_7298_19910324_10_mask.tif  
  inflating: ./lgg-mri-segmentation/kaggle_3m/TCGA_DU_7298_19910324/TCGA_DU_7298_19910324_11.tif  
  inflating: ./lgg-mri-segmentation/kaggle_3m/TCGA_DU_7298_19910324/TCGA_DU_7298_19910324_11_mask.tif  
  inflating: ./lgg-mri-segmentation/kaggle_3m/TCGA_DU_7298_19910324/TCGA_DU_7298_19910324_12.tif  
  inflating: ./lgg-mri-segmentation/kaggle_3m/TCGA_DU_7298_19910324/TCGA_DU_7298_19910324_12_mask.tif  
  inflating: ./lgg-mri-segmentation/kaggle_3m/TCGA_DU_7298_19910324/TCGA_DU_7298_19910324_13.tif  
  inflating: ./lgg-mri-seg

# ここから

In [3]:
#Requirementsからモジュールをインストール
#※バージョンがconflictしまくるのでバージョン指定なし
#medpy以外はすでに入っている

modules = """
numpy
tensorflow
scikit-learn
scikit-image
imageio
medpy
Pillow
scipy
pandas
tqdm
"""

with open("requirements.txt", mode="w") as f:
    f.write(modules)
!pip install -r requirements.txt

import argparse
import os

import numpy as np
import torch
from matplotlib import pyplot as plt
from matplotlib.backends.backend_agg import FigureCanvasAgg
from medpy.filter.binary import largest_connected_component
from skimage.io import imsave
from torch.utils.data import DataLoader
from tqdm import tqdm

from dataset import BrainSegmentationDataset as Dataset
from unet import UNet
from utils import dsc, gray2rgb, outline


Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting medpy
  Downloading MedPy-0.4.0.tar.gz (151 kB)
[K     |████████████████████████████████| 151 kB 13.2 MB/s 
Collecting SimpleITK>=1.1.0
  Downloading SimpleITK-2.1.1.2-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (48.4 MB)
[K     |████████████████████████████████| 48.4 MB 20 kB/s 
Building wheels for collected packages: medpy
  Building wheel for medpy (setup.py) ... [?25l[?25hdone
  Created wheel for medpy: filename=MedPy-0.4.0-cp37-cp37m-linux_x86_64.whl size=754476 sha256=2ad9926bac0e3a7aff6c6663bcb31439c0d26fc1538725db279a7160cf260bd3
  Stored in directory: /root/.cache/pip/wheels/b0/57/3a/da1183f22a6afb42e11138daa6a759de233fd977a984333602
Successfully built medpy
Installing collected packages: SimpleITK, medpy
Successfully installed SimpleITK-2.1.1.2 medpy-0.4.0


#**interference**

In [4]:
def postprocess_per_volume(
    input_list, pred_list, true_list, patient_slice_index, patients
):
    volumes = {}
    num_slices = np.bincount([p[0] for p in patient_slice_index])
    index = 0
    for p in range(len(num_slices)):
        volume_in = np.array(input_list[index : index + num_slices[p]])
        volume_pred = np.round(
            np.array(pred_list[index : index + num_slices[p]])
        ).astype(int)
        volume_pred = largest_connected_component(volume_pred)
        volume_true = np.array(true_list[index : index + num_slices[p]])
        volumes[patients[p]] = (volume_in, volume_pred, volume_true)
        index += num_slices[p]
    return volumes

def dsc_distribution(volumes):
    dsc_dict = {}
    for p in volumes:
        y_pred = volumes[p][1]
        y_true = volumes[p][2]
        dsc_dict[p] = dsc(y_pred, y_true, lcc=False)
    return dsc_dict

def plot_dsc(dsc_dist):
    y_positions = np.arange(len(dsc_dist))
    dsc_dist = sorted(dsc_dist.items(), key=lambda x: x[1])
    values = [x[1] for x in dsc_dist]
    labels = [x[0] for x in dsc_dist]
    labels = ["_".join(l.split("_")[1:-1]) for l in labels]
    fig = plt.figure(figsize=(12, 8))
    canvas = FigureCanvasAgg(fig)
    plt.barh(y_positions, values, align="center", color="skyblue")
    plt.yticks(y_positions, labels)
    plt.xticks(np.arange(0.0, 1.0, 0.1))
    plt.xlim([0.0, 1.0])
    plt.gca().axvline(np.mean(values), color="tomato", linewidth=2)
    plt.gca().axvline(np.median(values), color="forestgreen", linewidth=2)
    plt.xlabel("Dice coefficient", fontsize="x-large")
    plt.gca().xaxis.grid(color="silver", alpha=0.5, linestyle="--", linewidth=1)
    plt.tight_layout()
    canvas.draw()
    plt.close()
    s, (width, height) = canvas.print_to_buffer()
    return np.fromstring(s, np.uint8).reshape((height, width, 4))


In [None]:
#############################
weights_dir = "./weights/unet.pt" #weightの保存先
#############################

#predictionsフォルダ作成
os.makedirs("./predictions", exist_ok=True)

#deviceを定義
device = torch.device("cpu" if not torch.cuda.is_available() else "cuda:0")

#データセットとデータローダー
dataset = Dataset(
    images_dir="./lgg-mri-segmentation/kaggle_3m", #kaggle dataset使用
    subset="validation",
    image_size=256,
    random_sampling=False,
)
loader = DataLoader(
    dataset, batch_size=32, drop_last=False, num_workers=1
)

with torch.set_grad_enabled(False):
    unet = UNet(in_channels=Dataset.in_channels, out_channels=Dataset.out_channels)
    
    #モデルのweightをロード
    state_dict = torch.load(weights_dir, map_location=device)
    unet.load_state_dict(state_dict)

    unet.eval()
    unet.to(device)

    input_list = []
    pred_list = []
    true_list = []

    for i, data in tqdm(enumerate(loader)):
        x, y_true = data
        x, y_true = x.to(device), y_true.to(device)

        y_pred = unet(x)
        y_pred_np = y_pred.detach().cpu().numpy()
        pred_list.extend([y_pred_np[s] for s in range(y_pred_np.shape[0])])

        y_true_np = y_true.detach().cpu().numpy()
        true_list.extend([y_true_np[s] for s in range(y_true_np.shape[0])])

        x_np = x.detach().cpu().numpy()
        input_list.extend([x_np[s] for s in range(x_np.shape[0])])

    volumes = postprocess_per_volume(
        input_list,
        pred_list,
        true_list,
        loader.dataset.patient_slice_index,
        loader.dataset.patients,
    )

    dsc_dist = dsc_distribution(volumes)

    dsc_dist_plot = plot_dsc(dsc_dist)
    imsave("./dsc.png", dsc_dist_plot)

    for p in volumes:
        x = volumes[p][0]
        y_pred = volumes[p][1]
        y_true = volumes[p][2]
        for s in range(x.shape[0]):
            image = gray2rgb(x[s, 1])  # channel 1 is for FLAIR
            image = outline(image, y_pred[s, 0], color=[255, 0, 0])
            image = outline(image, y_true[s, 0], color=[0, 255, 0])
            filename = "{}-{}.png".format(p, str(s).zfill(2))
            filepath = os.path.join(args.predictions, filename)
            imsave(filepath, image)

reading validation images...
