In [85]:
import sys
from pathlib import Path
sys.path.append(str(Path.cwd().parent.parent))

In [86]:
from katacv.utils.related_pkgs.jax_flax_optax_orbax import *
from katacv.yolov3.yolov3 import get_args_and_writer
from katacv.yolov3.yolov3_model import get_yolov3_state

class Args:
    B = 3
    C = 20
    seed = 0
    input_shape = (1,416,416,3)
    learning_rate = 1e-4
    anchors = [
        (0.02, 0.03), (0.04, 0.07), (0.08, 0.06),  # (10, 13), (16, 30), (33, 23),  # in 416x416
        (0.07, 0.15), (0.15, 0.11), (0.14, 0.29),  # (30, 61), (62, 45), (59, 119),
        (0.28, 0.22), (0.38, 0.48), (0.90, 0.78),  # (116, 90), (156, 198), (373, 326)
    ]
    

args = Args()
state = get_yolov3_state(args)
# weights = ocp.PyTreeCheckpointer().restore("/home/wty/Coding/models/YOLOv3/YOLOv3-PASCAL-0050-lite")
weights = ocp.PyTreeCheckpointer().restore("/home/wty/Coding/models/YOLOv3/YOLOv3-test-0020-lite")
print(weights.keys())
# weights = ocp.PyTreeCheckpointer().restore("/home/yy/Coding/models/YOLOv1/YOLOv1-0080-lite")
state = state.replace(params=weights['params'], batch_stats=weights['batch_stats'])

dict_keys(['batch_stats', 'opt_state', 'params', 'step'])


In [87]:
from katacv.yolov3.yolov3 import TrainState
from katacv.utils.detection import slice_by_idxs, cvt_coord_cell2image

@partial(jax.jit, static_argnames=['B'])
def predict(state: TrainState, x, anchors, B=3):
    logits = state.apply_fn(
        {'params': state.params, 'batch_stats': state.batch_stats},
        x, train=False
    )
    @partial(jax.jit, static_argnames=['S'])
    def convert_coord2image(logits, anchors, S):
        N = logits.shape[0]
        ret = []
        for k in range(B):
            now = logits[:,:,:,k,:]  # (N,B,B,5+C)
            xy = cvt_coord_cell2image(jax.nn.sigmoid(now[...,1:3]))  # (N,B,B,2)
            w = jnp.exp(now[...,3:4]) * anchors[k][0]  # (N,B,B,1)
            h = jnp.exp(now[...,4:5]) * anchors[k][1]  # (N,B,B,1)
            cls = jnp.argmax(jax.nn.softmax(now[...,5:]), -1, keepdims=True)  # (N,B,B,1)
            c = jax.nn.sigmoid(now[...,0:1]) * jnp.max(jax.nn.softmax(now[...,5:]), -1, keepdims=True)  # (N,B,B,1)
            # c = jax.nn.sigmoid(now[...,0:1])  # (N,B,B,1)
            # print(c.shape, xy.shape, w.shape, h.shape, cls.shape)
            ret.append(jnp.concatenate([c,xy,w,h,cls], -1).reshape(N, S*S, 6))
        return jnp.array(ret).reshape(N, S*S, B, 6)
        return ret

    def get_best_box_in_cell(y):  # (N,S*S,B,6)
        idxs = jnp.argmax(y[...,0], -1)
        best_boxes = slice_by_idxs(y.reshape(*y.shape[:-2], -1), idxs * 6, 6)
        return best_boxes  # (N,S*S,6)

    best_boxes = []
    for i in range(len(logits)):
        pred = convert_coord2image(logits[i], anchors[i*B:(i+1)*B], logits[i].shape[1])
        best_boxes.append(get_best_box_in_cell(pred))
    best_boxes = jnp.concatenate(best_boxes, 1)
    return best_boxes

In [88]:
import time
start_time = time.time()
print("XLA compile...")
boxes = predict(state, jnp.empty((1,416,416,3), dtype='uint8'), args.anchors)
print(f"Compile complete! Use time: {time.time() - start_time:.2f} s")

XLA compile...
Compile complete! Use time: 3.16 s


In [None]:
from katacv.utils.VOC.build_dataset_yolov3 import DatasetBuilder


In [82]:
boxes.shape

(1, 3549, 6)

In [90]:
from PIL import Image
x = jnp.array(Image.open(r"/home/wty/Pictures/model_test/test_image/8examples/000007.jpg").resize((416,416)))[None, ...]
import matplotlib.pyplot as plt
# plt.imshow(x[0])
# print(x.shape)
boxes = predict(state, x, args.anchors)[0]
print(jnp.sort(boxes[:,0])[::-1])
nms_boxes = nms

[0.09618747 0.07229392 0.06658299 ... 0.025      0.025      0.025     ]


In [91]:
import io, time
import ipywidgets as widgets
from IPython.display import display, clear_output
from PIL import Image, ImageEnhance
import numpy as np
import matplotlib.pyplot as plt
from katacv.utils.detection import nms, plot_box, iou, slice_by_idxs, cvt_coord_cell2image, get_best_boxes_and_classes
from katacv.utils.VOC.label2realname import label2realname
label2realname = label2realname['PASCAL']

upload_button = widgets.FileUpload(description="上传图片")
iou_slider = widgets.FloatSlider(description="IOU阈值 (NMS)", min=0.0, max=1.0, step=0.01, value=0.3)
conf_slider = widgets.FloatSlider(description="最低置信度阈值", min=-0.5, max=2.0, step=0.01, value=0.1)
aux_output = widgets.Text("模型/识别用时：")

# @jax.jit  # optimizer use jit
def NMS(proba, boxes, iou_threshold, conf_threshold):
    start_time = time.time()
    cells = jnp.concatenate([proba, boxes], -1)
    # boxes = jax.jit(get_best_boxes_and_classes, static_argnums=[1,2])(cells, B=2, C=20)[0]
    boxes = get_best_boxes_and_classes(cells, B=2, C=20)[0]
    print("get boxes time:", time.time() - start_time)
    start_time = time.time()
    boxes = nms(boxes, iou_threshold=iou_threshold, conf_threshold=conf_threshold)
    print("nms time:", time.time() - start_time)
    return boxes

# 显示图片和调整亮度、对比度的函数
def show_image(upload, iou_threshold, conf_threshold):
    if upload:
        uploaded_file = upload[0]
        origin_image = Image.open(io.BytesIO(uploaded_file['content'])).convert('RGB')
        resize_image = origin_image.resize((416,416))
        # x = np.expand_dims(np.array(resize_image), 0)
        x = jnp.expand_dims(jnp.array(resize_image), 0).astype('float32')
        start_time = time.time()
        
        boxes = predict(state, x, args.anchors)[0]
        aux_output.value = f"模型用时：{time.time() - start_time:.4f} s"
        
        print("confidence:", jnp.sort(boxes[:,0])[::-1])
        boxes = nms(boxes, iou_threshold, conf_threshold)
        # boxes = NMS(proba, boxes, iou_threshold, conf_threshold)
        
        aux_output.value += f" 识别总用时：{time.time() - start_time:.4f} s"
        
        print(boxes.shape)
        fig, ax = plt.subplots(figsize=(15,8))
        ax.imshow(origin_image)
        for i in range(boxes.shape[0]):
            plot_box(ax, origin_image.size[::-1], boxes[i,1:5], text=f"{label2realname[int(boxes[i,5])]} {float(boxes[i,0]):.2f}")
        plt.show()
        # display(img)

interactive_output = widgets.interactive_output(
    show_image,
    {
        "upload": upload_button,
        "iou_threshold": iou_slider, 
        "conf_threshold": conf_slider,
    }
)

display(widgets.HBox([upload_button, iou_slider, conf_slider, aux_output]), interactive_output)


HBox(children=(FileUpload(value=(), description='上传图片'), FloatSlider(value=0.3, description='IOU阈值 (NMS)', max…

Output()