# vlm-gym TPU Training Setup

Quick setup for training on Colab TPU with low-memory optimizations.

**Runtime**: Select `Runtime > Change runtime type > TPU` before running.

In [None]:
# Install uv package manager
!curl -LsSf https://astral.sh/uv/install.sh | sh
import os
os.environ['PATH'] = f"{os.path.expanduser('~/.cargo/bin')}:{os.environ['PATH']}"

In [None]:
# Clone repo and checkout low-mem-gpu branch
%cd /content
!git clone https://github.com/sdan/vlm-gym.git
%cd vlm-gym
!git checkout low-mem-gpu

In [None]:
# Install dependencies
!uv venv .venv --python 3.10
!uv pip install -e .

# Install JAX TPU version
!pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

In [None]:
# Clear any TPU locks (if runtime was restarted)
!sudo pkill -9 python3 || true
!sudo rm -rf /tmp/libtpu_lockfile /tmp/tpu_logs || true

# Verify TPU setup
import os
os.environ['JAX_PLATFORMS'] = 'tpu'

import jax
print(f"JAX version: {jax.__version__}")
print(f"Devices: {jax.devices()}")
print(f"Device count: {jax.device_count()}")
print(f"Platform: {jax.devices()[0].platform}")

if jax.devices()[0].platform != 'tpu':
    print("\n⚠ WARNING: TPU not detected!")
    print("Please select: Runtime > Change runtime type > TPU")
    raise RuntimeError("TPU runtime required")

In [None]:
# Download pre-converted checkpoint from GCS
!mkdir -p checkpoints
!gsutil -m cp -r gs://geospot/checkpoints/qwen3vl_4b checkpoints/

# Verify download
!ls -lh checkpoints/qwen3vl_4b/

## Run Training

Single rollout + 100 steps with TPU-optimized settings.

In [None]:
# Train with low-memory optimizations for TPU v3-8
# - bf16 params + Adafactor + gradient checkpointing
# - ppo_minibatch=8 enables pmap across 8 TPU cores (1 sample/core)
# - Reduced batch_size, max_new_tokens, and vlm_max_pixels to fit in memory

!uv run python -m vlmrl.core.train \
  --low_memory=1 \
  --model_dir=checkpoints/qwen3vl_4b \
  --save_dir=runs/colab-tpu-test \
  --wandb_mode=offline \
  --wandb_name=colab-tpu-test \
  --env_name=geospot \
  --env_split=test \
  --total_steps=100 \
  --batch_size=2 \
  --log_interval=10 \
  --temperature=0.7 \
  --max_new_tokens=24 \
  --vlm_max_pixels=65000 \
  --ppo_minibatch=8 \
  --ppo_epochs=1 \
  --optimizer=adafactor \
  --learning_rate=1e-6 \
  --max_grad_norm=1.0 \
  --entropy_coef=0.0 \
  --use_ema=0 \
  --grad_checkpoint=1

## Verify Results

In [None]:
# Check checkpoint
import os
ckpt_path = "runs/colab-tpu-test/train_state.pkl"
if os.path.exists(ckpt_path):
    size_mb = os.path.getsize(ckpt_path) / 1024 / 1024
    print(f"✓ Training complete. Checkpoint: {size_mb:.1f} MB")
else:
    print("✗ No checkpoint found")

## Continue Training (Optional)

In [None]:
# Resume from checkpoint for longer training
!python -m vlmrl.core.train \
  --resume_path=runs/colab-tpu-test/train_state.pkl \
  --total_steps=1000 \
  --low_memory=1 \
  --model_dir=checkpoints/qwen3vl_4b \
  --save_dir=runs/colab-tpu-test