# Hero Video Processing - Google Colab

This notebook processes tennis videos with player tracking (SAM-3d-body) and ball detection (SAM3) using a T4 GPU.

**Make sure to:**
1. Enable GPU: Runtime → Change runtime type → GPU (T4)
2. Upload your video file
3. Upload required model files (see instructions below)

## Step 1: Install Dependencies

In [None]:
# Install core dependencies
!pip install -q opencv-python numpy torch torchvision torchaudio tqdm pillow
!pip install -q transformers accelerate
!pip install -q git+https://github.com/facebookresearch/dinov2.git

# Install additional dependencies for SAM-3d-body
!pip install -q trimesh pyrender braceexpand

# Install YOLO (ultralytics)
!pip install -q ultralytics

print("✅ Dependencies installed")

## Step 2: Clone Required Repositories

In [None]:
# Clone SAM-3d-body repository
!git clone -q https://github.com/facebookresearch/sam-3d-body.git

# Clone SAM3 repository (adjust URL if different)
# !git clone -q https://github.com/your-sam3-repo.git SAM3

print("✅ Repositories cloned")

## Step 3: Set Up Paths and Imports

In [None]:
import sys
import os
from pathlib import Path
import cv2
import numpy as np
import torch
from tqdm import tqdm
from PIL import Image
import json
from typing import List, Optional, Tuple, Dict, Any

# Add SAM-3d-body to path (need both root and sam-3d-body subdirectory)
sys.path.insert(0, '/content/sam-3d-body')
sys.path.insert(0, '/content/sam-3d-body/sam-3d-body')

# Check GPU availability
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"CUDA Version: {torch.version.cuda}")

# Create output directory
os.makedirs('/content/output', exist_ok=True)

print("✅ Paths and imports set up")

## Step 4: Upload Files

**Upload your video file and model files here.**

### Required Files:
1. **Input video** (`.mp4` file)
2. **SAM-3d-body model** (if not auto-downloaded)
3. **SAM3 model** (if needed)
4. **YOLO model** (optional, for human detection) - `playersnball5.pt`
5. **Court detection model** (optional) - `model_tennis_court_det.pt`

In [None]:
from google.colab import files

# Upload video file
print("Please upload your input video file:")
uploaded = files.upload()

# Get the uploaded video filename
video_filename = list(uploaded.keys())[0]
input_video_path = f"/content/{video_filename}"

print(f"✅ Video uploaded: {video_filename}")

In [None]:
# Upload model files (optional - some models auto-download)
print("Upload model files (YOLO, court detection, etc.) or press Cancel to skip:")
model_files = files.upload()

# Create models directory
os.makedirs('/content/models', exist_ok=True)

# Move uploaded models
for filename in model_files.keys():
    os.rename(f"/content/{filename}", f"/content/models/{filename}")
    print(f"✅ Model uploaded: {filename}")

## Step 5: Load Models

In [None]:
# Load SAM-3d-body model
print("Loading SAM-3d-body model...")

try:
    # Try using notebook utils (recommended)
    from notebook.utils import setup_sam_3d_body, load_sam_3d_body_hf
    from sam_3d_body.sam_3d_body_estimator import SAM3DBodyEstimator
    from sam_3d_body.visualization.skeleton_visualizer import SkeletonVisualizer
    from sam_3d_body.metadata.mhr70 import pose_info as mhr70_pose_info
    
    # Load model using notebook utils
    model, model_cfg = load_sam_3d_body_hf("facebook/sam-3d-body-dinov3", device=device)
    
    # Initialize estimator
    estimator = SAM3DBodyEstimator(
        sam_3d_body_model=model,
        model_cfg=model_cfg,
        human_detector=None,  # Will add YOLO later if available
        human_segmentor=None,
        fov_estimator=None,  # Optional
    )
    
    print("✅ SAM-3d-body loaded")
except Exception as e:
    print(f"❌ Error loading SAM-3d-body: {e}")
    import traceback
    traceback.print_exc()
    estimator = None

In [None]:
# Load SAM3 ball detector
print("Loading SAM3 ball detector...")

ball_detector = None

# Check if SAM3 directory exists
sam3_paths = ['/content/SAM3', '/content/sam3', '/content/SAM-3']
sam3_path = None
for path in sam3_paths:
    if os.path.exists(path):
        sam3_path = path
        break

if sam3_path:
    try:
        # Add SAM3 to path
        sys.path.insert(0, sam3_path)
        
        # Try importing SAM3BallDetector
        from test_sam3_ball_detection import SAM3BallDetector
        
        # Initialize SAM3
        ball_detector = SAM3BallDetector(
            model_path=sam3_path,
            device=device,
            use_transformers=True
        )
        
        print(f"✅ SAM3 ball detector loaded from {sam3_path}")
    except Exception as e:
        print(f"❌ Error loading SAM3: {e}")
        import traceback
        traceback.print_exc()
        ball_detector = None
else:
    print("⚠️ SAM3 directory not found. Please upload your SAM3 folder.")
    print("   Expected locations: /content/SAM3, /content/sam3, or /content/SAM-3")
    print("   You can upload it in the next cell or skip ball detection for now.")

In [None]:
# Load YOLO human detector (optional)
yolo_detector = None
yolo_model_path = "/content/models/playersnball5.pt"

if os.path.exists(yolo_model_path):
    try:
        from ultralytics import YOLO
        yolo_model = YOLO(yolo_model_path)
        yolo_model.to(device)  # Move to GPU if available
        print("✅ YOLO detector loaded")
        yolo_detector = yolo_model
        
        # Update SAM-3d-body estimator with YOLO detector if available
        if estimator is not None:
            # Note: SAM3DBodyEstimator may need to be reinitialized with detector
            # For now, YOLO will be used separately if needed
            pass
    except Exception as e:
        print(f"⚠️ Could not load YOLO: {e}")
        import traceback
        traceback.print_exc()
else:
    print("⚠️ YOLO model not found at /content/models/playersnball5.pt")
    print("   Skipping human detection (SAM-3d-body will process full image)")

## Step 6: Configure Processing Parameters

In [None]:
# Processing configuration
config = {
    'frame_skip': 5,  # Process every 5th frame (1 = all frames, higher = faster)
    'fps': 30.0,  # Output video FPS
    'player_color': '#50C878',  # Emerald green
    'ball_color': '#50C878',  # Emerald green
    'trail_length': 30,  # Ball trajectory trail length
    'keypoints_only': False,  # False = full mesh, True = keypoints only (faster)
    'process_resolution': 720,  # Downscale to 720px width for processing (0 = original)
    'enable_court_detection': False,  # Enable if you have court model
    'use_ensemble_ball': False,  # Use ensemble ball detection (slower but more accurate)
}

print("Configuration:")
for key, value in config.items():
    print(f"  {key}: {value}")

## Step 7: Process Video

In [None]:
def hex_to_bgr(hex_color: str) -> Tuple[int, int, int]:
    """Convert hex color to BGR tuple for OpenCV."""
    hex_color = hex_color.lstrip('#')
    r, g, b = tuple(int(hex_color[i:i+2], 16) for i in (0, 2, 4))
    return (b, g, r)  # BGR format


def process_video_colab(
    input_path: str,
    output_path: str,
    config: dict
):
    """Process video with player tracking and ball detection."""
    
    # Open video
    cap = cv2.VideoCapture(input_path)
    if not cap.isOpened():
        raise ValueError(f"Could not open video: {input_path}")
    
    # Get video properties
    original_fps = cap.get(cv2.CAP_PROP_FPS)
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    
    print(f"\nVideo Properties:")
    print(f"  Resolution: {width}x{height}")
    print(f"  FPS: {original_fps:.2f}")
    print(f"  Total frames: {total_frames}")
    
    # Calculate processing parameters
    frame_skip = config['frame_skip']
    frames_to_process = (total_frames + frame_skip - 1) // frame_skip
    
    # Resolution scaling
    process_resolution = config.get('process_resolution', 0)
    if process_resolution > 0 and process_resolution < width:
        scale_factor = process_resolution / width
        process_width = process_resolution
        process_height = int(height * scale_factor)
        print(f"  Processing at: {process_width}x{process_height} (scale: {scale_factor:.2f})")
    else:
        scale_factor = 1.0
        process_width = width
        process_height = height
    
    print(f"\nProcessing:")
    print(f"  Frame skip: {frame_skip} (processing {frames_to_process} frames)")
    print(f"  Output FPS: {config['fps']:.2f}")
    
    # Setup output video
    output_width = width
    output_height = height
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    output_fps = original_fps  # Keep original FPS
    out = cv2.VideoWriter(output_path, fourcc, output_fps, (output_width, output_height))
    
    # Initialize skeleton visualizer if needed
    skeleton_visualizer = None
    if config['keypoints_only']:
        try:
            skeleton_visualizer = SkeletonVisualizer(line_width=2, radius=5)
            skeleton_visualizer.set_pose_meta(mhr70_pose_info)
            # Set emerald green color
            emerald_green = (80, 200, 120)  # RGB
            skeleton_visualizer.kpt_color = emerald_green
            skeleton_visualizer.link_color = emerald_green
        except:
            pass
    
    # Color conversion
    player_color_bgr = hex_to_bgr(config['player_color'])
    ball_color_bgr = hex_to_bgr(config['ball_color'])
    
    # Processing loop
    frame_count = 0
    processed_count = 0
    ball_trajectory = []
    
    print(f"\nStarting processing...")
    
    with tqdm(total=frames_to_process, desc="Processing") as pbar:
        while cap.isOpened():
            ret, frame = cap.read()
            if not ret:
                break
            
            # Skip frames if needed
            if frame_count % frame_skip != 0:
                frame_count += 1
                continue
            
            # Downscale for processing if needed
            if scale_factor < 1.0:
                frame_processed = cv2.resize(frame, (process_width, process_height), interpolation=cv2.INTER_AREA)
            else:
                frame_processed = frame.copy()
            
            # Convert BGR to RGB for models
            frame_rgb = cv2.cvtColor(frame_processed, cv2.COLOR_BGR2RGB)
            
            # Process with SAM-3d-body
            player_outputs = []
            if estimator:
                try:
                    inference_type = "keypoints_only" if config['keypoints_only'] else "full"
                    player_outputs = estimator.process_one_image(frame_rgb, inference_type=inference_type)
                except Exception as e:
                    print(f"\nWarning: SAM-3d-body failed on frame {frame_count}: {e}")
            
            # Process with ball detector
            ball_detection = None
            if ball_detector:
                try:
                    ball_detection = ball_detector.detect_ball(frame_processed, text_prompt="tennis ball")
                except Exception as e:
                    pass  # Silently fail
            
            # Update ball trajectory
            if ball_detection:
                center, confidence, mask = ball_detection
                # Scale center back to original resolution if needed
                if scale_factor < 1.0:
                    center = (int(center[0] / scale_factor), int(center[1] / scale_factor))
                ball_trajectory.append(center)
                if len(ball_trajectory) > config['trail_length']:
                    ball_trajectory.pop(0)
            else:
                if len(ball_trajectory) > 0:
                    ball_trajectory.pop(0)
            
            # Create visualization
            vis_frame = frame.copy()
            
            # Draw ball trajectory
            if len(ball_trajectory) > 1:
                for i in range(len(ball_trajectory) - 1):
                    pt1 = ball_trajectory[i]
                    pt2 = ball_trajectory[i + 1]
                    alpha = (i + 1) / len(ball_trajectory)
                    color = tuple(int(c * alpha) for c in ball_color_bgr)
                    thickness = max(1, int(2 * alpha))
                    cv2.line(vis_frame, pt1, pt2, color, thickness, cv2.LINE_AA)
            
            # Draw players
            if player_outputs:
                if config['keypoints_only'] and skeleton_visualizer:
                    # Keypoints-only mode
                    for output in player_outputs:
                        keypoints_2d = output.get("pred_keypoints_2d", None)
                        if keypoints_2d is not None:
                            keypoints_2d_with_vis = np.concatenate(
                                [keypoints_2d, np.ones((keypoints_2d.shape[0], 1))], axis=-1
                            )
                            vis_frame_rgb = cv2.cvtColor(vis_frame, cv2.COLOR_BGR2RGB)
                            vis_frame_rgb = skeleton_visualizer.draw_skeleton(
                                vis_frame_rgb, keypoints_2d_with_vis, kpt_thr=0.3
                            )
                            vis_frame = cv2.cvtColor(vis_frame_rgb, cv2.COLOR_RGB2BGR)
                else:
                    # Full mesh mode (simplified - just draw keypoints for now)
                    for output in player_outputs:
                        keypoints_2d = output.get("pred_keypoints_2d", None)
                        if keypoints_2d is not None:
                            # Draw simple skeleton
                            for kp in keypoints_2d:
                                if len(kp) > 2 and kp[2] > 0:
                                    x, y = int(kp[0] / scale_factor) if scale_factor < 1.0 else int(kp[0]), \
                                           int(kp[1] / scale_factor) if scale_factor < 1.0 else int(kp[1])
                                    cv2.circle(vis_frame, (x, y), 4, player_color_bgr, -1)
            
            # Draw ball
            if ball_detection:
                center, confidence, mask = ball_detection
                if scale_factor < 1.0:
                    center = (int(center[0] / scale_factor), int(center[1] / scale_factor))
                cv2.circle(vis_frame, center, 8, ball_color_bgr, -1)
                cv2.circle(vis_frame, center, 8, (255, 255, 255), 2)
            
            # Write frame multiple times to maintain original playback speed
            for _ in range(frame_skip):
                out.write(vis_frame)
            
            processed_count += 1
            frame_count += 1
            pbar.update(1)
    
    # Cleanup
    cap.release()
    out.release()
    
    print(f"\n✅ Processing complete!")
    print(f"Processed {processed_count} frames")
    print(f"Output saved to: {output_path}")


# Process the video
output_video_path = "/content/output/hero_video_processed.mp4"

import time
start_time = time.time()

process_video_colab(input_video_path, output_video_path, config)

end_time = time.time()
processing_time = end_time - start_time

print(f"\n⏱️ Total processing time: {processing_time / 60:.2f} minutes ({processing_time:.2f} seconds)")

## Step 8: Download Output Video

In [None]:
# Download the processed video
from google.colab import files

print("Downloading processed video...")
files.download(output_video_path)
print("✅ Download complete!")