# BGAI Worker - Colab Edition

This notebook runs a **game worker** or **eval worker** on Google Colab that connects back to your main training node via Tailscale. Supports both **TPU** and **GPU** runtimes.

## Architecture

The distributed training system uses **Redis-based coordination** (no Ray required):
- Workers connect directly to Redis on the head node
- Model weights are synchronized via Redis
- Experiences are sent to a Redis replay buffer
- Workers auto-register and send heartbeats

## Worker Types

- **Game Worker**: Generates self-play games using MCTS, sends experiences to replay buffer
- **Eval Worker**: Evaluates model against baselines (random, self-play)

## Prerequisites

1. **Tailscale Auth Key**: Generate a reusable auth key at https://login.tailscale.com/admin/settings/keys
2. **Head Node Running**: Your main node should have:
   - Tailscale installed and running
   - Redis server running (port 6379)
   - Coordinator running (`python -m distributed.cli.main coordinator`)
3. **Colab Runtime**: Select GPU or TPU (Runtime -> Change runtime type)
4. **Colab Secrets**: Add `tailscale-key` and `redis-pass` to Colab secrets

## 0. Fix JAX Version (Required)

Colab's default JAX version has a bug. We need to upgrade it first.

In [None]:
# IMPORTANT: Upgrade JAX first to fix version incompatibility
# This must run BEFORE any other imports
import os

# Detect if we're on TPU or GPU
if 'COLAB_TPU_ADDR' in os.environ:
    print("TPU detected - installing JAX with TPU support...")
    !pip install -q --upgrade jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
else:
    print("GPU/CPU detected - installing JAX with CUDA support...")
    !pip install -q --upgrade "jax[cuda12]"

# Also upgrade optax and flax to compatible versions
!pip install -q --upgrade optax flax

print("\nJAX upgraded. You may need to restart the runtime if prompted.")
print("After restart, skip this cell and continue from the next one.")

In [None]:
# Verify JAX installation
import jax
import jax.numpy as jnp

# Force JAX initialization
_ = jnp.ones(1) + 1

print(f"JAX version: {jax.__version__}")
print(f"JAX backend: {jax.default_backend()}")
print(f"JAX devices: {jax.devices()}")

ACCELERATOR = jax.default_backend()
if ACCELERATOR == 'gpu':
    !nvidia-smi --query-gpu=name,memory.total --format=csv,noheader
elif ACCELERATOR == 'tpu':
    print(f"TPU cores: {len(jax.devices())}")
else:
    print("WARNING: Running on CPU - this will be slow!")

## 1. Install Tailscale

In [None]:
# Install Tailscale on Colab
!curl -fsSL https://tailscale.com/install.sh | sh
!pip install -q pysocks

In [None]:
# Start Tailscale daemon in userspace mode (required for Colab)
import time

!pkill tailscaled 2>/dev/null || true
!nohup tailscaled --tun=userspace-networking --socks5-server=localhost:1055 --socket=/var/run/tailscale/tailscaled.sock > /dev/null 2>&1 &
time.sleep(3)
print("Tailscale daemon started")

In [None]:
# Authenticate with Tailscale
from google.colab import userdata

TAILSCALE_AUTH_KEY = userdata.get('tailscale-key')
!sudo tailscale up --authkey={TAILSCALE_AUTH_KEY} --hostname=colab-bgai-worker

print("\nColab Tailscale IP:")
!tailscale ip -4

In [None]:
# Configuration - Set your head node's Tailscale IP
from google.colab import userdata

HEAD_NODE_IP = "100.105.50.111"  # Your head node's Tailscale IP
REDIS_PASSWORD = userdata.get('redis-pass')

print(f"Head Node IP: {HEAD_NODE_IP}")
print(f"Redis: {HEAD_NODE_IP}:6379")

In [None]:
# Test connectivity to head node
!tailscale ping -c 3 {HEAD_NODE_IP}

In [None]:
# Configure SOCKS proxy for Redis connection through Tailscale
import socket
import socks

socks.set_default_proxy(socks.SOCKS5, "127.0.0.1", 1055)
socket.socket = socks.socksocket
print("SOCKS proxy configured for Tailscale")

In [None]:
# Test Redis connectivity
!pip install -q redis
import redis

try:
    r = redis.Redis(host=HEAD_NODE_IP, port=6379, password=REDIS_PASSWORD, decode_responses=True)
    r.ping()
    print("Redis connection: OK")
    print(f"  Model version: {r.get('bgai:model:version') or 'None'}")
    print(f"  Buffer episodes: {r.llen('bgai:buffer:episodes')}")
except Exception as e:
    print(f"Redis connection FAILED: {e}")

## 2. Install Dependencies

In [None]:
# Install remaining dependencies
!pip install -q pgx chex redis msgpack msgpack-numpy prometheus_client psutil

In [None]:
# Install turbozero
!pip install -q git+https://github.com/sile16/turbozero.git@main

In [None]:
# Clone bgai repository
!git clone https://github.com/sile16/bgai.git /content/bgai 2>/dev/null || (cd /content/bgai && git pull)

import sys
sys.path.insert(0, '/content/bgai')
print("bgai repository ready")

## 3. Configure Worker

In [None]:
# Worker configuration
import uuid

# =============================================================================
# CHOOSE WORKER TYPE: 'game' or 'eval'
# =============================================================================
WORKER_TYPE = 'game'  # @param ['game', 'eval']

# Common configuration - adjust based on your accelerator's memory
# TPU: can use larger batch sizes (32-64)
# GPU T4: use 16-32
# GPU A100: use 64-128
BATCH_SIZE = 16  # @param {type:"integer"}
NUM_SIMULATIONS = 100  # @param {type:"integer"}
MAX_NODES = 400  # @param {type:"integer"}

# Generate unique worker ID
WORKER_ID = f"colab-{WORKER_TYPE}-{str(uuid.uuid4())[:8]}"

# Build config based on worker type
if WORKER_TYPE == 'game':
    WORKER_CONFIG = {
        'batch_size': BATCH_SIZE,
        'num_simulations': NUM_SIMULATIONS,
        'max_nodes': MAX_NODES,
        'temperature': 1.0,
        'max_episode_steps': 500,
        'redis_host': HEAD_NODE_IP,
        'redis_port': 6379,
        'redis_password': REDIS_PASSWORD,
        'metrics_port': 9100,
    }
else:  # eval
    WORKER_CONFIG = {
        'batch_size': BATCH_SIZE,
        'num_simulations': NUM_SIMULATIONS * 2,  # Eval uses more simulations
        'max_nodes': MAX_NODES * 2,
        'eval_games': 100,
        'eval_interval': 300,
        'eval_types': ['random', 'self_play'],  # gnubg not available on Colab
        'redis_host': HEAD_NODE_IP,
        'redis_port': 6379,
        'redis_password': REDIS_PASSWORD,
        'metrics_port': 9300,
    }

print(f"Worker Type: {WORKER_TYPE.upper()}")
print(f"Worker ID: {WORKER_ID}")
print(f"Accelerator: {ACCELERATOR.upper()}")
print(f"\nConfiguration:")
for k, v in WORKER_CONFIG.items():
    if k != 'redis_password':
        print(f"  {k}: {v}")

## 4. Start Worker

In [None]:
# Create the worker
# Note: This imports JAX-dependent modules, so JAX must be initialized first
print(f"Creating {WORKER_TYPE} worker...")

try:
    if WORKER_TYPE == 'game':
        from distributed.workers.game_worker import GameWorker
        worker = GameWorker(config=WORKER_CONFIG, worker_id=WORKER_ID)
    else:
        from distributed.workers.eval_worker import EvalWorker
        worker = EvalWorker(config=WORKER_CONFIG, worker_id=WORKER_ID)
    
    print(f"{WORKER_TYPE.title()} worker created successfully!")
except Exception as e:
    print(f"ERROR creating worker: {e}")
    import traceback
    traceback.print_exc()
    raise

In [None]:
# Start the worker (runs indefinitely until interrupted)
print(f"Starting {WORKER_TYPE} worker...")
print("Press the stop button (or Runtime -> Interrupt) to stop.")
print("-" * 50)

try:
    result = worker.run(num_iterations=-1)  # -1 = infinite
    print(f"Worker finished: {result}")
except KeyboardInterrupt:
    print("\nInterrupted - stopping worker...")
    worker.stop()
    print("Worker stopped.")
except Exception as e:
    print(f"Worker error: {e}")
    import traceback
    traceback.print_exc()

## 5. Monitoring (Optional)

Run these cells after stopping the worker or in a separate notebook.

In [None]:
# Check Redis buffer and cluster status
import redis

try:
    r = redis.Redis(host=HEAD_NODE_IP, port=6379, password=REDIS_PASSWORD, decode_responses=True)
    
    print("Cluster Status:")
    print(f"  Model version: {r.get('bgai:model:version') or 'None'}")
    print(f"  Buffer episodes: {r.llen('bgai:buffer:episodes')}")
    
    # Check registered workers
    worker_keys = r.keys('bgai:worker:*:info')
    print(f"  Registered workers: {len(worker_keys)}")
    for key in worker_keys:
        worker_id = key.split(':')[2]
        info = r.hgetall(key)
        print(f"    - {worker_id}: {info.get('worker_type', '?')} ({info.get('status', '?')})")
except Exception as e:
    print(f"Could not check status: {e}")

## 6. Cleanup

In [None]:
# Stop the worker gracefully
try:
    worker.stop()
    print("Worker stopped and deregistered.")
except Exception as e:
    print(f"Error stopping worker: {e}")

In [None]:
# Disconnect Tailscale (optional)
!sudo tailscale down
print("Tailscale disconnected.")

## Troubleshooting

### JAX Version Errors
- **TypeError with @jit**: Run the JAX upgrade cell first, then restart runtime if prompted
- **Import errors**: Make sure you ran the JAX upgrade cell before importing anything

### Kernel Crashes
- **On worker creation**: Usually means JAX ran out of memory. Reduce `BATCH_SIZE`.
- **TPU initialization**: Make sure the JAX initialization cell ran successfully before creating the worker.

### Tailscale Issues
- **Auth key expired**: Generate a new key at https://login.tailscale.com/admin/settings/keys
- **Can't reach head node**: Check that Tailscale is running on both machines

### Redis Issues
- **Auth failed**: Check the Redis password matches
- **Connection refused**: Ensure Redis is bound to 0.0.0.0 or Tailscale IP

### TPU/GPU/Memory Issues
- **OOM**: Reduce `BATCH_SIZE` in configuration (try 8 or 4)
- **No accelerator**: Check Colab runtime type (Runtime -> Change runtime type)

### Worker Issues
- **Model not found**: The worker will use random weights until training publishes a model
- **Coordinator not running**: Start coordinator on head node: `python -m distributed.cli.main coordinator`