In [1]:
import sys
from pathlib import Path
sys.path.append(str(Path.cwd().parent.parent))

In [9]:
from katacv.utils.related_pkgs.jax_flax_optax_orbax import *
from katacv.yolov3.yolov3 import get_args_and_writer
from katacv.yolov3.yolov3_model import get_yolov3_state

class Args:
    B = 3
    C = 20
    seed = 0
    input_shape = (1,416,416,3)
    learning_rate = 1e-4

args = Args()
state = get_yolov3_state(args)
weights = ocp.PyTreeCheckpointer().restore("/home/wty/Coding/models/YOLOv3/YOLOv3-PASCAL-0040-lite")
print(weights.keys())
# weights = ocp.PyTreeCheckpointer().restore("/home/yy/Coding/models/YOLOv1/YOLOv1-0080-lite")
state = state.replace(params=weights['params'], batch_stats=weights['batch_stats'])

dict_keys(['batch_stats', 'opt_state', 'params', 'step'])


In [11]:
print(state.params.keys())

dict_keys(['ConvBlock_0', 'ConvBlock_1', 'DarkNet_0', 'ScalePredict_0', 'ScalePredict_1', 'ScalePredict_2', 'YOLOBlock_0', 'YOLOBlock_1', 'YOLOBlock_2'])


In [12]:
@jax.jit
def predict(x):
    outputs = state.apply_fn(
        {'params': state.params, 'batch_stats': state.batch_stats},
        x, train=False
    )
    return jax.device_get(outputs)
print("XLA compile...")
y = predict(jnp.empty((1,416,416,3), dtype='float32'))
print("Compile complete!")

XLA compile...


ScopeParamShapeError: Initializer expected to generate shape (1, 1, 1024, 255) but got shape (1, 1, 1024, 75) instead for parameter "kernel" in "/ScalePredict_0/ConvBlock_1/Conv_0". (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.ScopeParamShapeError)

In [None]:
import io, time
import ipywidgets as widgets
from IPython.display import display, clear_output
from PIL import Image, ImageEnhance
import numpy as np
import matplotlib.pyplot as plt
from katacv.utils.detection import nms, plot_box, iou, slice_by_idxs, cvt_coord_cell2image, get_best_boxes_and_classes
from katacv.utils.VOC.label2realname import label2realname

upload_button = widgets.FileUpload(description="上传图片")
iou_slider = widgets.FloatSlider(description="IOU阈值 (NMS)", min=0.0, max=1.0, step=0.01, value=0.3)
conf_slider = widgets.FloatSlider(description="最低置信度阈值", min=-0.5, max=2.0, step=0.01, value=0.2)
aux_output = widgets.Text("模型/识别用时：")

# @jax.jit  # optimizer use jit
def NMS(proba, boxes, iou_threshold, conf_threshold):
    start_time = time.time()
    cells = jnp.concatenate([proba, boxes], -1)
    # boxes = jax.jit(get_best_boxes_and_classes, static_argnums=[1,2])(cells, B=2, C=20)[0]
    boxes = get_best_boxes_and_classes(cells, B=2, C=20)[0]
    print("get boxes time:", time.time() - start_time)
    start_time = time.time()
    boxes = nms(boxes, iou_threshold=iou_threshold, conf_threshold=conf_threshold)
    print("nms time:", time.time() - start_time)
    return boxes

# 显示图片和调整亮度、对比度的函数
def show_image(upload, iou_threshold, conf_threshold):
    if upload:
        uploaded_file = upload[0]
        origin_image = Image.open(io.BytesIO(uploaded_file['content'])).convert('RGB')
        resize_image = origin_image.resize((448,448))
        # x = np.expand_dims(np.array(resize_image), 0)
        x = jnp.expand_dims(jnp.array(resize_image), 0).astype('float32')
        start_time = time.time()
        
        proba, boxes = predict(x)
        aux_output.value = f"模型用时：{time.time() - start_time:.4f} s"
        
        boxes = NMS(proba, boxes, iou_threshold, conf_threshold)
        
        aux_output.value += f" 识别总用时：{time.time() - start_time:.4f} s"
        
        fig, ax = plt.subplots(figsize=(15,8))
        ax.imshow(origin_image)
        for i in range(boxes.shape[0]):
            plot_box(ax, origin_image.size[::-1], boxes[i,1:5], text=f"{label2realname[int(boxes[i,5])]} {float(boxes[i,0]):.2f}")
        plt.show()
        # display(img)

interactive_output = widgets.interactive_output(
    show_image,
    {
        "upload": upload_button,
        "iou_threshold": iou_slider, 
        "conf_threshold": conf_slider,
    }
)

display(widgets.HBox([upload_button, iou_slider, conf_slider, aux_output]), interactive_output)
