In [1]:
pip install ultralytics

Collecting ultralytics
  Downloading ultralytics-8.2.70-py3-none-any.whl.metadata (41 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/41.3 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m41.0/41.3 kB[0m [31m16.9 MB/s[0m eta [36m0:00:01[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m41.0/41.3 kB[0m [31m16.9 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m41.3/41.3 kB[0m [31m298.5 kB/s[0m eta [36m0:00:00[0m
Collecting ultralytics-thop>=2.0.0 (from ultralytics)
  Downloading ultralytics_thop-2.0.0-py3-none-any.whl.metadata (8.5 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch>=1.8.0->ultralytics)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch>=1.8.0->ultralytics)
  Using cached nvidia_cuda_runtime_c

In [2]:
import cv2
import numpy as np
import torch
from collections import deque
from ultralytics import YOLO
from tensorflow.keras.models import load_model
from tensorflow.keras.utils import Sequence
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
model = YOLO('yolov8n-pose.pt')
swing_detect_model = load_model('/content/drive/MyDrive/swing_detect.h5')
shot_class_model =load_model('/content/drive/MyDrive/swing_class_100frames.h5')

Downloading https://github.com/ultralytics/assets/releases/download/v8.2.0/yolov8n-pose.pt to 'yolov8n-pose.pt'...


100%|██████████| 6.52M/6.52M [00:00<00:00, 253MB/s]


In [4]:
swing_detect_class=['swing_begin','swing_middle','swing_end']
shot_class = ['forehand_stroke', 'forehand_slice', 'forehand_volley', 'backhand_stroke', 'backhand_volley', 'backhand_slice']

In [11]:
def preprocess_frame(frame):
    # Resize the frame to a specific size (e.g., 640x640)
    frame = cv2.resize(frame, (640, 640))
    # Convert BGR to RGB
    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    # Normalize pixel values to [0, 1]
    frame = frame.astype(np.float32) / 255.0
    # Add batch dimension
    frame = np.expand_dims(frame, axis=0)
    # Permute dimensions to match the model's expected input (BCHW)
    frame = np.transpose(frame, (0, 3, 1, 2))
    return frame

def extract_keypoints(frames):
    all_keypoints = []
    device = torch.device('cuda')
    annotated_frames = []
    for frame in frames:
        frame = preprocess_frame(frame)
        frame = torch.tensor(frame, dtype=torch.float32).to(device)
        results = model(frame,verbose=False)
        annotated_frame = results[0].plot()

        if len(results) > 0 and len(results[0].keypoints) > 0:
            keypoints_list = results[0].keypoints
            bboxes = results[0].boxes
            if bboxes is not None and len(bboxes) > 0:
                try:
                    xyxy = bboxes.xyxy.cpu().numpy()  # get xyxy cordinates to find the closest peron to the camera.
                    areas = [(box[2] - box[0]) * (box[3] - box[1]) for box in xyxy]
                except IndexError as e:
                    print(f"Error calculating areas: {e}")
                    print(f"Bounding boxes: {bboxes}")
                    continue

                max_area_index = np.argmax(areas)   # get the index that contain the biggest bbox
                keypoints = keypoints_list[max_area_index].xy[0].cpu().numpy()  #get the keypoints from the biggest bbox
                all_keypoints.append(keypoints)

                # キーポイントを描画（オプション）
                for point in keypoints:
                    x, y = point[:2]
                    cv2.circle(annotated_frame, (int(x), int(y)), 5, (0, 255, 0), -1)
                    annotated_frames.append(annotated_frame)
            else:
                print(f"Warning: No bounding boxes found in frame.")
        else:
            print(f"Warning: No keypoints found in frame.")

       # out.write(annotated_frame)  #これswingの結果も書く必要があるので後に持っていく.

    return all_keypoints, annotated_frames  #ここで次にannotated_framesを次に繋げる

def normalize_keypoints(keypoints):
    hip_index = 11  #  index of left hip
    shoulder_index = 5  # index of shoulder

    normalized_keypoints = []
    for frame_keypoints in keypoints:
        if len(frame_keypoints) > max(hip_index, shoulder_index):
            hip_point = frame_keypoints[hip_index]
            shoulder_point = frame_keypoints[shoulder_index]
            #  set hip point as (0,0) cordinates, set distance from hip to shoulder as 1 in every input.
            relative_points = frame_keypoints - hip_point
            # scaling
            scale_factor = np.linalg.norm(shoulder_point - hip_point)
            if scale_factor != 0:
                relative_points /= scale_factor
            normalized_keypoints.append(relative_points)
        else:
            print(f"Warning: Frame with insufficient keypoints detected. Skipping this frame.")

    return np.array(normalized_keypoints)



In [6]:
class DataGenerator(Sequence):
    def __init__(self, x_set, batch_size):
        self.x = x_set
        self.batch_size = batch_size

    def __len__(self):
        return int(np.ceil(len(self.x) / float(self.batch_size)))

    def __getitem__(self, idx):
        batch_x = self.x[idx * self.batch_size:(idx + 1) * self.batch_size]
        return batch_x



In [7]:
video_path = "/content/drive/MyDrive/test_30s.mp4"

In [8]:
from google.colab.patches import cv2_imshow

In [37]:
def add_text (annotated_frames):   #ここがannotated_framesではなくswing=2になるまでのannotated_framesをすべて含んでいたやつにtextを入れると効率的
  texted_frames =[]
  for text_frame in annotated_frames:
    cv2.putText(text_frame, "Swing", (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
    texted_frames.append(text_frame)
    return texted_frames




In [46]:
import cv2
import numpy as np
from collections import deque

def add_text(annotated_frames):
    texted_frames = []
    for text_frame in annotated_frames:
        cv2.putText(text_frame, "Swing", (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
        texted_frames.append(text_frame)
    return texted_frames

cap = cv2.VideoCapture(video_path)
frame_buffer = deque(maxlen=25)
is_recording_swing = False
swing_record = []

# VideoWriterの設定
fourcc = cv2.VideoWriter_fourcc(*'mp4v')  # または適切なコーデック
fps = cap.get(cv2.CAP_PROP_FPS)
width, height = 640, 640
out = cv2.VideoWriter('output.mp4', fourcc, fps, (width, height))

if not out.isOpened():
    print("Error: VideoWriter not opened.")
    cap.release()
    cv2.destroyAllWindows()
    exit()

while cap.isOpened():
    ret, frame = cap.read()
    if not ret:
        break

    frame_buffer.append(frame)

    if len(frame_buffer) == 25:
        keypoints, annotated_frames = extract_keypoints(frame_buffer)
        if len(keypoints) > 0:
            normalized_keypoints = normalize_keypoints(keypoints)
            X_array = np.array(normalized_keypoints)
            X_array = X_array.reshape((1, X_array.shape[0], X_array.shape[1] * X_array.shape[2]))
            generator = DataGenerator(X_array, batch_size=1)
            predictions = swing_detect_model.predict(generator)
            swing_prediction = np.argmax(predictions, axis=1)

            if swing_prediction[0] != 0:
                is_recording_swing = True
                print("swinging")

            if is_recording_swing:
                swing_record.extend([cv2.resize(frame, (width, height)) for frame in list(annotated_frames)[-2:]])
                if swing_prediction[0] == 2:
                    is_recording_swing = False
                    new_frames = add_text(swing_record)
                    if new_frames is not None:  # None チェックを追加
                        for frame in new_frames:
                            if frame.shape[1] != width or frame.shape[0] != height:
                                print(f"Error: Frame size mismatch. Expected ({width}, {height}), got ({frame.shape[1]}, {frame.shape[0]})")
                            out.write(frame)
                    swing_record.clear()


        for _ in range(2):
            if frame_buffer:
                frame_buffer.popleft()

    # スイング中でない場合は、通常のフレームを書き出す
    if not is_recording_swing:
        resized_frame = cv2.resize(frame, (width, height))
        out.write(resized_frame)

# リソースの解放
cap.release()
out.release()
cv2.destroyAllWindows()





swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
swinging
s

In [50]:
!cp /content/output.mp4 /content/drive/MyDrive/

In [24]:
import cv2
import numpy as np
from ultralytics import YOLO

# モデルの読み込み
model = YOLO("yolov8n-pose.pt")

# 入力と出力のビデオファイルパス
input_video = video_path
output_video = "output_video.mp4"

# ビデオキャプチャオブジェクトの作成
cap = cv2.VideoCapture(input_video)

# ビデオの属性を取得
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = int(cap.get(cv2.CAP_PROP_FPS))

# 出力ビデオライターの設定
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_video, fourcc, fps, (width, height))

# 座標データを保存するリスト
all_keypoints = []

frame_count = 0
while cap.isOpened():
    success, frame = cap.read()
    if not success:
        break

    # YOLOモデルで推論を行う
    results = model(frame)

    # 結果を描画したフレームを取得
    annotated_frame = results[0].plot()

    if len(results[0].boxes) > 0:
        # 最大面積のバウンディングボックスを選択
        areas = [(box.xyxy[0][2] - box.xyxy[0][0]) * (box.xyxy[0][3] - box.xyxy[0][1]) for box in results[0].boxes]
        largest_box_index = np.argmax(areas)

        keypoints = results[0].keypoints[largest_box_index]

        if keypoints.xy.ndim == 3 and keypoints.xy.shape[2] >= 2:
            all_keypoints.append(keypoints.xy[0].cpu().numpy())

            # キーポイントを描画
            for point in keypoints.xy[0]:
                x, y = point[:2]
                cv2.circle(annotated_frame, (int(x), int(y)), 5, (0, 255, 0), -1)

    frame_count += 1

    # 処理したフレームを出力ビデオに書き込む
    out.write(annotated_frame)

# キャプチャとビデオライターを解放
cap.release()
out.release()
cv2.destroyAllWindows()

# キーポイントデータをNumPy配列として保存
if all_keypoints:
    keypoints_array = np.array(all_keypoints)
    np.save('keypoints_data.npy', keypoints_array)
    print(f"Processed {frame_count} frames")
    print(f"Keypoints data shape: {keypoints_array.shape}")

    # スイングの分類を行う
    swing_class = classify_swing(keypoints_array)

    # 分類結果に基づいてビデオを再処理
    cap = cv2.VideoCapture(input_video)
    out = cv2.VideoWriter('classified_' + output_video, fourcc, fps, (width, height))

    frame_count = 0
    while cap.isOpened():
        success, frame = cap.read()
        if not success:
            break

        results = model(frame)
        annotated_frame = results[0].plot()

        # スイングクラスに基づいてテキストを追加
        if swing_class == 0:
            cv2.putText(annotated_frame, "Swing Start", (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
        elif swing_class == 1:
            cv2.putText(annotated_frame, "Swing In Progress", (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
        elif swing_class == 2:
            cv2.putText(annotated_frame, "Swing End", (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2)

        out.write(annotated_frame)
        frame_count += 1

    cap.release()
    out.release()

    print(f"Classified video saved as classified_{output_video}")
else:
    print("No valid keypoints data was collected.")


0: 384x640 1 person, 12.2ms
Speed: 1.9ms preprocess, 12.2ms inference, 2.3ms postprocess per image at shape (1, 3, 384, 640)


TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.

In [None]:
def is_swing_start(keypoints_sequence):
    if keypoints_sequence is not None and len(keypoints_sequence) == 35:
        keypoints_sequence = np.array(keypoints_sequence)
        keypoints_sequence = np.expand_dims(keypoints_sequence, axis=0)
        prediction = swing_start_model.predict(keypoints_sequence)
        return prediction[0][0] > 0.5  # 閾値は適宜調整してください
    return False

def classify_swing(keypoints_sequence):
    if keypoints_sequence is not None and len(keypoints_sequence) == 30:
        keypoints_sequence = np.array(keypoints_sequence)
        keypoints_sequence = np.expand_dims(keypoints_sequence, axis=0)
        prediction = swing_classification_model.predict(keypoints_sequence)
        return np.argmax(prediction)
    return None

# 動画ファイルの読み込み
cap = cv2.VideoCapture('path/to/your/video.mp4')

frame_buffer = deque(maxlen=35)
swing_frames = deque(maxlen=30)
is_recording_swing = False

while cap.isOpened():
    ret, frame = cap.read()
    if not ret:
        break

    keypoints = extract_keypoints(frame)
    normalized_keypoints = normalize_keypoints(keypoints)

    if normalized_keypoints is not None:
        frame_buffer.append(normalized_keypoints)

        if len(frame_buffer) == 35:
            if is_swing_start(list(frame_buffer)):
                is_recording_swing = True
                swing_frames.clear()
                swing_frames.extend(list(frame_buffer)[-30:])  # 最後の30フレームを使用

        if is_recording_swing:
            swing_frames.append(normalized_keypoints)

            if len(swing_frames) == 30:
                swing_class = classify_swing(list(swing_frames))
                if swing_class is not None:
                    label = shot_types[swing_class]
                    cv2.putText(frame, label, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)

                is_recording_swing = False

    cv2.imshow('Swing Classification', frame)

    if cv2.waitKey(1) & 0xFF == ord('q'):
        break

    # 3フレームごとにframe_bufferの先頭を削除
    if len(frame_buffer) == 35:
        for _ in range(3):
            if frame_buffer:
                frame_buffer.popleft()

cap.release()
cv2.destroyAllWindows()

In [None]:
all_data = []
for label, shot_type in enumerate(shot_types):
    folder_path = f'/Users/yusuke.s/Documents/pickleball_videos/{shot_type}'
    all_data.extend(process_video(folder_path, label))