In [4]:
%load_ext autoreload
import os
import json # Added to initialize the setting in Jupyter Notebook-by Mingyang
import easydict # Added to initialize the setting in Jupyter Notebook-by Mingyang
import time

import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from geneva.data.datasets import DATASETS
from geneva.evaluation.evaluate_metrics import report_inception_objects_score
from geneva.utils.config import keys, parse_config
from geneva.models.models import INFERENCE_MODELS
from geneva.data import codraw_dataset
from geneva.data import clevr_dataset

In [5]:
class Tester():
    def __init__(self, cfg, use_val=False, iteration=None, test_eval=False):
        self.model = INFERENCE_MODELS[cfg.gan_type](cfg)

        if use_val:
            dataset_path = cfg.val_dataset
            model_path = os.path.join(cfg.log_path, cfg.exp_name)
        else:
            dataset_path = cfg.dataset
            model_path = cfg.load_snapshot
        if test_eval:
            dataset_path = cfg.test_dataset
            model_path = cfg.load_snapshot

        self.model.load(model_path, iteration)
        self.dataset = DATASETS[cfg.dataset](path=keys[dataset_path],
                                             cfg=cfg,
                                             img_size=cfg.img_size)
        self.dataloader = DataLoader(self.dataset,
                                     batch_size=cfg.batch_size,
                                     shuffle=False,
                                     num_workers=cfg.num_workers,
                                     drop_last=True)

        self.iterations = len(self.dataset) // cfg.batch_size

        if cfg.dataset == 'codraw':
            self.dataloader.collate_fn = codraw_dataset.collate_data
        elif cfg.dataset == 'iclevr':
            self.dataloader.collate_fn = clevr_dataset.collate_data

        if cfg.results_path is None:
            cfg.results_path = os.path.join(cfg.log_path, cfg.exp_name,
                                            'results')
            if not os.path.exists(cfg.results_path):
                os.mkdir(cfg.results_path)

        self.cfg = cfg
        self.dataset_path = dataset_path

    def test(self):
        for batch in tqdm(self.dataloader, total=self.iterations):
            self.model.predict(batch)

In [7]:
config_file = "example_args/codraw_args.json"
#Load the config_file
with open(config_file, 'r') as f:
    cfg = json.load(f)
#convert cfg as easydict
cfg = easydict.EasyDict(cfg)
tester = Tester(cfg, test_eval=True)
print("finish setting up a tester")

finish setting up a tester


In [8]:
tester.test()

100%|██████████| 31/31 [00:13<00:00,  4.16it/s]


In [10]:
del tester
torch.cuda.empty_cache()
metrics_report = dict()
if cfg.metric_inception_objects:
    io_jss, io_ap, io_ar, io_af1, io_cs, io_gs = report_inception_objects_score(None,
                                                                                None,
                                                                                None,
                                                                                cfg.results_path,
                                                                                keys[cfg.dataset + '_inception_objects'],
                                                                                keys[cfg.test_dataset],
                                                                                cfg.dataset)

    metrics_report['jaccard'] = io_jss
    metrics_report['precision'] = io_ap
    metrics_report['recall'] = io_ar
    metrics_report['f1'] = io_af1
    metrics_report['cossim'] = io_cs
    metrics_report['relsim'] = io_gs
print(metrics_report)

100%|██████████| 1909/1909 [01:48<00:00, 18.00it/s]


Number of images used: 1909
JSS: 0.24663185293019207
 AP: 0.45078372677901224
AR: 0.3334648399734832
 F1: 0.3647282236615615
CS: 0.3995373845100403
GS: 0.17221260111025694
{'jaccard': 0.24663185293019207, 'precision': 0.45078372677901224, 'recall': 0.3334648399734832, 'f1': 0.3647282236615615, 'cossim': 0.39953738, 'relsim': 0.17221260111025694}



  'precision', 'predicted', average, warn_for)
  'recall', 'true', average, warn_for)
  'precision', 'predicted', average, warn_for)
  'recall', 'true', average, warn_for)
