In [4]:
import cv2  # 代入OpenCV模块
from PIL import Image
import numpy as np
import torch
import os

from getBBox import get_BBox
from face_alignment.detection.sfd.sfd_detector import SFDDetector

# get face detector
face_detector = SFDDetector(device='cuda')

# 全局变量
video_path = '/home/wenchi/zxy/HSD/jjk_video_2.mp4'  # 视频地址
extract_root_path = '/data1/wc_log/zxy/custom_dataset'  # 存放帧图片的位置
extract_frequency = 6  # 帧提取频率
batch_size = 2

In [2]:
def extract_frames(video_path):
    # 实例化视频对象
    video = cv2.VideoCapture(video_path)
    frame_count = 0

    frame_list = []

    # 循环遍历视频中的所有帧
    while True:
        # 逐帧读取
        _, frame = video.read()
        if frame is None:
            break
        # 按照设置的频率保存图片
        if frame_count % extract_frequency == 0:
            frame_list.append(frame)
        frame_count += 1  # 读取视频帧数＋1

    return frame_list


def crop_image(bbox, img, reshape_size = 512):
    h, w, _ = img.shape
    x1, y1, x2, y2 = bbox[0], bbox[1], bbox[2], bbox[3]

    center_point = [int((x1 + x2) / 2), int((y1 + y2) / 2)] ## recalculate the center point
    expand_size = int((y2 - y1)) # expand_size -- half of the total crop size
    crop_size = expand_size * 2

    new_x1 = center_point[0] - expand_size
    new_x2 = center_point[0] + expand_size
    new_y1 = center_point[1] - expand_size
    new_y2 = center_point[1] + expand_size

    (crop_left, origin_left) = (0, new_x1) if new_x1 >= 0 else (-new_x1, 0)
    (crop_right, origin_right) = (crop_size, new_x2) if new_x2 <= w else (w-new_x1, w)
    (crop_top, origin_top) = (0, new_y1) if new_y1 >= 0 else (-new_y1, 0)
    (crop_bottom, origin_bottom) = (crop_size, new_y2) if new_y2 <= h else (h-new_y1, h)

    aligned_img = np.zeros((crop_size, crop_size, 3), dtype=np.uint8)
    aligned_img[crop_top:crop_bottom, crop_left:crop_right] = img[origin_top:origin_bottom, origin_left:origin_right]
    aligned_img = Image.fromarray(aligned_img)
    aligned_img = aligned_img.resize((reshape_size, reshape_size))
    aligned_img = np.asarray(aligned_img)
    return aligned_img

def save_croped_images(frame_list, bboxlist, save_path):
    os.makedirs(save_path, exist_ok= True)
    index = 0

    for i in range(len(frame_list)):
        if i >= len(bboxlist):
            break
        bbox = bboxlist[i]
        if bbox is None or len(bbox) == 0:
            continue

        frame = crop_image(bbox, frame_list[i])
        # 设置保存文件名
        image_save_path = "{}/{}.png".format(save_path, str(index).zfill(8))
        # 保存图片
        cv2.imwrite(image_save_path, frame)
        index += 1  # 保存图片数＋1



In [3]:
frame_list = extract_frames(video_path)

imgs_numpy = np.asarray(frame_list)[..., ::-1].copy()
imgs_numpy = imgs_numpy.transpose(0, 3, 1, 2) # (ALL, 3, H, W)

In [5]:
# get bbox
imgs_tensor = torch.from_numpy(imgs_numpy).cuda()
bboxlist = get_BBox(imgs_tensor, face_detector, batch_size=batch_size) # (ALL, 4)

In [22]:
save_name = video_path.split('/')[-1]
save_name = save_name.split('.')[0]

save_path = os.path.join(extract_root_path, save_name)

save_croped_images(frame_list, bboxlist, save_path)
