In [None]:
import os
import time
import json
import argparse
import mmcv
from mmcv import Config
from mmcv.parallel import scatter, collate, MMDataParallel
from mmcv.runner import obj_from_dict, load_checkpoint, save_checkpoint, parallel_test
from mmcv.runner.log_buffer import LogBuffer

from mmdet.datasets import get_dataset, build_dataloader
from mmdet.models import build_detector, detectors
from utils.util import set_random_seed, batch_processor, get_current_lr
from utils import lr_scheduler as LRschedule
from utils.logger import init_logger
from utils.deep_lesion_eval import evaluate_deep_lesion
from utils.reorganize import reorganize_data
import torch
from torch.nn.utils import clip_grad
import multiprocessing

def parse_args():
    parser = argparse.ArgumentParser(description='Train a detector')
    parser.add_argument('--config', help='train config file path', default='./configs/fpn_msb.py')
    args = parser.parse_args(args=[])
    return args

    
def test(data_loader, model, cfg, logger):
    model.eval()
    results = []
    dataset = data_loader.dataset
    prog_bar = mmcv.ProgressBar(len(data_loader))
    with torch.no_grad():
        for i, data_batch in enumerate(data_loader):
            data_batch = reorganize_data(data_batch, cfg.cfg_3dce.num_images_3dce, cfg.cfg_3dce.num_slices)
            result = model(return_loss=False, rescale=True, **data_batch)
            results.append(result)
            batch_size = 1
            for _ in range(batch_size):
                prog_bar.update()
    
    return results

args = parse_args()
cfg = Config.fromfile(args.config)
work_dir = cfg.work_dir
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
log_dir = os.path.join(work_dir, 'logs')
if not os.path.exists(log_dir):
    os.makedirs(log_dir)
logger = init_logger(log_dir)
seed = cfg.seed
logger.info('Set random seed to {}'.format(seed))
set_random_seed(seed)

test_dataset = get_dataset(cfg.data.val)
test_data_loader = build_dataloader(test_dataset,
                                   1,
                                   1,
                                   1,
                                   dist=False,
                                   shuffle=False
                                   )

model = build_detector(cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
model = MMDataParallel(model).cuda()
checkpoint_file = cfg.load_from
load_checkpoint(model, checkpoint_file)
logger.info('load model from {}'.format(checkpoint_file))
results = test(test_data_loader, model, cfg, logger)
sensitivity = evaluate_deep_lesion(results, test_dataset, cfg.cfg_3dce, logger)