In [None]:
import cv2
import mediapipe as mp
import numpy as np
import torch
import torch.nn as nn
import time
from collections import deque
import matplotlib.pyplot as plt
from PIL import Image, ImageDraw, ImageFont
import torchvision.transforms as transforms

POSITION_CLASSES = [
    'standing', 
    'takedown1', 'takedown2',
    'open_guard1', 'open_guard2',
    'half_guard1', 'half_guard2',
    'closed_guard1', 'closed_guard2',
    '50-50_guard',
    'side_control1', 'side_control2',
    'mount1', 'mount2',
    'back1', 'back2',
    'turtle1', 'turtle2'
]

class PoseTransformer(nn.Module):
    def __init__(self, input_dim=34, embed_dim=64, num_heads=4, num_classes=18, dropout=0.1):
        super(PoseTransformer, self).__init__()
        self.embedding = nn.Linear(input_dim, embed_dim)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim,
            nhead=num_heads,
            dim_feedforward=128,
            dropout=dropout,
            batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=2)

        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(embed_dim * 2, 64),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(64, num_classes)
        )

    def forward(self, x): 
        x = self.embedding(x)  
        x = self.transformer_encoder(x) 
        out = self.classifier(x) 
        return out

def load_model(model_path='pose_transformer.pth'):
    model = PoseTransformer()
    
    try:
        model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
        print(f"Model loaded successfully from {model_path}")
    except Exception as e:
        print(f"Error loading model: {e}")
        return None
    
    model.eval()
    return model

def setup_mediapipe_multi():
    mp_pose = mp.solutions.pose
    pose = mp_pose.Pose(
        static_image_mode=False,
        model_complexity=1,
        smooth_landmarks=True,
        enable_segmentation=False,
        min_detection_confidence=0.5,
        min_tracking_confidence=0.5
    )
    return mp_pose, pose

def preprocess_keypoints(landmarks, image_height=480, image_width=640):
    if not landmarks:
        return None
    
    mp_to_coco = {
        0: 0,    # nose
        2: 1,    # left eye
        5: 2,    # right eye
        7: 3,    # left ear
        8: 4,    # right ear
        11: 5,   # left shoulder
        12: 6,   # right shoulder
        13: 7,   # left elbow
        14: 8,   # right elbow
        15: 9,   # left wrist
        16: 10,  # right wrist
        23: 11,  # left hip
        24: 12,  # right hip
        25: 13,  # left knee
        26: 14,  # right knee
        27: 15,  # left ankle
        28: 16   # right ankle
    }
    
    coco_keypoints = []
    for mp_idx, coco_idx in mp_to_coco.items():
        lm = landmarks[mp_idx]
        x, y = lm.x * image_width, lm.y * image_height
        coco_keypoints.extend([x, y])  
    
    keypoints = torch.tensor(coco_keypoints, dtype=torch.float32)
    return keypoints

def predict_position(model, keypoints1, keypoints2=None):
    if keypoints1 is None:
        return None, 0.0
    
    if keypoints2 is None:
        keypoints2 = torch.zeros_like(keypoints1)
    
    batch_input = torch.stack([keypoints1, keypoints2]).unsqueeze(0)
    
    with torch.no_grad():
        output = model(batch_input)
        probs = torch.nn.functional.softmax(output, dim=1)
        confidence, prediction = torch.max(probs, 1)
    
    return POSITION_CLASSES[prediction.item()], confidence.item()

def draw_pose(image, pose_landmarks, mp_pose, mp_drawing, color=(0, 255, 0)):
    if pose_landmarks:
        mp_drawing.draw_landmarks(
            image,
            pose_landmarks,
            mp_pose.POSE_CONNECTIONS,
            mp_drawing.DrawingSpec(color=color, thickness=2, circle_radius=2),
            mp_drawing.DrawingSpec(color=(0, 128, 255), thickness=2)
        )
    return image

def setup_prediction_smoothing(window_size=5):
    return deque(maxlen=window_size)

def get_smoothed_prediction(prediction_history):
    if not prediction_history:
        return None, 0.0
    
    pred_counts = {}
    conf_sums = {}
    
    for pred, conf in prediction_history:
        if pred not in pred_counts:
            pred_counts[pred] = 0
            conf_sums[pred] = 0
        pred_counts[pred] += 1
        conf_sums[pred] += conf
    
    max_count = 0
    smoothed_pred = None
    
    for pred, count in pred_counts.items():
        if count > max_count:
            max_count = count
            smoothed_pred = pred
    
    avg_conf = conf_sums[smoothed_pred] / pred_counts[smoothed_pred] if smoothed_pred else 0
    
    return smoothed_pred, avg_conf

def run_multi_person_detection(model_path='pose_transformer.pth'):
    model = load_model(model_path)
    if model is None:
        print("Failed to load model. Exiting.")
        return
    
    cap = cv2.VideoCapture(0)
    
    mp_pose = mp.solutions.pose
    mp_drawing = mp.solutions.drawing_utils
    
    # media pipe detect one person at a time, so run it do daffa with different ROI's fucker
    pose = mp_pose.Pose(
        static_image_mode=False,
        model_complexity=1,
        smooth_landmarks=True,
        enable_segmentation=True,  
        min_detection_confidence=0.5,
        min_tracking_confidence=0.5
    )
    
    prediction_history = setup_prediction_smoothing()
    
    prev_time = 0
    
    try:
        while cap.isOpened():
            success, image = cap.read()
            if not success:
                print("Failed to read from webcam.")
                break
            
            current_time = time.time()
            fps = 1 / (current_time - prev_time) if prev_time > 0 else 0
            prev_time = current_time
            
            image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            image_height, image_width, _ = image.shape
            
            results = pose.process(image_rgb)
            
            keypoints1 = None
            keypoints2 = None
            
            if results.pose_landmarks:
                keypoints1 = preprocess_keypoints(results.pose_landmarks.landmark, image_height, image_width)
                
                image = draw_pose(image, results.pose_landmarks, mp_pose, mp_drawing, color=(0, 255, 0))
                
                if results.segmentation_mask is not None:
                    mask = np.ones((image_height, image_width), dtype=np.uint8) * 255
                    segmentation_mask = results.segmentation_mask > 0.5
                    mask[segmentation_mask] = 0
                    
                    masked_image = image_rgb.copy()
                    for c in range(3):
                        masked_image[:, :, c] = cv2.bitwise_and(masked_image[:, :, c], mask)
                    
                    second_results = pose.process(masked_image)
                    
                    if second_results.pose_landmarks:
                        keypoints2 = preprocess_keypoints(second_results.pose_landmarks.landmark, image_height, image_width)
                        
                        image = draw_pose(image, second_results.pose_landmarks, mp_pose, mp_drawing, color=(255, 0, 0))
            
            position, confidence = None, 0.0
            if keypoints1 is not None:
                position, confidence = predict_position(model, keypoints1, keypoints2)
                prediction_history.append((position, confidence))
                smoothed_position, smoothed_confidence = get_smoothed_prediction(prediction_history)
            else:
                smoothed_position, smoothed_confidence = None, 0.0
            
            prediction_text = f"Position: {smoothed_position if smoothed_position else 'None'}"
            confidence_text = f"Confidence: {smoothed_confidence:.2f}"
            fps_text = f"FPS: {fps:.1f}"
            person_text = f"People detected: {(1 if keypoints1 is not None else 0) + (1 if keypoints2 is not None else 0)}"
            
            cv2.putText(image, prediction_text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
            cv2.putText(image, confidence_text, (10, 70), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
            cv2.putText(image, fps_text, (10, 110), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
            cv2.putText(image, person_text, (10, 150), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
            
            cv2.imshow('BJJ Position Recognition', image)
            
            if cv2.waitKey(1) & 0xFF == ord('q'):
                break
                
    except KeyboardInterrupt:
        print("Detection stopped by user")
    finally:
        cap.release()
        cv2.destroyAllWindows()

def run_multi_person_detection_with_holistic(model_path='pose_transformer.pth'):
    """
    Alternative approach using MediaPipe's Holistic model
    combined with a custom algorithm to separate two people
    """
    model = load_model(model_path)
    if model is None:
        print("Failed to load model. Exiting.")
        return
    
    mp_holistic = mp.solutions.holistic
    mp_pose = mp.solutions.pose
    mp_drawing = mp.solutions.drawing_utils
    
    holistic = mp_holistic.Holistic(
        min_detection_confidence=0.5,
        min_tracking_confidence=0.5
    )
    
    cap = cv2.VideoCapture(0)
    
    prediction_history = setup_prediction_smoothing()
    
    prev_time = 0
    
    try:
        while cap.isOpened():
            success, image = cap.read()
            if not success:
                print("Failed to read from webcam.")
                break
            
            current_time = time.time()
            fps = 1 / (current_time - prev_time) if prev_time > 0 else 0
            prev_time = current_time
            
            image = cv2.flip(image, 1)
            
            image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            image_height, image_width, _ = image.shape
            
            results = holistic.process(image_rgb)
            
           
            person1_landmarks = results.pose_landmarks
            
            keypoints1 = None
            keypoints2 = None
            
            if person1_landmarks:
                keypoints1 = preprocess_keypoints(person1_landmarks.landmark, image_height, image_width)
                image = draw_pose(image, person1_landmarks, mp_pose, mp_drawing, color=(0, 255, 0))
                
                
                second_image = image_rgb.copy()
                
                landmarks = np.array([[lm.x * image_width, lm.y * image_height] for lm in person1_landmarks.landmark])
                x_min, y_min = np.min(landmarks, axis=0).astype(int)
                x_max, y_max = np.max(landmarks, axis=0).astype(int)
                
                padding = 50
                x_min = max(0, x_min - padding)
                y_min = max(0, y_min - padding)
                x_max = min(image_width, x_max + padding)
                y_max = min(image_height, y_max + padding)
                
                mask = np.ones((image_height, image_width), dtype=np.uint8) * 255
                mask[y_min:y_max, x_min:x_max] = 0
                
                for c in range(3):
                    second_image[:, :, c] = cv2.bitwise_and(second_image[:, :, c], mask)
                
                second_results = holistic.process(second_image)
                
                if second_results.pose_landmarks:
                    keypoints2 = preprocess_keypoints(second_results.pose_landmarks.landmark, image_height, image_width)
                    image = draw_pose(image, second_results.pose_landmarks, mp_pose, mp_drawing, color=(255, 0, 0))
            
            position, confidence = None, 0.0
            if keypoints1 is not None:
                position, confidence = predict_position(model, keypoints1, keypoints2)
                prediction_history.append((position, confidence))
                smoothed_position, smoothed_confidence = get_smoothed_prediction(prediction_history)
            else:
                smoothed_position, smoothed_confidence = None, 0.0
            
            # Add prediction text
            prediction_text = f"Position: {smoothed_position if smoothed_position else 'None'}"
            confidence_text = f"Confidence: {smoothed_confidence:.2f}"
            fps_text = f"FPS: {fps:.1f}"
            person_text = f"People detected: {(1 if keypoints1 is not None else 0) + (1 if keypoints2 is not None else 0)}"
            
            cv2.putText(image, prediction_text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
            cv2.putText(image, confidence_text, (10, 70), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
            cv2.putText(image, fps_text, (10, 110), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
            cv2.putText(image, person_text, (10, 150), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
            
            cv2.imshow('BJJ Position Recognition', image)
            
            if cv2.waitKey(1) & 0xFF == ord('q'):
                break
                
    except KeyboardInterrupt:
        print("Detection stopped by user")
    finally:
        cap.release()
        cv2.destroyAllWindows()

def run_optimized_multi_person_detection(model_path='pose_transformer.pth'):
    """
    Most robust implementation for two-person BJJ position detection
    using MediaPipe Pose with custom tracking for multiple people
    """
    model = load_model(model_path)
    if model is None:
        print("Failed to load model. Exiting.")
        return
    
    mp_pose = mp.solutions.pose
    mp_drawing = mp.solutions.drawing_utils
    
    pose1 = mp_pose.Pose(
        static_image_mode=False,
        model_complexity=1,
        smooth_landmarks=True,
        min_detection_confidence=0.5,
        min_tracking_confidence=0.5
    )
    
    pose2 = mp_pose.Pose(
        static_image_mode=False,
        model_complexity=1,
        smooth_landmarks=True,
        min_detection_confidence=0.5,
        min_tracking_confidence=0.5
    )
    
    cap = cv2.VideoCapture(0)
    
    prediction_history = setup_prediction_smoothing()
    
    prev_time = 0
    
    person1_center = None
    person2_center = None
    person1_tracked = False
    person2_tracked = False
    
    try:
        while cap.isOpened():
            success, image = cap.read()
            if not success:
                print("Failed to read from webcam.")
                break
            
            current_time = time.time()
            fps = 1 / (current_time - prev_time) if prev_time > 0 else 0
            prev_time = current_time
            
            image_height, image_width, _ = image.shape
            
            image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            
            results1 = pose1.process(image_rgb)
            
            keypoints1 = None
            keypoints2 = None
            
            if results1.pose_landmarks:
                keypoints1 = preprocess_keypoints(results1.pose_landmarks.landmark, image_height, image_width)
                
                image = draw_pose(image, results1.pose_landmarks, mp_pose, mp_drawing, color=(0, 255, 0))
                
                landmarks = np.array([[lm.x * image_width, lm.y * image_height] for lm in results1.pose_landmarks.landmark])
                center1 = np.mean(landmarks, axis=0)
                person1_center = center1
                person1_tracked = True
                
                mask = np.ones((image_height, image_width), dtype=np.uint8) * 255
                
                x_min, y_min = np.min(landmarks, axis=0).astype(int)
                x_max, y_max = np.max(landmarks, axis=0).astype(int)
                
                padding = 30
                x_min = max(0, x_min - padding)
                y_min = max(0, y_min - padding)
                x_max = min(image_width, x_max + padding)
                y_max = min(image_height, y_max + padding)
                
                mask[y_min:y_max, x_min:x_max] = 0
                masked_image = image_rgb.copy()
                for c in range(3):
                    masked_image[:, :, c] = cv2.bitwise_and(masked_image[:, :, c], mask)
                
                results2 = pose2.process(masked_image)
                
                if results2.pose_landmarks:
                    keypoints2 = preprocess_keypoints(results2.pose_landmarks.landmark, image_height, image_width)
                    
                    image = draw_pose(image, results2.pose_landmarks, mp_pose, mp_drawing, color=(255, 0, 0))
                    
                    landmarks2 = np.array([[lm.x * image_width, lm.y * image_height] for lm in results2.pose_landmarks.landmark])
                    center2 = np.mean(landmarks2, axis=0)
                    person2_center = center2
                    person2_tracked = True
            
            position, confidence = None, 0.0
            if keypoints1 is not None:
                position, confidence = predict_position(model, keypoints1, keypoints2)
                prediction_history.append((position, confidence))
                smoothed_position, smoothed_confidence = get_smoothed_prediction(prediction_history)
            else:
                smoothed_position, smoothed_confidence = None, 0.0
            
            prediction_text = f"Position: {smoothed_position if smoothed_position else 'None'}"
            confidence_text = f"Confidence: {smoothed_confidence:.2f}"
            fps_text = f"FPS: {fps:.1f}"
            person_text = f"People detected: {(1 if keypoints1 is not None else 0) + (1 if keypoints2 is not None else 0)}"
            
            cv2.putText(image, prediction_text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
            cv2.putText(image, confidence_text, (10, 70), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
            cv2.putText(image, fps_text, (10, 110), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
            cv2.putText(image, person_text, (10, 150), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
            
            cv2.imshow('BJJ Position Recognition', image)
            
            if cv2.waitKey(1) & 0xFF == ord('q'):
                break
                
    except KeyboardInterrupt:
        print("Detection stopped by user")
    finally:
        # Clean up
        cap.release()
        cv2.destroyAllWindows()

if __name__ == "__main__":
    print("Starting BJJ position detection with multi-person support...")
    print("Choose detection method:")
    print("1: Basic Multi-Person Detection")
    print("2: Holistic-based Multi-Person Detection")
    print("3: Optimized Multi-Person Detection (Recommended)")
    
    choice = input("Enter choice (1/2/3): ")
    
    if choice == '1':
        run_multi_person_detection()
    elif choice == '2':
        run_multi_person_detection_with_holistic()
    else:
        run_optimized_multi_person_detection()

Starting BJJ position detection with multi-person support...
Choose detection method:
1: Basic Multi-Person Detection
2: Holistic-based Multi-Person Detection
3: Optimized Multi-Person Detection (Recommended)
Model loaded successfully from pose_transformer.pth


I0000 00:00:1746550441.284911 43561978 gl_context.cc:369] GL version: 2.1 (2.1 Metal - 88.1), renderer: Apple M1
I0000 00:00:1746550441.292908 43561978 gl_context.cc:369] GL version: 2.1 (2.1 Metal - 88.1), renderer: Apple M1
W0000 00:00:1746550441.412176 43582043 inference_feedback_manager.cc:114] Feedback manager requires a model with a single signature inference. Disabling support for feedback tensors.
W0000 00:00:1746550441.412183 43582034 inference_feedback_manager.cc:114] Feedback manager requires a model with a single signature inference. Disabling support for feedback tensors.
W0000 00:00:1746550441.440229 43582034 inference_feedback_manager.cc:114] Feedback manager requires a model with a single signature inference. Disabling support for feedback tensors.
W0000 00:00:1746550441.441353 43582046 inference_feedback_manager.cc:114] Feedback manager requires a model with a single signature inference. Disabling support for feedback tensors.


Detection stopped by user


: 