In [1]:
# define directories and imports
import numpy as np
import cv2
import matplotlib.pyplot as plt
import torch
from tqdm import tqdm
from pathlib import Path
import os
import warnings
warnings.filterwarnings('ignore')
os.environ['CUDA_VISIBLE_DEVICES'] = "0"

openpose_dir = Path('./src/PoseEstimation/')

save_dir = Path('./data/target/')
save_dir.mkdir(exist_ok=True)
img_dir = save_dir.joinpath('images')
img_dir.mkdir(exist_ok=True)

In [7]:
# extract .png frames from the target video
# here we only extracted about 120+ frames
cap = cv2.VideoCapture(str(save_dir.joinpath('IMG_1501.mp4')))
i=0
while(cap.isOpened()):
    flag, frame = cap.read()
    if flag == False or i == 125:
        break

    cv2.imwrite(str(img_dir.joinpath(f'{i:05d}.png')), frame)
    i += 1

In [10]:
# use pretrained vgg19 model weights to generate pose estimations 
import sys
sys.path.append(str(openpose_dir))
sys.path.append('./src/utils')
# openpose
from rtpose_vgg import get_model
from coco_eval import get_multiplier, get_outputs

# utils/ load the pose estimation weights
from openpose_utils import remove_noise, get_pose
weight_name = './src/PoseEstimation/network/weight/pose_model.pth'
print('load model...')
model = get_model('vgg19')
model.load_state_dict(torch.load(weight_name))
model = torch.nn.DataParallel(model).cuda()
model.float()
model.eval()
pass

save_dir = Path('./data/target/')
save_dir.mkdir(exist_ok=True)

img_dir = save_dir.joinpath('images')
img_dir.mkdir(exist_ok=True)


# make label images for pix2pix
train_dir = save_dir.joinpath('train')
train_dir.mkdir(exist_ok=True)

train_img_dir = train_dir.joinpath('train_img')
train_img_dir.mkdir(exist_ok=True)
train_label_dir = train_dir.joinpath('train_label')
train_label_dir.mkdir(exist_ok=True)
train_head_dir = train_dir.joinpath('head_img')
train_head_dir.mkdir(exist_ok=True)

# crop the pictures into desired size and shape
pose_cords = []
for idx in tqdm(range(len(os.listdir(str(img_dir)))-1)):
    img_path = img_dir.joinpath('{:05}.png'.format(idx))
    img = cv2.imread(str(img_path))
    shape_dst = np.min(img.shape[:2])
    oh = (img.shape[0] - shape_dst) // 2
    ow = (img.shape[1] - shape_dst) // 2

    img = img[oh:oh + shape_dst+200, ow:ow + shape_dst+200]
    img = cv2.resize(img, (512, 512))
    multiplier = get_multiplier(img)
    with torch.no_grad():
        paf, heatmap = get_outputs(multiplier, img, model, 'rtpose')
    r_heatmap = np.array([remove_noise(ht)
                          for ht in heatmap.transpose(2, 0, 1)[:-1]]).transpose(1, 2, 0)
    heatmap[:, :, :-1] = r_heatmap
    param = {'thre1': 0.1, 'thre2': 0.05, 'thre3': 0.5}
    # get_pose
    label, cord = get_pose(param, heatmap, paf)
    index = 13
    crop_size = 25
    try:
        head_cord = cord[index]
    except:
        head_cord = pose_cords[-1] # if there is not head point in picture, use last frame

    pose_cords.append(head_cord)
    head = img[int(head_cord[1] - crop_size): int(head_cord[1] + crop_size),
           int(head_cord[0] - crop_size): int(head_cord[0] + crop_size), :]
    plt.imshow(head)
    plt.savefig(str(train_head_dir.joinpath('pose_{}.jpg'.format(idx))))
    plt.clf()
    cv2.imwrite(str(train_img_dir.joinpath('{:05}.png'.format(idx))), img)
    cv2.imwrite(str(train_label_dir.joinpath('{:05}.png'.format(idx))), label)

pose_cords = np.array(pose_cords, dtype=np.int)
np.save(str((save_dir.joinpath('pose.npy'))), pose_cords)
torch.cuda.empty_cache()


load model...
Bulding VGG19




  0%|          | 0/124 [00:00<?, ?it/s][A[A

  1%|          | 1/124 [00:12<25:39, 12.51s/it][A[A

  2%|▏         | 2/124 [00:24<25:08, 12.37s/it][A[A

  2%|▏         | 3/124 [00:36<24:24, 12.10s/it][A[A

  3%|▎         | 4/124 [00:47<24:03, 12.03s/it][A[A

  4%|▍         | 5/124 [00:59<23:21, 11.78s/it][A[A

  5%|▍         | 6/124 [01:10<22:57, 11.67s/it][A[A

  6%|▌         | 7/124 [01:22<22:43, 11.66s/it][A[A

  6%|▋         | 8/124 [01:33<22:26, 11.61s/it][A[A

  7%|▋         | 9/124 [01:45<22:13, 11.59s/it][A[A

  8%|▊         | 10/124 [01:56<22:00, 11.59s/it][A[A

  9%|▉         | 11/124 [02:08<21:42, 11.52s/it][A[A

 10%|▉         | 12/124 [02:19<21:33, 11.55s/it][A[A

 10%|█         | 13/124 [02:31<21:29, 11.61s/it][A[A

 11%|█▏        | 14/124 [02:42<21:10, 11.55s/it][A[A

 12%|█▏        | 15/124 [02:54<21:07, 11.63s/it][A[A

 13%|█▎        | 16/124 [03:05<20:38, 11.47s/it][A[A

 14%|█▎        | 17/124 [03:17<20:35, 11.55s/it][A[A

 15%|█▍ 

<Figure size 432x288 with 0 Axes>