In [None]:
import argparse
import os

import lib.medloaders as medical_loaders
import lib.medzoo as medzoo
# Lib files
import lib.utils as utils
from lib.utils.early_stopping import EarlyStopping
from lib.utils.general import prepare_input
from lib.losses3D import DiceLoss, create_loss, BCEDiceLoss

import torch
import numpy as np

from easydict import EasyDict

from lib.metric3D.DiceCoefficient import DiceCoefficient
from lib.metric3D.MeanIoU import MeanIoU

from tqdm import tqdm

import time

import shutil

import matplotlib.pyplot as plt

In [None]:
args = EasyDict({
    "batch_size" : 1,
    "dataset_name" : "miccai2008-mslesions-t2w",
    "dim" : (208,224,208),
    "nEpochs" : 100,
    "classes" : 1,
    "split" : (0.8,0.2,0.0),
    "inChannels" : 1,
    "inModalities" : 1,
    "fold_id" : '1',
    "lr" : 1e-3,
    "cuda" : True,
    "resume" : '',
    "model" : 'NESTEDDENSEUNET3D', # VNET VECT2 UNET3D DENSENET1 DENSENET2 DENSENET3 HYPERDENSENET DENSEUNET3D NESTEDUNET3D NESTEDDENSEUNET3D
    "opt" : 'adam', # sgd adam rmsprop
    "log_dir" : 'runs',
    "loadData" : False,
    "terminal_show_freq" : 10,
    "channel" : "Flair",
})

start_time = time.strftime("%Y%m%d-%H%M%S", time.localtime(time.time()))

args.result_path = rf'results/{args.dataset_name}-{start_time}'

shutil.rmtree(args.result_path, ignore_errors=True)
utils.make_dirs(args.result_path)

args.save = rf'saved_models/{args.model}_checkpoints/{args.model}-{args.dataset_name}-{start_time}'
args.save_checkpoint = os.path.join(args.save,'checkpoint.pt')
args.tb_log_dir = rf'runs/{args.model}-{args.dataset_name}-{start_time}'

utils.make_dirs(args.tb_log_dir)

In [None]:
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
## FOR REPRODUCIBILITY OF RESULTS
seed = 1777777
utils.reproducibility(args, seed)
utils.make_dirs(args.save)
utils.save_arguments(args, args.save)

# Tensor Board

In [None]:
from torch.utils.tensorboard import SummaryWriter

# 기본 `log_dir` 은 "runs"이며, 여기서는 더 구체적으로 지정하였습니다
writer = SummaryWriter(f'{args.tb_log_dir}/log')

# Augmentation

In [None]:
import torchio as tio

from torchio.transforms import (
    RandomFlip,
    RandomAffine,
    RandomElasticDeformation, 
    RandomNoise,
    RandomMotion,
    RandomBiasField,
    RescaleIntensity,
    Resample,
    ToCanonical,
    ZNormalization,
    CropOrPad,
    HistogramStandardization,
    OneOf,
    Compose,
)

### Full-aug ###
transform = tio.Compose([
    CropOrPad((208,224,208)),
    RandomMotion(p=0.2),
    RandomBiasField(p=0.3),
    RandomNoise(p=0.5),
    RandomFlip(axes=(0,)),
    RandomAffine(p=0.5),
    RandomElasticDeformation(p=0.5),
    ZNormalization(),
])

### selectiv-aug ###
# transform = tio.Compose([
#     CropOrPad((144,144,144)),
#     # RandomMotion(p=0.2),
#     # RandomBiasField(p=0.3),
#     RandomNoise(p=0.5),
#     RandomFlip(axes=(0,)),
#     # RandomAffine(p=0.5),
#     RandomElasticDeformation(p=0.5),
#     ZNormalization(),
# ])

validation_transform = tio.Compose([
    CropOrPad((208,224,208)),
    ZNormalization()
])


# Dataset

In [None]:
from lib.medloaders.miccai_2008_ms_lesions import MICCAI2008MSLESIONS
from torch.utils.data import DataLoader

train_dataset = MICCAI2008MSLESIONS(train_mode='train', 
                                dataset_path=r'D:\MS-Lesion-Dataset\MS_Lesion_Challenge',
                                classes=args.classes,
                                # channel=args.channel,
                                crop_dim=args.dim,
                                split=args.split,
                                transform=transform,
                                sample_per_image=10)

val_dataset = MICCAI2008MSLESIONS(train_mode='val',
                                dataset_path=r'D:\MS-Lesion-Dataset\MS_Lesion_Challenge',
                                classes=args.classes,
                                # channel=args.channel,
                                crop_dim=args.dim,
                                split=args.split,
                                transform=validation_transform,
                                sample_per_image=1)

params = {
        'batch_size': args.batch_size,
        'shuffle': False,
        'num_workers': 4,
        'prefetch_factor' : 2
}

train_generator = DataLoader(train_dataset, **params)
val_generator = DataLoader(val_dataset, **params)

In [None]:
pre_dataset_path = r'C:\Users\VIP444\Documents\MS-Lesion-Dataset\MS_Lesion_Challenge'

for idx, (flair, mprage, pdw, t2w, mask) in enumerate(train_dataset.list):
    flair = train_dataset.dataset_path + flair.split(pre_dataset_path)[1]
    mprage = train_dataset.dataset_path + mprage.split(pre_dataset_path)[1]
    pdw = train_dataset.dataset_path + pdw.split(pre_dataset_path)[1]
    t2w = train_dataset.dataset_path + t2w.split(pre_dataset_path)[1]
    mask = train_dataset.dataset_path + mask.split(pre_dataset_path)[1]
    train_dataset.list[idx] = (flair, mprage, pdw, t2w, mask)

for idx, (flair, mprage, pdw, t2w, mask) in enumerate(val_dataset.list):
    flair = val_dataset.dataset_path + flair.split(pre_dataset_path)[1]
    mprage = val_dataset.dataset_path + mprage.split(pre_dataset_path)[1]
    pdw = val_dataset.dataset_path + pdw.split(pre_dataset_path)[1]
    t2w = val_dataset.dataset_path + t2w.split(pre_dataset_path)[1]
    mask = val_dataset.dataset_path + mask.split(pre_dataset_path)[1]
    val_dataset.list[idx] = (flair, mprage, pdw, t2w, mask)

# Model & Optimizer

In [None]:
from lib.medzoo.ResUnet3D import ResUNet3D
from lib.medzoo.DenseUnet3D import DenseUNet3D
from lib.medzoo.Nested_DenseUnet3D import NestedDenseUNet3D
import torch.optim as optim
import torchsummaryX

# model = ResUNet3D(in_channels=args.inChannels, n_classes=args.classes, base_n_filter=8)
# model = DenseUNet3D(in_channels=args.inChannels, out_channels=args.classes)
# model = NestedUNet(in_ch=args.inChannels, out_ch=args.classes)
model = NestedDenseUNet3D(in_channels=args.inChannels, out_channels=args.classes)
optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=0.0000000001) # weight_decay=0.0000000001

device = ('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

# torchsummaryX.summary(model, torch.zeros((1,1,208,224,208)).to(device))

# Loss Function

In [None]:
criterion = BCEDiceLoss(alpha=1, beta=1)

In [None]:
def train_model(model, train_generator, val_generator ,patience, n_epochs, metrics, tensorboard=None):

    # 모델이 학습되는 동안 trainning loss를 track
    train_losses = []
    # 모델이 학습되는 동안 validation loss를 track
    valid_losses = []
    # epoch당 average training loss를 track
    avg_train_losses = []
    # epoch당 average validation loss를 track
    avg_valid_losses = []

    train_metrics = [[] for i in range(len(metrics))]
    val_metrics = [[] for i in range(len(metrics))]

    # early_stopping object의 초기화
    early_stopping = EarlyStopping(patience = patience, verbose = True,path=args.save_checkpoint)

    n_epochs_length = len(str(n_epochs))
    for epoch in range(1, n_epochs + 1):
        ###################
        # train the model #
        ###################
        model.train() # prep model for training

        progressbar = tqdm(train_generator)
        for batch, input_tuple in enumerate(progressbar, start=1):
        
            # clear the gradients of all optimized variables
            optimizer.zero_grad()

            # gpu 연산으로 변경
            input, target = prepare_input(input_tuple=input_tuple, args=args)
            input.requires_grad = True

            # forward pass: 입력된 값을 모델로 전달하여 예측 출력 계산
            output = model(input)
            output = output.type(torch.FloatTensor)
            target = target.type(torch.FloatTensor)

            output = output.cpu()
            target = target.cpu()

            # calculate the loss
            loss = criterion(output, target)

            # backward pass: 모델의 파라미터와 관련된 loss의 그래디언트 계산
            loss.backward()

            # perform a single optimization step (parameter update)
            optimizer.step()

            # record training loss
            train_losses.append(loss.item())

            output[output > 0.5] = 1.0

            # record matric
            for i,metric in enumerate(metrics):
                value = metric(output,target).item()
                train_metrics[i].append(value)

            # print metric & loss
            print_msg = {
                'loss' : f'{loss.item():.5f}',
            }

            for i,metric in enumerate(metrics):
                print_msg[metric.metric_name] = f'{train_metrics[i][-1]:.5f}'

            progressbar.set_postfix(print_msg)
            progressbar.set_description(f'[{epoch:>{n_epochs_length}}/{n_epochs:>{n_epochs_length}}][{batch}/{len(train_generator)}]')

        ######################    
        # validate the model #
        ######################
        model.eval() # prep model for evaluation

        prediction_images = []

        with torch.no_grad():
            for i,input_tuple in enumerate(val_generator) :

                input, target = prepare_input(input_tuple=input_tuple, args=args)
                input.requires_grad = True

                # forward pass: 입력된 값을 모델로 전달하여 예측 출력 계산
                output = model(input)
                output = output.type(torch.FloatTensor)
                target = target.type(torch.FloatTensor)

                input = input.to(device)
                output = output.cpu()
                target = target.cpu()

                # calculate the loss
                loss = criterion(output, target)

                # record validation loss
                valid_losses.append(loss.item())

                # save prediction result
                prediction_images.append(
                    [input.detach().cpu().numpy(), output.detach().cpu().numpy(), target.detach().cpu().numpy()])

                output[output > 0.5] = 1.0

                # metric
                for i,metric in enumerate(metrics):
                    val_metrics[i].append(metric(output, target).item())

        columns = 3
        rows = len(prediction_images)

        fig=plt.figure(figsize=( 3 * columns,3 * rows))

        i = 1
        for [input,output, target] in prediction_images:
            fig.add_subplot(rows, columns, i); i+=1
            plt.imshow(input[0,0,:,:,23], cmap='gray')
            plt.title("original Image"); plt.axis('off')

            fig.add_subplot(rows, columns, i); i+=1
            plt.imshow(output[0,0,:,:, 23], cmap='gray')
            plt.title("Predicited Image"); plt.axis('off')

            fig.add_subplot(rows, columns, i); i+=1
            plt.imshow(target[0,0,:,:, 23], cmap='gray')
            plt.title("Original Mask"); plt.axis('off')

        plt.savefig(f'{args.result_path}/{epoch}.png')
        plt.close('all')
        plt.clf()

        fig = plt.figure(figsize=( 3 * columns,3 * rows))

        i = 1
        for [input,output, target] in prediction_images:

            fig.add_subplot(rows, columns, i); i+=1
            plt.imshow(input[0,0,:,:,80], cmap='gray')
            plt.title("original Image"); plt.axis('off')

            fig.add_subplot(rows, columns, i); i+=1
            plt.imshow(output[0,0,:,:, 80], cmap='gray')
            plt.title("Predicited Image"); plt.axis('off')

            fig.add_subplot(rows, columns, i); i+=1
            plt.imshow(target[0,0,:,:, 80], cmap='gray')
            plt.title("Original Mask"); plt.axis('off')

        plt.savefig(f'{args.result_path}/{epoch}.png')
        plt.close('all')
        plt.clf()

        # print 학습/검증 statistics
        # epoch당 평균 loss 계산
        train_loss = np.average(train_losses)
        valid_loss = np.average(valid_losses)
        avg_train_losses.append(train_loss)
        avg_valid_losses.append(valid_loss)

        # epoch당 평균 metric 계산
        train_metric = {}
        val_metric = {}

        for i, metric in enumerate(metrics):
            train_metric[metric.metric_name] = np.average(train_metrics[i])
            val_metric[metric.metric_name] = np.average(val_metrics[i])


        # epoch 결과 출력
        print_msg = (f'[{epoch:>{n_epochs_length}}/{n_epochs:>{n_epochs_length}}] ' +
                     f'loss: {train_loss:.5f} ' +
                     f'val_loss: {valid_loss:.5f} ')

        for key in train_metric.keys():
            print_msg += f'{key} : {train_metric[key]:.5f} '

        for key in val_metric.keys():
            print_msg += f'val_{key} : {val_metric[key]:.5f} '

        # tensorboard
        if tensorboard:
            writer.add_scalars("loss", 
                                {
                                    'train' : train_loss,
                                    'val' : valid_loss
                                 },
                                epoch)

            for train_key, val_key in zip(train_metric.keys(), val_metric.keys()):
                writer.add_scalars(train_key,
                                {
                                    'train' : train_metric[train_key],
                                     'val' : val_metric[val_key]
                                }, 
                                 epoch)

        # clear lists to track next epoch
        train_losses = []
        valid_losses = []
        train_metrics = [[] for i in range(len(metrics))]
        val_metrics = [[] for i in range(len(metrics))]

        # early_stopping는 validation loss가 감소하였는지 확인이 필요하며,
        # 만약 감소하였을경우 현제 모델을 checkpoint로 만든다.
        early_stopping(valid_loss, model)

        # if early_stopping.early_stop:
        #     print("Early stopping")
        #     break

   # best model이 저장되어있는 last checkpoint를 로드한다.
    model.load_state_dict(torch.load(args.save_checkpoint))
    writer.close()

    return  model, avg_train_losses, avg_valid_losses

# Metric

In [None]:
meanIoU = MeanIoU(metric_name="MeanIoU")
diceCoefficient = DiceCoefficient(metric_name="DiCE")

In [None]:
model, train_loss, valid_loss = train_model(
    model = model,
    train_generator = train_generator,
    val_generator = val_generator,
    patience = 5,
    n_epochs = args.nEpochs,
    metrics=[diceCoefficient, meanIoU],
    tensorboard=writer,
)

In [None]:
# torch.save(model.state_dict(), r'C:\Users\VIP444\Documents\Github\MS-Lesions-Pytorch\saved_models\NESTEDDENSEUNET3D_checkpoints\NESTEDDENSEUNET3D-miccai2008-mslesions-t2w-20220921-111810' + '\model_state_dict.pt')
torch.save(model, r'C:\Users\VIP444\Documents\Github\MS-Lesions-Pytorch\saved_models\NESTEDDENSEUNET3D_checkpoints\NESTEDDENSEUNET3D-miccai2008-mslesions-t2w-20220921-111810' + '\model.pt')

In [None]:
test_model = torch.load(r'C:\Users\VIP444\Documents\Github\MS-Lesions-Pytorch\saved_models\NESTEDDENSEUNET3D_checkpoints\NESTEDDENSEUNET3D-miccai2008-mslesions-flair-20220919-114041' + '\model.pt')

In [None]:
import torch
import torchio as tio
from lib.medzoo.Unet3D import Unet3D
from lib.medzoo.ResUnet3D import ResUNet3D
from lib.medzoo.Nested_DenseUnet3D import NestedDenseUNet3D

path = r'C:\Users\VIP444\Documents\Github\MS-Lesions-Pytorch\saved_models\NESTEDDENSEUNET3D_checkpoints\NESTEDDENSEUNET3D-miccai2008-mslesions-flair-20220919-114041' + '\model_state_dict.pt'
check_point = r'C:\Users\VIP444\Documents\Github\MS-Lesions-Pytorch\saved_models\NESTEDDENSEUNET3D_checkpoints\NESTEDDENSEUNET3D-miccai2008-mslesions-flair-20220419-180313\checkpoint.pt'
test_model = NestedDenseUNet3D(in_channels=args.inChannels, out_channels=args.classes)

if torch.cuda.is_available():
    test_model = test_model.cuda()

test_model.load_state_dict(torch.load(path))

In [None]:
from lib.medloaders.miccai_2008_ms_lesions import MICCAI2008MSLESIONS
from torch.utils.data import DataLoader

test_transform = tio.Compose([
    tio.CropOrPad((208,224,208)),
    tio.ZNormalization()
])

test_dataset = MICCAI2008MSLESIONS(
                                train_mode='test',
                                dataset_path=r'D:\MS-Lesion-Dataset\MS_Lesion_Challenge',
                                classes=args.classes, 
                                crop_dim=args.dim,
                                transform=test_transform
                                )

params = {
        'batch_size': args.batch_size,
        'shuffle': False,
        'num_workers': 4
        }

test_generator = DataLoader(test_dataset, **params)

In [None]:
test_dataset.list[0]

In [None]:
pre_dataset_path = r'C:\Users\VIP444\Documents\MS-Lesions-Segmentation\MS-Lesions-Pytorch\datasets\MICCAI_2008_MS_Lesions'

for idx, (flair, mprage, pdw, t2w) in enumerate(test_dataset.list):
    flair = test_dataset.dataset_path + flair.split(pre_dataset_path)[1]
    mprage = test_dataset.dataset_path + mprage.split(pre_dataset_path)[1]
    pdw = test_dataset.dataset_path + pdw.split(pre_dataset_path)[1]
    t2w = test_dataset.dataset_path + t2w.split(pre_dataset_path)[1]
    test_dataset.list[idx] = (flair, mprage, pdw, t2w)

In [None]:
import os
import nibabel as nib
from lib import utils
import matplotlib.pyplot as plt

output_name = 'KIT'

test_result =r'C:\Users\VIP444\Desktop\test_result'
utils.make_dirs(test_result)

test_model.eval() # prep model for evaluation
prediction_images = []

transform = tio.CropOrPad(args.dim)

with torch.no_grad():
    for i,input_tuple in enumerate(test_generator):
        # forward pass: 입력된 값을 모델로 전달하여 예측 출력 계산
        # [1,1,208, 224, 208]
        # [batch, channel, width, height, depth]
        input = input_tuple[0].cuda()

        target = input_tuple[1].cuda()
        # path = input_tuple[1][0]

        input.requires_grad = True
        output = test_model(input)

        input = input.cpu()
        output = output.cpu()
        target = target.cpu()

        # CLAMP 01
        output[output >= 0.5] = 1.0
        output[output < 0.5] = 0.0

        # torch 
        output = torch.squeeze(output,0)
        # crop [1,208,224,208] -> [1,181,217,181]
        # output = transform(output)
        # squeeze [181,217, 181]
        output = torch.squeeze(output,0)
        target = torch.squeeze(target,0)
        target = torch.squeeze(target,0)

        # save prediction result
        prediction_images.append(
            [input.detach().numpy(),output.detach().numpy(), target.detach().numpy()])

        # path_basename = '_'.join(os.path.basename(path).split('_')[:2])
        # affine = nib.load(path).affine

        # nii_image = nib.Nifti1Image(output.detach().numpy(), test_dataset.affine)
        # nib.save(nii_image, f'{test_result}/{path_basename}_KIT_{i}.nii')

In [None]:
columns = 2
rows = len(prediction_images)

fig=plt.figure(figsize=( 3 * columns,3 * rows))

i = 1
for [_input,_output] in prediction_images:
    fig.add_subplot(rows, columns, i); i+=1
    plt.imshow(_input[0,0,:,:,_input.shape[-1] // 2], cmap='gray')
    plt.title("original Image"); plt.axis('off')

    fig.add_subplot(rows, columns, i); i+=1
    plt.imshow(_output[:,:,_output.shape[-1] // 2], cmap='gray')
    plt.title("Predicited Image"); plt.axis('off')

    plt.savefig(f'{test_result}/test_KIT_{i}.png')
    plt.clf()