# About

ChocoBallDetectorの学習済みモデルを評価する。

評価指標は、チョコボール検出個数のMSE。チョコボール個数を数えることが目的なんで。

In [1]:
%config Completer.use_jedi = False

In [2]:
from ipywidgets import interact
import ipywidgets as widgets
import logging

In [3]:
from src import util
from src.preprocessor import ChocoPreProcessor
from src.evaluator import ChocoEvaluator

In [4]:
%matplotlib inline
import matplotlib.pyplot as plt

In [5]:
%load_ext autoreload
%autoreload 2

In [6]:
logger = logging.getLogger(__name__)
util.set_logger(logger)

In [7]:
IMG_DIR = "../data/test"
BBOX_DIR = "../data/test"
CLASSES_FILE = "../data/classes.txt"
OUT = "../out"
MODEL = "../out/choco_faster_rcnn.npz"

# 評価データの前処理

In [8]:
choco_prep = ChocoPreProcessor(logger=logger)
choco_prep.set_classes(class_file=CLASSES_FILE)

2021-03-25 13:33:40,885 - __main__ - INFO - set object class: ../data/classes.txt
2021-03-25 13:33:40,886 - __main__ - INFO - classes: dict_keys(['choco-ball', 'choco-package'])


{'choco-ball': 0, 'choco-package': 1}

In [9]:
dataset = choco_prep.set_dataset(anno_dir=BBOX_DIR, img_dir=IMG_DIR)
bboxs = choco_prep.get_bbox_list()
imgs = choco_prep.get_img_array()
obj_ids = choco_prep.get_object_ids_list()
classes = choco_prep.get_object_classes()
print(imgs.shape)

2021-03-25 13:33:40,989 - __main__ - INFO - annotation_file_path: ../data/test
2021-03-25 13:33:40,991 - __main__ - INFO - image_file_path: ../data/test
2021-03-25 13:33:40,993 - __main__ - INFO - annotation_file_size: 7
100%|██████████| 7/7 [00:00<00:00, 188.08it/s]

(7, 3, 302, 402)





# 評価の実行

In [10]:
ce = ChocoEvaluator(gpu=0)
ce.load_model(model_file=MODEL)

In [11]:
%%time
res_list, mse = ce.evaluate_chocoball_number(images=imgs, true_labels=obj_ids)

CPU times: user 1.3 s, sys: 104 ms, total: 1.41 s
Wall time: 1.4 s


In [12]:
print(f"Evaluation Images: {imgs.shape[0]}")
print(f"MSE: {mse}")

Evaluation Images: 7
MSE: 0.0


# 推論結果の可視化

In [13]:
def visualize_detect_image(idx):
    fig = plt.figure(figsize=(12, 4))
    _ = ce.vis_detect_image(res_list[idx], vis_score=True, fig=fig)
    plt.show()

In [14]:
interact(
    visualize_detect_image, 
    idx=widgets.Dropdown(options=list(range(imgs.shape[0])), 
                         value=3, 
                         description="dataset idx")
)

interactive(children=(Dropdown(description='dataset idx', index=3, options=(0, 1, 2, 3, 4, 5, 6), value=3), Ou…

<function __main__.visualize_detect_image(idx)>