<a href="https://colab.research.google.com/github/ykitaguchi77/Article_implementation/blob/main/3D-UNET_brain_segmentation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#**Brain_segmentation_pytorch 3D-UNET**

CTの複数画像を立体的にsegmentationする

GitHub: hhttps://github.com/mateuszbuda/brain-segmentation-pytorch

WebPage: https://towardsdatascience.com/creating-and-training-a-u-net-model-with-pytorch-for-2d-3d-semantic-segmentation-model-building-6ab09d6a0862


In [1]:
!git clone https://github.com/mateuszbuda/brain-segmentation-pytorch.git

#作業フォルダを移動
%cd brain-segmentation-pytorch

Cloning into 'brain-segmentation-pytorch'...
remote: Enumerating objects: 97, done.[K
remote: Counting objects: 100% (13/13), done.[K
remote: Compressing objects: 100% (12/12), done.[K
remote: Total 97 (delta 6), reused 2 (delta 1), pack-reused 84[K
Unpacking objects: 100% (97/97), done.
/content/brain-segmentation-pytorch


#Kaggle_3M datasetのダウンロード

まずKaggleに登録してAPIの使用許可を申請する必要あり。詳細は下記を参考に。

https://www.currypurin.com/entry/2018/kaggle-api

In [None]:
# mount Gdrive
from google.colab import drive
drive.mount('/content/drive')

# kaggle ライブラリのインストール
!pip install kaggle

# 一時フォルダに .kaggleフォルダを作成
!mkdir ~/.kaggle

# MyDrive の kaggle.json　(permissionファイル) を一時フォルダ内の .kaggleフォルダにコピー
!cp /content/drive/MyDrive/Kaggle/kaggle.json ~/.kaggle/

# アクセス権限の設定
!chmod 600 ~/.kaggle/kaggle.json

!mkdir ~/.kaggle

# zipファイルのダウンロード
!kaggle datasets download -d mateuszbuda/lgg-mri-segmentation
#!kaggle competitions download -c rsna-2022-cervical-spine-fracture-detection -p /content/drive/MyDrive/Kaggle
# 解凍
!unzip ./lgg-mri-segmentation.zip -d ./

# ここから

In [3]:
#Requirementsからモジュールをインストール
#※バージョンがconflictしまくるのでバージョン指定なし
#medpy以外はすでに入っている

modules = """
numpy
tensorflow
scikit-learn
scikit-image
imageio
medpy
Pillow
scipy
pandas
tqdm
"""

with open("requirements.txt", mode="w") as f:
    f.write(modules)
!pip install -r requirements.txt

import argparse
import os

import numpy as np
import torch
from matplotlib import pyplot as plt
from matplotlib.backends.backend_agg import FigureCanvasAgg
from medpy.filter.binary import largest_connected_component
from skimage.io import imsave
from torch.utils.data import DataLoader
from tqdm import tqdm

from dataset import BrainSegmentationDataset as Dataset
from unet import UNet
from utils import dsc, gray2rgb, outline

import numpy as np
from PIL import Image
import glob


  Building wheel for medpy (setup.py) ... [?25l[?25hdone
  Created wheel for medpy: filename=MedPy-0.4.0-cp37-cp37m-linux_x86_64.whl size=754481 sha256=a2ba0c389c80a7a211e92c6e0b8eba67373bd19c135dfabb76b7e236fc27f41b
  Stored in directory: /root/.cache/pip/wheels/b0/57/3a/da1183f22a6afb42e11138daa6a759de233fd977a984333602
Successfully built medpy
Installing collected packages: SimpleITK, medpy
Successfully installed SimpleITK-2.1.1.2 medpy-0.4.0


#**interference**

In [28]:
def postprocess_per_volume(
    input_list, pred_list, true_list, patient_slice_index, patients
):
    volumes = {}
    num_slices = np.bincount([p[0] for p in patient_slice_index]) #各要素が何個ずつあるかを数える
    index = 0
    for p in range(len(num_slices)):
        volume_in = np.array(input_list[index : index + num_slices[p]])
        volume_pred = np.round(
            np.array(pred_list[index : index + num_slices[p]])
        ).astype(int)
        volume_pred = largest_connected_component(volume_pred)
        volume_true = np.array(true_list[index : index + num_slices[p]])
        volumes[patients[p]] = (volume_in, volume_pred, volume_true)
        index += num_slices[p]
    return volumes

def dsc_distribution(volumes):
    dsc_dict = {}
    for p in volumes:
        y_pred = volumes[p][1]
        y_true = volumes[p][2]
        dsc_dict[p] = dsc(y_pred, y_true, lcc=False)
    return dsc_dict

def plot_dsc(dsc_dist):
    y_positions = np.arange(len(dsc_dist))
    dsc_dist = sorted(dsc_dist.items(), key=lambda x: x[1])
    values = [x[1] for x in dsc_dist]
    labels = [x[0] for x in dsc_dist]
    labels = ["_".join(l.split("_")[1:-1]) for l in labels]
    fig = plt.figure(figsize=(12, 8))
    canvas = FigureCanvasAgg(fig)
    plt.barh(y_positions, values, align="center", color="skyblue")
    plt.yticks(y_positions, labels)
    plt.xticks(np.arange(0.0, 1.0, 0.1))
    plt.xlim([0.0, 1.0])
    plt.gca().axvline(np.mean(values), color="tomato", linewidth=2)
    plt.gca().axvline(np.median(values), color="forestgreen", linewidth=2)
    plt.xlabel("Dice coefficient", fontsize="x-large")
    plt.gca().xaxis.grid(color="silver", alpha=0.5, linestyle="--", linewidth=1)
    plt.tight_layout()
    canvas.draw()
    plt.close()
    s, (width, height) = canvas.print_to_buffer()
    return np.fromstring(s, np.uint8).reshape((height, width, 4))

def outline(image, mask, color):
    mask = np.round(mask)
    yy, xx = np.nonzero(mask)
    for y, x in zip(yy, xx):
        if 0.0 < np.mean(mask[max(0, y - 1) : y + 2, max(0, x - 1) : x + 2]) < 1.0:
            image[max(0, y) : y + 1, max(0, x) : x + 1] = color
    return image


import os
import random

import numpy as np
import torch
from skimage.io import imread
from torch.utils.data import Dataset

from utils import crop_sample, pad_sample, resize_sample, normalize_volume


# #dataset (datasetフォルダに入っているのでここに書く必要ないが、解説のために再定義)
class BrainSegmentationDataset(Dataset):
    """Brain MRI dataset for FLAIR abnormality segmentation"""

    in_channels = 3
    out_channels = 1

    def __init__(
        self,
        images_dir,
        transform=None,
        image_size=256,
        subset="train",
        random_sampling=True,
        validation_cases=10,
        seed=42,
    ):
        assert subset in ["all", "train", "validation"] #all, train,validation以外のsubsetにするとエラーを出す

        # read images
        volumes = {}
        masks = {}
        print("reading {} images...".format(subset))
        for (dirpath, dirnames, filenames) in os.walk(images_dir): #dirpath: 親フォルダのパス、filenames: ファイルの名前
            image_slices = []
            mask_slices = []

            for filename in sorted(
                filter(lambda f: ".tif" in f, filenames),  
                key=lambda x: int(x.split(".")[-2].split("_")[4]),
            ): #tifがついているファイルを番号順にソートする
                filepath = os.path.join(dirpath, filename)
                if "mask" in filename: #maskがファイル名についているもの
                    mask_slices.append(imread(filepath, as_gray=True))
                else: #ついていないもの
                    image_slices.append(imread(filepath))
            if len(image_slices) > 0:
                patient_id = dirpath.split("/")[-1] #TCGA_HT_8018_19970411
                volumes[patient_id] = np.array(image_slices[1:-1]) #volumes: マスクなし画像,RGB(最初と最後の1枚ずつを除外する（informationなどがある？？？？）)
                masks[patient_id] = np.array(mask_slices[1:-1]) #masks: マスク画像, grayscale

        self.patients = sorted(volumes)

        # select cases to subset
        if not subset == "all":
            random.seed(seed)
            validation_patients = random.sample(self.patients, k=validation_cases) #validation_casesで指定した分だけ無作為に抜き出す
            if subset == "validation":
                self.patients = validation_patients
            else:
                self.patients = sorted(
                    list(set(self.patients).difference(validation_patients))
                )

        print("preprocessing {} volumes...".format(subset))
        # create list of tuples (volume, mask)
        self.volumes = [(volumes[k], masks[k]) for k in self.patients]

        print("cropping {} volumes...".format(subset))
        # crop to smallest enclosing volume
        self.volumes = [crop_sample(v) for v in self.volumes]

        print("padding {} volumes...".format(subset))
        # pad to square
        self.volumes = [pad_sample(v) for v in self.volumes]

        print("resizing {} volumes...".format(subset))
        # resize
        self.volumes = [resize_sample(v, size=image_size) for v in self.volumes]

        print("normalizing {} volumes...".format(subset))
        # normalize channel-wise
        self.volumes = [(normalize_volume(v), m) for v, m in self.volumes]

        # probabilities for sampling slices based on masks
        self.slice_weights = [m.sum(axis=-1).sum(axis=-1) for v, m in self.volumes]
        self.slice_weights = [
            (s + (s.sum() * 0.1 / len(s))) / (s.sum() * 1.1) for s in self.slice_weights
        ]

        # add channel dimension to masks
        self.volumes = [(v, m[..., np.newaxis]) for (v, m) in self.volumes]

        print("done creating {} dataset".format(subset))

        # create global index for patient and slice (idx -> (p_idx, s_idx))
        num_slices = [v.shape[0] for v, m in self.volumes]
        self.patient_slice_index = list(
            zip(
                sum([[i] * num_slices[i] for i in range(len(num_slices))], []),
                sum([list(range(x)) for x in num_slices], []),
            )
        )

        self.random_sampling = random_sampling

        self.transform = transform

    def __len__(self):
        return len(self.patient_slice_index)

    def __getitem__(self, idx):
        patient = self.patient_slice_index[idx][0]
        slice_n = self.patient_slice_index[idx][1]

        if self.random_sampling:
            patient = np.random.randint(len(self.volumes))
            slice_n = np.random.choice(
                range(self.volumes[patient][0].shape[0]), p=self.slice_weights[patient]
            )

        v, m = self.volumes[patient]
        image = v[slice_n]
        mask = m[slice_n]

        if self.transform is not None:
            image, mask = self.transform((image, mask))

        # fix dimensions (C, H, W)
        image = image.transpose(2, 0, 1)
        mask = mask.transpose(2, 0, 1)

        image_tensor = torch.from_numpy(image.astype(np.float32))
        mask_tensor = torch.from_numpy(mask.astype(np.float32))

        # return tensors
        return image_tensor, mask_tensor

In [29]:
#これでインポートしてもOK
#from dataset import BrainSegmentationDataset as Dataset


#############################
weights_dir = "./weights/unet.pt" #weightの保存先
#############################

#predictionsフォルダ作成
os.makedirs("./predictions", exist_ok=True)

#deviceを定義
device = torch.device("cpu" if not torch.cuda.is_available() else "cuda:0")

#データセットとデータローダー
dataset = BrainSegmentationDataset(
    images_dir="./lgg-mri-segmentation/kaggle_3m", #kaggle dataset使用
    subset="validation",
    image_size=256,
    random_sampling=False,
)
loader = DataLoader(
    dataset, batch_size=32, drop_last=False, num_workers=1
)

with torch.set_grad_enabled(False):
    unet = UNet(in_channels=Dataset.in_channels, out_channels=Dataset.out_channels)
    
    #モデルのweightをロード
    state_dict = torch.load(weights_dir, map_location=device)
    unet.load_state_dict(state_dict)

    unet.eval()
    unet.to(device)

    input_list = []
    pred_list = []
    true_list = []

    for i, data in tqdm(enumerate(loader)):
        x, y_true = data
        x, y_true = x.to(device), y_true.to(device)

        y_pred = unet(x)
        y_pred_np = y_pred.detach().cpu().numpy()
        pred_list.extend([y_pred_np[s] for s in range(y_pred_np.shape[0])])

        y_true_np = y_true.detach().cpu().numpy()
        true_list.extend([y_true_np[s] for s in range(y_true_np.shape[0])])

        x_np = x.detach().cpu().numpy()
        input_list.extend([x_np[s] for s in range(x_np.shape[0])])

    volumes = postprocess_per_volume(
        input_list,
        pred_list,
        true_list,
        loader.dataset.patient_slice_index,
        loader.dataset.patients,
    )

    dsc_dist = dsc_distribution(volumes)

    dsc_dist_plot = plot_dsc(dsc_dist)
    imsave("./dsc.png", dsc_dist_plot)

    for p in volumes:
        x = volumes[p][0]
        y_pred = volumes[p][1]
        y_true = volumes[p][2]
        for s in range(x.shape[0]):
            image = gray2rgb(x[s, 1])  # channel 1 is for FLAIR
            image = outline(image, y_pred[s, 0], color=[255, 0, 0])
            image = outline(image, y_true[s, 0], color=[0, 255, 0])
            filename = "{}-{}.png".format(p, str(s).zfill(2))
            filepath = os.path.join("./predictions", filename)
            imsave(filepath, image)

[1;30;43mストリーミング出力は最後の 5000 行に切り捨てられました。[0m
         [0, 0, 1],
         [0, 0, 1]]],


       [[[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0],
         ...,
         [0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]],

        [[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0],
         ...,
         [0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]],

        [[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0],
         ...,
         [0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]],

        ...,

        [[0, 0, 1],
         [0, 0, 1],
         [0, 0, 1],
         ...,
         [0, 0, 1],
         [0, 0, 1],
         [0, 0, 1]],

        [[0, 0, 1],
         [0, 0, 1],
         [0, 0, 1],
         ...,
         [0, 0, 1],
         [0, 0, 1],
         [0, 0, 1]],

        [[0, 0, 1],
         [0, 0, 1],
         [0, 0, 1],
         ...,
         [0, 0, 1],
         [0, 0, 1],
         [0, 0, 1]]],


       [[[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0],
         ...,
    

IOPub data rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_data_rate_limit`.

Current values:
NotebookApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
NotebookApp.rate_limit_window=3.0 (secs)



### prediction結果を表示 ###

In [None]:
#prediction結果を表示
images = [Image.open(img) for img in glob.glob("./predictions/*")[0:30]]

cols =3
rows = len(images)//cols+1 #縦の行


fig = plt.figure(figsize=(cols*5, rows*5))


for i, im in enumerate(images):
    fig.add_subplot(rows, cols, i+1).set_title(str(i+1))
    plt.imshow(im)

plt.show()