# 01 roboflow track

**in**: match video (mp4)  
**out**: `data/track/{match}_track.csv`

tracking csv schema:
```
frame,t_sec,track_id,cls,team,x,y
0,0.00,1,player,0,45.2,32.1
```

In [None]:
# cell 1: imports + paths
from pathlib import Path

import cv2
import numpy as np
import pandas as pd
import supervision as sv
import torch
from sklearn.cluster import KMeans
from ultralytics import YOLO

from sports.common.view import ViewTransformer
from sports.configs.soccer import SoccerPitchConfiguration


class SimpleTeamClassifier:
    """Simple team classifier based on dominant color clustering."""

    def __init__(self):
        self.kmeans = None

    def fit(self, crops):
        """Fit on player crops by extracting dominant colors."""
        colors = []
        for crop in crops:
            if crop.size == 0:
                continue
            h, w = crop.shape[:2]
            mid_h, mid_w = h // 4, w // 4
            mid_region = crop[mid_h:3*mid_h, mid_w:3*mid_w]
            if mid_region.size > 0:
                avg_color = mid_region.mean(axis=(0, 1))
                colors.append(avg_color)

        if len(colors) >= 2:
            colors = np.array(colors)
            self.kmeans = KMeans(n_clusters=2, random_state=42, n_init=10)
            self.kmeans.fit(colors)

    def predict(self, crops):
        """Predict team for each crop."""
        if self.kmeans is None:
            return np.zeros(len(crops), dtype=int)

        teams = []
        for crop in crops:
            if crop.size == 0:
                teams.append(0)
                continue
            h, w = crop.shape[:2]
            mid_h, mid_w = h // 4, w // 4
            mid_region = crop[mid_h:3*mid_h, mid_w:3*mid_w]
            if mid_region.size > 0:
                avg_color = mid_region.mean(axis=(0, 1)).reshape(1, -1)
                team = self.kmeans.predict(avg_color)[0]
                teams.append(int(team))
            else:
                teams.append(0)

        return np.array(teams)


# paths - find first video in directory
VIDEO_DIR = Path('../data/video')
video_files = list(VIDEO_DIR.glob('*.mp4'))
if not video_files:
    raise FileNotFoundError(f'no video found in {VIDEO_DIR}')
VIDEO_PATH = video_files[0]

MODEL_DIR = Path('../data/models')
TRACK_DIR = Path('../data/track')
TRACK_DIR.mkdir(parents=True, exist_ok=True)

MATCH_NAME = VIDEO_PATH.stem
TRACK_CSV = TRACK_DIR / f'{MATCH_NAME}_track.csv'

# device selection
DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
print(f'using device: {DEVICE}')
print(f'video: {VIDEO_PATH}')

In [None]:
# cell 2: load yolo models (player + pitch keypoints)

# download from roboflow if missing:
# https://universe.roboflow.com/roboflow-jvuqo/football-players-detection-3zvbc
# https://universe.roboflow.com/roboflow-jvuqo/football-field-detection-f07vi

PL_MODEL_PATH = MODEL_DIR / 'football-player-detection.pt'
KP_MODEL_PATH = MODEL_DIR / 'football-pitch-detection.pt'

if not PL_MODEL_PATH.exists():
    raise FileNotFoundError(f'player model not found: {PL_MODEL_PATH}')
if not KP_MODEL_PATH.exists():
    raise FileNotFoundError(f'pitch model not found: {KP_MODEL_PATH}')

pl_model = YOLO(str(PL_MODEL_PATH))
kp_model = YOLO(str(KP_MODEL_PATH))

# move to device
pl_model.to(DEVICE)
kp_model.to(DEVICE)

# class ids from roboflow soccer models
CLS_BALL = 0
CLS_GK = 1
CLS_PLAYER = 2
CLS_REF = 3

CLS_NAMES = {CLS_BALL: 'ball', CLS_GK: 'goalkeeper', CLS_PLAYER: 'player', CLS_REF: 'referee'}

print(f'loaded models: player={PL_MODEL_PATH.name}, pitch={KP_MODEL_PATH.name}')

In [None]:
# cell 3: setup ByteTrack + team classifier

# pitch configuration (12000 x 7000 cm = 120m x 70m)
pitch_cfg = SoccerPitchConfiguration()
PITCH_LENGTH_CM = pitch_cfg.length  # 12000 cm
PITCH_WIDTH_CM = pitch_cfg.width    # 7000 cm

# normalize to meters
PITCH_X = PITCH_LENGTH_CM / 100  # 120m
PITCH_Y = PITCH_WIDTH_CM / 100   # 70m
print(f'pitch dimensions: {PITCH_X}m x {PITCH_Y}m')

# tracker
tracker = sv.ByteTrack(minimum_consecutive_frames=3)

# collect crops for team classifier (sample every 60 frames)
print('collecting player crops for team classification...')
crops = []
crop_limit = 300  # max crops to collect

for frame in sv.get_video_frames_generator(str(VIDEO_PATH), stride=60):
    det = sv.Detections.from_ultralytics(pl_model(frame, imgsz=1280, verbose=False)[0])
    player_det = det[det.class_id == CLS_PLAYER]
    for xyxy in player_det.xyxy[:5]:  # limit per frame
        crop = sv.crop_image(frame, xyxy)
        if crop.size > 0:
            crops.append(crop)
    if len(crops) >= crop_limit:
        break

print(f'collected {len(crops)} crops')

# fit team classifier
team_clf = SimpleTeamClassifier()
team_clf.fit(crops)
print('team classifier fitted')

In [None]:
# cell 4: process video -> tracking csv

video_info = sv.VideoInfo.from_video_path(str(VIDEO_PATH))
fps = video_info.fps
total_frames = video_info.total_frames
print(f'video: {total_frames} frames @ {fps} fps ({total_frames/fps:.1f}s)')

# process every 2nd frame for speed
STRIDE = 2
rows = []
frame_idx = 0
skipped = 0

for frame in sv.get_video_frames_generator(str(VIDEO_PATH), stride=STRIDE):
    # pitch keypoints
    kp_result = kp_model(frame, verbose=False)[0]
    kps = sv.KeyPoints.from_ultralytics(kp_result)
    
    # need at least 4 keypoints for homography
    if len(kps.xy) == 0:
        frame_idx += STRIDE
        skipped += 1
        continue
        
    mask = (kps.xy[0][:, 0] > 1) & (kps.xy[0][:, 1] > 1)
    if mask.sum() < 4:
        frame_idx += STRIDE
        skipped += 1
        continue
    
    # view transformer (pixel -> pitch coords in cm)
    try:
        vtf = ViewTransformer(
            source=kps.xy[0][mask].astype(np.float32),
            target=np.array(pitch_cfg.vertices)[mask].astype(np.float32),
        )
    except Exception:
        frame_idx += STRIDE
        skipped += 1
        continue
    
    # player detection + tracking
    det = sv.Detections.from_ultralytics(pl_model(frame, imgsz=1280, verbose=False)[0])
    det = tracker.update_with_detections(det)
    
    if len(det) == 0:
        frame_idx += STRIDE
        continue
    
    # team classification for players
    player_mask = det.class_id == CLS_PLAYER
    player_crops = [sv.crop_image(frame, xyxy) for xyxy in det.xyxy[player_mask]]
    player_teams = team_clf.predict(player_crops) if len(player_crops) > 0 else np.array([])
    
    # get bottom-center anchor points
    xy = det.get_anchors_coordinates(anchor=sv.Position.BOTTOM_CENTER)
    pitch_xy_cm = vtf.transform_points(points=xy)
    
    # convert cm to meters
    pitch_xy = pitch_xy_cm / 100.0
    
    # build team lookup
    team_ids = []
    player_idx = 0
    for cls_id in det.class_id:
        if cls_id == CLS_PLAYER:
            team_ids.append(int(player_teams[player_idx]))
            player_idx += 1
        elif cls_id == CLS_GK:
            team_ids.append(2)  # goalkeeper
        elif cls_id == CLS_REF:
            team_ids.append(3)  # referee
        else:
            team_ids.append(4)  # ball or other
    
    # append rows
    t_sec = frame_idx / fps
    for tid, cls_id, team_id, pos in zip(det.tracker_id, det.class_id, team_ids, pitch_xy):
        if tid is None:
            continue
        rows.append({
            'frame': frame_idx,
            't_sec': round(t_sec, 3),
            'track_id': int(tid),
            'cls': CLS_NAMES.get(cls_id, 'unknown'),
            'team': int(team_id),
            'x': round(float(pos[0]), 2),  # meters
            'y': round(float(pos[1]), 2),  # meters
        })
    
    frame_idx += STRIDE
    
    # progress
    if frame_idx % 200 == 0:
        print(f'processed {frame_idx}/{total_frames} frames, {len(rows)} detections')

print(f'done: {len(rows)} total detections, {skipped} frames skipped')

In [None]:
# cell 5: save csv + sanity plot

import matplotlib.pyplot as plt
from mplsoccer import Pitch

# save tracking csv
track_df = pd.DataFrame(rows)
track_df.to_csv(TRACK_CSV, index=False)
print(f'saved: {TRACK_CSV}')
print(f'shape: {track_df.shape}')
print(track_df.head(10))

# sanity check: plot positions on pitch (120m x 70m)
pitch = Pitch(pitch_type='custom', pitch_length=120, pitch_width=70, line_color='white')
fig, ax = pitch.draw(figsize=(12, 8))
fig.set_facecolor('#1a1a1a')

# sample middle of video
frames = track_df['frame'].unique()
mid_frame = frames[len(frames)//2]
sample = track_df[track_df['frame'] == mid_frame]

colors = {0: 'red', 1: 'blue', 2: 'yellow', 3: 'black', 4: 'white'}
for _, row in sample.iterrows():
    c = colors.get(row['team'], 'gray')
    marker = 'o' if row['cls'] != 'ball' else 's'
    size = 100 if row['cls'] != 'ball' else 150
    ax.scatter(row['x'], row['y'], c=c, s=size, marker=marker, edgecolors='white', linewidths=1)

ax.set_title(f'frame {mid_frame} - team positions', color='white', fontsize=14)
plt.tight_layout()
plt.show()

# stats
print('\ntracking stats:')
print(f"  frames: {track_df['frame'].nunique()}")
print(f"  unique tracks: {track_df['track_id'].nunique()}")
print(f"  class distribution:")
print(track_df['cls'].value_counts())
print(f"  team distribution:")
print(track_df['team'].value_counts())
print(f"\n  position ranges:")
print(f"    x: {track_df['x'].min():.1f} to {track_df['x'].max():.1f} m")
print(f"    y: {track_df['y'].min():.1f} to {track_df['y'].max():.1f} m")

## verification checklist

- [ ] tracking csv has >1000 rows for 1min video
- [ ] positions plot looks sensible on pitch
- [ ] ball track_id persists through passes
- [ ] team classification mostly correct