<a href="https://colab.research.google.com/github/yusuke-satani/swing_classification/blob/main/extract_keypoints.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import cv2
import numpy as np
from ultralytics import YOLO
import os

In [None]:
model = YOLO('yolov8n-pose.pt')

In [None]:
import torch

def extract_keypoints(frames):
    all_keypoints = []
    device = torch.device('mps')
    for frame in frames:
        frame = torch.tensor(frame, dtype=torch.float32).to(device)
        results = model(frame)

        if len(results) > 0 and len(results[0].keypoints) > 0:
            keypoints_list = results[0].keypoints
            bboxes = results[0].boxes  # list bounding boxes

            if bboxes is not None and len(bboxes) > 0:
                # calculate the area of bounding boxes
                try:
                    xyxy = bboxes.xyxy.cpu().numpy()  # Get bounding box coordinates
                    areas = [(box[2] - box[0]) * (box[3] - box[1]) for box in xyxy]
                except IndexError as e:
                    print(f"Error calculating areas: {e}")
                    print(f"Bounding boxes: {bboxes}")
                    continue

                # Get the index of the largest bounding box
                max_area_index = np.argmax(areas)

                # Get the keypoint corresponding to the largest bounding box
                keypoints = keypoints_list[max_area_index].xy[0].cpu().numpy()

                all_keypoints.append(keypoints)
            else:
                print(f"Warning: No bounding boxes found in frame.")
        else:
            print(f"Warning: No keypoints found in frame.")
    return all_keypoints


def normalize_keypoints(keypoints):
    hip_index = 11  #  index of left hip
    shoulder_index = 5  # index of shoulder

    normalized_keypoints = []
    for frame_keypoints in keypoints:
        if len(frame_keypoints) > max(hip_index, shoulder_index):
            hip_point = frame_keypoints[hip_index]
            shoulder_point = frame_keypoints[shoulder_index]

            # Calculate the point of the hip as the origin (0,0) and other points as relative positions to make the model robust
            relative_points = frame_keypoints - hip_point
            scale_factor = np.linalg.norm(shoulder_point - hip_point)
            if scale_factor != 0:
                relative_points /= scale_factor

            normalized_keypoints.append(relative_points)
        else:
            print(f"Warning: Frame with insufficient keypoints detected. Skipping this frame.")

    return np.array(normalized_keypoints)

def process_folder(folder_path, label):
    print(f"Processing folder: {folder_path}")
    video_data = []
    for video_file in os.listdir(folder_path):
        if video_file.endswith(('.mp4', '.avi', '.mov')):
            video_path = os.path.join(folder_path, video_file)
            frames = process_video(video_path)
            keypoints = extract_keypoints(frames)
            if len(keypoints) > 0:
                normalized_keypoints = normalize_keypoints(keypoints)
                if len(normalized_keypoints) > 0:
                    video_data.append((normalized_keypoints, label))
                else:
                    print(f"Warning: No valid keypoints found in video {video_file}")
            else:
                print(f"Warning: No keypoints detected in video {video_file}")
    return video_data
    print(f"Number of videos processed: {len(video_data)}")

shot_types = ['forehand_stroke','forehand_slice','forehand_volley', 'backhand_stroke', 'backhand_volley', 'backhand_slice']
all_data = []
for label, shot_type in enumerate(shot_types):
    folder_path = f'file_path{shot_type}'
    all_data.extend(process_folder(folder_path, label))

In [None]:
import numpy as np
# Convert to numpy array
all_data_array = np.array(all_data, dtype=object)
# Save to .npy file
np.save('swing_class.npy', all_data_array)

"import numpy as np\n# Convert to numpy array\nall_data_array = np.array(all_data, dtype=object)\n# Save to .npy file\nnp.save('all_data.npy', all_data_array)"