In [None]:
import sys
from pathlib import Path
sys.path.append(str(Path.cwd().parent.parent))

In [None]:
from katacv.utils.related_pkgs.jax_flax_optax_orbax import *
from katacv.yolov3.parser import get_args_and_writer
from katacv.yolov3.yolov3_model import get_yolov3_state

class Args:
    B = 3
    C = 20
    seed = 0
    input_shape = (1,416,416,3)
    learning_rate = 1e-4
    anchors = [
        (0.02, 0.03), (0.04, 0.07), (0.08, 0.06),  # (10, 13), (16, 30), (33, 23),  # in 416x416
        (0.07, 0.15), (0.15, 0.11), (0.14, 0.29),  # (30, 61), (62, 45), (59, 119),
        (0.28, 0.22), (0.38, 0.48), (0.90, 0.78),  # (116, 90), (156, 198), (373, 326)
    ]
    batch_size = 32
    warmup_epochs = 5
    total_epochs = 80
    freeze = False
    

args = Args()
state = get_yolov3_state(args)
weights = ocp.PyTreeCheckpointer().restore("/home/wty/Coding/models/YOLOv3/YOLOv3-PASCAL-0044-lite")
print(weights.keys())
# weights = ocp.PyTreeCheckpointer().restore("/home/yy/Coding/models/YOLOv1/YOLOv1-0080-lite")
state = state.replace(params=weights['params'], params_darknet=weights['params_darknet'], batch_stats=weights['batch_stats'])
# state = state.replace(params=weights['params'], params_darknet=None, batch_stats=weights['batch_stats'])

In [None]:
from katacv.yolov3.yolov3 import TrainState
from katacv.utils.detection import slice_by_idxs, cvt_coord_cell2image

@partial(jax.jit, static_argnames=['B'])
def predict(state: TrainState, x, anchors, B=3):
    logits = state.apply_fn(
        # {'params': {'neck': state.params, 'darknet': state.params_darknet}, 'batch_stats': state.batch_stats},
        {'params': state.params, 'batch_stats': state.batch_stats},
        x, train=False
    )
    @partial(jax.jit, static_argnames=['S'])
    def convert_coord2image(logits, anchors, S):
        N = logits.shape[0]
        ret = []
        for k in range(B):
            now = logits[:,:,:,k,:]  # (N,B,B,5+C)
            xy = cvt_coord_cell2image(jax.nn.sigmoid(now[...,1:3]))  # (N,B,B,2)
            w = jnp.exp(now[...,3:4]) * anchors[k][0]  # (N,B,B,1)
            h = jnp.exp(now[...,4:5]) * anchors[k][1]  # (N,B,B,1)
            cls = jnp.argmax(jax.nn.softmax(now[...,5:]), -1, keepdims=True)  # (N,B,B,1)
            c = jax.nn.sigmoid(now[...,0:1]) * jnp.max(jax.nn.softmax(now[...,5:]), -1, keepdims=True)  # (N,B,B,1)
            ret.append(jnp.concatenate([c,xy,w,h,cls], -1).reshape(N, S*S, 6))
        return jnp.array(ret).reshape(N, S*S, B, 6)
        return ret

    def get_best_box_in_cell(y):  # (N,S*S,B,6)
        idxs = jnp.argmax(y[...,0], -1)
        best_boxes = slice_by_idxs(y.reshape(*y.shape[:-2], -1), idxs * 6, 6)
        return best_boxes  # (N,S*S,6)

    best_boxes = []
    for i in range(len(logits)):
        pred = convert_coord2image(logits[i], anchors[i*B:(i+1)*B], logits[i].shape[1])
        best_boxes.append(get_best_box_in_cell(pred))
    best_boxes = jnp.concatenate(best_boxes, 1)
    return best_boxes

In [None]:
from katacv.utils.VOC.build_dataset_yolov3 import DatasetBuilder, Path, split_targets
from collections import namedtuple
args_dataset = {
    'path_dataset_tfrecord': Path('/home/wty/Coding/datasets/PASCAL/tfrecord'),
    'batch_size': 1,
    'shuffle_size': 16,
    'image_size': 448,
    'split_sizes': [52, 26, 13],
    'anchors': [
        (0.02, 0.03), (0.04, 0.07), (0.08, 0.06),  # (10, 13), (16, 30), (33, 23),  # in 416x416
        (0.07, 0.15), (0.15, 0.11), (0.14, 0.29),  # (30, 61), (62, 45), (59, 119),
        (0.28, 0.22), (0.38, 0.48), (0.90, 0.78),  # (116, 90), (156, 198), (373, 326)
    ],
    'bounding_box': 3,
    'iou_ignore_threshold': 3,
    'B': 3
}
args_dataset = namedtuple('Args_dataset', args_dataset)(**args_dataset)
ds_builder = DatasetBuilder(args_dataset)
ds, ds_size = ds_builder.get_dataset('val', shuffle=False, use_aug=False)
"""
YOLOv1 evaluate result:
8examples:
average mAP: 0.9583333333333334
average coco mAP: 0.7729166666666666

100examples:
average mAP: 0.8321197411003236
average coco mAP: 0.5845573008437084

val:
average mAP: 0.52847098262852
average coco mAP: 0.279217751483262

YOLOv3 evaluate result:
8examples:
average mAP: 0.8411458333333333
average coco mAP: 0.5552083333333333

100examples: (CPU: 39s)
average mAP: 0.7738129862765785
average coco mAP: 0.49074226760968254

val: (CPU: 30:25)
average mAP: 0.6551234067298597
average coco mAP: 0.389151482568806
"""

In [None]:
import matplotlib.pyplot as plt
plt.rcParams['font.family'] = ['serif']
from katacv.utils.detection import plot_box, cvt_coord_cell2image, nms, get_best_boxes_and_classes, mAP, cvt_one_yolov3_label2boxes, iou, coco_mAP
from katacv.utils.VOC.label2realname import label2realname
label2realname = label2realname['PASCAL']
import numpy as np
from tqdm import tqdm
from pathlib import Path
Path("test_images").mkdir(exist_ok=True)
S = 7

def cvt_one_yolov3_label2boxes(label):
    """
    (Numpy) Convert one YOLOv3 label to origin bboxes (relative to the whole image),
    assume the labels output of the dataset is `labels`, 
    the bboxes in first example can be call by `cvt_one_yolov3_label2boxes(labels[0][0])`
    - label.shape=(S,S,B,6), elements: (c,x,y,w,h,cls)
    """
    if type(label) != np.ndarray:
        label = np.array(label)
    boxes = []
    for i in range(label.shape[2]):
        box = label[:,:,i,:]  # (S,S,5+C)
        box[...,1:3] = cvt_coord_cell2image(box[...,1:3])
        box[...,3:5] /= box.shape[0]
        box = box[box[...,0] == 1]
        if box.size != 0:
            boxes.append(box)
    return np.concatenate(boxes, 0)

mAP_avg = 0
coco_mAP_avg = 0
bar = tqdm(ds, total=ds_size)
for i, (image, label) in enumerate(bar):
# for i, (image, label) in enumerate(tqdm(ds.take(10), total=10)):
    image = image.numpy(); label = split_targets(label, args_dataset)
    # fig, ax = plt.subplots(figsize=(5,5))
    # ax.imshow(image[0])
    
    boxes = predict(state, image, args.anchors)[0]
    boxes = nms(boxes, iou_threshold=0.45, conf_threshold=0.2)
    
    target_boxes = cvt_one_yolov3_label2boxes(label[0][0])
    # print("target:", target_boxes)
    # print("pred:", boxes)
    _mAP = mAP(boxes, target_boxes)
    coco = coco_mAP(boxes, target_boxes)
    mAP_avg += _mAP; coco_mAP_avg += coco
    
    bar.set_description(f"avg: mAP={mAP_avg/(i+1):.2f} coco_mAP={coco_mAP_avg/(i+1):.2f}")
    # for box in target_boxes:
    #     text= f"{label2realname[int(box[5])]}"
    #     plot_box(ax, image[0].shape, box[1:5], text)
    # for box in boxes:
    #     text= f"pred {label2realname[int(box[5])]} {box[0]:.2f}"
    #     plot_box(ax, image[0].shape, box[1:5], text, box_color='green')
    # plt.title(f"mAP: {_mAP:.2f}, coco_mAP: {coco:.2f}")
    # plt.tight_layout()
    # plt.savefig(f"test_images/test_img{i}.png", dpi=200)
    # plt.show() 
mAP_avg /= ds_size; coco_mAP_avg /= ds_size
print("average mAP:", mAP_avg)
print("average coco mAP:", coco_mAP_avg)