In [None]:
# ============================================================
# Tennis Player Tracking and Kinematic Analysis Pipeline
#
# Author: Peterson Antonio
# Affiliation: FCMax Performance
#
# Description:
# This notebook implements a complete computer vision pipeline for
# tennis player tracking and motion analysis using YOLOv12, ByteTrack,
# court keypoint detection, homography transformation, and temporal filtering.
#
# The pipeline enables conversion from image-space coordinates (pixels)
# to real-world coordinates (meters), allowing quantitative analysis of:
#
# • Player trajectories
# • Distance covered
# • Velocity
# • Acceleration and deceleration
#
# Core components:
# • YOLOv12 for object detection
# • ByteTrack for multi-object tracking
# • Court keypoint detection for homography estimation
# • Ground-plane coordinate transformation
# • Temporal filtering and interpolation
# • Kinematic metrics computation
#
# Applications:
# • Sports performance analysis
# • Player movement quantification
# • Computer vision research
# • Automated sports analytics systems
#
# ============================================================

In [None]:
# ============================================================
# ENVIRONMENT INITIALIZATION — Environment setup and project loading
#
# This block initializes the Google Colab runtime environment and prepares
# the workspace for execution of the analysis pipeline.
#
# Main objectives:
#
# 1. Mount Google Drive
#    - Provides access to project files, including:
#        • input videos
#        • trained models
#        • configuration files
#        • output directories
#
# 2. Define the project working directory
#    - Ensures all relative paths function correctly
#    - Centralizes execution within the main project folder
#
# 3. Install dependencies
#    - Automatically installs all required libraries as specified
#      in the requirements.txt file
#
# Benefits:
# - Ensures environment reproducibility
# - Prevents import errors and version incompatibilities
# - Allows consistent execution across different Colab sessions
#
# Technical note:
# This block should be executed once at the beginning of the session.
# ============================================================


# === Standard initialization block ===
from google.colab import drive
import os

# 1. Mount Google Drive
drive.mount('/content/drive')

# 2. Define project path (adjust if located elsewhere in your Drive)
project_path = "/content/drive/MyDrive/tennis_analysis"

# 3. Change to project directory
os.chdir(project_path)
print("Current working directory:", os.getcwd())

# 4. Install dependencies from requirements.txt
!pip install -r requirements.txt

In [None]:
# ============================================================
# TRACKER CONFIGURATION — Mode 1: Fixed environment-specific path
#
# This mode loads the default ByteTrack configuration file using an absolute
# filesystem path and creates a modified version with custom tracking parameters.
#
# Characteristics:
#
# • Simple and direct implementation
# • Uses a fixed path to the installed Ultralytics tracker configuration
# • Environment-dependent (Python version, OS, and installation location)
#
# Limitation:
#
# This approach may require manual path adjustment when executed on a different
# system, virtual environment, or Python installation.
#
# Recommended usage:
#
# Suitable for controlled environments such as Google Colab or a fixed local setup.
#
# Output:
#
# Generates a modified ByteTrack configuration file for improved tracking stability.
#
# ============================================================


# Generate a modified ByteTrack configuration file
import yaml

# Path to the default tracker configuration file
tracker_path = "/usr/local/lib/python3.12/dist-packages/ultralytics/cfg/trackers/bytetrack.yaml"

# Load configuration
with open(tracker_path, "r") as f:
    config = yaml.safe_load(f)

# Modify tracking parameters
config["track_buffer"] = 100        # increases track persistence during temporary occlusions
config["new_track_thresh"] = 0.7    # confidence threshold for initializing new tracks

# Save modified configuration file
with open("/content/drive/MyDrive/tennis_analysis/bytetrack_modified.yaml", "w") as f:
    yaml.dump(config, f)

In [None]:
# ============================================================
# TRACKER CONFIGURATION — Mode 2: Dynamic environment-independent path (recommended)
#
# This mode automatically detects the installation location of the Ultralytics
# package and accesses the correct ByteTrack configuration file.
#
# Characteristics:
#
# • Fully portable across environments
# • Works on Google Colab, Windows, Linux, virtual environments, and conda
# • Independent of Python version or installation path
# • Does not require manual path adjustments
#
# Recommended usage:
#
# This is the preferred approach for shared notebooks, GitHub repositories,
# and production pipelines where portability and reproducibility are required.
#
# Output:
#
# Generates a modified ByteTrack configuration file with customized tracking
# parameters, saved in the current project directory.
#
# ============================================================


# Generate a modified ByteTrack configuration file (portable mode)

import yaml
import ultralytics
import os

# Automatically detect Ultralytics installation path
ultra_path = os.path.dirname(ultralytics.__file__)

# Build path to default ByteTrack configuration file
tracker_path = os.path.join(
    ultra_path,
    "cfg",
    "trackers",
    "bytetrack.yaml"
)

print("Using original tracker configuration from:")
print(tracker_path)

# Load original configuration
with open(tracker_path, "r") as f:
    config = yaml.safe_load(f)

# Modify tracking parameters
config["track_buffer"] = 100        # increases track persistence during temporary occlusions
config["new_track_thresh"] = 0.7    # confidence threshold for initializing new tracks

# Define portable output path (current project directory)
output_path = os.path.join(os.getcwd(), "bytetrack_modified.yaml")

# Save modified configuration
with open(output_path, "w") as f:
    yaml.dump(config, f, sort_keys=False)

print("\nModified configuration file saved at:")
print(output_path)

In [None]:
# ============================================================
# OBJECT DETECTION AND TRACKING — Player detection and tracking with YOLOv12 + ByteTrack
#
# This module performs multi-object player detection and tracking using YOLOv12
# integrated with the ByteTrack algorithm for temporal ID association.
#
# Main goals:
# - Detect players in each video frame
# - Assign persistent IDs over time (multi-object tracking)
# - Extract stable coordinates for downstream kinematic analysis
#
# Pipeline overview:
#
# 1. YOLOv12 initialization
#    - Deep learning-based object detector optimized for real-time inference
#      with a strong accuracy/efficiency trade-off.
#
# 2. Multi-object tracking with ByteTrack
#    - Associates detections across consecutive frames
#    - Maintains persistent IDs for each player
#    - Uses a customized tracker configuration (bytetrack_modified.yaml) to
#      improve temporal stability and reduce ID switches.
#
#    Key parameters:
#      persist=True  → keeps tracker state across frames
#      tracker=...   → path to custom ByteTrack configuration
#
# 3. Player position extraction (footpoint heuristic)
#    - For each detected bounding box:
#
#        cx = horizontal center of the bounding box
#        cy = bottom edge of the bounding box (y2)
#
#    - The point (cx, cy) approximates the player's ground contact position
#      (ground-plane proxy), which is suitable for homography-based mapping
#      into real-world court coordinates.
#
# 4. Tabular data structure (pandas DataFrame)
#
#    Output format:
#
#        id1_x | id1_y | id2_x | id2_y | ...
#
#    Each row corresponds to a video frame.
#    Each column pair corresponds to a tracked player ID.
#    Missing detections are stored as NaN.
#
# Downstream usage:
# - Primary player selection (2-player extraction)
# - Homography transformation (pixels → meters)
# - Kinematic metrics computation (distance, speed, acceleration)
# - Trajectory analysis and visualization
#
# Technical note:
# The footpoint heuristic is a standard approach in sports tracking pipelines
# when direct 3D pose/foot contact estimation is not available.
# ============================================================


# ------------------------------- IMPORTANT --------------------------------
# Upload your input video to the `input_videos/` folder and update the path below.
# -------------------------------------------------------------------------

%cd /content/drive/MyDrive/tennis_analysis

from ultralytics import YOLO
import pandas as pd

# Load the YOLOv12 model
model = YOLO("yolo12n.pt")

# Input video path
source = "/content/drive/MyDrive/tennis_analysis/input_videos/Djokovic.mp4"

# Run tracking (YOLOv12 + ByteTrack)
# NOTE: `tracker` points to a customized ByteTrack config file.
results = model.track(
    source,
    save=True,
    persist=True,
    tracker="/content/drive/MyDrive/tennis_analysis/bytetrack_modified.yaml"
)

# ------------------------------------------------------------
# Convert tracking output into a per-frame DataFrame
# ------------------------------------------------------------

frames_data = []

for res in results:
    frame_dict = {}

    if res.boxes.id is not None:  # check if the frame contains tracked IDs
        ids = res.boxes.id.cpu().numpy().astype(int)
        xyxy = res.boxes.xyxy.cpu().numpy()  # [x1, y1, x2, y2]

        for obj_id, box in zip(ids, xyxy):
            cx = (box[0] + box[2]) / 2.0  # bbox center-x
            cy = box[3]                   # bbox bottom (y2) → footpoint heuristic
            frame_dict[f"id{obj_id}_x"] = cx
            frame_dict[f"id{obj_id}_y"] = cy

    frames_data.append(frame_dict)

# Build DataFrame (NaNs for frames where an ID was not detected)
player_detections = pd.DataFrame(frames_data)

# Sort columns by ID index, keeping x/y pairs grouped
player_detections = player_detections.reindex(
    sorted(
        player_detections.columns,
        key=lambda x: (int(x.split('_')[0][2:]), x[-1])
    ),
    axis=1
)

print(player_detections.head())

In [None]:
# ============================================================
# COURT KEYPOINT DETECTION — Extraction of court geometric reference points
#
# This module automatically detects structural keypoints of the tennis court
# using a deep learning model trained to identify line intersections and
# court boundaries.
#
# Main objective:
# Provide accurate correspondences between image-space coordinates (pixels)
# and real-world court coordinates (meters), enabling homography computation.
#
# Pipeline:
#
# 1. Video loading
#    - The video is loaded into memory as a sequence of frames.
#
# 2. Court keypoint detection model initialization
#    - CNN-based model trained specifically to detect structural reference
#      points of a tennis court.
#
# 3. Keypoint detection on the first frame
#    - Assumes a static camera setup, allowing the use of a single fixed
#      homography for the entire video sequence.
#    - Returns keypoint coordinates in image-space.
#
# Output:
# court_keypoints → numpy array of shape (N, 2)
# containing detected court keypoints in pixel coordinates (x, y).
#
# Downstream usage:
# These keypoints are used to compute the homography matrix that maps player
# positions from image-space (pixels) to real-world court coordinates (meters).
#
# Assumption:
# The camera remains fixed throughout the video. If the camera moves,
# homography must be recomputed dynamically per frame.
# ============================================================

from utils import read_video
from court_line_detector import CourtLineDetector

# Load video frames into memory
source = "/content/drive/MyDrive/tennis_analysis/input_videos/Djokovic.mp4"
video_frames = read_video(source)

# Initialize court keypoint detection model
court_model_path = "/content/drive/MyDrive/tennis_analysis/models/keypoints_model_50.pth"
court_line_detector = CourtLineDetector(court_model_path)

# Detect court keypoints from the first frame
court_keypoints = court_line_detector.predict(video_frames[0])

In [None]:
# ============================================================
# PLAYER VISUALIZATION PIPELINE — Selection, point rendering, and temporal motion trail
#
# This module automatically identifies the two primary tracked players and
# generates a visualization video with a simplified kinematic representation.
#
# Key features:
#
# 1. Automatic selection of the two main players
#    - Scores all tracked IDs based on:
#        • number of valid frames (track continuity)
#        • spatial consistency relative to the court center
#    - Selects the two most stable and relevant IDs for analysis.
#
# 2. Point-based representation using the footpoint heuristic
#    - Each player is represented by the bottom-center of the bounding box (cx, y2),
#      which approximates ground contact position (ground-plane proxy).
#    - This reduces visual clutter while keeping a consistent proxy of player position
#      on the court plane.
#
# 3. Optional temporal trail rendering ("motion brush")
#    - Accumulates player positions over time to visualize trajectories and
#      movement patterns clearly.
#    - The `trail_decay` parameter controls temporal fade-out:
#        • 1.0  → infinite trail (no decay)
#        • <1.0 → trail gradually fades over time
#
# 4. Court keypoint overlay
#    - Renders detected court keypoints to provide geometric context.
#
# 5. Optional ball rendering support
#    - If ball detections are provided, the ball is rendered as a point using
#      the same bottom-center heuristic.
#
# Use cases:
# - Visual QA of tracking stability (ID switches, drift, missed detections)
# - Qualitative trajectory inspection prior to quantitative analysis
# - Debugging and demonstration video generation
#
# Technical note:
# This module is for visualization only; it does not modify the coordinate data
# used for homography estimation or metric computation.
# ============================================================

import numpy as np
import pandas as pd
import cv2
import os

def normalize_court_keypoints(court_keypoints):
    arr = np.array(court_keypoints, dtype=float)

    if arr.ndim == 0:
        raise ValueError("court_keypoints is scalar; provide a list/array of points (N, 2).")

    if arr.ndim == 1:
        if arr.size % 2 == 0:
            arr = arr.reshape(-1, 2)
        else:
            raise ValueError("1D array with odd length cannot be interpreted as (x, y) pairs.")

    elif arr.ndim == 2:
        if arr.shape[1] != 2:
            raise ValueError(f"court_keypoints with shape {arr.shape} is not in (N, 2) format.")

    else:
        raise ValueError(f"court_keypoints with ndim={arr.ndim} is not supported.")

    return arr


def select_two_players_balanced(player_detections: pd.DataFrame, court_keypoints, alpha=0.5):
    kp = normalize_court_keypoints(court_keypoints)
    court_center = np.mean(kp, axis=0)

    metrics = {}
    x_cols = [c for c in player_detections.columns if c.endswith("_x")]

    for col_x in x_cols:
        player_id = col_x.rsplit("_", 1)[0]
        col_y = f"{player_id}_y"
        if col_y not in player_detections.columns:
            continue

        x_vals = pd.to_numeric(player_detections[col_x], errors="coerce").to_numpy(dtype=float)
        y_vals = pd.to_numeric(player_detections[col_y], errors="coerce").to_numpy(dtype=float)
        valid_mask = np.isfinite(x_vals) & np.isfinite(y_vals)

        n_valid = int(np.sum(valid_mask))
        if n_valid == 0 or n_valid < len(player_detections[col_x]) * 0.5:
            continue

        dx = x_vals[valid_mask] - court_center[0]
        dy = y_vals[valid_mask] - court_center[1]
        dist_mean = float(np.mean(np.sqrt(dx * dx + dy * dy)))

        # Lower score is better: closer to court center + more valid frames
        score = dist_mean / (n_valid ** alpha)
        metrics[player_id] = {"score": score, "dist_mean": dist_mean, "n_valid": n_valid}

    if len(metrics) < 2:
        raise ValueError("Fewer than two tracked IDs with sufficient valid data were found.")

    sorted_ids = sorted(metrics.keys(), key=lambda k: metrics[k]["score"])[:2]

    selected_cols = []
    for pid in sorted_ids:
        selected_cols.extend([f"{pid}_x", f"{pid}_y"])

    players_detections_selected = player_detections.reindex(columns=selected_cols).copy()
    players_detections_selected.columns = ["player1_x", "player1_y", "player2_x", "player2_y"]

    id_map = {"player1": sorted_ids[0], "player2": sorted_ids[1]}
    return players_detections_selected, sorted_ids, id_map


def draw_selected_players_and_court(
    video_frames,
    results,
    selected_ids,
    court_keypoints,
    output_path="players_and_court.mp4",
    fps=30,
    # optional: ball detections
    ball_detections=None,
    ball_color=(0, 0, 255),
    ball_radius=5,
    # point + trail
    point_radius=6,
    draw_trail=True,
    trail_decay=0.96,   # 1.0 = infinite trail; <1.0 = trail fades over time
    trail_mix=0.75,     # 0 = trail only, 1 = frame only
):
    """
    Generates a visualization video with:
      - court keypoints overlay
      - two players rendered as POINTS (no bounding boxes), using the bbox bottom-center (cx, y2)
      - optional accumulated motion trail ("brush")
      - optional ball rendering as a point (bbox bottom-center)

    Parameters:
      - point_radius: player marker radius in pixels
      - draw_trail: if True, renders the accumulated motion trail
      - trail_decay: trail decay factor (0.90–0.99 typical). Use 1.0 for no decay.
      - trail_mix: blend ratio between current frame and trail (0.6–0.85 typical)
    """
    dynamic_kp = isinstance(court_keypoints, (list, tuple)) and len(court_keypoints) == len(video_frames)
    kp_static = None if dynamic_kp else normalize_court_keypoints(court_keypoints)

    colors = [(0, 255, 0), (255, 0, 0)]  # player1 green, player2 blue

    h, w = video_frames[0].shape[:2]
    fourcc = cv2.VideoWriter_fourcc(*"mp4v")
    out = cv2.VideoWriter(output_path, fourcc, fps, (w, h))

    # persistent trail canvas
    trail = np.zeros_like(video_frames[0], dtype=np.uint8) if draw_trail else None

    def _extract_boxes(det_obj):
        boxes = []
        if det_obj is None:
            return boxes

        if isinstance(det_obj, (list, tuple, np.ndarray)):
            arr = np.array(det_obj)
            if arr.ndim == 2 and arr.shape[1] >= 4:
                for row in arr:
                    boxes.append([float(row[0]), float(row[1]), float(row[2]), float(row[3])])
                return boxes
            for item in det_obj:
                try:
                    a = np.array(item)
                    if a.size >= 4:
                        boxes.append([float(a.flatten()[0]), float(a.flatten()[1]),
                                      float(a.flatten()[2]), float(a.flatten()[3])])
                except Exception:
                    continue
            return boxes

        if hasattr(det_obj, "boxes"):
            xyxy = getattr(det_obj.boxes, "xyxy", None)
            if xyxy is not None:
                try:
                    arr = xyxy.cpu().numpy() if hasattr(xyxy, "cpu") else np.array(xyxy)
                    for row in arr:
                        boxes.append([float(row[0]), float(row[1]), float(row[2]), float(row[3])])
                    return boxes
                except Exception:
                    pass

        try:
            arr = np.array(det_obj, dtype=float)
            if arr.ndim == 1 and arr.size % 4 == 0 and arr.size >= 4:
                arr2 = arr.reshape(-1, 4)
                for row in arr2:
                    boxes.append([float(row[0]), float(row[1]), float(row[2]), float(row[3])])
                return boxes
        except Exception:
            pass

        return boxes

    n_frames = min(len(video_frames), len(results))
    if ball_detections is not None:
        n_frames = min(n_frames, len(ball_detections))

    for idx in range(n_frames):
        frame = video_frames[idx]
        res = results[idx]
        img = frame.copy()

        kp_frame = normalize_court_keypoints(court_keypoints[idx]) if dynamic_kp else kp_static

        # draw court keypoints
        for (x, y) in kp_frame:
            cv2.circle(img, (int(round(x)), int(round(y))), 4, (0, 255, 255), -1)

        # decay trail
        if draw_trail and trail is not None and trail_decay < 1.0:
            trail[:] = (trail.astype(np.float32) * float(trail_decay)).astype(np.uint8)

        # draw players as points (bbox bottom-center)
        try:
            if hasattr(res, "boxes") and res.boxes is not None and getattr(res.boxes, "id", None) is not None:
                ids = res.boxes.id.cpu().numpy().astype(int)
                xyxy = res.boxes.xyxy.cpu().numpy()

                for obj_id, box in zip(ids, xyxy):
                    key = f"id{int(obj_id)}"
                    if key not in selected_ids:
                        continue

                    p_idx = selected_ids.index(key)
                    color = colors[p_idx]

                    x1, y1, x2, y2 = box.astype(float)
                    cx = int(round((x1 + x2) / 2.0))
                    cy = int(round(y2))

                    if draw_trail and trail is not None:
                        cv2.circle(trail, (cx, cy), point_radius, color, -1)

                    cv2.circle(img, (cx, cy), point_radius, color, -1)
                    cv2.putText(img, f"P{p_idx+1}", (cx + point_radius + 2, max(0, cy - point_radius - 2)),
                                cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2)

        except Exception as e:
            cv2.putText(img, f"draw error: {e}", (10, 30),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2)

        # optional ball point rendering
        if ball_detections is not None:
            dets = ball_detections[idx]
            boxes = _extract_boxes(dets)
            for b in boxes:
                try:
                    x1, y1, x2, y2 = map(float, b[:4])
                except Exception:
                    continue
                cx = int(round((x1 + x2) / 2.0))
                cy = int(round(y2))
                cv2.circle(img, (cx, cy), ball_radius, ball_color, -1)
                cv2.putText(img, "Ball", (cx + ball_radius + 2, max(0, cy)),
                            cv2.FONT_HERSHEY_SIMPLEX, 0.5, ball_color, 2)

        # composite trail on top of frame
        if draw_trail and trail is not None:
            img = cv2.addWeighted(img, float(trail_mix), trail, 1.0 - float(trail_mix), 0.0)

        # optional ID legend
        for i, sid in enumerate(selected_ids):
            cv2.putText(img, f"P{i+1}: {sid}", (10, 20 + i * 20),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.6, colors[i], 2)

        out.write(img)

    out.release()
    print(f"Video saved at: {output_path}")


# =========================
# RESULTS FOLDER (per video)
# =========================

# Use the same variable from your notebook that points to the input video.
# In your case, you already use source in other cells.
VIDEO_PATH = source  # ex: "/content/drive/MyDrive/tennis_analysis/input_videos/Djokovic2.mp4"

video_name = os.path.splitext(os.path.basename(VIDEO_PATH))[0]

RESULTS_BASE_DIR = "/content/drive/MyDrive/tennis_analysis/results"
RESULTS_VIDEO_DIR = os.path.join(RESULTS_BASE_DIR, video_name)
os.makedirs(RESULTS_VIDEO_DIR, exist_ok=True)

# saída do vídeo dentro da pasta do vídeo
OUTPUT_VIDEO_PATH = os.path.join(RESULTS_VIDEO_DIR, f"{video_name}_tracked.mp4")

print("Results directory:", RESULTS_VIDEO_DIR)
print("Output video path:", OUTPUT_VIDEO_PATH)

# =========================
# RUN (selection + render)
# =========================

players_detections_selected, original_ids, id_map = select_two_players_balanced(
    player_detections, court_keypoints, alpha=0.5
)

draw_selected_players_and_court(
    video_frames, results, original_ids, court_keypoints,
    output_path=OUTPUT_VIDEO_PATH,
    fps=30,
    draw_trail=True,
    trail_decay=0.96,
    trail_mix=0.75,
    point_radius=6
)

In [None]:
# ============================================================
# BASE PIPELINE — Homography, interpolation, and temporal smoothing
#
# This pipeline implements the minimum workflow required to convert
# pixel coordinates into real-world coordinates (meters) using a homography,
# followed by gap filling and temporal smoothing.
#
# Stages:
# 1. Homography: maps player positions from the image plane (pixels) to the
#    real court plane (meters) using detected court keypoints.
#
# 2. Linear interpolation: fills missing samples (NaNs) caused by short-lived
#    detection/tracking dropouts.
#
# 3. Butterworth low-pass filter: reduces high-frequency jitter introduced by
#    the detector and tracker, while preserving the main kinematic structure
#    of the movement.
#
# 4. Metric computation: total distance, mean/max velocity, and acceleration.
#
# This pipeline is typically sufficient when tracking is stable and does not
# contain significant outliers.
# ============================================================
import os
import numpy as np
import pandas as pd
import cv2
import matplotlib.pyplot as plt
from scipy.signal import butter, filtfilt

# --- REAL COURT REFERENCE POINTS (meters) ---
real_coords = np.array([
    [0, 0],
    [0, 10.97],
    [23.77, 0],
    [23.77, 10.97],
    [0, 1.37],
    [23.77, 1.37],
    [0, 9.6],
    [23.77, 9.6],
    [5.485, 1.37],
    [5.485, 9.6],
    [18.285, 1.37],
    [18.285, 9.6],
    [5.485, 5.485],
    [18.285, 5.485]
], dtype=np.float32)


# --- HOMOGRAPHY + REAL-WORLD COORDINATES ---
def compute_fixed_homography_and_real_coords(players_detections_selected, court_keypoints, real_coords):
    real_coords = np.array(real_coords, dtype=np.float32)
    img_pts = np.array(court_keypoints, dtype=np.float32).reshape(-1, 2)

    H, status = cv2.findHomography(img_pts, real_coords)
    if H is None or status is None or not status.all():
        raise RuntimeError("Failed to compute homography using the first-frame keypoints.")

    n_frames = len(players_detections_selected)
    real_coords_list = []

    for i in range(n_frames):
        px_coords = players_detections_selected.iloc[i].values.reshape(2, 2)  # [[x1,y1], [x2,y2]]
        pts_homog = np.hstack([px_coords, np.ones((2, 1))])  # (2,3)

        real_pts_homog = (H @ pts_homog.T).T  # (2,3)
        real_pts = real_pts_homog[:, :2] / real_pts_homog[:, 2, np.newaxis]  # normalize

        real_coords_list.append(real_pts.flatten().tolist())

    players_real_coords = pd.DataFrame(
        real_coords_list,
        columns=["player1_x", "player1_y", "player2_x", "player2_y"]
    )
    return players_real_coords


# --- NaN INTERPOLATION (GAP FILLING) ---
def interpolate_nan_coords(df):
    df_interp = df.copy()
    for col in df_interp.columns:
        df_interp[col] = df_interp[col].interpolate(method="linear", limit_direction="both")
        # use ffill()/bfill() to avoid FutureWarning from fillna(method=...)
        df_interp[col] = df_interp[col].ffill().bfill()
    return df_interp


# --- BUTTERWORTH LOW-PASS FILTER ---
def butterworth_filter(data, cutoff=0.3, fs=30, order=4):
    nyq = 0.5 * fs
    normal_cutoff = cutoff / nyq
    b, a = butter(order, normal_cutoff, btype="low", analog=False)

    filtered_data = []
    n_players = len(data[0])

    # process each player independently
    for pid in range(n_players):
        x = np.array([frame[pid][0] for frame in data])
        y = np.array([frame[pid][1] for frame in data])
        x_filt = filtfilt(b, a, x)
        y_filt = filtfilt(b, a, y)
        filtered_data.append(np.stack([x_filt, y_filt], axis=1))

    # reconstruct per-frame structure
    filtered_by_frame = []
    for i in range(len(data)):
        frame_coords = [filtered_data[pid][i] for pid in range(n_players)]
        filtered_by_frame.append(frame_coords)

    return filtered_by_frame


# --- METRIC COMPUTATION ---
def compute_metrics(coords_per_frame, fs=30):
    n_players = len(coords_per_frame[0])
    distances = {pid: 0.0 for pid in range(n_players)}
    max_vel = {pid: 0.0 for pid in range(n_players)}
    mean_vel = {pid: 0.0 for pid in range(n_players)}
    max_accel = {pid: 0.0 for pid in range(n_players)}
    max_decel = {pid: 0.0 for pid in range(n_players)}

    velocities = {pid: [] for pid in range(n_players)}

    # velocities
    for i in range(1, len(coords_per_frame)):
        dt = 1.0 / fs
        for pid in range(n_players):
            prev = np.array(coords_per_frame[i - 1][pid])
            curr = np.array(coords_per_frame[i][pid])
            dist = np.linalg.norm(curr - prev)
            vel = dist / dt  # m/s
            distances[pid] += dist
            velocities[pid].append(vel)
            if vel > max_vel[pid]:
                max_vel[pid] = vel

    # accelerations
    for pid in range(n_players):
        v = np.array(velocities[pid])
        if len(v) > 1:
            a = np.diff(v) * fs  # m/s²
            max_accel[pid] = np.max(a)
            max_decel[pid] = np.min(a)
            mean_vel[pid] = np.mean(v)

    return distances, max_vel, mean_vel, max_accel, max_decel


# --- DATAFRAME -> LIST[FRAME] ---
def df_to_frame_list(df):
    frame_list = []
    for _, row in df.iterrows():
        coords = [(row["player1_x"], row["player1_y"]), (row["player2_x"], row["player2_y"])]
        frame_list.append(coords)
    return frame_list


# --- TRAJECTORY PLOTTING ---
def plot_trajectory(player_coords_filt, save_path=None, show=True, dpi=300):
    fig, ax = plt.subplots(figsize=(12, 6))

    # outer court (doubles)
    ax.plot([0, 23.77], [0, 0], "k")
    ax.plot([23.77, 23.77], [0, 10.97], "k")
    ax.plot([0, 23.77], [10.97, 10.97], "k")
    ax.plot([0, 0], [0, 10.97], "k")
    ax.plot([23.77 / 2, 23.77 / 2], [0, 10.97], "k", linewidth=3)

    # service boxes
    ax.plot([0, 23.77], [1.37, 1.37], "k")
    ax.plot([23.77, 23.77], [1.37, 9.6], "k")
    ax.plot([0, 23.77], [9.6, 9.6], "k")
    ax.plot([0, 0], [0, 9.6], "k")

    # vertical service lines
    ax.plot([5.485, 5.485], [1.37, 9.6], "k")
    ax.plot([18.285, 18.285], [1.37, 9.6], "k")

    # middle service line (net-side)
    ax.plot([5.485, 23.77 / 2], [10.97 / 2, 10.97 / 2], "k")
    ax.plot([23.77 / 2, 18.285], [10.97 / 2, 10.97 / 2], "k")

    ax.set_xlim([-8, 23.77 + 8])
    ax.set_ylim([-2, 10.97 + 2])
    ax.set_aspect("equal")
    ax.set_xlabel("X (m)")
    ax.set_ylabel("Y (m)")

    # trajectories
    n_players = len(player_coords_filt[0])
    for pid in range(n_players):
        traj = np.array([frame[pid] for frame in player_coords_filt], dtype=float)
        ax.plot(traj[:, 0], traj[:, 1], marker="o", label=f"Player {pid + 1}")

    ax.legend()

    # ✅ salva a figura certa (fig), antes do show
    if save_path is not None:
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        fig.savefig(save_path, dpi=dpi, bbox_inches="tight")
        print("Saved trajectory plot:", save_path)

    if show:
        plt.show()
    else:
        plt.close(fig)

    return fig, ax


# ============================================================
# RESULTS DIR (per video) — ensure it exists here too
# ============================================================
VIDEO_PATH = source  # seu path do vídeo de entrada
video_name = os.path.splitext(os.path.basename(VIDEO_PATH))[0]

RESULTS_BASE_DIR = "/content/drive/MyDrive/tennis_analysis/results"
RESULTS_VIDEO_DIR = os.path.join(RESULTS_BASE_DIR, video_name)
os.makedirs(RESULTS_VIDEO_DIR, exist_ok=True)

print("Saving outputs to:", RESULTS_VIDEO_DIR)

# ============================================================
# PIPELINE
# ============================================================

# 1) Real-world coordinates via homography
players_real_coords = compute_fixed_homography_and_real_coords(
    players_detections_selected,
    court_keypoints,
    real_coords
)

# (opcional, mas útil) salvar coordenadas
players_real_coords.to_csv(os.path.join(RESULTS_VIDEO_DIR, "real_coordinates.csv"), index=False)

# 2) Fill gaps (NaNs)
players_real_coords_interp = interpolate_nan_coords(players_real_coords)

# (opcional) salvar coordenadas interpoladas
players_real_coords_interp.to_csv(os.path.join(RESULTS_VIDEO_DIR, "filtered_coordinates.csv"), index=False)

# 3) Convert to per-frame list
player_coords_list = df_to_frame_list(players_real_coords_interp)

# 4) Temporal smoothing (Butterworth)
player_coords_filt = butterworth_filter(player_coords_list, cutoff=1, fs=30, order=4)

# 5) Metrics
fs = 30
distances, max_vel, mean_vel, max_accel, max_decel = compute_metrics(player_coords_filt, fs=fs)

# 6) Print results
for pid in distances.keys():
    print(f"\n=== Player {pid+1} ===")
    print(f"Total distance: {distances[pid]:.2f} m")
    print(f"Mean velocity: {mean_vel[pid]:.2f} m/s")
    print(f"Max velocity: {max_vel[pid]:.2f} m/s")
    print(f"Max acceleration: {max_accel[pid]:.2f} m/s²")
    print(f"Max deceleration: {max_decel[pid]:.2f} m/s²")

# ============================================================
# SAVE METRICS CSV
# ============================================================
metrics_rows = []
for pid in sorted(distances.keys()):
    metrics_rows.append({
        "video": os.path.basename(VIDEO_PATH),
        "player": f"player{pid+1}",
        "total_distance_m": float(distances[pid]),
        "mean_velocity_m_s": float(mean_vel[pid]),
        "max_velocity_m_s": float(max_vel[pid]),
        "max_acceleration_m_s2": float(max_accel[pid]),
        "max_deceleration_m_s2": float(max_decel[pid]),
        "fps": fs
    })

metrics_df = pd.DataFrame(metrics_rows)
metrics_path = os.path.join(RESULTS_VIDEO_DIR, "metrics.csv")
metrics_df.to_csv(metrics_path, index=False, encoding="utf-8")
print("Saved metrics:", metrics_path)

# ============================================================
# SAVE TRAJECTORY PLOT
# ============================================================

# Se sua plot_trajectory NÃO tem save_path, use este fallback:
plot_path = os.path.join(RESULTS_VIDEO_DIR, "trajectory.png")

# Opção A (recomendada): se você atualizou plot_trajectory(save_path=...)
try:
    plot_trajectory(player_coords_filt, save_path=plot_path)
except TypeError:
    # Opção B (fallback): salva a figura "atual" do matplotlib
    plot_trajectory(player_coords_filt)
    import matplotlib.pyplot as plt
    plt.savefig(plot_path, dpi=300, bbox_inches="tight")
    print("Saved trajectory plot (fallback):", plot_path)