# Tennis Contact Point Analysis

Upload a **video** of your tennis rally to automatically detect contact frames and analyze body positioning.

**How it works:**
1. **TrackNet** (deep learning) tracks the ball throughout the video
2. **Audio analysis** detects impact sounds (ball hitting racket)
3. Both signals are **fused** for robust contact detection
4. For each contact, we analyze pose and measure contact point spacing

**Cells:**
1. **Setup** - Install dependencies
2. **Upload & Detect** - Upload video and run contact detection
3. **Generate Diagnostic Video** - Creates video with overlays showing detection signals
4. **Text Diagnostics** - Frame-by-frame text output
5. **Analyze Contact** - Detailed analysis of a selected contact
6. **Batch Analysis** - Analyze all contacts at once
7. **Download** - Download results

In [None]:
#@title 1. Setup - Install Dependencies
import os, sys, shutil

REPO_URL = "https://github.com/xiaoxiang-ma/tennis_contact_point_spacing.git"
REPO_DIR = "/content/tennis_contact_point_spacing"

# Always re-clone to ensure latest code
if os.path.exists(REPO_DIR):
    shutil.rmtree(REPO_DIR)
!git clone {REPO_URL} {REPO_DIR}

!pip install -q -r {REPO_DIR}/requirements.txt

# Clear any cached module imports from previous runs
for mod_name in list(sys.modules.keys()):
    if mod_name.startswith(("src.", "utils.")):
        del sys.modules[mod_name]

if REPO_DIR not in sys.path:
    sys.path.insert(0, REPO_DIR)

# Create weights directory
os.makedirs(os.path.join(REPO_DIR, "weights"), exist_ok=True)

print("\nSetup complete!")
print("Note: TrackNet weights will be downloaded automatically on first use (~50MB)")

In [None]:
#@title 2. Upload Video & Detect Contacts
import numpy as np
import cv2
from google.colab import files
from IPython.display import display, Image as IPImage, HTML
import os

from utils.video_io import load_video
from src.tracknet import TrackNetDetector
from src.contact_detection import detect_contacts, get_contact_ball_position

#@markdown ### Settings
SHOT_TYPE = "right_forehand"  #@param ["right_forehand", "right_backhand", "left_forehand", "left_backhand"]
USE_AUDIO = True  #@param {type:"boolean"}
DEBUG_MODE = True  #@param {type:"boolean"}
SAVE_DEBUG_FRAMES = False  #@param {type:"boolean"}

# Create output directory
output_dir = "/content/output"
os.makedirs(output_dir, exist_ok=True)

# --- Upload video ---
print("Upload your tennis video (MP4, MOV, etc.):")
uploaded = files.upload()
video_filename = list(uploaded.keys())[0]
video_path = os.path.join("/content", video_filename)
with open(video_path, "wb") as f:
    f.write(uploaded[video_filename])

# --- Load video ---
print(f"\nLoading video: {video_filename}")
frames, metadata = load_video(video_path)
fps = metadata["fps"]
print(f"  Resolution: {metadata['width']}x{metadata['height']}")
print(f"  Frame rate: {fps:.1f} fps")
print(f"  Duration: {metadata['duration_sec']:.2f}s ({len(frames)} frames)")

# --- Initialize TrackNet ---
print("\nInitializing TrackNet (downloading weights if needed)...")
tracknet = TrackNetDetector(
    weights_path=None,  # Auto-download
    device=None,  # Auto-detect GPU/CPU
    confidence_threshold=0.5,
    save_debug_frames=SAVE_DEBUG_FRAMES,
    debug_output_dir=os.path.join(output_dir, "tracknet_debug"),
)

# Check if weights loaded successfully
if not tracknet.weights_loaded:
    print("\n" + "!"*60)
    print("WARNING: TrackNet weights failed to load!")
    print("Ball detection will NOT work properly.")
    print("Check the error message above for details.")
    print("!"*60 + "\n")

# --- Detect contacts ---
print("\nDetecting contacts (this may take a few minutes)...")
contacts, ball_detections = detect_contacts(
    video_path=video_path,
    frames=frames,
    fps=fps,
    tracknet_detector=tracknet,
    use_audio=USE_AUDIO,
    debug=DEBUG_MODE,
    save_debug_frames=SAVE_DEBUG_FRAMES,
)

# --- Build trajectory for diagnostics ---
trajectory = tracknet.get_ball_trajectory(ball_detections)

# --- Display results ---
print(f"\n" + "="*60)
print(f"CONTACT DETECTION RESULTS")
print(f"="*60)
print(f"Ball detected in {len(ball_detections)}/{len(frames)} frames ({100*len(ball_detections)/len(frames):.1f}%)")

print(f"\nDetected {len(contacts)} contact(s):")
print()

# Store contact info for later use
contact_info = []
for i, (frame_num, confidence, source) in enumerate(contacts):
    time_sec = frame_num / fps
    ball_pos, ball_method = get_contact_ball_position(frame_num, ball_detections)
    
    contact_info.append({
        'index': i,
        'frame': frame_num,
        'time': time_sec,
        'confidence': confidence,
        'source': source,
        'ball_pos': ball_pos,
        'ball_method': ball_method,
    })
    
    print(f"  Contact {i+1}: Frame {frame_num} ({time_sec:.2f}s)")
    print(f"    Confidence: {confidence:.0%} (source: {source})")
    if ball_pos:
        print(f"    Ball position: ({ball_pos[0]:.0f}, {ball_pos[1]:.0f}) [{ball_method}]")
    else:
        print(f"    Ball position: Not available")
    print()

# Store for next cells
ANALYSIS_DATA = {
    'frames': frames,
    'fps': fps,
    'metadata': metadata,
    'contacts': contact_info,
    'contacts_raw': contacts,  # Keep raw format for video generation
    'ball_detections': ball_detections,
    'trajectory': trajectory,
    'shot_type': SHOT_TYPE,
    'video_path': video_path,
}

print("\n" + "="*60)
print("Next: Run the DIAGNOSTIC VIDEO cell to visualize detection")
print("="*60)

In [None]:
#@title 3. Generate Diagnostic Video (Recommended)
import os
from IPython.display import display, HTML
from google.colab import files as colab_files

from src.visualization import create_diagnostic_video

#@markdown ### Settings
#@markdown Set the known real contact frame (if you know it) to highlight in yellow:
KNOWN_CONTACT_FRAME = 81  #@param {type:"integer"}
SHOW_TRAJECTORY_TRAIL = True  #@param {type:"boolean"}
TRAJECTORY_TAIL_FRAMES = 30  #@param {type:"integer"}

if 'ANALYSIS_DATA' not in dir():
    raise ValueError("Please run cell 2 first!")

output_dir = "/content/output"
os.makedirs(output_dir, exist_ok=True)

output_path = os.path.join(output_dir, "diagnostic_video.mp4")

print("Generating diagnostic video with overlays...")
print("  - Ball position (green circle)")
print("  - Velocity vector (yellow arrow)")
print("  - Detection signals: SPIKE (cyan), REVERSAL (magenta), DECEL (orange)")
print("  - Confidence meter")
print("  - Detected contacts (red border)")
if KNOWN_CONTACT_FRAME:
    print(f"  - Known contact frame {KNOWN_CONTACT_FRAME} (yellow border)")
print()

create_diagnostic_video(
    frames=ANALYSIS_DATA['frames'],
    ball_detections=ANALYSIS_DATA['ball_detections'],
    contacts=ANALYSIS_DATA['contacts_raw'],
    fps=ANALYSIS_DATA['fps'],
    output_path=output_path,
    known_contact_frame=KNOWN_CONTACT_FRAME if KNOWN_CONTACT_FRAME > 0 else None,
    show_trajectory=SHOW_TRAJECTORY_TRAIL,
    trajectory_tail=TRAJECTORY_TAIL_FRAMES,
)

print("\n" + "="*60)
print("VIDEO OVERLAY LEGEND")
print("="*60)
print("""
TOP-LEFT PANEL:
  Frame/Time     - Current frame number and timestamp
  Ball           - DETECTED (green) or MISSING (red)
  Speed          - Current ball speed in px/s
  Signal boxes   - SPIKE (cyan), REVERSAL (magenta), DECEL (orange)
  Conf bar       - Detection confidence (green=high, yellow=medium, orange=low)

ON FRAME:
  Green circle   - Ball position (size = confidence)
  Yellow arrow   - Velocity direction
  Green trail    - Recent ball trajectory
  RED border     - Frame detected as contact
  YELLOW border  - Known real contact frame (for comparison)
""")

# Display video in notebook
from base64 import b64encode
video_data = open(output_path, 'rb').read()
video_b64 = b64encode(video_data).decode()
display(HTML(f'''
<video width="800" controls>
  <source src="data:video/mp4;base64,{video_b64}" type="video/mp4">
</video>
'''))

print(f"\nVideo saved to: {output_path}")
print("Downloading...")
colab_files.download(output_path)

In [None]:
#@title 4. Text Diagnostics - Frame-by-Frame Analysis (Optional)
import numpy as np
import pandas as pd
from IPython.display import display

from src.contact_detection import compute_ball_velocity, debug_frame_region

#@markdown ### Diagnostic Settings
KNOWN_CONTACT_FRAME = 81  #@param {type:"integer"}
SHOW_ALL_FRAMES = False  #@param {type:"boolean"}
FOCUS_WINDOW = 15  #@param {type:"integer"}

if 'ANALYSIS_DATA' not in dir():
    raise ValueError("Please run cell 2 first!")

ball_detections = ANALYSIS_DATA['ball_detections']
trajectory = ANALYSIS_DATA['trajectory']
fps = ANALYSIS_DATA['fps']
contacts = ANALYSIS_DATA['contacts']
num_frames = len(ANALYSIS_DATA['frames'])

# Compute velocities
velocities = compute_ball_velocity(trajectory, fps)
vel_dict = {v[0]: {'vx': v[1], 'vy': v[2], 'speed': v[3]} for v in velocities}

# Reference speed for spike detection
speeds = np.array([v[3] for v in velocities])
nonzero_speeds = speeds[speeds > 1e-3]
ref_speed = np.median(nonzero_speeds) if len(nonzero_speeds) > 0 else 100
spike_threshold = ref_speed * 2.0

print("="*80)
print("DIAGNOSTIC: FRAME-BY-FRAME CONTACT DETECTION ANALYSIS")
print("="*80)
print(f"\nVideo: {num_frames} frames at {fps:.1f} fps")
print(f"Ball detected in: {len(ball_detections)} frames ({100*len(ball_detections)/num_frames:.1f}%)")
print(f"Velocity samples: {len(velocities)}")
print(f"\nReference speed (median): {ref_speed:.1f} px/s")
print(f"Spike threshold (2x median): {spike_threshold:.1f} px/s")

# Find detection gaps
detected_frames = set(ball_detections.keys())
if detected_frames:
    min_f, max_f = min(detected_frames), max(detected_frames)
    gaps = []
    gap_start = None
    for f in range(min_f, max_f + 1):
        if f not in detected_frames:
            if gap_start is None:
                gap_start = f
        else:
            if gap_start is not None:
                gap_len = f - gap_start
                if gap_len >= 3:
                    gaps.append((gap_start, f - 1, gap_len))
                gap_start = None
    if gap_start is not None:
        gaps.append((gap_start, max_f, max_f - gap_start + 1))
    
    if gaps:
        print(f"\n" + "-"*40)
        print("DETECTION GAPS (3+ consecutive missing frames):")
        print("-"*40)
        for start, end, length in gaps:
            contains_known = start <= KNOWN_CONTACT_FRAME <= end
            marker = " <<<< KNOWN CONTACT IN THIS GAP!" if contains_known else ""
            print(f"  Frames {start:3d}-{end:3d}: {length} frames missing{marker}")

# Build frame-by-frame data
print(f"\n" + "="*80)
print(f"FRAME-BY-FRAME VELOCITY ANALYSIS")
if not SHOW_ALL_FRAMES:
    print(f"(Showing frames {KNOWN_CONTACT_FRAME - FOCUS_WINDOW} to {KNOWN_CONTACT_FRAME + FOCUS_WINDOW})")
print("="*80)

print(f"\n{'Frame':>6} {'Time':>7} {'Ball?':>6} {'Speed':>10} {'Vx':>8} {'Vy':>8} {'Spike':>7} {'Reversal':>9} {'Decel':>7} {'Conf':>6} {'Notes'}")
print("-"*100)

# Determine frame range
if SHOW_ALL_FRAMES:
    frame_range = range(num_frames)
else:
    frame_range = range(
        max(0, KNOWN_CONTACT_FRAME - FOCUS_WINDOW),
        min(num_frames, KNOWN_CONTACT_FRAME + FOCUS_WINDOW + 1)
    )

rows = []
prev_vel = None

for f in frame_range:
    time_sec = f / fps
    
    # Ball detection status
    ball_detected = f in ball_detections
    ball_str = "YES" if ball_detected else "---"
    
    # Velocity data
    if f in vel_dict:
        v = vel_dict[f]
        speed = v['speed']
        vx, vy = v['vx'], v['vy']
        speed_str = f"{speed:8.1f}"
        vx_str = f"{vx:+7.1f}"
        vy_str = f"{vy:+7.1f}"
        
        # Spike detection
        is_spike = speed > spike_threshold
        spike_str = f"{speed/ref_speed:.1f}x" if is_spike else ""
        
        # Reversal detection (need previous velocity)
        reversal_str = ""
        reversal_val = 0
        if prev_vel is not None:
            dot = prev_vel['vx'] * vx + prev_vel['vy'] * vy
            if dot < 0:
                mag0 = np.sqrt(prev_vel['vx']**2 + prev_vel['vy']**2)
                mag1 = np.sqrt(vx**2 + vy**2)
                cos_angle = dot / (mag0 * mag1 + 1e-6)
                reversal_val = max(0, -cos_angle)
                reversal_str = f"{reversal_val:.2f}"
        
        # Deceleration detection
        decel_str = ""
        decel_val = 0
        if prev_vel is not None:
            prev_speed = prev_vel['speed']
            if prev_speed > ref_speed and speed < prev_speed * 0.5:
                decel_val = 1 - (speed / prev_speed)
                decel_str = f"{decel_val:.2f}"
        
        # Compute confidence
        conf = 0.0
        if reversal_val > 0:
            conf += 0.3 * (0.5 + 0.5 * reversal_val)
        if is_spike:
            conf += 0.2 * min((speed / ref_speed) / 2.0, 1.5)
        if decel_val > 0:
            conf += 0.2 * decel_val
        conf = min(conf, 0.7)
        conf_str = f"{conf:.2f}" if conf > 0.1 else ""
        
        prev_vel = v
    else:
        speed_str = "---"
        vx_str = "---"
        vy_str = "---"
        spike_str = ""
        reversal_str = ""
        decel_str = ""
        conf_str = ""
        conf = 0
        is_spike = False
    
    # Notes
    notes = []
    if f == KNOWN_CONTACT_FRAME:
        notes.append("<<< KNOWN CONTACT")
    if any(c['frame'] == f for c in contacts):
        c = next(c for c in contacts if c['frame'] == f)
        notes.append(f"DETECTED (conf={c['confidence']:.2f})")
    if not ball_detected and f > 0 and (f-1) in ball_detections:
        notes.append("ball lost")
    if ball_detected and f > 0 and (f-1) not in ball_detections:
        notes.append("ball found")
    
    notes_str = " | ".join(notes)
    
    # Highlight important rows
    marker = ">>>" if f == KNOWN_CONTACT_FRAME else "   "
    
    print(f"{marker}{f:3d} {time_sec:7.2f}s {ball_str:>6} {speed_str:>10} {vx_str:>8} {vy_str:>8} {spike_str:>7} {reversal_str:>9} {decel_str:>7} {conf_str:>6}  {notes_str}")

print("-"*100)

# Summary
print(f"\n" + "="*80)
print("SUMMARY")
print("="*80)
print(f"\nKnown contact frame: {KNOWN_CONTACT_FRAME}")
print(f"  Ball detected at frame {KNOWN_CONTACT_FRAME}? {KNOWN_CONTACT_FRAME in ball_detections}")
print(f"  Velocity data at frame {KNOWN_CONTACT_FRAME}? {KNOWN_CONTACT_FRAME in vel_dict}")

# Was it detected?
detected_at_known = any(c['frame'] == KNOWN_CONTACT_FRAME for c in contacts)
if detected_at_known:
    print(f"  Contact WAS detected at frame {KNOWN_CONTACT_FRAME}")
else:
    # Find closest detection
    if contacts:
        closest = min(contacts, key=lambda c: abs(c['frame'] - KNOWN_CONTACT_FRAME))
        print(f"  Contact NOT detected at frame {KNOWN_CONTACT_FRAME}")
        print(f"  Closest detection: frame {closest['frame']} ({abs(closest['frame'] - KNOWN_CONTACT_FRAME)} frames away)")
    else:
        print(f"  No contacts detected at all!")

# Check if frame is in a gap
in_gap = KNOWN_CONTACT_FRAME not in ball_detections
if in_gap:
    print(f"\n  ** FRAME {KNOWN_CONTACT_FRAME} IS MISSING BALL DETECTION **")
    print(f"     This is likely why the contact was not detected.")
    print(f"     The ball is probably occluded by the racket at contact.")

In [None]:
#@title 5. Analyze Selected Contact
import numpy as np
import pandas as pd
import cv2
from IPython.display import display, Image as IPImage
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import os

from utils.coordinate_transforms import (
    pelvis_origin_transform, estimate_ground_plane, apply_ground_plane
)
from src.pose_estimation import PoseEstimator
from src.measurements import compute_measurements
from src.visualization import (
    draw_skeleton, draw_contact_point, save_annotated_frame
)

#@markdown ### Select which contact to analyze:
CONTACT_INDEX = 0  #@param {type:"integer"}

# Validate
if 'ANALYSIS_DATA' not in dir():
    raise ValueError("Please run cell 2 first to detect contacts!")

contacts = ANALYSIS_DATA['contacts']
if len(contacts) == 0:
    raise ValueError("No contacts were detected in the video.")

if CONTACT_INDEX < 0 or CONTACT_INDEX >= len(contacts):
    raise ValueError(f"Invalid contact index. Valid range: 0-{len(contacts)-1}")

contact = contacts[CONTACT_INDEX]
frame_num = contact['frame']
ball_pos = contact['ball_pos']
ball_method = contact['ball_method']
shot_type = ANALYSIS_DATA['shot_type']
frames = ANALYSIS_DATA['frames']
fps = ANALYSIS_DATA['fps']

print(f"Analyzing Contact {CONTACT_INDEX + 1}")
print(f"  Frame: {frame_num} ({contact['time']:.2f}s)")
print(f"  Detection confidence: {contact['confidence']:.0%}")
print(f"  Ball position method: {ball_method}")

# Get frame
frame = frames[frame_num]
h, w = frame.shape[:2]

# Determine which wrist to use based on shot type
if shot_type in ["right_forehand", "right_backhand"]:
    contact_wrist_name = "right_wrist"
    hand_label = "RIGHT"
else:
    contact_wrist_name = "left_wrist"
    hand_label = "LEFT"

# --- Pose estimation ---
print("\nEstimating pose...")
pose_estimator = PoseEstimator(static_image_mode=True, model_complexity=2)
landmarks, raw_result = pose_estimator.process_frame(frame)

if landmarks is None:
    pose_estimator.close()
    raise ValueError("No pose detected in frame. The player may not be clearly visible.")

pixel_lm = pose_estimator.get_pixel_landmarks(raw_result, frame.shape)
pose_estimator.close()
print("  Pose detected successfully!")

# --- Get contact point (ball position or fallback to wrist) ---
if ball_pos is not None:
    contact_pixel = ball_pos
    contact_source = f"ball ({ball_method})"
    print(f"  Using ball position as contact point: ({ball_pos[0]:.0f}, {ball_pos[1]:.0f})")
else:
    # Fallback to wrist
    if contact_wrist_name in pixel_lm:
        contact_pixel = pixel_lm[contact_wrist_name]
        contact_source = "wrist (fallback)"
        print(f"  Ball not detected - using wrist position as fallback")
    else:
        raise ValueError("Neither ball nor wrist position available")

# --- Transform coordinates for measurements ---
pelvis = landmarks.get("pelvis", np.zeros(3))
centered = pelvis_origin_transform(landmarks)
ground_z = estimate_ground_plane(centered)
adjusted = apply_ground_plane(centered, ground_z)

# For contact point, we need to estimate 3D position from 2D ball detection
# Use wrist depth as approximation (ball is near wrist at contact)
wrist_3d = landmarks.get(contact_wrist_name, np.zeros(3))
contact_3d = wrist_3d.copy()  # Start with wrist position

# If we have ball pixel position, adjust x/y based on pixel offset from wrist
if ball_pos is not None and contact_wrist_name in pixel_lm:
    wrist_px = pixel_lm[contact_wrist_name]
    # Estimate scale from wrist (pixels to normalized coords)
    # This is approximate - proper depth estimation would need stereo or depth camera
    px_offset_x = ball_pos[0] - wrist_px[0]
    px_offset_y = ball_pos[1] - wrist_px[1]
    
    # Rough scaling (assume ~1000px corresponds to ~1 normalized unit)
    scale = 0.001  
    contact_3d[0] += px_offset_x * scale
    contact_3d[1] += px_offset_y * scale

contact_adjusted = contact_3d - pelvis - np.array([0, 0, ground_z])

# --- Compute measurements ---
print("Computing measurements...")
meas = compute_measurements(adjusted, contact_adjusted)
meas["contact_source"] = contact_source
meas["ball_detection_method"] = ball_method
meas["shot_type"] = shot_type
meas["frame_num"] = frame_num
meas["contact_confidence"] = contact['confidence']

# --- Create output directory ---
output_dir = "/content/output"
os.makedirs(output_dir, exist_ok=True)

# --- 1. Annotated frame with skeleton + contact point ---
print("\n" + "="*60)
print("CONTACT FRAME ANALYSIS")
print("="*60)

annotated = frame.copy()
annotated = draw_skeleton(annotated, pixel_lm, thickness=3)

# Draw contact point (ball position or wrist)
cx, cy = int(contact_pixel[0]), int(contact_pixel[1])
draw_contact_point(annotated, cx, cy, radius=15)

# Add label
label = f"CONTACT ({contact_source})"
cv2.putText(annotated, label, (cx + 20, cy - 10),
            cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2)

# Add frame info
info_text = f"Frame {frame_num} | {contact['time']:.2f}s | Conf: {contact['confidence']:.0%}"
cv2.putText(annotated, info_text, (20, 30),
            cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)

out_path = os.path.join(output_dir, f"contact_{CONTACT_INDEX+1}_frame_{frame_num}.png")
save_annotated_frame(annotated, out_path)

print("\nAnnotated contact frame:")
display(IPImage(filename=out_path, width=800))

# --- 2. Measurements ---
print("\n" + "="*60)
print("CONTACT POINT MEASUREMENTS")
print("="*60)
print(f"\nShot type: {shot_type}")
print(f"Contact source: {contact_source}")
print()
print(f"Lateral offset:      {meas.get('lateral_offset_cm', 0):>7.1f} cm  ({meas.get('lateral_offset_inches', 0):>5.1f} in)")
print(f"  (+ = to your left, - = to your right)")
print(f"Forward/back:        {meas.get('forward_back_cm', 0):>7.1f} cm  ({meas.get('forward_back_inches', 0):>5.1f} in)")
print(f"  (+ = in front, - = behind)")
print(f"Height above ground: {meas.get('height_above_ground_cm', 0):>7.1f} cm  ({meas.get('height_above_ground_inches', 0):>5.1f} in)")
if "shoulder_line_distance_cm" in meas:
    print(f"Shoulder line dist:  {meas.get('shoulder_line_distance_cm', 0):>7.1f} cm  ({meas.get('shoulder_line_distance_inches', 0):>5.1f} in)")
if "relative_to_shoulder_height_cm" in meas:
    print(f"vs Shoulder height:  {meas.get('relative_to_shoulder_height_cm', 0):>7.1f} cm  ({meas.get('relative_to_shoulder_height_inches', 0):>5.1f} in)")
    print(f"  (+ = above shoulder, - = below)")
print("="*60)

# Save measurements CSV
csv_path = os.path.join(output_dir, f"measurements_contact_{CONTACT_INDEX+1}.csv")
df = pd.DataFrame([meas])
df.to_csv(csv_path, index=False)
print(f"\nMeasurements saved to {csv_path}")

In [None]:
#@title 6. Batch Analysis - Analyze All Contacts
import numpy as np
import pandas as pd
import os

from utils.coordinate_transforms import (
    pelvis_origin_transform, estimate_ground_plane, apply_ground_plane
)
from src.pose_estimation import PoseEstimator
from src.measurements import compute_measurements

#@markdown Run this cell to analyze ALL detected contacts at once.

if 'ANALYSIS_DATA' not in dir():
    raise ValueError("Please run cell 2 first!")

contacts = ANALYSIS_DATA['contacts']
frames = ANALYSIS_DATA['frames']
fps = ANALYSIS_DATA['fps']
shot_type = ANALYSIS_DATA['shot_type']
ball_detections = ANALYSIS_DATA['ball_detections']

if len(contacts) == 0:
    print("No contacts to analyze.")
else:
    print(f"Analyzing {len(contacts)} contacts...\n")
    
    all_measurements = []
    pose_estimator = PoseEstimator(static_image_mode=True, model_complexity=2)
    
    # Determine wrist based on shot type
    if shot_type in ["right_forehand", "right_backhand"]:
        contact_wrist_name = "right_wrist"
    else:
        contact_wrist_name = "left_wrist"
    
    for contact in contacts:
        idx = contact['index']
        frame_num = contact['frame']
        ball_pos = contact['ball_pos']
        ball_method = contact['ball_method']
        
        print(f"Contact {idx+1}: Frame {frame_num}...", end=" ")
        
        frame = frames[frame_num]
        landmarks, raw_result = pose_estimator.process_frame(frame)
        
        if landmarks is None:
            print("SKIPPED (no pose detected)")
            continue
        
        pixel_lm = pose_estimator.get_pixel_landmarks(raw_result, frame.shape)
        
        # Get contact point
        if ball_pos is not None:
            contact_pixel = ball_pos
            contact_source = f"ball ({ball_method})"
        elif contact_wrist_name in pixel_lm:
            contact_pixel = pixel_lm[contact_wrist_name]
            contact_source = "wrist (fallback)"
        else:
            print("SKIPPED (no contact point)")
            continue
        
        # Transform and measure
        pelvis = landmarks.get("pelvis", np.zeros(3))
        centered = pelvis_origin_transform(landmarks)
        ground_z = estimate_ground_plane(centered)
        adjusted = apply_ground_plane(centered, ground_z)
        
        wrist_3d = landmarks.get(contact_wrist_name, np.zeros(3))
        contact_3d = wrist_3d.copy()
        
        if ball_pos is not None and contact_wrist_name in pixel_lm:
            wrist_px = pixel_lm[contact_wrist_name]
            px_offset_x = ball_pos[0] - wrist_px[0]
            px_offset_y = ball_pos[1] - wrist_px[1]
            scale = 0.001
            contact_3d[0] += px_offset_x * scale
            contact_3d[1] += px_offset_y * scale
        
        contact_adjusted = contact_3d - pelvis - np.array([0, 0, ground_z])
        meas = compute_measurements(adjusted, contact_adjusted)
        
        meas["contact_index"] = idx + 1
        meas["frame_num"] = frame_num
        meas["time_sec"] = contact['time']
        meas["contact_confidence"] = contact['confidence']
        meas["contact_source"] = contact_source
        meas["detection_source"] = contact['source']
        meas["shot_type"] = shot_type
        
        all_measurements.append(meas)
        print("OK")
    
    pose_estimator.close()
    
    # Save combined results
    if all_measurements:
        output_dir = "/content/output"
        os.makedirs(output_dir, exist_ok=True)
        
        df = pd.DataFrame(all_measurements)
        csv_path = os.path.join(output_dir, "all_contacts_measurements.csv")
        df.to_csv(csv_path, index=False)
        
        print(f"\n" + "="*60)
        print("SUMMARY OF ALL CONTACTS")
        print("="*60)
        display(df[['contact_index', 'frame_num', 'time_sec', 'contact_confidence', 
                    'lateral_offset_cm', 'forward_back_cm', 'height_above_ground_cm',
                    'contact_source']])
        
        print(f"\nResults saved to {csv_path}")

In [None]:
#@title 7. Download All Results
from google.colab import files as colab_files
import glob
import os

output_dir = "/content/output"
all_files = glob.glob(os.path.join(output_dir, "*"))

print("Files available for download:")
for f in all_files:
    size_mb = os.path.getsize(f) / (1024 * 1024)
    print(f"  {os.path.basename(f)} ({size_mb:.1f} MB)")

print("\nDownloading...")
for f in all_files:
    colab_files.download(f)