In [1]:
pip install ultralytics

Collecting ultralytics
  Downloading ultralytics-8.2.70-py3-none-any.whl.metadata (41 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/41.3 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m41.3/41.3 kB[0m [31m1.1 MB/s[0m eta [36m0:00:00[0m
Collecting ultralytics-thop>=2.0.0 (from ultralytics)
  Downloading ultralytics_thop-2.0.0-py3-none-any.whl.metadata (8.5 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch>=1.8.0->ultralytics)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch>=1.8.0->ultralytics)
  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch>=1.8.0->ultralytics)
  Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu1

In [2]:
import cv2
import numpy as np
import torch
from collections import deque
from ultralytics import YOLO
from tensorflow.keras.models import load_model
from tensorflow.keras.utils import Sequence
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
model = YOLO('yolov8n-pose.pt')
swing_detect_model = load_model('/content/drive/MyDrive/swing_detect.h5')
shot_class_model =load_model('/content/drive/MyDrive/swing_class_100frames.h5')

Downloading https://github.com/ultralytics/assets/releases/download/v8.2.0/yolov8l-pose.pt to 'yolov8l-pose.pt'...


100%|██████████| 85.3M/85.3M [00:00<00:00, 128MB/s]


In [4]:
swing_detect_class=['swing_begin','swing_middle','swing_end']
shot_class = ['forehand_stroke', 'forehand_slice', 'forehand_volley', 'backhand_stroke', 'backhand_volley', 'backhand_slice']

In [5]:
def preprocess_frame(frame):
    # Resize the frame to a specific size (e.g., 640x640)
    frame = cv2.resize(frame, (640, 640))
    # Convert BGR to RGB
    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    # Normalize pixel values to [0, 1]
    frame = frame.astype(np.float32) / 255.0
    # Add batch dimension
    frame = np.expand_dims(frame, axis=0)
    # Permute dimensions to match the model's expected input (BCHW)
    frame = np.transpose(frame, (0, 3, 1, 2))
    return frame

def extract_keypoints(frames):
    all_keypoints = []
    device = torch.device('cuda')
    for frame in frames:
        frame = preprocess_frame(frame)
        frame = torch.tensor(frame, dtype=torch.float32).to(device)
        results = model(frame,verbose=False)
        if len(results) > 0 and len(results[0].keypoints) > 0:
            keypoints_list = results[0].keypoints
            bboxes = results[0].boxes
            if bboxes is not None and len(bboxes) > 0:
                try:
                    xyxy = bboxes.xyxy.cpu().numpy()  # get xyxy cordinates to find the closest peron to the camera.
                    areas = [(box[2] - box[0]) * (box[3] - box[1]) for box in xyxy]
                except IndexError as e:
                    print(f"Error calculating areas: {e}")
                    print(f"Bounding boxes: {bboxes}")
                    continue

                max_area_index = np.argmax(areas)   # get the index that contain the biggest bbox
                keypoints = keypoints_list[max_area_index].xy[0].cpu().numpy()  #get the keypoints from the biggest bbox
                all_keypoints.append(keypoints)
            else:
                print(f"Warning: No bounding boxes found in frame.")
        else:
            print(f"Warning: No keypoints found in frame.")

    return all_keypoints


def normalize_keypoints(keypoints):
    hip_index = 11  #  index of left hip
    shoulder_index = 5  # index of shoulder

    normalized_keypoints = []
    for frame_keypoints in keypoints:
        if len(frame_keypoints) > max(hip_index, shoulder_index):
            hip_point = frame_keypoints[hip_index]
            shoulder_point = frame_keypoints[shoulder_index]
            #  set hip point as (0,0) cordinates, set distance from hip to shoulder as 1 in every input.
            relative_points = frame_keypoints - hip_point
            # scaling
            scale_factor = np.linalg.norm(shoulder_point - hip_point)
            if scale_factor != 0:
                relative_points /= scale_factor
            normalized_keypoints.append(relative_points)
        else:
            print(f"Warning: Frame with insufficient keypoints detected. Skipping this frame.")

    return np.array(normalized_keypoints)



In [6]:
class DataGenerator(Sequence):
    def __init__(self, x_set, batch_size):
        self.x = x_set
        self.batch_size = batch_size

    def __len__(self):
        return int(np.ceil(len(self.x) / float(self.batch_size)))

    def __getitem__(self, idx):
        batch_x = self.x[idx * self.batch_size:(idx + 1) * self.batch_size]
        return batch_x



In [7]:
video_path = "/content/drive/MyDrive/test_30s.mp4"

In [None]:
cap = cv2.VideoCapture(video_path)
frame_buffer = deque(maxlen=25)
swing_frames = deque(maxlen=100)
is_recording_swing = False


test_pred=[]

while cap.isOpened():
  ret, frame = cap.read()
  if not ret:
    break
  frame_buffer.append(frame)
  if len(frame_buffer) == 25:
    if not is_recording_swing:
      keypoints = extract_keypoints(frame_buffer)
      if len(keypoints) > 0:
        normalized_keypoints = normalize_keypoints(keypoints)
        #change keypoints to the numpy
        X_array = np.array(normalized_keypoints)
        X_array = X_array.reshape(X_array.shape[0], X_array.shape[1], -1)
        #print(X_array.shape)
        X_array = X_array.reshape((1, X_array.shape[0], X_array.shape[1] * X_array.shape[2]))
        generator = DataGenerator(X_array, batch_size=1)
        predictions = swing_detect_model.predict(generator)
        swing_prediction = np.argmax(predictions, axis=1)
        test_pred.append(swing_prediction[0])
        #print("swing prrediction is "+str(swing_prediction))

        for _ in range(5):
            if frame_buffer:
                frame_buffer.popleft()



    # if is_recording_swing = True


cap.release()





In [None]:
test_pred

In [None]:
def is_swing_start(keypoints_sequence):
    if keypoints_sequence is not None and len(keypoints_sequence) == 35:
        keypoints_sequence = np.array(keypoints_sequence)
        keypoints_sequence = np.expand_dims(keypoints_sequence, axis=0)
        prediction = swing_start_model.predict(keypoints_sequence)
        return prediction[0][0] > 0.5  # 閾値は適宜調整してください
    return False

def classify_swing(keypoints_sequence):
    if keypoints_sequence is not None and len(keypoints_sequence) == 30:
        keypoints_sequence = np.array(keypoints_sequence)
        keypoints_sequence = np.expand_dims(keypoints_sequence, axis=0)
        prediction = swing_classification_model.predict(keypoints_sequence)
        return np.argmax(prediction)
    return None

# 動画ファイルの読み込み
cap = cv2.VideoCapture('path/to/your/video.mp4')

frame_buffer = deque(maxlen=35)
swing_frames = deque(maxlen=30)
is_recording_swing = False

while cap.isOpened():
    ret, frame = cap.read()
    if not ret:
        break

    keypoints = extract_keypoints(frame)
    normalized_keypoints = normalize_keypoints(keypoints)

    if normalized_keypoints is not None:
        frame_buffer.append(normalized_keypoints)

        if len(frame_buffer) == 35:
            if is_swing_start(list(frame_buffer)):
                is_recording_swing = True
                swing_frames.clear()
                swing_frames.extend(list(frame_buffer)[-30:])  # 最後の30フレームを使用

        if is_recording_swing:
            swing_frames.append(normalized_keypoints)

            if len(swing_frames) == 30:
                swing_class = classify_swing(list(swing_frames))
                if swing_class is not None:
                    label = shot_types[swing_class]
                    cv2.putText(frame, label, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)

                is_recording_swing = False

    cv2.imshow('Swing Classification', frame)

    if cv2.waitKey(1) & 0xFF == ord('q'):
        break

    # 3フレームごとにframe_bufferの先頭を削除
    if len(frame_buffer) == 35:
        for _ in range(3):
            if frame_buffer:
                frame_buffer.popleft()

cap.release()
cv2.destroyAllWindows()

In [None]:
all_data = []
for label, shot_type in enumerate(shot_types):
    folder_path = f'/Users/yusuke.s/Documents/pickleball_videos/{shot_type}'
    all_data.extend(process_video(folder_path, label))