# LOCOMOT.IO Self-Play Training

Train the neural network AI by playing against itself. Creates a curriculum of progressively harder opponents.

**Architecture:**
- Input: 48 values (8 directions × 6 features)
- Hidden: 64 → 64 (ReLU)
- Output: 3 (left, straight, right)

**Self-Play Strategy:**
1. Start with current trained model
2. Play games against copies of itself
3. Learn from wins and losses
4. Keep a pool of past versions for diverse training

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
import json
from collections import deque
from copy import deepcopy
import matplotlib.pyplot as plt

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

## Neural Network Architecture
Matches the JavaScript implementation exactly

In [None]:
class LocomotNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(48, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, 3)
        )
    
    def forward(self, x):
        return self.net(x)
    
    def get_action(self, state, epsilon=0.0):
        """Get action with epsilon-greedy exploration"""
        if random.random() < epsilon:
            return random.randint(0, 2)
        with torch.no_grad():
            state_t = torch.FloatTensor(state).unsqueeze(0).to(device)
            q_values = self.forward(state_t)
            return q_values.argmax(dim=1).item()

def load_existing_weights(model, json_path):
    """Load weights from JavaScript brain.json format"""
    with open(json_path, 'r') as f:
        weights = json.load(f)
    
    state_dict = {
        'net.0.weight': torch.FloatTensor(weights['net.0.weight']),
        'net.0.bias': torch.FloatTensor(weights['net.0.bias']),
        'net.2.weight': torch.FloatTensor(weights['net.2.weight']),
        'net.2.bias': torch.FloatTensor(weights['net.2.bias']),
        'net.4.weight': torch.FloatTensor(weights['net.4.weight']),
        'net.4.bias': torch.FloatTensor(weights['net.4.bias']),
    }
    model.load_state_dict(state_dict)
    return model

def save_brain_json(model, path):
    """Export weights to JavaScript-compatible format"""
    state_dict = model.state_dict()
    brain = {
        'net.0.weight': state_dict['net.0.weight'].cpu().numpy().tolist(),
        'net.0.bias': state_dict['net.0.bias'].cpu().numpy().tolist(),
        'net.2.weight': state_dict['net.2.weight'].cpu().numpy().tolist(),
        'net.2.bias': state_dict['net.2.bias'].cpu().numpy().tolist(),
        'net.4.weight': state_dict['net.4.weight'].cpu().numpy().tolist(),
        'net.4.bias': state_dict['net.4.bias'].cpu().numpy().tolist(),
    }
    with open(path, 'w') as f:
        json.dump(brain, f)
    print(f'Saved brain to {path}')

## Game Environment
Simplified version of LOCOMOT.IO for fast training

In [None]:
class LocomotEnv:
    """FAST simplified LOCOMOT.IO environment for self-play"""
    
    WORLD_COLS = 50  # Smaller world = faster
    WORLD_ROWS = 40
    DIRECTIONS = [(0, -1), (1, 0), (0, 1), (-1, 0)]  # UP, RIGHT, DOWN, LEFT
    
    # Pre-compute direction offsets for 8 rays
    RAY_OFFSETS = [
        (0, -1), (1, -1), (1, 0), (1, 1),
        (0, 1), (-1, 1), (-1, 0), (-1, -1)
    ]
    
    def __init__(self, num_agents=4):
        self.num_agents = num_agents
        self.reset()
    
    def reset(self):
        self.agents = []
        self.pickup_set = set()  # Use set for O(1) lookup
        self.segment_map = {}  # (x,y) -> (agent_idx, seg_idx) for fast collision
        self.step_count = 0
        
        # Spawn agents
        for i in range(self.num_agents):
            x = random.randint(8, self.WORLD_COLS - 8)
            y = random.randint(8, self.WORLD_ROWS - 8)
            direction = random.randint(0, 3)
            dx, dy = self.DIRECTIONS[direction]
            segments = [(x - j*dx, y - j*dy) for j in range(4)]
            
            self.agents.append({
                'segments': segments,
                'direction': direction,
                'alive': True,
                'score': 0
            })
            
            # Add to segment map
            for seg_idx, seg in enumerate(segments):
                self.segment_map[seg] = (i, seg_idx)
        
        # Spawn food
        for _ in range(20):
            self.pickup_set.add((
                random.randint(0, self.WORLD_COLS - 1),
                random.randint(0, self.WORLD_ROWS - 1)
            ))
        
        return [self.get_vision_fast(i) for i in range(self.num_agents)]
    
    def get_vision_fast(self, agent_idx):
        """FAST vision using pre-computed offsets and set lookups"""
        agent = self.agents[agent_idx]
        if not agent['alive']:
            return np.zeros(48, dtype=np.float32)
        
        head_x, head_y = agent['segments'][0]
        current_dir = agent['direction']
        my_length = len(agent['segments'])
        my_segments = set(agent['segments'][1:])  # Skip head
        
        vision = np.zeros(48, dtype=np.float32)
        
        # Rotate ray offsets based on direction
        # Direction 0=UP, 1=RIGHT, 2=DOWN, 3=LEFT
        rot = current_dir
        
        for ray_idx in range(8):
            # Rotate ray index by current direction
            rotated_idx = (ray_idx + rot * 2) % 8
            dx, dy = self.RAY_OFFSETS[rotated_idx]
            
            food_dist = self_danger = wall_dist = 0.0
            enemy_smaller = enemy_bigger = enemy_body = 0.0
            
            for dist in range(1, 16):  # Shorter rays = faster
                cx, cy = head_x + dx * dist, head_y + dy * dist
                
                # Wall
                if cx < 0 or cx >= self.WORLD_COLS or cy < 0 or cy >= self.WORLD_ROWS:
                    if wall_dist == 0:
                        wall_dist = 1.0 / dist
                    break
                
                # Self body
                if self_danger == 0 and (cx, cy) in my_segments:
                    self_danger = 1.0 / dist
                
                # Food
                if food_dist == 0 and (cx, cy) in self.pickup_set:
                    food_dist = 1.0 / dist
                
                # Other agents (use segment map)
                pos = (cx, cy)
                if pos in self.segment_map:
                    other_idx, seg_idx = self.segment_map[pos]
                    if other_idx != agent_idx and self.agents[other_idx]['alive']:
                        other_len = len(self.agents[other_idx]['segments'])
                        if seg_idx == 0:  # Head
                            if other_len < my_length and enemy_smaller == 0:
                                enemy_smaller = 1.0 / dist
                            elif enemy_bigger == 0:
                                enemy_bigger = 1.0 / dist
                        elif enemy_body == 0:
                            enemy_body = 1.0 / dist
            
            base = ray_idx * 6
            vision[base:base+6] = [food_dist, self_danger, wall_dist, enemy_smaller, enemy_bigger, enemy_body]
        
        return vision
    
    def step(self, actions):
        """Execute actions. Returns (observations, rewards, dones, game_over)"""
        self.step_count += 1
        rewards = np.zeros(self.num_agents, dtype=np.float32)
        
        # Apply turns
        for i, agent in enumerate(self.agents):
            if not agent['alive']:
                continue
            action = actions[i]
            if action == 0:
                agent['direction'] = (agent['direction'] - 1) % 4
            elif action == 2:
                agent['direction'] = (agent['direction'] + 1) % 4
        
        # Move agents and update segment map
        new_heads = []
        for i, agent in enumerate(self.agents):
            if not agent['alive']:
                new_heads.append(None)
                continue
            
            # Remove old tail from map
            old_tail = agent['segments'][-1]
            if old_tail in self.segment_map and self.segment_map[old_tail][0] == i:
                del self.segment_map[old_tail]
            
            dx, dy = self.DIRECTIONS[agent['direction']]
            new_head = (agent['segments'][0][0] + dx, agent['segments'][0][1] + dy)
            new_heads.append(new_head)
            
            agent['segments'].insert(0, new_head)
            agent['segments'].pop()
            
            # Update segment map
            for seg_idx, seg in enumerate(agent['segments']):
                self.segment_map[seg] = (i, seg_idx)
            
            rewards[i] += 0.01
        
        # Check collisions
        for i, agent in enumerate(self.agents):
            if not agent['alive']:
                continue
            
            hx, hy = agent['segments'][0]
            
            # Wall
            if hx < 0 or hx >= self.WORLD_COLS or hy < 0 or hy >= self.WORLD_ROWS:
                agent['alive'] = False
                rewards[i] -= 5.0
                continue
            
            # Self collision
            if (hx, hy) in set(agent['segments'][1:]):
                agent['alive'] = False
                rewards[i] -= 5.0
                continue
            
            # Enemy collision
            for j, other in enumerate(self.agents):
                if i == j or not other['alive']:
                    continue
                
                oh = other['segments'][0]
                if hx == oh[0] and hy == oh[1]:
                    # Head-to-head
                    if len(agent['segments']) > len(other['segments']):
                        other['alive'] = False
                        rewards[j] -= 5.0
                        rewards[i] += 3.0
                    elif len(agent['segments']) < len(other['segments']):
                        agent['alive'] = False
                        rewards[i] -= 5.0
                        rewards[j] += 3.0
                    else:
                        agent['alive'] = other['alive'] = False
                        rewards[i] = rewards[j] = -3.0
                    break
                
                # Head into body
                if (hx, hy) in set(other['segments'][1:]):
                    agent['alive'] = False
                    rewards[i] -= 5.0
                    rewards[j] += 2.0
                    break
        
        # Food
        for i, agent in enumerate(self.agents):
            if not agent['alive']:
                continue
            head = agent['segments'][0]
            if head in self.pickup_set:
                self.pickup_set.remove(head)
                agent['segments'].append(agent['segments'][-1])
                agent['score'] += 1
                rewards[i] += 1.0
                self.pickup_set.add((
                    random.randint(0, self.WORLD_COLS - 1),
                    random.randint(0, self.WORLD_ROWS - 1)
                ))
        
        observations = [self.get_vision_fast(i) for i in range(self.num_agents)]
        dones = [not a['alive'] for a in self.agents]
        
        alive = sum(1 for a in self.agents if a['alive'])
        game_over = alive <= 1 or self.step_count >= 500  # Shorter episodes
        
        if game_over and alive == 1:
            for i, a in enumerate(self.agents):
                if a['alive']:
                    rewards[i] += 5.0
        
        return observations, rewards.tolist(), dones, game_over

## Self-Play Training
Uses a pool of past versions to ensure diverse opponents

In [None]:
class SelfPlayTrainer:
    def __init__(self, pool_size=10):
        self.current_model = LocomotNetwork().to(device)
        self.target_model = LocomotNetwork().to(device)
        self.target_model.load_state_dict(self.current_model.state_dict())
        
        self.opponent_pool = [deepcopy(self.current_model.state_dict())]
        self.pool_size = pool_size
        self.memory = deque(maxlen=50000)
        
        self.optimizer = optim.Adam(self.current_model.parameters(), lr=0.001)
        self.gamma = 0.99
        self.epsilon = 1.0
        self.epsilon_min = 0.05
        self.epsilon_decay = 0.995  # Faster decay
        self.batch_size = 64
        
        self.episode_rewards = []
        self.win_rates = []
    
    def select_opponents(self, num_opponents):
        opponents = []
        for _ in range(num_opponents):
            if random.random() < 0.3:
                opponents.append(deepcopy(self.current_model))
            else:
                idx = min(int(random.triangular(0, len(self.opponent_pool), len(self.opponent_pool))), 
                         len(self.opponent_pool) - 1)
                opponent = LocomotNetwork().to(device)
                opponent.load_state_dict(self.opponent_pool[idx])
                opponent.eval()
                opponents.append(opponent)
        return opponents
    
    def play_episode(self):
        env = LocomotEnv(num_agents=4)
        observations = env.reset()
        opponents = self.select_opponents(3)
        
        episode_transitions = []
        total_reward = 0
        
        while True:
            actions = [self.current_model.get_action(observations[0], self.epsilon)]
            for i, opp in enumerate(opponents):
                if env.agents[i + 1]['alive']:
                    actions.append(opp.get_action(observations[i + 1], 0.0))
                else:
                    actions.append(1)
            
            next_obs, rewards, dones, game_over = env.step(actions)
            
            if env.agents[0]['alive'] or dones[0]:
                episode_transitions.append((observations[0], actions[0], rewards[0], next_obs[0], dones[0]))
                total_reward += rewards[0]
            
            observations = next_obs
            if game_over:
                break
        
        for t in episode_transitions:
            self.memory.append(t)
        
        return total_reward, env.agents[0]['alive'], len(env.agents[0]['segments'])
    
    def train_step(self):
        if len(self.memory) < self.batch_size:
            return 0
        
        batch = random.sample(self.memory, self.batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)
        
        states = torch.FloatTensor(np.array(states)).to(device)
        actions = torch.LongTensor(actions).to(device)
        rewards = torch.FloatTensor(rewards).to(device)
        next_states = torch.FloatTensor(np.array(next_states)).to(device)
        dones = torch.FloatTensor(dones).to(device)
        
        current_q = self.current_model(states).gather(1, actions.unsqueeze(1))
        
        with torch.no_grad():
            next_actions = self.current_model(next_states).argmax(1)
            next_q = self.target_model(next_states).gather(1, next_actions.unsqueeze(1)).squeeze()
            target_q = rewards + self.gamma * next_q * (1 - dones)
        
        loss = nn.MSELoss()(current_q.squeeze(), target_q)
        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.current_model.parameters(), 1.0)
        self.optimizer.step()
        return loss.item()
    
    def update_target(self):
        tau = 0.01
        for tp, cp in zip(self.target_model.parameters(), self.current_model.parameters()):
            tp.data.copy_(tau * cp.data + (1 - tau) * tp.data)
    
    def add_to_pool(self):
        self.opponent_pool.append(deepcopy(self.current_model.state_dict()))
        if len(self.opponent_pool) > self.pool_size:
            del self.opponent_pool[len(self.opponent_pool) // 2]
    
    def train(self, num_episodes=2000, save_every=500):
        """Main training loop with progress tracking"""
        import time
        wins = 0
        recent_rewards = deque(maxlen=50)
        start_time = time.time()
        
        for episode in range(num_episodes):
            reward, won, final_length = self.play_episode()
            recent_rewards.append(reward)
            if won:
                wins += 1
            
            for _ in range(2):
                self.train_step()
            
            self.update_target()
            self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)
            
            if episode % 50 == 0 and episode > 0:
                self.add_to_pool()
            
            # Print every 25 episodes for faster feedback
            if episode % 25 == 0:
                elapsed = time.time() - start_time
                eps_per_sec = (episode + 1) / elapsed if elapsed > 0 else 0
                win_rate = wins / 25 if episode > 0 else 0
                avg_reward = np.mean(recent_rewards) if recent_rewards else 0
                print(f'Ep {episode:4d} | Reward: {avg_reward:6.2f} | Win: {win_rate:5.1%} | ε: {self.epsilon:.3f} | {eps_per_sec:.1f} ep/s')
                self.episode_rewards.append(avg_reward)
                self.win_rates.append(win_rate)
                wins = 0
            
            if episode % save_every == 0 and episode > 0:
                save_brain_json(self.current_model, f'brain_selfplay_ep{episode}.json')
        
        print(f'\nDone! Total time: {time.time() - start_time:.1f}s')
        return self.current_model

## Run Training

In [None]:
# Create trainer
trainer = SelfPlayTrainer(pool_size=10)

# Train! (2000 episodes should take ~5-10 min)
trained_model = trainer.train(num_episodes=2000, save_every=500)

In [None]:
# Plot training progress
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

ax1.plot(trainer.episode_rewards)
ax1.set_title('Average Episode Reward')
ax1.set_xlabel('Episode (×100)')
ax1.set_ylabel('Reward')
ax1.grid(True)

ax2.plot(trainer.win_rates)
ax2.set_title('Win Rate vs Self-Play Pool')
ax2.set_xlabel('Episode (×100)')
ax2.set_ylabel('Win Rate')
ax2.grid(True)

plt.tight_layout()
plt.show()

In [None]:
# Save final brain for use in game
save_brain_json(trained_model, 'brain_selfplay_final.json')
print('\nDone! Download brain_selfplay_final.json and replace the BRAIN data in index.html')

## How to Use the Trained Brain

1. Download `brain_selfplay_final.json`
2. Open `index.html`
3. Find the `<script id="brainData" type="application/json">` section
4. Replace its contents with the downloaded JSON
5. Deploy and test!

The self-play trained AI should be more aggressive and strategic than the original.