In [None]:
pip install ultralytics

Collecting ultralytics
  Downloading ultralytics-8.2.70-py3-none-any.whl.metadata (41 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/41.3 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m41.0/41.3 kB[0m [31m22.6 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m41.3/41.3 kB[0m [31m801.7 kB/s[0m eta [36m0:00:00[0m
Collecting ultralytics-thop>=2.0.0 (from ultralytics)
  Downloading ultralytics_thop-2.0.0-py3-none-any.whl.metadata (8.5 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch>=1.8.0->ultralytics)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch>=1.8.0->ultralytics)
  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch>=1.8.0->ultr

In [None]:
import cv2
import numpy as np
import torch
from collections import deque
from ultralytics import YOLO
from tensorflow.keras.models import load_model
from tensorflow.keras.utils import Sequence
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
model = YOLO('yolov8n-pose.pt')
swing_detect_model = load_model('/content/drive/MyDrive/swing_detect.h5')
shot_class_model =load_model('/content/drive/MyDrive/swing_class_100frames.h5')

In [None]:
swing_detect_class=['swing_begin','swing_middle','swing_end']
shot_class = ['forehand_stroke', 'forehand_slice', 'forehand_volley', 'backhand_stroke', 'backhand_volley', 'backhand_slice']

In [None]:
def preprocess_frame(frame):
    # Resize the frame to a specific size (e.g., 640x640)
    frame = cv2.resize(frame, (640, 640))
    # Convert BGR to RGB
    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    # Normalize pixel values to [0, 1]
    frame = frame.astype(np.float32) / 255.0
    # Add batch dimension
    frame = np.expand_dims(frame, axis=0)
    # Permute dimensions to match the model's expected input (BCHW)
    frame = np.transpose(frame, (0, 3, 1, 2))
    return frame

def extract_keypoints(frames):
    all_keypoints = []
    device = torch.device('cpu')
    for frame in frames:
        frame = preprocess_frame(frame)
        frame = torch.tensor(frame, dtype=torch.float32).to(device)
        results = model(frame)
        if len(results) > 0 and len(results[0].keypoints) > 0:
            keypoints_list = results[0].keypoints
            bboxes = results[0].boxes
            if bboxes is not None and len(bboxes) > 0:
                try:
                    xyxy = bboxes.xyxy.cpu().numpy()  # get xyxy cordinates to find the closest peron to the camera.
                    areas = [(box[2] - box[0]) * (box[3] - box[1]) for box in xyxy]
                except IndexError as e:
                    print(f"Error calculating areas: {e}")
                    print(f"Bounding boxes: {bboxes}")
                    continue

                max_area_index = np.argmax(areas)   # get the index that contain the biggest bbox
                keypoints = keypoints_list[max_area_index].xy[0].cpu().numpy()  #get the keypoints from the biggest bbox
                all_keypoints.append(keypoints)
            else:
                print(f"Warning: No bounding boxes found in frame.")
        else:
            print(f"Warning: No keypoints found in frame.")

    return all_keypoints


def normalize_keypoints(keypoints):
    hip_index = 11  #  index of left hip
    shoulder_index = 5  # index of shoulder

    normalized_keypoints = []
    for frame_keypoints in keypoints:
        if len(frame_keypoints) > max(hip_index, shoulder_index):
            hip_point = frame_keypoints[hip_index]
            shoulder_point = frame_keypoints[shoulder_index]
            #  set hip point as (0,0) cordinates, set distance from hip to shoulder as 1 in every input.
            relative_points = frame_keypoints - hip_point
            # scaling
            scale_factor = np.linalg.norm(shoulder_point - hip_point)
            if scale_factor != 0:
                relative_points /= scale_factor
            normalized_keypoints.append(relative_points)
        else:
            print(f"Warning: Frame with insufficient keypoints detected. Skipping this frame.")

    return np.array(normalized_keypoints)



In [None]:
class DataGenerator(Sequence):
    def __init__(self, x_set, batch_size):
        self.x = x_set
        self.y= 0
        self.batch_size = batch_size

    def __len__(self):
        return int(np.ceil(len(self.x) / float(self.batch_size)))

    def __getitem__(self, idx):
        batch_x = self.x[idx * self.batch_size:(idx + 1) * self.batch_size]
        return batch_x




In [None]:
video_path = "/content/drive/MyDrive/test_video.mp4"

In [None]:
cap = cv2.VideoCapture(video_path)
frame_buffer = deque(maxlen=35)
swing_frames = deque(maxlen=100)
is_recording_swing = False


while cap.isOpened():
  ret, frame = cap.read()
  if not ret:
    break
  frame_buffer.append(frame)
  if len(frame_buffer) == 35:
    if not is_recording_swing:
      keypoints = extract_keypoints(frame_buffer)
      if len(keypoints) > 0:
        normalized_keypoints = normalize_keypoints(keypoints)
        #change keypoints to the numpy
        X_array = np.array(normalized_keypoints)
        X_array = X_array.reshape(X_array.shape[0], X_array.shape[1], -1)
        #print(X_array.shape)
        generator = DataGenerator(X_array, batch_size=16)
        predictions = swing_detect_model.predict(generator)



    # if is_recording_swing = True


cap.release()




0: 640x640 1 person, 251.6ms
Speed: 1.5ms preprocess, 251.6ms inference, 4.2ms postprocess per image at shape (1, 3, 640, 640)

0: 640x640 1 person, 202.7ms
Speed: 0.0ms preprocess, 202.7ms inference, 4.0ms postprocess per image at shape (1, 3, 640, 640)

0: 640x640 1 person, 220.2ms
Speed: 0.0ms preprocess, 220.2ms inference, 4.3ms postprocess per image at shape (1, 3, 640, 640)

0: 640x640 1 person, 220.4ms
Speed: 0.1ms preprocess, 220.4ms inference, 4.3ms postprocess per image at shape (1, 3, 640, 640)

0: 640x640 1 person, 200.2ms
Speed: 0.0ms preprocess, 200.2ms inference, 4.3ms postprocess per image at shape (1, 3, 640, 640)

0: 640x640 1 person, 204.2ms
Speed: 0.0ms preprocess, 204.2ms inference, 4.1ms postprocess per image at shape (1, 3, 640, 640)

0: 640x640 2 persons, 222.0ms
Speed: 0.0ms preprocess, 222.0ms inference, 4.5ms postprocess per image at shape (1, 3, 640, 640)

0: 640x640 2 persons, 210.7ms
Speed: 0.0ms preprocess, 210.7ms inference, 4.3ms postprocess per image 

InvalidArgumentError: Graph execution error:

Detected at node while/MatMul defined at (most recent call last):
  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main

  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code

  File "/usr/local/lib/python3.10/dist-packages/colab_kernel_launcher.py", line 37, in <module>

  File "/usr/local/lib/python3.10/dist-packages/traitlets/config/application.py", line 992, in launch_instance

  File "/usr/local/lib/python3.10/dist-packages/ipykernel/kernelapp.py", line 619, in start

  File "/usr/local/lib/python3.10/dist-packages/tornado/platform/asyncio.py", line 195, in start

  File "/usr/lib/python3.10/asyncio/base_events.py", line 603, in run_forever

  File "/usr/lib/python3.10/asyncio/base_events.py", line 1909, in _run_once

  File "/usr/lib/python3.10/asyncio/events.py", line 80, in _run

  File "/usr/local/lib/python3.10/dist-packages/tornado/ioloop.py", line 685, in <lambda>

  File "/usr/local/lib/python3.10/dist-packages/tornado/ioloop.py", line 738, in _run_callback

  File "/usr/local/lib/python3.10/dist-packages/tornado/gen.py", line 825, in inner

  File "/usr/local/lib/python3.10/dist-packages/tornado/gen.py", line 786, in run

  File "/usr/local/lib/python3.10/dist-packages/ipykernel/kernelbase.py", line 361, in process_one

  File "/usr/local/lib/python3.10/dist-packages/tornado/gen.py", line 234, in wrapper

  File "/usr/local/lib/python3.10/dist-packages/ipykernel/kernelbase.py", line 261, in dispatch_shell

  File "/usr/local/lib/python3.10/dist-packages/tornado/gen.py", line 234, in wrapper

  File "/usr/local/lib/python3.10/dist-packages/ipykernel/kernelbase.py", line 539, in execute_request

  File "/usr/local/lib/python3.10/dist-packages/tornado/gen.py", line 234, in wrapper

  File "/usr/local/lib/python3.10/dist-packages/ipykernel/ipkernel.py", line 302, in do_execute

  File "/usr/local/lib/python3.10/dist-packages/ipykernel/zmqshell.py", line 539, in run_cell

  File "/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py", line 2975, in run_cell

  File "/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py", line 3030, in _run_cell

  File "/usr/local/lib/python3.10/dist-packages/IPython/core/async_helpers.py", line 78, in _pseudo_sync_runner

  File "/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py", line 3257, in run_cell_async

  File "/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py", line 3473, in run_ast_nodes

  File "/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py", line 3553, in run_code

  File "<ipython-input-51-ebc6da651d9f>", line 22, in <cell line: 7>

  File "/usr/local/lib/python3.10/dist-packages/keras/src/utils/traceback_utils.py", line 65, in error_handler

  File "/usr/local/lib/python3.10/dist-packages/keras/src/engine/training.py", line 2655, in predict

  File "/usr/local/lib/python3.10/dist-packages/keras/src/engine/training.py", line 2440, in predict_function

  File "/usr/local/lib/python3.10/dist-packages/keras/src/engine/training.py", line 2425, in step_function

  File "/usr/local/lib/python3.10/dist-packages/keras/src/engine/training.py", line 2413, in run_step

  File "/usr/local/lib/python3.10/dist-packages/keras/src/engine/training.py", line 2381, in predict_step

  File "/usr/local/lib/python3.10/dist-packages/keras/src/utils/traceback_utils.py", line 65, in error_handler

  File "/usr/local/lib/python3.10/dist-packages/keras/src/engine/training.py", line 590, in __call__

  File "/usr/local/lib/python3.10/dist-packages/keras/src/utils/traceback_utils.py", line 65, in error_handler

  File "/usr/local/lib/python3.10/dist-packages/keras/src/engine/base_layer.py", line 1149, in __call__

  File "/usr/local/lib/python3.10/dist-packages/keras/src/utils/traceback_utils.py", line 96, in error_handler

  File "/usr/local/lib/python3.10/dist-packages/keras/src/engine/functional.py", line 515, in call

  File "/usr/local/lib/python3.10/dist-packages/keras/src/engine/functional.py", line 672, in _run_internal_graph

  File "/usr/local/lib/python3.10/dist-packages/keras/src/layers/rnn/base_rnn.py", line 556, in __call__

  File "/usr/local/lib/python3.10/dist-packages/keras/src/utils/traceback_utils.py", line 65, in error_handler

  File "/usr/local/lib/python3.10/dist-packages/keras/src/engine/base_layer.py", line 1149, in __call__

  File "/usr/local/lib/python3.10/dist-packages/keras/src/utils/traceback_utils.py", line 96, in error_handler

  File "/usr/local/lib/python3.10/dist-packages/keras/src/layers/rnn/lstm.py", line 749, in call

  File "/usr/local/lib/python3.10/dist-packages/keras/src/layers/rnn/lstm.py", line 1339, in lstm_with_backend_selection

  File "/usr/local/lib/python3.10/dist-packages/keras/src/layers/rnn/lstm.py", line 981, in standard_lstm

  File "/usr/local/lib/python3.10/dist-packages/keras/src/backend.py", line 5168, in rnn

  File "/usr/local/lib/python3.10/dist-packages/keras/src/backend.py", line 5147, in _step

  File "/usr/local/lib/python3.10/dist-packages/keras/src/layers/rnn/lstm.py", line 967, in step

  File "/usr/local/lib/python3.10/dist-packages/keras/src/backend.py", line 2463, in dot

Matrix size-incompatible: In[0]: [16,2], In[1]: [34,512]
	 [[{{node while/MatMul}}]]
	 [[model/lstm/PartitionedCall]] [Op:__inference_predict_function_8731]

In [None]:
def is_swing_start(keypoints_sequence):
    if keypoints_sequence is not None and len(keypoints_sequence) == 35:
        keypoints_sequence = np.array(keypoints_sequence)
        keypoints_sequence = np.expand_dims(keypoints_sequence, axis=0)
        prediction = swing_start_model.predict(keypoints_sequence)
        return prediction[0][0] > 0.5  # 閾値は適宜調整してください
    return False

def classify_swing(keypoints_sequence):
    if keypoints_sequence is not None and len(keypoints_sequence) == 30:
        keypoints_sequence = np.array(keypoints_sequence)
        keypoints_sequence = np.expand_dims(keypoints_sequence, axis=0)
        prediction = swing_classification_model.predict(keypoints_sequence)
        return np.argmax(prediction)
    return None

# 動画ファイルの読み込み
cap = cv2.VideoCapture('path/to/your/video.mp4')

frame_buffer = deque(maxlen=35)
swing_frames = deque(maxlen=30)
is_recording_swing = False

while cap.isOpened():
    ret, frame = cap.read()
    if not ret:
        break

    keypoints = extract_keypoints(frame)
    normalized_keypoints = normalize_keypoints(keypoints)

    if normalized_keypoints is not None:
        frame_buffer.append(normalized_keypoints)

        if len(frame_buffer) == 35:
            if is_swing_start(list(frame_buffer)):
                is_recording_swing = True
                swing_frames.clear()
                swing_frames.extend(list(frame_buffer)[-30:])  # 最後の30フレームを使用

        if is_recording_swing:
            swing_frames.append(normalized_keypoints)

            if len(swing_frames) == 30:
                swing_class = classify_swing(list(swing_frames))
                if swing_class is not None:
                    label = shot_types[swing_class]
                    cv2.putText(frame, label, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)

                is_recording_swing = False

    cv2.imshow('Swing Classification', frame)

    if cv2.waitKey(1) & 0xFF == ord('q'):
        break

    # 3フレームごとにframe_bufferの先頭を削除
    if len(frame_buffer) == 35:
        for _ in range(3):
            if frame_buffer:
                frame_buffer.popleft()

cap.release()
cv2.destroyAllWindows()

In [None]:
all_data = []
for label, shot_type in enumerate(shot_types):
    folder_path = f'/Users/yusuke.s/Documents/pickleball_videos/{shot_type}'
    all_data.extend(process_video(folder_path, label))