<a href="https://colab.research.google.com/github/ykitaguchi77/Article_implementation/blob/main/3D-UNET_brain_segmentation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#**Brain_segmentation_pytorch 3D-UNET**

CTの複数画像を立体的にsegmentationする

GitHub: hhttps://github.com/mateuszbuda/brain-segmentation-pytorch

WebPage: https://towardsdatascience.com/creating-and-training-a-u-net-model-with-pytorch-for-2d-3d-semantic-segmentation-model-building-6ab09d6a0862


In [1]:
!git clone https://github.com/mateuszbuda/brain-segmentation-pytorch.git

#作業フォルダを移動
%cd brain-segmentation-pytorch

Cloning into 'brain-segmentation-pytorch'...
remote: Enumerating objects: 97, done.[K
remote: Counting objects: 100% (13/13), done.[K
remote: Compressing objects: 100% (12/12), done.[K
remote: Total 97 (delta 6), reused 2 (delta 1), pack-reused 84[K
Unpacking objects: 100% (97/97), done.
/content/brain-segmentation-pytorch


#Kaggle_3M datasetのダウンロード

まずKaggleに登録してAPIの使用許可を申請する必要あり。詳細は下記を参考に。

https://www.currypurin.com/entry/2018/kaggle-api

In [8]:
# mount Gdrive
from google.colab import drive
drive.mount('/content/drive')

# kaggle ライブラリのインストール
!pip install kaggle

# 一時フォルダに .kaggleフォルダを作成
!mkdir ~/.kaggle

# MyDrive の kaggle.json　(permissionファイル) を一時フォルダ内の .kaggleフォルダにコピー
!cp /content/drive/MyDrive/Kaggle/kaggle.json ~/.kaggle/

# アクセス権限の設定
!chmod 600 ~/.kaggle/kaggle.json

!mkdir ~/.kaggle

# zipファイルのダウンロード
!kaggle datasets download -d mateuszbuda/lgg-mri-segmentation
#!kaggle competitions download -c rsna-2022-cervical-spine-fracture-detection -p /content/drive/MyDrive/Kaggle
# 解凍
!unzip ./lgg-mri-segmentation.zip -d ./

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
mkdir: cannot create directory ‘/root/.kaggle’: File exists
mkdir: cannot create directory ‘/root/.kaggle’: File exists
lgg-mri-segmentation.zip: Skipping, found more recently modified local copy (use --force to force download)
Archive:  ./lgg-mri-segmentation.zip
replace ./kaggle_3m/README.md? [y]es, [n]o, [A]ll, [N]one, [r]ename: 

# ここから

In [None]:
#Requirementsからモジュールをインストール
#※バージョンがconflictしまくるのでバージョン指定なし
#medpy以外はすでに入っている

modules = """
numpy
tensorflow
scikit-learn
scikit-image
imageio
medpy
Pillow
scipy
pandas
tqdm
"""

with open("requirements.txt", mode="w") as f:
    f.write(modules)
!pip install -r requirements.txt

import argparse
import os

import numpy as np
import torch
from matplotlib import pyplot as plt
from matplotlib.backends.backend_agg import FigureCanvasAgg
from medpy.filter.binary import largest_connected_component
from skimage.io import imsave
from torch.utils.data import DataLoader
from tqdm import tqdm

from dataset import BrainSegmentationDataset as Dataset
from unet import UNet
from utils import dsc, gray2rgb, outline

import numpy as np
from PIL import Image
import glob


#**Modules**

In [5]:
def postprocess_per_volume(
    input_list, pred_list, true_list, patient_slice_index, patients
):
    volumes = {}
    num_slices = np.bincount([p[0] for p in patient_slice_index]) #各要素が何個ずつあるかを数える
    index = 0
    for p in range(len(num_slices)):
        volume_in = np.array(input_list[index : index + num_slices[p]])
        volume_pred = np.round(
            np.array(pred_list[index : index + num_slices[p]])
        ).astype(int)
        volume_pred = largest_connected_component(volume_pred)
        volume_true = np.array(true_list[index : index + num_slices[p]])
        volumes[patients[p]] = (volume_in, volume_pred, volume_true)
        index += num_slices[p]
    return volumes

def dsc_distribution(volumes):
    dsc_dict = {}
    for p in volumes:
        y_pred = volumes[p][1]
        y_true = volumes[p][2]
        dsc_dict[p] = dsc(y_pred, y_true, lcc=False)
    return dsc_dict

def plot_dsc(dsc_dist):
    y_positions = np.arange(len(dsc_dist))
    dsc_dist = sorted(dsc_dist.items(), key=lambda x: x[1])
    values = [x[1] for x in dsc_dist]
    labels = [x[0] for x in dsc_dist]
    labels = ["_".join(l.split("_")[1:-1]) for l in labels]
    fig = plt.figure(figsize=(12, 8))
    canvas = FigureCanvasAgg(fig)
    plt.barh(y_positions, values, align="center", color="skyblue")
    plt.yticks(y_positions, labels)
    plt.xticks(np.arange(0.0, 1.0, 0.1))
    plt.xlim([0.0, 1.0])
    plt.gca().axvline(np.mean(values), color="tomato", linewidth=2)
    plt.gca().axvline(np.median(values), color="forestgreen", linewidth=2)
    plt.xlabel("Dice coefficient", fontsize="x-large")
    plt.gca().xaxis.grid(color="silver", alpha=0.5, linestyle="--", linewidth=1)
    plt.tight_layout()
    canvas.draw()
    plt.close()
    s, (width, height) = canvas.print_to_buffer()
    return np.fromstring(s, np.uint8).reshape((height, width, 4))

def outline(image, mask, color):
    mask = np.round(mask)
    yy, xx = np.nonzero(mask)
    for y, x in zip(yy, xx):
        if 0.0 < np.mean(mask[max(0, y - 1) : y + 2, max(0, x - 1) : x + 2]) < 1.0:
            image[max(0, y) : y + 1, max(0, x) : x + 1] = color
    return image


import os
import random

import numpy as np
import torch
from skimage.io import imread
from skimage.exposure import rescale_intensity
from skimage.transform import resize
from torch.utils.data import Dataset

from utils import crop_sample, pad_sample, resize_sample, normalize_volume


# #dataset (datasetフォルダに入っているのでここに書く必要ないが、解説のために再定義)
class BrainSegmentationDataset(Dataset):
    """Brain MRI dataset for FLAIR abnormality segmentation"""

    in_channels = 3
    out_channels = 1

    def __init__(
        self,
        images_dir,
        transform=None,
        image_size=256,
        subset="train",
        random_sampling=True,
        validation_cases=10,
        seed=42,
    ):
        assert subset in ["all", "train", "validation"] #all, train,validation以外のsubsetにするとエラーを出す

        # read images
        volumes = {}
        masks = {}
        print("reading {} images...".format(subset))
        for (dirpath, dirnames, filenames) in os.walk(images_dir): #dirpath: 親フォルダのパス、filenames: ファイルの名前
            image_slices = []
            mask_slices = []

            for filename in sorted(
                filter(lambda f: ".tif" in f, filenames),  
                key=lambda x: int(x.split(".")[-2].split("_")[4]),
            ): #tifがついているファイルを番号順にソートする
                filepath = os.path.join(dirpath, filename)
                if "mask" in filename: #maskがファイル名についているもの
                    mask_slices.append(imread(filepath, as_gray=True))
                else: #ついていないもの
                    image_slices.append(imread(filepath))
            if len(image_slices) > 0:
                patient_id = dirpath.split("/")[-1] #TCGA_HT_8018_19970411
                volumes[patient_id] = np.array(image_slices[1:-1]) #volumes: マスクなし画像,RGB(最初と最後の1枚ずつを除外する（informationなどがある？？？？）)
                masks[patient_id] = np.array(mask_slices[1:-1]) #masks: マスク画像, grayscale

        print(len(volumes))
        print(len(masks))

        self.patients = sorted(volumes) #patientsのリストをソート

        # select cases to subset
        if not subset == "all":
            random.seed(seed)
            validation_patients = random.sample(self.patients, k=validation_cases) #validation_casesで指定した分だけ無作為に抜き出す
            if subset == "validation":
                self.patients = validation_patients
            else:
                self.patients = sorted(
                    list(set(self.patients).difference(validation_patients))
                )

        print("preprocessing {} volumes...".format(subset))
        
        # create list of tuples (volume, mask)
        self.volumes = [(volumes[k], masks[k]) for k in self.patients] #スライスの数だけvolumesとmaskのペアを作る


        ##############################################################
        # print(f"volumes: {volumes[self.patients[2]].shape}") #(18,256,256,3) --> 枚数、縦、横、RGB
        # print(f"masks: {masks[self.patients[2]].shape}") #(18,256,256)
        ##############################################################
        ##################################################################
        print(f"normalize_volume: {self.volumes[2][0].shape}") #2症例目の0(volume)の形状 (18,256,256,3) --> 枚数、縦、横、RGB
        print(f"normalize_masks: {self.volumes[2][1].shape}") #2症例目の1(masks)の形状 (18,256,256)
        #################################################################

        print("cropping {} volumes...".format(subset))
        # crop to smallest enclosing volume ...何をしているのかいまいちよくわからない...
        self.volumes = [crop_sample(v) for v in self.volumes]


        print("padding {} volumes...".format(subset))
        # pad to square  横長の画像を正方形に
        self.volumes = [pad_sample(v) for v in self.volumes]

        print("resizing {} volumes...".format(subset))
        # resize #256*256にリサイズ
        self.volumes = [resize_sample(v, size=image_size) for v in self.volumes]

        print("normalizing {} volumes...".format(subset))
        # normalize channel-wise
        self.volumes = [(normalize_volume(v), m) for v, m in self.volumes]  #v: volume, m: mask、上下の10%を除去してノーマライズ

        ##################################################################
        print(f"normalized_volume_shape: {self.volumes[0][0].shape}") #0症例目の0(volume)の形状(18,256,256,3)
        print(f"normalized_masks_shape: {self.volumes[0][1].shape}") #0症例目の1(masks)の形状(18,256,256)
        #################################################################


        # probabilities for sampling slices based on masks
        self.slice_weights = [m.sum(axis=-1).sum(axis=-1) for v, m in self.volumes] #(10)
        print(len(self.slice_weights))
        self.slice_weights = [
            (s + (s.sum() * 0.1 / len(s))) / (s.sum() * 1.1) for s in self.slice_weights #(10)
        ]
        print(f"n_slice_weights: {len(self.slice_weights)}")


        # add channel dimension to masks
        self.volumes = [(v, m[..., np.newaxis]) for (v, m) in self.volumes] 
        ##################################################################
        print(f"final_volume_shape: {self.volumes[0][0].shape}") #0症例目の0(volume)の形状 (18,256,256,3) --> 枚数、縦、横、RGB
        print(f"final_masks_shape: {self.volumes[0][1].shape}") #2症例目の1(masks)の形状 (18,256,256,1) <--最後に1の次元を追加
        #################################################################
        #################################
        print(len(self.volumes)) #10
        #################################

        print("done creating {} dataset".format(subset))

        # create global index for patient and slice (idx -> (p_idx, s_idx))
        # [(0,0), (0,1),(0,2)...(0,18), (1,0), (1,1), (1,3), ...(1,18)...] 
        num_slices = [v.shape[0] for v, m in self.volumes] #スライス数
        self.patient_slice_index = list(
            zip(
                sum([[i] * num_slices[i] for i in range(len(num_slices))], []), 
                sum([list(range(x)) for x in num_slices], []), 
            )
        )

        self.random_sampling = random_sampling

        self.transform = transform

    def __len__(self):
        return len(self.patient_slice_index)

    def __getitem__(self, idx):
        patient = self.patient_slice_index[idx][0]
        slice_n = self.patient_slice_index[idx][1]
        #print(f"idx:{idx}, patient:{patient}, slice_n:{slice_n}")

        if self.random_sampling:
            patient = np.random.randint(len(self.volumes))
            slice_n = np.random.choice(
                range(self.volumes[patient][0].shape[0]), p=self.slice_weights[patient]
            )
 
        v, m = self.volumes[patient]
        image = v[slice_n] #(256,256,3)
        mask = m[slice_n] #(256,256,1)

        if self.transform is not None:
            image, mask = self.transform((image, mask))

        # fix dimensions (C, H, W)
        image = image.transpose(2, 0, 1)
        mask = mask.transpose(2, 0, 1)

        image_tensor = torch.from_numpy(image.astype(np.float32))
        mask_tensor = torch.from_numpy(mask.astype(np.float32))

        # return tensors
        return image_tensor, mask_tensor


#以下utilisより抜粋（utilisよりimportされているのでここで定義しなくてもOK）
def crop_sample(x):
    volume, mask = x
    volume[volume < np.max(volume) * 0.1] = 0 #最高densityの0.1倍未満のものはゼロに切り捨てる
    z_projection = np.max(np.max(np.max(volume, axis=-1), axis=-1), axis=-1)
    z_nonzero = np.nonzero(z_projection)
    z_min = np.min(z_nonzero)
    z_max = np.max(z_nonzero) + 1
    y_projection = np.max(np.max(np.max(volume, axis=0), axis=-1), axis=-1)
    y_nonzero = np.nonzero(y_projection)
    y_min = np.min(y_nonzero)
    y_max = np.max(y_nonzero) + 1
    x_projection = np.max(np.max(np.max(volume, axis=0), axis=0), axis=-1)
    x_nonzero = np.nonzero(x_projection)
    x_min = np.min(x_nonzero)
    x_max = np.max(x_nonzero) + 1
    return (
        volume[z_min:z_max, y_min:y_max, x_min:x_max],
        mask[z_min:z_max, y_min:y_max, x_min:x_max],
    )

def pad_sample(x): #横長の画像を正方形にする
    volume, mask = x
    a = volume.shape[1]
    b = volume.shape[2]
    if a == b:
        return volume, mask
    diff = (max(a, b) - min(a, b)) / 2.0
    if a > b:
        padding = ((0, 0), (0, 0), (int(np.floor(diff)), int(np.ceil(diff))))
    else:
        padding = ((0, 0), (int(np.floor(diff)), int(np.ceil(diff))), (0, 0))
    mask = np.pad(mask, padding, mode="constant", constant_values=0)
    padding = padding + ((0, 0),)
    volume = np.pad(volume, padding, mode="constant", constant_values=0)
    return volume, mask

def normalize_volume(volume):
    p10 = np.percentile(volume, 10)
    p99 = np.percentile(volume, 99)
    volume = rescale_intensity(volume, in_range=(p10, p99)) #skimageを用いてintensityの上下を切ってnormalizeする
    m = np.mean(volume, axis=(0, 1, 2))
    s = np.std(volume, axis=(0, 1, 2))
    volume = (volume - m) / s
    return volume

def resize_sample(x, size=256): #skimage.transform.resizeを用いてsize=256にリサイズ）
    volume, mask = x
    v_shape = volume.shape
    out_shape = (v_shape[0], size, size)
    mask = resize(
        mask,
        output_shape=out_shape,
        order=0,
        mode="constant",
        cval=0,
        anti_aliasing=False,
    ) #order=0: nearest neighbor
    out_shape = out_shape + (v_shape[3],)
    volume = resize(
        volume,
        output_shape=out_shape,
        order=2,
        mode="constant",
        cval=0,
        anti_aliasing=False,
    ) #order=2: bi-quadratic
    return volume, mask


#**Interference**

In [None]:
#これをインポートすれば上のmoduleは不要
#from dataset import BrainSegmentationDataset as Dataset


#############################
#weights_dir = "./weights/unet.pt" #weightの保存先
weights_dir = "/content/drive/MyDrive/Kaggle/Brain_segmentation_3dUNET/unet.pt" #下でtrainingしたモデルを用いる場合
#############################

#predictionsフォルダ作成
os.makedirs("./predictions", exist_ok=True)

#deviceを定義
device = torch.device("cpu" if not torch.cuda.is_available() else "cuda:0")

#データセットとデータローダー
dataset = BrainSegmentationDataset(
    images_dir="/content/brain-segmentation-pytorch/kaggle_3m", #kaggle dataset使用
    subset="validation",
    image_size=256,
    random_sampling=False,
)
loader = DataLoader(
    dataset, batch_size=32, drop_last=False, num_workers=0
)

with torch.set_grad_enabled(False):
    unet = UNet(in_channels=BrainSegmentationDataset.in_channels, out_channels=BrainSegmentationDataset.out_channels)
    
    #モデルのweightをロード
    state_dict = torch.load(weights_dir, map_location=device)
    unet.load_state_dict(state_dict)

    unet.eval()
    unet.to(device)

    input_list = []
    pred_list = []
    true_list = []

    for i, data in tqdm(enumerate(loader)):
        x, y_true = data
        x, y_true = x.to(device), y_true.to(device)

        y_pred = unet(x)
        y_pred_np = y_pred.detach().cpu().numpy()
        pred_list.extend([y_pred_np[s] for s in range(y_pred_np.shape[0])])

        y_true_np = y_true.detach().cpu().numpy()
        true_list.extend([y_true_np[s] for s in range(y_true_np.shape[0])])

        x_np = x.detach().cpu().numpy()
        input_list.extend([x_np[s] for s in range(x_np.shape[0])])

    volumes = postprocess_per_volume(
        input_list,
        pred_list,
        true_list,
        loader.dataset.patient_slice_index,
        loader.dataset.patients,
    )

    dsc_dist = dsc_distribution(volumes)

    dsc_dist_plot = plot_dsc(dsc_dist)
    imsave("./dsc.png", dsc_dist_plot)

    for p in volumes:
        x = volumes[p][0]
        y_pred = volumes[p][1]
        y_true = volumes[p][2]
        for s in range(x.shape[0]):
            image = gray2rgb(x[s, 1])  # channel 1 is for FLAIR
            image = outline(image, y_pred[s, 0], color=[255, 0, 0]) #赤: prediction
            image = outline(image, y_true[s, 0], color=[0, 255, 0]) #緑：groundtruth
            filename = "{}-{}.png".format(p, str(s).zfill(2))
            filepath = os.path.join("./predictions", filename)
            imsave(filepath, image)

reading validation images...
110
110
preprocessing validation volumes...
normalize_volume: (18, 256, 256, 3)
normalize_masks: (18, 256, 256)
cropping validation volumes...
padding validation volumes...
resizing validation volumes...
normalizing validation volumes...
normalized_volume_shape: (28, 256, 256, 3)
normalized_masks_shape: (28, 256, 256)
10
n_slice_weights: 10
final_volume_shape: (28, 256, 256, 3)
final_masks_shape: (28, 256, 256, 1)
10
done creating validation dataset


11it [00:08,  1.27it/s]


### prediction結果を表示 ###

3列で全てを表示

In [None]:
#prediction結果を表示 (赤：pred、緑：groundtruth)
images = [Image.open(img) for img in glob.glob("./predictions/*")[0:90]]

cols =3
rows = len(images)//cols+1 #縦の行


fig = plt.figure(figsize=(cols*5, rows*5))


for i, im in enumerate(images):
    fig.add_subplot(rows, cols, i+1).set_title(str(i+1))
    plt.imshow(im)

plt.show()

#**Train**

※loggerの部分はtensorflow1→tensorflow2に書き直さないといけないので、省略しています

In [11]:
import json
import os

import numpy as np
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm

#from dataset import BrainSegmentationDataset as Dataset #上で定義し直しているのでインポートしない
#from logger import Logger
from loss import DiceLoss
from transform import transforms
from unet import UNet
from utils import log_images, dsc

from statistics import mean
import time



def dsc_per_volume(validation_pred, validation_true, patient_slice_index):
    dsc_list = []
    num_slices = np.bincount([p[0] for p in patient_slice_index]) #要素の数を数える、すなわち患者毎の枚数をリスト化する
    index = 0
    for p in range(len(num_slices)):
        y_pred = np.array(validation_pred[index : index + num_slices[p]])
        y_true = np.array(validation_true[index : index + num_slices[p]])
        dsc_list.append(dsc(y_pred, y_true))
        index += num_slices[p]
    return dsc_list


def log_loss_summary(logger, loss, step, prefix=""):
    logger.scalar_summary(prefix + "loss", np.mean(loss), step)



#main

#時間測定開始
time_start = time.perf_counter()

#random_seed
np.random.seed = 42

#parameters
batch_size = 8
image_size = 256
num_workers = 0
lr = 0.0001
n_epochs = 100
vis_freq = 10 #frequency of saving images to log file
vis_images = 200 #number of visualization images to save in log file
#weight_path = "./weights"
weights_dir = "/content/drive/MyDrive/Kaggle/Brain_segmentation_3dUNET"
load_weight = False #Gdriveに保存しているweightをロードするかどうか
earlystopping = 10 #０にするとearlystoppingをoffにする

#フォルダ作成
os.makedirs("./weights", exist_ok=True)
os.makedirs("./logs", exist_ok=True)

# snapshotargs()

#deviceを定義
device = torch.device("cpu" if not torch.cuda.is_available() else "cuda:0")


if 'loader_valid' in globals(): #データローダー作成済みなら省略する（読み込みに時間かかるので）
    pass
else:
    #dataset
    dataset_train = BrainSegmentationDataset(
        images_dir="/content/brain-segmentation-pytorch/kaggle_3m",
        subset="train",
        image_size=image_size,
        transform=transforms(scale=0.05, angle=15, flip_prob=0.5), #scale, angle: augmentationの拡大縮小および回転角度
    )
    dataset_valid = BrainSegmentationDataset(
        images_dir="/content/brain-segmentation-pytorch/kaggle_3m",
        subset="validation",
        image_size=image_size,
        random_sampling=False,
    )

    #dataloader
    loader_train = DataLoader(
        dataset_train,
        batch_size=batch_size,
        shuffle=True,
        drop_last=True,
        num_workers=num_workers,
        worker_init_fn= None,
    )
    loader_valid = DataLoader(
        dataset_valid,
        batch_size=batch_size,
        drop_last=False,
        num_workers=num_workers,
        worker_init_fn= None,
    )
    loaders = {"train": loader_train, "valid": loader_valid}

    print(f"elapsed_time: {time.perf_counter() - time_start}")
    print("")

#model
unet = UNet(in_channels=BrainSegmentationDataset.in_channels, out_channels=BrainSegmentationDataset.out_channels)
unet.to(device)

if load_weight is True:
    unet.load_state_dict(torch.load (os.path.join(weights_dir, "unet.pt")))
    print("loading weight...")
else:
    pass

dsc_loss = DiceLoss()
best_validation_dsc = 0.0

optimizer = optim.Adam(unet.parameters(), lr=lr)

#logger = Logger("./logs")
loss_train = []
loss_valid = []

step = 0
earlystopping_counter = 0
time_start = time.perf_counter() #時間測定開始
for epoch in tqdm(range(n_epochs), total=n_epochs): 
    for phase in ["train", "valid"]:
        if phase == "train":
            unet.train()
        else:
            unet.eval()

        validation_pred = []
        validation_true = []

        for i, data in enumerate(loaders[phase]):
            if phase == "train":
                step += 1

            x, y_true = data #x: 入力画像、y_true: ラベル画像
            x, y_true = x.to(device), y_true.to(device)

            optimizer.zero_grad()

            with torch.set_grad_enabled(phase == "train"):
                y_pred = unet(x)

                loss = dsc_loss(y_pred, y_true)

                if phase == "valid":
                    loss_valid.append(loss.item())
                    val_loss = loss.item() #途中経過表示用
                    y_pred_np = y_pred.detach().cpu().numpy()
                    validation_pred.extend(
                        [y_pred_np[s] for s in range(y_pred_np.shape[0])]
                    )
                    y_true_np = y_true.detach().cpu().numpy()
                    validation_true.extend(
                        [y_true_np[s] for s in range(y_true_np.shape[0])]
                    )
                    if (epoch % vis_freq == 0) or (epoch == n_epochs - 1):
                        if i * batch_size < vis_images:
                            tag = "image/{}".format(i)
                            num_images = vis_images - i * batch_size
                            # logger.image_list_summary(
                            #     tag,
                            #     log_images(x, y_true, y_pred)[:num_images],
                            #     step,
                            # )

                if phase == "train":
                    loss_train.append(loss.item())
                    train_loss = loss.item() #途中経過表示用
                    loss.backward()
                    optimizer.step()

            if phase == "train" and (step + 1) % 10 == 0:
                #log_loss_summary(logger, loss_train, step)
                loss_train = []

        if phase == "valid":
            #log_loss_summary(logger, loss_valid, step, prefix="val_")
            mean_dsc = np.mean(
                dsc_per_volume(
                    validation_pred,
                    validation_true,
                    loader_valid.dataset.patient_slice_index,
                )
            )
            # logger.scalar_summary("val_dsc", mean_dsc, step)
            
            ####途中経過####
            print("")
            print(f"epoch: {str(epoch+1)}")
            print(f"train_loss: {train_loss:.5f}")
            print(f"val_loss: {val_loss:.5f}")
            print(f"val_dsc: {mean_dsc:.5f}") 
            print(f"elapsed_time: {time.perf_counter() - time_start:.5f}")
            
            if mean_dsc > best_validation_dsc:
                print(f'mean_dsc increased ({best_validation_dsc:5f} --> {mean_dsc:5f}). Saving model...')
                best_validation_dsc = mean_dsc
                torch.save(unet.state_dict(), os.path.join(weights_dir, "unet.pt"))
                earlystopping_counter = 0 #reset earlystopping
            else:
                earlystopping_counter += 1
                if earlystopping >= 1:
                    print(f"earlystopping_counter: {earlystopping_counter}")
                    if earlystopping_counter == earlystopping:
                          print("The training stopped with earlystopping!")
                          break
            print("")
            loss_valid = []

print("Best validation mean DSC: {:4f}".format(best_validation_dsc))



  1%|          | 1/100 [01:36<2:39:50, 96.87s/it]


epoch: 1
train_loss: 0.78171
val_loss: 0.90799
val_dsc: 0.59761
elapsed_time: 96.75017
mean_dsc increased (0.000000 --> 0.597606). Saving model...



  2%|▏         | 2/100 [03:12<2:37:05, 96.18s/it]


epoch: 2
train_loss: 0.78682
val_loss: 0.87244
val_dsc: 0.71170
elapsed_time: 192.45216
mean_dsc increased (0.597606 --> 0.711697). Saving model...



  3%|▎         | 3/100 [04:47<2:34:27, 95.54s/it]


epoch: 3
train_loss: 0.47919
val_loss: 0.80471
val_dsc: 0.77219
elapsed_time: 287.22216
mean_dsc increased (0.711697 --> 0.772185). Saving model...



  4%|▍         | 4/100 [06:22<2:32:41, 95.43s/it]


epoch: 4
train_loss: 0.45336
val_loss: 0.68390
val_dsc: 0.78878
elapsed_time: 382.49654
mean_dsc increased (0.772185 --> 0.788783). Saving model...



  5%|▌         | 5/100 [07:59<2:31:41, 95.81s/it]


epoch: 5
train_loss: 0.26241
val_loss: 0.52789
val_dsc: 0.80667
elapsed_time: 478.97162
mean_dsc increased (0.788783 --> 0.806667). Saving model...



  5%|▌         | 5/100 [08:45<2:46:24, 105.10s/it]


KeyboardInterrupt: ignored