In [None]:
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

from config.config import Config
from src.data_generator import VOCData
from src.visualize import draw_box
from src.model import get_rcnn_target, generate_anchors, unparameterize_box, box_iou

# increase image resolution for jupyter notebook
import matplotlib as mpl
mpl.rcParams['figure.dpi'] = 150


# load config
config = Config()
config.BATCH_SIZE = 1

# load dataset
test_data_dir = 'dataset/VOC2007/trainval/VOCdevkit/VOC2007'
test_dataset = VOCData(test_data_dir,
                        config.INPUT_SHAPE,
                        batch_size=config.BATCH_SIZE,
                        max_gt_instance=config.MAX_GT_INSTANCE,
                        debug=True)

# saved model dir
model_dir = 'model/20210817-054509'
model = tf.keras.models.load_model(model_dir)

In [None]:
# get result from the network
test_dataset.shuffle()
img_batch, (gt_cls_ids, gt_boxes, num_gt_instance) = test_dataset[0]
outputs = model(img_batch, training=False)

# unpack outputs
rpn_cls_output = outputs[0]  # [B, A, 2]
rpn_reg_output = outputs[1]  # [B, A, 4]
rcnn_cls_output = outputs[2]  # [B*N, 21]
rcnn_reg_output = outputs[3]  # [B*N, 20*4]
roi_boxes = outputs[4]  # [B*N, 4]
valid_num = outputs[5]  # [B]

In [None]:
top_k_rois = roi_boxes[:100]

gt_box_img = test_dataset.draw_label_img(
    img_batch[0],
    gt_cls_ids[0][:num_gt_instance[0]],
    gt_boxes[0][:num_gt_instance[0]])


In [None]:
# visualize rcnn targets
targets = get_rcnn_target(roi_boxes,
                          gt_boxes[0][:num_gt_instance[0]],
                          gt_cls_ids[0][:num_gt_instance[0]],
                          valid_num[0],
                          pos_threshold=config.RCNN_POS_IOU_THOLD,
                          neg_threshold=config.RCNN_NEG_IOU_THOLD,
                          total_sample_num=config.RCNN_TOTAL_SAMPLE_NUM,
                          pos_sample_ratio=config.RCNN_POS_SAMPLE_RATIO)

# unpack outputs
rpn_cls_target = targets[0]
rpn_reg_target = targets[1]
rpn_train_indices = targets[2]
rpn_valid_train_num = targets[3]
rpn_sampled_pos_num = targets[4]