# BGAI Game Worker - Colab Edition

This notebook runs a game worker on Google Colab that connects back to your main training node via Tailscale.

## Prerequisites

1. **Tailscale Auth Key**: Generate a reusable auth key at https://login.tailscale.com/admin/settings/keys
2. **Head Node Running**: Your main node should have:
   - Tailscale installed and running
   - Redis server running (port 6379)
   - Ray head node running (port 10001)
   - Coordinator running

## Setup Steps

1. Run the cells in order
2. Enter your Tailscale auth key when prompted
3. Enter your head node's Tailscale IP when prompted
4. The game worker will start generating games!

## 1. Install Tailscale

In [None]:
# Install Tailscale on Colab
!curl -fsSL https://tailscale.com/install.sh | sh

In [None]:
import getpass

# Get Tailscale auth key (hidden input)
print("Enter your Tailscale auth key (get one from https://login.tailscale.com/admin/settings/keys):")
TAILSCALE_AUTH_KEY = getpass.getpass("Tailscale Auth Key: ")

In [None]:
# Start Tailscale daemon and authenticate
import subprocess
import time

# Start tailscaled in userspace mode (required for Colab)
!sudo tailscaled --state=/var/lib/tailscale/tailscaled.state --socket=/run/tailscale/tailscaled.sock &
time.sleep(3)

# Authenticate with Tailscale
!sudo tailscale up --authkey={TAILSCALE_AUTH_KEY} --hostname=colab-game-worker

# Show our Tailscale IP
!tailscale ip -4

In [None]:
# Enter your head node's Tailscale IP
HEAD_NODE_IP = input("Enter head node Tailscale IP (e.g., 100.x.x.x): ").strip()
REDIS_PASSWORD = getpass.getpass("Redis password (default: bgai-password): ") or "bgai-password"

print(f"\nConfiguration:")
print(f"  Head Node IP: {HEAD_NODE_IP}")
print(f"  Ray Address: ray://{HEAD_NODE_IP}:10001")
print(f"  Redis: {HEAD_NODE_IP}:6379")

In [None]:
# Test connectivity to head node
import socket

def test_port(host, port, name):
    try:
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        sock.settimeout(5)
        result = sock.connect_ex((host, port))
        sock.close()
        if result == 0:
            print(f"  [OK] {name} ({host}:{port})")
            return True
        else:
            print(f"  [FAIL] {name} ({host}:{port}) - Connection refused")
            return False
    except Exception as e:
        print(f"  [FAIL] {name} ({host}:{port}) - {e}")
        return False

print("Testing connectivity to head node...")
redis_ok = test_port(HEAD_NODE_IP, 6379, "Redis")
ray_ok = test_port(HEAD_NODE_IP, 10001, "Ray Client")

if redis_ok and ray_ok:
    print("\nAll services reachable! Ready to proceed.")
else:
    print("\nSome services unreachable. Please check:")
    print("  1. Head node is running all services")
    print("  2. Tailscale is connected on both ends")
    print("  3. Firewall allows connections")

## 2. Install Python Dependencies

In [None]:
# Check GPU availability
!nvidia-smi

In [None]:
# Install JAX with GPU support for Colab
!pip install --upgrade "jax[cuda12]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

In [None]:
# Install other dependencies
!pip install pgx flax chex optax ray[client] redis msgpack msgpack-numpy prometheus_client

In [None]:
# Install turbozero from the repository
!pip install git+https://github.com/sile16/turbozero.git@main

In [None]:
# Clone the bgai repository
!git clone https://github.com/sile16/bgai.git /content/bgai 2>/dev/null || (cd /content/bgai && git pull)

import sys
sys.path.insert(0, '/content/bgai')

In [None]:
# Verify JAX GPU setup
import jax
print(f"JAX devices: {jax.devices()}")
print(f"JAX default backend: {jax.default_backend()}")

## 3. Test Redis Connection

In [None]:
import redis

# Test Redis connection
try:
    r = redis.Redis(
        host=HEAD_NODE_IP,
        port=6379,
        password=REDIS_PASSWORD,
        decode_responses=True
    )
    r.ping()
    print(f"Redis connection successful!")
    
    # Check buffer stats if available
    buffer_size = r.get('bgai:buffer:metadata:size')
    if buffer_size:
        print(f"Current buffer size: {buffer_size}")
except Exception as e:
    print(f"Redis connection failed: {e}")

## 4. Start Game Worker

In [None]:
# Game worker configuration
WORKER_CONFIG = {
    'batch_size': 16,           # Parallel environments (adjust based on GPU memory)
    'num_simulations': 100,     # MCTS iterations per move
    'max_nodes': 400,           # Max MCTS tree nodes
    'temperature': 1.0,         # Exploration temperature
    'redis_host': HEAD_NODE_IP,
    'redis_port': 6379,
    'redis_password': REDIS_PASSWORD,
    'metrics_port': 9100,       # Prometheus metrics port
    'heartbeat_interval': 10.0, # Seconds between heartbeats
}

# Generate unique worker ID
import uuid
WORKER_ID = f"colab-{str(uuid.uuid4())[:8]}"
print(f"Worker ID: {WORKER_ID}")
print(f"Configuration: {WORKER_CONFIG}")

In [None]:
# Connect to Ray cluster
import ray
import os

# Set environment for JAX memory management
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.7'

# Initialize Ray with connection to head node
ray_address = f"ray://{HEAD_NODE_IP}:10001"
print(f"Connecting to Ray cluster at {ray_address}...")

try:
    ray.init(
        address=ray_address,
        namespace="bgai",
        runtime_env={
            "env_vars": {
                "XLA_PYTHON_CLIENT_MEM_FRACTION": "0.7",
            }
        }
    )
    print("Connected to Ray cluster!")
    print(f"Available resources: {ray.available_resources()}")
except Exception as e:
    print(f"Failed to connect to Ray: {e}")
    print("\nTroubleshooting:")
    print("  1. Ensure Ray head node is running on the head node")
    print("  2. Check that port 10001 is accessible via Tailscale")
    print("  3. Try: ray start --head --port=6380 --ray-client-server-port=10001")

In [None]:
# Get coordinator handle
try:
    coordinator = ray.get_actor("coordinator", namespace="bgai")
    print("Found coordinator!")
    
    # Get current model version
    version = ray.get(coordinator.get_model_version.remote())
    print(f"Current model version: {version}")
except Exception as e:
    print(f"Failed to get coordinator: {e}")
    print("\nMake sure the coordinator is running on the head node:")
    print("  python -m distributed.cli.main coordinator --redis-host localhost")

In [None]:
# Import and start the game worker
from distributed.workers.game_worker import GameWorker

print(f"Starting game worker '{WORKER_ID}'...")
print(f"  Batch size: {WORKER_CONFIG['batch_size']}")
print(f"  MCTS simulations: {WORKER_CONFIG['num_simulations']}")
print(f"  Redis: {WORKER_CONFIG['redis_host']}:{WORKER_CONFIG['redis_port']}")

# Create game worker as a Ray actor
GameWorkerActor = ray.remote(
    num_gpus=1.0,  # Request GPU
    max_restarts=3,
)(GameWorker)

worker = GameWorkerActor.remote(
    coordinator_handle=coordinator,
    worker_id=WORKER_ID,
    config=WORKER_CONFIG,
)

print("Game worker actor created!")

In [None]:
# Start the worker (this runs indefinitely)
# The worker will generate games and send them to the Redis buffer

print("Starting game generation...")
print("Press the stop button to interrupt.")
print("-" * 50)

try:
    # Start running (blocks until interrupted or error)
    result = ray.get(worker.run.remote(num_iterations=-1))  # -1 = infinite
    print(f"Worker finished: {result}")
except KeyboardInterrupt:
    print("\nStopping worker...")
    ray.get(worker.stop.remote())
    print("Worker stopped.")
except Exception as e:
    print(f"Worker error: {e}")

## 5. Monitoring (Optional)

In [None]:
# Check worker stats
try:
    stats = ray.get(worker.get_stats.remote())
    print("Worker Statistics:")
    for key, value in stats.items():
        print(f"  {key}: {value}")
except Exception as e:
    print(f"Could not get stats: {e}")

In [None]:
# Check Redis buffer status
try:
    r = redis.Redis(
        host=HEAD_NODE_IP,
        port=6379,
        password=REDIS_PASSWORD,
        decode_responses=True
    )
    
    print("Buffer Status:")
    print(f"  Total experiences: {r.llen('bgai:buffer:experiences')}")
    print(f"  Total episodes: {r.llen('bgai:buffer:episodes')}")
    
    # Check registered workers
    worker_keys = r.keys('bgai:worker:metrics:*')
    print(f"  Active workers: {len(worker_keys)}")
    for key in worker_keys:
        print(f"    - {key.split(':')[-1]}")
except Exception as e:
    print(f"Could not check buffer: {e}")

## 6. Cleanup

In [None]:
# Stop the worker gracefully
try:
    ray.get(worker.stop.remote())
    print("Worker stopped.")
except Exception as e:
    print(f"Error stopping worker: {e}")

# Disconnect from Ray
ray.shutdown()
print("Disconnected from Ray cluster.")

In [None]:
# Disconnect Tailscale (optional)
!sudo tailscale down
print("Tailscale disconnected.")

## Troubleshooting

### Tailscale Issues
- **Auth key expired**: Generate a new key at https://login.tailscale.com/admin/settings/keys
- **Can't reach head node**: Check that Tailscale is running on both machines

### Ray Connection Issues
- **Connection refused**: Ensure Ray head node is running with client server enabled
- **Namespace mismatch**: Both coordinator and worker must use `namespace="bgai"`

### Redis Issues
- **Auth failed**: Check the Redis password matches
- **Connection refused**: Ensure Redis is bound to 0.0.0.0 or Tailscale IP

### GPU/Memory Issues
- **OOM**: Reduce `batch_size` in WORKER_CONFIG
- **No GPU**: Check Colab runtime is GPU-enabled