In [1]:
pip install ultralytics



In [2]:
import cv2
import numpy as np
import torch
from collections import deque
from ultralytics import YOLO
from tensorflow.keras.models import load_model
from tensorflow.keras.utils import Sequence
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [3]:
model = YOLO('yolov8n-pose.pt')
swing_detect_model = load_model('/content/drive/MyDrive/swing_detect.h5')
#shot_class_model =load_model('/content/drive/MyDrive/swing_class_100frames.h5')



In [4]:
swing_detect_class=['swing_begin','swing_middle','swing_end']
shot_class = ['forehand_stroke', 'forehand_slice', 'forehand_volley', 'backhand_stroke', 'backhand_volley', 'backhand_slice']

In [5]:

def preprocess_frame(frame):
    frame = cv2.resize(frame, (640, 640))
    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    frame = frame.astype(np.float32) / 255.0
    frame = np.expand_dims(frame, axis=0)
    frame = np.transpose(frame, (0, 3, 1, 2))
    return frame

def extract_keypoints(frames):
    all_keypoints = []
    device = torch.device('cuda')
    if not isinstance(frames, list):
        frames = [frames]
    for frame in frames:
        frame = preprocess_frame(frame)
        frame = torch.tensor(frame, dtype=torch.float32).to(device)
        results = model(frame, verbose=False)

        if len(results) > 0 and len(results[0].keypoints) > 0:
            keypoints_list = results[0].keypoints
            bboxes = results[0].boxes
            if bboxes is not None and len(bboxes) > 0:
                try:
                    xyxy = bboxes.xyxy.cpu().numpy()
                    areas = [(box[2] - box[0]) * (box[3] - box[1]) for box in xyxy]
                    max_area_index = np.argmax(areas)
                    keypoints = keypoints_list[max_area_index].xy[0].cpu().numpy()
                    all_keypoints.append(keypoints)
                except IndexError as e:
                    print(f"Error calculating areas: {e}")
                    print(f"Bounding boxes: {bboxes}")
                    continue
    return all_keypoints, frames

def normalize_keypoints(keypoints):
    hip_index = 11  # index of left hip
    shoulder_index = 5  # index of shoulder

    normalized_keypoints = []
    for frame_keypoints in keypoints:
        if len(frame_keypoints) > max(hip_index, shoulder_index):
            hip_point = frame_keypoints[hip_index]
            shoulder_point = frame_keypoints[shoulder_index]
            relative_points = frame_keypoints - hip_point
            scale_factor = np.linalg.norm(shoulder_point - hip_point)
            if scale_factor != 0:
                relative_points /= scale_factor
            normalized_keypoints.append(relative_points)
        else:
            print(f"Warning: Frame with insufficient keypoints detected. Skipping this frame.")

    return np.array(normalized_keypoints)

In [6]:
def rescale_keypoints(keypoints, input_size, output_size):
  scale_x = output_size[0] / input_size[0]
  scale_y = output_size[1] / input_size[1]
  rescaled_keypoints = []
  for kp in keypoints:
    rescaled_keypoints.append([kp[0] * scale_x, kp[1] * scale_y])
  return rescaled_keypoints

def draw_keypoints(frame, original_keypoints, color=(0, 255, 0), radius=5):
	for kp in original_keypoints:
		if kp is not None and len(kp) == 2:
			x, y = int(kp[0]), int(kp[1])
			cv2.circle(frame, (x, y), radius, color, -1)
	return frame

def draw_rec(all_keypoints,frames):
	annotated_frames = []
	for point in all_keypoints:
		x, y = point[:2]
		cv2.circle(annotated_frame, (int(x), int(y)), 5, (0, 255, 0), -1)
		annotated_frames.append(annotated_frame)
	return

def add_text (new_frame):
  texted_frame = cv2.putText(new_frame, "Swing", (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
  return texted_frame






In [7]:
class DataGenerator(Sequence):
	def __init__(self, x_set, batch_size):
		self.x = x_set
		self.batch_size = batch_size

	def __len__(self):
		return int(np.ceil(len(self.x) / float(self.batch_size)))

	def __getitem__(self, idx):
		batch_x = self.x[idx * self.batch_size:(idx + 1) * self.batch_size]
		return batch_x

In [8]:
video_path = "/content/drive/MyDrive/test_30s.mp4"

In [9]:
from google.colab.patches import cv2_imshow

In [21]:
!cp /content/output.mp4 /content/drive/MyDrive/

In [20]:
import cv2
import numpy as np
import torch
from collections import deque



cap = cv2.VideoCapture(video_path)
frame_buffer = deque(maxlen=40)
is_recording_swing = False
swing_record = []
flag = False

# VideoWriterの設定
fourcc = cv2.VideoWriter_fourcc(*'MP4V')
fps = cap.get(cv2.CAP_PROP_FPS)
print(fps)
width, height = 640, 640
original_width, original_height = 1280, 720
print(f"Frame size: {original_width}x{original_height}")
out = cv2.VideoWriter('output.mp4', fourcc, fps, (original_width, original_height))

if not out.isOpened():
    print("Error: VideoWriter not opened.")
    cap.release()
    cv2.destroyAllWindows()
    exit()

while cap.isOpened():
    ret, frame = cap.read()
    if not ret:
        break

    frame_buffer.append(frame)

    if (len(frame_buffer) < 40 & flag == False):
        new_frame = frame_buffer[-1]
        keypoints, _ = extract_keypoints([new_frame])
        if keypoints:
            original_keypoints = rescale_keypoints(keypoints[0], (640, 640), (original_width, original_height))
            new_frame = draw_keypoints(new_frame, original_keypoints, color=(0, 255, 0), radius=5)
            out.write(new_frame)

    elif len(frame_buffer) == 40:
        flag = True
        keypoints, frames = extract_keypoints(list(frame_buffer))
        if len(keypoints) > 0:
            normalized_keypoints = normalize_keypoints(keypoints)
            X_array = np.array(normalized_keypoints)
            X_array = X_array.reshape((1, X_array.shape[0], X_array.shape[1] * X_array.shape[2]))

            # DataGenerator を使用せず、直接 X_array を predict() に渡す
            predictions = swing_detect_model.predict(X_array)
            swing_prediction = np.argmax(predictions, axis=1)

            if swing_prediction[0] != 2:   #スイング中
                is_recording_swing = True
                print("swinging")
                swing_record.extend([cv2.resize(frame, (original_width, original_height)) for frame in list(frames)[-2:]])

            elif swing_prediction[0] == 2:  # スイング中でない
                if is_recording_swing:  #スイング終了
                    swing_record.extend([cv2.resize(frame, (original_width, original_height)) for frame in list(frames)[-2:]])
                    is_recording_swing = False
                    for new_frame, kp in zip(swing_record, keypoints[-len(swing_record):]):
                        original_keypoints = rescale_keypoints(kp, (640, 640), (original_width, original_height))
                        #ここで推論モデルを挟むだけ(padする)
                        new_frame = draw_keypoints(new_frame, original_keypoints, color=(0, 255, 0), radius=5)
                        new_frame = add_text(new_frame)
                        out.write(new_frame)
                    swing_record.clear()

                else: #停止中
                    is_recording_swing = False
                    end_swing = [cv2.resize(frame, (original_width, original_height)) for frame in list(frames)[-2:]]
                    #end_swing = [cv2.resize(frame, (original_width, original_height)) for frame in list(frames)[:]]   試してみる価値あり
                    for new_frame, kp in zip(end_swing, keypoints[-2:]):
                        original_keypoints = rescale_keypoints(kp, (640, 640), (original_width, original_height))
                        new_frame = draw_keypoints(new_frame, original_keypoints, color=(0, 255, 0), radius=5)
                        out.write(new_frame)

        for _ in range(2):
            if frame_buffer:
                frame_buffer.popleft()

# リソースの解放
cap.release()
out.release()
cv2.destroyAllWindows()

29.97002997002997
Frame size: 1280x720
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 20ms/step
swinging
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 20ms/step
swinging
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 30ms/step
swinging
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 27ms/step
swinging
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 30ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 30ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 23ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 21ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 19ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 19ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 19ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 20ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m