In [1]:
import os
import cv2
import numpy as np
import mediapipe as mp
import tensorflow as tf
from sklearn.model_selection import train_test_split
from tensorflow.keras.layers import LSTM, Dense, Dropout
from tensorflow.keras.models import Sequential
from tqdm import tqdm

2025-04-06 21:32:42.561383: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'  # Hide INFO/WARNING logs

In [3]:
SEQUENCE_LENGTH = 15  # 0.5s clips at 30fps
NUM_KEYPOINTS = 33    # MediaPipe pose points
CLASSES = ['backhand_drive','backhand_net_shot', 'forehand_clear', 'forehand_drive', 'forehand_lift', 'forehand_net_shot']

In [4]:
# Initialize MediaPipe Pose
mp_pose = mp.solutions.pose

In [5]:
def process_keypoints(keypoints):
    """Normalize keypoints relative to hip center"""
    keypoints = np.array(keypoints)
    
    # Use hip (index 23) as reference point
    hip_coords = keypoints[:, 23*3:23*3+2]
    
    # Normalize coordinates
    keypoints[:, 0::3] -= hip_coords[:, 0:1]  # X
    keypoints[:, 1::3] -= hip_coords[:, 1:2]  # Y
    return keypoints

In [6]:
def extract_keypoints(video_path):
    """Extract pose keypoints from video"""
    cap = cv2.VideoCapture(video_path)
    keypoint_sequence = []
    
    with mp_pose.Pose(
        min_detection_confidence=0.5,
        min_tracking_confidence=0.5) as pose:
        
        while cap.isOpened():
            ret, frame = cap.read()
            if not ret:
                break
                
            # Process frame
            frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            results = pose.process(frame_rgb)
            
            if results.pose_landmarks:
                kps = []
                for landmark in results.pose_landmarks.landmark:
                    kps.extend([landmark.x, landmark.y, landmark.visibility])
                keypoint_sequence.append(kps)
            else:
                keypoint_sequence.append([0]*(NUM_KEYPOINTS*3))  # Zero-pad missing frames
    
    cap.release()
    
    # Normalize and align sequence
    sequence = process_keypoints(keypoint_sequence)
    return sequence[:SEQUENCE_LENGTH]

In [7]:
def build_dataset(dataset_path):
    """Create dataset from video directory"""
    X, y = [], []
    
    for class_idx, class_name in enumerate(CLASSES):
        class_dir = os.path.join(dataset_path, class_name)
        videos = os.listdir(class_dir)
        
        for video in tqdm(videos, desc=f'Processing {class_name}'):
            video_path = os.path.join(class_dir, video)
            sequence = extract_keypoints(video_path)
            
            if len(sequence) == SEQUENCE_LENGTH:
                X.append(sequence)
                y.append(class_idx)
    
    return np.array(X), np.array(y)

In [8]:
# # Build dataset
# X, y = build_dataset('/kaggle/input/video-dataset/vdataset')

# # Train-test split
# X_train, X_test, y_train, y_test = train_test_split(
#     X, y, test_size=0.2, stratify=y)

# # Convert labels to one-hot
# y_train = tf.keras.utils.to_categorical(y_train)
# y_test = tf.keras.utils.to_categorical(y_test)

In [9]:
# # Build LSTM model
# model = Sequential([
#     LSTM(64, input_shape=(SEQUENCE_LENGTH, NUM_KEYPOINTS*3), return_sequences=True),
#     Dropout(0.3),
#     LSTM(32),
#     Dense(32, activation='relu'),
#     Dense(len(CLASSES), activation='softmax')
# ])

In [10]:
# model.compile(
#     optimizer='adam',
#     loss='categorical_crossentropy',
#     metrics=['accuracy']
# )

In [11]:
# history = model.fit(
#     X_train, y_train,
#     validation_data=(X_test, y_test),
#     epochs=20,
#     batch_size=32,
#     callbacks=[
#         tf.keras.callbacks.EarlyStopping(patience=10, restore_best_weights=True),
#         tf.keras.callbacks.ModelCheckpoint(
#             'pose_model.keras',
#             save_best_only=True,
#             monitor='val_accuracy'
#         )
#     ]
# )

In [12]:
# # Evaluation
# loss, accuracy = model.evaluate(X_test, y_test)
# print(f"Test Accuracy: {accuracy*100:.2f}%")

In [None]:
class PoseClassifier:
    def __init__(self, model_path, sequence_length=15):
        self.model = tf.keras.models.load_model(model_path)
        self.sequence = []
        self.sequence_length = sequence_length
        self.pose = mp_pose.Pose(
            min_detection_confidence=0.5,
            min_tracking_confidence=0.5
        )  # Reuse Pose instance

    def classify(self, frame):
        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        results = self.pose.process(frame_rgb)
        
        with mp_pose.Pose() as pose:
            results = pose.process(frame_rgb)
            
            if results.pose_landmarks:
                kps = []
                for landmark in results.pose_landmarks.landmark:
                    kps.extend([landmark.x, landmark.y, landmark.visibility])
                self.sequence.append(kps)
            else:
                self.sequence.append([0]*(NUM_KEYPOINTS*3))
                
            # Maintain fixed sequence length
            if len(self.sequence) > self.sequence_length:
                self.sequence = self.sequence[-self.sequence_length:]
                
            if len(self.sequence) == self.sequence_length:
                processed = process_keypoints(np.array([self.sequence]))
                prediction = self.model.predict(processed)[0]
                return CLASSES[np.argmax(prediction)], np.max(prediction)
            
        return "", 0


: 

In [None]:
# Usage
classifier = PoseClassifier('pose_model.keras')

cap = cv2.VideoCapture(0)
if not cap.isOpened():
    print("Error: Could not open camera")
    exit()

while cap.isOpened():
    ret, frame = cap.read()
    if not ret:
        break
    
    pose_label, confidence = classifier.classify(frame)
    
    cv2.putText(frame, f"{pose_label} ({confidence:.2f})", 
               (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0,255,0), 2)
    cv2.imshow('Badminton Pose Classification', frame)
    
    if cv2.waitKey(1) == ord('q'):
        break

cap.release()
cv2.destroyAllWindows()
