# Tennis Contact Point Detection System

A two-stage detection pipeline for analyzing tennis ball-racket contact in stroke videos.

## Pipeline Overview

### Stage 1: Temporal Detection (WHEN)
Determines the **frame number** where contact occurs.

**Primary**: Audio-based contact detection
- Transient peak analysis with bandpass filtering (1-4kHz)
- Isolates ball-string impact sound from background noise

**Secondary**: Visual verification
- Ball trajectory change detection (velocity reversal, speed spike, deceleration)
- Trajectory analysis from TrackNet ball tracking

**Output**: Contact frame number with confidence score
- HIGH confidence: Audio + visual signals agree
- MEDIUM confidence: Audio only
- LOW confidence (<70%): Manual selection fallback available

### Stage 2: Spatial Localization (WHERE)
Determines the **3D position** of contact relative to the player's body.

**Process**:
1. MediaPipe Pose extracts body keypoints (wrist, elbow, shoulder, hip, ankle)
2. CV techniques pinpoint contact location (ball position, racket center estimation)
3. Transform to body-relative coordinate system

**Output**:
- 3D contact coordinates (pelvis-centered)
- Continuous measurements: height (cm), forward/back (cm), lateral (cm)
- Categorical labels: low/hip-level/waist-level/chest-level/high, behind/neutral/forward
- Position relative to shoulder and hip landmarks

---

## Cells

1. **Setup** - Install dependencies
2. **Upload & Stage 1 Detection** - Temporal contact detection (audio + visual)
3. **Review Detections** - Diagnostic video and confidence check
4. **Manual Frame Selection** - Fallback for low confidence contacts
5. **Stage 2: Spatial Localization** - Analyze selected contact
6. **Batch Analysis** - Analyze all contacts with full measurements
7. **Download Results**

In [None]:
#@title 1. Setup - Install Dependencies
import os, sys, shutil

REPO_URL = "https://github.com/xiaoxiang-ma/tennis_contact_point_spacing.git"
REPO_DIR = "/content/tennis_contact_point_spacing"

# Always re-clone to ensure latest code
if os.path.exists(REPO_DIR):
    shutil.rmtree(REPO_DIR)
!git clone {REPO_URL} {REPO_DIR}

!pip install -q -r {REPO_DIR}/requirements.txt

# Install ipywidgets for manual selection interface
!pip install -q ipywidgets

# Clear any cached module imports from previous runs
for mod_name in list(sys.modules.keys()):
    if mod_name.startswith(("src.", "utils.")):
        del sys.modules[mod_name]

if REPO_DIR not in sys.path:
    sys.path.insert(0, REPO_DIR)

# Create weights directory
os.makedirs(os.path.join(REPO_DIR, "weights"), exist_ok=True)

print("\n" + "="*60)
print("SETUP COMPLETE")
print("="*60)
print("\nTwo-Stage Contact Detection Pipeline Ready:")
print("  Stage 1: Temporal Detection (audio + visual)")
print("  Stage 2: Spatial Localization (pose + CV)")
print("\nTrackNet weights will download automatically on first use (~50MB)")

In [None]:
#@title 2. Upload Video & Stage 1: Temporal Detection
import numpy as np
import cv2
from google.colab import files
from IPython.display import display, Image as IPImage, HTML
import os

from utils.video_io import load_video
from src.tracknet import TrackNetDetector, TrackNetV4Detector
from src.contact_detection import detect_contacts, get_contact_ball_position

#@markdown ## Shot Settings
SHOT_TYPE = "right_forehand"  #@param ["right_forehand", "right_backhand", "left_forehand", "left_backhand"]

#@markdown ## Stage 1: Temporal Detection Settings
#@markdown ---
#@markdown ### Primary: Audio Detection
#@markdown Audio-based detection uses transient peak analysis with bandpass filtering (1-4kHz)
USE_AUDIO = True  #@param {type:"boolean"}

#@markdown ### Secondary: Visual Detection (TrackNet)
#@markdown Ball trajectory analysis for velocity reversal/spike detection
CONFIDENCE_THRESHOLD = 0.3  #@param {type:"slider", min:0.1, max:0.9, step:0.1}
FILL_GAPS = True  #@param {type:"boolean"}
MAX_GAP_FRAMES = 10  #@param {type:"slider", min:3, max:20, step:1}

#@markdown ### Motion Attention (TrackNetV4)
USE_MOTION_ATTENTION = True  #@param {type:"boolean"}
MOTION_BOOST = 2.0  #@param {type:"slider", min:1.0, max:4.0, step:0.5}
MOTION_MIN_ATTENTION = 0.3  #@param {type:"slider", min:0.1, max:0.5, step:0.1}

#@markdown ### Debug Options
DEBUG_MODE = True  #@param {type:"boolean"}
SAVE_DEBUG_FRAMES = False  #@param {type:"boolean"}

# Create output directory
output_dir = "/content/output"
os.makedirs(output_dir, exist_ok=True)

# --- Upload video ---
print("="*60)
print("STAGE 1: TEMPORAL DETECTION (When does contact occur?)")
print("="*60)
print("\nUpload your tennis video (MP4, MOV, etc.):")
uploaded = files.upload()
video_filename = list(uploaded.keys())[0]
video_path = os.path.join("/content", video_filename)
with open(video_path, "wb") as f:
    f.write(uploaded[video_filename])

# --- Load video ---
print(f"\nLoading video: {video_filename}")
frames, metadata = load_video(video_path)
fps = metadata["fps"]
print(f"  Resolution: {metadata['width']}x{metadata['height']}")
print(f"  Frame rate: {fps:.1f} fps")
print(f"  Duration: {metadata['duration_sec']:.2f}s ({len(frames)} frames)")

# --- Initialize TrackNet ---
print("\n" + "-"*60)
print("Initializing ball tracking (TrackNet)...")
if USE_MOTION_ATTENTION:
    print(f"  Using TrackNetV4 with motion attention")
    print(f"  Motion boost: {MOTION_BOOST}x, Min attention: {MOTION_MIN_ATTENTION}")
    tracknet = TrackNetV4Detector(
        weights_path=None,
        device=None,
        confidence_threshold=CONFIDENCE_THRESHOLD,
        save_debug_frames=SAVE_DEBUG_FRAMES,
        debug_output_dir=os.path.join(output_dir, "tracknet_debug"),
        motion_boost=MOTION_BOOST,
        motion_min_attention=MOTION_MIN_ATTENTION,
        use_motion_attention=True,
    )
else:
    print(f"  Using TrackNet (standard)")
    tracknet = TrackNetDetector(
        weights_path=None,
        device=None,
        confidence_threshold=CONFIDENCE_THRESHOLD,
        save_debug_frames=SAVE_DEBUG_FRAMES,
        debug_output_dir=os.path.join(output_dir, "tracknet_debug"),
    )

if not tracknet.weights_loaded:
    print("\n⚠️  WARNING: TrackNet weights failed to load!")
    print("    Ball detection may not work properly.")

# --- Run Stage 1: Temporal Detection ---
print("\n" + "-"*60)
print("Running temporal detection...")
print("  Primary: Audio analysis (1-4kHz transient detection)")
print("  Secondary: Trajectory analysis (velocity reversal/spike)")
if FILL_GAPS:
    print(f"  Gap filling: enabled (max {MAX_GAP_FRAMES} frames)")

contacts, ball_detections = detect_contacts(
    video_path=video_path,
    frames=frames,
    fps=fps,
    tracknet_detector=tracknet,
    use_audio=USE_AUDIO,
    fill_gaps=FILL_GAPS,
    max_gap_frames=MAX_GAP_FRAMES,
    debug=DEBUG_MODE,
    save_debug_frames=SAVE_DEBUG_FRAMES,
)

# Build trajectory for diagnostics
trajectory = tracknet.get_ball_trajectory(ball_detections)

# --- Display Stage 1 Results ---
print("\n" + "="*60)
print("STAGE 1 RESULTS: TEMPORAL DETECTION")
print("="*60)
print(f"\nBall tracking: {len(ball_detections)}/{len(frames)} frames ({100*len(ball_detections)/len(frames):.1f}%)")

print(f"\nDetected {len(contacts)} contact(s):")
print("-"*60)

# Confidence threshold for manual selection recommendation
LOW_CONFIDENCE_THRESHOLD = 0.7
needs_manual = False

contact_info = []
for i, (frame_num, confidence, source) in enumerate(contacts):
    time_sec = frame_num / fps
    ball_pos, ball_method = get_contact_ball_position(frame_num, ball_detections)

    # Determine confidence level
    if source == 'both':
        conf_level = "HIGH (audio+visual)"
    elif confidence >= LOW_CONFIDENCE_THRESHOLD:
        conf_level = "MEDIUM"
    else:
        conf_level = "LOW ⚠️"
        needs_manual = True

    contact_info.append({
        'index': i,
        'frame': frame_num,
        'time': time_sec,
        'confidence': confidence,
        'confidence_level': conf_level,
        'source': source,
        'ball_pos': ball_pos,
        'ball_method': ball_method,
    })

    print(f"\n  Contact {i+1}: Frame {frame_num} ({time_sec:.2f}s)")
    print(f"    Confidence: {confidence:.0%} - {conf_level}")
    print(f"    Source: {source}")
    if ball_pos:
        print(f"    Ball position: ({ball_pos[0]:.0f}, {ball_pos[1]:.0f}) [{ball_method}]")

if not contacts:
    print("\n  ⚠️  No contacts detected automatically.")
    print("      Please use manual frame selection (Cell 4)")
    needs_manual = True

# Store for subsequent cells
ANALYSIS_DATA = {
    'frames': frames,
    'fps': fps,
    'metadata': metadata,
    'contacts': contact_info,
    'contacts_raw': contacts,
    'ball_detections': ball_detections,
    'trajectory': trajectory,
    'shot_type': SHOT_TYPE,
    'video_path': video_path,
    'motion_attention_enabled': USE_MOTION_ATTENTION,
    'needs_manual_selection': needs_manual,
}

print("\n" + "="*60)
if needs_manual:
    print("⚠️  LOW CONFIDENCE DETECTED - Manual review recommended")
    print("   Run Cell 3 for diagnostic video, then Cell 4 for manual selection")
else:
    print("✓ Detection complete - Run Cell 3 to verify with diagnostic video")
print("="*60)

In [None]:
#@title 3. Review Detections - Diagnostic Video
import os
from IPython.display import display, HTML
from google.colab import files as colab_files

from src.visualization import create_diagnostic_video

#@markdown ### Diagnostic Settings
#@markdown Set a known real contact frame (if you know it) to compare with detection:
KNOWN_CONTACT_FRAME = 0  #@param {type:"integer"}
SHOW_TRAJECTORY_TRAIL = True  #@param {type:"boolean"}
TRAJECTORY_TAIL_FRAMES = 30  #@param {type:"integer"}

if 'ANALYSIS_DATA' not in dir():
    raise ValueError("Please run Cell 2 first!")

output_dir = "/content/output"
os.makedirs(output_dir, exist_ok=True)
output_path = os.path.join(output_dir, "diagnostic_video.mp4")

print("="*60)
print("DETECTION VERIFICATION")
print("="*60)
print("\nGenerating diagnostic video with overlays...")
print("  - Ball position (green circle)")
print("  - Velocity vector (yellow arrow)")
print("  - Detection signals: SPIKE (cyan), REVERSAL (magenta), DECEL (orange)")
print("  - Detected contacts (red border)")
if KNOWN_CONTACT_FRAME > 0:
    print(f"  - Known contact frame {KNOWN_CONTACT_FRAME} (yellow border)")

create_diagnostic_video(
    frames=ANALYSIS_DATA['frames'],
    ball_detections=ANALYSIS_DATA['ball_detections'],
    contacts=ANALYSIS_DATA['contacts_raw'],
    fps=ANALYSIS_DATA['fps'],
    output_path=output_path,
    known_contact_frame=KNOWN_CONTACT_FRAME if KNOWN_CONTACT_FRAME > 0 else None,
    show_trajectory=SHOW_TRAJECTORY_TRAIL,
    trajectory_tail=TRAJECTORY_TAIL_FRAMES,
)

print("\n" + "="*60)
print("VIDEO OVERLAY LEGEND")
print("="*60)
print("""
TOP-LEFT PANEL:
  Frame/Time     - Current frame number and timestamp
  Ball           - DETECTED (green) or MISSING (red)
  Speed          - Current ball speed in px/s
  Signal boxes   - SPIKE (cyan), REVERSAL (magenta), DECEL (orange)
  Conf bar       - Detection confidence (green=high, yellow=medium, orange=low)

ON FRAME:
  Green circle   - Ball position (size = confidence)
  Yellow arrow   - Velocity direction
  Green trail    - Recent ball trajectory
  RED border     - Frame detected as contact
  YELLOW border  - Known real contact frame (if set)
""")

# Display video
from base64 import b64encode
video_data = open(output_path, 'rb').read()
video_b64 = b64encode(video_data).decode()
display(HTML(f'''
<video width="800" controls>
  <source src="data:video/mp4;base64,{video_b64}" type="video/mp4">
</video>
'''))

# Show detection summary
print("\n" + "="*60)
print("DETECTION SUMMARY")
print("="*60)
contacts = ANALYSIS_DATA['contacts']
if contacts:
    for c in contacts:
        print(f"\n  Contact {c['index']+1}: Frame {c['frame']} ({c['time']:.2f}s)")
        print(f"    Confidence: {c['confidence']:.0%} - {c['confidence_level']}")
        print(f"    Source: {c['source']}")
else:
    print("\n  No contacts detected.")

# Recommendation
print("\n" + "-"*60)
if ANALYSIS_DATA.get('needs_manual_selection', False):
    print("⚠️  RECOMMENDATION: Use Cell 4 for manual frame selection")
    print("   Some detections have low confidence (<70%)")
else:
    print("✓ Detections look good - proceed to Stage 2 (Cell 5)")

print(f"\nVideo saved to: {output_path}")
print("Downloading...")
colab_files.download(output_path)

In [None]:
#@title 4. Manual Frame Selection (Fallback for Low Confidence)
import numpy as np
import cv2
from IPython.display import display, HTML, clear_output
import ipywidgets as widgets
from base64 import b64encode
import os

#@markdown ### Manual Selection Options
#@markdown Use this cell when automatic detection has low confidence (<70%)
#@markdown or when no contacts were detected.

#@markdown Enter the contact frame number manually:
MANUAL_CONTACT_FRAME = 0  #@param {type:"integer"}

#@markdown Or use the interactive selector below:
USE_INTERACTIVE_SELECTOR = True  #@param {type:"boolean"}

if 'ANALYSIS_DATA' not in dir():
    raise ValueError("Please run Cell 2 first!")

frames = ANALYSIS_DATA['frames']
fps = ANALYSIS_DATA['fps']
contacts = ANALYSIS_DATA['contacts']
ball_detections = ANALYSIS_DATA['ball_detections']

print("="*60)
print("MANUAL FRAME SELECTION")
print("="*60)

if contacts:
    print("\nCurrently detected contacts:")
    for c in contacts:
        status = "✓" if c['confidence'] >= 0.7 else "⚠️ LOW"
        print(f"  {status} Contact {c['index']+1}: Frame {c['frame']} ({c['time']:.2f}s) - {c['confidence']:.0%}")
else:
    print("\nNo contacts were automatically detected.")

print("\n" + "-"*60)

if MANUAL_CONTACT_FRAME > 0:
    # Use the manually specified frame
    manual_frame = MANUAL_CONTACT_FRAME
    print(f"\nUsing manually specified frame: {manual_frame}")

    # Get ball position at this frame
    from src.contact_detection import get_contact_ball_position
    ball_pos, ball_method = get_contact_ball_position(manual_frame, ball_detections)

    # Add or replace contact
    new_contact = {
        'index': 0,
        'frame': manual_frame,
        'time': manual_frame / fps,
        'confidence': 1.0,
        'confidence_level': 'MANUAL',
        'source': 'manual',
        'ball_pos': ball_pos,
        'ball_method': ball_method,
    }

    # Update ANALYSIS_DATA
    ANALYSIS_DATA['contacts'] = [new_contact]
    ANALYSIS_DATA['contacts_raw'] = [(manual_frame, 1.0, 'manual')]
    ANALYSIS_DATA['needs_manual_selection'] = False

    print(f"  Time: {manual_frame / fps:.2f}s")
    if ball_pos:
        print(f"  Ball position: ({ball_pos[0]:.0f}, {ball_pos[1]:.0f}) [{ball_method}]")

    # Show the frame
    frame = frames[manual_frame].copy()
    h, w = frame.shape[:2]
    cv2.putText(frame, f"MANUAL CONTACT - Frame {manual_frame}", (10, 30),
                cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 0), 3)
    cv2.putText(frame, f"MANUAL CONTACT - Frame {manual_frame}", (10, 30),
                cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 255), 2)

    if ball_pos:
        cv2.circle(frame, (int(ball_pos[0]), int(ball_pos[1])), 15, (0, 255, 0), 2)
        cv2.circle(frame, (int(ball_pos[0]), int(ball_pos[1])), 5, (0, 255, 0), -1)

    _, buffer = cv2.imencode('.jpg', frame, [cv2.IMWRITE_JPEG_QUALITY, 90])
    b64 = b64encode(buffer).decode()
    display(HTML(f'<img src="data:image/jpeg;base64,{b64}" width="{min(w, 800)}"/>'))

    print("\n✓ Manual contact frame set. Proceed to Stage 2 (Cell 5)")

elif USE_INTERACTIVE_SELECTOR:
    # Interactive frame browser
    print("\nInteractive Frame Browser")
    print("Use the slider to find the exact contact frame.")
    print("-"*60)

    # State for selected frame
    selected_frame = {'value': contacts[0]['frame'] if contacts else len(frames) // 2}

    # Create widgets
    frame_slider = widgets.IntSlider(
        value=selected_frame['value'],
        min=0,
        max=len(frames) - 1,
        step=1,
        description='Frame:',
        continuous_update=False,
        layout=widgets.Layout(width='80%'),
    )

    step_buttons = widgets.HBox([
        widgets.Button(description='<< -10', layout=widgets.Layout(width='80px')),
        widgets.Button(description='< -1', layout=widgets.Layout(width='70px')),
        widgets.Button(description='+1 >', layout=widgets.Layout(width='70px')),
        widgets.Button(description='+10 >>', layout=widgets.Layout(width='80px')),
    ])

    select_button = widgets.Button(
        description='✓ Select This Frame as Contact',
        button_style='success',
        layout=widgets.Layout(width='300px'),
    )

    output = widgets.Output()

    def update_display(frame_num):
        frame = frames[frame_num].copy()
        h, w = frame.shape[:2]
        time_sec = frame_num / fps

        # Add frame info
        info_text = f"Frame {frame_num}/{len(frames)-1} | {time_sec:.2f}s"
        cv2.putText(frame, info_text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 0), 3)
        cv2.putText(frame, info_text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 255, 255), 2)

        # Mark if this is a detected contact
        for c in contacts:
            if c['frame'] == frame_num:
                marker = f"DETECTED: {c['confidence']:.0%} confidence"
                color = (0, 255, 0) if c['confidence'] >= 0.7 else (0, 165, 255)
                cv2.putText(frame, marker, (10, 60), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 0), 3)
                cv2.putText(frame, marker, (10, 60), cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2)

        # Show ball if detected
        if frame_num in ball_detections:
            bx, by, bconf = ball_detections[frame_num]
            cv2.circle(frame, (int(bx), int(by)), 15, (0, 255, 0), 2)
            cv2.circle(frame, (int(bx), int(by)), 5, (0, 255, 0), -1)

        _, buffer = cv2.imencode('.jpg', frame, [cv2.IMWRITE_JPEG_QUALITY, 85])
        b64 = b64encode(buffer).decode()

        with output:
            clear_output(wait=True)
            display(HTML(f'<img src="data:image/jpeg;base64,{b64}" width="{min(w, 800)}"/>'))

            # Show suggested contacts
            if contacts:
                print("\nSuggested contacts from automatic detection:")
                for c in contacts:
                    status = "✓" if c['confidence'] >= 0.7 else "⚠️"
                    print(f"  {status} Frame {c['frame']} ({c['time']:.2f}s): {c['confidence']:.0%}")

    def on_slider_change(change):
        update_display(change['new'])

    def step(delta):
        new_val = max(0, min(len(frames) - 1, frame_slider.value + delta))
        frame_slider.value = new_val

    def on_select(button):
        selected_frame['value'] = frame_slider.value
        frame_num = frame_slider.value

        # Get ball position
        from src.contact_detection import get_contact_ball_position
        ball_pos, ball_method = get_contact_ball_position(frame_num, ball_detections)

        # Update ANALYSIS_DATA
        new_contact = {
            'index': 0,
            'frame': frame_num,
            'time': frame_num / fps,
            'confidence': 1.0,
            'confidence_level': 'MANUAL',
            'source': 'manual',
            'ball_pos': ball_pos,
            'ball_method': ball_method,
        }
        ANALYSIS_DATA['contacts'] = [new_contact]
        ANALYSIS_DATA['contacts_raw'] = [(frame_num, 1.0, 'manual')]
        ANALYSIS_DATA['needs_manual_selection'] = False

        with output:
            clear_output(wait=True)
            print(f"✓ Selected Frame {frame_num} as contact")
            print(f"  Time: {frame_num / fps:.2f}s")
            if ball_pos:
                print(f"  Ball position: ({ball_pos[0]:.0f}, {ball_pos[1]:.0f})")
            print("\n✓ Proceed to Stage 2 (Cell 5)")

    frame_slider.observe(on_slider_change, names='value')
    step_buttons.children[0].on_click(lambda b: step(-10))
    step_buttons.children[1].on_click(lambda b: step(-1))
    step_buttons.children[2].on_click(lambda b: step(1))
    step_buttons.children[3].on_click(lambda b: step(10))
    select_button.on_click(on_select)

    # Display widgets
    display(widgets.VBox([
        widgets.HTML("<p><b>Instructions:</b> Use slider or buttons to find the frame where ball hits racket.</p>"),
        frame_slider,
        step_buttons,
        select_button,
        output,
    ]))

    update_display(frame_slider.value)

else:
    print("\nTo manually select a contact frame:")
    print("  1. Set MANUAL_CONTACT_FRAME to the frame number, OR")
    print("  2. Enable USE_INTERACTIVE_SELECTOR and use the slider")
    print("\nTip: Watch the diagnostic video (Cell 3) to identify the contact frame.")

In [None]:
#@title 5. Stage 2: Spatial Localization (Where is contact?)
import numpy as np
import pandas as pd
import cv2
from IPython.display import display, Image as IPImage, HTML
import os
from base64 import b64encode

from utils.coordinate_transforms import (
    pelvis_origin_transform, estimate_ground_plane, apply_ground_plane
)
from src.pose_estimation import PoseEstimator
from src.measurements import compute_measurements, compute_relative_to_landmarks
from src.spatial_localization import localize_contact_point
from src.visualization import (
    draw_skeleton, draw_contact_point, save_annotated_frame
)

#@markdown ### Select Contact to Analyze
CONTACT_INDEX = 0  #@param {type:"integer"}

if 'ANALYSIS_DATA' not in dir():
    raise ValueError("Please run Cell 2 first!")

contacts = ANALYSIS_DATA['contacts']
if len(contacts) == 0:
    raise ValueError("No contacts detected. Please use Cell 4 for manual selection.")

if CONTACT_INDEX < 0 or CONTACT_INDEX >= len(contacts):
    raise ValueError(f"Invalid contact index. Valid range: 0-{len(contacts)-1}")

contact = contacts[CONTACT_INDEX]
frame_num = contact['frame']
ball_pos = contact['ball_pos']
ball_method = contact['ball_method']
shot_type = ANALYSIS_DATA['shot_type']
frames = ANALYSIS_DATA['frames']
fps = ANALYSIS_DATA['fps']

print("="*60)
print("STAGE 2: SPATIAL LOCALIZATION (Where is contact?)")
print("="*60)
print(f"\nAnalyzing Contact {CONTACT_INDEX + 1}")
print(f"  Frame: {frame_num} ({contact['time']:.2f}s)")
print(f"  Detection: {contact['confidence']:.0%} confidence ({contact['source']})")
print(f"  Shot type: {shot_type}")

# Get frame
frame = frames[frame_num]
h, w = frame.shape[:2]

# Determine contact wrist
if shot_type in ["right_forehand", "right_backhand"]:
    contact_wrist_name = "right_wrist"
else:
    contact_wrist_name = "left_wrist"

# --- Stage 2a: Pose Estimation ---
print("\n" + "-"*60)
print("Step 1: Pose Estimation (MediaPipe)")
pose_estimator = PoseEstimator(static_image_mode=True, model_complexity=2)
landmarks, raw_result = pose_estimator.process_frame(frame)

if landmarks is None:
    pose_estimator.close()
    raise ValueError("No pose detected. Player may not be visible.")

pixel_lm = pose_estimator.get_pixel_landmarks(raw_result, frame.shape)
pose_estimator.close()
print("  ✓ Pose detected successfully")
print(f"  Detected {len(landmarks)} body landmarks")

# --- Stage 2b: Contact Point Localization ---
print("\n" + "-"*60)
print("Step 2: Contact Point Localization")

loc_result = localize_contact_point(
    frame=frame,
    ball_position=ball_pos,
    pixel_landmarks=pixel_lm,
    landmarks_3d=landmarks,
    shot_type=shot_type,
)

contact_pixel = loc_result['contact_pixel']
contact_method = loc_result['method']
loc_confidence = loc_result['confidence']

print(f"  Method: {contact_method}")
print(f"  Localization confidence: {loc_confidence:.0%}")
if contact_pixel:
    print(f"  Contact position (pixels): ({contact_pixel[0]}, {contact_pixel[1]})")
if loc_result['racket_center']:
    print(f"  Estimated racket center: {loc_result['racket_center']}")

# --- Stage 2c: 3D Coordinate Transform ---
print("\n" + "-"*60)
print("Step 3: Body-Relative 3D Coordinates")

# Transform to pelvis-centered coordinates
pelvis = landmarks.get("pelvis", np.zeros(3))
centered = pelvis_origin_transform(landmarks)
ground_z = estimate_ground_plane(centered)
adjusted = apply_ground_plane(centered, ground_z)

# Get wrist 3D position as base for contact point
wrist_3d = landmarks.get(contact_wrist_name, np.zeros(3))
contact_3d = wrist_3d.copy()

# Adjust contact 3D based on pixel offset from wrist to ball
if ball_pos is not None and contact_wrist_name in pixel_lm:
    wrist_px = pixel_lm[contact_wrist_name]
    px_offset_x = ball_pos[0] - wrist_px[0]
    px_offset_y = ball_pos[1] - wrist_px[1]

    # Estimate scale factor from shoulder width
    left_shoulder = pixel_lm.get("left_shoulder")
    right_shoulder = pixel_lm.get("right_shoulder")
    if left_shoulder and right_shoulder:
        px_shoulder_width = np.sqrt(
            (right_shoulder[0] - left_shoulder[0])**2 +
            (right_shoulder[1] - left_shoulder[1])**2
        )
        real_shoulder_width = np.linalg.norm(
            landmarks.get("right_shoulder", np.zeros(3)) -
            landmarks.get("left_shoulder", np.zeros(3))
        )
        scale = real_shoulder_width / (px_shoulder_width + 1e-6)
    else:
        scale = 0.001

    contact_3d[0] += px_offset_x * scale
    contact_3d[1] += px_offset_y * scale

contact_adjusted = contact_3d - pelvis - np.array([0, 0, ground_z])

print(f"  Contact 3D (pelvis-relative): ({contact_adjusted[0]*100:.1f}, {contact_adjusted[1]*100:.1f}, {contact_adjusted[2]*100:.1f}) cm")

# --- Stage 2d: Compute Measurements ---
print("\n" + "-"*60)
print("Step 4: Computing Measurements")

meas = compute_measurements(adjusted, contact_adjusted)
rel_meas = compute_relative_to_landmarks(landmarks, contact_3d)
meas.update(rel_meas)

# Add metadata
meas["contact_source"] = f"{contact['source']} + {contact_method}"
meas["detection_confidence"] = contact['confidence']
meas["localization_confidence"] = loc_confidence
meas["shot_type"] = shot_type
meas["frame_num"] = frame_num
meas["time_sec"] = contact['time']

# --- Create annotated frame ---
output_dir = "/content/output"
os.makedirs(output_dir, exist_ok=True)

annotated = frame.copy()
annotated = draw_skeleton(annotated, pixel_lm, thickness=3)

if contact_pixel:
    cx, cy = int(contact_pixel[0]), int(contact_pixel[1])
    draw_contact_point(annotated, cx, cy, radius=15)
    label = f"CONTACT ({contact_method})"
    cv2.putText(annotated, label, (cx + 20, cy - 10),
                cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2)

# Add frame info
info_text = f"Frame {frame_num} | {contact['time']:.2f}s | {contact['confidence']:.0%} conf"
cv2.putText(annotated, info_text, (20, 30),
            cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 0), 3)
cv2.putText(annotated, info_text, (20, 30),
            cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)

out_path = os.path.join(output_dir, f"contact_{CONTACT_INDEX+1}_frame_{frame_num}.png")
save_annotated_frame(annotated, out_path)

# --- Display Results ---
print("\n" + "="*60)
print("SPATIAL LOCALIZATION RESULTS")
print("="*60)

# Show annotated frame
_, buffer = cv2.imencode('.jpg', annotated, [cv2.IMWRITE_JPEG_QUALITY, 90])
b64 = b64encode(buffer).decode()
display(HTML(f'<img src="data:image/jpeg;base64,{b64}" width="{min(w, 800)}"/>'))

print(f"\n{'='*60}")
print("CONTACT POINT MEASUREMENTS")
print(f"{'='*60}")
print(f"\nShot type: {shot_type}")
print(f"Contact source: {meas['contact_source']}")

print(f"\n--- POSITION (relative to pelvis) ---")
print(f"  Lateral offset:      {meas.get('lateral_offset_cm', 0):>7.1f} cm  ({meas.get('lateral_offset_inches', 0):>5.1f} in)")
print(f"  Forward/back:        {meas.get('forward_back_cm', 0):>7.1f} cm  ({meas.get('forward_back_inches', 0):>5.1f} in)")
print(f"  Height above ground: {meas.get('height_above_ground_cm', 0):>7.1f} cm  ({meas.get('height_above_ground_inches', 0):>5.1f} in)")

print(f"\n--- CATEGORICAL LABELS ---")
print(f"  Height category:     {meas.get('height_category', 'N/A')}")
print(f"  Forward category:    {meas.get('forward_category', 'N/A')}")
print(f"  Lateral category:    {meas.get('lateral_category', 'N/A')}")
print(f"  Contact zone:        {meas.get('contact_zone', 'N/A')}")

if 'cm_in_front_of_shoulder' in meas:
    print(f"\n--- RELATIVE TO SHOULDER ---")
    print(f"  In front of shoulder: {meas.get('cm_in_front_of_shoulder', 0):>6.1f} cm  ({meas.get('inches_in_front_of_shoulder', 0):>5.1f} in)")
    print(f"  Above/below shoulder: {meas.get('cm_above_shoulder', 0):>6.1f} cm  ({meas.get('inches_above_shoulder', 0):>5.1f} in)")

if 'cm_in_front_of_hip' in meas:
    print(f"\n--- RELATIVE TO HIP ---")
    print(f"  In front of hip:      {meas.get('cm_in_front_of_hip', 0):>6.1f} cm  ({meas.get('inches_in_front_of_hip', 0):>5.1f} in)")
    print(f"  Above/below hip:      {meas.get('cm_above_hip', 0):>6.1f} cm  ({meas.get('inches_above_hip', 0):>5.1f} in)")

print(f"{'='*60}")

# Save measurements
csv_path = os.path.join(output_dir, f"measurements_contact_{CONTACT_INDEX+1}.csv")
pd.DataFrame([meas]).to_csv(csv_path, index=False)
print(f"\nMeasurements saved to: {csv_path}")
print(f"Annotated frame saved to: {out_path}")

In [None]:
#@title 6. Batch Analysis - Analyze All Contacts
import numpy as np
import pandas as pd
import os

from utils.coordinate_transforms import (
    pelvis_origin_transform, estimate_ground_plane, apply_ground_plane
)
from src.pose_estimation import PoseEstimator
from src.measurements import compute_measurements, compute_relative_to_landmarks
from src.spatial_localization import localize_contact_point

if 'ANALYSIS_DATA' not in dir():
    raise ValueError("Please run Cell 2 first!")

contacts = ANALYSIS_DATA['contacts']
frames = ANALYSIS_DATA['frames']
fps = ANALYSIS_DATA['fps']
shot_type = ANALYSIS_DATA['shot_type']
ball_detections = ANALYSIS_DATA['ball_detections']

print("="*60)
print("BATCH ANALYSIS - All Contacts")
print("="*60)

if len(contacts) == 0:
    print("\nNo contacts to analyze.")
else:
    print(f"\nAnalyzing {len(contacts)} contacts...\n")

    all_measurements = []
    pose_estimator = PoseEstimator(static_image_mode=True, model_complexity=2)

    if shot_type in ["right_forehand", "right_backhand"]:
        contact_wrist_name = "right_wrist"
    else:
        contact_wrist_name = "left_wrist"

    for contact in contacts:
        idx = contact['index']
        frame_num = contact['frame']
        ball_pos = contact['ball_pos']
        ball_method = contact['ball_method']

        print(f"Contact {idx+1}: Frame {frame_num}...", end=" ")

        frame = frames[frame_num]
        landmarks, raw_result = pose_estimator.process_frame(frame)

        if landmarks is None:
            print("SKIPPED (no pose)")
            continue

        pixel_lm = pose_estimator.get_pixel_landmarks(raw_result, frame.shape)

        # Spatial localization
        loc_result = localize_contact_point(
            frame=frame,
            ball_position=ball_pos,
            pixel_landmarks=pixel_lm,
            landmarks_3d=landmarks,
            shot_type=shot_type,
        )

        contact_pixel = loc_result['contact_pixel']
        if contact_pixel is None:
            print("SKIPPED (no contact point)")
            continue

        # Transform coordinates
        pelvis = landmarks.get("pelvis", np.zeros(3))
        centered = pelvis_origin_transform(landmarks)
        ground_z = estimate_ground_plane(centered)
        adjusted = apply_ground_plane(centered, ground_z)

        wrist_3d = landmarks.get(contact_wrist_name, np.zeros(3))
        contact_3d = wrist_3d.copy()

        if ball_pos is not None and contact_wrist_name in pixel_lm:
            wrist_px = pixel_lm[contact_wrist_name]
            px_offset_x = ball_pos[0] - wrist_px[0]
            px_offset_y = ball_pos[1] - wrist_px[1]

            left_shoulder = pixel_lm.get("left_shoulder")
            right_shoulder = pixel_lm.get("right_shoulder")
            if left_shoulder and right_shoulder:
                px_shoulder_width = np.sqrt(
                    (right_shoulder[0] - left_shoulder[0])**2 +
                    (right_shoulder[1] - left_shoulder[1])**2
                )
                real_shoulder_width = np.linalg.norm(
                    landmarks.get("right_shoulder", np.zeros(3)) -
                    landmarks.get("left_shoulder", np.zeros(3))
                )
                scale = real_shoulder_width / (px_shoulder_width + 1e-6)
            else:
                scale = 0.001

            contact_3d[0] += px_offset_x * scale
            contact_3d[1] += px_offset_y * scale

        contact_adjusted = contact_3d - pelvis - np.array([0, 0, ground_z])

        # Compute measurements
        meas = compute_measurements(adjusted, contact_adjusted)
        rel_meas = compute_relative_to_landmarks(landmarks, contact_3d)
        meas.update(rel_meas)

        # Add metadata
        meas["contact_index"] = idx + 1
        meas["frame_num"] = frame_num
        meas["time_sec"] = contact['time']
        meas["detection_confidence"] = contact['confidence']
        meas["detection_source"] = contact['source']
        meas["localization_method"] = loc_result['method']
        meas["localization_confidence"] = loc_result['confidence']
        meas["shot_type"] = shot_type

        all_measurements.append(meas)
        print(f"OK ({loc_result['method']}, {meas.get('height_category', '?')})")

    pose_estimator.close()

    if all_measurements:
        output_dir = "/content/output"
        os.makedirs(output_dir, exist_ok=True)
        df = pd.DataFrame(all_measurements)
        csv_path = os.path.join(output_dir, "all_contacts_measurements.csv")
        df.to_csv(csv_path, index=False)

        print(f"\n" + "="*60)
        print("BATCH ANALYSIS SUMMARY")
        print("="*60)

        # Display key columns
        display_cols = [
            'contact_index', 'frame_num', 'time_sec',
            'detection_confidence', 'height_category', 'forward_category',
            'lateral_offset_cm', 'forward_back_cm', 'height_above_ground_cm'
        ]
        display_cols = [c for c in display_cols if c in df.columns]
        display(df[display_cols])

        print(f"\n--- Category Distribution ---")
        if 'height_category' in df.columns:
            print(f"  Height: {df['height_category'].value_counts().to_dict()}")
        if 'forward_category' in df.columns:
            print(f"  Forward: {df['forward_category'].value_counts().to_dict()}")

        print(f"\n--- Measurement Ranges ---")
        print(f"  Height: {df['height_above_ground_cm'].min():.1f} - {df['height_above_ground_cm'].max():.1f} cm")
        print(f"  Forward: {df['forward_back_cm'].min():.1f} - {df['forward_back_cm'].max():.1f} cm")
        print(f"  Lateral: {df['lateral_offset_cm'].min():.1f} - {df['lateral_offset_cm'].max():.1f} cm")

        print(f"\nFull results saved to: {csv_path}")
    else:
        print("\nNo measurements computed.")

In [None]:
#@title 7. Download All Results
from google.colab import files as colab_files
import glob
import os

output_dir = "/content/output"

print("="*60)
print("DOWNLOAD RESULTS")
print("="*60)

all_files = glob.glob(os.path.join(output_dir, "*"))

if not all_files:
    print("\nNo output files found. Run the analysis cells first.")
else:
    print("\nFiles available for download:")
    print("-"*60)

    total_size = 0
    for f in sorted(all_files):
        size_mb = os.path.getsize(f) / (1024 * 1024)
        total_size += size_mb
        fname = os.path.basename(f)
        print(f"  {fname:<40} {size_mb:>6.2f} MB")

    print("-"*60)
    print(f"  {'TOTAL':<40} {total_size:>6.2f} MB")

    print("\nDownloading all files...")
    for f in sorted(all_files):
        print(f"  Downloading: {os.path.basename(f)}")
        colab_files.download(f)

    print("\n" + "="*60)
    print("Download complete!")
    print("="*60)