# LOCOMOT.IO Neural Network Training v2

## Quick Start (Works in Colab!)
1. Runtime → Run all
2. Upload `old_brain.json` when prompted
3. Wait ~5-10 min for training
4. Download `brain_v2_final.json` when done

**Two training modes:**
- **Self-play**: AI learns by playing against itself
- **Imitation**: AI learns from recorded top player data

In [None]:
# Setup - works in Colab or locally
import subprocess
import sys

# Install dependencies if needed (Colab)
try:
    import torch
    from tqdm.auto import tqdm
except ImportError:
    subprocess.check_call([sys.executable, '-m', 'pip', 'install', 'torch', 'numpy', 'matplotlib', 'tqdm', '-q'])
    from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
import json
from collections import deque
from copy import deepcopy
import matplotlib.pyplot as plt
import os
from tqdm.auto import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using: {device}')

# Handle file uploads for Colab
IN_COLAB = 'google.colab' in sys.modules
if IN_COLAB:
    from google.colab import files
    print('Running in Colab - will prompt for file uploads')
else:
    print('Running locally')

## Setup (just run these cells)

In [None]:
# Architecture constants
OLD_INPUT_SIZE = 85  # Previous: 8 directions × 10 features + 5 state
NEW_INPUT_SIZE = 96  # New: 8 directions × 11 features + 8 state (gun type + MVP)
HIDDEN1_SIZE = 128   # Larger hidden layer for more features
HIDDEN2_SIZE = 64
OUTPUT_SIZE = 3

class LocomotNetwork(nn.Module):
    def __init__(self, input_size=NEW_INPUT_SIZE):
        super().__init__()
        self.input_size = input_size
        self.net = nn.Sequential(
            nn.Linear(input_size, HIDDEN1_SIZE),
            nn.ReLU(),
            nn.Linear(HIDDEN1_SIZE, HIDDEN2_SIZE),
            nn.ReLU(),
            nn.Linear(HIDDEN2_SIZE, OUTPUT_SIZE)
        )
    
    def forward(self, x):
        return self.net(x)
    
    def get_action(self, state, epsilon=0.0):
        """Get action with epsilon-greedy exploration"""
        if random.random() < epsilon:
            return random.randint(0, 2)
        with torch.no_grad():
            state_t = torch.FloatTensor(state).unsqueeze(0).to(device)
            q_values = self.forward(state_t)
            return q_values.argmax(dim=1).item()


def load_existing_weights(model, json_path):
    """Load weights from JavaScript brain.json format (same input size)"""
    with open(json_path, 'r') as f:
        weights = json.load(f)
    
    state_dict = {
        'net.0.weight': torch.FloatTensor(weights['net.0.weight']),
        'net.0.bias': torch.FloatTensor(weights['net.0.bias']),
        'net.2.weight': torch.FloatTensor(weights['net.2.weight']),
        'net.2.bias': torch.FloatTensor(weights['net.2.bias']),
        'net.4.weight': torch.FloatTensor(weights['net.4.weight']),
        'net.4.bias': torch.FloatTensor(weights['net.4.bias']),
    }
    model.load_state_dict(state_dict)
    return model


def transfer_from_old_brain(new_model, old_json_path):
    """
    Transfer learning: Load old 85-input brain weights into new 96-input model.
    
    Old format (85 inputs = 8 dirs × 10 features + 5 state):
      Per direction: [health_pickup, machinegun, shotgun, cannon, pulse, self, wall, smaller_head, bigger_head, enemy_body]
      State: [health_ratio, arena_position, threat_density, my_length, current_gun_type]
    
    New format (96 inputs = 8 dirs × 11 features + 8 state):
      Per direction: [health_pickup, machinegun, shotgun, cannon, pulse, self, wall, smaller_head, bigger_head, enemy_body, enemy_gun_type]
      State: [health_ratio, arena_position, threat_density, my_length, current_gun_type, is_mvp, mvp_time, mvp_distance]
    
    Mapping strategy:
    - Old 10 spatial features → new positions (indices 0-9)
    - New enemy_gun_type (index 10) → initialize with small random weights
    - Old 5 state inputs → copy to new state positions (88-92)
    - New MVP inputs (93-95) → initialize with small random weights
    """
    with open(old_json_path, 'r') as f:
        old_weights = json.load(f)
    
    old_input_size = len(old_weights['net.0.weight'][0])
    old_hidden_size = len(old_weights['net.0.bias'])
    
    print(f"  Old brain: {old_input_size} inputs, {old_hidden_size} hidden")
    
    old_w0 = torch.FloatTensor(old_weights['net.0.weight'])
    old_b0 = torch.FloatTensor(old_weights['net.0.bias'])
    
    new_w0 = torch.zeros(HIDDEN1_SIZE, NEW_INPUT_SIZE)
    new_b0 = torch.zeros(HIDDEN1_SIZE)
    
    new_b0[:min(old_hidden_size, HIDDEN1_SIZE)] = old_b0[:min(old_hidden_size, HIDDEN1_SIZE)]
    
    if old_input_size == 85:
        # Transfer from 85-input brain to 96-input
        for d in range(8):
            old_base = d * 10  # Old: 10 features per direction
            new_base = d * 11  # New: 11 features per direction
            
            for h in range(min(old_hidden_size, HIDDEN1_SIZE)):
                # Copy first 10 features directly
                for f in range(10):
                    new_w0[h, new_base + f] = old_w0[h, old_base + f]
                # Index 10 (enemy_gun_type) stays zero, will be Xavier init
        
        # Copy 5 state inputs (80-84 → 88-92)
        for h in range(min(old_hidden_size, HIDDEN1_SIZE)):
            for s in range(5):
                new_w0[h, 88 + s] = old_w0[h, 80 + s]
    elif old_input_size == 60:
        # Transfer from old 60-input brain
        for d in range(8):
            old_base = d * 7
            new_base = d * 11
            
            for h in range(min(old_hidden_size, HIDDEN1_SIZE)):
                new_w0[h, new_base + 0] = old_w0[h, old_base + 0]  # health
                gun_weight = old_w0[h, old_base + 1]
                for g in range(4):
                    new_w0[h, new_base + 1 + g] = gun_weight  # gun types
                new_w0[h, new_base + 5] = old_w0[h, old_base + 2]  # self
                new_w0[h, new_base + 6] = old_w0[h, old_base + 3]  # wall
                new_w0[h, new_base + 7] = old_w0[h, old_base + 4]  # smaller
                new_w0[h, new_base + 8] = old_w0[h, old_base + 5]  # bigger
                new_w0[h, new_base + 9] = old_w0[h, old_base + 6]  # body
        
        for h in range(min(old_hidden_size, HIDDEN1_SIZE)):
            for s in range(4):
                new_w0[h, 88 + s] = old_w0[h, 56 + s]
    else:
        print(f"  Unknown old format ({old_input_size} inputs), using random init")
        nn.init.xavier_uniform_(new_w0)
    
    # Initialize new inputs with Xavier
    for d in range(8):
        nn.init.xavier_uniform_(new_w0[:, d*11 + 10:d*11 + 11])  # enemy_gun_type
    nn.init.xavier_uniform_(new_w0[:, 93:96])  # MVP inputs
    
    new_state_dict = new_model.state_dict()
    new_state_dict['net.0.weight'] = new_w0
    new_state_dict['net.0.bias'] = new_b0
    
    # Hidden and output layers
    old_w2 = torch.FloatTensor(old_weights['net.2.weight'])
    old_b2 = torch.FloatTensor(old_weights['net.2.bias'])
    new_w2 = torch.zeros(HIDDEN2_SIZE, HIDDEN1_SIZE)
    new_b2 = torch.zeros(HIDDEN2_SIZE)
    
    h2_copy = min(old_w2.shape[0], HIDDEN2_SIZE)
    h1_copy = min(old_w2.shape[1], HIDDEN1_SIZE)
    new_w2[:h2_copy, :h1_copy] = old_w2[:h2_copy, :h1_copy]
    new_b2[:h2_copy] = old_b2[:h2_copy]
    
    if HIDDEN1_SIZE > old_w2.shape[1]:
        nn.init.xavier_uniform_(new_w2[:, old_w2.shape[1]:])
    
    new_state_dict['net.2.weight'] = new_w2
    new_state_dict['net.2.bias'] = new_b2
    
    old_w4 = torch.FloatTensor(old_weights['net.4.weight'])
    old_b4 = torch.FloatTensor(old_weights['net.4.bias'])
    new_w4 = torch.zeros(OUTPUT_SIZE, HIDDEN2_SIZE)
    new_b4 = old_b4.clone()
    
    h2_copy = min(old_w4.shape[1], HIDDEN2_SIZE)
    new_w4[:, :h2_copy] = old_w4[:, :h2_copy]
    
    new_state_dict['net.4.weight'] = new_w4
    new_state_dict['net.4.bias'] = new_b4
    
    new_model.load_state_dict(new_state_dict)
    print(f"Transferred weights from old {old_input_size}-input brain to new {NEW_INPUT_SIZE}-input model")
    print(f"  - Added enemy_gun_type per direction (index 10)")
    print(f"  - Added MVP inputs (is_mvp, mvp_time, mvp_distance)")
    return new_model


def save_brain_json(model, path):
    """Export weights to JavaScript-compatible format"""
    state_dict = model.state_dict()
    brain = {
        'input_size': model.input_size,
        'net.0.weight': state_dict['net.0.weight'].cpu().numpy().tolist(),
        'net.0.bias': state_dict['net.0.bias'].cpu().numpy().tolist(),
        'net.2.weight': state_dict['net.2.weight'].cpu().numpy().tolist(),
        'net.2.bias': state_dict['net.2.bias'].cpu().numpy().tolist(),
        'net.4.weight': state_dict['net.4.weight'].cpu().numpy().tolist(),
        'net.4.bias': state_dict['net.4.bias'].cpu().numpy().tolist(),
    }
    with open(path, 'w') as f:
        json.dump(brain, f)
    print(f'Saved brain to {path} (input_size={model.input_size})')

## Arena Rules Extraction
Automatically reads game rules from index.html to keep training environment in sync.

In [None]:
class LocomotEnv:
    """FAST simplified LOCOMOT.IO environment for self-play with 96 inputs (gun types + MVP)"""
    
    WORLD_COLS = 50  # Smaller world = faster
    WORLD_ROWS = 40
    DIRECTIONS = [(0, -1), (1, 0), (0, 1), (-1, 0)]  # UP, RIGHT, DOWN, LEFT
    
    # Pre-compute direction offsets for 8 rays
    RAY_OFFSETS = [
        (0, -1), (1, -1), (1, 0), (1, 1),
        (0, 1), (-1, 1), (-1, 0), (-1, -1)
    ]
    
    # Gun types
    GUN_TYPES = ['MACHINEGUN', 'SHOTGUN', 'CANNON', 'PULSE']
    GUN_TYPE_VALUE = {'MACHINEGUN': 0.25, 'SHOTGUN': 0.5, 'CANNON': 0.75, 'PULSE': 1.0}
    
    def __init__(self, num_agents=4):
        self.num_agents = num_agents
        self.current_mvp_idx = None  # Track which agent is MVP
        self.reset()
    
    def reset(self):
        self.agents = []
        self.health_pickups = set()
        self.gun_pickups = {}
        self.segment_map = {}
        self.step_count = 0
        self.current_mvp_idx = None
        
        for i in range(self.num_agents):
            x = random.randint(8, self.WORLD_COLS - 8)
            y = random.randint(8, self.WORLD_ROWS - 8)
            direction = random.randint(0, 3)
            dx, dy = self.DIRECTIONS[direction]
            
            # "All Same Gun" - start with all MACHINEGUN
            segments = []
            for j in range(4):
                seg = {
                    'x': x - j*dx,
                    'y': y - j*dy,
                    'hp': 100 if j > 0 else float('inf'),
                    'maxHp': 100 if j > 0 else float('inf'),
                    'gun_type': 0 if j > 0 else -1  # 0 = MACHINEGUN
                }
                segments.append(seg)
            
            self.agents.append({
                'segments': segments,
                'direction': direction,
                'alive': True,
                'score': 0,
                'prev_tail_pos': None,
                'current_gun': 0,  # Track current gun type (All Same Gun)
                'mvp_time': 0,     # Accumulated MVP time
                'is_mvp': False
            })
            
            for seg_idx, seg in enumerate(segments):
                self.segment_map[(seg['x'], seg['y'])] = (i, seg_idx)
        
        for _ in range(15):
            pos = (random.randint(0, self.WORLD_COLS - 1), random.randint(0, self.WORLD_ROWS - 1))
            self.gun_pickups[pos] = random.randint(0, 3)
        
        for _ in range(5):
            self.health_pickups.add((
                random.randint(0, self.WORLD_COLS - 1),
                random.randint(0, self.WORLD_ROWS - 1)
            ))
        
        return [self.get_vision_v4(i) for i in range(self.num_agents)]
    
    def update_mvp(self):
        """Update which agent is MVP based on length and recent activity"""
        best_score = -1
        best_idx = None
        
        for i, agent in enumerate(self.agents):
            if not agent['alive']:
                agent['is_mvp'] = False
                continue
            score = len(agent['segments']) * 2 + agent['score']
            if score > best_score:
                best_score = score
                best_idx = i
        
        # Update MVP status
        for i, agent in enumerate(self.agents):
            was_mvp = agent['is_mvp']
            agent['is_mvp'] = (i == best_idx)
            if agent['is_mvp']:
                agent['mvp_time'] += 1  # Accumulate MVP time
        
        self.current_mvp_idx = best_idx
    
    def get_health_ratio(self, agent):
        total_hp = 0
        total_max = 0
        for seg in agent['segments'][1:]:
            if seg['hp'] != float('inf'):
                total_hp += seg['hp']
                total_max += seg['maxHp']
        return total_hp / total_max if total_max > 0 else 1.0
    
    def get_arena_position(self, head_x, head_y):
        dist_to_edge = min(head_x, head_y, 
                          self.WORLD_COLS - 1 - head_x, 
                          self.WORLD_ROWS - 1 - head_y)
        max_dist = min(self.WORLD_COLS, self.WORLD_ROWS) / 2
        return min(dist_to_edge / max_dist, 1.0)
    
    def get_threat_density(self, agent_idx, head_x, head_y):
        threat = 0.0
        my_length = len(self.agents[agent_idx]['segments'])
        
        for i, other in enumerate(self.agents):
            if i == agent_idx or not other['alive']:
                continue
            
            other_head = other['segments'][0]
            dist = abs(other_head['x'] - head_x) + abs(other_head['y'] - head_y)
            
            if dist < 15:
                size_factor = 2.0 if len(other['segments']) > my_length else 0.5
                threat += size_factor / max(dist, 1)
        
        return min(threat / 3.0, 1.0)
    
    def get_current_gun_type(self, agent):
        gun_type = agent.get('current_gun', -1)
        if gun_type >= 0:
            return (gun_type + 1) * 0.25
        return 0.0
    
    def get_vision_v4(self, agent_idx):
        """
        96 inputs: 8 dirs × 11 features + 8 state
        Per direction: [health, mg, sg, cn, pl, self, wall, smaller, bigger, body, enemy_gun_type]
        State: [health_ratio, arena_pos, threat, length, my_gun, is_mvp, mvp_time, mvp_dist]
        """
        agent = self.agents[agent_idx]
        if not agent['alive']:
            return np.zeros(NEW_INPUT_SIZE, dtype=np.float32)
        
        head = agent['segments'][0]
        head_x, head_y = head['x'], head['y']
        current_dir = agent['direction']
        my_length = len(agent['segments'])
        my_segments = set((s['x'], s['y']) for s in agent['segments'][1:])
        
        vision = np.zeros(NEW_INPUT_SIZE, dtype=np.float32)
        rot = current_dir
        
        for ray_idx in range(8):
            rotated_idx = (ray_idx + rot * 2) % 8
            dx, dy = self.RAY_OFFSETS[rotated_idx]
            
            health_dist = 0.0
            gun_dists = [0.0, 0.0, 0.0, 0.0]
            self_danger = wall_dist = 0.0
            enemy_smaller = enemy_bigger = enemy_body = 0.0
            nearest_enemy_gun = 0.0
            nearest_enemy_dist = float('inf')
            
            for dist in range(1, 16):
                cx, cy = head_x + dx * dist, head_y + dy * dist
                
                if cx < 0 or cx >= self.WORLD_COLS or cy < 0 or cy >= self.WORLD_ROWS:
                    if wall_dist == 0:
                        wall_dist = 1.0 / dist
                    break
                
                if self_danger == 0 and (cx, cy) in my_segments:
                    self_danger = 1.0 / dist
                
                if health_dist == 0 and (cx, cy) in self.health_pickups:
                    health_dist = 1.0 / dist
                
                pos = (cx, cy)
                if pos in self.gun_pickups:
                    gun_type = self.gun_pickups[pos]
                    if gun_dists[gun_type] == 0:
                        gun_dists[gun_type] = 1.0 / dist
                
                if pos in self.segment_map:
                    other_idx, seg_idx = self.segment_map[pos]
                    if other_idx != agent_idx and self.agents[other_idx]['alive']:
                        other = self.agents[other_idx]
                        other_len = len(other['segments'])
                        
                        if seg_idx == 0:  # Head
                            # Track nearest enemy gun type
                            if dist < nearest_enemy_dist:
                                nearest_enemy_dist = dist
                                gun = other.get('current_gun', 0)
                                nearest_enemy_gun = (gun + 1) * 0.25
                            
                            if other_len < my_length and enemy_smaller == 0:
                                enemy_smaller = 1.0 / dist
                            elif enemy_bigger == 0:
                                enemy_bigger = 1.0 / dist
                        elif enemy_body == 0:
                            enemy_body = 1.0 / dist
            
            # 11 features per direction
            base = ray_idx * 11
            vision[base + 0] = health_dist
            vision[base + 1] = gun_dists[0]
            vision[base + 2] = gun_dists[1]
            vision[base + 3] = gun_dists[2]
            vision[base + 4] = gun_dists[3]
            vision[base + 5] = self_danger
            vision[base + 6] = wall_dist
            vision[base + 7] = enemy_smaller
            vision[base + 8] = enemy_bigger
            vision[base + 9] = enemy_body
            vision[base + 10] = nearest_enemy_gun
        
        # 8 state inputs (88-95)
        vision[88] = self.get_health_ratio(agent)
        vision[89] = self.get_arena_position(head_x, head_y)
        vision[90] = self.get_threat_density(agent_idx, head_x, head_y)
        vision[91] = min(my_length / 20.0, 1.0)
        vision[92] = self.get_current_gun_type(agent)
        vision[93] = 1.0 if agent.get('is_mvp', False) else 0.0
        vision[94] = min(agent.get('mvp_time', 0) / 500.0, 1.0)  # Normalize to ~500 steps
        
        # Distance to MVP
        mvp_dist = 1.0
        if self.current_mvp_idx is not None and self.current_mvp_idx != agent_idx:
            mvp = self.agents[self.current_mvp_idx]
            if mvp['alive'] and mvp['segments']:
                mvp_head = mvp['segments'][0]
                dist = abs(mvp_head['x'] - head_x) + abs(mvp_head['y'] - head_y)
                mvp_dist = min(dist / 50.0, 1.0)
        vision[95] = mvp_dist
        
        return vision
    
    def step(self, actions):
        self.step_count += 1
        rewards = np.zeros(self.num_agents, dtype=np.float32)
        
        # Apply turns
        for i, agent in enumerate(self.agents):
            if not agent['alive']:
                continue
            action = actions[i]
            if action == 0:
                agent['direction'] = (agent['direction'] - 1) % 4
            elif action == 2:
                agent['direction'] = (agent['direction'] + 1) % 4
        
        # Move agents
        for i, agent in enumerate(self.agents):
            if not agent['alive']:
                continue
            
            old_tail = agent['segments'][-1]
            agent['prev_tail_pos'] = (old_tail['x'], old_tail['y'])
            
            tail_pos = agent['prev_tail_pos']
            if tail_pos in self.segment_map and self.segment_map[tail_pos][0] == i:
                del self.segment_map[tail_pos]
            
            dx, dy = self.DIRECTIONS[agent['direction']]
            head = agent['segments'][0]
            new_head = {
                'x': head['x'] + dx,
                'y': head['y'] + dy,
                'hp': float('inf'),
                'maxHp': float('inf'),
                'gun_type': -1
            }
            
            agent['segments'].insert(0, new_head)
            agent['segments'].pop()
            
            for seg_idx, seg in enumerate(agent['segments']):
                self.segment_map[(seg['x'], seg['y'])] = (i, seg_idx)
            
            rewards[i] += 0.01
        
        # Check collisions
        for i, agent in enumerate(self.agents):
            if not agent['alive']:
                continue
            
            head = agent['segments'][0]
            hx, hy = head['x'], head['y']
            
            if hx < 0 or hx >= self.WORLD_COLS or hy < 0 or hy >= self.WORLD_ROWS:
                agent['alive'] = False
                rewards[i] -= 5.0
                continue
            
            my_body = set((s['x'], s['y']) for s in agent['segments'][1:])
            if (hx, hy) in my_body:
                agent['alive'] = False
                rewards[i] -= 5.0
                continue
            
            for j, other in enumerate(self.agents):
                if i == j or not other['alive']:
                    continue
                
                oh = other['segments'][0]
                if hx == oh['x'] and hy == oh['y']:
                    if len(agent['segments']) > len(other['segments']):
                        other['alive'] = False
                        rewards[j] -= 5.0
                        rewards[i] += 3.0
                    elif len(agent['segments']) < len(other['segments']):
                        agent['alive'] = False
                        rewards[i] -= 5.0
                        rewards[j] += 3.0
                    else:
                        agent['alive'] = other['alive'] = False
                        rewards[i] = rewards[j] = -3.0
                    break
                
                other_body = set((s['x'], s['y']) for s in other['segments'][1:])
                if (hx, hy) in other_body:
                    agent['alive'] = False
                    rewards[i] -= 5.0
                    rewards[j] += 2.0
                    break
        
        # Pickup collection - "All Same Gun" system
        for i, agent in enumerate(self.agents):
            if not agent['alive']:
                continue
            
            head = agent['segments'][0]
            head_pos = (head['x'], head['y'])
            
            if head_pos in self.health_pickups:
                self.health_pickups.remove(head_pos)
                for seg in agent['segments'][1:]:
                    if seg['hp'] != float('inf'):
                        seg['hp'] = min(seg['hp'] + 30, seg['maxHp'])
                rewards[i] += 0.5
                self.health_pickups.add((
                    random.randint(0, self.WORLD_COLS - 1),
                    random.randint(0, self.WORLD_ROWS - 1)
                ))
            
            if head_pos in self.gun_pickups:
                gun_type = self.gun_pickups[head_pos]
                del self.gun_pickups[head_pos]
                
                # "All Same Gun" - convert ALL segments to new gun type
                for seg in agent['segments'][1:]:
                    seg['gun_type'] = gun_type
                agent['current_gun'] = gun_type
                
                # Add new segment
                prev_tail = agent['prev_tail_pos']
                agent['segments'].append({
                    'x': prev_tail[0],
                    'y': prev_tail[1],
                    'hp': 100,
                    'maxHp': 100,
                    'gun_type': gun_type
                })
                self.segment_map[prev_tail] = (i, len(agent['segments']) - 1)
                agent['score'] += 1
                rewards[i] += 1.0
                
                new_pos = (random.randint(0, self.WORLD_COLS - 1), random.randint(0, self.WORLD_ROWS - 1))
                self.gun_pickups[new_pos] = random.randint(0, 3)
        
        # Update MVP
        self.update_mvp()
        
        # MVP reward bonus
        for i, agent in enumerate(self.agents):
            if agent['alive'] and agent.get('is_mvp', False):
                rewards[i] += 0.02  # Small bonus for being MVP
        
        # Random damage
        if self.step_count % 20 == 0:
            for agent in self.agents:
                if agent['alive'] and len(agent['segments']) > 1:
                    idx = random.randint(1, len(agent['segments']) - 1)
                    if agent['segments'][idx]['hp'] != float('inf'):
                        agent['segments'][idx]['hp'] -= random.randint(5, 15)
                        if agent['segments'][idx]['hp'] <= 0:
                            agent['segments'].pop(idx)
        
        observations = [self.get_vision_v4(i) for i in range(self.num_agents)]
        dones = [not a['alive'] for a in self.agents]
        
        alive = sum(1 for a in self.agents if a['alive'])
        game_over = alive <= 1 or self.step_count >= 500
        
        if game_over and alive == 1:
            for i, a in enumerate(self.agents):
                if a['alive']:
                    rewards[i] += 5.0
        
        return observations, rewards.tolist(), dones, game_over

In [None]:
class SelfPlayTrainer:
    def __init__(self, pool_size=10, old_brain_path=None):
        self.current_model = LocomotNetwork(input_size=NEW_INPUT_SIZE).to(device)
        
        if old_brain_path and os.path.exists(old_brain_path):
            # Try to load - transfer if needed
            with open(old_brain_path, 'r') as f:
                old_data = json.load(f)
            old_input_size = len(old_data['net.0.weight'][0]) if 'net.0.weight' in old_data else 0
            
            if old_input_size == NEW_INPUT_SIZE:
                load_existing_weights(self.current_model, old_brain_path)
                print(f"✓ Loaded matching {old_input_size}-input brain from {old_brain_path}")
            else:
                transfer_from_old_brain(self.current_model, old_brain_path)
                print(f"Transferred from {old_input_size}-input to {NEW_INPUT_SIZE}-input")
        else:
            print("Starting with random weights")
        
        self.target_model = LocomotNetwork(input_size=NEW_INPUT_SIZE).to(device)
        self.target_model.load_state_dict(self.current_model.state_dict())
        
        self.opponent_pool = [deepcopy(self.current_model.state_dict())]
        self.pool_size = pool_size
        self.memory = deque(maxlen=50000)
        
        self.optimizer = optim.Adam(self.current_model.parameters(), lr=0.001)
        self.gamma = 0.99
        self.epsilon = 0.5 if old_brain_path else 1.0
        self.epsilon_min = 0.05
        self.epsilon_decay = 0.995
        self.batch_size = 64
        
        self.episode_rewards = []
        self.win_rates = []
    
    def train_from_player_data(self, data_path, epochs=10):
        """Imitation learning - train from recorded top player data"""
        with open(data_path, 'r') as f:
            data = json.load(f)
        
        if not data:
            print("No player data found")
            return
        
        # Check input size compatibility
        sample_state = data[0].get('state', [])
        if len(sample_state) != NEW_INPUT_SIZE:
            print(f"⚠ Player data has {len(sample_state)} inputs, expected {NEW_INPUT_SIZE}")
            print("  Skipping imitation learning (incompatible data)")
            return
        
        # Data format: [{state: [...], action: 0/1/2}, ...]
        states = torch.FloatTensor([d['state'] for d in data]).to(device)
        actions = torch.LongTensor([d['action'] for d in data]).to(device)
        
        print(f"Training on {len(data)} recorded frames...")
        
        criterion = nn.CrossEntropyLoss()
        
        for epoch in tqdm(range(epochs), desc="Imitation Learning"):
            perm = torch.randperm(len(states))
            states, actions = states[perm], actions[perm]
            
            total_loss = 0
            correct = 0
            
            for i in range(0, len(states), self.batch_size):
                batch_states = states[i:i+self.batch_size]
                batch_actions = actions[i:i+self.batch_size]
                
                outputs = self.current_model(batch_states)
                loss = criterion(outputs, batch_actions)
                
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                
                total_loss += loss.item()
                correct += (outputs.argmax(1) == batch_actions).sum().item()
            
            acc = correct / len(states) * 100
            tqdm.write(f"Epoch {epoch+1}/{epochs} | Loss: {total_loss:.4f} | Accuracy: {acc:.1f}%")
        
        self.target_model.load_state_dict(self.current_model.state_dict())
        print("Imitation learning complete!")
    
    def select_opponents(self, num_opponents):
        opponents = []
        for _ in range(num_opponents):
            if random.random() < 0.3:
                opponents.append(deepcopy(self.current_model))
            else:
                idx = min(int(random.triangular(0, len(self.opponent_pool), len(self.opponent_pool))), 
                         len(self.opponent_pool) - 1)
                opponent = LocomotNetwork(input_size=NEW_INPUT_SIZE).to(device)
                opponent.load_state_dict(self.opponent_pool[idx])
                opponent.eval()
                opponents.append(opponent)
        return opponents
    
    def play_episode(self):
        env = LocomotEnv(num_agents=4)
        observations = env.reset()
        opponents = self.select_opponents(3)
        
        episode_transitions = []
        total_reward = 0
        
        while True:
            actions = [self.current_model.get_action(observations[0], self.epsilon)]
            for i, opp in enumerate(opponents):
                if env.agents[i + 1]['alive']:
                    actions.append(opp.get_action(observations[i + 1], 0.0))
                else:
                    actions.append(1)
            
            next_obs, rewards, dones, game_over = env.step(actions)
            
            if env.agents[0]['alive'] or dones[0]:
                episode_transitions.append((observations[0], actions[0], rewards[0], next_obs[0], dones[0]))
                total_reward += rewards[0]
            
            observations = next_obs
            if game_over:
                break
        
        for t in episode_transitions:
            self.memory.append(t)
        
        final_health = env.get_health_ratio(env.agents[0]) if env.agents[0]['alive'] else 0
        return total_reward, env.agents[0]['alive'], len(env.agents[0]['segments']), final_health
    
    def train_step(self):
        if len(self.memory) < self.batch_size:
            return 0
        
        batch = random.sample(self.memory, self.batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)
        
        states = torch.FloatTensor(np.array(states)).to(device)
        actions = torch.LongTensor(actions).to(device)
        rewards = torch.FloatTensor(rewards).to(device)
        next_states = torch.FloatTensor(np.array(next_states)).to(device)
        dones = torch.FloatTensor(dones).to(device)
        
        current_q = self.current_model(states).gather(1, actions.unsqueeze(1))
        
        with torch.no_grad():
            next_actions = self.current_model(next_states).argmax(1)
            next_q = self.target_model(next_states).gather(1, next_actions.unsqueeze(1)).squeeze()
            target_q = rewards + self.gamma * next_q * (1 - dones)
        
        loss = nn.MSELoss()(current_q.squeeze(), target_q)
        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.current_model.parameters(), 1.0)
        self.optimizer.step()
        return loss.item()
    
    def update_target(self):
        tau = 0.01
        for tp, cp in zip(self.target_model.parameters(), self.current_model.parameters()):
            tp.data.copy_(tau * cp.data + (1 - tau) * tp.data)
    
    def add_to_pool(self):
        self.opponent_pool.append(deepcopy(self.current_model.state_dict()))
        if len(self.opponent_pool) > self.pool_size:
            del self.opponent_pool[len(self.opponent_pool) // 2]
    
    def train(self, num_episodes=2000, save_every=500):
        import time
        wins = 0
        recent_rewards = deque(maxlen=50)
        start_time = time.time()
        
        pbar = tqdm(range(num_episodes), desc="Self-Play Training")
        for episode in pbar:
            reward, won, final_length, final_health = self.play_episode()
            recent_rewards.append(reward)
            if won:
                wins += 1
            
            for _ in range(2):
                self.train_step()
            
            self.update_target()
            self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)
            
            if episode % 50 == 0 and episode > 0:
                self.add_to_pool()
            
            if episode % 25 == 0:
                win_rate = wins / 25 if episode > 0 else 0
                avg_reward = np.mean(recent_rewards) if recent_rewards else 0
                pbar.set_postfix({
                    'R': f'{avg_reward:.1f}',
                    'Win': f'{win_rate:.0%}',
                    'ε': f'{self.epsilon:.2f}'
                })
                self.episode_rewards.append(avg_reward)
                self.win_rates.append(win_rate)
                wins = 0
            
            if episode % save_every == 0 and episode > 0:
                save_brain_json(self.current_model, f'brain_v3_ep{episode}.json')
        
        elapsed = time.time() - start_time
        print(f'\nDone! {elapsed:.1f}s ({num_episodes/elapsed:.1f} ep/s)')
        return self.current_model

## Training
Run the cell below to start training. Takes ~5-10 minutes.

In [None]:
# Auto-find or fetch brain file
import glob
import urllib.request
from datetime import datetime

# GitHub raw URL for the brain file
BRAIN_URL = 'https://raw.githubusercontent.com/savecharlie/locomot-io/main/training/brain_85_init.json'
SERVER_URL = 'https://locomot-io.savecharlie.partykit.dev'

def find_local_brain():
    """Find any brain.json file locally"""
    search_paths = ['.', '/content']
    all_brains = []
    
    for path in search_paths:
        if os.path.exists(path):
            for f in glob.glob(os.path.join(path, '*.json')):
                if 'brain' in os.path.basename(f).lower():
                    all_brains.append(f)
    
    if all_brains:
        all_brains.sort(key=lambda f: os.path.getmtime(f), reverse=True)
        return all_brains[0]
    return None

def fetch_brain_from_github():
    """Fetch brain from GitHub"""
    try:
        print(f"Fetching from {BRAIN_URL}...")
        with urllib.request.urlopen(BRAIN_URL, timeout=30) as response:
            brain_data = json.loads(response.read().decode())
            if brain_data and 'net.0.weight' in brain_data:
                with open('brain_from_github.json', 'w') as f:
                    json.dump(brain_data, f)
                input_size = len(brain_data['net.0.weight'][0])
                print(f"✓ Fetched brain from GitHub ({input_size} inputs)")
                return 'brain_from_github.json'
    except Exception as e:
        print(f"Could not fetch brain from GitHub: {e}")
    return None

# Try to find brain: local first, then GitHub
brain_path = find_local_brain()

if brain_path:
    print(f"✓ Found local brain: {brain_path}")
else:
    print("No local brain found, fetching from GitHub...")
    brain_path = fetch_brain_from_github()

if not brain_path:
    print("⚠ No brain found - starting fresh with random weights")

# Fetch player training data from server
print("\nFetching player training data from server...")
try:
    req = urllib.request.Request(
        f'{SERVER_URL}/party/collective',
        data=json.dumps({'type': 'get_player_data'}).encode(),
        headers={'Content-Type': 'application/json'},
        method='POST'
    )
    with urllib.request.urlopen(req, timeout=10) as response:
        player_data = json.loads(response.read().decode())
        if player_data and len(player_data) > 0:
            with open('player_data.json', 'w') as f:
                json.dump(player_data, f)
            print(f"✓ Got {len(player_data)} frames from top players!")
            has_player_data = True
        else:
            print("No player data on server yet")
            has_player_data = False
except Exception as e:
    print(f"Could not fetch player data: {e}")
    has_player_data = False

# Train!
trainer = SelfPlayTrainer(pool_size=10, old_brain_path=brain_path)

# If we have player data, do imitation learning first
if has_player_data:
    print("\n=== Imitation Learning from Top Players ===")
    trainer.train_from_player_data('player_data.json', epochs=10)

# Then self-play
print("\n=== Self-Play Training ===")
trained_model = trainer.train(num_episodes=2000, save_every=500)

In [None]:
# Plot training progress
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

ax1.plot(trainer.episode_rewards)
ax1.set_title('Average Episode Reward')
ax1.set_xlabel('Episode (×100)')
ax1.set_ylabel('Reward')
ax1.grid(True)

ax2.plot(trainer.win_rates)
ax2.set_title('Win Rate vs Self-Play Pool')
ax2.set_xlabel('Episode (×100)')
ax2.set_ylabel('Win Rate')
ax2.grid(True)

plt.tight_layout()
plt.show()

In [None]:
# Save with timestamp so next run auto-finds it
from datetime import datetime
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
filename = f'brain_{timestamp}.json'

save_brain_json(trained_model, filename)
print(f"\n✓ Saved as {filename}")
print("Next run will auto-load this brain!")

if IN_COLAB:
    files.download(filename)
    print('Downloaded! Update index.html with this brain.')

## Architecture v3 (85 inputs)

After training completes, the brain will work with the updated `index.html`.

**New format (85 inputs = 8 dirs × 10 features + 5 state):**

Per direction (10 features):
- health_pickup - distance to health pickup
- machinegun_pickup - distance to MACHINEGUN pickup
- shotgun_pickup - distance to SHOTGUN pickup
- cannon_pickup - distance to CANNON pickup
- pulse_pickup - distance to PULSE pickup
- self - distance to own body
- wall - distance to wall/edge
- smaller_head - distance to smaller enemy head
- bigger_head - distance to bigger enemy head
- enemy_body - distance to enemy body

State inputs (5):
- health_ratio - overall body health (0-1)
- arena_position - distance from edge (0-1)
- threat_density - nearby enemy threat (0-1)
- my_length - normalized snake length (0-1)
- current_gun_type - what gun first body segment has (0-1)

**Changes from v2:**
- Gun pickup split from 1 to 4 inputs (per gun type)
- Added current_gun_type state input
- Hidden layer expanded (64 → 128) for more features
- Pickup cooloff: enemies can't pick up own recent drops (3 sec cooloff)