# 三叶青分类-ONNX Runtime部署-视频

使用 ONNX Runtime 推理引擎，载入自己训练得到的图像分类 onnx 模型，预测摄像头实时画面。

2024/5/1

## 导入工具包

In [1]:
import onnxruntime
import torch
import time

import cv2

from torchvision import transforms
import torch.nn.functional as F
import pandas as pd
import numpy as np
from PIL import Image, ImageFont, ImageDraw
import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
ort_session = onnxruntime.InferenceSession(r'D:\Train_Custom_Dataset\图像分类\7-ONNX Runtime图像分类部署\1-Pytorch图像分类模型转ONNX\resnet101_sanyeqing_10.onnx')
#测试时，我们只使用确定性的图像预处理操作
transform_test = transforms.Compose([transforms.Resize(256),
                                    # 从图像中心裁切224x224大小的图片
                                    transforms.CenterCrop(224),
                                    transforms.ToTensor(),
                                    transforms.Normalize([0.485, 0.456, 0.406],
                                                        [0.229, 0.224, 0.225])
                                    ])

## 处理单帧画面的函数


In [3]:
# 处理帧函数
def process_frame(img,n,idx_to_labels_csv):
    
    '''
    输入摄像头拍摄画面bgr-array，输出图像分类预测结果bgr-array
    '''
    
    # 记录该帧开始处理的时间
    start_time = time.time()
    
    ## 画面转成 RGB 的 Pillow 格式
    img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # BGR转RGB
    img_pil = Image.fromarray(img_rgb) # array 转 PIL

    ## 预处理
    input_img = transform_test(img_pil) # 预处理
    input_tensor = input_img.unsqueeze(0).numpy()
    
    ## onnx runtime 预测
    ort_inputs = {'input': input_tensor} # onnx runtime 输入
    pred_logits = ort_session.run(['output'], ort_inputs)[0] # onnx runtime 输出
    pred_logits = torch.tensor(pred_logits)
    pred_softmax = F.softmax(pred_logits, dim=1) # 对 logit 分数做 softmax 运算
    
    # 载入类别和对应 ID
    idx_to_labels = pd.read_csv(idx_to_labels_csv,encoding='utf-8')
    # 创建一个字典来存储真实标签
    idx_to_labels = {str(index): row['labels'] for index, row in idx_to_labels.iterrows()}  

    ## 解析top-n预测结果的类别和置信度
    top_n = torch.topk(pred_softmax, n) # 取置信度最大的 n 个结果
    pred_ids = top_n[1].cpu().detach().numpy().squeeze() # 解析出类别
    confs = top_n[0].cpu().detach().numpy().squeeze() # 解析出置信度
    
    ## 在图像上写中文
    draw = ImageDraw.Draw(img_pil) 
    for i in range(len(confs)):
        class_name = idx_to_labels[str(pred_ids[i])] # 获取类别名称
        text = '{:<15} {:>.3f}'.format(class_name, confs[i])
        # 文字坐标，中文字符串，字体，rgba颜色
        draw.text((50, 100 + 50 * i), text, fill=(255, 0, 0, 1))
    img = np.array(img_pil) # PIL 转 array
    img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) # RGB转BGR
    
    # 记录该帧处理完毕的时间
    end_time = time.time()
    # 计算每秒处理图像帧数FPS
    FPS = 1/(end_time - start_time)  
    # 图片，添加的文字，左上角坐标，字体，字体大小，颜色，线宽，线型
    img = cv2.putText(img, 'FPS  '+str(int(FPS)), (50, 80), cv2.FONT_HERSHEY_SIMPLEX, 2, (255, 0, 255), 4, cv2.LINE_AA)
    return img



## 调用摄像头获取每帧（模板）

In [15]:
# 调用摄像头逐帧实时处理模板
# 导入opencv-python
import cv2
import time

# 获取摄像头，传入0表示获取系统默认摄像头
cap = cv2.VideoCapture(1)

# 打开cap
cap.open(0)

# 无限循环，直到break被触发
while cap.isOpened():
    # 获取画面
    success, frame = cap.read()
    if not success:
        print('Error')
        break
    
    ## !!!处理帧函数
    frame = process_frame(frame,5,idx_to_labels_csv="D:\三叶青项目\wht-sanyeqing\idx_to_labels.csv")
    
    # 展示处理后的三通道图像
    cv2.imshow('my_window',frame)

    if cv2.waitKey(1) in [ord('q'),27]: # 按键盘上的q或esc退出（在英文输入法下）
        break
    
# 关闭摄像头
cap.release()

# 关闭图像窗口
cv2.destroyAllWindows()

### 按键盘上的`q`键退出

## 视频逐帧处理

In [16]:
import cv2
import numpy as np
import time
from tqdm import tqdm

def generate_video(input_path,n,idx_to_labels_csv):
    filehead = input_path.split('/')[-1]
    output_path = "out-" + filehead
    
    print('视频开始处理',input_path)
    
    # 获取视频总帧数
    cap = cv2.VideoCapture(input_path)
    frame_count = 0
    while(cap.isOpened()):
        success, frame = cap.read()
        frame_count += 1
        if not success:
            break
    cap.release()
    print('视频总帧数为',frame_count)
    
    # cv2.namedWindow('Crack Detection and Measurement Video Processing')
    cap = cv2.VideoCapture(input_path)
    frame_size = (cap.get(cv2.CAP_PROP_FRAME_WIDTH), cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

    # fourcc = int(cap.get(cv2.CAP_PROP_FOURCC))
    # fourcc = cv2.VideoWriter_fourcc(*'XVID')
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    fps = cap.get(cv2.CAP_PROP_FPS)

    out = cv2.VideoWriter(output_path, fourcc, fps, (int(frame_size[0]), int(frame_size[1])))
    
    # 进度条绑定视频总帧数
    with tqdm(total=frame_count-1) as pbar:
        try:
            while(cap.isOpened()):
                success, frame = cap.read()
                if not success:
                    break

                # # 处理帧
                # frame_path = './temp_frame.png'
                # cv2.imwrite(frame_path, frame)
                try:
                    frame = process_frame(frame,n,idx_to_labels_csv)
                except:
                    print('报错！', error)
                    pass
                
                if success == True:
                    cv2.imshow('Video Processing', frame)
                    out.write(frame)

                    # 进度条更新一帧
                    pbar.update(1)

                if cv2.waitKey(1) & 0xFF == ord('q'):
                    break
        except:
            print('中途中断')
            pass

    cv2.destroyAllWindows()
    out.release()
    cap.release()
    print('视频已保存', output_path)

In [17]:
generate_video(input_path=r"D:/vedio/2024-05-03-01-17-14.mp4",n=5,idx_to_labels_csv="D:\三叶青项目\wht-sanyeqing\idx_to_labels.csv")

视频开始处理 D:/vedio/2024-05-03-01-17-14.mp4
视频总帧数为 684


100%|██████████| 683/683 [02:09<00:00,  5.28it/s]

视频已保存 out-2024-05-03-01-17-14.mp4



