# Tennis Contact Point Analysis

Upload a **video** of your tennis rally to automatically detect contact frames and analyze body positioning.

**How it works:**
1. **TrackNet** (deep learning) tracks the ball throughout the video
2. **Audio analysis** detects impact sounds (ball hitting racket)
3. Both signals are **fused** for robust contact detection
4. For each contact, we analyze pose and measure contact point spacing

**Steps:**
1. Run the **Setup** cell to install dependencies
2. Run the **Upload & Analyze** cell - upload your video
3. View detected contacts and select which to analyze in detail

In [None]:
#@title 1. Setup - Install dependencies and clone repo
import os, sys, shutil

REPO_URL = "https://github.com/xiaoxiang-ma/tennis_contact_point_spacing.git"
REPO_DIR = "/content/tennis_contact_point_spacing"

# Always re-clone to ensure latest code
if os.path.exists(REPO_DIR):
    shutil.rmtree(REPO_DIR)
!git clone {REPO_URL} {REPO_DIR}

!pip install -q -r {REPO_DIR}/requirements.txt

# Clear any cached module imports from previous runs
for mod_name in list(sys.modules.keys()):
    if mod_name.startswith(("src.", "utils.")):
        del sys.modules[mod_name]

if REPO_DIR not in sys.path:
    sys.path.insert(0, REPO_DIR)

# Create weights directory
os.makedirs(os.path.join(REPO_DIR, "weights"), exist_ok=True)

print("\nSetup complete!")
print("Note: TrackNet weights will be downloaded automatically on first use (~50MB)")

In [None]:
#@title 2. Upload Video & Detect Contacts
import numpy as np
import cv2
from google.colab import files
from IPython.display import display, Image as IPImage, HTML
import os

from utils.video_io import load_video
from src.tracknet import TrackNetDetector
from src.contact_detection import detect_contacts, get_contact_ball_position

#@markdown ### Settings
SHOT_TYPE = "right_forehand"  #@param ["right_forehand", "right_backhand", "left_forehand", "left_backhand"]
USE_AUDIO = True  #@param {type:"boolean"}
DEBUG_MODE = True  #@param {type:"boolean"}
SAVE_DEBUG_FRAMES = True  #@param {type:"boolean"}

# Create output directory
output_dir = "/content/output"
os.makedirs(output_dir, exist_ok=True)

# --- Upload video ---
print("Upload your tennis video (MP4, MOV, etc.):")
uploaded = files.upload()
video_filename = list(uploaded.keys())[0]
video_path = os.path.join("/content", video_filename)
with open(video_path, "wb") as f:
    f.write(uploaded[video_filename])

# --- Load video ---
print(f"\nLoading video: {video_filename}")
frames, metadata = load_video(video_path)
fps = metadata["fps"]
print(f"  Resolution: {metadata['width']}x{metadata['height']}")
print(f"  Frame rate: {fps:.1f} fps")
print(f"  Duration: {metadata['duration_sec']:.2f}s ({len(frames)} frames)")

# --- Initialize TrackNet ---
print("\nInitializing TrackNet (downloading weights if needed)...")
tracknet = TrackNetDetector(
    weights_path=None,  # Auto-download
    device=None,  # Auto-detect GPU/CPU
    confidence_threshold=0.5,
    save_debug_frames=SAVE_DEBUG_FRAMES,
    debug_output_dir=os.path.join(output_dir, "tracknet_debug"),
)

# Check if weights loaded successfully
if not tracknet.weights_loaded:
    print("\n" + "!"*60)
    print("WARNING: TrackNet weights failed to load!")
    print("Ball detection will NOT work properly.")
    print("Check the error message above for details.")
    print("!"*60 + "\n")

# --- Detect contacts ---
print("\nDetecting contacts (this may take a few minutes)...")
contacts, ball_detections = detect_contacts(
    video_path=video_path,
    frames=frames,
    fps=fps,
    tracknet_detector=tracknet,
    use_audio=USE_AUDIO,
    debug=DEBUG_MODE,
    save_debug_frames=SAVE_DEBUG_FRAMES,
)

# --- Display results ---
print(f"\n" + "="*60)
print(f"CONTACT DETECTION RESULTS")
print(f"="*60)
print(f"Ball detected in {len(ball_detections)}/{len(frames)} frames ({100*len(ball_detections)/len(frames):.1f}%)")

# Show sample debug frames if available
if SAVE_DEBUG_FRAMES:
    debug_dir = os.path.join(output_dir, "tracknet_debug")
    import glob
    debug_frames = sorted(glob.glob(os.path.join(debug_dir, "frame_*.png")))[:3]
    if debug_frames:
        print(f"\nSample TrackNet debug frames (from {debug_dir}):")
        for df in debug_frames:
            print(f"  - {os.path.basename(df)}")
            display(IPImage(filename=df, width=800))

print(f"\nDetected {len(contacts)} contact(s):")
print()

# Store contact info for later use
contact_info = []
for i, (frame_num, confidence, source) in enumerate(contacts):
    time_sec = frame_num / fps
    
    # Get ball position at this contact
    ball_pos, ball_method = get_contact_ball_position(frame_num, ball_detections)
    
    contact_info.append({
        'index': i,
        'frame': frame_num,
        'time': time_sec,
        'confidence': confidence,
        'source': source,
        'ball_pos': ball_pos,
        'ball_method': ball_method,
    })
    
    print(f"  Contact {i+1}: Frame {frame_num} ({time_sec:.2f}s)")
    print(f"    Confidence: {confidence:.0%} (source: {source})")
    if ball_pos:
        print(f"    Ball position: ({ball_pos[0]:.0f}, {ball_pos[1]:.0f}) [{ball_method}]")
    else:
        print(f"    Ball position: Not available")
    print()

# Store for next cell
ANALYSIS_DATA = {
    'frames': frames,
    'fps': fps,
    'metadata': metadata,
    'contacts': contact_info,
    'ball_detections': ball_detections,
    'shot_type': SHOT_TYPE,
    'video_path': video_path,
}

if len(contacts) > 0:
    print("\nRun the next cell to analyze a specific contact in detail.")
else:
    print("\nNo contacts detected. Check the debug frames above to see what TrackNet is detecting.")
    print("You can also run the Debug cell at the bottom to see more details.")

In [None]:
#@title 3. Analyze Selected Contact
import numpy as np
import pandas as pd
import cv2
from IPython.display import display, Image as IPImage
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import os

from utils.coordinate_transforms import (
    pelvis_origin_transform, estimate_ground_plane, apply_ground_plane
)
from src.pose_estimation import PoseEstimator
from src.measurements import compute_measurements
from src.visualization import (
    draw_skeleton, draw_contact_point, save_annotated_frame
)

#@markdown ### Select which contact to analyze:
CONTACT_INDEX = 0  #@param {type:"integer"}

# Validate
if 'ANALYSIS_DATA' not in dir():
    raise ValueError("Please run cell 2 first to detect contacts!")

contacts = ANALYSIS_DATA['contacts']
if len(contacts) == 0:
    raise ValueError("No contacts were detected in the video.")

if CONTACT_INDEX < 0 or CONTACT_INDEX >= len(contacts):
    raise ValueError(f"Invalid contact index. Valid range: 0-{len(contacts)-1}")

contact = contacts[CONTACT_INDEX]
frame_num = contact['frame']
ball_pos = contact['ball_pos']
ball_method = contact['ball_method']
shot_type = ANALYSIS_DATA['shot_type']
frames = ANALYSIS_DATA['frames']
fps = ANALYSIS_DATA['fps']

print(f"Analyzing Contact {CONTACT_INDEX + 1}")
print(f"  Frame: {frame_num} ({contact['time']:.2f}s)")
print(f"  Detection confidence: {contact['confidence']:.0%}")
print(f"  Ball position method: {ball_method}")

# Get frame
frame = frames[frame_num]
h, w = frame.shape[:2]

# Determine which wrist to use based on shot type
if shot_type in ["right_forehand", "right_backhand"]:
    contact_wrist_name = "right_wrist"
    hand_label = "RIGHT"
else:
    contact_wrist_name = "left_wrist"
    hand_label = "LEFT"

# --- Pose estimation ---
print("\nEstimating pose...")
pose_estimator = PoseEstimator(static_image_mode=True, model_complexity=2)
landmarks, raw_result = pose_estimator.process_frame(frame)

if landmarks is None:
    pose_estimator.close()
    raise ValueError("No pose detected in frame. The player may not be clearly visible.")

pixel_lm = pose_estimator.get_pixel_landmarks(raw_result, frame.shape)
pose_estimator.close()
print("  Pose detected successfully!")

# --- Get contact point (ball position or fallback to wrist) ---
if ball_pos is not None:
    contact_pixel = ball_pos
    contact_source = f"ball ({ball_method})"
    print(f"  Using ball position as contact point: ({ball_pos[0]:.0f}, {ball_pos[1]:.0f})")
else:
    # Fallback to wrist
    if contact_wrist_name in pixel_lm:
        contact_pixel = pixel_lm[contact_wrist_name]
        contact_source = "wrist (fallback)"
        print(f"  Ball not detected - using wrist position as fallback")
    else:
        raise ValueError("Neither ball nor wrist position available")

# --- Transform coordinates for measurements ---
pelvis = landmarks.get("pelvis", np.zeros(3))
centered = pelvis_origin_transform(landmarks)
ground_z = estimate_ground_plane(centered)
adjusted = apply_ground_plane(centered, ground_z)

# For contact point, we need to estimate 3D position from 2D ball detection
# Use wrist depth as approximation (ball is near wrist at contact)
wrist_3d = landmarks.get(contact_wrist_name, np.zeros(3))
contact_3d = wrist_3d.copy()  # Start with wrist position

# If we have ball pixel position, adjust x/y based on pixel offset from wrist
if ball_pos is not None and contact_wrist_name in pixel_lm:
    wrist_px = pixel_lm[contact_wrist_name]
    # Estimate scale from wrist (pixels to normalized coords)
    # This is approximate - proper depth estimation would need stereo or depth camera
    px_offset_x = ball_pos[0] - wrist_px[0]
    px_offset_y = ball_pos[1] - wrist_px[1]
    
    # Rough scaling (assume ~1000px corresponds to ~1 normalized unit)
    scale = 0.001  
    contact_3d[0] += px_offset_x * scale
    contact_3d[1] += px_offset_y * scale

contact_adjusted = contact_3d - pelvis - np.array([0, 0, ground_z])

# --- Compute measurements ---
print("Computing measurements...")
meas = compute_measurements(adjusted, contact_adjusted)
meas["contact_source"] = contact_source
meas["ball_detection_method"] = ball_method
meas["shot_type"] = shot_type
meas["frame_num"] = frame_num
meas["contact_confidence"] = contact['confidence']

# --- Create output directory ---
output_dir = "/content/output"
os.makedirs(output_dir, exist_ok=True)

# --- 1. Annotated frame with skeleton + contact point ---
print("\n" + "="*60)
print("CONTACT FRAME ANALYSIS")
print("="*60)

annotated = frame.copy()
annotated = draw_skeleton(annotated, pixel_lm, thickness=3)

# Draw contact point (ball position or wrist)
cx, cy = int(contact_pixel[0]), int(contact_pixel[1])
draw_contact_point(annotated, cx, cy, radius=15)

# Add label
label = f"CONTACT ({contact_source})"
cv2.putText(annotated, label, (cx + 20, cy - 10),
            cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2)

# Add frame info
info_text = f"Frame {frame_num} | {contact['time']:.2f}s | Conf: {contact['confidence']:.0%}"
cv2.putText(annotated, info_text, (20, 30),
            cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)

out_path = os.path.join(output_dir, f"contact_{CONTACT_INDEX+1}_frame_{frame_num}.png")
save_annotated_frame(annotated, out_path)

print("\nAnnotated contact frame:")
display(IPImage(filename=out_path, width=800))

# --- 2. Ball trajectory visualization ---
print("\n" + "="*60)
print("BALL TRAJECTORY")
print("="*60)

ball_detections = ANALYSIS_DATA['ball_detections']
if ball_detections:
    traj_frame = frame.copy()
    
    # Draw ball trajectory
    sorted_frames = sorted(ball_detections.keys())
    points = [(int(ball_detections[f][0]), int(ball_detections[f][1])) for f in sorted_frames]
    
    # Draw trajectory line
    for i in range(1, len(points)):
        # Color gradient: blue (old) -> green (new)
        t = i / len(points)
        color = (int(255*(1-t)), int(255*t), 0)
        cv2.line(traj_frame, points[i-1], points[i], color, 2)
    
    # Draw ball positions as dots
    for f in sorted_frames:
        x, y, conf = ball_detections[f]
        color = (0, 255, 0) if f != frame_num else (0, 0, 255)
        cv2.circle(traj_frame, (int(x), int(y)), 4, color, -1)
    
    # Highlight contact frame ball position
    if ball_pos:
        cv2.circle(traj_frame, (int(ball_pos[0]), int(ball_pos[1])), 12, (0, 0, 255), 3)
        cv2.putText(traj_frame, "CONTACT", (int(ball_pos[0])+15, int(ball_pos[1])-5),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 2)
    
    traj_path = os.path.join(output_dir, f"trajectory_{CONTACT_INDEX+1}.png")
    cv2.imwrite(traj_path, traj_frame)
    display(IPImage(filename=traj_path, width=800))
else:
    print("No ball trajectory data available.")

# --- 3. Measurements ---
print("\n" + "="*60)
print("CONTACT POINT MEASUREMENTS")
print("="*60)
print(f"\nShot type: {shot_type}")
print(f"Contact source: {contact_source}")
print()
print(f"Lateral offset:      {meas.get('lateral_offset_cm', 0):>7.1f} cm  ({meas.get('lateral_offset_inches', 0):>5.1f} in)")
print(f"  (+ = to your left, - = to your right)")
print(f"Forward/back:        {meas.get('forward_back_cm', 0):>7.1f} cm  ({meas.get('forward_back_inches', 0):>5.1f} in)")
print(f"  (+ = in front, - = behind)")
print(f"Height above ground: {meas.get('height_above_ground_cm', 0):>7.1f} cm  ({meas.get('height_above_ground_inches', 0):>5.1f} in)")
if "shoulder_line_distance_cm" in meas:
    print(f"Shoulder line dist:  {meas.get('shoulder_line_distance_cm', 0):>7.1f} cm  ({meas.get('shoulder_line_distance_inches', 0):>5.1f} in)")
if "relative_to_shoulder_height_cm" in meas:
    print(f"vs Shoulder height:  {meas.get('relative_to_shoulder_height_cm', 0):>7.1f} cm  ({meas.get('relative_to_shoulder_height_inches', 0):>5.1f} in)")
    print(f"  (+ = above shoulder, - = below)")
print("="*60)

# Save measurements CSV
csv_path = os.path.join(output_dir, f"measurements_contact_{CONTACT_INDEX+1}.csv")
df = pd.DataFrame([meas])
df.to_csv(csv_path, index=False)
print(f"\nMeasurements saved to {csv_path}")

In [None]:
#@title 4. Analyze All Contacts (Batch)
import pandas as pd
import os

from utils.coordinate_transforms import (
    pelvis_origin_transform, estimate_ground_plane, apply_ground_plane
)
from src.pose_estimation import PoseEstimator
from src.measurements import compute_measurements

#@markdown Run this cell to analyze ALL detected contacts at once.

if 'ANALYSIS_DATA' not in dir():
    raise ValueError("Please run cell 2 first!")

contacts = ANALYSIS_DATA['contacts']
frames = ANALYSIS_DATA['frames']
fps = ANALYSIS_DATA['fps']
shot_type = ANALYSIS_DATA['shot_type']
ball_detections = ANALYSIS_DATA['ball_detections']

if len(contacts) == 0:
    print("No contacts to analyze.")
else:
    print(f"Analyzing {len(contacts)} contacts...\n")
    
    all_measurements = []
    pose_estimator = PoseEstimator(static_image_mode=True, model_complexity=2)
    
    # Determine wrist based on shot type
    if shot_type in ["right_forehand", "right_backhand"]:
        contact_wrist_name = "right_wrist"
    else:
        contact_wrist_name = "left_wrist"
    
    for contact in contacts:
        idx = contact['index']
        frame_num = contact['frame']
        ball_pos = contact['ball_pos']
        ball_method = contact['ball_method']
        
        print(f"Contact {idx+1}: Frame {frame_num}...", end=" ")
        
        frame = frames[frame_num]
        landmarks, raw_result = pose_estimator.process_frame(frame)
        
        if landmarks is None:
            print("SKIPPED (no pose detected)")
            continue
        
        pixel_lm = pose_estimator.get_pixel_landmarks(raw_result, frame.shape)
        
        # Get contact point
        if ball_pos is not None:
            contact_pixel = ball_pos
            contact_source = f"ball ({ball_method})"
        elif contact_wrist_name in pixel_lm:
            contact_pixel = pixel_lm[contact_wrist_name]
            contact_source = "wrist (fallback)"
        else:
            print("SKIPPED (no contact point)")
            continue
        
        # Transform and measure
        pelvis = landmarks.get("pelvis", np.zeros(3))
        centered = pelvis_origin_transform(landmarks)
        ground_z = estimate_ground_plane(centered)
        adjusted = apply_ground_plane(centered, ground_z)
        
        wrist_3d = landmarks.get(contact_wrist_name, np.zeros(3))
        contact_3d = wrist_3d.copy()
        
        if ball_pos is not None and contact_wrist_name in pixel_lm:
            wrist_px = pixel_lm[contact_wrist_name]
            px_offset_x = ball_pos[0] - wrist_px[0]
            px_offset_y = ball_pos[1] - wrist_px[1]
            scale = 0.001
            contact_3d[0] += px_offset_x * scale
            contact_3d[1] += px_offset_y * scale
        
        contact_adjusted = contact_3d - pelvis - np.array([0, 0, ground_z])
        meas = compute_measurements(adjusted, contact_adjusted)
        
        meas["contact_index"] = idx + 1
        meas["frame_num"] = frame_num
        meas["time_sec"] = contact['time']
        meas["contact_confidence"] = contact['confidence']
        meas["contact_source"] = contact_source
        meas["detection_source"] = contact['source']
        meas["shot_type"] = shot_type
        
        all_measurements.append(meas)
        print("OK")
    
    pose_estimator.close()
    
    # Save combined results
    if all_measurements:
        output_dir = "/content/output"
        os.makedirs(output_dir, exist_ok=True)
        
        df = pd.DataFrame(all_measurements)
        csv_path = os.path.join(output_dir, "all_contacts_measurements.csv")
        df.to_csv(csv_path, index=False)
        
        print(f"\n" + "="*60)
        print("SUMMARY OF ALL CONTACTS")
        print("="*60)
        display(df[['contact_index', 'frame_num', 'time_sec', 'contact_confidence', 
                    'lateral_offset_cm', 'forward_back_cm', 'height_above_ground_cm',
                    'contact_source']])
        
        print(f"\nResults saved to {csv_path}")

In [None]:
#@title 5. Download Results
from google.colab import files as colab_files
import glob

output_dir = "/content/output"
all_files = glob.glob(os.path.join(output_dir, "*"))

print("Downloading files...")
for f in all_files:
    print(f"  {os.path.basename(f)}")
    colab_files.download(f)

## Debug: View Ball Detection & Audio Analysis (Optional)

Run the cell below to visualize TrackNet's ball detection and audio envelope.

In [None]:
#@title Debug: Ball Detection & Audio Visualization
import matplotlib.pyplot as plt
import numpy as np

if 'ANALYSIS_DATA' not in dir():
    raise ValueError("Please run cell 2 first!")

ball_detections = ANALYSIS_DATA['ball_detections']
contacts = ANALYSIS_DATA['contacts']
fps = ANALYSIS_DATA['fps']
num_frames = len(ANALYSIS_DATA['frames'])
video_path = ANALYSIS_DATA['video_path']

fig, axes = plt.subplots(3, 1, figsize=(14, 10))

# Plot 1: Ball detection confidence over time
ax1 = axes[0]
if ball_detections:
    frames_detected = sorted(ball_detections.keys())
    confidences = [ball_detections[f][2] for f in frames_detected]
    times = [f/fps for f in frames_detected]
    
    ax1.plot(times, confidences, 'b-', alpha=0.7, label='Ball confidence')
    ax1.axhline(y=0.5, color='r', linestyle='--', alpha=0.5, label='Threshold')
    
    # Mark contacts
    for c in contacts:
        ax1.axvline(x=c['time'], color='g', linestyle='-', alpha=0.8)
        ax1.text(c['time'], 1.0, f"C{c['index']+1}", ha='center', va='bottom', fontsize=10)

ax1.set_xlabel('Time (s)')
ax1.set_ylabel('Detection Confidence')
ax1.set_title('TrackNet Ball Detection Confidence')
ax1.legend()
ax1.set_ylim(0, 1.1)

# Plot 2: Ball position (x, y) over time
ax2 = axes[1]
if ball_detections:
    xs = [ball_detections[f][0] for f in frames_detected]
    ys = [ball_detections[f][1] for f in frames_detected]
    
    ax2.plot(times, xs, 'r-', alpha=0.7, label='X position')
    ax2.plot(times, ys, 'b-', alpha=0.7, label='Y position')
    
    for c in contacts:
        ax2.axvline(x=c['time'], color='g', linestyle='-', alpha=0.8)

ax2.set_xlabel('Time (s)')
ax2.set_ylabel('Position (pixels)')
ax2.set_title('Ball Position Over Time')
ax2.legend()

# Plot 3: Audio envelope (if available)
ax3 = axes[2]
try:
    from src.audio_detection import get_audio_envelope_for_debug
    audio, sr, envelope = get_audio_envelope_for_debug(video_path)
    
    audio_times = np.arange(len(envelope)) / sr
    ax3.plot(audio_times, envelope, 'purple', alpha=0.7, label='Audio envelope (1-4kHz)')
    
    for c in contacts:
        ax3.axvline(x=c['time'], color='g', linestyle='-', alpha=0.8)
        ax3.text(c['time'], envelope.max(), f"C{c['index']+1}", ha='center', va='bottom', fontsize=10)
    
    ax3.set_xlabel('Time (s)')
    ax3.set_ylabel('Amplitude')
    ax3.set_title('Audio Envelope (bandpass filtered for impact sounds)')
    ax3.legend()
except Exception as e:
    ax3.text(0.5, 0.5, f'Audio analysis not available: {e}', 
             ha='center', va='center', transform=ax3.transAxes)
    ax3.set_title('Audio Envelope (not available)')

plt.tight_layout()
plt.savefig('/content/output/detection_debug.png', dpi=120)
plt.show()

print(f"\nBall detected in {len(ball_detections)}/{num_frames} frames ({100*len(ball_detections)/num_frames:.1f}%)")
print(f"Contacts detected: {len(contacts)}")
for c in contacts:
    print(f"  Contact {c['index']+1}: {c['time']:.2f}s (source: {c['source']}, conf: {c['confidence']:.0%})")