In [1]:
import cv2
import numpy as np
import mediapipe as mp
from mediapipe.tasks import python
from mediapipe.tasks.python import vision
from mediapipe.framework.formats import landmark_pb2
from mediapipe.python.solutions import drawing_styles
import pandas as pd
import warnings
import urllib.request
warnings.filterwarnings("ignore") # Mainly for FutureWarning in pandas

In [2]:
# Download mediapipe pose landmarker heavy model
# model_url = "https://storage.googleapis.com/mediapipe-models/pose_landmarker/pose_landmarker_heavy/float16/latest/pose_landmarker_heavy.task"
filename = "../models/pose_landmarker_heavy.task"
# Download the file from `url` and save it locally under `filename`
# urllib.request.urlretrieve(model_url, filename)
# print(f"Model downloaded and saved as {filename}")
model_path = filename # Path to model file
video_source = '../videos/name_of_file.mp4' # Path to video file
output_name = '../keypoints/keypoints_file_name' # Path to output file (Note: this will be a .csv file with the number of the detected pose as the suffix)
num_poses = 4 # number of poses to detect; must be >= 1
min_pose_detection_confidence = 0.5
min_pose_presence_confidence = 0.5
min_tracking_confidence = 0.5
body_parts = [
    "nose", "left_eye_inner", "left_eye", "left_eye_outer", "right_eye_inner",
    "right_eye", "right_eye_outer", "left_ear", "right_ear", "mouth_left",
    "mouth_right", "left_shoulder", "right_shoulder", "left_elbow", "right_elbow",
    "left_wrist", "right_wrist", "left_pinky", "right_pinky", "left_index",
    "right_index", "left_thumb", "right_thumb", "left_hip", "right_hip",
    "left_knee", "right_knee", "left_ankle", "right_ankle", "left_heel",
    "right_heel", "left_foot_index", "right_foot_index"
]
base_options = python.BaseOptions(model_asset_path=model_path)
options = vision.PoseLandmarkerOptions(
    base_options=base_options,
    running_mode=vision.RunningMode.VIDEO,
    min_pose_detection_confidence=min_pose_detection_confidence,
    min_pose_presence_confidence=min_pose_presence_confidence,
    min_tracking_confidence=min_tracking_confidence,
    num_poses=num_poses
)

In [3]:
def draw_landmarks_on_image(rgb_image, detection_result):
    pose_landmarks_list = detection_result.pose_landmarks
    annotated_image = np.copy(rgb_image)

    # Loop through the detected poses to visualize.
    for idx in range(len(pose_landmarks_list)):
        pose_landmarks = pose_landmarks_list[idx]

        pose_landmarks_proto = landmark_pb2.NormalizedLandmarkList()
        pose_landmarks_proto.landmark.extend([
            landmark_pb2.NormalizedLandmark(
                x=landmark.x,
                y=landmark.y,
                z=landmark.z) for landmark in pose_landmarks
        ])
        mp.solutions.drawing_utils.draw_landmarks(
            annotated_image,
            pose_landmarks_proto,
            mp.solutions.pose.POSE_CONNECTIONS,
            mp.solutions.drawing_styles.get_default_pose_landmarks_style())
    return annotated_image

def create_empty_dataframe():
    columns = ['frame']
    for part in body_parts:
        columns.extend([f'{part}_x', f'{part}_y', f'{part}_z'])
    return pd.DataFrame(columns=columns)

pose_dfs = []

def append_keypoints_to_df(pose_landmarks_list, frame_index):
    for pose_idx, pose_landmarks in enumerate(pose_landmarks_list):
        row = [frame_index]
        for landmark in pose_landmarks:
            row.extend([landmark.x, landmark.y, landmark.z])
        if pose_idx >= len(pose_dfs):
            pose_dfs.append(create_empty_dataframe())
        pose_dfs[pose_idx] = pd.concat([pose_dfs[pose_idx], pd.DataFrame([row], columns=pose_dfs[pose_idx].columns)], ignore_index=True)

In [4]:
with vision.PoseLandmarker.create_from_options(options) as landmarker:
    cap = cv2.VideoCapture(video_source)
    frame_index = 0

    while cap.isOpened():
        success, image = cap.read()
        if not success:
            print("Image capture failed.")
            break

        # Convert the frame received from OpenCV to a MediaPipe’s Image object.
        mp_image = mp.Image(
            image_format=mp.ImageFormat.SRGB,
            data=cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
        timestamp_ms = int(cv2.getTickCount() / cv2.getTickFrequency() * 1000)

        detection_result = landmarker.detect_for_video(mp_image, timestamp_ms)
        if detection_result.pose_landmarks:
            append_keypoints_to_df(detection_result.pose_landmarks, frame_index)

        # Draw the pose landmarks on the image.
        annotated_image = draw_landmarks_on_image(image, detection_result)
        cv2.imshow('Pose Landmarks', annotated_image)

        if cv2.waitKey(5) & 0xFF == 27:
            break

        frame_index += 1

    cap.release()
    cv2.destroyAllWindows()

# Save the dataframes to CSV files
for i, df in enumerate(pose_dfs):
    df.to_csv(f'{output_name}_{i}_keypoints.csv', index=False)