In [1]:
import os
os.environ["ONNXRUNTIME_EXECUTION_PROVIDERS"] = "[CUDAExecutionProvider]"

In [2]:
import os
HOME = os.getcwd()
print(HOME)

/home/ubuntu/projects/sure-football-analysis


In [3]:
# from inference import get_model
from ultralytics import YOLO

# ROBOFLOW_API_KEY = os.environ.get("ROBOFLOW_API_KEY")
# PLAYER_DETECTION_MODEL_ID = "football-players-detection-3zvbc/12"
# PLAYER_DETECTION_MODEL = get_model(PLAYER_DETECTION_MODEL_ID, api_key=ROBOFLOW_API_KEY)
PLAYER_DETECTION_MODEL = YOLO("app/models/yolo11_football_v2/weights/best.pt")

In [4]:
import torch
from transformers import AutoProcessor, SiglipVisionModel

SIGLIP_MODEL_PATH = 'google/siglip-base-patch16-224'

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
EMBEDDINGS_MODEL = SiglipVisionModel.from_pretrained(SIGLIP_MODEL_PATH).to(DEVICE)
EMBEDDINGS_PROCESSOR = AutoProcessor.from_pretrained(SIGLIP_MODEL_PATH)

In [5]:
import supervision as sv
import numpy as np
from more_itertools import chunked
from tqdm import tqdm
import warnings
import torch

# Suppress FutureWarnings from sklearn
warnings.filterwarnings("ignore", category=FutureWarning, module="sklearn")

# Configuration
SOURCE_VIDEO_PATH = "app/test_data/raw/121364_0.mp4"
BATCH_SIZE = 64
PLAYER_ID = 2
STRIDE = 30

# Frame generator
frame_generator = sv.get_video_frames_generator(
    source_path=SOURCE_VIDEO_PATH, stride=STRIDE
)

# Collect crops
crops = []
for frame in tqdm(frame_generator, desc="collecting crops"):
    result = PLAYER_DETECTION_MODEL.predict(frame, conf=0.3)[0]
    detections = sv.Detections.from_ultralytics(result)
    detections = detections.with_nms(threshold=0.5, class_agnostic=True)
    detections = detections[detections.class_id == PLAYER_ID]
    players_crops = [sv.crop_image(frame, xyxy) for xyxy in detections.xyxy]
    crops += players_crops

# Convert crops to pillow format
crops = [sv.cv2_to_pillow(crop) for crop in crops]

# Process crops in batches
batches = chunked(crops, BATCH_SIZE)
data = []
with torch.no_grad():
    for batch in tqdm(batches, desc="embedding extraction"):
        inputs = EMBEDDINGS_PROCESSOR(images=batch, return_tensors="pt").to(DEVICE)
        outputs = EMBEDDINGS_MODEL(**inputs)
        embeddings = torch.mean(outputs.last_hidden_state, dim=1).cpu().numpy()
        data.append(embeddings)

# Concatenate all embeddings
data = np.concatenate(data)

collecting crops: 0it [00:00, ?it/s]


0: 736x1280 1 ball, 2 goalkeepers, 20 players, 2 referees, 53.7ms
Speed: 9.3ms preprocess, 53.7ms inference, 99.3ms postprocess per image at shape (1, 3, 736, 1280)


collecting crops: 1it [00:00,  1.05it/s]


0: 736x1280 1 goalkeeper, 20 players, 2 referees, 33.4ms
Speed: 9.8ms preprocess, 33.4ms inference, 1.1ms postprocess per image at shape (1, 3, 736, 1280)


collecting crops: 2it [00:01,  2.11it/s]


0: 736x1280 1 goalkeeper, 21 players, 2 referees, 33.4ms
Speed: 11.0ms preprocess, 33.4ms inference, 1.0ms postprocess per image at shape (1, 3, 736, 1280)


collecting crops: 3it [00:01,  3.12it/s]


0: 736x1280 20 players, 2 referees, 33.5ms
Speed: 9.7ms preprocess, 33.5ms inference, 1.0ms postprocess per image at shape (1, 3, 736, 1280)


collecting crops: 4it [00:01,  4.10it/s]


0: 736x1280 1 ball, 20 players, 2 referees, 33.2ms
Speed: 10.6ms preprocess, 33.2ms inference, 1.0ms postprocess per image at shape (1, 3, 736, 1280)


collecting crops: 5it [00:01,  4.83it/s]


0: 736x1280 1 ball, 20 players, 2 referees, 33.0ms
Speed: 10.1ms preprocess, 33.0ms inference, 1.0ms postprocess per image at shape (1, 3, 736, 1280)


collecting crops: 6it [00:01,  5.57it/s]


0: 736x1280 20 players, 2 referees, 33.2ms
Speed: 6.8ms preprocess, 33.2ms inference, 1.4ms postprocess per image at shape (1, 3, 736, 1280)


collecting crops: 7it [00:01,  6.05it/s]


0: 736x1280 1 ball, 19 players, 2 referees, 33.7ms
Speed: 9.3ms preprocess, 33.7ms inference, 1.0ms postprocess per image at shape (1, 3, 736, 1280)


collecting crops: 8it [00:01,  6.44it/s]


0: 736x1280 1 ball, 1 goalkeeper, 20 players, 2 referees, 33.0ms
Speed: 8.1ms preprocess, 33.0ms inference, 1.0ms postprocess per image at shape (1, 3, 736, 1280)


collecting crops: 9it [00:02,  6.77it/s]


0: 736x1280 1 ball, 21 players, 2 referees, 33.0ms
Speed: 7.0ms preprocess, 33.0ms inference, 1.0ms postprocess per image at shape (1, 3, 736, 1280)


collecting crops: 10it [00:02,  6.85it/s]


0: 736x1280 1 ball, 22 players, 2 referees, 33.0ms
Speed: 9.7ms preprocess, 33.0ms inference, 1.1ms postprocess per image at shape (1, 3, 736, 1280)


collecting crops: 11it [00:02,  7.02it/s]


0: 736x1280 1 ball, 1 goalkeeper, 20 players, 2 referees, 33.1ms
Speed: 6.9ms preprocess, 33.1ms inference, 1.0ms postprocess per image at shape (1, 3, 736, 1280)


collecting crops: 12it [00:02,  7.08it/s]


0: 736x1280 1 goalkeeper, 21 players, 2 referees, 33.1ms
Speed: 9.9ms preprocess, 33.1ms inference, 1.0ms postprocess per image at shape (1, 3, 736, 1280)


collecting crops: 13it [00:02,  7.31it/s]


0: 736x1280 1 ball, 1 goalkeeper, 20 players, 2 referees, 33.0ms
Speed: 7.2ms preprocess, 33.0ms inference, 1.1ms postprocess per image at shape (1, 3, 736, 1280)


collecting crops: 14it [00:02,  7.29it/s]


0: 736x1280 1 goalkeeper, 21 players, 2 referees, 33.2ms
Speed: 11.3ms preprocess, 33.2ms inference, 1.1ms postprocess per image at shape (1, 3, 736, 1280)


collecting crops: 15it [00:02,  7.19it/s]


0: 736x1280 1 goalkeeper, 21 players, 2 referees, 33.0ms
Speed: 11.2ms preprocess, 33.0ms inference, 1.5ms postprocess per image at shape (1, 3, 736, 1280)


collecting crops: 16it [00:03,  6.95it/s]


0: 736x1280 20 players, 3 referees, 33.5ms
Speed: 11.4ms preprocess, 33.5ms inference, 1.1ms postprocess per image at shape (1, 3, 736, 1280)


collecting crops: 17it [00:03,  6.89it/s]


0: 736x1280 22 players, 1 referee, 33.0ms
Speed: 11.7ms preprocess, 33.0ms inference, 1.3ms postprocess per image at shape (1, 3, 736, 1280)


collecting crops: 18it [00:03,  6.77it/s]


0: 736x1280 22 players, 2 referees, 33.1ms
Speed: 11.6ms preprocess, 33.1ms inference, 1.4ms postprocess per image at shape (1, 3, 736, 1280)


collecting crops: 19it [00:03,  6.59it/s]


0: 736x1280 23 players, 2 referees, 33.4ms
Speed: 11.3ms preprocess, 33.4ms inference, 1.0ms postprocess per image at shape (1, 3, 736, 1280)


collecting crops: 20it [00:03,  6.60it/s]


0: 736x1280 1 ball, 1 goalkeeper, 19 players, 2 referees, 33.3ms
Speed: 11.6ms preprocess, 33.3ms inference, 1.0ms postprocess per image at shape (1, 3, 736, 1280)


collecting crops: 21it [00:03,  6.60it/s]


0: 736x1280 1 goalkeeper, 21 players, 2 referees, 33.2ms
Speed: 11.1ms preprocess, 33.2ms inference, 1.0ms postprocess per image at shape (1, 3, 736, 1280)


collecting crops: 22it [00:03,  6.77it/s]


0: 736x1280 1 ball, 1 goalkeeper, 21 players, 2 referees, 33.2ms
Speed: 7.0ms preprocess, 33.2ms inference, 1.0ms postprocess per image at shape (1, 3, 736, 1280)


collecting crops: 23it [00:04,  6.73it/s]


0: 736x1280 1 goalkeeper, 22 players, 1 referee, 33.3ms
Speed: 12.6ms preprocess, 33.3ms inference, 1.3ms postprocess per image at shape (1, 3, 736, 1280)


collecting crops: 24it [00:04,  6.70it/s]


0: 736x1280 1 goalkeeper, 22 players, 2 referees, 33.5ms
Speed: 11.1ms preprocess, 33.5ms inference, 1.1ms postprocess per image at shape (1, 3, 736, 1280)


collecting crops: 25it [00:04,  5.62it/s]
embedding extraction: 8it [00:03,  2.42it/s]


In [6]:
import umap
from sklearn.cluster import KMeans
from sports.common.team import TeamClassifier


REDUCER = umap.UMAP(n_components=3)
CLUSTERING_MODEL = KMeans(n_clusters=2)

projections = REDUCER.fit_transform(data)
clusters = CLUSTERING_MODEL.fit_predict(projections)

frame_generator = sv.get_video_frames_generator(
    source_path=SOURCE_VIDEO_PATH, stride=STRIDE)

crops = []
for frame in tqdm(frame_generator, desc='collecting crops'):
    # result = PLAYER_DETECTION_MODEL.infer(frame, confidence=0.3)[0]
    result = PLAYER_DETECTION_MODEL.predict(frame, conf=0.3)[0]
    # detections = sv.Detections.from_inference(result)
    detections = sv.Detections.from_ultralytics(result)
    players_detections = detections[detections.class_id == PLAYER_ID]
    players_crops = [sv.crop_image(frame, xyxy) for xyxy in detections.xyxy]
    crops += players_crops

team_classifier = TeamClassifier(device="cuda")
team_classifier.fit(crops)

collecting crops: 0it [00:00, ?it/s]


0: 736x1280 1 ball, 2 goalkeepers, 20 players, 2 referees, 33.0ms
Speed: 12.7ms preprocess, 33.0ms inference, 1.0ms postprocess per image at shape (1, 3, 736, 1280)

0: 736x1280 1 goalkeeper, 20 players, 2 referees, 33.5ms
Speed: 7.4ms preprocess, 33.5ms inference, 1.0ms postprocess per image at shape (1, 3, 736, 1280)


collecting crops: 2it [00:00,  9.21it/s]


0: 736x1280 1 goalkeeper, 21 players, 2 referees, 32.8ms
Speed: 6.9ms preprocess, 32.8ms inference, 1.1ms postprocess per image at shape (1, 3, 736, 1280)


collecting crops: 3it [00:00,  8.43it/s]


0: 736x1280 20 players, 2 referees, 33.5ms
Speed: 10.1ms preprocess, 33.5ms inference, 1.0ms postprocess per image at shape (1, 3, 736, 1280)


collecting crops: 4it [00:00,  8.21it/s]


0: 736x1280 1 ball, 20 players, 2 referees, 33.5ms
Speed: 10.8ms preprocess, 33.5ms inference, 1.0ms postprocess per image at shape (1, 3, 736, 1280)


collecting crops: 5it [00:00,  7.80it/s]


0: 736x1280 1 ball, 20 players, 2 referees, 33.1ms
Speed: 10.1ms preprocess, 33.1ms inference, 1.0ms postprocess per image at shape (1, 3, 736, 1280)


collecting crops: 6it [00:00,  7.85it/s]


0: 736x1280 20 players, 2 referees, 33.4ms
Speed: 7.1ms preprocess, 33.4ms inference, 1.0ms postprocess per image at shape (1, 3, 736, 1280)


collecting crops: 7it [00:00,  7.74it/s]


0: 736x1280 1 ball, 19 players, 2 referees, 33.3ms
Speed: 10.0ms preprocess, 33.3ms inference, 1.2ms postprocess per image at shape (1, 3, 736, 1280)


collecting crops: 8it [00:01,  7.79it/s]


0: 736x1280 1 ball, 1 goalkeeper, 20 players, 2 referees, 33.3ms
Speed: 7.0ms preprocess, 33.3ms inference, 1.0ms postprocess per image at shape (1, 3, 736, 1280)


collecting crops: 9it [00:01,  7.72it/s]


0: 736x1280 1 ball, 21 players, 2 referees, 33.5ms
Speed: 7.5ms preprocess, 33.5ms inference, 1.0ms postprocess per image at shape (1, 3, 736, 1280)


collecting crops: 10it [00:01,  7.50it/s]


0: 736x1280 1 ball, 22 players, 2 referees, 32.9ms
Speed: 9.9ms preprocess, 32.9ms inference, 1.3ms postprocess per image at shape (1, 3, 736, 1280)


collecting crops: 11it [00:01,  7.49it/s]


0: 736x1280 1 ball, 1 goalkeeper, 20 players, 2 referees, 33.3ms
Speed: 11.2ms preprocess, 33.3ms inference, 1.1ms postprocess per image at shape (1, 3, 736, 1280)


collecting crops: 12it [00:01,  7.32it/s]


0: 736x1280 1 goalkeeper, 21 players, 2 referees, 33.1ms
Speed: 10.1ms preprocess, 33.1ms inference, 1.0ms postprocess per image at shape (1, 3, 736, 1280)


collecting crops: 13it [00:01,  7.46it/s]


0: 736x1280 1 ball, 1 goalkeeper, 20 players, 2 referees, 33.4ms
Speed: 6.8ms preprocess, 33.4ms inference, 1.0ms postprocess per image at shape (1, 3, 736, 1280)


collecting crops: 14it [00:01,  7.40it/s]


0: 736x1280 1 goalkeeper, 21 players, 2 referees, 32.9ms
Speed: 10.6ms preprocess, 32.9ms inference, 1.0ms postprocess per image at shape (1, 3, 736, 1280)


collecting crops: 15it [00:01,  7.49it/s]


0: 736x1280 1 goalkeeper, 21 players, 2 referees, 33.1ms
Speed: 11.2ms preprocess, 33.1ms inference, 1.0ms postprocess per image at shape (1, 3, 736, 1280)


collecting crops: 16it [00:02,  7.35it/s]


0: 736x1280 20 players, 3 referees, 32.9ms
Speed: 11.1ms preprocess, 32.9ms inference, 1.0ms postprocess per image at shape (1, 3, 736, 1280)


collecting crops: 17it [00:02,  7.40it/s]


0: 736x1280 22 players, 1 referee, 33.5ms
Speed: 10.0ms preprocess, 33.5ms inference, 1.0ms postprocess per image at shape (1, 3, 736, 1280)


collecting crops: 18it [00:02,  7.44it/s]


0: 736x1280 22 players, 2 referees, 33.5ms
Speed: 11.2ms preprocess, 33.5ms inference, 1.0ms postprocess per image at shape (1, 3, 736, 1280)


collecting crops: 19it [00:02,  7.32it/s]


0: 736x1280 23 players, 2 referees, 32.8ms
Speed: 10.1ms preprocess, 32.8ms inference, 1.0ms postprocess per image at shape (1, 3, 736, 1280)


collecting crops: 20it [00:02,  7.44it/s]


0: 736x1280 1 ball, 1 goalkeeper, 19 players, 2 referees, 32.9ms
Speed: 6.8ms preprocess, 32.9ms inference, 1.2ms postprocess per image at shape (1, 3, 736, 1280)


collecting crops: 21it [00:02,  7.41it/s]


0: 736x1280 1 goalkeeper, 21 players, 2 referees, 33.1ms
Speed: 11.2ms preprocess, 33.1ms inference, 1.1ms postprocess per image at shape (1, 3, 736, 1280)


collecting crops: 22it [00:02,  7.54it/s]


0: 736x1280 1 ball, 1 goalkeeper, 21 players, 2 referees, 33.3ms
Speed: 10.7ms preprocess, 33.3ms inference, 1.2ms postprocess per image at shape (1, 3, 736, 1280)


collecting crops: 23it [00:03,  7.37it/s]


0: 736x1280 1 goalkeeper, 22 players, 1 referee, 33.4ms
Speed: 6.9ms preprocess, 33.4ms inference, 1.1ms postprocess per image at shape (1, 3, 736, 1280)


collecting crops: 24it [00:03,  7.55it/s]


0: 736x1280 1 goalkeeper, 22 players, 2 referees, 33.0ms
Speed: 6.8ms preprocess, 33.0ms inference, 1.1ms postprocess per image at shape (1, 3, 736, 1280)


collecting crops: 25it [00:03,  7.42it/s]
Embedding extraction: 19it [00:03,  4.81it/s]


## Imports and Initial Setup

In [8]:
import supervision as sv
from tqdm import tqdm
import numpy as np
from boxmot import BotSort # Using BoTSORT
import cv2
from pathlib import Path
import torch
from collections import defaultdict, deque
import warnings
import logging
import traceback # Import traceback for detailed error printing
import os
import random # Added for sparkle effect
import math # Added for distance calculation
import argparse # Added for command-line arguments
import base64 # For encoding images for Groq API
from groq import Groq # For interacting with Groq API

# Suppress most logging messages
logging.basicConfig(level=logging.ERROR) # Show only CRITICAL errors
logging.disable(logging.WARNING) # Disable WARNING messages
logging.disable(logging.INFO) # Disable INFO messages specifically
warnings.filterwarnings('ignore', category=UserWarning, module='paddle')
warnings.filterwarnings('ignore', category=UserWarning, module='torchvision') # Ignore potential torchvision warnings
warnings.filterwarnings('ignore', category=FutureWarning) # Ignore potential future warnings

# Attempt to import PaddleOCR
try:
    from paddleocr import PaddleOCR
    PADDLEOCR_AVAILABLE = True
except ImportError:
    print("Warning: PaddleOCR not found. OCR functionality will be disabled.")
    PADDLEOCR_AVAILABLE = False


In [9]:
# ----- Configuration -----
# --- Paths ---
DEFAULT_SOURCE_VIDEO_PATH = "app/test_data/raw/your_video.mp4" # INPUT: Path to your video
DEFAULT_OUTPUT_VIDEO_PATH = "your_video_tracking_actions_analyzed.mp4" # OUTPUT: Path for annotated video
DEFAULT_ACTIONS_DIR = "app/test_data/predicted/action_frames" # OUTPUT: Base directory for saved frame sequences
OCR_DEBUG_DIR = "ocr_debug_crops" # Optional output for OCR debugging
DEFAULT_REID_WEIGHTS_PATH = 'clip_market1501.pt' # INPUT: Path for BoTSORT ReID weights (if used)

# --- Processing Device ---
DEVICE = torch.device(0) if torch.cuda.is_available() else torch.device('cpu')

# --- Class IDs ---
# Initial detection model IDs (MUST match your detection model output)
BALL_ID = 0
GOALKEEPER_ID = 1
PLAYER_ID = 2
REFEREE_ID = 3
# Team/Role Class IDs (assigned *after* classification/resolution)
TEAM_A_ID = 0 # Example ID for Team A
TEAM_B_ID = 1 # Example ID for Team B
REFEREE_TEAM_ID = 2 # Example ID for Referee

# --- OCR Configuration ---
OCR_CONFIDENCE_THRESHOLD = 0.8 # Minimum confidence for accepting OCR result
MIN_JERSEY_DIGITS = 1 # Min number of digits expected on a jersey
MAX_JERSEY_DIGITS = 2 # Max number of digits expected on a jersey

# --- ID Management Configuration ---
LOST_TRACK_MEMORY_SECONDS = 20 # How long to remember a lost track ID for potential re-identification
MISMATCH_CONSISTENCY_FRAMES = 3 # How many consecutive frames an OCR mismatch must occur to update ID

# --- Ball Trail Configuration ---
BALL_TRAIL_SECONDS = 1 # Duration of the visual ball trail
SPARKLE_COUNT = 3 # Number of sparkles per trail point
SPARKLE_RADIUS = 2 # Radius of sparkles
SPARKLE_OFFSET = 3 # Max random offset for sparkles
MAX_BALL_DISTANCE_PER_FRAME = 400 # Max pixels ball can move between frames (for outlier rejection)

# --- Interaction Detection & Frame Saving Configuration ---
IOU_THRESHOLD = 0.05 # Min Intersection over Union for player-ball interaction
MIN_INTERACTION_FRAMES = 5  # Min consecutive frames for a valid interaction
CLIP_PADDING_SECONDS = 2 # Seconds of padding before/after interaction core
FRAMES_PER_SECOND_TO_SAVE = 5 # Target FPS for saved frame sequences (for VLM)

# --- Groq API Configuration ---
# IMPORTANT: Set the GROQ_API_KEY environment variable before running!
# export GROQ_API_KEY='your_actual_api_key'
GROQ_API_KEY = os.environ.get("GROQ_API_KEY")
GROQ_VLM_MODEL = "llama3-groq-vision-alpha" # Adjust if needed, check Groq docs for current vision models


In [10]:
# ----- Initialize PaddleOCR -----
ocr_model = None
if PADDLEOCR_AVAILABLE:
    try:
        # Initialize PaddleOCR - adjust parameters as needed
        ocr_model = PaddleOCR(
            use_angle_cls=False, # Usually false for jersey numbers
            lang='en',           # Assume English digits
            use_gpu=(DEVICE.type == 'cuda'), # Use GPU if available
            show_log=False       # Suppress PaddleOCR's internal logs
        )
        print("PaddleOCR initialized successfully.")
    except Exception as e:
        print(f"Error initializing PaddleOCR: {e}. Disabling OCR.")
        PADDLEOCR_AVAILABLE = False
else:
    print("PaddleOCR not available. Skipping initialization.")


PaddleOCR initialized successfully.


In [11]:
# ----- Color Calculation & Helper -----
DEFAULT_TEAM_A_COLOR = sv.Color.from_hex('#FF0000') # Red
DEFAULT_TEAM_B_COLOR = sv.Color.from_hex('#00FFFF') # Cyan
DEFAULT_REFEREE_COLOR = sv.Color.from_hex('#00FF00') # Green
FALLBACK_COLOR = sv.Color.from_hex('#808080') # Grey
COLOR_SIMILARITY_THRESHOLD = 50.0 # Max RGB distance diff to be considered ambiguous

def calculate_average_color(frame: np.ndarray, detections: sv.Detections, central_fraction: float = 0.5) -> sv.Color | None:
    """Calculates the average color from the central region of detection boxes."""
    if len(detections) == 0: return None
    avg_colors = []
    height, width, _ = frame.shape
    for xyxy in detections.xyxy:
        x1, y1, x2, y2 = map(int, xyxy)
        x1, y1 = max(0, x1), max(0, y1); x2, y2 = min(width, x2), min(height, y2)
        if x1 >= x2 or y1 >= y2: continue # Skip invalid boxes
        # Extract central region
        box_w, box_h = x2 - x1, y2 - y1
        center_x, center_y = x1 + box_w // 2, y1 + box_h // 2
        central_w, central_h = int(box_w * central_fraction), int(box_h * central_fraction)
        cx1 = max(x1, center_x - central_w // 2); cy1 = max(y1, center_y - central_h // 2)
        cx2 = min(x2, center_x + central_w // 2); cy2 = min(y2, center_y + central_h // 2)
        if cx1 >= cx2 or cy1 >= cy2: continue # Skip if central region is invalid
        crop = frame[cy1:cy2, cx1:cx2]
        if crop.size > 0:
            avg_bgr = cv2.mean(crop)[:3] # Calculate mean BGR
            avg_colors.append(avg_bgr)
    if not avg_colors: return None
    # Calculate overall average and convert to sv.Color
    final_avg_bgr = np.mean(avg_colors, axis=0)
    b, g, r = map(int, final_avg_bgr)
    # Prevent very dark colors (optional adjustment)
    min_intensity = 50
    if r < min_intensity and g < min_intensity and b < min_intensity:
        r, g, b = min_intensity, min_intensity, min_intensity
    return sv.Color(r=r, g=g, b=b)

def color_distance(color1: sv.Color | None, color2: sv.Color | None) -> float:
    """Calculates Euclidean distance between two sv.Color objects in RGB space."""
    if color1 is None or color2 is None: return float('inf')
    # Basic check for valid sv.Color objects
    if not all(hasattr(c, attr) for c in [color1, color2] for attr in ['r', 'g', 'b']): return float('inf')
    try:
        rgb1 = np.array([color1.r, color1.g, color1.b])
        rgb2 = np.array([color2.r, color2.g, color2.b])
        return np.linalg.norm(rgb1 - rgb2) # Euclidean distance
    except Exception as e: print(f"Error calculating color distance: {e}"); return float('inf')

# ----- Enhanced Goalkeeper Resolution Function -----
def resolve_goalkeepers_team_id(
    frame: np.ndarray,
    goalkeepers: sv.Detections,
    team_a_color: sv.Color | None,
    team_b_color: sv.Color | None,
    color_similarity_threshold: float = COLOR_SIMILARITY_THRESHOLD
) -> np.ndarray:
    """Assigns team IDs (TEAM_A_ID or TEAM_B_ID) to goalkeepers."""
    goalkeeper_team_ids = []
    if len(goalkeepers) == 0: return np.array([], dtype=int)

    frame_height, frame_width, _ = frame.shape
    valid_team_colors = team_a_color is not None and team_b_color is not None

    for i in range(len(goalkeepers)):
        gk_detection_single = goalkeepers[i:i+1] # Process one GK at a time
        # Get center x-coordinate for positional fallback
        gk_center_x, _ = gk_detection_single.get_anchors_coordinates(sv.Position.CENTER)[0]
        assigned_id = -1 # Default to invalid ID

        # 1. Try Color Similarity if team colors are valid
        if valid_team_colors:
            gk_color = calculate_average_color(frame, gk_detection_single)
            if gk_color is not None:
                dist_a = color_distance(gk_color, team_a_color)
                dist_b = color_distance(gk_color, team_b_color)
                # Assign if colors are distinct enough
                if abs(dist_a - dist_b) > color_similarity_threshold:
                    assigned_id = TEAM_A_ID if dist_a < dist_b else TEAM_B_ID
                # else: Color is ambiguous, proceed to fallback

        # 2. Positional Fallback (if color failed or was ambiguous)
        if assigned_id == -1:
            # Assign based on which half of the pitch they are on
            # Assumes Team A (ID 0) defends left goal, Team B (ID 1) defends right
            assigned_id = TEAM_A_ID if gk_center_x < frame_width / 2 else TEAM_B_ID

        goalkeeper_team_ids.append(assigned_id)

    return np.array(goalkeeper_team_ids, dtype=int)


# ----- OCR Function -----
def perform_ocr_on_crop(crop: np.ndarray) -> tuple[str | None, float | None]:
    """Performs OCR on a given crop, returning the best digit sequence and confidence."""
    if not PADDLEOCR_AVAILABLE or ocr_model is None or crop.size == 0:
        return None, None # Return None if OCR not available or crop is empty
    try:
        # Perform OCR using the initialized model
        result = ocr_model.ocr(crop, cls=False) # cls=False might improve speed for digits
        best_num, highest_conf = None, 0.0

        # Process results: PaddleOCR typically returns a list of lists
        # Each inner list can contain [box_coordinates, (text, confidence)]
        if result and result[0]: # Check if result is not empty
             for res_item in result[0]:
                 # Ensure the structure is as expected
                 if len(res_item) == 2 and isinstance(res_item[1], tuple) and len(res_item[1]) == 2:
                     text, confidence = res_item[1]
                     # Validate the extracted text and confidence
                     if (isinstance(text, str) and text.isdigit() and
                         MIN_JERSEY_DIGITS <= len(text) <= MAX_JERSEY_DIGITS and
                         isinstance(confidence, (float, int)) and
                         confidence > OCR_CONFIDENCE_THRESHOLD):
                         # Keep the result with the highest confidence
                         if confidence > highest_conf:
                             highest_conf, best_num = confidence, text
        # Return the best number found, or None if no valid number met the threshold
        return best_num, highest_conf if best_num else None
    except Exception as e:
        # Log errors during OCR inference
        print(f"Error during PaddleOCR inference: {e}")
        return None, None


In [12]:
# ----- Interaction Tracking Class -----
class InteractionTracker:
    """
    Tracks interactions between ball and players based on IoU,
    manages confirmed interactions, and buffers frames for clip extraction.
    """
    def __init__(self, fps):
        self.fps = fps if fps > 0 else 30 # Use default FPS if invalid
        # Stores active interactions: key=(ball_id, player_id) -> {details}
        self.active_interactions = defaultdict(lambda: {
            'start_frame': None, 'end_frame': None,
            'ball_id': None, 'player_id': None, 'active': False
        })
        # List to store details of confirmed interactions
        self.confirmed_interactions = []
        # Calculate buffer size based on padding and a safety margin
        buffer_seconds = (CLIP_PADDING_SECONDS * 2) + 5 # e.g., 2s before + 2s after + 5s margin
        max_buffer_frames = int(buffer_seconds * self.fps) if self.fps > 0 else 150 # Default if FPS invalid
        # Deque to efficiently store recent frames (frame_number, frame_image)
        self.frame_buffer = deque(maxlen=max_buffer_frames)
        # Keep track of interaction indices already processed (for saving frames)
        self.processed_interaction_indices = set()

    def update(self, frame_number, current_frame, interactions):
        """
        Updates interaction status based on current frame's interactions
        and stores the frame in the buffer.

        Args:
            frame_number (int): The current frame number.
            current_frame (np.ndarray): The current video frame (original, pre-annotation).
            interactions (list): List of (ball_tracker_id, player_tracker_id) tuples
                                 representing interactions in the current frame.
        """
        # Store a *copy* of the original frame to avoid modification by later annotation steps
        self.frame_buffer.append((frame_number, current_frame.copy()))

        # Set of active (ball_id, player_id) keys in the current frame
        current_interaction_keys = set(interactions)

        # --- Update Active Interactions ---
        for ball_id, player_id in current_interaction_keys:
            key = (ball_id, player_id)
            if not self.active_interactions[key]['active']:
                # Start tracking a new potential interaction
                self.active_interactions[key] = {
                    'start_frame': frame_number, 'end_frame': frame_number,
                    'ball_id': ball_id, 'player_id': player_id, 'active': True
                }
            else:
                # Continue an existing interaction, update the end frame
                self.active_interactions[key]['end_frame'] = frame_number

        # --- Check for Ended Interactions ---
        # Find keys that were active previously but are not in the current frame
        ended_keys = [k for k in self.active_interactions
                      if self.active_interactions[k]['active'] and k not in current_interaction_keys]

        for key in ended_keys:
            interaction_data = self.active_interactions[key]
            # Mark as inactive *before* checking duration
            self.active_interactions[key]['active'] = False

            # Calculate interaction duration (inclusive)
            duration = (interaction_data['end_frame'] - interaction_data['start_frame']) + 1

            # Check if the interaction meets the minimum duration threshold
            if duration >= MIN_INTERACTION_FRAMES:
                # Calculate padding in frames
                padding_frames = int(CLIP_PADDING_SECONDS * self.fps) if self.fps > 0 else 0
                # Calculate clip start/end frames including padding
                clip_start_frame = max(0, interaction_data['start_frame'] - padding_frames)
                # Pad *after* the last frame the interaction was seen
                clip_end_frame = interaction_data['end_frame'] + padding_frames

                # Store the confirmed interaction details
                self.confirmed_interactions.append({
                    'ball_id': interaction_data['ball_id'],
                    'player_id': interaction_data['player_id'],
                    'interaction_start_frame': interaction_data['start_frame'],
                    'interaction_end_frame': interaction_data['end_frame'],
                    'clip_start_frame': clip_start_frame,
                    'clip_end_frame': clip_end_frame,
                })
                # Optional: Log confirmed interaction
                # print(f"Confirmed interaction: Ball {interaction_data['ball_id']} / Player {interaction_data['player_id']} "
                #       f"(Frames {interaction_data['start_frame']}-{interaction_data['end_frame']}, Duration: {duration} frames)")

    def get_clip_frames_data(self, interaction_index):
        """
        Retrieves frame numbers and frame images for a specific confirmed interaction,
        ensuring it hasn't been processed before for saving.

        Args:
            interaction_index (int): The index of the confirmed interaction.

        Returns:
            list or None: A list of (frame_number, frame_image) tuples for the clip,
                          or None if the index is invalid or already processed.
        """
        # Check if index is valid and not already processed
        if interaction_index in self.processed_interaction_indices or \
           interaction_index >= len(self.confirmed_interactions):
             return None # Invalid index or already handled

        interaction = self.confirmed_interactions[interaction_index]
        start_frame = interaction['clip_start_frame']
        end_frame = interaction['clip_end_frame']

        # Extract relevant frames from the buffer based on calculated clip range
        clip_frames_data = [(f_num, frame) for (f_num, frame) in self.frame_buffer
                           if start_frame <= f_num <= end_frame]

        # Handle cases where buffer might not contain all needed frames (e.g., very long interactions)
        if not clip_frames_data:
            print(f"Warning: No frames found in buffer for interaction {interaction_index} "
                  f"(Requested frames {start_frame}-{end_frame}). Buffer might be too small "
                  f"or interaction happened too long ago.")
            # Mark as processed even if no frames found to avoid retrying
            self.processed_interaction_indices.add(interaction_index)
            return None

        # Mark this interaction index as processed to prevent duplicate saving/analysis
        self.processed_interaction_indices.add(interaction_index)
        return clip_frames_data


In [13]:
# ----- Frame Saving Function -----
def save_interaction_frames(clip_frames_data, base_dir, interaction_info, original_fps) -> str | None:
    """
    Saves selected frames from clip data as images, sampled at target FPS.

    Args:
        clip_frames_data (list): List of (frame_number, frame_image) tuples.
        base_dir (str): The root directory to save action frame sequences.
        interaction_info (dict): Dictionary containing 'index', 'ball_id', 'player_id'.
        original_fps (float): The FPS of the source video.

    Returns:
        str | None: The path to the directory where frames were saved, or None on failure.
    """
    if not clip_frames_data:
        # print(f"Info: No frames to save for interaction {interaction_info.get('index', 'N/A')}.")
        return None # Nothing to save

    # Create a unique directory name for this interaction clip sequence
    clip_dir_name = (f"action_{interaction_info.get('index', 'unk'):04d}_"
                     f"b{interaction_info.get('ball_id', 'unk')}_"
                     f"p{interaction_info.get('player_id', 'unk')}")
    clip_path = os.path.join(base_dir, clip_dir_name)
    try:
        # Create the directory, including parent directories if needed
        os.makedirs(clip_path, exist_ok=True)
    except OSError as e:
        print(f"Error creating directory {clip_path}: {e}")
        return None # Cannot save frames if directory creation fails

    # Calculate the frame step needed to achieve the target save FPS
    if FRAMES_PER_SECOND_TO_SAVE <= 0 or original_fps <= 0:
        frame_step = 1 # Save all frames if target/original FPS is invalid
        # print("Warning: Saving all frames for clip due to invalid target/original FPS.")
    else:
        # Ensure step is at least 1
        frame_step = max(1, round(original_fps / FRAMES_PER_SECOND_TO_SAVE))

    saved_count = 0
    # Iterate through the clip frames using the calculated step
    for i in range(0, len(clip_frames_data), frame_step):
        frame_number, frame_image = clip_frames_data[i]
        # Use the original frame number in the filename for traceability
        frame_filename = f"frame_{frame_number:06d}.jpg" # Using jpg for potentially smaller size
        frame_filepath = os.path.join(clip_path, frame_filename)

        try:
            # Save the frame as a JPG image with decent quality
            cv2.imwrite(frame_filepath, frame_image, [int(cv2.IMWRITE_JPEG_QUALITY), 90])
            saved_count += 1
        except Exception as e:
            # Log error but continue trying to save other frames
            print(f"Error saving frame {frame_number} to {frame_filepath}: {e}")

    if saved_count > 0:
        # print(f"Saved {saved_count} frames for interaction {interaction_info.get('index', 'N/A')} "
        #       f"to {clip_path} (saved every {frame_step} frame(s))")
        return clip_path # Return the path where frames were saved
    else:
         print(f"Info: No frames were successfully saved for interaction {interaction_info.get('index', 'N/A')}")
         # Optionally remove the empty directory
         try: os.rmdir(clip_path)
         except OSError: pass
         return None


In [14]:
# ----- Groq API Interaction Functions -----

def encode_image(image_path: str) -> str | None:
    """Encodes a single image file to a base64 string."""
    try:
        with open(image_path, "rb") as image_file:
            return base64.b64encode(image_file.read()).decode('utf-8')
    except FileNotFoundError:
        print(f"Error: Image file not found at {image_path}")
        return None
    except Exception as e:
        print(f"Error encoding image {image_path}: {e}")
        return None

def analyze_action_frames(clip_dir: str, groq_client: Groq, player_id: int | str, ball_id: int | str):
    """
    Sends frames from a directory to Groq VLM and prints the analysis.

    Args:
        clip_dir (str): Path to the directory containing saved frame images (.jpg).
        groq_client (Groq): Initialized Groq API client.
        player_id (int | str): The tracker ID of the player involved in the interaction.
        ball_id (int | str): The tracker ID (or placeholder) of the ball involved.
    """
    print(f"\n--- Analyzing action for Player {player_id} (Interaction Dir: {os.path.basename(clip_dir)}) ---")
    try:
        image_files = sorted([os.path.join(clip_dir, f) for f in os.listdir(clip_dir) if f.lower().endswith(".jpg")])
        if not image_files:
            print("No image files found in the directory.")
            return

        # Prepare message content with text prompt and image URLs
        content = [{"type": "text", "text": f"Describe the action performed by the player (Tracker ID: {player_id}) interacting with the ball (ID: {ball_id}) in this sequence of frames."}]
        print(f"Sending {len(image_files)} frames to Groq VLM...")

        for image_path in image_files:
            base64_image = encode_image(image_path)
            if base64_image:
                content.append({
                    "type": "image_url",
                    "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"},
                })
            else:
                print(f"Skipping invalid image: {image_path}")

        if len(content) <= 1: # Only text prompt, no valid images
             print("No valid images could be encoded for analysis.")
             return

        # Make the API call
        chat_completion = groq_client.chat.completions.create(
            messages=[{"role": "user", "content": content}],
            model=GROQ_VLM_MODEL,
            # max_tokens=1024, # Optional: Adjust token limit if needed
            # temperature=0.2 # Optional: Adjust temperature for creativity vs factuality
        )

        # Print the response
        analysis = chat_completion.choices[0].message.content
        print(f"Groq VLM Analysis:\n{analysis}")

    except Exception as e:
        print(f"Error during Groq API call for {clip_dir}: {e}")
        traceback.print_exc()
    print("--- Analysis complete ---")


In [15]:
# ----- Annotation Parameters -----
ELLIPSE_THICKNESS = 1
LABEL_TEXT_COLOR = sv.Color.BLACK
LABEL_TEXT_POSITION = sv.Position.BOTTOM_CENTER
LABEL_TEXT_SCALE = 0.4
LABEL_TEXT_THICKNESS = 1
BALL_TRAIL_BASE_COLOR = (0, 255, 255) # Bright Cyan (BGR)
BALL_TRAIL_THICKNESS = 1
SPARKLE_BASE_INTENSITY = 150
SPARKLE_MAX_INTENSITY = 255
CURRENT_BALL_MARKER_RADIUS = 4
CURRENT_BALL_MARKER_COLOR = (255, 255, 255) # White (BGR)
CURRENT_BALL_MARKER_THICKNESS = -1 # Filled circle


In [16]:
# ----- Global State -----
player_data = {} # Stores data per track_id: {jersey_id, confidence, last_seen, team_id, mismatch_history}
recently_lost_jerseys = defaultdict(lambda: deque(maxlen=10)) # jersey_num -> list of lost track info
ball_positions = None # Deque for ball trail positions (will be initialized in main)

# ----- Frame Processing Function -----
def process_frame(
    frame: np.ndarray,
    frame_idx: int,
    tracker: BotSort,
    interaction_tracker: InteractionTracker,
    width: int, # Pass frame dimensions
    height: int # Pass frame dimensions
    ):
    """Processes a single frame for detection, tracking, OCR, interaction, and annotation."""
    # Access global state (or pass models/state as arguments if preferred)
    global player_data, recently_lost_jerseys, ball_positions, detection_model, team_classifier

    # 1. Detection
    results = detection_model.predict(frame, conf=0.3, iou=0.5, device=DEVICE, verbose=False)
    if not results or len(results) == 0:
        interaction_tracker.update(frame_idx, frame, []) # Update tracker even if no detections
        return frame # Return original frame

    detections = sv.Detections.from_ultralytics(results[0])

    # 2. Pre-processing & Ball Position Update
    ball_detections = detections[detections.class_id == BALL_ID]
    people_detections = detections[detections.class_id != BALL_ID]

    # Update ball trail deque (check if initialized)
    if ball_positions is not None and len(ball_detections) > 0 :
        x1, y1, x2, y2 = ball_detections.xyxy[0] # Assume single ball
        current_center = (int((x1 + x2) / 2), int((y1 + y2) / 2))
        is_valid_position = True # Check for outliers
        if len(ball_positions) > 0:
            prev_center = ball_positions[-1]
            if isinstance(prev_center, tuple) and len(prev_center) == 2:
                distance = math.dist(current_center, prev_center)
                if distance > MAX_BALL_DISTANCE_PER_FRAME: is_valid_position = False
        if is_valid_position: ball_positions.append(current_center)

    # 3. Team/Role Classification
    players_detections = people_detections[people_detections.class_id == PLAYER_ID]
    goalkeepers_detections = people_detections[people_detections.class_id == GOALKEEPER_ID]
    referees_detections = people_detections[people_detections.class_id == REFEREE_ID]

    # --- Player Classification ---
    classified_players = sv.Detections.empty()
    if len(players_detections) > 0:
        players_crops = []; valid_indices = []
        for i, xyxy in enumerate(players_detections.xyxy):
            crop = sv.crop_image(frame, xyxy);
            if crop is not None and crop.size > 0: players_crops.append(crop); valid_indices.append(i)
        if players_crops:
            predicted_team_ids = team_classifier.predict(players_crops)
            if predicted_team_ids is not None and len(predicted_team_ids) == len(players_crops):
                assigned_ids = np.full(len(players_detections), -1, dtype=int) # Default -1
                for i, pred_id in enumerate(predicted_team_ids): assigned_ids[valid_indices[i]] = pred_id
                valid_classification_mask = (assigned_ids != -1) # Filter out failed classifications
                players_detections.class_id = assigned_ids # Update class IDs
                classified_players = players_detections[valid_classification_mask]

    # --- Calculate Dynamic Team Colors ---
    team_a_detections = classified_players[classified_players.class_id == TEAM_A_ID]
    team_b_detections = classified_players[classified_players.class_id == TEAM_B_ID]
    referee_only_detections = referees_detections # Use original referee detections for color calculation
    current_team_a_color = calculate_average_color(frame, team_a_detections) or DEFAULT_TEAM_A_COLOR
    current_team_b_color = calculate_average_color(frame, team_b_detections) or DEFAULT_TEAM_B_COLOR
    current_referee_color = calculate_average_color(frame, referee_only_detections) or DEFAULT_REFEREE_COLOR
    # Map final team/role IDs to their current average color
    dynamic_color_map = {
        TEAM_A_ID: current_team_a_color,
        TEAM_B_ID: current_team_b_color,
        REFEREE_TEAM_ID: current_referee_color
    }

    # --- Goalkeeper Classification ---
    classified_gks = sv.Detections.empty()
    if len(goalkeepers_detections) > 0:
        gk_team_ids = resolve_goalkeepers_team_id(frame, goalkeepers_detections, current_team_a_color, current_team_b_color)
        if gk_team_ids is not None and len(gk_team_ids) == len(goalkeepers_detections):
            valid_gk_mask = (gk_team_ids != -1) # Ensure resolution was successful
            goalkeepers_detections.class_id = gk_team_ids # Update class IDs to TEAM_A_ID or TEAM_B_ID
            classified_gks = goalkeepers_detections[valid_gk_mask]

    # --- Referee Classification ---
    classified_refs = sv.Detections.empty()
    if len(referees_detections) > 0:
        ref_team_ids = np.full(len(referees_detections), REFEREE_TEAM_ID) # Assign specific referee ID
        referees_detections.class_id = ref_team_ids
        classified_refs = referees_detections

    # --- Merge Detections for Tracking ---
    # Combine all classified entities (players, GKs, refs) before sending to tracker
    detections_to_track = sv.Detections.merge([classified_players, classified_gks, classified_refs])

    # 4. Tracking using BoTSORT
    tracked_people = sv.Detections.empty() # Detections with tracker IDs assigned
    current_frame_tracker_ids = set() # Keep track of IDs present in this frame
    if len(detections_to_track) > 0 and tracker is not None:
        # Prepare input for BoTSORT: [x1, y1, x2, y2, conf, cls]
        # Use the *final* classified IDs (TEAM_A_ID, TEAM_B_ID, REFEREE_TEAM_ID)
        boxmot_input = np.hstack((
            detections_to_track.xyxy,
            detections_to_track.confidence[:, np.newaxis],
            detections_to_track.class_id[:, np.newaxis]
        ))
        try:
            # Update tracker state
            tracks = tracker.update(boxmot_input, frame) # Pass original frame for ReID
            if tracks.shape[0] > 0:
                # Process tracker output: [x1, y1, x2, y2, track_id, conf, cls, ...]
                tracked_people = sv.Detections(
                    xyxy=tracks[:, 0:4],
                    tracker_id=tracks[:, 4].astype(int),
                    confidence=tracks[:, 5],
                    class_id=tracks[:, 6].astype(int) # Use the class ID returned by the tracker
                )
                current_frame_tracker_ids = set(tracked_people.tracker_id)
        except Exception as e: print(f"[Frame {frame_idx}] Error during tracker update: {e}")
    elif tracker is not None: # Update tracker even if no detections this frame to advance internal state
         try: tracker.update(np.empty((0, 6)), frame) # Empty update
         except Exception as e: print(f"[Frame {frame_idx}] Error updating tracker with empty input: {e}")

    # --- Use DETECTED ball for interaction checks (simpler than tracking ball) ---
    tracked_ball = ball_detections

    # 5. Interaction Detection (using TRACKED people and DETECTED ball)
    current_interactions = [] # List of (ball_id, player_id) tuples for this frame
    placeholder_ball_id = -1 # Use a placeholder as ball isn't tracked with persistent ID here

    if len(tracked_ball) > 0 and len(tracked_people) > 0:
        ball_box = tracked_ball.xyxy[0:1] # Get the first detected ball's box [1, 4]
        # Consider only players/GKs for interaction (filter out referees)
        player_mask = np.isin(tracked_people.class_id, [TEAM_A_ID, TEAM_B_ID])
        interacting_players = tracked_people[player_mask]

        if len(interacting_players) > 0:
            # Calculate IoU between the ball and all potential players
            iou_matrix = sv.box_iou_batch(ball_box, interacting_players.xyxy) # Shape: (1, num_players)
            # Find indices of players whose IoU with the ball exceeds the threshold
            interacting_player_indices = np.where(iou_matrix[0] > IOU_THRESHOLD)[0]

            # Record interactions using player's tracker ID
            for player_idx in interacting_player_indices:
                player_tracker_id = interacting_players.tracker_id[player_idx]
                current_interactions.append((placeholder_ball_id, player_tracker_id))

    # 6. Update Interaction Tracker (Pass ORIGINAL frame before annotation)
    interaction_tracker.update(frame_idx, frame, current_interactions)

    # 7. OCR and Player ID Management (Label Generation)
    final_labels = [] # List to store display labels for each tracked person
    current_player_data = {} # Stores updated player data for this frame

    if len(tracked_people) > 0:
        # Get frame dimensions (needed for cropping checks)
        frame_height, frame_width, _ = frame.shape

        for i in range(len(tracked_people)):
            track_id = tracked_people.tracker_id[i]
            team_id = tracked_people.class_id[i] # Final team/role ID from tracker
            bbox = tracked_people.xyxy[i]
            # Ensure bbox coordinates are valid integers within frame bounds
            x1, y1, x2, y2 = map(int, bbox)
            x1, y1 = max(0, x1), max(0, y1)
            x2, y2 = min(frame_width, x2), min(frame_height, y2) # Use frame dimensions

            detected_jersey_num, ocr_confidence = None, None
            # Attempt OCR only if bbox is valid and OCR is available
            if x1 < x2 and y1 < y2 and PADDLEOCR_AVAILABLE:
                 player_crop = frame[y1:y2, x1:x2]
                 if player_crop.size > 0: # Check crop is not empty
                    # Convert crop to grayscale for potentially better OCR
                    gray_crop = cv2.cvtColor(player_crop, cv2.COLOR_BGR2GRAY)
                    # Apply some basic preprocessing (optional, might help OCR)
                    # gray_crop = cv2.threshold(gray_crop, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)[1]
                    # gray_crop = cv2.medianBlur(gray_crop, 3)
                    detected_jersey_num, ocr_confidence = perform_ocr_on_crop(gray_crop)
                    # --- Optional: Save debug crops ---
                    # if detected_jersey_num is not None:
                    #     try:
                    #         debug_filename = os.path.join(OCR_DEBUG_DIR, f"f{frame_idx}_t{track_id}_ocr_{detected_jersey_num}.png")
                    #         cv2.imwrite(debug_filename, gray_crop)
                    #     except Exception as write_e: print(f"Error saving OCR debug crop: {write_e}")

            # --- Update player data based on OCR result and history ---
            assigned_jersey_id = None # Jersey ID to display for this track in this frame
            if track_id in player_data:
                # Track already exists, update its data
                p_data = player_data[track_id]
                p_data["last_seen"] = frame_idx
                p_data["team_id"] = team_id # Update team ID (usually shouldn't change)
                current_jersey_id = p_data["jersey_id"] # Stored jersey ID
                mismatch_history = p_data["mismatch_history"]

                if detected_jersey_num is not None:
                    # OCR successful this frame
                    if current_jersey_id is None or detected_jersey_num == current_jersey_id:
                        # First time seeing number, or it matches the stored one
                        p_data["jersey_id"] = detected_jersey_num
                        p_data["jersey_confidence"] = ocr_confidence
                        mismatch_history.clear() # Reset mismatch counter
                    else:
                        # Mismatch detected! Current OCR differs from stored ID
                        mismatch_history.append(detected_jersey_num)
                        # Check if mismatch is consistent over several frames
                        if len(mismatch_history) >= MISMATCH_CONSISTENCY_FRAMES and all(num == detected_jersey_num for num in mismatch_history):
                            # Consistent mismatch, update the stored jersey ID
                            # print(f"[Frame {frame_idx}] Track {track_id}: Jersey ID updated from {current_jersey_id} to {detected_jersey_num} (Consistent mismatch).")
                            p_data["jersey_id"] = detected_jersey_num
                            p_data["jersey_confidence"] = ocr_confidence
                            mismatch_history.clear() # Reset after update
                        # else: Inconsistent mismatch, keep stored ID for now
                else:
                    # OCR failed this frame, clear mismatch history
                    mismatch_history.clear()
                assigned_jersey_id = p_data["jersey_id"] # Use the potentially updated stored ID for display
                current_player_data[track_id] = p_data # Store updated data for this frame

            else:
                # New track ID encountered
                found_match = False
                # Try to match with recently lost tracks using jersey number (if OCR succeeded)
                if detected_jersey_num is not None and detected_jersey_num in recently_lost_jerseys:
                    potential_matches = []
                    # Check lost tracks with the same jersey number
                    for lost_track_info in reversed(recently_lost_jerseys[detected_jersey_num]):
                        time_diff = frame_idx - lost_track_info["last_seen"]
                        # Check if within time window and team ID matches
                        if time_diff < LOST_TRACK_MEMORY_FRAMES and lost_track_info["team_id"] == team_id:
                            potential_matches.append((lost_track_info, time_diff))

                    if potential_matches:
                        # Found potential matches, pick the most recent one
                        potential_matches.sort(key=lambda x: x[1]) # Sort by time difference
                        best_match_info, _ = potential_matches[0]
                        assigned_jersey_id = detected_jersey_num
                        # Re-establish player data using the matched info
                        p_data = {
                            "jersey_id": assigned_jersey_id, "jersey_confidence": ocr_confidence,
                            "last_seen": frame_idx, "team_id": team_id,
                            "mismatch_history": deque(maxlen=MISMATCH_CONSISTENCY_FRAMES) # Reset history
                        }
                        current_player_data[track_id] = p_data # Assign data to the *new* track ID
                        # print(f"[Frame {frame_idx}] Track {track_id}: Re-identified Jersey #{assigned_jersey_id} (was lost track {best_match_info['tracker_id']}).")
                        # Remove the matched entry from recently_lost list
                        try: recently_lost_jerseys[detected_jersey_num].remove(best_match_info)
                        except ValueError: pass # Ignore if already removed
                        found_match = True

                if not found_match:
                    # No match found, create a new player data entry for this track ID
                    assigned_jersey_id = detected_jersey_num # Could be None if OCR failed
                    current_player_data[track_id] = {
                        "jersey_id": assigned_jersey_id,
                        "jersey_confidence": ocr_confidence if detected_jersey_num is not None else None,
                        "last_seen": frame_idx, "team_id": team_id,
                        "mismatch_history": deque(maxlen=MISMATCH_CONSISTENCY_FRAMES)
                    }

            # --- Generate Display Label ---
            # Determine team prefix based on final assigned team/role ID
            if team_id == TEAM_A_ID: team_prefix = "T1"
            elif team_id == TEAM_B_ID: team_prefix = "T2"
            elif team_id == REFEREE_TEAM_ID: team_prefix = "Ref"
            else: team_prefix = f"T{team_id}" # Fallback for unexpected IDs

            base_label = f"{team_prefix} P{track_id}" # Base label: Team P<TrackID>
            display_id = base_label
            # Add jersey number if available and assigned
            if assigned_jersey_id is not None:
                display_id = f"{base_label} #{assigned_jersey_id}"

            final_labels.append(display_id) # Add the final label for this track

    # 8. Update Global Player Data & Handle Lost Tracks
    # Find tracks that were present in the previous frame but not this one
    lost_tracker_ids = set(player_data.keys()) - current_frame_tracker_ids
    for lost_id in lost_tracker_ids:
        lost_info = player_data[lost_id]
        # If the lost track had a known jersey ID, add it to the recently lost list
        # This allows matching if it reappears soon with the same number
        if lost_info.get("jersey_id") is not None:
            recently_lost_jerseys[lost_info["jersey_id"]].append({
                "tracker_id": lost_id, # Store the original tracker ID that was lost
                "last_seen": lost_info["last_seen"],
                "team_id": lost_info["team_id"]
            })

    # Periodic cleanup of very old entries in recently_lost_jerseys (e.g., every minute)
    # This prevents the dictionary from growing indefinitely
    current_fps = interaction_tracker.fps # Get FPS from interaction tracker
    if current_fps > 0 and frame_idx > 0 and frame_idx % (int(current_fps) * 60) == 0:
        # print(f"[Frame {frame_idx}] Cleaning up old lost tracks...")
        for jersey_num in list(recently_lost_jerseys.keys()): # Iterate over keys copy
            q = recently_lost_jerseys[jersey_num]
            # Keep only entries seen within twice the memory window (generous buffer)
            valid_entries = deque([entry for entry in q if (frame_idx - entry["last_seen"]) < LOST_TRACK_MEMORY_FRAMES * 2], maxlen=10)
            if valid_entries:
                recently_lost_jerseys[jersey_num] = valid_entries
            else:
                # Remove jersey number if no recent tracks associated with it
                del recently_lost_jerseys[jersey_num]

    # Update the global player data state with the data processed in this frame
    player_data = current_player_data

    # 9. Annotation
    annotated_frame = frame.copy() # Start annotation on a clean copy of the original frame

    # --- Annotate "Magical" Ball Trail ---
    if ball_positions is not None and len(ball_positions) >= 2:
        num_points = len(ball_positions)
        for i in range(1, num_points):
            pt1 = ball_positions[i-1]; pt2 = ball_positions[i]
            # Ensure points are valid tuples before drawing
            if isinstance(pt1, tuple) and isinstance(pt2, tuple) and len(pt1)==2 and len(pt2)==2:
                 # Draw the trail line
                 cv2.line(annotated_frame, pt1, pt2, BALL_TRAIL_BASE_COLOR, BALL_TRAIL_THICKNESS)
                 # --- Add Sparkles (Optional visual flair) ---
                 alpha_fraction = (i - 1) / max(1, num_points - 1) # Fade effect for sparkles
                 sparkle_intensity = int(SPARKLE_BASE_INTENSITY + (SPARKLE_MAX_INTENSITY - SPARKLE_BASE_INTENSITY) * alpha_fraction)
                 sparkle_color = (sparkle_intensity, sparkle_intensity, sparkle_intensity) # Grey sparkles
                 for _ in range(SPARKLE_COUNT):
                     # Add random offset to the end point of the line segment
                     offset_x = random.randint(-SPARKLE_OFFSET, SPARKLE_OFFSET)
                     offset_y = random.randint(-SPARKLE_OFFSET, SPARKLE_OFFSET)
                     sparkle_pt = (pt2[0] + offset_x, pt2[1] + offset_y)
                     # Draw small filled circle as a sparkle
                     cv2.circle(annotated_frame, sparkle_pt, SPARKLE_RADIUS, sparkle_color, -1)

    # --- Annotate Current Ball Position ---
    if ball_positions is not None and len(ball_positions) > 0:
         last_pos = ball_positions[-1] # Get the most recent position
         if isinstance(last_pos, tuple) and len(last_pos)==2:
              # Draw a marker (e.g., white filled circle) at the current ball position
              cv2.circle(annotated_frame, last_pos, CURRENT_BALL_MARKER_RADIUS, CURRENT_BALL_MARKER_COLOR, CURRENT_BALL_MARKER_THICKNESS)

    # --- Annotate Tracked People (Ellipses and Labels) ---
    # Ensure we have the same number of labels as tracked people
    if len(tracked_people) > 0 and len(final_labels) == len(tracked_people):
        # Group by team ID for consistent color annotation per team/role
        unique_team_ids = np.unique(tracked_people.class_id)
        for current_team_id in unique_team_ids:
            # Create mask for current team/role
            team_mask = (tracked_people.class_id == current_team_id)
            # Get detections and labels for this team/role
            team_detections = tracked_people[team_mask]
            team_labels = [label for i, label in enumerate(final_labels) if team_mask[i]]

            if len(team_detections) == 0: continue # Skip if no detections for this team

            # Get the dynamic color calculated earlier for this team/role ID
            team_color = dynamic_color_map.get(current_team_id, FALLBACK_COLOR) # Use fallback if ID unexpected

            # Create temporary annotators with the specific color for this group
            temp_ellipse_annotator = sv.EllipseAnnotator(color=team_color, thickness=ELLIPSE_THICKNESS)
            temp_label_annotator = sv.LabelAnnotator(
                color=team_color, text_color=LABEL_TEXT_COLOR,
                text_position=LABEL_TEXT_POSITION, text_scale=LABEL_TEXT_SCALE,
                text_thickness=LABEL_TEXT_THICKNESS
            )
            # Annotate ellipses and labels for this team/role group
            try:
                annotated_frame = temp_ellipse_annotator.annotate(annotated_frame, team_detections)
                annotated_frame = temp_label_annotator.annotate(annotated_frame, team_detections, team_labels)
            except Exception as e:
                # Log errors during annotation phase
                print(f"[Frame {frame_idx}] Error during annotation for team {current_team_id}: {e}")

    # Return the fully annotated frame
    return annotated_frame


In [17]:
# ----- Main Execution Function -----
def main(args):
    # Make global state variables modifiable within this function
    global player_data, recently_lost_jerseys, ball_positions, detection_model, team_classifier

    # --- Print Configuration ---
    print("--- Configuration ---")
    print(f"Source Video: {args.source}")
    print(f"Annotated Video Output: {args.output}")
    print(f"Action Frames Output Dir: {args.actions_dir}")
    print(f"Using device: {DEVICE}")
    print(f"ReID Weights Path: {args.reid_weights}")
    print(f"OCR Enabled: {PADDLEOCR_AVAILABLE}")
    print(f"Groq API Key Loaded: {'Yes' if GROQ_API_KEY else 'No (VLM Analysis Disabled)'}")
    print(f"Groq VLM Model: {GROQ_VLM_MODEL if GROQ_API_KEY else 'N/A'}")
    print("--------------------")


    # --- Create Output Directories ---
    os.makedirs(os.path.dirname(args.output), exist_ok=True)
    os.makedirs(args.actions_dir, exist_ok=True)
    if PADDLEOCR_AVAILABLE: os.makedirs(OCR_DEBUG_DIR, exist_ok=True)

    # --- Load Models ---
    # In a real scenario, load your trained models here
    print("Loading models...")
    try:
        # Replace MockModel and MockTeamClassifier with your actual loading logic
        # Example: detection_model = load_yolo(args.model_path)
        # Example: team_classifier = load_classifier(args.classifier_path)
        detection_model = MockModel() # Using placeholder
        team_classifier = MockTeamClassifier() # Using placeholder
        print("Models loaded (using placeholders).")
    except Exception as e:
        print(f"Error loading models: {e}. Exiting.")
        exit(1)

    # --- Get Video Info ---
    print("Getting video info...")
    try:
        # Use supervision to get video properties
        video_info = sv.VideoInfo.from_video_path(str(args.source))
        width, height, fps = video_info.width, video_info.height, video_info.fps
        total_frames = video_info.total_frames if video_info.total_frames else 0
        if not fps or fps <= 0: # Handle cases where FPS might be missing/invalid
             print("Warning: Invalid FPS from supervision. Trying OpenCV.")
             raise ValueError("Invalid FPS")
    except Exception as e_sv:
        print(f"Warning: Could not get video info via supervision ({e_sv}). Using OpenCV fallback.")
        try:
            cap = cv2.VideoCapture(str(args.source))
            if not cap.isOpened(): raise IOError(f"Cannot open video file: {args.source}")
            width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
            height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
            fps = cap.get(cv2.CAP_PROP_FPS)
            total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
            cap.release()
            if not fps or fps <= 0: fps = 30; print("Warning: OpenCV also failed to get valid FPS. Assuming 30 FPS.")
            if total_frames <= 0: total_frames = 0; print("Warning: Total frames unknown.")
            # Create a VideoInfo object for VideoSink based on OpenCV results
            video_info = sv.VideoInfo(width=width, height=height, fps=fps)
        except Exception as e_cv:
             print(f"FATAL: Error getting video info via OpenCV fallback: {e_cv}. Exiting.")
             exit(1)

    print(f"Video Info: {width}x{height}, FPS: {fps:.2f}, Total Frames: {total_frames if total_frames > 0 else 'Unknown'}")

    # --- Initialize Tracker, InteractionTracker, Ball Trail ---
    print("Initializing trackers...")
    reid_weights_path = Path(args.reid_weights)
    tracker = BotSort(
        reid_weights=reid_weights_path if reid_weights_path.exists() else None,
        device=DEVICE,
        half=False, # Adjust based on your model/GPU (True for FP16)
        with_reid=reid_weights_path.exists(),
        frame_rate=fps if fps > 0 else 30 # Provide frame rate to tracker
    )
    print(f"BoTSORT initialized. ReID enabled: {reid_weights_path.exists()}")

    interaction_tracker = InteractionTracker(fps=fps)
    print(f"InteractionTracker initialized. Buffer size: {interaction_tracker.frame_buffer.maxlen} frames.")

    # Initialize ball trail deque
    if fps > 0:
        trail_maxlen = int(fps * BALL_TRAIL_SECONDS)
        ball_positions = deque(maxlen=trail_maxlen)
        print(f"Ball trail deque initialized with maxlen={trail_maxlen}")
    else:
        ball_positions = deque(maxlen=1) # Minimal deque if FPS unknown
        print("Warning: FPS unknown. Ball trail length may be inaccurate.")

    # Update global LOST_TRACK_MEMORY_FRAMES based on actual FPS
    global LOST_TRACK_MEMORY_FRAMES
    LOST_TRACK_MEMORY_FRAMES = int(fps * LOST_TRACK_MEMORY_SECONDS) if fps > 0 else 30 * LOST_TRACK_MEMORY_SECONDS
    print(f"Lost track memory set to {LOST_TRACK_MEMORY_FRAMES} frames ({LOST_TRACK_MEMORY_SECONDS} seconds)")

    # --- Initialize Groq Client (if API key is available) ---
    groq_client = None
    if GROQ_API_KEY:
        try:
            groq_client = Groq(api_key=GROQ_API_KEY)
            print("Groq client initialized.")
        except Exception as e:
            print(f"Error initializing Groq client: {e}. VLM analysis will be disabled.")
            GROQ_API_KEY = None # Disable Groq if client fails
    else:
        print("GROQ_API_KEY not set. VLM analysis will be disabled.")

    # --- Video Processing Loop ---
    print("Starting video processing loop...")
    frame_generator = sv.get_video_frames_generator(source_path=str(args.source), stride=1)
    # Use VideoSink for efficient video writing
    with sv.VideoSink(target_path=args.output, video_info=video_info) as sink:
        # Setup progress bar if total frames are known
        tqdm_total = total_frames if total_frames and total_frames > 0 else None
        with tqdm(total=tqdm_total, desc="Processing video", unit="frame", mininterval=1.0) as pbar:
            for frame_idx, frame in enumerate(frame_generator):
                if frame is None:
                    print(f"\nWarning: Received None frame at index {frame_idx}, ending processing.")
                    break # Stop processing if a frame is None
                try:
                    # Process the current frame
                    annotated_frame = process_frame(frame, frame_idx, tracker, interaction_tracker, width, height)
                    # Write the *annotated* frame to the output video file
                    sink.write_frame(frame=annotated_frame)

                except Exception as e:
                    # Catch errors during frame processing
                    print(f"\n--- CRITICAL ERROR processing frame {frame_idx}: {e} ---")
                    traceback.print_exc() # Print detailed traceback
                    print("Attempting to continue processing...")
                    # Optionally write the original frame to avoid stopping video output
                    # sink.write_frame(frame=frame)
                # Update progress bar
                if pbar is not None: pbar.update(1)

    # --- Post-Processing: Save Interaction Frames & Analyze with Groq ---
    print(f"\nFinished processing video frames.")
    confirmed_interactions_count = len(interaction_tracker.confirmed_interactions)
    print(f"Found {confirmed_interactions_count} confirmed interactions.")

    if confirmed_interactions_count > 0:
        print("Saving interaction frame sequences and performing VLM analysis (if enabled)...")
        saved_clips_count = 0
        analyzed_clips_count = 0
        # Iterate through all confirmed interactions using their index
        for idx in range(confirmed_interactions_count):
            # Retrieve frames for this specific interaction index
            # This also marks the interaction as processed in the tracker
            clip_frames_data = interaction_tracker.get_clip_frames_data(idx)

            if clip_frames_data:
                 # Get interaction details for naming and analysis
                 interaction = interaction_tracker.confirmed_interactions[idx]
                 interaction_info = {
                     'index': idx,
                     'ball_id': interaction['ball_id'], # Note: This is the placeholder ID (-1)
                     'player_id': interaction['player_id'] # This is the player's tracker ID
                 }
                 # Save the frames as images, get the directory path if successful
                 saved_clip_path = save_interaction_frames(
                     clip_frames_data, args.actions_dir, interaction_info, fps
                 )

                 if saved_clip_path:
                     saved_clips_count += 1
                     # --- Call Groq API if enabled and save was successful ---
                     if groq_client:
                         try:
                             analyze_action_frames(
                                 clip_dir=saved_clip_path,
                                 groq_client=groq_client,
                                 player_id=interaction_info['player_id'],
                                 ball_id=interaction_info['ball_id']
                             )
                             analyzed_clips_count += 1
                         except Exception as e_groq:
                             print(f"Error during Groq analysis for {saved_clip_path}: {e_groq}")
                 # else: Saving frames failed (error already printed by save_interaction_frames)
            # else: Frame data retrieval failed (error already printed by get_clip_frames_data)

        print(f"\nFinished saving frame sequences for {saved_clips_count} interactions to '{args.actions_dir}'.")
        if GROQ_API_KEY:
             print(f"Attempted VLM analysis for {analyzed_clips_count} interactions.")
    else:
        print("No interactions met the criteria for saving/analysis.")

    print(f"Annotated video saved to: {args.output}")

    # --- Cleanup ---
    if DEVICE.type == 'cuda':
        try:
            torch.cuda.empty_cache()
            print("CUDA cache cleared.")
        except Exception as e:
            print(f"Error clearing CUDA cache: {e}")
    print("Processing complete.")


In [18]:
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Advanced Football Tracking with Action Frame Extraction and VLM Analysis")
    parser.add_argument("--source", type=str, default=DEFAULT_SOURCE_VIDEO_PATH, help="Path to the source video file.")
    parser.add_argument("--output", type=str, default=DEFAULT_OUTPUT_VIDEO_PATH, help="Path to save the annotated output video.")
    parser.add_argument("--actions_dir", type=str, default=DEFAULT_ACTIONS_DIR, help="Directory to save action frame sequences.")
    parser.add_argument("--reid_weights", type=str, default=DEFAULT_REID_WEIGHTS_PATH, help="Path to the ReID model weights for BoTSORT (e.g., clip_market1501.pt).")
    # Add arguments for your actual model paths if needed, e.g.:
    # parser.add_argument("--model_path", type=str, required=True, help="Path to detection model weights.")
    # parser.add_argument("--classifier_path", type=str, required=True, help="Path to team classifier model weights.")

    # In a notebook environment, you might prefer to define args manually:
    # class Args:
    #     source = "path/to/your/video.mp4"
    #     output = "output/annotated_video.mp4"
    #     actions_dir = "output/action_frames"
    #     reid_weights = "models/clip_market1501.pt"
    #     # model_path = "models/yolo.pt"
    #     # classifier_path = "models/classifier.pt"
    # args = Args()

    # Parse arguments if running as a script
    args = parser.parse_args()

    # --- Run the main processing function ---
    main(args)


usage: ipykernel_launcher.py [-h] [--source SOURCE] [--output OUTPUT]
                             [--actions_dir ACTIONS_DIR]
                             [--reid_weights REID_WEIGHTS]
ipykernel_launcher.py: error: unrecognized arguments: --f=/run/user/1000/jupyter/runtime/kernel-v382f4ee6a98fac3bcf9fffa63abbf8064afb62cbb.json


SystemExit: 2