In [1]:
import cv2
from pytube import YouTube
from pathlib import Path

save_dir = Path('./data/source/')
save_dir.mkdir(exist_ok=True)

img_dir = save_dir.joinpath('images')
img_dir.mkdir(exist_ok=True)

# Download the Bruno Mars music video from youtube
yt = YouTube('https://www.youtube.com/watch?v=PMivT7MJ41M')
yt.streams.first().download(save_dir, 'mv')

# Extract .png frames from the video
cap = cv2.VideoCapture(str(save_dir.joinpath('mv.mp4')))

# Removed the first 125 frames to cut straight to dancing
i = -125
while(cap.isOpened()):
    flag, frame = cap.read()
    if flag == False or i == 1000: # each second is 25 frames
        break
    if (i>=0):
        cv2.imwrite(str(img_dir.joinpath(f'{i:05d}.png')), frame)
    i += 1

In [3]:
# use pose estimation to generate poses
import numpy as np
import matplotlib.pyplot as plt
import torch
from tqdm import tqdm

openpose_dir = Path('./src/PoseEstimation/')

import sys
sys.path.append(str(openpose_dir))
sys.path.append('./src/utils')

In [6]:
# openpose
from coco_eval import get_multiplier, get_outputs
from rtpose_vgg import get_model
# utils
from openpose_utils import remove_noise, get_pose
import os

weight_name = './src/PoseEstimation/network/weight/pose_model.pth'

model = get_model('vgg19')
model.load_state_dict(torch.load(weight_name))
model = torch.nn.DataParallel(model).cuda()
model.float()
model.eval()

# make label images for pix2pix
test_img_dir = save_dir.joinpath('test_img')
test_img_dir.mkdir(exist_ok=True)
test_label_dir = save_dir.joinpath('test_label_ori')
test_label_dir.mkdir(exist_ok=True)
test_head_dir = save_dir.joinpath('test_head_ori')
test_head_dir.mkdir(exist_ok=True)

pose_cords = []
for idx in tqdm(range(len(os.listdir(str(img_dir))))):
    img_path = img_dir.joinpath('{:05}.png'.format(idx))
    img = cv2.imread(str(img_path))
    shape_dst = np.min(img.shape[:2])
    oh = (img.shape[0] - shape_dst) // 2
    ow = (img.shape[1] - shape_dst) // 2

    img = img[oh:oh + shape_dst, ow:ow + shape_dst]
    img = cv2.resize(img, (512, 512))
    multiplier = get_multiplier(img)
    with torch.no_grad():
        paf, heatmap = get_outputs(multiplier, img, model, 'rtpose')
    r_heatmap = np.array([remove_noise(ht)
                          for ht in heatmap.transpose(2, 0, 1)[:-1]]) \
        .transpose(1, 2, 0)
    heatmap[:, :, :-1] = r_heatmap
    param = {'thre1': 0.1, 'thre2': 0.05, 'thre3': 0.5}
    label, cord = get_pose(param, heatmap, paf)
    index = 13
    crop_size = 25
    try:
        head_cord = cord[index]
    except:
        head_cord = pose_cords[-1] # if there is not head point in picture, use last frame

    pose_cords.append(head_cord)
    head = img[int(head_cord[1] - crop_size): int(head_cord[1] + crop_size),
           int(head_cord[0] - crop_size): int(head_cord[0] + crop_size), :]
    plt.imshow(head)
    plt.savefig(str(test_head_dir.joinpath('pose_{}.jpg'.format(idx))))
    plt.clf()
    cv2.imwrite(str(test_img_dir.joinpath('{:05}.png'.format(idx))), img)
    cv2.imwrite(str(test_label_dir.joinpath('{:05}.png'.format(idx))), label)
    if idx % 100 == 0 and idx != 0:
        pose_cords_arr = np.array(pose_cords, dtype=np.int)
        np.save(str((save_dir.joinpath('pose_source.npy'))), pose_cords_arr)

pose_cords_arr = np.array(pose_cords, dtype=np.int)
np.save(str((save_dir.joinpath('pose_source.npy'))), pose_cords_arr)
torch.cuda.empty_cache()


Bulding VGG19


  init.normal(m.weight, std=0.01)
  init.constant(m.bias, 0.0)
  init.normal(self.model1_1[8].weight, std=0.01)
  init.normal(self.model1_2[8].weight, std=0.01)
  init.normal(self.model2_1[12].weight, std=0.01)
  init.normal(self.model3_1[12].weight, std=0.01)
  init.normal(self.model4_1[12].weight, std=0.01)
  init.normal(self.model5_1[12].weight, std=0.01)
  init.normal(self.model6_1[12].weight, std=0.01)
  init.normal(self.model2_2[12].weight, std=0.01)
  init.normal(self.model3_2[12].weight, std=0.01)
  init.normal(self.model4_2[12].weight, std=0.01)
  init.normal(self.model5_2[12].weight, std=0.01)
  init.normal(self.model6_2[12].weight, std=0.01)
100%|██████████| 1000/1000 [2:04:17<00:00,  7.36s/it] 


<Figure size 432x288 with 0 Axes>