# Tennis TrackNet 2x — Colab Training

Train **TrackNet2x** (576x1024 resolution) on the tennis dataset with GPU optimizations.

**Optimizations enabled on Colab:**
- Mixed precision (FP16) — ~2x speedup
- torch.compile — ~1.3x speedup
- Larger batch size (8-16 vs 2-4 locally)
- Precomputed frames for fast data loading

**Requirements:** GPU runtime (T4 or better), Google Drive mounted

## 1. Setup

In [None]:
# Check GPU
!nvidia-smi --query-gpu=name,memory.total --format=csv,noheader

import torch
print(f'PyTorch: {torch.__version__}')
print(f'CUDA available: {torch.cuda.is_available()}')
if torch.cuda.is_available():
    print(f'GPU: {torch.cuda.get_device_name(0)}')
    mem_gb = torch.cuda.get_device_properties(0).total_mem / 1e9
    print(f'Memory: {mem_gb:.1f} GB')
    # Recommend batch size based on GPU memory
    if mem_gb >= 40:
        print('Recommended batch_size: 16')
    elif mem_gb >= 15:
        print('Recommended batch_size: 8')
    else:
        print('Recommended batch_size: 4')

In [None]:
# Mount Google Drive (for persistent storage across sessions)
from google.colab import drive
drive.mount('/content/drive')

# Persistent storage dir
DRIVE_DIR = '/content/drive/MyDrive/tennis-tracknet'
!mkdir -p {DRIVE_DIR}

In [None]:
# Clone private repo (will prompt for GitHub auth)
import os
if not os.path.exists('/content/tennis-tracknet'):
    !git clone https://github.com/smyng/tennis-tracknet.git /content/tennis-tracknet
else:
    !cd /content/tennis-tracknet && git pull

os.chdir('/content/tennis-tracknet')
!pwd

In [None]:
# Install dependencies
!pip install -q parse tqdm tensorboard

## 2. Dataset

Downloads the TrackNet v1 tennis dataset and converts it to TrackNetV3 format.
Converted data is cached in Google Drive so you only do this once.

In [None]:
import os

DATA_DRIVE = f'{DRIVE_DIR}/data'
DATA_LOCAL = '/content/tennis-tracknet/data'

# Check if converted data already exists in Drive
if os.path.exists(f'{DATA_DRIVE}/train/match1'):
    print('Converted dataset found in Drive, symlinking...')
    !rm -rf {DATA_LOCAL}
    !ln -s {DATA_DRIVE} {DATA_LOCAL}
    !ls {DATA_LOCAL}/train/
    print('Done.')
else:
    print('No converted dataset in Drive. Will download and convert.')
    print('This takes ~15-20 min the first time.')

In [None]:
# Download raw tennis dataset (skip if data already linked above)
import os
if not os.path.exists(f'{DATA_LOCAL}/train/match1'):
    !pip install -q gdown
    
    RAW_DIR = '/content/raw-tennis-dataset'
    !mkdir -p {RAW_DIR}
    
    # Download from the TrackNet v1 dataset Google Drive
    # Folder: https://drive.google.com/drive/folders/11r0RUaQHX7I3ANkaYG4jOxXK1OYo01Ut
    import gdown
    gdown.download_folder(
        'https://drive.google.com/drive/folders/11r0RUaQHX7I3ANkaYG4jOxXK1OYo01Ut',
        output=RAW_DIR, quiet=False
    )
    
    # The dataset is inside Dataset.zip — unzip it
    import zipfile
    zip_path = os.path.join(RAW_DIR, 'Dataset.zip')
    if os.path.exists(zip_path):
        print('Extracting Dataset.zip...')
        with zipfile.ZipFile(zip_path, 'r') as z:
            z.extractall(RAW_DIR)
        os.remove(zip_path)
    
    # Find the directory containing game1/, game2/, etc.
    DATASET_DIR = RAW_DIR
    for candidate in [os.path.join(RAW_DIR, 'Dataset'), RAW_DIR]:
        if os.path.exists(os.path.join(candidate, 'game1')):
            DATASET_DIR = candidate
            break
    
    print(f'Dataset directory: {DATASET_DIR}')
    !ls {DATASET_DIR}/

In [None]:
# Convert dataset and save to Drive for persistence
import os
if not os.path.exists(f'{DATA_LOCAL}/train/match1'):
    # Find the dataset directory (set by previous cell, or detect it)
    if 'DATASET_DIR' not in dir():
        RAW_DIR = '/content/raw-tennis-dataset'
        DATASET_DIR = os.path.join(RAW_DIR, 'Dataset') if os.path.exists(os.path.join(RAW_DIR, 'Dataset', 'game1')) else RAW_DIR

    !python scripts/convert_tennis_dataset.py \
        --input {DATASET_DIR} \
        --output {DATA_DRIVE} \
        --test-games 9 10 \
        --verbose
    
    # Symlink Drive data into repo
    !rm -rf {DATA_LOCAL}
    !ln -s {DATA_DRIVE} {DATA_LOCAL}
    
    print('Conversion complete.')
    !ls {DATA_LOCAL}/train/

In [None]:
# Generate median.npz files (required by dataset.py for background subtraction)
# Saved to Drive, so this only runs once
import os, cv2, numpy as np
from pathlib import Path
from tqdm import tqdm

data_dir = DATA_DRIVE
needs_generation = False

# Check if median files already exist
for split in ['train', 'test']:
    split_dir = Path(data_dir) / split
    if not split_dir.exists():
        continue
    for match_dir in split_dir.iterdir():
        frame_root = match_dir / 'frame'
        if not frame_root.exists():
            continue
        for rally_dir in frame_root.iterdir():
            if rally_dir.is_dir() and not (rally_dir / 'median.npz').exists():
                needs_generation = True
                break
        if needs_generation:
            break
    if needs_generation:
        break

if needs_generation:
    print('Generating median.npz files (one-time)...')
    for split in ['train', 'test']:
        split_dir = Path(data_dir) / split
        if not split_dir.exists():
            continue
        for match_dir in sorted(split_dir.iterdir()):
            if not match_dir.is_dir():
                continue
            frame_root = match_dir / 'frame'
            if not frame_root.exists():
                continue
            rally_medians = []
            for rally_dir in tqdm(sorted(frame_root.iterdir()), desc=f'{split}/{match_dir.name}'):
                if not rally_dir.is_dir():
                    continue
                median_file = rally_dir / 'median.npz'
                if median_file.exists():
                    rally_medians.append(np.load(str(median_file))['median'])
                    continue
                frames = sorted(rally_dir.glob('*.png'))
                if not frames:
                    continue
                step = max(1, len(frames) // 50)
                sampled = frames[::step][:50]
                imgs = [cv2.imread(str(f))[..., ::-1] for f in sampled]
                median = np.median(np.array(imgs), axis=0)
                np.savez(str(median_file), median=median)
                rally_medians.append(median)
            # Match-level median
            match_median = match_dir / 'median.npz'
            if not match_median.exists() and rally_medians:
                median = np.median(np.array(rally_medians), axis=0)
                np.savez(str(match_median), median=median)
    print('Done.')
else:
    print('median.npz files already exist.')

## 3. Pretrained Weights

Download the original TrackNetV3 badminton (shuttlecock) checkpoint.

In [None]:
import os

CKPT_DIR = '/content/tennis-tracknet/ckpts'
os.makedirs(CKPT_DIR, exist_ok=True)

if not os.path.exists(f'{CKPT_DIR}/TrackNet_best.pt'):
    !pip install -q gdown
    import gdown
    
    # Original TrackNetV3 checkpoints
    # https://drive.google.com/file/d/1CfzE87a0f6LhBp0kniSl1-89zaLCZ8cA/view
    gdown.download(
        'https://drive.google.com/uc?id=1CfzE87a0f6LhBp0kniSl1-89zaLCZ8cA',
        output='/content/TrackNetV3_ckpts.zip', quiet=False
    )
    !cd /content && unzip -o TrackNetV3_ckpts.zip -d /content/ckpts_tmp/
    !rm /content/TrackNetV3_ckpts.zip
    
    # The zip contains a nested ckpts/ folder — flatten it
    import glob, shutil
    for pt_file in glob.glob('/content/ckpts_tmp/**/*.pt', recursive=True):
        shutil.move(pt_file, CKPT_DIR)
    !rm -rf /content/ckpts_tmp

print('Pretrained checkpoints:')
!ls -lh {CKPT_DIR}/*.pt

# Verify checkpoint
import torch
ckpt = torch.load(f'{CKPT_DIR}/TrackNet_best.pt', map_location='cpu', weights_only=False)
print(f"\nPretrained model: epoch {ckpt['epoch']}, bg_mode='{ckpt['param_dict']['bg_mode']}'")
print(f"Original training: seq_len={ckpt['param_dict']['seq_len']}, batch_size={ckpt['param_dict']['batch_size']}")

## 4. Precompute Frames (one-time)

Precomputes resized + background-subtracted frames as `.npy` files for fast data loading.
Cached in Google Drive so this only runs once (~20 min).

In [None]:
# Precompute frames (skip if already done)
# Since data/ is symlinked to Drive, output goes to Drive automatically
import os, glob

PRECOMPUTE_DIR = os.path.join(DATA_LOCAL, 'precomputed', 'subtract_concat_576x1024')

existing = glob.glob(os.path.join(PRECOMPUTE_DIR, '*.npy'))
if len(existing) > 90:
    print(f'Precomputed frames found: {len(existing)} files — skipping')
else:
    print(f'Precomputing frames ({len(existing)} found, need ~95)...')
    !python precompute_frames.py \
        --data_dir {DATA_LOCAL} \
        --bg_mode subtract_concat \
        --height 576 --width 1024 \
        --splits train val

In [None]:
# Upload local checkpoint to resume training on Colab
# Run this cell, then select your TrackNet2x_cur.pt file from local machine
import os
from google.colab import files

EXP_NAME = 'tennis_2x_colab'
SAVE_DIR = f'{DRIVE_DIR}/exps/{EXP_NAME}'
os.makedirs(SAVE_DIR, exist_ok=True)

# Symlink exps into repo
!mkdir -p {DRIVE_DIR}/exps
!rm -rf /content/tennis-tracknet/exps
!ln -s {DRIVE_DIR}/exps /content/tennis-tracknet/exps

# Upload checkpoint
print('Upload TrackNet2x_cur.pt (and optionally TrackNet2x_best.pt):')
uploaded = files.upload()

for name, data in uploaded.items():
    dest = os.path.join(SAVE_DIR, name)
    with open(dest, 'wb') as f:
        f.write(data)
    print(f'Saved {name} to {dest} ({len(data)/1e6:.1f} MB)')

# Verify
import torch
ckpt = torch.load(os.path.join(SAVE_DIR, 'TrackNet2x_cur.pt'), map_location='cpu', weights_only=False)
print(f"\nCheckpoint: epoch {ckpt['epoch']+1}, best_val_acc={ckpt['max_val_acc']:.4f}")

## 5. Train

**Option A:** Start fresh training from scratch (cell below).
**Option B:** Resume from a local checkpoint — upload it first, then use the resume cell.

In [None]:
# Training config (used by both fresh start and resume)
import os, torch

EXP_NAME = 'tennis_2x_colab'
SAVE_DIR = f'{DRIVE_DIR}/exps/{EXP_NAME}'
os.makedirs(SAVE_DIR, exist_ok=True)

# Symlink exps into repo so train.py can find them
!mkdir -p {DRIVE_DIR}/exps
!rm -rf /content/tennis-tracknet/exps
!ln -s {DRIVE_DIR}/exps /content/tennis-tracknet/exps

# Batch size: 16 for A100, 8 for T4/V100, 4 for P100
mem_gb = torch.cuda.get_device_properties(0).total_mem / 1e9
BATCH_SIZE = 16 if mem_gb >= 40 else (8 if mem_gb >= 15 else 4)

print(f'Experiment: {EXP_NAME}')
print(f'Batch size: {BATCH_SIZE} (GPU: {mem_gb:.0f} GB)')
print(f'Checkpoints: {SAVE_DIR}')

# Check for existing checkpoint
ckpt_path = os.path.join(SAVE_DIR, 'TrackNet2x_cur.pt')
if os.path.exists(ckpt_path):
    ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=False)
    print(f"Found checkpoint: epoch {ckpt['epoch']+1}, best_val_acc={ckpt['max_val_acc']:.4f}")
    print('Use the RESUME cell below to continue training.')
else:
    print('No checkpoint found. Use the FRESH START cell below.')

In [None]:
# Resume training (run Setup cells 1-5 first, then this)
EXP_NAME = 'tennis_2x_colab'

# Re-link Drive dirs
!rm -rf /content/tennis-tracknet/data /content/tennis-tracknet/exps
!ln -s {DRIVE_DIR}/data /content/tennis-tracknet/data
!ln -s {DRIVE_DIR}/exps /content/tennis-tracknet/exps

!python train.py \
    --model_name TrackNet2x \
    --epochs 30 \
    --fp16 \
    --compile \
    --num_workers 4 \
    --save_dir exps/{EXP_NAME} \
    --resume_training \
    --verbose

In [None]:
# RESUME — continue training from checkpoint (local or previous Colab run)
!python train.py \
    --model_name TrackNet2x \
    --epochs 30 \
    --batch_size {BATCH_SIZE} \
    --fp16 \
    --compile \
    --num_workers 4 \
    --save_dir exps/{EXP_NAME} \
    --resume_training \
    --verbose

In [None]:
# TensorBoard
EXP_NAME = 'tennis_2x_colab'
%load_ext tensorboard
%tensorboard --logdir {DRIVE_DIR}/exps/{EXP_NAME}/logs

In [None]:
# Evaluate best model on test set
EXP_NAME = 'tennis_2x_colab'

!python test.py \
    --split test \
    --tracknet_file exps/{EXP_NAME}/TrackNet2x_best.pt \
    --save_dir exps/{EXP_NAME}/eval

In [None]:
# Copy best model back to Drive for download
EXP_NAME = 'tennis_2x_colab'
!cp exps/{EXP_NAME}/TrackNet2x_best.pt {DRIVE_DIR}/{EXP_NAME}_best.pt
print(f'Best model saved to: {DRIVE_DIR}/{EXP_NAME}_best.pt')
print('You can download it from Google Drive.')