# Tennis Contact Point Analysis - MVP

Upload a tennis rally video to detect contact points, estimate 3D pose, and measure contact position relative to body landmarks.

**Steps:**
1. Run the **Setup** cell to install dependencies
2. Run the **Upload & Process** cell to upload your video and run the pipeline
3. View results and download annotated images + CSV

In [None]:
#@title 1. Setup - Install dependencies and clone repo
import os, sys, shutil

REPO_URL = "https://github.com/xiaoxiang-ma/tennis_contact_point_spacing.git"
REPO_DIR = "/content/tennis_contact_point_spacing"

# Always re-clone to ensure latest code
if os.path.exists(REPO_DIR):
    shutil.rmtree(REPO_DIR)
!git clone {REPO_URL} {REPO_DIR}

!pip install -q -r {REPO_DIR}/requirements.txt

# Clear any cached module imports from previous runs
for mod_name in list(sys.modules.keys()):
    if mod_name.startswith(("src.", "utils.")):
        del sys.modules[mod_name]

if REPO_DIR not in sys.path:
    sys.path.insert(0, REPO_DIR)

print("Setup complete!")

In [None]:
#@title 2. Upload & Process Video
import numpy as np
import pandas as pd
import cv2
from google.colab import files
from tqdm.notebook import tqdm
from IPython.display import display, Image as IPImage
import os

from utils.video_io import load_video
from utils.coordinate_transforms import (
    pelvis_origin_transform, estimate_ground_plane,
    apply_ground_plane, estimate_player_height_scale
)
from src.ball_detection import BallTracker
from src.pose_estimation import PoseEstimator, LANDMARK_MAP
from src.contact_detection import detect_contacts
from src.measurements import compute_measurements
from src.visualization import (
    annotate_contact_frame, save_annotated_frame
)

# --- Upload video ---
print("Upload your tennis video (MP4, MOV, AVI):")
uploaded = files.upload()
video_filename = list(uploaded.keys())[0]
video_path = os.path.join("/content", video_filename)
with open(video_path, "wb") as f:
    f.write(uploaded[video_filename])

# --- Load video ---
print(f"\nLoading video: {video_filename}")
frames, meta = load_video(video_path)
fps = meta["fps"]
print(f"  {len(frames)} frames, {meta['width']}x{meta['height']}, {fps:.1f} fps, {meta['duration_sec']:.1f}s")

# --- Ball detection ---
print("\nDetecting ball...")
tracker = BallTracker()  # HSV fallback (no TrackNet weights)
ball_detections = tracker.detect_all(frames, progress=True)
print(f"  Ball detected in {len(ball_detections)}/{len(frames)} frames")

# Quick diagnostic: show ball position range
if ball_detections:
    xs = [d[1] for d in ball_detections]
    ys = [d[2] for d in ball_detections]
    print(f"  Ball x range: {min(xs):.0f}-{max(xs):.0f}, y range: {min(ys):.0f}-{max(ys):.0f}")

# --- Pose estimation ---
print("\nEstimating pose...")
pose_estimator = PoseEstimator(static_image_mode=True, model_complexity=2)

wrist_positions = {}  # frame_num -> (x, y) in pixels
pose_cache = {}  # frame_num -> (world_landmarks, raw_result)

# Process ALL frames (short video, and we need full wrist trajectory)
for i in tqdm(range(len(frames)), desc="Pose estimation"):
    landmarks, result = pose_estimator.process_frame(frames[i])
    if landmarks is not None:
        pose_cache[i] = (landmarks, result)

        # Get pixel-space wrist positions for contact detection
        pixel_lm = pose_estimator.get_pixel_landmarks(result, frames[i].shape)
        if pixel_lm is not None:
            h, w = frames[i].shape[:2]
            rw = pixel_lm.get("right_wrist")
            lw = pixel_lm.get("left_wrist")
            if rw and lw:
                mid_x = w / 2
                if abs(rw[0] - mid_x) > abs(lw[0] - mid_x):
                    wrist_positions[i] = (float(rw[0]), float(rw[1]))
                else:
                    wrist_positions[i] = (float(lw[0]), float(lw[1]))

pose_estimator.close()
print(f"  Pose estimated for {len(pose_cache)} frames")
print(f"  Wrist positions for {len(wrist_positions)} frames")

# --- Contact detection ---
print("\nDetecting contacts...")
contacts = detect_contacts(
    ball_detections, fps,
    wrist_positions=wrist_positions,
    velocity_spike_threshold=1.5,
    wrist_proximity_px=300,
    min_frame_gap=int(fps * 0.3),  # at least 0.3s apart
    debug=True,
)
print(f"  Found {len(contacts)} contact(s)")

if not contacts:
    print("\nNo contacts detected. Try adjusting camera angle or lighting.")

# --- Process each contact ---
output_dir = "/content/output"
os.makedirs(output_dir, exist_ok=True)

all_measurements = []

# Create a fresh pose estimator for any missing contact frames
pe2 = PoseEstimator(static_image_mode=True, model_complexity=2)

for contact_frame, confidence in contacts:
    if contact_frame not in pose_cache:
        landmarks, result = pe2.process_frame(frames[contact_frame])
        if landmarks is None:
            print(f"  Frame {contact_frame}: no pose detected, skipping")
            continue
        pose_cache[contact_frame] = (landmarks, result)

    landmarks, raw_result = pose_cache[contact_frame]

    # Determine dominant wrist (the one further extended from pelvis)
    lw = landmarks.get("left_wrist", np.zeros(3))
    rw = landmarks.get("right_wrist", np.zeros(3))
    pelvis = landmarks.get("pelvis", np.zeros(3))
    if np.linalg.norm(rw - pelvis) > np.linalg.norm(lw - pelvis):
        contact_wrist_name = "right_wrist"
        contact_point = rw
    else:
        contact_wrist_name = "left_wrist"
        contact_point = lw

    # Transform to pelvis-centered, ground-adjusted coordinates
    centered = pelvis_origin_transform(landmarks)
    ground_z = estimate_ground_plane(centered)
    adjusted = apply_ground_plane(centered, ground_z)

    # Contact point in same coordinate system
    contact_adjusted = contact_point - pelvis - np.array([0, 0, ground_z])

    # Compute measurements
    meas = compute_measurements(adjusted, contact_adjusted)
    meas["frame_number"] = contact_frame
    meas["timestamp"] = contact_frame / fps
    meas["confidence"] = confidence
    meas["contact_x"] = contact_point[0]
    meas["contact_y"] = contact_point[1]
    meas["contact_z"] = contact_point[2]
    all_measurements.append(meas)

    # Visualize
    pixel_lm = pe2.get_pixel_landmarks(raw_result, frames[contact_frame].shape)
    if pixel_lm is None:
        # Fallback: re-run pose just for pixel landmarks
        _, r2 = pe2.process_frame(frames[contact_frame])
        pixel_lm = pe2.get_pixel_landmarks(r2, frames[contact_frame].shape)

    if pixel_lm:
        annotated = annotate_contact_frame(
            frames[contact_frame], pixel_lm, contact_wrist_name,
            meas, contact_frame, fps
        )
        out_path = os.path.join(output_dir, f"contact_frame_{contact_frame}.png")
        save_annotated_frame(annotated, out_path)
    print(f"  Contact at frame {contact_frame} (t={contact_frame/fps:.2f}s, conf={confidence:.2f})")

pe2.close()

# Save CSV
if all_measurements:
    video_stem = os.path.splitext(video_filename)[0]
    csv_path = os.path.join(output_dir, f"measurements_{video_stem}.csv")
    df = pd.DataFrame(all_measurements)
    first_cols = ["frame_number", "timestamp", "confidence",
                  "contact_x", "contact_y", "contact_z"]
    other_cols = [c for c in df.columns if c not in first_cols]
    df = df[first_cols + sorted(other_cols)]
    df.to_csv(csv_path, index=False)
    print(f"\nMeasurements saved to {csv_path}")

print("\nProcessing complete!")

In [None]:
#@title 2b. Visual Debug — Wrist speed, ball detections, sample frames
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use("Agg")
from IPython.display import display, Image as IPImage

# --- Compute wrist speed for plotting ---
sorted_wf = sorted(wrist_positions.keys())
wrist_speed_frames = []
wrist_speed_vals = []
for i in range(1, len(sorted_wf)):
    f0, f1 = sorted_wf[i-1], sorted_wf[i]
    dt = (f1 - f0) / fps
    if dt <= 0:
        continue
    x0, y0 = wrist_positions[f0]
    x1, y1 = wrist_positions[f1]
    speed = np.sqrt((x1-x0)**2 + (y1-y0)**2) / dt
    wrist_speed_frames.append(f1)
    wrist_speed_vals.append(speed)

fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Top-left: Wrist speed over time (PRIMARY signal)
axes[0, 0].plot(wrist_speed_frames, wrist_speed_vals, 'b.-', markersize=2)
axes[0, 0].set_title("Wrist speed over time (px/s) — PRIMARY contact signal")
axes[0, 0].set_xlabel("Frame")
axes[0, 0].set_ylabel("Speed (px/s)")
if wrist_speed_vals:
    p85 = np.percentile(wrist_speed_vals, 85)
    axes[0, 0].axhline(y=p85, color='orange', linestyle=':', label=f'p85 threshold: {p85:.0f}')
for cf, cc in contacts:
    axes[0, 0].axvline(x=cf, color='red', alpha=0.8, linestyle='--', label=f'contact {cf}')
axes[0, 0].legend(fontsize=8)

# Top-right: Wrist X and Y position
axes[0, 1].plot(sorted_wf, [wrist_positions[f][0] for f in sorted_wf], 'b.-', markersize=2, label='wrist x')
axes[0, 1].plot(sorted_wf, [wrist_positions[f][1] for f in sorted_wf], 'c.-', markersize=2, label='wrist y')
for cf, cc in contacts:
    axes[0, 1].axvline(x=cf, color='red', alpha=0.8, linestyle='--')
axes[0, 1].set_title("Wrist position over time")
axes[0, 1].set_xlabel("Frame")
axes[0, 1].set_ylabel("Pixels")
axes[0, 1].legend()

# Bottom-left: Ball detection positions (for reference, may be wrong)
ball_frames_list = [d[0] for d in ball_detections]
ball_xs = [d[1] for d in ball_detections]
ball_ys = [d[2] for d in ball_detections]
axes[1, 0].plot(ball_frames_list, ball_xs, 'g.-', markersize=2, label='HSV "ball" x')
axes[1, 0].plot(ball_frames_list, ball_ys, 'y.-', markersize=2, label='HSV "ball" y')
axes[1, 0].set_title("HSV ball detection (may be wrong object)")
axes[1, 0].set_xlabel("Frame")
axes[1, 0].set_ylabel("Pixels")
axes[1, 0].legend()

# Bottom-right: empty or summary text
axes[1, 1].axis("off")
summary = f"Video: {len(frames)} frames, {fps:.0f} fps, {len(frames)/fps:.1f}s\n"
summary += f"Pose detected: {len(pose_cache)}/{len(frames)} frames\n"
summary += f"Wrist tracked: {len(wrist_positions)} frames\n"
summary += f"Contacts found: {len(contacts)}\n"
for cf, cc in contacts:
    summary += f"  Frame {cf} (t={cf/fps:.2f}s, conf={cc:.2f})\n"
axes[1, 1].text(0.1, 0.5, summary, fontsize=12, family='monospace',
                verticalalignment='center', transform=axes[1, 1].transAxes)

plt.tight_layout()
plt.savefig("/content/output/debug_plots.png", dpi=120)
plt.close()
display(IPImage(filename="/content/output/debug_plots.png", width=900))

# --- Annotated sample frames showing wrist (blue) and contact frames (red) ---
print("\nSample frames with wrist position (blue) and contacts (red border):")
contact_frame_set = set(cf for cf, _ in contacts)

# Show evenly spaced frames + all contact frames
sample_indices = set(np.linspace(0, len(frames) - 1, 6, dtype=int).tolist())
sample_indices.update(contact_frame_set)
sample_indices = sorted(sample_indices)[:8]

ncols = min(4, len(sample_indices))
nrows = (len(sample_indices) + ncols - 1) // ncols
fig2, axes2 = plt.subplots(nrows, ncols, figsize=(5 * ncols, 5 * nrows))
if nrows == 1 and ncols == 1:
    axes2 = np.array([axes2])
axes2 = axes2.flat

for i, (idx, ax) in enumerate(zip(sample_indices, axes2)):
    frame_rgb = cv2.cvtColor(frames[idx], cv2.COLOR_BGR2RGB).copy()

    # Draw wrist position
    if idx in wrist_positions:
        wx, wy = int(wrist_positions[idx][0]), int(wrist_positions[idx][1])
        cv2.circle(frame_rgb, (wx, wy), 12, (50, 50, 255), 3)
        cv2.putText(frame_rgb, "wrist", (wx + 15, wy - 5),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.5, (50, 50, 255), 2)

    title = f"Frame {idx} (t={idx/fps:.2f}s)"
    if idx in contact_frame_set:
        # Red border for contact frames
        cv2.rectangle(frame_rgb, (0, 0), (frame_rgb.shape[1]-1, frame_rgb.shape[0]-1),
                      (255, 0, 0), 6)
        title += " *** CONTACT ***"

    ax.imshow(frame_rgb)
    ax.set_title(title, fontsize=10, color='red' if idx in contact_frame_set else 'black')
    ax.axis("off")

for i in range(len(sample_indices), nrows * ncols):
    axes2[i].axis("off")

plt.tight_layout()
plt.savefig("/content/output/debug_frames.png", dpi=120)
plt.close()
display(IPImage(filename="/content/output/debug_frames.png", width=900))

In [None]:
#@title 3. Results - View annotated frames & download
import glob
from IPython.display import display, Image as IPImage, HTML
import pandas as pd

output_dir = "/content/output"

# Display annotated images
png_files = sorted(glob.glob(os.path.join(output_dir, "contact_frame_*.png")))
if png_files:
    print(f"Found {len(png_files)} contact frame(s):\n")
    for png_path in png_files:
        print(os.path.basename(png_path))
        display(IPImage(filename=png_path, width=800))
        print()
else:
    print("No annotated frames found.")

# Display measurements table
csv_files = glob.glob(os.path.join(output_dir, "measurements_*.csv"))
if csv_files:
    df = pd.read_csv(csv_files[0])
    print("\nMeasurements:")
    display(df)

# Download links
print("\n--- Download Files ---")
from google.colab import files as colab_files
for f in png_files + csv_files:
    colab_files.download(f)