# Pose Estimation

In [None]:
!git clone https://www.github.com/WongKinYiu/yolov7.git

In [None]:
import cv2
import numpy as np
import os
import sys
import tqdm
import h5py
import gc

## utils_preprocess

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from os import listdir
from os.path import isfile, join
import cv2 as cv
import torch
from torchvision import transforms
import sys

sys.path.insert(1, 'yolov7/') 

from utils.datasets import letterbox
from utils.general import non_max_suppression_kpt
from utils.plots import output_to_keypoint, plot_skeleton_kpts

In [None]:
def get_file_names_from_dir(dir: str):
    return [dir + f for f in listdir(dir) if isfile(join(dir, f))]


def read_video_from_file(file: str):
    return cv.VideoCapture(file)


def play_video(video):
    video_window = 'video_window'
    cv.namedWindow(video_window)
    while True:
        ret, frame = video.read()  # read a single frame
        if not ret:
            print("EOV or Could not read the frame")
            cv.destroyWindow(video_window)
            break
        reescaled_frame = frame
        #     for i in range(scaleLevel-1):
        #         reescaled_frame = cv.pyrDown(reescaled_frame)
        cv.imshow(video_window, reescaled_frame)
        waitKey = (cv.waitKey(1) & 0xFF)
        if waitKey == ord('q'):  # if Q pressed you could do something else with other keypress
            cv.destroyWindow(video_window)
            video.release()
            break


def break_video_into_frames(video, p=False):
    frames = []
    success = True
    while success:
        success, frame = video.read()
        frames.append(frame)
    if (p):
        print(f'Readed {len(frames)} frames.')
    return frames


def display_frame(frame):
    cv.imshow('Frame:', frame)
    cv.waitKey(0)
    cv.destroyAllWindows()


def play_frames(frames):
    video_window = 'video_window'
    cv.namedWindow(video_window)
    for f in frames:
        cv.imshow(video_window, f)
        waitKey = (cv.waitKey(1) & 0xFF)
        if waitKey == ord('q'):  # if Q pressed you could do something else with other keypress
            cv.destroyWindow(video_window)
            break


def loop_frames(frames):
    video_window = 'video_window'
    cv.namedWindow(video_window)
    last_frame_index = len(frames) - 1
    i = 0
    while i != last_frame_index:
        if (i + 1) == last_frame_index:
            i = 0
        cv.imshow(video_window, frames[i])
        i = i + 1
        waitKey = (cv.waitKey(1) & 0xFF)
        if waitKey == ord('q'):  # if Q pressed you could do something else with other keypress
            cv.destroyWindow(video_window)
            break


def load_model(model_path):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = torch.load(model_path, map_location=device)['model']
    # Put in inference mode
    model.float().eval()

    if torch.cuda.is_available():
        print('using gpu')
        # half() turns predictions into float16 tensors
        # which significantly lowers inference time
        model.half().to(device)
    return model


def run_inference(frame, model):
    # Resize and pad image
    frame_letter = letterbox(frame, 960, stride=64, auto=True)[0] # shape: (768, 960, 3)
    # Apply transforms
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    frame_letter = transforms.ToTensor()(frame_letter).half().to(device) # torch.Size([3, 768, 960])
    # Turn image into batch
    frame_unsqueeze = frame_letter.unsqueeze(0) # torch.Size([1, 3, 768, 960])
    with torch.no_grad():
        output, _ = model(frame_unsqueeze) # torch.Size([1, 45900, 57])
    del frame, frame_letter
    del model
    return output, frame_unsqueeze


def visualize_output(output, image, save_path, model, show = False):
    plt.clf()
    nimg = image[0].permute(1, 2, 0) * 255
    nimg = nimg.cpu().numpy().astype(np.uint8)
    nimg = cv.cvtColor(nimg, cv.COLOR_RGB2BGR)
    print(f'the output shape is: {output.shape}')
    for idx in range(output.shape[0]):
        plot_skeleton_kpts(nimg, output[idx, 7:].T, 3)
    plt.figure(figsize=(12, 12))
    plt.axis('off')
    plt.imshow(nimg)
    plt.savefig(save_path)
    if show:
        plt.show()

    

#supress keypoints below the confidence threshold 
def supress_kpt(unfiltered_output, model):
    output = non_max_suppression_kpt(unfiltered_output, 
                0.25, # Confidence Threshold
                0.65, # IoU Threshold
                nc=model.yaml['nc'], # Number of Classes
                nkpt=model.yaml['nkpt'], # Number of Keypoints
                kpt_label=True)
    return output

## main

In [None]:
VIOLENT_PATH = 'data/train/Fight/'
NON_VIOLENT_PATH = 'data/train/NonFight'
IMAGES_OUT_FOLDER = 'poses_visualsition/'
H5_OUT_PATH = 'data/poses_pickle.h5'
POSE_ESTIMATION_MODEL_PATH = 'yolov7/yolov7-w6-pose.pt'

if __name__ == "__main__":
    with h5py.File(H5_OUT_PATH, 'w') as hf:
        violent_videos_files = get_file_names_from_dir(VIOLENT_PATH)
        non_violent_videos_files = get_file_names_from_dir(NON_VIOLENT_PATH)
        classes_paths = {'violent': VIOLENT_PATH, 'non_viloent': NON_VIOLENT_PATH}

        video = read_video_from_file(violent_videos_files[-1])
        video_frames = break_video_into_frames(video, True)
        # loop_frames(video_frames)

        # create dictionary with the output pose embeddings
        
        pose_estimation_model = load_model(POSE_ESTIMATION_MODEL_PATH)

        # Folder names are used as pose class names.
        for class_name, class_path in classes_paths.items():
            if not os.path.exists(os.path.join(IMAGES_OUT_FOLDER, class_name)):
                os.makedirs(os.path.join(IMAGES_OUT_FOLDER, class_name))
            
            class_videos_files = get_file_names_from_dir(class_path)
            images_out_folder_class = os.path.join(IMAGES_OUT_FOLDER, class_name)
            for video_file in tqdm.tqdm(class_videos_files, position=0):
                video = read_video_from_file(video_file)
                video_frames = break_video_into_frames(video, True)
                
                video_name = os.path.basename(video_file)
                video_name = os.path.splitext(video_name)[0]
                frame_num = 0
                for frame in video_frames:
                    if frame is not None:
                        # create a unique name for current frame
                        image_name = f'{video_name}_{frame_num}.png'
                        # print(image_name)
                        visualize_plot_path = os.path.join(images_out_folder_class, image_name)

                        # run pose estimation on frame
                        output, image =  run_inference(frame,pose_estimation_model) 
                        # output =  run_inference(frame,pose_estimation_model) 
                        # filter keypoints with low score
                        supressed_output = supress_kpt(output, pose_estimation_model)
                        with torch.no_grad():
                            keypoints = output_to_keypoint(supressed_output)
                        # print(type(keypoints))
                        
                        # plot the pose estimation on top of the frame
                        visualize_output(keypoints, image, visualize_plot_path, pose_estimation_model, show = False)
                        # del output, image
                        torch.cuda.empty_cache()
                        gc.collect()


                        # add output to dataframe
                        hf.create_dataset(f"{class_name}/{video_name}/{frame_num}",  data= keypoints)
                        
                        frame_num = frame_num + 1