# NeuralODE for Brain Tumor Segmentation on BraTS dataset

In [None]:
import nibabel

import os
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.utils.data import DataLoader


In [None]:
import torch
from torch.utils.data import Dataset
import os
import numpy as np

class bratsDataset(Dataset):
    """
    __init__ needs a rootPath
    from there it assumes there are two folders numtrain and numlabels containing numpy files
    """
    def __init__(self, rootPath, PATH_SIZE):
        self.train = []
        self.labels = []
        
        pathTrain = os.path.join(rootPath, 'num'+ PATH_SIZE + 'train')
        imgPaths = os.listdir(pathTrain)
        for path in imgPaths:
            if path.endswith('.npy'):
                self.train.append(os.path.join(pathTrain, path))
        
        pathLabels = os.path.join(rootPath, 'num' + PATH_SIZE + 'labels')
        imgPaths = os.listdir(pathLabels)
        for path in imgPaths:
            if path.endswith('.npy'):
                self.labels.append(os.path.join(pathLabels, path))
        
    def __len__(self):
        assert len(self.train) == len(self.labels)
        return len(self.train)
    
    def __getitem__(self,idx):
        x = np.load(self.train[idx])
        x = torch.from_numpy(x)
        x = x.float()
        
        y = np.load(self.labels[idx])
        y = torch.from_numpy(y)
        y = y.float()
        
        return x,y


class brats3dDataset(Dataset):
    """
    Needs a rootPath.
    Expects to find 'source' and 'target' folders.
    Expects to deal with preprocessed numpy files.
    """
    def __init__(self, rootPath):
        self.source = []
        self.target = []
        # perhaps a pandas Series would be better than an array?
        
        pathSource = os.path.join(rootPath, 'source')
        imgPaths = os.listdir(pathSource)
        for path in imgPaths:
            if path.endswith('.npy'):
                self.source.append(os.path.join(pathSource, path))
        
        pathTarget = os.path.join(rootPath, 'target')
        imgPaths = os.listdir(pathTarget)
        for path in imgPaths:
            if path.endswith('.npy'):
                self.target.append(os.path.join(pathTarget, path))
        
    def __len__(self):
        assert len(self.source) == len(self.target)
        return len(self.source)
    
    def __getitem__(self,idx):
        x = np.load(self.source[idx])
        x = torch.from_numpy(x)
        x = x.float()
        
        y = np.load(self.target[idx])
        y = torch.from_numpy(y)
        y = y.float()
        
        return x,y

# Preprocessing BraTS Dataset (preprocess.py)

In [None]:
import cv2
import os
import pdb
import numpy as np
import nibabel as nib
from sklearn.preprocessing import LabelBinarizer

IMG_ROOT = './Task01_BrainTumor/imagesTr'
IMG_PATH = './Task01_BrainTumor/imagesTr/BRATS_148.nii.gz'
IMG_OUTPUT_ROOT = './train/image_T1'

LABEL_ROOT = './Task01_BrainTumor/labelsTr'
IABEL_PATH = './Task01_BrainTumor/labelsTr/BRATS_148.nii.gz'
LABEL_OUTPUT_ROOT = './train/label'

L0 = 0      # Background
L1 = 50     # Necrotic and Non-enhancing Tumor
L2 = 100    # Edema
L3 = 150    # Enhancing Tumor

# MRI Image channels Description
# ch0: FLAIR / ch1: T1 / ch2: T1c/ ch3: T2
# cf) In this project, we use FLAIR and T1c MRI dataset
# 
# Data Load Example
#img = nib.load(IMG_PATH)
#img = (img.get_fdata())[:,:,:,3]                # img shape = (240,240,155)


# MRI Label Channels Description
# 0: Background         / 1: Necrotic and non-enhancing tumor (paper, 1+3)
# 2: edema (paper, 2)   / 3: Enhancing tumor (paper, 4)
# 
# <Input>           <Prediction>
# FLAIR             Complete(1,2,3)
# FLAIR             Core(1,3)
# T1c               Enhancing(3)
#
# Data Load Example
# label = nib.load(LABEL_PATH)
# label = (label.get_fdata()).astype(np.uint16)   # label shape = (240,240,155)


def nii2jpg_img(img_path, output_root):
    img_name = (img_path.split('/')[-1]).split('.')[0]
    output_path = os.path.join(output_root, img_name)
    try:
        os.makedirs(output_root)
    except:
        pass
    try:
        os.makedirs(output_path)
    except:
        pass
    img = nib.load(img_path)
    img = (img.get_fdata())[:,:,:,1]
    img = (img/img.max())*255
    img = img.astype(np.uint8)

    for i in range(img.shape[2]):
        filename = os.path.join(output_path, img_name+'_'+str(i)+'.jpg')
        gray_img = img[:,:,i]
        #color_img = np.expand_dims(gray_img, 3)
        #color_img = np.concatenate([color_img, color_img, color_img], 2)

        # COLOR LABELING
        #c255 = np.expand_dims(np.ones(gray_img.shape)*255, 3)
        #c0 = np.expand_dims(np.zeros(gray_img.shape), 3)
        #color = np.concatenate([c0,c0,c255], 2)
        #color_img = color_img.astype(np.float32) + color
        #color_img = (color_img / color_img.max()) *255

        cv2.imwrite(filename, gray_img)


def nii2jpg_label(img_path, output_root):
    img_name = (img_path.split('/')[-1]).split('.')[0]
    output_path = os.path.join(output_root, img_name)
    try:
        os.mkdir(output_root)
    except:
        pass
    try:
        os.mkdir(output_path)
    except:
        pass
    img = nib.load(img_path)
    img = (img.get_fdata())[:,:,:]
    pdb.set_trace()
    img = img*50
    img = img.astype(np.uint8)

    for i in range(img.shape[2]):
        filename = os.path.join(output_path, img_name+'_'+str(i)+'.jpg')
        gray_img = img[:,:,i]
        #color_img = np.expand_dims(gray_img, 3)
        #color_img = np.concatenate([color_img, color_img, color_img], 2)

        # COLOR LABELING
        #c255 = np.expand_dims(np.ones(gray_img.shape)*255, 3)
        #c0 = np.expand_dims(np.zeros(gray_img.shape), 3)
        #color = np.concatenate([c0,c0,c255], 2)
        #color_img = color_img.astype(np.float32) + color
        #color_img = (color_img / color_img.max()) *255

        cv2.imwrite(filename, gray_img)


for path in os.listdir(IMG_ROOT):
    print(path)
    if path[0] == '.':
        continue
    nii2jpg_img(os.path.join(IMG_ROOT,path), IMG_OUTPUT_ROOT)
'''
for path in os.listdir(LABEL_ROOT):
    print(path)
    if path[0] == '.':
        continue
    nii2jpg_label(os.path.join(LABEL_ROOT,path), LABEL_OUTPUT_ROOT)
'''


# Config.py file


In [None]:
import torch

# Device Init
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Model output shape Init
class_num = 2


# Data Handling Parameters
complete_threshold = 0.05
complete_rate = 0.66

core_threshold = 0.05
core_rate = 0.66

enhancing_threshold = 0.02
enhancing_rate = 0.7

# Data Augmentation Parameters


# Utils.py script


In [None]:
import pdb
import cv2
import os
import numpy as np
import nibabel as nib
import torch
import sys
import time
import logging
import logging.handlers
import pydensecrf.densecrf as dcrf

In [None]:
class Checkpoint:
    def __init__(self, model, optimizer=None, epoch=0, best_score=1):
        self.model = model
        self.optimizer = optimizer
        self.epoch = epoch
        self.best_score = best_score

    def load(self, path):
        checkpoint = torch.load(path)
        self.model.load_state_dict(checkpoint["model_state"])
        self.epoch = checkpoint["epoch"]
        self.best_score = checkpoint["best_score"]
        if self.optimizer:
            self.optimizer.load_state_dict(checkpoint["optimizer_state"])
            for state in self.optimizer.state.values():
                  for k, v in state.items():
                           if torch.is_tensor(v):
                                    state[k] = v.cuda()

    def save(self, path):
        state_dict = self.model.module.state_dict()
        torch.save({"model_state": state_dict,
                    "optimizer_state": self.optimizer.state_dict(),
                    "epoch": self.epoch,
                    "best_score": self.best_score}, path)


def progress_bar(current, total, msg=None):
    ''' Source Code from 'kuangliu/pytorch-cifar'
        (https://github.com/kuangliu/pytorch-cifar/blob/master/utils.py)
    '''
    global last_time, begin_time
    if current == 0:
        begin_time = time.time()  # Reset for new bar.

    cur_len = int(TOTAL_BAR_LENGTH*current/total)
    rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1

    sys.stdout.write(' [')
    for i in range(cur_len):
        sys.stdout.write('=')
    sys.stdout.write('>')
    for i in range(rest_len):
        sys.stdout.write('.')
    sys.stdout.write(']')

    cur_time = time.time()
    step_time = cur_time - last_time
    last_time = cur_time
    tot_time = cur_time - begin_time

    L = []
    L.append('  Step: %s' % format_time(step_time))
    L.append(' | Tot: %s' % format_time(tot_time))
    if msg:
        L.append(' | ' + msg)

    msg = ''.join(L)
    sys.stdout.write(msg)
    for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3):
        sys.stdout.write(' ')

    # Go back to the center of the bar.
    for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2):
        sys.stdout.write('\b')
    sys.stdout.write(' %d/%d ' % (current+1, total))

    if current < total-1:
        sys.stdout.write('\r')
    else:
        sys.stdout.write('\n')
    sys.stdout.flush()


def format_time(seconds):
    ''' Source Code from 'kuangliu/pytorch-cifar'
        (https://github.com/kuangliu/pytorch-cifar/blob/master/utils.py)
    '''
    days = int(seconds / 3600/24)
    seconds = seconds - days*3600*24
    hours = int(seconds / 3600)
    seconds = seconds - hours*3600
    minutes = int(seconds / 60)
    seconds = seconds - minutes*60
    secondsf = int(seconds)
    seconds = seconds - secondsf
    millis = int(seconds*1000)

    f = ''
    i = 1
    if days > 0:
        f += str(days) + 'D'
        i += 1
    if hours > 0 and i <= 2:
        f += str(hours) + 'h'
        i += 1
    if minutes > 0 and i <= 2:
        f += str(minutes) + 'm'
        i += 1
    if secondsf > 0 and i <= 2:
        f += str(secondsf) + 's'
        i += 1
    if millis > 0 and i <= 2:
        f += str(millis) + 'ms'
        i += 1
    if f == '':
        f = '0ms'
    return f


def get_logger(level="DEBUG", file_level="DEBUG"):
    logger = logging.getLogger(None)
    logger.setLevel(level)
    fomatter = logging.Formatter(
            '%(asctime)s  [%(levelname)s]  %(message)s  (%(filename)s:  %(lineno)s)')
    fileHandler = logging.handlers.TimedRotatingFileHandler(
            'result.log', when='d', encoding='utf-8')
    fileHandler.setLevel(file_level)
    fileHandler.setFormatter(fomatter)
    logger.addHandler(fileHandler)
    return logger

In [None]:
from pydensecrf.utils import compute_unary, create_pairwise_bilateral,\
         create_pairwise_gaussian, softmax_to_unary, unary_from_softmax

_, term_width = os.popen('stty size', 'r').read().split()
term_width = int(term_width)

TOTAL_BAR_LENGTH = 65.
last_time = time.time()
begin_time = last_time


def dice_coef(preds, targets, backprop=True):
    smooth = 1.0
    class_num = 2
    if backprop:
        for i in range(class_num):
            pred = preds[:,i,:,:]
            target = targets[:,i,:,:]
            intersection = (pred * target).sum()
            loss_ = 1 - ((2.0 * intersection + smooth) / (pred.sum() + target.sum() + smooth))
            if i == 0:
                loss = loss_
            else:
                loss = loss + loss_
        loss = loss/class_num
        return loss
    else:
        # Need to generalize
        targets = np.array(targets.argmax(1))
        if len(preds.shape) > 3:
            preds = np.array(preds).argmax(1)
        for i in range(class_num):
            pred = (preds==i).astype(np.uint8)
            target= (targets==i).astype(np.uint8)
            intersection = (pred * target).sum()
            loss_ = 1 - ((2.0 * intersection + smooth) / (pred.sum() + target.sum() + smooth))
            if i == 0:
                loss = loss_
            else:
                loss = loss + loss_
        loss = loss/class_num
        return loss


def get_crf_img(inputs, outputs):
    for i in range(outputs.shape[0]):
        img = inputs[i]
        softmax_prob = outputs[i]
        unary = unary_from_softmax(softmax_prob)
        unary = np.ascontiguousarray(unary)
        d = dcrf.DenseCRF(img.shape[0] * img.shape[1], 2)
        d.setUnaryEnergy(unary)
        feats = create_pairwise_gaussian(sdims=(10,10), shape=img.shape[:2])
        d.addPairwiseEnergy(feats, compat=3, kernel=dcrf.DIAG_KERNEL,
                            normalization=dcrf.NORMALIZE_SYMMETRIC)
        feats = create_pairwise_bilateral(sdims=(50,50), schan=(20,20,20),
                                          img=img, chdim=2)
        d.addPairwiseEnergy(feats, compat=10, kernel=dcrf.DIAG_KERNEL,
                            normalization=dcrf.NORMALIZE_SYMMETRIC)
        Q = d.inference(5)
        res = np.argmax(Q, axis=0).reshape((img.shape[0], img.shape[1]))
        if i == 0:
            crf = np.expand_dims(res,axis=0)
        else:
            res = np.expand_dims(res,axis=0)
            crf = np.concatenate((crf,res),axis=0)
    return crf


def erode_dilate(outputs, kernel_size=7):
    kernel = np.ones((kernel_size,kernel_size),np.uint8)
    outputs = outputs.astype(np.uint8)
    for i in range(outputs.shape[0]):
        img = outputs[i]
        img = cv2.morphologyEx(img, cv2.MORPH_OPEN, kernel)
        img = cv2.morphologyEx(img, cv2.MORPH_CLOSE, kernel)
        outputs[i] = img
    return outputs


def post_process(args, inputs, outputs, input_path=None,
                 crf_flag=True, erode_dilate_flag=True,
                 save=True, overlap=True):
    inputs = (np.array(inputs.squeeze()).astype(np.float32)) * 255
    inputs = np.expand_dims(inputs, axis=3)
    inputs = np.concatenate((inputs,inputs,inputs), axis=3)
    outputs = np.array(outputs)

    # Conditional Random Field
    if crf_flag:
        outputs = get_crf_img(inputs, outputs)
    else:
        outputs = outputs.argmax(1)

    # Erosion and Dilation
    if erode_dilate_flag:
        outputs = erode_dilate(outputs, kernel_size=7)
    if save == False:
        return outputs

    outputs = outputs*255
    for i in range(outputs.shape[0]):
        path = input_path[i].split('/')
        output_folder = os.path.join(args.output_root, path[-2])
        try:
            os.mkdir(output_folder)
        except:
            pass
        output_path = os.path.join(output_folder, path[-1])
        if overlap:
            img = outputs[i]
            img = np.expand_dims(img, axis=2)
            zeros = np.zeros(img.shape)
            img = np.concatenate((zeros,zeros,img), axis=2)
            img = np.array(img).astype(np.float32)
            img = inputs[i] + img
            if img.max() > 0:
                img = (img/img.max())*255
            else:
                img = (img/1) * 255
            cv2.imwrite(output_path, img)
        else:
            img = outputs[i]
            cv2.imwrite(output_path, img)
    return None


# UNet Model architecture code

In [None]:
import pdb
import torch
import torch.nn as nn

from torch.nn.functional import softmax


def conv3x3(in_c, out_c, kernel_size=3, stride=1, padding=1,
            bias=True, useBN=False, drop_rate=0):
    if useBN:
        return nn.Sequential(
                nn.ReflectionPad2d(padding),
                nn.Conv2d(in_c, out_c, kernel_size, stride, padding=0, bias=bias),
                nn.BatchNorm2d(out_c),
                nn.Dropout2d(p=drop_rate),
                nn.ReLU(inplace=True),
                nn.ReflectionPad2d(padding),
                nn.Conv2d(out_c, out_c, kernel_size, stride, padding=0, bias=bias),
                nn.BatchNorm2d(out_c),
                nn.Dropout2d(p=drop_rate),
                nn.ReLU(inplace=True))
    else:
        return nn.Sequential(
                nn.ReflectionPad2d(padding),
                nn.Conv2d(in_c, out_c, kernel_size, stride, padding=0, bias=bias),
                nn.Dropout2d(p=drop_rate),
                nn.ReLU(),
                nn.ReflectionPad2d(padding),
                nn.Conv2d(out_c, out_c, kernel_size, stride, padding=0, bias=bias),
                nn.Dropout2d(p=drop_rate),
                nn.ReLU())


def upsample(in_c, out_c, bias=True, drop_rate=0):
	return nn.Sequential(
        #nn.ReflectionPad2d(1),
		nn.ConvTranspose2d(in_c, out_c, 4, 2, 1, bias=bias),
        nn.Dropout2d(p=drop_rate),
        nn.ReLU())


class UNet(nn.Module):
    def __init__(self, in_channel=1, class_num=2, useBN=False, drop_rate=0):
        super(UNet, self).__init__()
        self.output_dim = class_num
        self.drop_rate = drop_rate
        self.conv1 = conv3x3(in_channel, 64, useBN=useBN, drop_rate=self.drop_rate)
        self.conv2 = conv3x3(64, 128, useBN=useBN, drop_rate=self.drop_rate)
        self.conv3 = conv3x3(128, 256, useBN=useBN, drop_rate=self.drop_rate)
        self.conv4 = conv3x3(256, 512, useBN=useBN, drop_rate=self.drop_rate)
        self.conv5 = conv3x3(512, 1024, useBN=useBN, drop_rate=self.drop_rate)

        self.conv4m = conv3x3(1024, 512, useBN=useBN, drop_rate=self.drop_rate)
        self.conv3m = conv3x3(512, 256, useBN=useBN, drop_rate=self.drop_rate)
        self.conv2m = conv3x3(256, 128, useBN=useBN, drop_rate=self.drop_rate)
        self.conv1m = conv3x3(128, 64, useBN=useBN, drop_rate=self.drop_rate)

        self.conv0  = nn.Sequential(nn.ReflectionPad2d(1),
                                    nn.Conv2d(64, self.output_dim, 3, 1, 0),
                                    nn.Dropout2d(p=self.drop_rate),
                                    nn.ReLU())
        self.max_pool = nn.MaxPool2d(2)

        self.upsample54 = upsample(1024, 512, drop_rate=self.drop_rate)
        self.upsample43 = upsample(512, 256, drop_rate=self.drop_rate)
        self.upsample32 = upsample(256, 128, drop_rate=self.drop_rate)
        self.upsample21 = upsample(128, 64, drop_rate=self.drop_rate)

		## weight initialization
        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
                if m.bias is not None:
                    m.bias.data.zero_()
                nn.init.normal_(m.weight.data, mean=0, std=0.01)

    def forward(self, x):
        output1 = self.conv1(x)
        output2 = self.conv2(self.max_pool(output1))
        output3 = self.conv3(self.max_pool(output2))
        output4 = self.conv4(self.max_pool(output3))
        output5 = self.conv5(self.max_pool(output4))

        conv5m_out = torch.cat((self.upsample54(output5), output4), 1)
        conv4m_out = self.conv4m(conv5m_out)
        conv4m_out = torch.cat((self.upsample43(output4), output3), 1)
        conv3m_out = self.conv3m(conv4m_out)

        conv3m_out = torch.cat((self.upsample32(output3), output2), 1)
        conv2m_out = self.conv2m(conv3m_out)

        conv2m_out = torch.cat((self.upsample21(output2), output1), 1)
        conv1m_out = self.conv1m(conv2m_out)

        final = self.conv0(conv1m_out)
        final = softmax(final, dim=1)
        return final


def test():
    net = UNet(class_num=2)
    y = net(torch.randn(3,1,240,240))
    print(y.size())

# Trainer Code (train.py) 

In [None]:
import argparse
import logging
import pdb

import torch
import torch.backends.cudnn as cudnn

from config import *
from dataset import *
from models import *
from utils import *


def train(args):

    # Variables and logger Init
    device = config.device
    cudnn.benchmark = True
    get_logger()

    # Data Load
    trainloader = data_loader(args, mode='train')
    validloader = data_loader(args, mode='valid')

    # Model Load
    net, optimizer, best_score, start_epoch =\
        load_model(args, class_num=config.class_num, mode='train')
    log_msg = '\n'.join(['%s Train Start'%(args.model)])
    logging.info(log_msg)

    for epoch in range(start_epoch, start_epoch+args.epochs):

        # Train Model
        print('\n\n\nEpoch: {}\n<Train>'.format(epoch))
        net.train(True)
        loss = 0
        lr = args.lr * (0.5 ** (epoch // 4))
        for param_group in optimizer.param_groups:
            param_group["lr"] = lr
        torch.set_grad_enabled(True)
        for idx, (inputs, targets, paths) in enumerate(trainloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            if type(outputs) == tuple:
                outputs = outputs[0]
            batch_loss = dice_coef(outputs, targets)
            optimizer.zero_grad()
            batch_loss.backward()
            optimizer.step()
            loss += float(batch_loss)
            progress_bar(idx, len(trainloader), 'Loss: %.5f, Dice-Coef: %.5f'
                         %((loss/(idx+1)), (1-(loss/(idx+1)))))
        log_msg = '\n'.join(['Epoch: %d  Loss: %.5f,  Dice-Coef:  %.5f'\
                         %(epoch, loss/(idx+1), 1-(loss/(idx+1)))])
        logging.info(log_msg)

        # Validate Model
        print('\n\n<Validation>')
        net.eval()
        for module in net.module.modules():
            if isinstance(module, torch.nn.modules.Dropout2d):
                module.train(True)
            elif isinstance(module, torch.nn.modules.Dropout):
                module.train(True)
            else:
                pass
        loss = 0
        torch.set_grad_enabled(False)
        for idx, (inputs, targets, paths) in enumerate(validloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            if type(outputs) == tuple:
                outputs = outputs[0]
            #outputs = post_process(args, inputs, outputs, save=False)
            batch_loss = dice_coef(outputs, targets, backprop=False)
            loss += float(batch_loss)
            progress_bar(idx, len(validloader), 'Loss: %.5f, Dice-Coef: %.5f'
                         %((loss/(idx+1)), (1-(loss/(idx+1)))))
        log_msg = '\n'.join(['Epoch: %d  Loss: %.5f,  Dice-Coef:  %.5f'
                        %(epoch, loss/(idx+1), 1-(loss/(idx+1)))])
        logging.info(log_msg)

        # Save Model
        loss /= (idx+1)
        score = 1 - loss
        if score > best_score:
            checkpoint = Checkpoint(net, optimizer, epoch, score)
            checkpoint.save(os.path.join(args.ckpt_root, args.model+'.tar'))
            best_score = score
            print("Saving...")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument("--resume", type=bool, default=False,
                        help="Model Trianing resume.")
    parser.add_argument("--model", type=str, default='pspnet_res50',
                        help="Model Name (unet, pspnet_squeeze, pspnet_res50,\
                        pspnet_res34, pspnet_res50, deeplab)")
    parser.add_argument("--in_channel", type=int, default=1,
                        help="A number of images to use for input")
    parser.add_argument("--batch_size", type=int, default=80,
                        help="The batch size to load the data")
    parser.add_argument("--epochs", type=int, default=30,
                        help="The training epochs to run.")
    parser.add_argument("--drop_rate", type=float, default=0.1,
                        help="Drop-out Rate")
    parser.add_argument("--lr", type=float, default=0.001,
                        help="Learning rate to use in training")
    parser.add_argument("--data", type=str, default="complete",
                        help="Label data type.")
    parser.add_argument("--img_root", type=str, default="../../data/train/image_FLAIR",
                        help="The directory containing the training image dataset.")
    parser.add_argument("--label_root", type=str, default="../../data/train/label",
                        help="The directory containing the training label datgaset")
    parser.add_argument("--output_root", type=str, default="./output/prediction",
                        help="The directory containing the result predictions")
    parser.add_argument("--ckpt_root", type=str, default="./checkpoint",
                        help="The directory containing the checkpoint files")
    args = parser.parse_args()

    train(args)

# Testing (test.py) Script

In [None]:
import pdb
import argparse
import torch
import torch.backends.cudnn as cudnn

from config import *
from dataset import *
from models import *
from utils import *


def test(args):

    # Device Init
    device = config.device
    cudnn.benchmark = True

    # Data Load
    testloader = data_loader(args, mode='test')

    # Model Load
    net, _, _, _ = load_model(args, class_num=config.class_num, mode='test')

    net.eval()
    torch.set_grad_enabled(False)
    for idx, (inputs, paths) in enumerate(testloader):
        inputs = inputs.to(device)
        outputs = net(inputs)
        if type(outputs) == tuple:
            outputs = outputs[0]
        post_process(args, inputs, outputs, paths)
        print(idx)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument("--model", type=str, default='pspnet', # Need to be fixed
                        help="Model Name")
    parser.add_argument("--batch_size", type=int, default=155, # Need to be fixed
                        help="The batch size to load the data")
    parser.add_argument("--data", type=str, default="complete",
                        help="Label data type.")
    parser.add_argument("--img_root", type=str, default="../data/train/image_FLAIR",
                        help="The directory containing the training image dataset.")
    parser.add_argument("--output_root", type=str, default="./output/prediction",
                        help="The directory containing the results.")
    parser.add_argument("--ckpt_root", type=str, default="./checkpoint",
                        help="The directory containing the trained model checkpoint")
    args = parser.parse_args()

    test(args)