In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
import numpy as np
import cv2
import os
import sys
import time

import torch
from torch import nn

module_dir = '/content/drive/MyDrive/pose_estimation/keypoint_module/'

sys.path.append(module_dir)

from models.yolo import *
from models.hrnet import *
from utils.detector import *

In [None]:
def predict(file_path, pred_path, module_dir, draw_bbox=False, box_tr=0.7):

  # file_path - absolute path to file
  # pred_path - absolute path for prediction
  # module_dir - path for module folder
  # draw_bbox - draw bboxes or not
  # box_tr - threshold for bbox confidence

  image_formats = ['.jpg', '.png', '.jpeg', '.bmp']
  video_formats = ['.mp4', '.mov', '.avi', '.webm', '.mkv', '.m4v']
  file_format = file_path[file_path.rindex('.'):].lower()

  device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

  yolov5 = load_yolo_model(module_dir).to(device)
  keypoint_net = load_keypoint_net(module_dir).to(device)

  if file_format in image_formats:
    
    pred_path = predict_image(file_path, pred_path, yolov5, keypoint_net, device,
                              draw_bbox=draw_bbox, box_tr=box_tr)

  elif file_format in video_formats:

    pred_path = predict_video(file_path, pred_path, yolov5, keypoint_net, device,
                              draw_bbox=draw_bbox, box_tr=box_tr)

  else:
    print('Unknown file format')

  return pred_path

In [None]:
files_dir = '/content/drive/MyDrive/pose_estimation/keypoint_module/test_examples/'
result_dir = '/content/drive/MyDrive/pose_estimation/keypoint_module/results/'

In [None]:
for filename in os.listdir(files_dir):
  curr_time = time.time()
  file_path = os.path.join(files_dir, filename)
  pred_path = result_dir + file_path[file_path.rindex('/')+1:file_path.rindex('.')]+'_predict.webm'
  pred_path = predict(file_path, pred_path, module_dir)
  print(pred_path, '--  ready, time:', round(time.time() - curr_time), 's')

  if param.grad is not None:


/content/drive/MyDrive/pose_estimation/keypoint_module/results/hands1_predict.webm --  ready, time: 40 s
/content/drive/MyDrive/pose_estimation/keypoint_module/results/hands2_predict.webm --  ready, time: 14 s
/content/drive/MyDrive/pose_estimation/keypoint_module/results/hands5_predict.webm --  ready, time: 16 s
/content/drive/MyDrive/pose_estimation/keypoint_module/results/hands6_predict.webm --  ready, time: 18 s
/content/drive/MyDrive/pose_estimation/keypoint_module/results/hands7_predict.webm --  ready, time: 19 s
/content/drive/MyDrive/pose_estimation/keypoint_module/results/hands10_predict.webm --  ready, time: 21 s
/content/drive/MyDrive/pose_estimation/keypoint_module/results/hands11_predict.webm --  ready, time: 19 s
/content/drive/MyDrive/pose_estimation/keypoint_module/results/hands12_predict.webm --  ready, time: 16 s
/content/drive/MyDrive/pose_estimation/keypoint_module/results/hands14_predict.webm --  ready, time: 22 s
/content/drive/MyDrive/pose_estimation/keypoint_mod