In [1]:
import torch
import cv2
import C3D_model
import numpy as np

In [2]:
def center_corp(frame):
    frame = frame[8:120,30:142,:]
    return np.array(frame).astype(np.uint8)

In [3]:
def center_crop(frame):
    height, width, _ = frame.shape
    new_height = 112
    new_width = 112
    top = (height - new_height) // 2
    left = (width - new_width) // 2
    return frame[top:top+new_height, left:left+new_width]

def inference():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")  # 定义模型的设备
    
    # 加载数据标签
    with open("./data/labels.txt", 'r') as f:
        class_names = f.readlines()
        f.close()
    # 加载模型并将模型参数加载到模型中
    model = C3D_model.C3D(num_classes=101)
    checkpoint = torch.load('./model_resule/models/C3D_epoch-14.pth.tar')
    model.load_state_dict(checkpoint['state_dict'])
    
    # 将模型放入到设备中，并设置验证模式
    model.to(device)
    model.eval()
    
    video = './data/testvideo_2.avi'
    cap = cv2.VideoCapture(video)
    retaining = True
    
    clip = []
    while retaining:
        retaining, frame = cap.read()  # 读取视频帧
        if not retaining or frame is None:
            continue
            
        tmp_ = center_crop(cv2.resize(frame, (171, 128)))  # resize 图片到 171*128 后把图片切成 112*112
        tmp = tmp_ - np.array([[[90.0, 98.0, 102.0]]])  # 归一化
        clip.append(tmp)  # 将视频图片帧加载到列表
        
        if len(clip) == 16:
            inputs = np.array(clip).astype(np.float32)
            inputs = np.expand_dims(inputs, axis=0)
            inputs = np.transpose(inputs, (0, 4, 1, 2, 3))
            inputs = torch.from_numpy(inputs)
            inputs = torch.autograd.Variable(inputs, requires_grad=False).to(device)
            
            with torch.no_grad():
                outputs = model.forward(inputs)
                
            probs = torch.nn.Softmax(dim=1)(outputs)
            label = torch.max(probs, 1)[1].detach().cpu().numpy()[0]  # 获取预测标签
            
            # 将预测结果显示到视频上
            cv2.putText(frame, class_names[label].split(' ')[-1].strip(), (20, 20),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 1)
            cv2.putText(frame, "prob: %.4f" % probs[0][label], (20, 40),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 1)
            
            clip.pop(0)
            
        cv2.imshow('C3D_model', frame)
        if cv2.waitKey(10) & 0xFF == ord(' '):  # 检测空格键
            break
    
    cap.release()
    cv2.destroyAllWindows()

In [4]:
if __name__ == "__main__":
    inference()