# Tennis TrackNet â€” Colab Training

Train TrackNet on the tennis dataset using **pretrained shuttlecock (badminton) weights** from the original TrackNetV3 repo.

**What this does:**
- Clones your private repo
- Downloads the original TrackNet v1 tennis dataset
- Downloads pretrained badminton weights
- Fine-tunes on tennis with those weights as initialization

**Requirements:** GPU runtime (T4 or better), Google Drive mounted

## 1. Setup

In [None]:
# Check GPU
!nvidia-smi --query-gpu=name,memory.total --format=csv,noheader

import torch
print(f'PyTorch: {torch.__version__}')
print(f'CUDA available: {torch.cuda.is_available()}')
if torch.cuda.is_available():
    print(f'GPU: {torch.cuda.get_device_name(0)}')
    print(f'Memory: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB')

In [None]:
# Mount Google Drive (for persistent storage across sessions)
from google.colab import drive
drive.mount('/content/drive')

# Persistent storage dir
DRIVE_DIR = '/content/drive/MyDrive/tennis-tracknet'
!mkdir -p {DRIVE_DIR}

In [None]:
# Clone private repo (will prompt for GitHub auth)
import os
if not os.path.exists('/content/tennis-tracknet'):
    !git clone https://github.com/smyng/tennis-tracknet.git /content/tennis-tracknet
else:
    !cd /content/tennis-tracknet && git pull

os.chdir('/content/tennis-tracknet')
!pwd

In [None]:
# Install dependencies
!pip install -q parse tqdm tensorboard

## 2. Dataset

Downloads the TrackNet v1 tennis dataset and converts it to TrackNetV3 format.
Converted data is cached in Google Drive so you only do this once.

In [None]:
import os

DATA_DRIVE = f'{DRIVE_DIR}/data'
DATA_LOCAL = '/content/tennis-tracknet/data'

# Check if converted data already exists in Drive
if os.path.exists(f'{DATA_DRIVE}/train/match1'):
    print('Converted dataset found in Drive, symlinking...')
    !rm -rf {DATA_LOCAL}
    !ln -s {DATA_DRIVE} {DATA_LOCAL}
    !ls {DATA_LOCAL}/train/
    print('Done.')
else:
    print('No converted dataset in Drive. Will download and convert.')
    print('This takes ~15-20 min the first time.')

In [None]:
# Download raw tennis dataset (skip if data already linked above)
import os
if not os.path.exists(f'{DATA_LOCAL}/train/match1'):
    !pip install -q gdown
    
    RAW_DIR = '/content/raw-tennis-dataset'
    !mkdir -p {RAW_DIR}
    
    # Download from the TrackNet v1 dataset Google Drive
    # Folder: https://drive.google.com/drive/folders/11r0RUaQHX7I3ANkaYG4jOxXK1OYo01Ut
    import gdown
    gdown.download_folder(
        'https://drive.google.com/drive/folders/11r0RUaQHX7I3ANkaYG4jOxXK1OYo01Ut',
        output=RAW_DIR, quiet=False
    )
    
    print(f'Downloaded to {RAW_DIR}')
    !ls {RAW_DIR}/

In [None]:
# Convert dataset and save to Drive for persistence
import os
if not os.path.exists(f'{DATA_LOCAL}/train/match1'):
    !python scripts/convert_tennis_dataset.py \
        --input /content/raw-tennis-dataset \
        --output {DATA_DRIVE} \
        --test-games 9 10 \
        --verbose
    
    # Symlink Drive data into repo
    !rm -rf {DATA_LOCAL}
    !ln -s {DATA_DRIVE} {DATA_LOCAL}
    
    print('Conversion complete.')
    !ls {DATA_LOCAL}/train/

## 3. Pretrained Weights

Download the original TrackNetV3 badminton (shuttlecock) checkpoint.

In [None]:
import os

CKPT_DIR = '/content/tennis-tracknet/ckpts'
os.makedirs(CKPT_DIR, exist_ok=True)

if not os.path.exists(f'{CKPT_DIR}/TrackNet_best.pt'):
    !pip install -q gdown
    import gdown
    
    # Original TrackNetV3 checkpoints
    # https://drive.google.com/file/d/1CfzE87a0f6LhBp0kniSl1-89zaLCZ8cA/view
    gdown.download(
        'https://drive.google.com/uc?id=1CfzE87a0f6LhBp0kniSl1-89zaLCZ8cA',
        output='/content/TrackNetV3_ckpts.zip', quiet=False
    )
    !cd /content && unzip -o TrackNetV3_ckpts.zip -d tennis-tracknet/ckpts/
    !rm /content/TrackNetV3_ckpts.zip

print('Pretrained checkpoints:')
!ls -lh {CKPT_DIR}/*.pt

# Verify checkpoint
import torch
ckpt = torch.load(f'{CKPT_DIR}/TrackNet_best.pt', map_location='cpu', weights_only=False)
print(f"\nPretrained model: epoch {ckpt['epoch']}, bg_mode='{ckpt['param_dict']['bg_mode']}'")
print(f"Original training: seq_len={ckpt['param_dict']['seq_len']}, batch_size={ckpt['param_dict']['batch_size']}")

## 4. Train

Fine-tune TrackNet on tennis using the pretrained shuttlecock weights.

**Note:** Uses `bg_mode='concat'` to match the pretrained model's architecture (27 input channels).
Your local run uses `subtract_concat` (32 channels) which can't directly load these weights.

In [None]:
# Training config
EXP_NAME = 'tennis_pretrained_shuttlecock'
SAVE_DIR = f'{DRIVE_DIR}/exps/{EXP_NAME}'  # Save to Drive for persistence
os.makedirs(SAVE_DIR, exist_ok=True)

# Also symlink exps into repo so train.py can find them
!mkdir -p {DRIVE_DIR}/exps
!rm -rf /content/tennis-tracknet/exps
!ln -s {DRIVE_DIR}/exps /content/tennis-tracknet/exps

print(f'Experiment: {EXP_NAME}')
print(f'Checkpoints will be saved to: {SAVE_DIR}')

In [None]:
# Start training with pretrained shuttlecock weights
!python train.py \
    --model_name TrackNet \
    --seq_len 8 \
    --epochs 30 \
    --batch_size 10 \
    --optim Adam \
    --learning_rate 0.001 \
    --bg_mode concat \
    --alpha 0.5 \
    --tolerance 4 \
    --save_dir exps/{EXP_NAME} \
    --pretrained ckpts/TrackNet_best.pt \
    --verbose

## 5. Resume Training (if disconnected)

Run this if your session disconnected. Checkpoints are saved to Google Drive.

In [None]:
# Resume training (run Setup cells 1-3 first, then this)
EXP_NAME = 'tennis_pretrained_shuttlecock'

# Re-link Drive dirs
!rm -rf /content/tennis-tracknet/data /content/tennis-tracknet/exps
!ln -s {DRIVE_DIR}/data /content/tennis-tracknet/data
!ln -s {DRIVE_DIR}/exps /content/tennis-tracknet/exps

!python train.py \
    --model_name TrackNet \
    --epochs 30 \
    --save_dir exps/{EXP_NAME} \
    --resume_training \
    --verbose

## 6. Monitor & Evaluate

In [None]:
# Check training progress
import torch, os, glob

EXP_NAME = 'tennis_pretrained_shuttlecock'
save_dir = f'{DRIVE_DIR}/exps/{EXP_NAME}'

for name in ['TrackNet_cur.pt', 'TrackNet_best.pt']:
    path = os.path.join(save_dir, name)
    if os.path.exists(path):
        ckpt = torch.load(path, map_location='cpu', weights_only=False)
        print(f'{name}:')
        print(f"  Epoch: {ckpt['epoch'] + 1} / 30")
        print(f"  Best val accuracy: {ckpt['max_val_acc']:.4f}")
        print()

In [None]:
# TensorBoard
%load_ext tensorboard
%tensorboard --logdir {DRIVE_DIR}/exps/{EXP_NAME}/logs

In [None]:
# Evaluate best model on test set
EXP_NAME = 'tennis_pretrained_shuttlecock'

!python test.py \
    --split test \
    --tracknet_file exps/{EXP_NAME}/TrackNet_best.pt \
    --save_dir exps/{EXP_NAME}/eval

In [None]:
# Copy best model back to Drive for download
EXP_NAME = 'tennis_pretrained_shuttlecock'
!cp exps/{EXP_NAME}/TrackNet_best.pt {DRIVE_DIR}/{EXP_NAME}_best.pt
print(f'Best model saved to: {DRIVE_DIR}/{EXP_NAME}_best.pt')
print('You can download it from Google Drive.')