In [2]:
pip install ultralytics

Collecting ultralytics
  Downloading ultralytics-8.2.72-py3-none-any.whl.metadata (41 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/41.3 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m41.3/41.3 kB[0m [31m1.5 MB/s[0m eta [36m0:00:00[0m
Collecting ultralytics-thop>=2.0.0 (from ultralytics)
  Downloading ultralytics_thop-2.0.0-py3-none-any.whl.metadata (8.5 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch>=1.8.0->ultralytics)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch>=1.8.0->ultralytics)
  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch>=1.8.0->ultralytics)
  Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu1

In [1]:
import cv2
import numpy as np
import torch
from collections import deque
from ultralytics import YOLO
from tensorflow.keras.models import load_model
from tensorflow.keras.utils import Sequence
from google.colab import drive
drive.mount('/content/drive')

ModuleNotFoundError: No module named 'ultralytics'

In [None]:
model = YOLO('yolov8n-pose.pt')
swing_detect_model = load_model('/content/drive/MyDrive/swing_detect.h5')
#shot_class_model =load_model('/content/drive/MyDrive/swing_class_100frames.h5')

In [None]:
swing_detect_class=['swing_begin','swing_middle','swing_end']
shot_class = ['forehand_stroke', 'forehand_slice', 'forehand_volley', 'backhand_stroke', 'backhand_volley', 'backhand_slice']

In [None]:
def preprocess_frame(frame):
	frame = cv2.resize(frame, (640, 640))
	frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
	frame = frame.astype(np.float32) / 255.0
	frame = np.expand_dims(frame, axis=0)
	frame = np.transpose(frame, (0, 3, 1, 2))
	return frame

def extract_keypoints(frames):
	all_keypoints = []
	device = torch.device('cuda')
	# annotated_frames = []
	for frame in frames:
		frame = preprocess_frame(frame)
		frame = torch.tensor(frame, dtype=torch.float32).to(device)
		results = model(frame, verbose=False)
		# annotated_frame = results[0].plot()

		if len(results) > 0 and len(results[0].keypoints) > 0:
			keypoints_list = results[0].keypoints
			bboxes = results[0].boxes
			if bboxes is not None and len(bboxes) > 0:
				try:
					xyxy = bboxes.xyxy.cpu().numpy()  # get xyxy coordinates to find the closest person to the camera.
					areas = [(box[2] - box[0]) * (box[3] - box[1]) for box in xyxy]
				except IndexError as e:
					print(f"Error calculating areas: {e}")
					print(f"Bounding boxes: {bboxes}")
					continue
				max_area_index = np.argmax(areas)  # get the index that contains the biggest bbox
				keypoints = keypoints_list[max_area_index].xy[0].cpu().numpy()  # get the keypoints from the biggest bbox
				all_keypoints.append(keypoints)
		# out.write(annotated_frame)  # これswingの結果も書く必要があるので後に持っていく.
	return all_keypoints, frames  # ここで次にannotated_framesを次に繋げる

def normalize_keypoints(keypoints):
	hip_index = 11  # index of left hip
	shoulder_index = 5  # index of shoulder

	normalized_keypoints = []
	for frame_keypoints in keypoints:
		if len(frame_keypoints) > max(hip_index, shoulder_index):
			hip_point = frame_keypoints[hip_index]
			shoulder_point = frame_keypoints[shoulder_index]
			# set hip point as (0,0) coordinates, set distance from hip to shoulder as 1 in every input.
			relative_points = frame_keypoints - hip_point
			# scaling
			scale_factor = np.linalg.norm(shoulder_point - hip_point)
			if scale_factor != 0:
				relative_points /= scale_factor
			normalized_keypoints.append(relative_points)
		else:
			print(f"Warning: Frame with insufficient keypoints detected. Skipping this frame.")

	return np.array(normalized_keypoints)


In [None]:
def rescale_keypoints(keypoints, input_size, output_size):
  scale_x = output_size[0] / input_size[0]
  scale_y = output_size[1] / input_size[1]
  rescaled_keypoints = []
  for kp in keypoints:
    rescaled_keypoints.append([kp[0] * scale_x, kp[1] * scale_y])
  return rescaled_keypoints

def draw_keypoints(frame, original_keypoints, color=(0, 255, 0), radius=5):
	for kp in original_keypoints:
		if kp is not None and len(kp) == 2:
			x, y = int(kp[0]), int(kp[1])
			cv2.circle(frame, (x, y), radius, color, -1)
	return frame

def draw_rec(all_keypoints,frames):
	annotated_frames = []
	for point in all_keypoints:
		x, y = point[:2]
		cv2.circle(annotated_frame, (int(x), int(y)), 5, (0, 255, 0), -1)
		annotated_frames.append(annotated_frame)
	return

def add_text (new_frame):
  texted_frame = cv2.putText(new_frame, "Swing", (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
  return texted_frame




In [None]:
class DataGenerator(Sequence):
	def __init__(self, x_set, batch_size):
		self.x = x_set
		self.batch_size = batch_size

	def __len__(self):
		return int(np.ceil(len(self.x) / float(self.batch_size)))

	def __getitem__(self, idx):
		batch_x = self.x[idx * self.batch_size:(idx + 1) * self.batch_size]
		return batch_x

In [None]:
video_path = "/content/drive/MyDrive/test_30s.mp4"

In [None]:
from google.colab.patches import cv2_imshow

In [None]:
import cv2
import numpy as np
from collections import deque

cap = cv2.VideoCapture(video_path)
frame_buffer = deque(maxlen=40)
is_recording_swing = False
swing_record = []

# VideoWriterの設定
fourcc = cv2.VideoWriter_fourcc(*'mp4v')  # または適切なコーデック
fps = cap.get(cv2.CAP_PROP_FPS)
width, height = 640, 640
original_width, original_height = 1280,720
out = cv2.VideoWriter('output.mp4', fourcc, fps, (original_width, original_height))

if not out.isOpened():
	print("Error: VideoWriter not opened.")
	cap.release()
	cv2.destroyAllWindows()
	exit()

while cap.isOpened():
	ret, frame = cap.read()
	if not ret:
		break

	frame_buffer.append(frame)
  swing_record = []

  if (len(frame_buffer) < 40):
    new_frame = frame_bufer[-1]
    keypoints, frames = extract_keypoints(new_frame)
    original_keypoints = rescale_keypoints(keypoints, (640,640), (1820,720))
    new_frame = draw_keypoints(new_frame, original_keypoints, color=(0, 255, 0), radius=5)
    out.write(new_frame)

	elif (len(frame_buffer) == 40):
		keypoints, frames = extract_keypoints(frame_buffer)
		if len(keypoints) > 0:
			normalized_keypoints = normalize_keypoints(keypoints)
			X_array = np.array(normalized_keypoints)
			X_array = X_array.reshape((1, X_array.shape[0], X_array.shape[1] * X_array.shape[2]))
			generator = DataGenerator(X_array, batch_size=1)
			predictions = swing_detect_model.predict(generator)
			swing_prediction = np.argmax(predictions, axis=1)
			if swing_prediction[0] != 2:   #スイング最中
				is_recording_swing = True
				print("swinging")
        swing_record.extend([cv2.resize(new_frame, (original_width, original_height)) for new_frame in list(frames)[-3:-1]])

      elif swing_prediction[0] == 2:  # not swinging
        if is_recording_swing:  #the end of swing
          new_frames=[]
          is_recording_swing = False
          for new_frame in swing_record:
            new_frame =cv2.resize(new_frame, (1280, 720))
            original_keypoints = rescale_keypoints(keypoints, (640,640), (1820,720))
            print(original_keypoints[0])
            print(original_keypoints[1])
            print(original_keypoints[2])
            print(original_keypoints[3])
            new_frame = draw_keypoints(new_frame, original_keypoints, color=(0, 255, 0), radius=5)
            new_frame = add_text(new_frame)
            out.write(new_frame)
          swing_record.clear()

        else: #停止中。ここに来るたびvideoにframeを追加
          end_swing = extend([cv2.resize(frame, (original_width, original_height)) for frame in list(frames)[-2:]])
          for new_frame in end_swing:
            original_keypoints = rescale_keypoints(keypoints[-2:], (640,640), (1820,720))
            new_frame= draw_keypoints(new_frame, original_keypoints, color=(0, 255, 0), radius=5)
            out.write(new_frame)

		for _ in range(2):
			if frame_buffer:
				frame_buffer.popleft()





# リソースの解放
cap.release()
out.release()
cv2.destroyAllWindows()



In [None]:
!cp /content/output.mp4 /content/drive/MyDrive/

In [None]:
def is_swing_start(keypoints_sequence):
    if keypoints_sequence is not None and len(keypoints_sequence) == 35:
        keypoints_sequence = np.array(keypoints_sequence)
        keypoints_sequence = np.expand_dims(keypoints_sequence, axis=0)
        prediction = swing_start_model.predict(keypoints_sequence)
        return prediction[0][0] > 0.5  # 閾値は適宜調整してください
    return False

def classify_swing(keypoints_sequence):
    if keypoints_sequence is not None and len(keypoints_sequence) == 30:
        keypoints_sequence = np.array(keypoints_sequence)
        keypoints_sequence = np.expand_dims(keypoints_sequence, axis=0)
        prediction = swing_classification_model.predict(keypoints_sequence)
        return np.argmax(prediction)
    return None

# 動画ファイルの読み込み
cap = cv2.VideoCapture('path/to/your/video.mp4')

frame_buffer = deque(maxlen=35)
swing_frames = deque(maxlen=30)
is_recording_swing = False

while cap.isOpened():
    ret, frame = cap.read()
    if not ret:
        break

    keypoints = extract_keypoints(frame)
    normalized_keypoints = normalize_keypoints(keypoints)

    if normalized_keypoints is not None:
        frame_buffer.append(normalized_keypoints)

        if len(frame_buffer) == 35:
            if is_swing_start(list(frame_buffer)):
                is_recording_swing = True
                swing_frames.clear()
                swing_frames.extend(list(frame_buffer)[-30:])  # 最後の30フレームを使用

        if is_recording_swing:
            swing_frames.append(normalized_keypoints)

            if len(swing_frames) == 30:
                swing_class = classify_swing(list(swing_frames))
                if swing_class is not None:
                    label = shot_types[swing_class]
                    cv2.putText(frame, label, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)

                is_recording_swing = False

    cv2.imshow('Swing Classification', frame)

    if cv2.waitKey(1) & 0xFF == ord('q'):
        break

    # 3フレームごとにframe_bufferの先頭を削除
    if len(frame_buffer) == 35:
        for _ in range(3):
            if frame_buffer:
                frame_buffer.popleft()

cap.release()
cv2.destroyAllWindows()

In [None]:
all_data = []
for label, shot_type in enumerate(shot_types):
    folder_path = f'/Users/yusuke.s/Documents/pickleball_videos/{shot_type}'
    all_data.extend(process_video(folder_path, label))