In [None]:
pip install ultralytics

In [None]:
import cv2
import numpy as np
import torch
from collections import deque
from ultralytics import YOLO
from tensorflow.keras.models import load_model
from tensorflow.keras.utils import Sequence
from google.colab import drive
from google.colab.patches import cv2_imshow
drive.mount('/content/drive')

In [None]:
model = YOLO('yolov8n-pose.pt')
swing_detect_model = load_model('path/swing_detect.h5')
shot_class_model =load_model('path/swing_class.h5')

In [None]:
swing_detect_class=['swing_begin','swing_middle','swing_end']
shot_class = ['forehand_stroke', 'forehand_slice', 'forehand_volley', 'backhand_stroke', 'backhand_volley', 'backhand_slice']

In [None]:
video_path = "path/video.mp4"

In [None]:
def add_text (new_frame):
  texted_frame = cv2.putText(new_frame, "Swing", (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
  return texted_frame

def extract_keypoints(results):
    all_keypoints = []
    if len(results) > 0 and len(results[0].keypoints) > 0:
        keypoints_list = results[0].keypoints
        bboxes = results[0].boxes

        if bboxes is not None and len(bboxes) > 0:
            try:
                xyxy = bboxes.xyxy.cpu().numpy()  # Get bounding box coordinates
                areas = [(box[2] - box[0]) * (box[3] - box[1]) for box in xyxy]
            except IndexError as e:
                print(f"Error calculating areas: {e}")
                print(f"Bounding boxes: {bboxes}")
                return all_keypoints

            # Get the index of the largest bounding box
            max_area_index = np.argmax(areas)
            # Get the keypoint corresponding to the largest bounding box
            keypoints = keypoints_list[max_area_index].xy[0].cpu().numpy()

            all_keypoints.append(keypoints)
        else:
            print(f"Warning: No bounding boxes found in frame.")
    else:
        print(f"Warning: No keypoints found in frame.")
    return all_keypoints

def process_frame(frame_buffer):
    original_keypoints = []
    plotted_frames = []

    for frame in frame_buffer:
        results = model(frame)
        frame_keypoints = extract_keypoints(results)
        original_keypoints.extend(frame_keypoints)

        plotted_frame = results[0].plot()
        plotted_frames.append(plotted_frame)
    return original_keypoints, plotted_frames


def normalize_keypoints(keypoints):
    hip_index = 11
    shoulder_index = 5

    normalized_keypoints = []
    #print("keypoints is passed")
    for frame_keypoints in keypoints:
        if len(frame_keypoints) > max(hip_index, shoulder_index):
            hip_point = frame_keypoints[hip_index][:2]
            #print(hip_point)
            shoulder_point = frame_keypoints[shoulder_index][:2]
            #print(shoulder_point)

            relative_points = frame_keypoints[:, :2] - hip_point
            #print("relative_points")
            #print(relative_points)

            scale_factor = np.linalg.norm(shoulder_point - hip_point)
            if scale_factor != 0:
                relative_points /= scale_factor

            normalized_keypoints.append(relative_points)
        else:
            print(f"Warning: Frame with insufficient keypoints detected. Keypoints: {len(frame_keypoints)}")

    print(f"Normalized {len(normalized_keypoints)} frames out of {len(keypoints)}")
    return np.array(normalized_keypoints)


def add_text (new_frame):
  texted_frame = cv2.putText(new_frame, "Swing", (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
  return texted_frame

def add_shot (new_frame,shot_prediction):
  if (shot_prediction == 0):
    texted_frame = cv2.putText(new_frame, "Forehand-stroke", (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
  elif (shot_prediction == 1):
    texted_frame = cv2.putText(new_frame, "Forehand-slice", (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
  elif (shot_prediction == 2):
    texted_frame = cv2.putText(new_frame, "Forehand-volley", (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
  elif (shot_prediction == 3):
    texted_frame = cv2.putText(new_frame, "Backhand-stroke", (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
  elif (shot_prediction == 4):
    texted_frame = cv2.putText(new_frame, "Backhand-volley", (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
  else:
    texted_frame = cv2.putText(new_frame, "Backhand-slice", (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
  return texted_frame


def pad_sequences(sequences, max_length):
    return [seq[:max_length] if len(seq) > max_length else np.pad(seq, ((0, max_length - len(seq)), (0, 0), (0, 0)), 'constant') for seq in sequences]


In [None]:
cap = cv2.VideoCapture(video_path)
frame_buffer = deque(maxlen=40)
is_recording_swing = False
swing_record = []
swing_frame = []
flag = False
final_frames = []
fourcc = cv2.VideoWriter_fourcc(*'MP4V')
fps = cap.get(cv2.CAP_PROP_FPS)
print(fps)
#width, height = 640, 640
original_width, original_height = 1280, 720
#print(f"Frame size: {original_width}x{original_height}")
out = cv2.VideoWriter('federer_output.mp4', fourcc, fps, (original_width, original_height))

if not out.isOpened():
    print("Error: VideoWriter not opened.")
    cap.release()
    cv2.destroyAllWindows()
    exit()

while cap.isOpened():
    ret, frame = cap.read()
    if not ret:
        break

    frame_buffer.append(frame)

    if (len(frame_buffer) < 40 & flag == False):
        new_frame = frame_buffer[-1]
        original_keypoints, plotted_frame = process_frame([new_frame])
        #print("--------------------------------------------------------------------------------------------------------------------------------------")
        print("keypoints num")
        print(len(original_keypoints))
        #original_keypoints, plotted_frame = process_frame([new_frame])
        if original_keypoints:
            final_frames.append(plotted_frame)

    elif len(frame_buffer) == 40:
        flag = True
        original_keypoints, plotted_frames = process_frame(list(frame_buffer))
        #print("--------------------------------------------------------------------------------------------------------------------------------------")
        print("keypoints num")
        print(len(original_keypoints[0]))
        if len(original_keypoints) > 0:
            normalized_keypoints = normalize_keypoints(original_keypoints)
            X_array = np.array(normalized_keypoints)
            X_array = X_array.reshape((1, X_array.shape[0], X_array.shape[1] * X_array.shape[2]))

            predictions = swing_detect_model.predict(X_array)
            swing_prediction = np.argmax(predictions, axis=1)

            if swing_prediction[0] != 2:   #swinging
                print("swinging")
                is_recording_swing = True
                if (len(swing_frame)==0):
                  #print(final_frames)
                  #final_frames = final_frames[:-38]
                  swing_frame.extend(list(plotted_frames)[-2:])
                  swing_record.extend(list(normalized_keypoints)[-2:])

                  #swing_frame.extend(list(plotted_frames)[-2:])
                else:
                  swing_frame.extend(list(plotted_frames)[-2:])
                  swing_record.extend(list(normalized_keypoints)[-2:])

            elif swing_prediction[0] == 2:  # swinging is done
                if is_recording_swing:  #the swing is done
                    swing_frame.extend(list(plotted_frames)[-2:])
                    swing_record.extend(list(normalized_keypoints)[-2:])
                    is_recording_swing = False
                    normalized_keypoints = normalize_keypoints(swing_record)
                    X_padded = pad_sequences([normalized_keypoints], 100)
                    X_array = np.array(X_padded).reshape((1, 100, 34))
                    print("--------------------------------------------------------------------------------------------------------------------------------------")
                    print(X_array.shape)
                    predictions = shot_class_model.predict(X_array)
                    shot_prediction = np.argmax(predictions, axis=1)
                    for new_frame in swing_frame:
                        new_frame = add_shot(new_frame,shot_prediction[0])
                        final_frames.append(new_frame)
                    swing_frame.clear()
                    swing_record.clear()

                else: #still not swinging
                    end_swing = []
                    is_recording_swing = False
                    end_swing.extend(list(plotted_frames)[-2:])
                    for new_frame in end_swing:
                        final_frames.append(new_frame)

        for _ in range(2):
            if frame_buffer:
                frame_buffer.popleft()
for gen_frame in final_frames:
  out.write(gen_frame)


cap.release()
out.release()
cv2.destroyAllWindows()