In [1]:
import argparse
import os

import lib.medloaders as medical_loaders
import lib.medzoo as medzoo
# Lib files
import lib.utils as utils
from lib.utils.early_stopping import EarlyStopping
from lib.utils.general import prepare_input
from lib.losses3D import DiceLoss, create_loss,BCEDiceLoss

import torch
import numpy as np

from easydict import EasyDict

from lib.metric3D.DiceCoefficient import DiceCoefficient
from lib.metric3D.MeanIoU import MeanIoU

from tqdm import tqdm

import time

import shutil

import matplotlib.pyplot as plt

In [2]:
args = EasyDict({
    "batchSz" : 1,
    "dataset_name" : "miccai2018",
    "dim" : (144,144,144),
    "nEpochs" : 50,
    "classes" : 1,
    "samples_train" : 100,
    "samples_val" : 10,
    "split" : 0.8,
    "inChannels" : 1,
    "inModalities" : 1,
    "fold_id" : '1',
    "lr" : 1e-4,
    "cuda" : True,
    "resume" : '',
    "model" : 'UNET3D', # VNET VECT2 UNET3D DENSENET1 DENSENET2 DENSENET3 HYPERDENSENET
    "opt" : 'adam', # sgd adam rmsprop
    "log_dir" : 'runs',
    "loadData" : False,
    "terminal_show_freq" : 10,
    "result_path" : f'results/{time.strftime("%Y%m%d-%H%M%S", time.localtime(time.time()))}'
})

shutil.rmtree(args.result_path, ignore_errors=True)
os.mkdir(args.result_path)

args.save = f'saved_models/{args.model}_checkpoints/{args.model}_{utils.datestr()}_{args.dataset_name}'
args.save_checkpoint = os.path.join(args.save,'checkpoint.pt')
args.tb_log_dir = f'runs/{args.model}_{utils.datestr()}_{args.dataset_name}'

shutil.rmtree(args.tb_log_dir, ignore_errors=True)
os.makedirs(args.tb_log_dir)

In [3]:
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
## FOR REPRODUCIBILITY OF RESULTS
seed = 1777777
utils.reproducibility(args, seed)

utils.make_dirs(args.save)
utils.save_arguments(args, args.save)

## Tensorboard 설정

In [4]:
from torch.utils.tensorboard import SummaryWriter

# 기본 `log_dir` 은 "runs"이며, 여기서는 더 구체적으로 지정하였습니다
writer = SummaryWriter(f'{args.tb_log_dir}/log')

In [5]:
# loss function
# criterion = create_loss('CrossEntropyLoss')
# criterion = BCEDiceLoss(classes=args.classes, weight=torch.tensor([0.1]).cuda())
criterion = BCEDiceLoss(alpha=1, beta=1)

# model & optimizer
model, optimizer = medzoo.create_model(args)

# model with cuda 
if args.cuda:
    model = model.cuda()
    print("Model transferred in GPU.....")

Building Model . . . . . . . .UNET3D
UNET3D Number of params: 1781192
Model transferred in GPU.....


In [6]:
def train_model(model, batch_size, train_generator, val_generator ,patience, n_epochs, metrics, tensorboard=None):

    # 모델이 학습되는 동안 trainning loss를 track
    train_losses = []
    # 모델이 학습되는 동안 validation loss를 track
    valid_losses = []
    # epoch당 average training loss를 track
    avg_train_losses = []
    # epoch당 average validation loss를 track
    avg_valid_losses = []

    train_metrics = [[] for i in range(len(metrics))]
    val_metrics = [[] for i in range(len(metrics))]

    # early_stopping object의 초기화
    early_stopping = EarlyStopping(patience = patience, verbose = True,path=args.save_checkpoint)

    n_epochs_length = len(str(n_epochs))
    for epoch in range(1, n_epochs + 1):
        ###################
        # train the model #
        ###################
        model.train() # prep model for training

        progressbar = tqdm(train_generator)
        for batch, input_tuple in enumerate(progressbar, start=1):
        
            # clear the gradients of all optimized variables
            optimizer.zero_grad()  

            # gpu 연산으로 변경
            input, target = prepare_input(input_tuple=input_tuple, args=args)
            input.requires_grad = True

            # forward pass: 입력된 값을 모델로 전달하여 예측 출력 계산
            output = model(input)

            # calculate the loss
            loss = criterion(output, target)

            # backward pass: 모델의 파라미터와 관련된 loss의 그래디언트 계산
            loss.backward()

            # perform a single optimization step (parameter update)
            optimizer.step()

            # record training loss
            train_losses.append(loss.item())

            output[output > 0.5] = 1.0

            # record matric
            for i,metric in enumerate(metrics):
                value = metric(output,target).item()
                train_metrics[i].append(value)

            # print metric & loss
            print_msg = {
                'loss' : f'{loss.item():.5f}',
            }

            for i,metric in enumerate(metrics):
                print_msg[metric.metric_name] = f'{train_metrics[i][-1]:.5f}'

            progressbar.set_postfix(print_msg)
            progressbar.set_description(f'[{epoch:>{n_epochs_length}}/{n_epochs:>{n_epochs_length}}][{batch}/{len(train_generator)}]')

        ######################    
        # validate the model #
        ######################
        model.eval() # prep model for evaluation

        prediction_images = []
        for i,input_tuple in enumerate(val_generator) :

            input, target = prepare_input(input_tuple=input_tuple, args=args)
            input.requires_grad = True

            # forward pass: 입력된 값을 모델로 전달하여 예측 출력 계산
            output = model(input)

            # calculate the loss
            loss = criterion(output, target)

            # record validation loss
            valid_losses.append(loss.item())

            # save prediction result
            if i < 10 :
                prediction_images.append(
                    [input.cpu().detach().numpy(),output.cpu().detach().numpy(), target.cpu().detach().numpy()])

            output[output > 0.5] = 1.0

            # metric
            for i,metric in enumerate(metrics):
                val_metrics[i].append(metric(output,target).item())

        fig=plt.figure(figsize=(10, 10))
        columns = 3
        rows = 3

        i = 1
        for [input,output, target] in prediction_images:

            fig.add_subplot(rows, columns, i); i+=1
            plt.imshow(input[0,0,:,:,80])
            plt.title("original Image"); plt.axis('off')

            fig.add_subplot(rows, columns, i); i+=1
            plt.imshow(output[0,0,:,:, 80])
            plt.title("Predicited Image"); plt.axis('off')

            fig.add_subplot(rows, columns, i); i+=1
            plt.imshow(target[0,0,:,:, 80])
            plt.title("Original Mask"); plt.axis('off')

        plt.savefig(f'{args.result_path}/{epoch}.png')


        # print 학습/검증 statistics
        # epoch당 평균 loss 계산
        train_loss = np.average(train_losses)
        valid_loss = np.average(valid_losses)
        avg_train_losses.append(train_loss)
        avg_valid_losses.append(valid_loss)

        # epoch당 평균 metric 계산
        train_metric = {}
        val_metric = {}

        for i, metric in enumerate(metrics):
            train_metric[metric.metric_name] = np.average(train_metrics[i])
            val_metric[metric.metric_name] = np.average(val_metrics[i])


        # epoch 결과 출력
        print_msg = (f'[{epoch:>{n_epochs_length}}/{n_epochs:>{n_epochs_length}}] ' +
                     f'loss: {train_loss:.5f} ' +
                     f'val_loss: {valid_loss:.5f} ')

        for key in train_metric.keys():
            print_msg += f'{key} : {train_metric[key]:.5f} '

        for key in val_metric.keys():
            print_msg += f'val_{key} : {val_metric[key]:.5f} '

        print(print_msg)


        # tensorboard
        if tensorboard:
            writer.add_scalars("loss", 
                                {
                                    'train' : train_loss,
                                    'val' : valid_loss
                                 },
                                epoch)

            for train_key, val_key in zip(train_metric.keys(), val_metric.keys()):
                writer.add_scalars(train_key,
                                {
                                    'train' : train_metric[train_key],
                                     'val' : val_metric[val_key]
                                }, 
                                 epoch)

        # clear lists to track next epoch
        train_losses = []
        valid_losses = []
        train_metrics = [[] for i in range(len(metrics))]
        val_metrics = [[] for i in range(len(metrics))]

        # early_stopping는 validation loss가 감소하였는지 확인이 필요하며,
        # 만약 감소하였을경우 현제 모델을 checkpoint로 만든다.
        early_stopping(valid_loss, model)

        if early_stopping.early_stop:
            print("Early stopping")
            break

   # best model이 저장되어있는 last checkpoint를 로드한다.
    model.load_state_dict(torch.load(args.save_checkpoint))
    writer.close()

    return  model, avg_train_losses, avg_valid_losses

In [7]:
# train generator & val generator
training_generator, val_generator, full_volume, affine = medical_loaders.generate_datasets(args, path=r'datasets\MICCAI_2018_Brain_Tumor')

train : mri t2 flair and label .nii.gz to .npy: 100%|██████████| 386/386 [06:40<00:00,  1.04s/it]
val : mri t2 flair and label .nii.gz to .npy: 100%|██████████| 98/98 [01:37<00:00,  1.00it/s]DATA SAMPLES HAVE BEEN GENERATED SUCCESSFULLY



In [8]:
meanIoU = MeanIoU(metric_name="MeanIoU")
diceCoefficient = DiceCoefficient(metric_name="DiCE")

model, train_loss, valid_loss = train_model(model = model,
                                            train_generator = training_generator,
                                            val_generator = val_generator,
                                            batch_size = args.batchSz,
                                            patience = 5,
                                            n_epochs = args.nEpochs,
                                            metrics=[diceCoefficient, meanIoU],
                                            tensorboard=writer)

[ 1/50][264/386]:  68%|██████▊   | 264/386 [04:10<01:55,  1.05it/s, loss=0.55375, DiCE=0.86086, MeanIoU=0.74400]

In [None]:
import matplotlib.pyplot as plt

# 훈련이 진행되는 과정에 따라 loss를 시각화
fig = plt.figure(figsize=(10,8))
plt.plot(range(1,len(train_loss)+1),train_loss, label='Training Loss')
plt.plot(range(1,len(valid_loss)+1),valid_loss,label='Validation Loss')

# validation loss의 최저값 지점을 찾기
minposs = valid_loss.index(min(valid_loss))+1
plt.axvline(minposs, linestyle='--', color='r',label='Early Stopping Checkpoint')

plt.xlabel('epochs')
plt.ylabel('loss')
plt.ylim(0, 0.5) # 일정한 scale
plt.xlim(0, len(train_loss)+1) # 일정한 scale
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.show()
fig.savefig('loss_plot.png', bbox_inches = 'tight')

In [None]:
meanIoU = MeanIoU(metric_name="MeanIoU")
diceCoefficient = DiceCoefficient(metric_name="DiCE")

model, train_loss, valid_loss = train_model(model = model,
                                            train_generator = training_generator,
                                            val_generator = val_generator,
                                            batch_size = args.batchSz,
                                            patience = 5,
                                            n_epochs = args.nEpochs,
                                            metrics=[diceCoefficient, meanIoU],
                                            tensorboard=writer)