# LOCOMOT.IO Neural Network Training v2

## Quick Start (Works in Colab!)
1. Runtime → Run all
2. Upload `old_brain.json` when prompted
3. Wait ~5-10 min for training
4. Download `brain_v2_final.json` when done

**Two training modes:**
- **Self-play**: AI learns by playing against itself
- **Imitation**: AI learns from recorded top player data

In [None]:
# Setup - works in Colab or locally
import subprocess
import sys

# Install dependencies if needed (Colab)
try:
    import torch
    from tqdm.auto import tqdm
except ImportError:
    subprocess.check_call([sys.executable, '-m', 'pip', 'install', 'torch', 'numpy', 'matplotlib', 'tqdm', '-q'])
    from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
import json
from collections import deque
from copy import deepcopy
import matplotlib.pyplot as plt
import os
from tqdm.auto import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using: {device}')

# Handle file uploads for Colab
IN_COLAB = 'google.colab' in sys.modules
if IN_COLAB:
    from google.colab import files
    print('Running in Colab - will prompt for file uploads')
else:
    print('Running locally')

## Setup (just run these cells)

In [None]:
# Architecture constants
OLD_INPUT_SIZE = 48  # 8 directions × 6 features
NEW_INPUT_SIZE = 60  # 8 directions × 7 features + 4 state inputs
HIDDEN_SIZE = 64
OUTPUT_SIZE = 3

class LocomotNetwork(nn.Module):
    def __init__(self, input_size=NEW_INPUT_SIZE):
        super().__init__()
        self.input_size = input_size
        self.net = nn.Sequential(
            nn.Linear(input_size, HIDDEN_SIZE),
            nn.ReLU(),
            nn.Linear(HIDDEN_SIZE, HIDDEN_SIZE),
            nn.ReLU(),
            nn.Linear(HIDDEN_SIZE, OUTPUT_SIZE)
        )
    
    def forward(self, x):
        return self.net(x)
    
    def get_action(self, state, epsilon=0.0):
        """Get action with epsilon-greedy exploration"""
        if random.random() < epsilon:
            return random.randint(0, 2)
        with torch.no_grad():
            state_t = torch.FloatTensor(state).unsqueeze(0).to(device)
            q_values = self.forward(state_t)
            return q_values.argmax(dim=1).item()


def load_existing_weights(model, json_path):
    """Load weights from JavaScript brain.json format (same input size)"""
    with open(json_path, 'r') as f:
        weights = json.load(f)
    
    state_dict = {
        'net.0.weight': torch.FloatTensor(weights['net.0.weight']),
        'net.0.bias': torch.FloatTensor(weights['net.0.bias']),
        'net.2.weight': torch.FloatTensor(weights['net.2.weight']),
        'net.2.bias': torch.FloatTensor(weights['net.2.bias']),
        'net.4.weight': torch.FloatTensor(weights['net.4.weight']),
        'net.4.bias': torch.FloatTensor(weights['net.4.bias']),
    }
    model.load_state_dict(state_dict)
    return model


def transfer_from_old_brain(new_model, old_json_path):
    """
    Transfer learning: Load old 48-input brain weights into new 60-input model.
    
    Old format (48 inputs = 8 dirs × 6 features):
      Per direction: [food, self, wall, smaller_head, bigger_head, enemy_body]
    
    New format (60 inputs = 8 dirs × 7 features + 4 state):
      Per direction: [health_pickup, gun_pickup, self, wall, smaller_head, bigger_head, enemy_body]
      State: [health_ratio, arena_position, threat_density, my_length]
    
    Mapping strategy:
    - Old 'food' weights → average into both health_pickup and gun_pickup
    - Old self/wall/enemy weights → copy directly to new positions
    - New state inputs → initialize with small random weights
    """
    with open(old_json_path, 'r') as f:
        old_weights = json.load(f)
    
    old_w0 = torch.FloatTensor(old_weights['net.0.weight'])  # [64, 48]
    old_b0 = torch.FloatTensor(old_weights['net.0.bias'])    # [64]
    
    # Create new first layer weights [64, 60]
    new_w0 = torch.zeros(HIDDEN_SIZE, NEW_INPUT_SIZE)
    
    # Map old weights to new positions
    # Old: 8 dirs × [food, self, wall, smaller, bigger, body] = indices 0-47
    # New: 8 dirs × [health, gun, self, wall, smaller, bigger, body] + 4 state = indices 0-59
    
    for d in range(8):  # 8 directions
        old_base = d * 6  # Old: 6 features per direction
        new_base = d * 7  # New: 7 features per direction
        
        # Old food → split to health_pickup and gun_pickup (indices 0, 1)
        food_weights = old_w0[:, old_base + 0]
        new_w0[:, new_base + 0] = food_weights  # health_pickup
        new_w0[:, new_base + 1] = food_weights  # gun_pickup (same initially)
        
        # self (1→2), wall (2→3), smaller (3→4), bigger (4→5), body (5→6)
        new_w0[:, new_base + 2] = old_w0[:, old_base + 1]  # self
        new_w0[:, new_base + 3] = old_w0[:, old_base + 2]  # wall
        new_w0[:, new_base + 4] = old_w0[:, old_base + 3]  # smaller_head
        new_w0[:, new_base + 5] = old_w0[:, old_base + 4]  # bigger_head
        new_w0[:, new_base + 6] = old_w0[:, old_base + 5]  # enemy_body
    
    # Initialize new state inputs (indices 56-59) with small random weights
    # These will be learned during training
    nn.init.xavier_uniform_(new_w0[:, 56:60])
    
    # Load into new model
    new_state_dict = new_model.state_dict()
    new_state_dict['net.0.weight'] = new_w0
    new_state_dict['net.0.bias'] = old_b0  # Bias stays same size
    
    # Copy remaining layers directly (they stay the same size)
    new_state_dict['net.2.weight'] = torch.FloatTensor(old_weights['net.2.weight'])
    new_state_dict['net.2.bias'] = torch.FloatTensor(old_weights['net.2.bias'])
    new_state_dict['net.4.weight'] = torch.FloatTensor(old_weights['net.4.weight'])
    new_state_dict['net.4.bias'] = torch.FloatTensor(old_weights['net.4.bias'])
    
    new_model.load_state_dict(new_state_dict)
    print(f"✓ Transferred weights from old 48-input brain to new 60-input model")
    print(f"  - Spatial features mapped (food → health+gun)")
    print(f"  - State inputs initialized with Xavier uniform")
    return new_model


def save_brain_json(model, path):
    """Export weights to JavaScript-compatible format"""
    state_dict = model.state_dict()
    brain = {
        'input_size': model.input_size,  # Store input size for compatibility checking
        'net.0.weight': state_dict['net.0.weight'].cpu().numpy().tolist(),
        'net.0.bias': state_dict['net.0.bias'].cpu().numpy().tolist(),
        'net.2.weight': state_dict['net.2.weight'].cpu().numpy().tolist(),
        'net.2.bias': state_dict['net.2.bias'].cpu().numpy().tolist(),
        'net.4.weight': state_dict['net.4.weight'].cpu().numpy().tolist(),
        'net.4.bias': state_dict['net.4.bias'].cpu().numpy().tolist(),
    }
    with open(path, 'w') as f:
        json.dump(brain, f)
    print(f'Saved brain to {path} (input_size={model.input_size})')

In [None]:
class LocomotEnv:
    """FAST simplified LOCOMOT.IO environment for self-play with extended inputs"""
    
    WORLD_COLS = 50  # Smaller world = faster
    WORLD_ROWS = 40
    DIRECTIONS = [(0, -1), (1, 0), (0, 1), (-1, 0)]  # UP, RIGHT, DOWN, LEFT
    
    # Pre-compute direction offsets for 8 rays
    RAY_OFFSETS = [
        (0, -1), (1, -1), (1, 0), (1, 1),
        (0, 1), (-1, 1), (-1, 0), (-1, -1)
    ]
    
    def __init__(self, num_agents=4):
        self.num_agents = num_agents
        self.reset()
    
    def reset(self):
        self.agents = []
        self.health_pickups = set()  # Health pickups (green)
        self.gun_pickups = set()     # Gun pickups (various colors)
        self.segment_map = {}  # (x,y) -> (agent_idx, seg_idx) for fast collision
        self.step_count = 0
        
        # Spawn agents with health tracking
        for i in range(self.num_agents):
            x = random.randint(8, self.WORLD_COLS - 8)
            y = random.randint(8, self.WORLD_ROWS - 8)
            direction = random.randint(0, 3)
            dx, dy = self.DIRECTIONS[direction]
            
            # Each segment has hp/maxHp (head is invincible)
            segments = []
            for j in range(4):
                seg = {
                    'x': x - j*dx,
                    'y': y - j*dy,
                    'hp': 100 if j > 0 else float('inf'),  # Head invincible
                    'maxHp': 100 if j > 0 else float('inf')
                }
                segments.append(seg)
            
            self.agents.append({
                'segments': segments,
                'direction': direction,
                'alive': True,
                'score': 0,
                'prev_tail_pos': None  # Track previous tail position for pickup growth
            })
            
            # Add to segment map
            for seg_idx, seg in enumerate(segments):
                self.segment_map[(seg['x'], seg['y'])] = (i, seg_idx)
        
        # Spawn pickups (70% gun, 30% health like the real game)
        for _ in range(15):
            self.gun_pickups.add((
                random.randint(0, self.WORLD_COLS - 1),
                random.randint(0, self.WORLD_ROWS - 1)
            ))
        for _ in range(5):
            self.health_pickups.add((
                random.randint(0, self.WORLD_COLS - 1),
                random.randint(0, self.WORLD_ROWS - 1)
            ))
        
        return [self.get_vision_v2(i) for i in range(self.num_agents)]
    
    def get_health_ratio(self, agent):
        """Calculate health ratio (0-1) from body segments"""
        total_hp = 0
        total_max = 0
        for seg in agent['segments'][1:]:  # Skip head (infinite hp)
            if seg['hp'] != float('inf'):
                total_hp += seg['hp']
                total_max += seg['maxHp']
        return total_hp / total_max if total_max > 0 else 1.0
    
    def get_arena_position(self, head_x, head_y):
        """Calculate arena position safety (0-1, higher = more centered)"""
        dist_to_edge = min(head_x, head_y, 
                          self.WORLD_COLS - 1 - head_x, 
                          self.WORLD_ROWS - 1 - head_y)
        max_dist = min(self.WORLD_COLS, self.WORLD_ROWS) / 2
        return min(dist_to_edge / max_dist, 1.0)
    
    def get_threat_density(self, agent_idx, head_x, head_y):
        """Calculate threat density from nearby enemies (0-1)"""
        threat = 0.0
        my_length = len(self.agents[agent_idx]['segments'])
        
        for i, other in enumerate(self.agents):
            if i == agent_idx or not other['alive']:
                continue
            
            other_head = other['segments'][0]
            dist = abs(other_head['x'] - head_x) + abs(other_head['y'] - head_y)
            
            if dist < 15:
                # Bigger enemies are more threatening
                size_factor = 2.0 if len(other['segments']) > my_length else 0.5
                threat += size_factor / max(dist, 1)
        
        # Normalize to 0-1 (cap at reasonable max)
        return min(threat / 3.0, 1.0)
    
    def get_vision_v2(self, agent_idx):
        """
        Extended vision with 60 inputs:
        - 56 spatial: 8 directions × 7 features
          [health_pickup, gun_pickup, self, wall, smaller_head, bigger_head, enemy_body]
        - 4 state: [health_ratio, arena_position, threat_density, my_length_normalized]
        """
        agent = self.agents[agent_idx]
        if not agent['alive']:
            return np.zeros(NEW_INPUT_SIZE, dtype=np.float32)
        
        head = agent['segments'][0]
        head_x, head_y = head['x'], head['y']
        current_dir = agent['direction']
        my_length = len(agent['segments'])
        my_segments = set((s['x'], s['y']) for s in agent['segments'][1:])
        
        vision = np.zeros(NEW_INPUT_SIZE, dtype=np.float32)
        
        # Rotate ray offsets based on direction
        rot = current_dir
        
        for ray_idx in range(8):
            rotated_idx = (ray_idx + rot * 2) % 8
            dx, dy = self.RAY_OFFSETS[rotated_idx]
            
            health_dist = gun_dist = self_danger = wall_dist = 0.0
            enemy_smaller = enemy_bigger = enemy_body = 0.0
            
            for dist in range(1, 16):
                cx, cy = head_x + dx * dist, head_y + dy * dist
                
                # Wall
                if cx < 0 or cx >= self.WORLD_COLS or cy < 0 or cy >= self.WORLD_ROWS:
                    if wall_dist == 0:
                        wall_dist = 1.0 / dist
                    break
                
                # Self body
                if self_danger == 0 and (cx, cy) in my_segments:
                    self_danger = 1.0 / dist
                
                # Health pickup
                if health_dist == 0 and (cx, cy) in self.health_pickups:
                    health_dist = 1.0 / dist
                
                # Gun pickup
                if gun_dist == 0 and (cx, cy) in self.gun_pickups:
                    gun_dist = 1.0 / dist
                
                # Other agents
                pos = (cx, cy)
                if pos in self.segment_map:
                    other_idx, seg_idx = self.segment_map[pos]
                    if other_idx != agent_idx and self.agents[other_idx]['alive']:
                        other_len = len(self.agents[other_idx]['segments'])
                        if seg_idx == 0:  # Head
                            if other_len < my_length and enemy_smaller == 0:
                                enemy_smaller = 1.0 / dist
                            elif enemy_bigger == 0:
                                enemy_bigger = 1.0 / dist
                        elif enemy_body == 0:
                            enemy_body = 1.0 / dist
            
            # Store 7 features per direction
            base = ray_idx * 7
            vision[base:base+7] = [health_dist, gun_dist, self_danger, wall_dist, 
                                   enemy_smaller, enemy_bigger, enemy_body]
        
        # Add 4 state inputs (indices 56-59)
        vision[56] = self.get_health_ratio(agent)
        vision[57] = self.get_arena_position(head_x, head_y)
        vision[58] = self.get_threat_density(agent_idx, head_x, head_y)
        vision[59] = min(my_length / 20.0, 1.0)  # Normalized length (cap at 20)
        
        return vision
    
    def step(self, actions):
        """Execute actions. Returns (observations, rewards, dones, game_over)"""
        self.step_count += 1
        rewards = np.zeros(self.num_agents, dtype=np.float32)
        
        # Apply turns
        for i, agent in enumerate(self.agents):
            if not agent['alive']:
                continue
            action = actions[i]
            if action == 0:
                agent['direction'] = (agent['direction'] - 1) % 4
            elif action == 2:
                agent['direction'] = (agent['direction'] + 1) % 4
        
        # Move agents and update segment map
        for i, agent in enumerate(self.agents):
            if not agent['alive']:
                continue
            
            # Save previous tail position BEFORE movement (for pickup growth)
            old_tail = agent['segments'][-1]
            agent['prev_tail_pos'] = (old_tail['x'], old_tail['y'])
            
            # Remove old tail from map
            tail_pos = agent['prev_tail_pos']
            if tail_pos in self.segment_map and self.segment_map[tail_pos][0] == i:
                del self.segment_map[tail_pos]
            
            dx, dy = self.DIRECTIONS[agent['direction']]
            head = agent['segments'][0]
            new_head = {
                'x': head['x'] + dx,
                'y': head['y'] + dy,
                'hp': float('inf'),
                'maxHp': float('inf')
            }
            
            agent['segments'].insert(0, new_head)
            agent['segments'].pop()
            
            # Update segment map
            for seg_idx, seg in enumerate(agent['segments']):
                self.segment_map[(seg['x'], seg['y'])] = (i, seg_idx)
            
            rewards[i] += 0.01
        
        # Check collisions
        for i, agent in enumerate(self.agents):
            if not agent['alive']:
                continue
            
            head = agent['segments'][0]
            hx, hy = head['x'], head['y']
            
            # Wall
            if hx < 0 or hx >= self.WORLD_COLS or hy < 0 or hy >= self.WORLD_ROWS:
                agent['alive'] = False
                rewards[i] -= 5.0
                continue
            
            # Self collision
            my_body = set((s['x'], s['y']) for s in agent['segments'][1:])
            if (hx, hy) in my_body:
                agent['alive'] = False
                rewards[i] -= 5.0
                continue
            
            # Enemy collision
            for j, other in enumerate(self.agents):
                if i == j or not other['alive']:
                    continue
                
                oh = other['segments'][0]
                if hx == oh['x'] and hy == oh['y']:
                    # Head-to-head
                    if len(agent['segments']) > len(other['segments']):
                        other['alive'] = False
                        rewards[j] -= 5.0
                        rewards[i] += 3.0
                    elif len(agent['segments']) < len(other['segments']):
                        agent['alive'] = False
                        rewards[i] -= 5.0
                        rewards[j] += 3.0
                    else:
                        agent['alive'] = other['alive'] = False
                        rewards[i] = rewards[j] = -3.0
                    break
                
                # Head into body
                other_body = set((s['x'], s['y']) for s in other['segments'][1:])
                if (hx, hy) in other_body:
                    agent['alive'] = False
                    rewards[i] -= 5.0
                    rewards[j] += 2.0
                    break
        
        # Pickup collection
        for i, agent in enumerate(self.agents):
            if not agent['alive']:
                continue
            
            head = agent['segments'][0]
            head_pos = (head['x'], head['y'])
            
            # Health pickup - heals segments
            if head_pos in self.health_pickups:
                self.health_pickups.remove(head_pos)
                # Heal all body segments
                for seg in agent['segments'][1:]:
                    if seg['hp'] != float('inf'):
                        seg['hp'] = min(seg['hp'] + 30, seg['maxHp'])
                rewards[i] += 0.5
                # Respawn
                self.health_pickups.add((
                    random.randint(0, self.WORLD_COLS - 1),
                    random.randint(0, self.WORLD_ROWS - 1)
                ))
            
            # Gun pickup - grows snake (use PREVIOUS tail position to avoid self-collision)
            if head_pos in self.gun_pickups:
                self.gun_pickups.remove(head_pos)
                prev_tail = agent['prev_tail_pos']
                agent['segments'].append({
                    'x': prev_tail[0],
                    'y': prev_tail[1],
                    'hp': 100,
                    'maxHp': 100
                })
                # Update segment map for new segment
                self.segment_map[prev_tail] = (i, len(agent['segments']) - 1)
                agent['score'] += 1
                rewards[i] += 1.0
                # Respawn
                self.gun_pickups.add((
                    random.randint(0, self.WORLD_COLS - 1),
                    random.randint(0, self.WORLD_ROWS - 1)
                ))
        
        # Random damage to simulate combat
        if self.step_count % 20 == 0:
            for agent in self.agents:
                if agent['alive'] and len(agent['segments']) > 1:
                    # Random segment takes minor damage
                    idx = random.randint(1, len(agent['segments']) - 1)
                    if agent['segments'][idx]['hp'] != float('inf'):
                        agent['segments'][idx]['hp'] -= random.randint(5, 15)
                        if agent['segments'][idx]['hp'] <= 0:
                            # Segment destroyed - snake shrinks
                            agent['segments'].pop(idx)
        
        observations = [self.get_vision_v2(i) for i in range(self.num_agents)]
        dones = [not a['alive'] for a in self.agents]
        
        alive = sum(1 for a in self.agents if a['alive'])
        game_over = alive <= 1 or self.step_count >= 500
        
        if game_over and alive == 1:
            for i, a in enumerate(self.agents):
                if a['alive']:
                    rewards[i] += 5.0
        
        return observations, rewards.tolist(), dones, game_over

In [None]:
class SelfPlayTrainer:
    def __init__(self, pool_size=10, old_brain_path=None):
        self.current_model = LocomotNetwork(input_size=NEW_INPUT_SIZE).to(device)
        
        if old_brain_path and os.path.exists(old_brain_path):
            transfer_from_old_brain(self.current_model, old_brain_path)
            print(f"Loaded weights from {old_brain_path}")
        else:
            print("Starting with random weights")
        
        self.target_model = LocomotNetwork(input_size=NEW_INPUT_SIZE).to(device)
        self.target_model.load_state_dict(self.current_model.state_dict())
        
        self.opponent_pool = [deepcopy(self.current_model.state_dict())]
        self.pool_size = pool_size
        self.memory = deque(maxlen=50000)
        
        self.optimizer = optim.Adam(self.current_model.parameters(), lr=0.001)
        self.gamma = 0.99
        self.epsilon = 0.5 if old_brain_path else 1.0
        self.epsilon_min = 0.05
        self.epsilon_decay = 0.995
        self.batch_size = 64
        
        self.episode_rewards = []
        self.win_rates = []
    
    def train_from_player_data(self, data_path, epochs=10):
        """Imitation learning - train from recorded top player data"""
        with open(data_path, 'r') as f:
            data = json.load(f)
        
        if not data:
            print("No player data found")
            return
        
        # Data format: [{state: [...], action: 0/1/2}, ...]
        states = torch.FloatTensor([d['state'] for d in data]).to(device)
        actions = torch.LongTensor([d['action'] for d in data]).to(device)
        
        print(f"Training on {len(data)} recorded frames...")
        
        criterion = nn.CrossEntropyLoss()
        
        for epoch in tqdm(range(epochs), desc="Imitation Learning"):
            # Shuffle
            perm = torch.randperm(len(states))
            states, actions = states[perm], actions[perm]
            
            total_loss = 0
            correct = 0
            
            for i in range(0, len(states), self.batch_size):
                batch_states = states[i:i+self.batch_size]
                batch_actions = actions[i:i+self.batch_size]
                
                outputs = self.current_model(batch_states)
                loss = criterion(outputs, batch_actions)
                
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                
                total_loss += loss.item()
                correct += (outputs.argmax(1) == batch_actions).sum().item()
            
            acc = correct / len(states) * 100
            tqdm.write(f"Epoch {epoch+1}/{epochs} | Loss: {total_loss:.4f} | Accuracy: {acc:.1f}%")
        
        # Update target
        self.target_model.load_state_dict(self.current_model.state_dict())
        print("Imitation learning complete!")
    
    def select_opponents(self, num_opponents):
        opponents = []
        for _ in range(num_opponents):
            if random.random() < 0.3:
                opponents.append(deepcopy(self.current_model))
            else:
                idx = min(int(random.triangular(0, len(self.opponent_pool), len(self.opponent_pool))), 
                         len(self.opponent_pool) - 1)
                opponent = LocomotNetwork(input_size=NEW_INPUT_SIZE).to(device)
                opponent.load_state_dict(self.opponent_pool[idx])
                opponent.eval()
                opponents.append(opponent)
        return opponents
    
    def play_episode(self):
        env = LocomotEnv(num_agents=4)
        observations = env.reset()
        opponents = self.select_opponents(3)
        
        episode_transitions = []
        total_reward = 0
        
        while True:
            actions = [self.current_model.get_action(observations[0], self.epsilon)]
            for i, opp in enumerate(opponents):
                if env.agents[i + 1]['alive']:
                    actions.append(opp.get_action(observations[i + 1], 0.0))
                else:
                    actions.append(1)
            
            next_obs, rewards, dones, game_over = env.step(actions)
            
            if env.agents[0]['alive'] or dones[0]:
                episode_transitions.append((observations[0], actions[0], rewards[0], next_obs[0], dones[0]))
                total_reward += rewards[0]
            
            observations = next_obs
            if game_over:
                break
        
        for t in episode_transitions:
            self.memory.append(t)
        
        final_health = env.get_health_ratio(env.agents[0]) if env.agents[0]['alive'] else 0
        return total_reward, env.agents[0]['alive'], len(env.agents[0]['segments']), final_health
    
    def train_step(self):
        if len(self.memory) < self.batch_size:
            return 0
        
        batch = random.sample(self.memory, self.batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)
        
        states = torch.FloatTensor(np.array(states)).to(device)
        actions = torch.LongTensor(actions).to(device)
        rewards = torch.FloatTensor(rewards).to(device)
        next_states = torch.FloatTensor(np.array(next_states)).to(device)
        dones = torch.FloatTensor(dones).to(device)
        
        current_q = self.current_model(states).gather(1, actions.unsqueeze(1))
        
        with torch.no_grad():
            next_actions = self.current_model(next_states).argmax(1)
            next_q = self.target_model(next_states).gather(1, next_actions.unsqueeze(1)).squeeze()
            target_q = rewards + self.gamma * next_q * (1 - dones)
        
        loss = nn.MSELoss()(current_q.squeeze(), target_q)
        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.current_model.parameters(), 1.0)
        self.optimizer.step()
        return loss.item()
    
    def update_target(self):
        tau = 0.01
        for tp, cp in zip(self.target_model.parameters(), self.current_model.parameters()):
            tp.data.copy_(tau * cp.data + (1 - tau) * tp.data)
    
    def add_to_pool(self):
        self.opponent_pool.append(deepcopy(self.current_model.state_dict()))
        if len(self.opponent_pool) > self.pool_size:
            del self.opponent_pool[len(self.opponent_pool) // 2]
    
    def train(self, num_episodes=2000, save_every=500):
        import time
        wins = 0
        recent_rewards = deque(maxlen=50)
        start_time = time.time()
        
        pbar = tqdm(range(num_episodes), desc="Self-Play Training")
        for episode in pbar:
            reward, won, final_length, final_health = self.play_episode()
            recent_rewards.append(reward)
            if won:
                wins += 1
            
            for _ in range(2):
                self.train_step()
            
            self.update_target()
            self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)
            
            if episode % 50 == 0 and episode > 0:
                self.add_to_pool()
            
            if episode % 25 == 0:
                win_rate = wins / 25 if episode > 0 else 0
                avg_reward = np.mean(recent_rewards) if recent_rewards else 0
                pbar.set_postfix({
                    'R': f'{avg_reward:.1f}',
                    'Win': f'{win_rate:.0%}',
                    'ε': f'{self.epsilon:.2f}'
                })
                self.episode_rewards.append(avg_reward)
                self.win_rates.append(win_rate)
                wins = 0
            
            if episode % save_every == 0 and episode > 0:
                save_brain_json(self.current_model, f'brain_v2_ep{episode}.json')
        
        elapsed = time.time() - start_time
        print(f'\nDone! {elapsed:.1f}s ({num_episodes/elapsed:.1f} ep/s)')
        return self.current_model

## Training
Run the cell below to start training. Takes ~5-10 minutes.

In [None]:
# Auto-find newest brain file or upload
import glob
from datetime import datetime

def find_newest_brain():
    """Find the newest brain file by timestamp in filename"""
    patterns = ['brain_*.json', 'old_brain*.json']
    all_brains = []
    for p in patterns:
        all_brains.extend(glob.glob(p))
    
    if not all_brains:
        return None
    
    # Sort by modification time, newest first
    all_brains.sort(key=lambda f: os.path.getmtime(f), reverse=True)
    return all_brains[0]

# Try to find existing brain
brain_path = find_newest_brain()

if brain_path:
    print(f"✓ Found brain: {brain_path}")
elif IN_COLAB:
    print("No brain found. Upload one:")
    uploaded = files.upload()
    brain_files = [f for f in uploaded.keys() if f.endswith('.json')]
    if brain_files:
        brain_path = brain_files[0]
        print(f"✓ Using uploaded: {brain_path}")
    else:
        print("⚠ No brain uploaded - starting fresh")
        brain_path = None
else:
    print("⚠ No brain found - starting fresh")

# Fetch player training data from server
import urllib.request
print("\nFetching player training data from server...")
try:
    req = urllib.request.Request(
        'https://locomot-io.savecharlie.partykit.dev/party/collective',
        data=json.dumps({'type': 'get_player_data'}).encode(),
        headers={'Content-Type': 'application/json'},
        method='POST'
    )
    with urllib.request.urlopen(req, timeout=10) as response:
        player_data = json.loads(response.read().decode())
        if player_data and len(player_data) > 0:
            with open('player_data.json', 'w') as f:
                json.dump(player_data, f)
            print(f"✓ Got {len(player_data)} frames from top players!")
            has_player_data = True
        else:
            print("No player data on server yet")
            has_player_data = False
except Exception as e:
    print(f"Could not fetch player data: {e}")
    has_player_data = False

# Train!
trainer = SelfPlayTrainer(pool_size=10, old_brain_path=brain_path)

# If we have player data, do imitation learning first
if has_player_data:
    print("\n=== Imitation Learning from Top Players ===")
    trainer.train_from_player_data('player_data.json', epochs=10)

# Then self-play
print("\n=== Self-Play Training ===")
trained_model = trainer.train(num_episodes=2000, save_every=500)

In [None]:
# Plot training progress
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

ax1.plot(trainer.episode_rewards)
ax1.set_title('Average Episode Reward')
ax1.set_xlabel('Episode (×100)')
ax1.set_ylabel('Reward')
ax1.grid(True)

ax2.plot(trainer.win_rates)
ax2.set_title('Win Rate vs Self-Play Pool')
ax2.set_xlabel('Episode (×100)')
ax2.set_ylabel('Win Rate')
ax2.grid(True)

plt.tight_layout()
plt.show()

In [None]:
# Save with timestamp so next run auto-finds it
from datetime import datetime
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
filename = f'brain_{timestamp}.json'

save_brain_json(trained_model, filename)
print(f"\n✓ Saved as {filename}")
print("Next run will auto-load this brain!")

if IN_COLAB:
    files.download(filename)
    print('Downloaded! Update index.html with this brain.')

## Next Steps

After training completes, ask Iris to:
1. Update `getVision()` in index.html for 60 inputs
2. Replace the brain data with `brain_v2_final.json`

The new format adds:
- Split food → health_pickup + gun_pickup
- 4 state inputs at the end (health, position, threat, length)