# Football Team Detection - Long Video (Optimized)

This notebook processes 7+ minute football videos using:
- **ResNet50** embeddings 
- **Frame sampling** (every 5th frame)
- **HDBSCAN clustering** for automatic team detection
- **Dominant color extraction** for accurate team colors
- **Smart ball tracking** with possession detection

**Workflow:**
1. Load video & extract player crops with embeddings
2. Cluster players using UMAP + HDBSCAN
3. Extract dominant jersey colors from each cluster
4. Annotate video with team colors + ball tracking

In [4]:
# === 1. IMPORTS & CONFIGURATION ===
import cv2
import numpy as np
import matplotlib.pyplot as plt
from ultralytics import YOLO
from sklearn.preprocessing import StandardScaler
from sklearn.cluster import KMeans
import umap
import hdbscan
import plotly.graph_objects as go
import ipywidgets as widgets
from IPython.display import display
from PIL import Image
import torch
import torchvision.models as models
import torchvision.transforms as transforms
import pickle
import os
import time
import warnings

warnings.filterwarnings('ignore')

# Configuration
VIDEO_PATH = 'videos/7_min_sample.mp4'
MODEL_PATH = 'yolov8n.pt'
EMBEDDINGS_CACHE = 'embeddings_cache_7_min.pkl'

CONF_THRESH = 0.3
FRAME_SAMPLING_RATE = 2
MAX_DETECTIONS_PER_FRAME = 20

# Load ResNet50 for embeddings
print("Loading ResNet50...")
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")

resnet_model = models.resnet50(pretrained=True)
resnet_model = torch.nn.Sequential(*list(resnet_model.children())[:-1])
resnet_model.to(device).eval()

resnet_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

print("✓ ResNet50 ready (2048-dim embeddings)\n")

Loading ResNet50...
Device: cpu
✓ ResNet50 ready (2048-dim embeddings)



In [5]:
# === 2. HELPER FUNCTIONS ===
def get_resnet_embedding(image):
    """Extract ResNet50 embedding from BGR image"""
    img_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    img_tensor = resnet_transform(Image.fromarray(img_rgb)).unsqueeze(0).to(device)
    with torch.no_grad():
        return resnet_model(img_tensor).squeeze().cpu().numpy()

def crop_player(frame, x1, y1, x2, y2):
    """Crop with bounds checking"""
    h, w = frame.shape[:2]
    x1, y1, x2, y2 = max(0, int(x1)), max(0, int(y1)), min(w, int(x2)), min(h, int(y2))
    return frame[y1:y2, x1:x2] if x2 > x1 and y2 > y1 else None

def format_time(seconds):
    """Format seconds as MM:SS"""
    return f"{int(seconds // 60)}:{int(seconds % 60):02d}"

print("✓ Helper functions loaded")

✓ Helper functions loaded


In [6]:
# === 3. VIDEO PROCESSING & EMBEDDING EXTRACTION ===

if os.path.exists(EMBEDDINGS_CACHE):
    print(f"Loading cache: {EMBEDDINGS_CACHE}")
    with open(EMBEDDINGS_CACHE, 'rb') as f:
        cache_data = pickle.load(f)
        all_crops = cache_data['crops']
        all_embeddings = cache_data['embeddings']
        ball_crops = cache_data['ball_crops']
        ball_detections = cache_data.get('ball_detections', []) 
        video_width = cache_data['video_width']
        video_height = cache_data['video_height']
        detection_metadata = cache_data['detection_metadata']
    
    print(f"✓ Loaded: {len(all_crops)} person detections, {len(ball_crops)} ball crops, {len(ball_detections)} ball metadata\n")

else:
    print("Processing video...")
    start_time = time.time()
    last_progress_time = start_time
    
    cap = cv2.VideoCapture(VIDEO_PATH)
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    fps = cap.get(cv2.CAP_PROP_FPS)
    video_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    video_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    
    estimated_frames_to_process = total_frames // FRAME_SAMPLING_RATE
    
    print(f"Video: {format_time(total_frames/fps)} ({total_frames} frames @ {fps:.1f} FPS)")
    print(f"Sampling: Every {FRAME_SAMPLING_RATE}th frame (~{estimated_frames_to_process} frames to process)\n")
    
    model = YOLO(MODEL_PATH)
    all_crops, all_embeddings, ball_crops, ball_detections, detection_metadata = [], [], [], [], []
    frame_count = 0
    processed_frames = 0
    
    # Progress checkpoints
    progress_checkpoints = set(range(1000, total_frames, 1000))
    
    while True:
        ret, frame = cap.read()
        if not ret:
            break
        
        # CRITICAL: Skip frames if not at sampling interval
        if frame_count % FRAME_SAMPLING_RATE != 0:
            frame_count += 1
            continue
        
        # Timeout protection - warn if stuck
        current_time = time.time()
        if current_time - last_progress_time > 600:  # No progress for 600 seconds
            elapsed_total = current_time - start_time
            print(f" Frame {frame_count}/{total_frames}, {len(all_crops)} detections, elapsed {format_time(elapsed_total)}")
            last_progress_time = current_time
        
        results = model(frame, conf=CONF_THRESH, verbose=False)
        frame_detections = 0
        
        for result in results:
            if result.boxes is None:
                continue
            
            for box in result.boxes:
                if frame_detections >= MAX_DETECTIONS_PER_FRAME:
                    break
                
                x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
                cls, conf = int(box.cls), float(box.conf)
                
                # Person detection
                if cls == 0:
                    width, height = x2 - x1, y2 - y1
                    area = width * height
                    aspect_ratio = height / width if width > 0 else 0
                    center_x, center_y = (x1 + x2) / 2, (y1 + y2) / 2
                    
                    # Basic filters
                    if (area < 2000 or area > 50000 or aspect_ratio < 1.2 or aspect_ratio > 4.0 or
                        width < 30 or height < 50 or x1 < 10 or y1 < 10 or
                        x2 > video_width - 10 or y2 > video_height - 10):
                        continue
                    
                    crop = crop_player(frame, x1, y1, x2, y2)
                    if crop is not None:
                        all_crops.append(crop)
                        all_embeddings.append(get_resnet_embedding(crop))
                        detection_metadata.append({'frame_idx': frame_count, 'bbox': (x1, y1, x2, y2)})
                        frame_detections += 1
                
                # Ball detection - SAVE METADATA
                elif cls == 32:
                    area = (x2 - x1) * (y2 - y1)
                    if 50 < area < 2000 and conf > 0.3:
                        crop = crop_player(frame, x1, y1, x2, y2)
                        if crop is not None:
                            ball_crops.append(crop)
                            ball_detections.append({
                                'frame_idx': frame_count,
                                'bbox': (x1, y1, x2, y2),
                                'center': ((x1 + x2) / 2, (y1 + y2) / 2),
                                'conf': conf,
                                'area': area
                            })
        
        processed_frames += 1
        frame_count += 1
        
        # Progress every 1000 frames
        if frame_count in progress_checkpoints:
            elapsed = time.time() - start_time
            fps_proc = processed_frames / elapsed
            eta = (estimated_frames_to_process - processed_frames) / fps_proc if fps_proc > 0 else 0
            
            print(f"  Frame {frame_count}/{total_frames} ({frame_count/total_frames*100:.1f}%) - {len(all_crops)} detections")
            print(f"    Processed: {processed_frames} frames | Speed: {fps_proc:.1f} FPS | ETA: {format_time(eta)}")
            last_progress_time = time.time()
    
    cap.release()
    
    total_elapsed = time.time() - start_time
    print(f"\n✓ Processed in {format_time(total_elapsed)}")
    print(f"  Frames: {frame_count} total, {processed_frames} processed")
    print(f"  Detections: {len(all_crops)} persons, {len(ball_crops)} ball crops")
    print(f"  Average speed: {processed_frames/total_elapsed:.2f} FPS")
    
    # Save cache with ball metadata
    print(f"\nSaving cache...")
    with open(EMBEDDINGS_CACHE, 'wb') as f:
        pickle.dump({
            'crops': all_crops, 'embeddings': all_embeddings, 
            'ball_crops': ball_crops, 'ball_detections': ball_detections,  
            'video_width': video_width, 'video_height': video_height, 
            'detection_metadata': detection_metadata
        }, f)
    print(f"✓ Cache saved\n")

Processing video...


Video: 7:50 (14109 frames @ 30.0 FPS)
Sampling: Every 2th frame (~7054 frames to process)


✓ Processed in 133:41
  Frames: 14109 total, 7055 processed
  Detections: 48645 persons, 1451 ball crops
  Average speed: 0.88 FPS

Saving cache...
✓ Cache saved



In [7]:
# === 4. CLUSTERING (UMAP + HDBSCAN) ===

embeddings_array = np.array(all_embeddings)
print(f"Clustering {len(embeddings_array)} detections ({embeddings_array.shape[1]}-dim)...")

# Standardize
scaler = StandardScaler()
embeddings_scaled = scaler.fit_transform(embeddings_array)

# UMAP 4D reduction
umap_reducer = umap.UMAP(n_components=4, n_neighbors=15, min_dist=0.1, random_state=42)
embeddings_4d = umap_reducer.fit_transform(embeddings_scaled)
embeddings_3d = embeddings_4d[:, :3]
print(f"✓ UMAP 4D → 3D")

# HDBSCAN clustering
clusterer = hdbscan.HDBSCAN(min_cluster_size=50, min_samples=50, cluster_selection_epsilon=0.05)
team_labels = clusterer.fit_predict(embeddings_4d)

unique_labels = np.unique(team_labels)
n_clusters = len(unique_labels[unique_labels >= 0])
n_noise = np.sum(team_labels == -1)

print(f"✓ Found {n_clusters} clusters, {n_noise} outliers ({n_noise/len(team_labels)*100:.1f}%)")

# Analyze clusters
cluster_info = []
for label in unique_labels:
    if label == -1:
        continue
    
    cluster_mask = team_labels == label
    cluster_size = np.sum(cluster_mask)
    
    cluster_bboxes = [detection_metadata[i]['bbox'] for i in range(len(team_labels)) if team_labels[i] == label]
    avg_x_position = np.mean([(bbox[0] + bbox[2])/2 for bbox in cluster_bboxes])
    
    cluster_info.append({
        'id': label,
        'size': cluster_size,
        'avg_x_pos': avg_x_position
    })

cluster_info.sort(key=lambda x: x['size'], reverse=True)

# Team mapping
team_mapping = {}
if len(cluster_info) >= 2:
    team_mapping[cluster_info[0]['id']] = "TEAM 1"
    team_mapping[cluster_info[1]['id']] = "TEAM 2"
    print(f"\n  TEAM 1: {cluster_info[0]['size']} detections")
    print(f"  TEAM 2: {cluster_info[1]['size']} detections")

if len(cluster_info) >= 3 and cluster_info[2]['size'] > 100:
    team_mapping[cluster_info[2]['id']] = "REFEREE"
    print(f"  REFEREE: {cluster_info[2]['size']} detections")

for info in cluster_info:
    if info['id'] not in team_mapping:
        x_pos = "L" if info['avg_x_pos'] < video_width * 0.3 else ("R" if info['avg_x_pos'] > video_width * 0.7 else "C")
        team_mapping[info['id']] = f"OTHER (C{info['id']}, {x_pos})"

print(f"\n✓ Clustering complete!\n")

Clustering 48645 detections (2048-dim)...
✓ UMAP 4D → 3D
✓ Found 19 clusters, 386 outliers (0.8%)

  TEAM 1: 21072 detections
  TEAM 2: 20742 detections
  REFEREE: 2909 detections

✓ Clustering complete!



In [15]:
# === 5. INTERACTIVE 3D VISUALIZATION ===

import plotly.graph_objects as go
import numpy as np

# Create basic plotly figure (not FigureWidget to avoid anywidget issues)
fig = go.Figure()

color_palette_rgb = {
    'blue': 'rgb(0, 100, 255)', 'yellow': 'rgb(255, 220, 0)', 'white': 'rgb(220, 220, 220)',
    'red': 'rgb(255, 50, 50)', 'orange': 'rgb(255, 165, 0)', 'cyan': 'rgb(0, 255, 255)',
    'magenta': 'rgb(255, 0, 255)', 'pink': 'rgb(255, 192, 203)', 'brown': 'rgb(165, 42, 42)',
    'purple': 'rgb(160, 0, 200)', 'lightgray': 'rgb(180, 180, 180)'
}
color_map = ['blue', 'yellow', 'white', 'red', 'orange', 'cyan', 'magenta', 'pink', 'brown', 'purple']

# Outliers
noise_mask = team_labels == -1
if np.any(noise_mask):
    fig.add_trace(go.Scatter3d(
        x=embeddings_3d[noise_mask, 0], 
        y=embeddings_3d[noise_mask, 1], 
        z=embeddings_3d[noise_mask, 2],
        mode='markers', 
        name=f'Outliers ({np.sum(noise_mask)})',
        marker=dict(size=2, color='lightgray', opacity=0.2, line=dict(width=0)),
        text=[f'Outlier #{i}' for i in np.where(noise_mask)[0]],
        hovertemplate='%{text}<br>UMAP: (%{x:.2f}, %{y:.2f}, %{z:.2f})<extra></extra>'
    ))

# Clusters
for i, info in enumerate(cluster_info):
    cluster_mask = team_labels == info['id']
    cluster_indices = np.where(cluster_mask)[0]
    
    fig.add_trace(go.Scatter3d(
        x=embeddings_3d[cluster_mask, 0], 
        y=embeddings_3d[cluster_mask, 1], 
        z=embeddings_3d[cluster_mask, 2],
        mode='markers', 
        name=f'{team_mapping[info["id"]]} ({info["size"]})',
        marker=dict(
            size=4 if i < 2 else 3,
            color=color_palette_rgb[color_map[i % len(color_map)]],
            opacity=0.8 if i < 2 else 0.7,
            line=dict(width=0)
        ),
        text=[f'{team_mapping[info["id"]]} #{idx}' for idx in cluster_indices],
        hovertemplate='%{text}<br>UMAP: (%{x:.2f}, %{y:.2f}, %{z:.2f})<extra></extra>'
    ))

fig.update_layout(
    title='3D UMAP Clustering - Football Player Detection',
    scene=dict(
        xaxis_title='UMAP 1', 
        yaxis_title='UMAP 2', 
        zaxis_title='UMAP 3',
        bgcolor='rgb(240, 240, 240)',
        xaxis=dict(gridcolor='white', showbackground=True),
        yaxis=dict(gridcolor='white', showbackground=True),
        zaxis=dict(gridcolor='white', showbackground=True)
    ),
    width=1000, 
    height=700,
    hovermode='closest',
    showlegend=True,
    legend=dict(
        yanchor="top",
        y=0.99,
        xanchor="left",
        x=0.01,
        bgcolor='rgba(255, 255, 255, 0.8)'
    )
)

fig.show()

print("✓ 3D plot displayed!")
print("\n💡 Tip: Click and drag to rotate, scroll to zoom")
print(f"   Total points: {len(all_crops)} detections")
print(f"   Clusters: {len(cluster_info)} teams/groups\n")

✓ 3D plot displayed!

💡 Tip: Click and drag to rotate, scroll to zoom
   Total points: 48645 detections
   Clusters: 19 teams/groups



In [9]:
# === 6. COLOR EXTRACTION ===

def extract_dominant_color(crops, num_samples=30):
    """Extract dominant color from center 60% of crops"""
    if len(crops) == 0:
        return (200, 200, 200)
    
    sampled_crops = [crops[i] for i in np.random.choice(len(crops), min(num_samples, len(crops)), replace=False)]
    all_pixels = []
    
    for crop in sampled_crops:
        crop_rgb = cv2.cvtColor(cv2.resize(crop, (64, 64)), cv2.COLOR_BGR2RGB)
        h, w = crop_rgb.shape[:2]
        
        # Center 60%
        center_crop = crop_rgb[int(h*0.2):int(h*0.8), int(w*0.2):int(w*0.8)]
        pixels = center_crop.reshape(-1, 3)
        
        # Filter green (pitch) and extremes
        is_green = (pixels[:, 1] > pixels[:, 0] + 20) & (pixels[:, 1] > pixels[:, 2] + 20)
        pixels = pixels[~is_green]
        brightness = pixels.mean(axis=1)
        pixels = pixels[(brightness > 30) & (brightness < 225)]
        
        if len(pixels) > 0:
            all_pixels.append(pixels)
    
    if len(all_pixels) == 0:
        return (200, 200, 200)
    
    all_pixels = np.vstack(all_pixels)
    
    # KMeans for dominant color
    kmeans = KMeans(n_clusters=3, random_state=42, n_init=10)
    kmeans.fit(all_pixels)
    dominant_color_rgb = kmeans.cluster_centers_[np.argmax(np.bincount(kmeans.labels_))].astype(int)
    
    # Boost brightness & saturation in HSV
    hsv = cv2.cvtColor(np.uint8([[dominant_color_rgb]]), cv2.COLOR_RGB2HSV)[0][0]
    hsv[1] = np.clip(hsv[1] * 1.2, 0, 255)  # Saturation
    hsv[2] = np.clip(hsv[2] * 1.2, 0, 255)  # Brightness
    boosted_rgb = cv2.cvtColor(np.uint8([[hsv]]), cv2.COLOR_HSV2RGB)[0][0]
    
    return tuple(boosted_rgb[::-1].tolist())  # BGR

# Extract colors
cluster_colors = {}
print("Extracting dominant colors...\n")

for info in cluster_info:
    cluster_crops = [all_crops[i] for i in range(len(team_labels)) if team_labels[i] == info['id']]
    cluster_colors[info['id']] = extract_dominant_color(cluster_crops)
    print(f"  {team_mapping[info['id']]}: BGR{cluster_colors[info['id']]}")

print("\n✓ Colors extracted!\n")

Extracting dominant colors...

  TEAM 1: BGR(117, 85, 37)
  TEAM 2: BGR(248, 255, 213)
  REFEREE: BGR(34, 62, 39)
  OTHER (C6, C): BGR(50, 72, 76)
  OTHER (C8, C): BGR(53, 70, 49)
  OTHER (C18, C): BGR(47, 70, 52)
  OTHER (C1, C): BGR(192, 184, 146)
  OTHER (C11, C): BGR(52, 64, 23)
  OTHER (C0, R): BGR(38, 44, 75)
  OTHER (C3, R): BGR(20, 0, 198)
  OTHER (C16, C): BGR(58, 68, 59)
  OTHER (C5, C): BGR(89, 82, 139)
  OTHER (C4, R): BGR(103, 78, 46)
  OTHER (C9, L): BGR(116, 75, 33)
  OTHER (C2, R): BGR(50, 60, 36)
  OTHER (C13, L): BGR(240, 255, 215)
  OTHER (C17, C): BGR(45, 56, 32)
  OTHER (C7, C): BGR(138, 75, 11)
  OTHER (C15, R): BGR(45, 49, 30)

✓ Colors extracted!



In [11]:
# === 7. VIDEO ANNOTATION (ENHANCED SMOOTHING + ADVANCED BALL TRACKING) ===

NUM_BOTTOM_CLUSTERS_TO_SKIP = 3

# Filter clusters
clusters_to_draw = set(c['id'] for c in cluster_info[:-NUM_BOTTOM_CLUSTERS_TO_SKIP]) if len(cluster_info) > NUM_BOTTOM_CLUSTERS_TO_SKIP else set(c['id'] for c in cluster_info)

print(f"Drawing {len(clusters_to_draw)} clusters (skipping bottom {NUM_BOTTOM_CLUSTERS_TO_SKIP} + outliers)\n")

# Build frame mapping for players
frame_detections = {}
for i, metadata in enumerate(detection_metadata):
    if team_labels[i] != -1 and team_labels[i] in clusters_to_draw:
        frame_idx = metadata['frame_idx']
        if frame_idx not in frame_detections:
            frame_detections[frame_idx] = []
        frame_detections[frame_idx].append({'bbox': metadata['bbox'], 'cluster': team_labels[i]})

print(f"Mapped {sum(len(v) for v in frame_detections.values())} player detections\n")

# Build frame mapping for balls
ball_frame_map = {}
for ball_det in ball_detections:
    frame_idx = ball_det['frame_idx']
    if frame_idx not in ball_frame_map:
        ball_frame_map[frame_idx] = []
    ball_frame_map[frame_idx].append(ball_det)

print(f"Mapped {len(ball_detections)} ball detections across {len(ball_frame_map)} frames\n")

# Video I/O
cap = cv2.VideoCapture(VIDEO_PATH)
fps = int(cap.get(cv2.CAP_PROP_FPS))
width, height = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

out = cv2.VideoWriter('videos/7_min_annotated.mp4', cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))

player_history = {}
ball_history = []
frame_idx = 0
start_time = time.time()

# ENHANCED SMOOTHING PARAMETERS
SMOOTHING_WINDOW = 10
EMA_ALPHA = 0.3
DUPLICATE_THRESHOLD = 40
IOU_THRESHOLD = 0.2
TRACK_LIFETIME = 15
CLEANUP_INTERVAL = 100

# ADVANCED BALL TRACKING PARAMETERS
BALL_PERSISTENCE_FRAMES = 8
BALL_MAX_SPEED = 150
BALL_CONF_THRESHOLD = 0.15
POSSESSION_DISTANCE_THRESHOLD = 120
BALL_SMOOTHING_WINDOW = 5

print("Annotating video...")
print(f"🎯 Player smoothing: {SMOOTHING_WINDOW}-frame EMA (alpha={EMA_ALPHA})")
print(f"✨ Duplicate prevention: {DUPLICATE_THRESHOLD}px threshold")
print(f"⏱️ Track lifetime: {TRACK_LIFETIME} frames")
print(f"📊 Frame sampling: Every {FRAME_SAMPLING_RATE}th frame")
print(f"⚽ Ball tracking: {BALL_PERSISTENCE_FRAMES}-frame persistence, {BALL_SMOOTHING_WINDOW}-frame smoothing\n")

# Progress tracking
progress_checkpoints = [int(total_frames * p / 100) for p in range(10, 100, 10)]
last_progress_time = start_time

def compute_iou(box1, box2):
    x1_1, y1_1, x2_1, y2_1 = box1
    x1_2, y1_2, x2_2, y2_2 = box2
    xi1, yi1 = max(x1_1, x1_2), max(y1_1, y1_2)
    xi2, yi2 = min(x2_1, x2_2), min(y2_1, y2_2)
    inter = max(0, xi2 - xi1) * max(0, yi2 - yi1)
    union = (x2_1 - x1_1) * (y2_1 - y1_1) + (x2_2 - x1_2) * (y2_2 - y2_2) - inter
    return inter / union if union > 0 else 0

def average_position(history, window=SMOOTHING_WINDOW):
    """Simple average of last N positions"""
    if not history:
        return None, None
    
    recent = history[-window:]
    positions = [((h['bbox'][0] + h['bbox'][2]) / 2, h['bbox'][3]) for h in recent]
    
    avg_x = sum(p[0] for p in positions) / len(positions)
    avg_y = sum(p[1] for p in positions) / len(positions)
    
    return avg_x, avg_y

def exponential_moving_average(history, alpha=EMA_ALPHA):
    """Calculate EMA position with velocity prediction"""
    if not history:
        return None, None
    
    if len(history) == 1:
        bbox = history[0]['bbox']
        return (bbox[0] + bbox[2]) / 2, bbox[3]
    
    # Start with oldest position
    bbox = history[0]['bbox']
    ema_x = (bbox[0] + bbox[2]) / 2
    ema_y = bbox[3]
    
    # Apply EMA
    for h in history[1:]:
        bbox = h['bbox']
        curr_x = (bbox[0] + bbox[2]) / 2
        curr_y = bbox[3]
        
        ema_x = alpha * curr_x + (1 - alpha) * ema_x
        ema_y = alpha * curr_y + (1 - alpha) * ema_y
    
    return ema_x, ema_y

def interpolate_position(history, current_frame, max_gap=FRAME_SAMPLING_RATE):
    """Interpolate position between detections for smoother movement"""
    if not history or len(history) < 2:
        return exponential_moving_average(history)
    
    last_det = history[-1]
    last_frame = last_det['frame']
    
    # If we have a very recent detection, use it
    if current_frame - last_frame <= 1:
        return exponential_moving_average(history)
    
    # Find two bracketing detections for interpolation
    if len(history) >= 2:
        prev_det = history[-2]
        prev_frame = prev_det['frame']
        
        # Only interpolate if gap is reasonable
        if last_frame - prev_frame <= max_gap and current_frame - last_frame <= max_gap:
            # Calculate velocity
            prev_bbox = prev_det['bbox']
            last_bbox = last_det['bbox']
            
            prev_x = (prev_bbox[0] + prev_bbox[2]) / 2
            prev_y = prev_bbox[3]
            last_x = (last_bbox[0] + last_bbox[2]) / 2
            last_y = last_bbox[3]
            
            dt = last_frame - prev_frame
            if dt > 0:
                vx = (last_x - prev_x) / dt
                vy = (last_y - prev_y) / dt
                
                # Predict forward
                time_ahead = current_frame - last_frame
                pred_x = last_x + vx * time_ahead
                pred_y = last_y + vy * time_ahead
                
                # Blend prediction with EMA for stability
                ema_x, ema_y = exponential_moving_average(history)
                blend_factor = min(time_ahead / max_gap, 0.7)  # More EMA weight as we get further from detection
                
                final_x = blend_factor * ema_x + (1 - blend_factor) * pred_x
                final_y = blend_factor * ema_y + (1 - blend_factor) * pred_y
                
                return final_x, final_y
    
    # Fallback to EMA
    return exponential_moving_average(history)

def find_nearest_player(position, player_history, current_frame, max_time_diff=3):
    """Find the player nearest to a given position"""
    min_dist = float('inf')
    nearest_player_id = None
    
    for player_id, history in player_history.items():
        if not history:
            continue
        
        last_frame = history[-1]['frame']
        if abs(current_frame - last_frame) > max_time_diff:
            continue
        
        bbox = history[-1]['bbox']
        player_x = (bbox[0] + bbox[2]) / 2
        player_y = bbox[3]  # Feet position
        
        dist = np.sqrt((player_x - position[0])**2 + (player_y - position[1])**2)
        
        if dist < min_dist:
            min_dist = dist
            nearest_player_id = player_id
    
    if min_dist < POSSESSION_DISTANCE_THRESHOLD:
        return nearest_player_id
    return None

def predict_ball_position(ball_history, current_frame):
    """Predict ball position based on recent trajectory"""
    if len(ball_history) < 2:
        return None
    
    last = ball_history[-1]
    prev = ball_history[-2]
    
    if current_frame - last['frame'] > 5:
        return None
    
    dt = last['frame'] - prev['frame']
    if dt == 0:
        return last['center']
    
    vx = (last['center'][0] - prev['center'][0]) / dt
    vy = (last['center'][1] - prev['center'][1]) / dt
    
    time_ahead = current_frame - last['frame']
    pred_x = last['center'][0] + vx * time_ahead
    pred_y = last['center'][1] + vy * time_ahead
    
    return (pred_x, pred_y)

def is_valid_ball_detection(center, ball_history, current_frame):
    """Check if detection is consistent with ball trajectory"""
    if not ball_history:
        return True
    
    predicted = predict_ball_position(ball_history, current_frame)
    if predicted is None:
        return True
    
    dist = np.sqrt((center[0] - predicted[0])**2 + (center[1] - predicted[1])**2)
    return dist < BALL_MAX_SPEED

def smooth_ball_position(ball_history, frame_idx):
    """Calculate smoothed ball position from recent history"""
    recent_positions = []
    for b in ball_history[-BALL_SMOOTHING_WINDOW:]:
        if frame_idx - b['frame'] <= BALL_PERSISTENCE_FRAMES:
            recent_positions.append(b['center'])
    
    if recent_positions:
        center_x = sum(p[0] for p in recent_positions) / len(recent_positions)
        center_y = sum(p[1] for p in recent_positions) / len(recent_positions)
        return (center_x, center_y)
    
    return None

ball_possession_player = None
model = YOLO(MODEL_PATH)

while True:
    ret, frame = cap.read()
    if not ret:
        break
    
    # PERIODIC CLEANUP - Remove dead tracks to prevent slowdown
    if frame_idx % CLEANUP_INTERVAL == 0:
        dead_pids = [pid for pid, hist in player_history.items() 
                     if not hist or frame_idx - hist[-1]['frame'] > TRACK_LIFETIME * 2]
        for pid in dead_pids:
            del player_history[pid]
        
        if frame_idx % 1000 == 0 and dead_pids:
            print(f"    Cleaned {len(dead_pids)} dead tracks, {len(player_history)} active")
    
    # Get nearest sampled frame detections for players
    nearest_frame = min((f for f in frame_detections.keys() if abs(f - frame_idx) <= FRAME_SAMPLING_RATE), 
                       key=lambda f: abs(f - frame_idx), default=None)
    
    if nearest_frame and nearest_frame in frame_detections:
        for det in frame_detections[nearest_frame]:
            bbox, cluster = det['bbox'], det['cluster']
            
            # Match to existing track using IOU
            best_id, best_iou = None, 0
            for pid, hist in player_history.items():
                if hist and frame_idx - hist[-1]['frame'] <= TRACK_LIFETIME:
                    iou = compute_iou(bbox, hist[-1]['bbox'])
                    # Bonus for same cluster
                    if hist[-1]['cluster'] == cluster:
                        iou += 0.2
                    
                    if iou > IOU_THRESHOLD and iou > best_iou:
                        best_iou, best_id = iou, pid
            
            if best_id is not None:
                player_history[best_id].append({'bbox': bbox, 'cluster': cluster, 'frame': frame_idx})
                if len(player_history[best_id]) > SMOOTHING_WINDOW + 2:
                    player_history[best_id].pop(0)
            else:
                player_history[len(player_history)] = [{'bbox': bbox, 'cluster': cluster, 'frame': frame_idx}]
    
    # Draw players with ENHANCED smoothing and interpolation
    drawn_positions = []
    
    for pid, hist in player_history.items():
        if hist and frame_idx - hist[-1]['frame'] <= TRACK_LIFETIME:
            center_x, center_y = interpolate_position(hist, frame_idx)
            
            if center_x is None:
                continue
            
            # Check for duplicates
            is_duplicate = False
            for drawn_x, drawn_y, drawn_cluster in drawn_positions:
                distance = np.sqrt((center_x - drawn_x)**2 + (center_y - drawn_y)**2)
                if distance < DUPLICATE_THRESHOLD:
                    is_duplicate = True
                    break
            
            if not is_duplicate:
                drawn_positions.append((center_x, center_y, hist[-1]['cluster']))
                color = cluster_colors.get(hist[-1]['cluster'], (255, 255, 255))
                cv2.ellipse(frame, (int(center_x), int(center_y)), (40, 12), 0, 0, 360, color, 3)
    
    # ADVANCED BALL TRACKING
    results = model(frame, conf=BALL_CONF_THRESHOLD, verbose=False, classes=[32])
    
    ball_candidates = []
    for result in results:
        boxes = result.boxes
        if boxes is not None:
            for box in boxes:
                x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
                conf = float(box.conf)
                
                width_box = x2 - x1
                height_box = y2 - y1
                area = width_box * height_box
                
                if 30 < area < 2500 and conf > BALL_CONF_THRESHOLD:
                    ball_center = ((x1 + x2) / 2, (y1 + y2) / 2)
                    
                    if is_valid_ball_detection(ball_center, ball_history, frame_idx):
                        ball_candidates.append({
                            'center': ball_center,
                            'conf': conf,
                            'area': area
                        })
    
    # Process ball detection and possession
    if ball_candidates:
        # Ball detected - find best candidate
        best_ball = max(ball_candidates, key=lambda b: b['conf'] * (1.0 if 100 < b['area'] < 800 else 0.5))
        ball_center = best_ball['center']
        
        # Find nearest player
        nearest_player_id = find_nearest_player(ball_center, player_history, frame_idx)
        
        # Update ball history
        ball_history.append({
            'center': ball_center,
            'frame': frame_idx,
            'type': 'detected',
            'possession_player': nearest_player_id
        })
        
        # Update possession
        if nearest_player_id is not None:
            ball_possession_player = nearest_player_id
        else:
            ball_possession_player = None
        
        if len(ball_history) > 20:
            ball_history.pop(0)
    else:
        # Ball not detected - check if player has possession
        if ball_history:
            last_ball = ball_history[-1]
            frames_since_detection = frame_idx - last_ball['frame']
            
            # If ball was recently with a player, assume they still have it
            if frames_since_detection <= 15 and last_ball.get('possession_player') is not None:
                possession_player_id = last_ball['possession_player']
                
                # Check if player is still visible
                if possession_player_id in player_history:
                    player_hist = player_history[possession_player_id]
                    if player_hist and frame_idx - player_hist[-1]['frame'] <= 3:
                        # Player visible - ball is with them
                        bbox = player_hist[-1]['bbox']
                        estimated_ball_pos = ((bbox[0] + bbox[2]) / 2, bbox[3])
                        
                        ball_history.append({
                            'center': estimated_ball_pos,
                            'frame': frame_idx,
                            'type': 'possession',
                            'possession_player': possession_player_id
                        })
                        ball_possession_player = possession_player_id
                        
                        if len(ball_history) > 20:
                            ball_history.pop(0)
    
    # Draw ball with smoothing
    if ball_history:
        last_ball = ball_history[-1]
        frames_since_update = frame_idx - last_ball['frame']
        
        if frames_since_update <= BALL_PERSISTENCE_FRAMES:
            # Get smoothed position
            smoothed_pos = smooth_ball_position(ball_history, frame_idx)
            
            if smoothed_pos:
                center_x, center_y = smoothed_pos
            elif frames_since_update <= 3:
                # Try prediction
                predicted = predict_ball_position(ball_history, frame_idx)
                if predicted:
                    center_x, center_y = predicted
                else:
                    center_x, center_y = last_ball['center']
            else:
                center_x, center_y = last_ball['center']
            
            # Draw triangle
            triangle = np.array([
                [int(center_x), int(center_y - 25)],
                [int(center_x - 12), int(center_y - 8)],
                [int(center_x + 12), int(center_y - 8)]
            ])
            cv2.fillPoly(frame, [triangle], (0, 255, 0))
            cv2.polylines(frame, [triangle], True, (0, 0, 0), 2)
    
    out.write(frame)
    frame_idx += 1
    
    # Progress tracking every 10%
    if frame_idx in progress_checkpoints:
        elapsed = time.time() - start_time
        progress_pct = (frame_idx / total_frames) * 100
        fps_current = frame_idx / elapsed
        eta_seconds = (total_frames - frame_idx) / fps_current
        
        print(f"  📊 {progress_pct:.0f}% complete ({frame_idx}/{total_frames})")
        print(f"     ⏱️  Elapsed: {format_time(elapsed)} | ETA: {format_time(eta_seconds)} | Speed: {fps_current:.1f} FPS")
        last_progress_time = time.time()

cap.release()
out.release()

total_time = time.time() - start_time
print(f"\n✓ Video saved: videos/7_min_annotated.mp4")
print(f"  Time: {format_time(total_time)}")
print(f"  Average FPS: {total_frames/total_time:.1f}")

# Ball tracking statistics
detected_frames = sum(1 for b in ball_history if b.get('type') == 'detected')
possession_frames = sum(1 for b in ball_history if b.get('type') == 'possession')

print(f"\n⚽ Ball Tracking Statistics:")
print(f"  Detected: {detected_frames} frames ({detected_frames/frame_idx*100:.1f}%)")
print(f"  In possession: {possession_frames} frames ({possession_frames/frame_idx*100:.1f}%)")
print(f"  Total visible: {len(ball_history)} frames ({len(ball_history)/frame_idx*100:.1f}%)")

Drawing 16 clusters (skipping bottom 3 + outliers)

Mapped 48103 player detections

Mapped 1451 ball detections across 1402 frames

Annotating video...
🎯 Player smoothing: 10-frame EMA (alpha=0.3)
✨ Duplicate prevention: 40px threshold
⏱️ Track lifetime: 15 frames
📊 Frame sampling: Every 2th frame
⚽ Ball tracking: 8-frame persistence, 5-frame smoothing

    Cleaned 2 dead tracks, 5 active
  📊 10% complete (1410/14109)
     ⏱️  Elapsed: 3:07 | ETA: 28:08 | Speed: 7.5 FPS
  📊 20% complete (2821/14109)
     ⏱️  Elapsed: 6:51 | ETA: 27:27 | Speed: 6.8 FPS
    Cleaned 25 dead tracks, 7 active
  📊 30% complete (4232/14109)
     ⏱️  Elapsed: 11:30 | ETA: 26:52 | Speed: 6.1 FPS
  📊 40% complete (5643/14109)
     ⏱️  Elapsed: 16:20 | ETA: 24:30 | Speed: 5.8 FPS
    Cleaned 3 dead tracks, 5 active
  📊 50% complete (7054/14109)
     ⏱️  Elapsed: 21:09 | ETA: 21:10 | Speed: 5.6 FPS
  📊 60% complete (8465/14109)
     ⏱️  Elapsed: 26:03 | ETA: 17:22 | Speed: 5.4 FPS
    Cleaned 11 dead tracks, 0 act

In [None]:
# === 5.5 SAMPLE CROPS FROM EACH CLUSTER ===

print("Sample crops from each cluster:\n")

# Calculate grid layout
num_clusters = len(cluster_info)
samples_per_cluster = 8

fig, axes = plt.subplots(num_clusters, samples_per_cluster, figsize=(20, num_clusters * 2.5))

# Handle single cluster case
if num_clusters == 1:
    axes = axes.reshape(1, -1)

for cluster_idx, info in enumerate(cluster_info):
    cluster_mask = team_labels == info['id']
    cluster_indices = np.where(cluster_mask)[0]
    
    # Sample random crops from this cluster
    sample_size = min(samples_per_cluster, len(cluster_indices))
    sample_indices = np.random.choice(cluster_indices, sample_size, replace=False)
    
    # Add row label
    row_label = f"{team_mapping[info['id']]}\n({info['size']} detections)"
    
    for col_idx in range(samples_per_cluster):
        ax = axes[cluster_idx, col_idx] if num_clusters > 1 else axes[col_idx]
        
        if col_idx < sample_size:
            crop_idx = sample_indices[col_idx]
            crop = all_crops[crop_idx]
            crop_rgb = cv2.cvtColor(cv2.resize(crop, (128, 128)), cv2.COLOR_BGR2RGB)
            
            ax.imshow(crop_rgb)
            ax.set_title(f'Det #{crop_idx}', fontsize=8)
        
        # Add row label on first column
        if col_idx == 0:
            ax.set_ylabel(row_label, fontsize=10, fontweight='bold')
        
        ax.axis('off')

plt.suptitle('Sample Player Crops per Cluster', fontsize=14, y=0.998)
plt.tight_layout()
plt.show()

print("✓ Sample crops displayed!\n")