# Tennis Contact Point Analysis - MVP

Upload a **single image** of your contact moment to analyze body positioning and measure contact point spacing.

**Steps:**
1. Run the **Setup** cell to install dependencies
2. Run the **Upload & Analyze** cell — upload your contact frame image
3. View results: annotated image with skeleton overlay + measurements table

In [None]:
#@title 1. Setup - Install dependencies and clone repo
import os, sys, shutil

REPO_URL = "https://github.com/xiaoxiang-ma/tennis_contact_point_spacing.git"
REPO_DIR = "/content/tennis_contact_point_spacing"

# Always re-clone to ensure latest code
if os.path.exists(REPO_DIR):
    shutil.rmtree(REPO_DIR)
!git clone {REPO_URL} {REPO_DIR}

!pip install -q -r {REPO_DIR}/requirements.txt

# Clear any cached module imports from previous runs
for mod_name in list(sys.modules.keys()):
    if mod_name.startswith(("src.", "utils.")):
        del sys.modules[mod_name]

if REPO_DIR not in sys.path:
    sys.path.insert(0, REPO_DIR)

print("Setup complete!")

In [None]:
#@title 2. Upload & Analyze Contact Frame
import numpy as np
import pandas as pd
import cv2
from google.colab import files
from IPython.display import display, Image as IPImage
import os
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

from utils.coordinate_transforms import (
    pelvis_origin_transform, estimate_ground_plane,
    apply_ground_plane
)
from src.pose_estimation import PoseEstimator
from src.measurements import compute_measurements
from src.visualization import (
    draw_skeleton, draw_contact_point, draw_measurements, save_annotated_frame
)

#@markdown ### Shot Type (select one):
SHOT_TYPE = "right_forehand"  #@param ["right_forehand", "right_backhand", "left_forehand", "left_backhand"]

# Determine which wrist to use based on shot type
if SHOT_TYPE in ["right_forehand", "right_backhand"]:
    contact_wrist_name = "right_wrist"
    hand_label = "RIGHT"
else:
    contact_wrist_name = "left_wrist"
    hand_label = "LEFT"

print(f"Shot type: {SHOT_TYPE} → using {hand_label} wrist as contact point")

# --- Upload image ---
print("\nUpload your contact frame image (PNG, JPG):")
uploaded = files.upload()
image_filename = list(uploaded.keys())[0]
image_path = os.path.join("/content", image_filename)
with open(image_path, "wb") as f:
    f.write(uploaded[image_filename])

# --- Load image ---
print(f"\nLoading image: {image_filename}")
frame = cv2.imread(image_path)
if frame is None:
    raise ValueError(f"Could not load image: {image_path}")
h, w = frame.shape[:2]
print(f"  Resolution: {w}x{h}")

# --- Pose estimation ---
print("\nEstimating pose...")
pose_estimator = PoseEstimator(static_image_mode=True, model_complexity=2)
landmarks, raw_result = pose_estimator.process_frame(frame)

if landmarks is None:
    pose_estimator.close()
    raise ValueError("No pose detected in image. Make sure the player is clearly visible.")

pixel_lm = pose_estimator.get_pixel_landmarks(raw_result, frame.shape)
pose_estimator.close()
print("  Pose detected successfully!")

# --- Get contact point ---
contact_point = landmarks.get(contact_wrist_name, np.zeros(3))
pelvis = landmarks.get("pelvis", np.zeros(3))

# --- Transform coordinates ---
centered = pelvis_origin_transform(landmarks)
ground_z = estimate_ground_plane(centered)
adjusted = apply_ground_plane(centered, ground_z)

# Contact point in same coordinate system
contact_adjusted = contact_point - pelvis - np.array([0, 0, ground_z])

# --- Compute measurements ---
print("Computing measurements...")
meas = compute_measurements(adjusted, contact_adjusted)
meas["contact_wrist"] = contact_wrist_name
meas["shot_type"] = SHOT_TYPE

# --- Create output directory ---
output_dir = "/content/output"
os.makedirs(output_dir, exist_ok=True)

# --- 1. Annotated frame with FULL skeleton ---
print("\n" + "="*60)
print("POSE VERIFICATION - Check that skeleton matches your body")
print("="*60)

annotated = frame.copy()
annotated = draw_skeleton(annotated, pixel_lm, thickness=3)

# Draw contact wrist with larger red marker
if contact_wrist_name in pixel_lm:
    cx, cy = pixel_lm[contact_wrist_name]
    draw_contact_point(annotated, cx, cy, radius=15)

# Add label for contact point
if contact_wrist_name in pixel_lm:
    cx, cy = pixel_lm[contact_wrist_name]
    cv2.putText(annotated, f"CONTACT ({hand_label})", (cx + 20, cy - 10),
                cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2)

out_skeleton_path = os.path.join(output_dir, f"skeleton_{os.path.splitext(image_filename)[0]}.png")
save_annotated_frame(annotated, out_skeleton_path)

print("\nSkeleton overlay (verify joints are correctly placed):")
display(IPImage(filename=out_skeleton_path, width=800))

# --- 2. 3D Pose Visualization ---
print("\n" + "="*60)
print("3D POSE VIEW - Verify depth estimation")
print("="*60)

fig = plt.figure(figsize=(12, 5))

# 3D scatter of landmarks
ax1 = fig.add_subplot(121, projection='3d')
ax1.set_title("3D Pose (pelvis-centered)")

# Plot each landmark
colors = {
    'pelvis': 'black', 'left_hip': 'green', 'right_hip': 'green',
    'left_shoulder': 'blue', 'right_shoulder': 'blue',
    'left_elbow': 'cyan', 'right_elbow': 'cyan',
    'left_wrist': 'magenta', 'right_wrist': 'magenta',
    'left_knee': 'orange', 'right_knee': 'orange',
    'left_ankle': 'brown', 'right_ankle': 'brown',
    'nose': 'red', 'head': 'red'
}

for name, coords in adjusted.items():
    if name in colors:
        c = colors[name]
        # MediaPipe: x=lateral, y=vertical (neg=up), z=depth (neg=forward)
        ax1.scatter(coords[0]*100, -coords[2]*100, -coords[1]*100,
                   c=c, s=50, label=name)
        ax1.text(coords[0]*100, -coords[2]*100, -coords[1]*100, name, fontsize=7)

# Mark contact point
cx, cy, cz = contact_adjusted
ax1.scatter(cx*100, -cz*100, -cy*100, c='red', s=200, marker='*', label='CONTACT')

ax1.set_xlabel('Lateral (cm)')
ax1.set_ylabel('Forward (cm)')
ax1.set_zlabel('Height (cm)')

# Top-down view (bird's eye)
ax2 = fig.add_subplot(122)
ax2.set_title("Top-down view (bird's eye)")
ax2.set_aspect('equal')

for name, coords in adjusted.items():
    if name in colors:
        c = colors[name]
        ax2.scatter(coords[0]*100, -coords[2]*100, c=c, s=50)
        ax2.annotate(name, (coords[0]*100, -coords[2]*100), fontsize=7)

ax2.scatter(cx*100, -cz*100, c='red', s=200, marker='*')
ax2.annotate('CONTACT', (cx*100, -cz*100), fontsize=9, color='red')
ax2.set_xlabel('Lateral (cm) ← Left | Right →')
ax2.set_ylabel('Forward (cm) ↑')
ax2.axhline(y=0, color='gray', linestyle='--', alpha=0.5)
ax2.axvline(x=0, color='gray', linestyle='--', alpha=0.5)
ax2.grid(True, alpha=0.3)

plt.tight_layout()
out_3d_path = os.path.join(output_dir, f"pose_3d_{os.path.splitext(image_filename)[0]}.png")
plt.savefig(out_3d_path, dpi=120)
plt.close()

display(IPImage(filename=out_3d_path, width=900))

# --- 3. Measurements ---
print("\n" + "="*60)
print("CONTACT POINT MEASUREMENTS")
print("="*60)
print(f"\nShot type: {SHOT_TYPE}")
print(f"Contact wrist: {hand_label}")
print()
print(f"Lateral offset:      {meas.get('lateral_offset_cm', 0):>7.1f} cm  ({meas.get('lateral_offset_inches', 0):>5.1f} in)")
print(f"  (+ = to your left, - = to your right)")
print(f"Forward/back:        {meas.get('forward_back_cm', 0):>7.1f} cm  ({meas.get('forward_back_inches', 0):>5.1f} in)")
print(f"  (+ = in front, - = behind)")
print(f"Height above ground: {meas.get('height_above_ground_cm', 0):>7.1f} cm  ({meas.get('height_above_ground_inches', 0):>5.1f} in)")
if "shoulder_line_distance_cm" in meas:
    print(f"Shoulder line dist:  {meas.get('shoulder_line_distance_cm', 0):>7.1f} cm  ({meas.get('shoulder_line_distance_inches', 0):>5.1f} in)")
if "relative_to_shoulder_height_cm" in meas:
    print(f"vs Shoulder height:  {meas.get('relative_to_shoulder_height_cm', 0):>7.1f} cm  ({meas.get('relative_to_shoulder_height_inches', 0):>5.1f} in)")
    print(f"  (+ = above shoulder, - = below)")
print("="*60)

# Save CSV
csv_path = os.path.join(output_dir, f"measurements_{os.path.splitext(image_filename)[0]}.csv")
df = pd.DataFrame([meas])
df.to_csv(csv_path, index=False)
print(f"\nResults saved to {output_dir}/")

In [None]:
#@title 3. Download Results
from google.colab import files as colab_files
import glob

output_dir = "/content/output"
png_files = glob.glob(os.path.join(output_dir, "contact_analyzed_*.png"))
csv_files = glob.glob(os.path.join(output_dir, "measurements_*.csv"))

print("Downloading files...")
for f in png_files + csv_files:
    print(f"  {os.path.basename(f)}")
    colab_files.download(f)