In [1]:
import sys
from pathlib import Path
sys.path.append(str(Path.cwd().parent))

In [2]:
from katacv.utils.related_pkgs.jax_flax_optax_orbax import *
from katacv.yolov1.yolov1 import get_yolov1_state

state = get_yolov1_state()
weights = ocp.PyTreeCheckpointer().restore("/home/wty/Coding/models/YOLOv1/YOLOv1-0080-lite")
state = state.replace(params=weights['params'], batch_stats=weights['batch_stats'])

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [3]:
@jax.jit
def predict(x):
    proba, boxes = state.apply_fn(
        {'params': state.params, 'batch_stats': state.batch_stats},
        x, train=False
    )
    return jax.device_get(proba), jax.device_get(boxes)

In [4]:
import io, time
import ipywidgets as widgets
from IPython.display import display, clear_output
from PIL import Image, ImageEnhance
import numpy as np
import matplotlib.pyplot as plt
from katacv.utils.detection import get_best_boxes_and_classes, nms, plot_box
from katacv.utils.VOC.label2realname import label2realname

upload_button = widgets.FileUpload(description="上传图片")
iou_slider = widgets.FloatSlider(description="IOU阈值 (NMS)", min=0.0, max=1.0, step=0.01, value=0.4)
conf_slider = widgets.FloatSlider(description="最低置信度阈值", min=-0.5, max=2.0, step=0.01, value=0.2)
aux_output = widgets.Text("识别用时：")

# 显示图片和调整亮度、对比度的函数
def show_image(upload, iou_threshold, conf_threshold):
    if upload:
        uploaded_file = upload[0]
        origin_image = Image.open(io.BytesIO(uploaded_file['content'])).convert('RGB')
        resize_image = origin_image.resize((448,448))
        x = np.expand_dims(np.array(resize_image), 0)
        start_time = time.time()
        
        proba, boxes = predict(x)
        cells = jnp.concatenate([proba, boxes], -1)
        boxes = get_best_boxes_and_classes(cells, S=7, B=2, C=20)[0]
        boxes = nms(boxes, iou_threshold=iou_threshold, conf_threshold=conf_threshold)
        
        aux_output.value = f"识别用时：{time.time() - start_time:.4f} s"
        
        fig, ax = plt.subplots(figsize=(15,8))
        ax.imshow(origin_image)
        for i in range(boxes.shape[0]):
            plot_box(ax, origin_image.size[::-1], boxes[i,1:5], text=f"{label2realname[int(boxes[i,5])]} {float(boxes[i,0]):.2f}")
        plt.show()
        # display(img)

interactive_output = widgets.interactive_output(
    show_image,
    {
        "upload": upload_button,
        "iou_threshold": iou_slider, 
        "conf_threshold": conf_slider,
    }
)

display(widgets.HBox([upload_button, iou_slider, conf_slider, aux_output]), interactive_output)


HBox(children=(FileUpload(value=(), description='上传图片'), FloatSlider(value=0.4, description='IOU阈值 (NMS)', max…

Output()