<a href="https://colab.research.google.com/github/whale1510/KSEB_AI_proj/blob/main/modules/swing_detect_inference_model_v2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## 야구 동작(스윙)을 인식하기 위한 모델 추론 모듈
- name : 조병웅
- project : 야구 AI 캐스터
- stack : mediapipe(0.10.14), openCV(3.10.12), python(3.9.5),

In [None]:
import os
import cv2
import numpy as np
import mediapipe as mp
from sklearn.ensemble import RandomForestClassifier
import joblib  # 모델 저장 및 로드에 사용

# Mediapipe setup
mp_drawing = mp.solutions.drawing_utils
mp_pose = mp.solutions.pose
pose = mp_pose.Pose()

# Function to extract pose landmarks from an image
def extract_pose_from_frame(frame):
    frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    results = pose.process(frame_rgb)

    if results.pose_landmarks:
        return results.pose_landmarks.landmark, results.pose_landmarks
    return None, None

# Function to convert landmarks to a flat array
def landmarks_to_array(landmarks):
    if landmarks is None:
        return np.zeros(33 * 3)
    return np.array([[landmark.x, landmark.y, landmark.z] for landmark in landmarks]).flatten()

# Load the trained model
model = joblib.load('trained_model.pkl')#swing 모델

# Function to predict the label for a single frame
def predict_single_frame(frame):
    landmarks, landmark_points = extract_pose_from_frame(frame)
    pose_vector = landmarks_to_array(landmarks)
    prediction = model.predict([pose_vector])
    return prediction[0]#, landmark_points

def swing_detecting(frame):
    result = predict_single_frame(frame)
    return result
"""
# Real-time prediction from webcam
cap = cv2.VideoCapture(0)

while cap.isOpened():
    ret, frame = cap.read()
    if not ret:
        break

    label, landmarks = predict_single_frame(frame)

    if landmarks is not None:
        mp_drawing.draw_landmarks(frame, landmarks, mp_pose.POSE_CONNECTIONS)

    cv2.putText(frame, f'Label: {label}', (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2, cv2.LINE_AA)
    cv2.imshow('Real-time Prediction', frame)

    if cv2.waitKey(1) & 0xFF == ord('q'):
        break

cap.release()
cv2.destroyAllWindows()