# Tennis Contact Point Analysis - MVP

Upload a tennis rally video to detect contact points, estimate 3D pose, and measure contact position relative to body landmarks.

**Steps:**
1. Run the **Setup** cell to install dependencies
2. Run the **Upload & Process** cell to upload your video and run the pipeline
3. View results and download annotated images + CSV

In [None]:
#@title 1. Setup - Install dependencies and clone repo
import os

# Clone the repository if not already present
REPO_URL = "https://github.com/YOUR_USERNAME/tennis_contact_point_spacing.git"  # Update this
REPO_DIR = "/content/tennis_contact_point_spacing"

if not os.path.exists(REPO_DIR):
    !git clone {REPO_URL} {REPO_DIR}
else:
    !cd {REPO_DIR} && git pull

!pip install -q -r {REPO_DIR}/requirements.txt

import sys
sys.path.insert(0, REPO_DIR)

print("Setup complete!")

In [None]:
#@title 2. Upload & Process Video
import numpy as np
import pandas as pd
import cv2
import mediapipe as mp
from google.colab import files
from tqdm.notebook import tqdm
from IPython.display import display, Image as IPImage
import os

from utils.video_io import load_video
from utils.coordinate_transforms import (
    pelvis_origin_transform, estimate_ground_plane,
    apply_ground_plane, estimate_player_height_scale
)
from src.ball_detection import BallTracker
from src.pose_estimation import PoseEstimator, LANDMARK_MAP
from src.contact_detection import detect_contacts
from src.measurements import compute_measurements
from src.visualization import (
    annotate_contact_frame, save_annotated_frame, _world_to_pixel
)

# --- Upload video ---
print("Upload your tennis video (MP4, MOV, AVI):")
uploaded = files.upload()
video_filename = list(uploaded.keys())[0]
video_path = os.path.join("/content", video_filename)
with open(video_path, "wb") as f:
    f.write(uploaded[video_filename])

# --- Load video ---
print(f"\nLoading video: {video_filename}")
frames, meta = load_video(video_path)
fps = meta["fps"]
print(f"  {len(frames)} frames, {meta['width']}x{meta['height']}, {fps:.1f} fps, {meta['duration_sec']:.1f}s")

# --- Ball detection ---
print("\nDetecting ball...")
tracker = BallTracker()  # HSV fallback (no TrackNet weights)
ball_detections = tracker.detect_all(frames, progress=True)
print(f"  Ball detected in {len(ball_detections)}/{len(frames)} frames")

# --- Pose estimation (run on all frames to get wrist positions) ---
print("\nEstimating pose...")
pose_estimator = PoseEstimator(static_image_mode=True, model_complexity=2)

# We need image-space landmarks for visualization, so use raw MediaPipe
mp_pose = mp.solutions.pose.Pose(static_image_mode=True, model_complexity=2)

# Get wrist positions for contact detection (in pixel coords)
wrist_positions = {}  # frame_num -> (x, y)
pose_cache = {}  # frame_num -> (world_landmarks, mp_result)

# Only process frames near ball detections + some padding
ball_frames = set(d[0] for d in ball_detections)
frames_to_process = set()
for bf in ball_frames:
    for offset in range(-5, 6):
        f = bf + offset
        if 0 <= f < len(frames):
            frames_to_process.add(f)

for i in tqdm(sorted(frames_to_process), desc="Pose estimation"):
    rgb = cv2.cvtColor(frames[i], cv2.COLOR_BGR2RGB)
    result = mp_pose.process(rgb)
    if result.pose_world_landmarks is not None:
        # Extract world landmarks
        landmarks = {}
        wl = result.pose_world_landmarks.landmark
        for name, idx in LANDMARK_MAP.items():
            lm = wl[idx]
            landmarks[name] = np.array([lm.x, lm.y, lm.z])
        if "left_hip" in landmarks and "right_hip" in landmarks:
            landmarks["pelvis"] = (landmarks["left_hip"] + landmarks["right_hip"]) / 2.0
        if "nose" in landmarks:
            landmarks["head"] = landmarks["nose"]
        pose_cache[i] = (landmarks, result)

        # Wrist pixel positions (use right wrist as dominant hand guess)
        if result.pose_landmarks is not None:
            h, w = frames[i].shape[:2]
            rw = result.pose_landmarks.landmark[16]  # right wrist
            lw = result.pose_landmarks.landmark[15]  # left wrist
            # Use the wrist that's further from the body center (more extended)
            rw_x, rw_y = rw.x * w, rw.y * h
            lw_x, lw_y = lw.x * w, lw.y * h
            # Pick the wrist further from body midpoint
            mid_x = w / 2
            if abs(rw_x - mid_x) > abs(lw_x - mid_x):
                wrist_positions[i] = (rw_x, rw_y)
            else:
                wrist_positions[i] = (lw_x, lw_y)

mp_pose.close()
print(f"  Pose estimated for {len(pose_cache)} frames")

# --- Contact detection ---
print("\nDetecting contacts...")
contacts = detect_contacts(
    ball_detections, fps,
    wrist_positions=wrist_positions,
    velocity_spike_threshold=1.8,
    wrist_proximity_px=200,
    min_frame_gap=int(fps * 0.5)  # at least 0.5s apart
)
print(f"  Found {len(contacts)} contact(s)")

if not contacts:
    print("\nNo contacts detected. Try adjusting camera angle or lighting.")

# --- Process each contact ---
output_dir = "/content/output"
os.makedirs(output_dir, exist_ok=True)

all_measurements = []

for contact_frame, confidence in contacts:
    if contact_frame not in pose_cache:
        # Try to get pose for this specific frame
        rgb = cv2.cvtColor(frames[contact_frame], cv2.COLOR_BGR2RGB)
        result = mp.solutions.pose.Pose(static_image_mode=True, model_complexity=2).process(rgb)
        if result.pose_world_landmarks is None:
            print(f"  Frame {contact_frame}: no pose detected, skipping")
            continue
        landmarks = {}
        wl = result.pose_world_landmarks.landmark
        for name, idx in LANDMARK_MAP.items():
            lm = wl[idx]
            landmarks[name] = np.array([lm.x, lm.y, lm.z])
        if "left_hip" in landmarks and "right_hip" in landmarks:
            landmarks["pelvis"] = (landmarks["left_hip"] + landmarks["right_hip"]) / 2.0
        if "nose" in landmarks:
            landmarks["head"] = landmarks["nose"]
        pose_cache[contact_frame] = (landmarks, result)

    landmarks, mp_result = pose_cache[contact_frame]

    # Determine dominant wrist (the one further extended)
    lw = landmarks.get("left_wrist", np.zeros(3))
    rw = landmarks.get("right_wrist", np.zeros(3))
    pelvis = landmarks.get("pelvis", np.zeros(3))
    if np.linalg.norm(rw - pelvis) > np.linalg.norm(lw - pelvis):
        contact_wrist_name = "right_wrist"
        contact_point = rw
    else:
        contact_wrist_name = "left_wrist"
        contact_point = lw

    # Transform to pelvis-centered, ground-adjusted coordinates
    centered = pelvis_origin_transform(landmarks)
    ground_z = estimate_ground_plane(centered)
    adjusted = apply_ground_plane(centered, ground_z)

    # Contact point in same coordinate system
    contact_adjusted = contact_point - pelvis - np.array([0, 0, ground_z])

    # Compute measurements
    meas = compute_measurements(adjusted, contact_adjusted)
    meas["frame_number"] = contact_frame
    meas["timestamp"] = contact_frame / fps
    meas["confidence"] = confidence
    meas["contact_x"] = contact_point[0]
    meas["contact_y"] = contact_point[1]
    meas["contact_z"] = contact_point[2]
    all_measurements.append(meas)

    # Visualize
    pixel_lm = _world_to_pixel(landmarks, frames[contact_frame], mp_result)
    annotated = annotate_contact_frame(
        frames[contact_frame], pixel_lm, contact_wrist_name,
        meas, contact_frame, fps
    )
    out_path = os.path.join(output_dir, f"contact_frame_{contact_frame}.png")
    save_annotated_frame(annotated, out_path)
    print(f"  Contact at frame {contact_frame} (t={contact_frame/fps:.2f}s, conf={confidence:.2f})")

# Save CSV
if all_measurements:
    video_stem = os.path.splitext(video_filename)[0]
    csv_path = os.path.join(output_dir, f"measurements_{video_stem}.csv")
    df = pd.DataFrame(all_measurements)
    # Reorder columns
    first_cols = ["frame_number", "timestamp", "confidence",
                  "contact_x", "contact_y", "contact_z"]
    other_cols = [c for c in df.columns if c not in first_cols]
    df = df[first_cols + sorted(other_cols)]
    df.to_csv(csv_path, index=False)
    print(f"\nMeasurements saved to {csv_path}")

print("\nProcessing complete!")

In [None]:
#@title 3. Results - View annotated frames & download
import glob
from IPython.display import display, Image as IPImage, HTML
import pandas as pd

output_dir = "/content/output"

# Display annotated images
png_files = sorted(glob.glob(os.path.join(output_dir, "contact_frame_*.png")))
if png_files:
    print(f"Found {len(png_files)} contact frame(s):\n")
    for png_path in png_files:
        print(os.path.basename(png_path))
        display(IPImage(filename=png_path, width=800))
        print()
else:
    print("No annotated frames found.")

# Display measurements table
csv_files = glob.glob(os.path.join(output_dir, "measurements_*.csv"))
if csv_files:
    df = pd.read_csv(csv_files[0])
    print("\nMeasurements:")
    display(df)

# Download links
print("\n--- Download Files ---")
from google.colab import files as colab_files
for f in png_files + csv_files:
    colab_files.download(f)