# Tennis Contact Detection by Sound

Detect ball-racket contact frames using **pure audio analysis** â€” no visual ball tracking needed.

**How it works:**
1. Extract audio from your tennis video
2. Apply a bandpass filter (1–4 kHz) to isolate the characteristic impact "thump"
3. Compute the amplitude envelope and detect peaks above a noise-adaptive threshold
4. Map detected peaks back to video frames
5. Display annotated debug frames for visual inspection

**Why audio?** Tennis ball impacts produce a sharp, distinctive sound (5–20ms duration) in the 1–4 kHz frequency range. This is far more reliable than visual ball tracking which fails on blurry/occluded balls.

**Cells:**
1. **Setup** — Install dependencies
2. **Upload & Detect** — Upload video, run audio contact detection
3. **Audio Debug** — Waveform and envelope plots with detected peaks
4. **Visual Inspection** — Debug frames showing each detected contact
5. **Pose Analysis** — Analyze body positioning at selected contact (optional)
6. **Download** — Download results

In [None]:
#@title 1. Setup - Install Dependencies
import os, sys, shutil

REPO_URL = "https://github.com/xiaoxiang-ma/tennis_contact_point_spacing.git"
REPO_DIR = "/content/tennis_contact_point_spacing"

# Always re-clone to ensure latest code
if os.path.exists(REPO_DIR):
    shutil.rmtree(REPO_DIR)
!git clone {REPO_URL} {REPO_DIR}

!pip install -q -r {REPO_DIR}/requirements.txt

# Clear any cached module imports from previous runs
for mod_name in list(sys.modules.keys()):
    if mod_name.startswith(("src.", "utils.")):
        del sys.modules[mod_name]

if REPO_DIR not in sys.path:
    sys.path.insert(0, REPO_DIR)

print("\nSetup complete! No GPU required — audio detection runs on CPU.")

In [None]:
#@title 2. Upload Video & Detect Contacts by Sound
import numpy as np
import cv2
from google.colab import files
from IPython.display import display, Image as IPImage, HTML
import os

from utils.video_io import load_video
from src.contact_detection import detect_contacts, get_debug_audio_data

#@markdown ### Audio Detection Settings
#@markdown **Bandpass Filter** — Tennis impacts are strongest in 1–4 kHz
LOW_FREQ = 1000  #@param {type:"integer"}
HIGH_FREQ = 4000  #@param {type:"integer"}

#@markdown **Sensitivity** — Lower threshold factor = more sensitive (more detections, more false positives)
PEAK_THRESHOLD_FACTOR = 3.0  #@param {type:"slider", min:1.5, max:6.0, step:0.5}

#@markdown **Noise Floor Percentile** — Baseline noise level estimation
NOISE_PERCENTILE = 75.0  #@param {type:"slider", min:50.0, max:95.0, step:5.0}

#@markdown **Min Gap Between Contacts (ms)** — Suppress duplicate detections
MIN_GAP_MS = 300  #@param {type:"slider", min:100, max:1000, step:50}

#@markdown ### Other Settings
SAMPLE_RATE = 22050  #@param {type:"integer"}
DEBUG_MODE = True  #@param {type:"boolean"}

# Create output directory
output_dir = "/content/output"
os.makedirs(output_dir, exist_ok=True)

# --- Upload video ---
print("Upload your tennis video (MP4, MOV, etc.):")
uploaded = files.upload()
video_filename = list(uploaded.keys())[0]
video_path = os.path.join("/content", video_filename)
with open(video_path, "wb") as f:
    f.write(uploaded[video_filename])

# --- Load video ---
print(f"\nLoading video: {video_filename}")
frames, metadata = load_video(video_path)
fps = metadata["fps"]
print(f"  Resolution: {metadata['width']}x{metadata['height']}")
print(f"  Frame rate: {fps:.1f} fps")
print(f"  Duration: {metadata['duration_sec']:.2f}s ({len(frames)} frames)")

# --- Detect contacts via audio ---
print(f"\n{'='*60}")
print("DETECTING CONTACTS BY SOUND")
print(f"{'='*60}")

contacts = detect_contacts(
    video_path=video_path,
    fps=fps,
    sample_rate=SAMPLE_RATE,
    low_freq=LOW_FREQ,
    high_freq=HIGH_FREQ,
    min_gap_ms=MIN_GAP_MS,
    noise_percentile=NOISE_PERCENTILE,
    peak_threshold_factor=PEAK_THRESHOLD_FACTOR,
    debug=DEBUG_MODE,
)

# --- Get audio debug data for plotting ---
audio_data = get_debug_audio_data(
    video_path, SAMPLE_RATE, LOW_FREQ, HIGH_FREQ,
)

# --- Display results ---
print(f"\n{'='*60}")
print(f"RESULTS: {len(contacts)} contact(s) detected")
print(f"{'='*60}")

contact_info = []
for i, (frame_num, confidence, source) in enumerate(contacts):
    time_sec = frame_num / fps
    contact_info.append({
        'index': i,
        'frame': frame_num,
        'time': time_sec,
        'confidence': confidence,
    })
    print(f"  Contact {i+1}: Frame {frame_num} ({time_sec:.2f}s) — confidence {confidence:.0%}")

# Store for subsequent cells
ANALYSIS_DATA = {
    'frames': frames,
    'fps': fps,
    'metadata': metadata,
    'contacts': contact_info,
    'contacts_raw': contacts,
    'audio_data': audio_data,
    'video_path': video_path,
}

print(f"\nNext: Run cell 3 to see the audio waveform, then cell 4 for debug frames.")

In [None]:
#@title 3. Audio Debug — Waveform, Envelope & Detected Peaks
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import display

if 'ANALYSIS_DATA' not in dir():
    raise ValueError("Please run cell 2 first!")

audio_data = ANALYSIS_DATA['audio_data']
contacts = ANALYSIS_DATA['contacts']
fps = ANALYSIS_DATA['fps']
sr = audio_data['sample_rate']
raw_audio = audio_data['raw_audio']
envelope = audio_data['envelope']
duration = audio_data['duration_sec']

# Time axes
time_audio = np.arange(len(raw_audio)) / sr
time_env = np.arange(len(envelope)) / sr

# Noise floor and threshold (match detection logic)
noise_floor = np.percentile(envelope, NOISE_PERCENTILE)
threshold = noise_floor * PEAK_THRESHOLD_FACTOR

fig, axes = plt.subplots(3, 1, figsize=(16, 10), sharex=True)

# --- Raw waveform ---
axes[0].plot(time_audio, raw_audio, color='steelblue', linewidth=0.3, alpha=0.7)
axes[0].set_ylabel('Amplitude')
axes[0].set_title('Raw Audio Waveform')
for c in contacts:
    axes[0].axvline(c['time'], color='red', linewidth=1.5, alpha=0.8, linestyle='--')
axes[0].set_xlim(0, duration)

# --- Filtered envelope ---
axes[1].plot(time_env, envelope, color='darkorange', linewidth=0.5)
axes[1].axhline(noise_floor, color='gray', linewidth=1, linestyle=':', label=f'Noise floor ({NOISE_PERCENTILE}th pctl)')
axes[1].axhline(threshold, color='red', linewidth=1, linestyle='--', label=f'Threshold ({PEAK_THRESHOLD_FACTOR}x noise)')
axes[1].set_ylabel('Envelope Amplitude')
axes[1].set_title(f'Bandpass Filtered Envelope ({LOW_FREQ}–{HIGH_FREQ} Hz)')
axes[1].legend(loc='upper right')
for c in contacts:
    axes[1].axvline(c['time'], color='red', linewidth=1.5, alpha=0.8, linestyle='--')

# --- Zoomed envelope with contact markers ---
axes[2].plot(time_env, envelope, color='darkorange', linewidth=0.5)
axes[2].axhline(threshold, color='red', linewidth=1, linestyle='--', alpha=0.5)
axes[2].set_ylabel('Envelope Amplitude')
axes[2].set_xlabel('Time (seconds)')
axes[2].set_title('Detected Contacts')

for c in contacts:
    axes[2].axvline(c['time'], color='red', linewidth=2, alpha=0.9)
    axes[2].annotate(
        f"Contact {c['index']+1}\nFrame {c['frame']}\n{c['confidence']:.0%}",
        xy=(c['time'], threshold),
        xytext=(c['time'] + duration * 0.01, threshold * 1.5),
        fontsize=8,
        color='red',
        fontweight='bold',
        arrowprops=dict(arrowstyle='->', color='red', lw=1),
    )

plt.tight_layout()
plt.savefig(os.path.join(output_dir, 'audio_debug.png'), dpi=150, bbox_inches='tight')
plt.show()

print(f"\nAudio stats:")
print(f"  Duration: {duration:.2f}s")
print(f"  Sample rate: {sr} Hz")
print(f"  Noise floor: {noise_floor:.6f}")
print(f"  Detection threshold: {threshold:.6f}")
print(f"  Contacts found: {len(contacts)}")

In [None]:
#@title 4. Visual Inspection — Debug Frames at Each Contact
import cv2
import numpy as np
import os
from IPython.display import display, Image as IPImage, HTML

from src.visualization import annotate_contact_frame, save_annotated_frame

if 'ANALYSIS_DATA' not in dir():
    raise ValueError("Please run cell 2 first!")

#@markdown ### Display Settings
#@markdown **Frames before/after contact** to show for context
CONTEXT_FRAMES = 2  #@param {type:"slider", min:0, max:5, step:1}

contacts = ANALYSIS_DATA['contacts']
frames = ANALYSIS_DATA['frames']
fps = ANALYSIS_DATA['fps']
num_frames = len(frames)

if len(contacts) == 0:
    print("No contacts detected. Try lowering PEAK_THRESHOLD_FACTOR in cell 2.")
else:
    print(f"Showing debug frames for {len(contacts)} detected contact(s)...")
    print(f"Context: {CONTEXT_FRAMES} frame(s) before and after each contact\n")

    for c in contacts:
        idx = c['index']
        frame_num = c['frame']
        confidence = c['confidence']
        time_sec = c['time']

        print(f"{'='*70}")
        print(f"CONTACT {idx+1}: Frame {frame_num} ({time_sec:.2f}s) | Confidence: {confidence:.0%}")
        print(f"{'='*70}")

        # Collect frames: context before, contact frame, context after
        frame_range = range(
            max(0, frame_num - CONTEXT_FRAMES),
            min(num_frames, frame_num + CONTEXT_FRAMES + 1)
        )

        row_images = []
        for f in frame_range:
            img = frames[f].copy()
            h, w = img.shape[:2]

            if f == frame_num:
                # This is the contact frame — annotate it
                img = annotate_contact_frame(img, f, fps, confidence)
            else:
                # Context frame — just add frame number
                label = f"Frame {f} ({f/fps:.2f}s)"
                cv2.putText(img, label, (10, 30),
                           cv2.FONT_HERSHEY_SIMPLEX, 0.6, (200, 200, 200), 2)

            # Resize for display if too large
            max_display_h = 300
            if h > max_display_h:
                scale = max_display_h / h
                img = cv2.resize(img, (int(w * scale), max_display_h))

            row_images.append(img)

        # Concatenate horizontally
        # Pad to same height if needed
        max_h = max(im.shape[0] for im in row_images)
        padded = []
        for im in row_images:
            if im.shape[0] < max_h:
                pad = np.zeros((max_h - im.shape[0], im.shape[1], 3), dtype=np.uint8)
                im = np.vstack([im, pad])
            padded.append(im)

        strip = np.hstack(padded)

        # Save and display
        strip_path = os.path.join(output_dir, f"contact_{idx+1}_debug_strip.png")
        save_annotated_frame(strip, strip_path)

        # Also save the contact frame alone at full resolution
        contact_frame_annotated = annotate_contact_frame(
            frames[frame_num].copy(), frame_num, fps, confidence
        )
        solo_path = os.path.join(output_dir, f"contact_{idx+1}_frame_{frame_num}.png")
        save_annotated_frame(contact_frame_annotated, solo_path)

        # Display strip
        _, buf = cv2.imencode('.png', strip)
        display(IPImage(data=buf.tobytes(), width=min(strip.shape[1], 1200)))
        print()

    print(f"\n{'='*70}")
    print(f"Saved {len(contacts)} contact debug strips + full-res frames to {output_dir}")
    print(f"{'='*70}")

In [None]:
#@title 5. Pose Analysis at Selected Contact (Optional)
import numpy as np
import pandas as pd
import cv2
from IPython.display import display, Image as IPImage
import os

from utils.coordinate_transforms import (
    pelvis_origin_transform, estimate_ground_plane, apply_ground_plane
)
from src.pose_estimation import PoseEstimator
from src.measurements import compute_measurements
from src.visualization import (
    draw_skeleton, draw_contact_point, draw_measurements, save_annotated_frame
)

#@markdown ### Select contact to analyze
CONTACT_INDEX = 0  #@param {type:"integer"}

#@markdown ### Shot type (determines which wrist = contact point)
SHOT_TYPE = "right_forehand"  #@param ["right_forehand", "right_backhand", "left_forehand", "left_backhand"]

if 'ANALYSIS_DATA' not in dir():
    raise ValueError("Please run cell 2 first!")

contacts = ANALYSIS_DATA['contacts']
if len(contacts) == 0:
    raise ValueError("No contacts detected.")
if CONTACT_INDEX < 0 or CONTACT_INDEX >= len(contacts):
    raise ValueError(f"Invalid index. Valid range: 0-{len(contacts)-1}")

contact = contacts[CONTACT_INDEX]
frame_num = contact['frame']
frames = ANALYSIS_DATA['frames']
fps = ANALYSIS_DATA['fps']

contact_wrist_name = "right_wrist" if SHOT_TYPE.startswith("right") else "left_wrist"

print(f"Analyzing Contact {CONTACT_INDEX + 1}")
print(f"  Frame: {frame_num} ({contact['time']:.2f}s)")
print(f"  Confidence: {contact['confidence']:.0%}")
print(f"  Shot type: {SHOT_TYPE} -> using {contact_wrist_name}")

frame = frames[frame_num]

# --- Pose estimation ---
print("\nEstimating pose...")
pose_estimator = PoseEstimator(static_image_mode=True, model_complexity=2)
landmarks, raw_result = pose_estimator.process_frame(frame)

if landmarks is None:
    pose_estimator.close()
    raise ValueError("No pose detected. Player may not be visible in this frame.")

pixel_lm = pose_estimator.get_pixel_landmarks(raw_result, frame.shape)
pose_estimator.close()
print("  Pose detected!")

# --- Contact point = wrist position ---
if contact_wrist_name not in pixel_lm:
    raise ValueError(f"{contact_wrist_name} not detected in pose.")

contact_pixel = pixel_lm[contact_wrist_name]

# --- Transform coordinates ---
centered = pelvis_origin_transform(landmarks)
ground_z = estimate_ground_plane(centered)
adjusted = apply_ground_plane(centered, ground_z)

wrist_3d = landmarks.get(contact_wrist_name, np.zeros(3))
pelvis = landmarks.get("pelvis", np.zeros(3))
contact_adjusted = wrist_3d - pelvis - np.array([0, 0, ground_z])

# --- Measurements ---
meas = compute_measurements(adjusted, contact_adjusted)
meas["shot_type"] = SHOT_TYPE
meas["frame_num"] = frame_num
meas["contact_confidence"] = contact['confidence']

# --- Annotated frame ---
annotated = frame.copy()
annotated = draw_skeleton(annotated, pixel_lm, thickness=3)
cx, cy = int(contact_pixel[0]), int(contact_pixel[1])
draw_contact_point(annotated, cx, cy, radius=15)
annotated = draw_measurements(annotated, meas, frame_num, fps)

out_path = os.path.join(output_dir, f"contact_{CONTACT_INDEX+1}_pose.png")
save_annotated_frame(annotated, out_path)

print(f"\n{'='*60}")
print("CONTACT POINT MEASUREMENTS")
print(f"{'='*60}")
print(f"Lateral offset:      {meas.get('lateral_offset_cm', 0):>7.1f} cm")
print(f"Forward/back:        {meas.get('forward_back_cm', 0):>7.1f} cm")
print(f"Height above ground: {meas.get('height_above_ground_cm', 0):>7.1f} cm")
if 'shoulder_line_distance_cm' in meas:
    print(f"Shoulder distance:   {meas['shoulder_line_distance_cm']:>7.1f} cm")
print(f"{'='*60}")

display(IPImage(filename=out_path, width=800))

csv_path = os.path.join(output_dir, f"measurements_contact_{CONTACT_INDEX+1}.csv")
pd.DataFrame([meas]).to_csv(csv_path, index=False)
print(f"\nSaved to {csv_path}")

In [None]:
#@title 6. Download Results
from google.colab import files as colab_files
import glob
import os

output_dir = "/content/output"
all_files = glob.glob(os.path.join(output_dir, "*"))

if not all_files:
    print("No output files yet. Run the detection cells first.")
else:
    print("Files available:")
    for f in sorted(all_files):
        size_kb = os.path.getsize(f) / 1024
        print(f"  {os.path.basename(f)} ({size_kb:.1f} KB)")

    print("\nDownloading...")
    for f in all_files:
        colab_files.download(f)