In [None]:
import sys
from pathlib import Path
sys.path.append(str(Path.cwd().parent.parent))

In [None]:
from katacv.utils.related_pkgs.jax_flax_optax_orbax import *
from katacv.yolov1.yolov1 import get_yolov1_state

state = get_yolov1_state()
weights = ocp.PyTreeCheckpointer().restore("/home/wty/Coding/models/YOLOv1/YOLOv1-0080-lite")
# weights = ocp.PyTreeCheckpointer().restore("/home/yy/Coding/modelsets/YOLOv1/YOLOv1-0080-lite")
state = state.replace(params=weights['params'], batch_stats=weights['batch_stats'])

In [None]:
@jax.jit
def predict(x):
    proba, boxes = state.apply_fn(
        {'params': state.params, 'batch_stats': state.batch_stats},
        x, train=False
    )
    return jax.device_get(proba), jax.device_get(boxes)
print("XLA compile...")
y = predict(jnp.empty((1,448,448,3), dtype='uint8'))
print("Compile complete!")

In [None]:
from katacv.utils.VOC.build_dataset import VOCBuilder, Path
from collections import namedtuple
args = {
    'path_dataset_tfrecord': Path('/home/wty/Coding/datasets/VOC/tfrecord'),
    'batch_size': 1,
    'shuffle_size': 1,
    'image_size': 448,
    'split_size': 7
}
args = namedtuple('Args', args)(**args)
ds_builder = VOCBuilder(args)
ds, ds_size = ds_builder.get_dataset('val', use_aug=False)
"""
8examples:
average mAP: 0.9583333333333334
average coco mAP: 0.7729166666666666

100examples:
average mAP: 0.8321197411003236
average coco mAP: 0.5845573008437084

val:
average mAP: 0.52847098262852
average coco mAP: 0.279217751483262
"""

In [None]:
import matplotlib.pyplot as plt
from katacv.utils.detection import plot_box, cvt_coord_cell2image, nms, get_best_boxes_and_classes, mAP, cvt_one_label2boxes, iou, coco_mAP
from katacv.utils.VOC.label2realname import label2realname
import numpy as np
from tqdm import tqdm
S = 7

mAP_avg = 0
coco_mAP_avg = 0
for i, (image, label) in enumerate(tqdm(ds, total=ds_size)):
    image, label = image.numpy(), label[0].numpy()
    # fig, ax = plt.subplots()
    # ax.imshow(image[0])
    
    proba, boxes = predict(image)
    cells = jnp.concatenate([proba, boxes], -1)
    boxes = jax.jit(get_best_boxes_and_classes, static_argnums=[1,2])(cells, B=2, C=20)[0]
    boxes = nms(boxes)
    
    target_boxes = cvt_one_label2boxes(label, 20)
    # print("target:", target_boxes)
    # print("pred:", boxes)
    _mAP = mAP(boxes, target_boxes)
    # print("mAP:", _mAP)
    coco = coco_mAP(boxes, target_boxes)
    # print("coco_mAP:", coco)
    mAP_avg += _mAP; coco_mAP_avg += coco
    
    # for box in target_boxes:
    #     text= f"{label2realname[int(box[5])]}"
    #     plot_box(ax, image[0].shape, box[1:5], text)
    # for box in boxes:
    #     text= f"pred {label2realname[int(box[5])]}"
    #     plot_box(ax, image[0].shape, box[1:5], text)
    # plt.show() 
mAP_avg /= ds_size; coco_mAP_avg /= ds_size
print("average mAP:", mAP_avg)
print("average coco mAP:", coco_mAP_avg)