# Our Library

### Neural Network Training Library

We made a Python library that lets us train AI by actually playing Pokémon! The library handles everything: watching the game, picking moves, and learning from mistakes—so the AI gets better all by itself.

### Rust-based Emulator with PyO3 Bindings

We souped up a super-fast Game Boy Advance emulator (written in Rust) so it can chat with Python. Thanks to PyO3, Python can now:
- **See what’s happening:** Grab all the juicy game details—Pokémon stats, battle info, and more—and send them to the AI.
- **Control the action:** Let the AI pick moves or switch Pokémon, and actually make those things happen in-game.


### Custom Pokémon Disassembly for RL

We also hacked the Pokémon Emerald game code so it can:
- **Understand the AI’s commands** sent from Python.
- **Skip all the boring stuff** (like graphics and text) so training is way, way faster.

---

# The tutorial 
In this tutorial, we will, step by step, train a neural network with our librairy
## Goals : 
 - Train a small model in MARL to be the best on 1v1 battles 
 - Watch the performance of our model 
 - Export the model in ONNX (needed for teh next tutorial, run the model on GBA)
 

 ## Imports
 Make sure that you followed the README.md install step correctly

In [None]:
# Imports for training and interacting with the environment
import sys
sys.path.append("..")  

import numpy as np
import random

# PettingZoo for multi-agent RL environments
from pettingzoo.utils import parallel_to_aec
from pettingzoo.test import parallel_api_test

# Main environment and core components
from pkmn_rl_arena.env.battle_core import BattleCore
from pkmn_rl_arena.env.battle_arena import BattleArena, RenderMode, ReplayBuffer
from pkmn_rl_arena.env.pkmn_team_factory import PkmnTeamFactory
from pkmn_rl_arena.env.observation import ObservationFactory, ObsIdx
from pkmn_rl_arena.paths import PATHS

# Logging and debugging
from pkmn_rl_arena import log

# For RL algorithms and neural networks
import torch
import torch.nn as nn
import torch.optim as optim

log.setLevel("CRITICAL")



## Instanciate
With PyTorch, we can easily create a model — and as you can see, this one is really small. But why keep it so small?
On a regular PC, the model size doesn’t matter too much (it mostly depends on your hardware).
However, if you want to export and run the model on the GBA, memory becomes a huge limitation. Here’s the compilation info from **pokeemerald**:
```bash
Memory region         Used Size  Region Size  %age Used
           EWRAM:      251688 B       256 KB     96.01%
           IWRAM:       30416 B        32 KB     92.82%
             ROM:    13334028 B        32 MB     39.74%
```
What does this mean in practice?
- EWRAM → about 10.2 KB left, this is your real RAM (read–write), similar to system RAM on a PC.
- ROM → about 19.27 MB free, this is slow, read-only memory, but it’s where we can store the model’s weights.

If we quantize the model to int8, then:
- We can store up to ~20 million parameters in ROM.
- But RAM is extremely limited: only 10,456 int8 values can fit into EWRAM.

That’s why we need to be very careful. For each node n in the model, the sum of its inputs and outputs must stay below this RAM limit:
$$
\text{input}_n + \text{output}_n < 10,456
$$

![My model architecture](./assets/gba-archi-model.png)


In [None]:
# Instantiate the environment
core = BattleCore(PATHS["ROM"], PATHS["BIOS"], PATHS["MAP"])
env = BattleArena(core)

# get observation and action space sizes
obs = env.reset()[0]
obs_size = obs["player"]["observation"].shape[0]
action_size = env.action_manager.action_space_size

class DQN(nn.Module):
    def __init__(self, obs_size, action_size):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_size, 128),
            nn.ReLU(),
            nn.Linear(128, action_size)
        )
    
    def forward(self, x):
        return self.net(x)


shared_agent = DQN(obs_size, action_size)
optimizer = optim.Adam(shared_agent.parameters(), lr=1e-3)

print("Environment and agents initialized.")

### Training Method
As you know we're gonna use MARL to train our model, only one model is used to train both agents.
Here's a diagram on how it works : 

![My model architecture](./assets/env.png)

To have a fast training, in our example we're only launching 5 episodes


In [None]:
# Hyperparameters
GAMMA = 0.99  # discount factor
EPSILON_START = 1.0  # initial exploration rate
EPSILON_END = 0.1  # final exploration rate
EPSILON_DECAY = 0.95  # decay rate per episode
BATCH_SIZE = 64  # batch size for replay buffer
NUM_EPISODES = 3  # number of episodes to train
TARGET_UPDATE = 10  # update target network every N episodes

target_network = DQN(obs_size, action_size)
target_network.load_state_dict(shared_agent.state_dict())

# For tracking performance
rewards_history = []
win_rates = []
epsilon = EPSILON_START

replay_buffer = ReplayBuffer(capacity=10000)

### Action Selection and Model Optimization
**Action Selection (** select_action **)**

This function implements the classic epsilon-greedy strategy for reinforcement learning:

- With probability epsilon, the agent chooses a random valid action (exploration).
- Otherwise, it selects the action with the highest predicted Q-value from the neural network (exploitation).
- The action mask ensures only legal actions are considered.

**Model Optimization (** optimize_model **)**

This function updates the neural network using experiences sampled from the replay buffer:

- A random batch of transitions is sampled.
- The network predicts Q-values for the current and next states.
- The loss is computed using the Bellman equation and mean squared error.
- The optimizer updates the network weights to minimize this loss.

In [None]:
def select_action(agent, state, epsilon):
    """Epsilon-greedy action selection"""
    if random.random() < epsilon:
        # Random action (exploration)
        valid_actions = np.where(state["action_mask"] == 1)[0]
        if len(valid_actions) > 0:
            return random.choice(valid_actions)
        else:
            return random.randint(0, action_size - 1)
    else:
        # Greedy action (exploitation)
        with torch.no_grad():
            state_tensor = torch.FloatTensor(state["observation"])
            q_values = shared_agent(state_tensor)
            
            # Apply action mask to only consider valid actions
            mask = torch.FloatTensor(state["action_mask"])
            q_values = q_values * mask - 1000.0 * (1 - mask)
            
            return q_values.argmax().item()

def optimize_model():
    """Update neural network from experiences in replay buffer"""
    if len(replay_buffer) < BATCH_SIZE:
        return 0

    transitions = replay_buffer.sample(BATCH_SIZE)
    states = torch.FloatTensor(np.array([t[0]["observation"] for t in transitions]))
    actions = torch.LongTensor([t[1] for t in transitions]).unsqueeze(1)
    rewards = torch.FloatTensor([t[2] for t in transitions])
    next_states = torch.FloatTensor(np.array([t[3]["observation"] for t in transitions]))
    dones = torch.FloatTensor([t[4] for t in transitions])

    current_q = shared_agent(states).gather(1, actions).squeeze(1)
    with torch.no_grad():
        max_next_q = target_network(next_states).max(1)[0]
        target_q = rewards + GAMMA * max_next_q * (1 - dones)

    loss = nn.MSELoss()(current_q, target_q)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    return loss.item()


### Training
**Training Loop:** Where the Magic (and RNG) Happens

This is where our neural network gets its hands dirty—by battling itself, over and over, until it (hopefully) learns something!

In [None]:
# Training loop
print(f"Starting training for {NUM_EPISODES} episodes")
for episode in range(NUM_EPISODES):
    # Reset environment
    observations, _ = env.reset()
    episode_rewards = {"player": 0, "enemy": 0}
    done = False
    
    # Episode loop
    while not done:
        # Select actions for both agents
        actions = {}
        for agent in env.agents:
            actions[agent] = select_action(agent, observations[agent], epsilon)
        
        # Take a step in the environment
        next_observations, rewards, terminations, truncations, _ = env.step(actions)
        
        # Check if episode is done
        done = all(terminations.values()) or all(truncations.values())
        
        # Store transitions in replay buffers
        for agent in env.agents:
            replay_buffer.push((
                observations[agent],
                actions[agent],
                rewards[agent],
                next_observations[agent],
                done
            ))
                    
            # Track rewards
            episode_rewards[agent] += rewards[agent]
        
        # Update observations
        observations = next_observations
        
        # Optimize the model
        loss = optimize_model()
    
    # Decay epsilon
    epsilon = max(EPSILON_END, epsilon * EPSILON_DECAY)
    
    # Update target network periodically
    if episode % TARGET_UPDATE == 0:
        target_network.load_state_dict(shared_agent.state_dict())
    
    # Track rewards
    total_reward = sum(episode_rewards.values())
    rewards_history.append(total_reward)
    
    # Calculate win rate (who won more often)
    if episode_rewards["player"] > episode_rewards["enemy"]:
        win_rates.append(1)  # Player won
    elif episode_rewards["player"] < episode_rewards["enemy"]:
        win_rates.append(0)  # Enemy won
    else:
        win_rates.append(0.5)  # Draw
    
    # Log progress
    if episode % 10 == 0:
        avg_reward = np.mean(rewards_history[-10:]) if len(rewards_history) >= 10 else np.mean(rewards_history)
        win_rate = np.mean(win_rates[-10:]) if len(win_rates) >= 10 else np.mean(win_rates)
        print(f"Episode {episode}/{NUM_EPISODES}, Avg Reward: {avg_reward:.2f}, Win Rate: {win_rate:.2f}, Epsilon: {epsilon:.2f}")


### Export
Once our model is trained, we can export the model to run it inside a GBA
We see you in the next tutorial `export.ipynb`

In [None]:
# Export the trained model to ONNX format
print("Exporting model to ONNX format...")
dummy_input = torch.randn(1, obs_size)
torch.onnx.export(
    shared_agent,
    dummy_input,
    "pokemon_battle_model.onnx",
    export_params=True,
    opset_version=11,
    input_names=['input'],
    output_names=['output']
)
print("Model exported successfully!")
