
# Turn Your AI Into GBA Code!

Ever dreamed of running your own neural network on a real Game Boy Advance? That’s exactly what this project lets you do—convert your ONNX model into C code that runs right inside Pokémon Emerald!

Last time, we trained a model. Now, let’s bring it to life on the GBA.

## What You’ll Do in This Tutorial 
- Load your trained ONNX model
- Quantize it to full int8 for retro hardware
- Run and test it directly in the emulator

### Imports

In [116]:
import sys
sys.path.append("..")

import os

import numpy as np
import torch
import torch.nn as nn
import onnx
import onnxruntime as ort
from onnx import shape_inference
from onnx import numpy_helper

import torch.nn.functional as F

from pkmn_rl_arena.quantize.quantize import FullQuantizer
from pkmn_rl_arena.export.onnx_exporter import ONNXExporter
from pkmn_rl_arena.paths import PATHS

from pkmn_rl_arena.export.passes.delete_pass import (
    DeleteFirstInputQDQPass,
    DeleteQuantizePass,
    DeleteFirstLastQuantizeDequantizePass,
)
from pkmn_rl_arena.export.passes.fusion_pass import (
    GemmQuantDequantFusionPass,
)

from pkmn_rl_arena.export.base import ExportBaseGba
from pkmn_rl_arena.export.onnx_utils import OnnxUtils
from pkmn_rl_arena.data.parser import MapAnalyzer

from pkmn_rl_arena.export.exporters.parameters import ExportParameters
import rustboyadvance_py



### Set Model Paths

Before moving forward, let's organize the three key ONNX files we'll use:

- **Original ONNX:** The model you trained, in full precision.
- **Quantized ONNX:** Converted to int8 for hardware compatibility, and used to test inference with ONNX Runtime.
- **Fused ONNX:** Created from the quantized model, with Gemm and QDQ operations fused together so our export module can process it for the GBA.

Defining these paths now will make the next steps much smoother.

In [117]:
onnx_path = "pokemon_battle_model.onnx"
quantized_onnx_path = "pokemon_battle_model_quantized.onnx"
fused_path = "pokemon_battle_model_quantized_fused.onnx"

You can print the actual graph with ONNX to see the changes we make.  
For a clearer visualization, try Netron—a handy tool for exploring ONNX model graphs.

In [118]:
onnx_raw = onnx.load(onnx_path)
print(onnx.helper.printable_graph(onnx_raw.graph))

graph main_graph (
  %input[FLOAT, 1x360]
) initializers (
  %w0[FLOAT, 360x128]
  %b0[FLOAT, 128]
  %w1[FLOAT, 128x10]
  %b1[FLOAT, 10]
) {
  %/Add_output_0 = Gemm[alpha = 1, beta = 1, transA = 0, transB = 0](%input, %w0, %b0)
  %/Relu_output_0 = Relu(%/Add_output_0)
  %output = Gemm[alpha = 1, beta = 1, transA = 0, transB = 0](%/Relu_output_0, %w1, %b1)
  return %output
}


## Quantize the model

Quantization is the process of converting your model’s weights and activations from floating point to int8, making it much more efficient for hardware like the GBA.  

To do this, we need to calibrate QDQ (Quantize-Dequantize) pairs using sample data, which helps the model learn how to scale values correctly.

> Note:  
> Calibration works best with real data from your application. In future versions, we’ll provide a tool to help generate calibration data automatically.

In [119]:

num_samples = 10

quantizer = FullQuantizer(onnx_path, quantized_onnx_path)
calib_reader = FullQuantizer.create_fake_calibration_data(
    onnx_path, num_samples=num_samples
)
quantizer.quantize(calib_reader)

# Infer shapes
quantized_model = onnx.load(quantized_onnx_path)
inferred_model = shape_inference.infer_shapes(quantized_model)
onnx.save(inferred_model, quantized_onnx_path)

print(onnx.helper.printable_graph(inferred_model.graph))

graph main_graph (
  %input[FLOAT, 1x360]
) initializers (
  %input_zero_point[INT8, scalar]
  %input_scale[FLOAT, scalar]
  %/Add_output_0_zero_point[INT8, scalar]
  %/Add_output_0_scale[FLOAT, scalar]
  %w0_zero_point[INT8, scalar]
  %w0_scale[FLOAT, scalar]
  %w0_quantized[INT8, 360x128]
  %/Relu_output_0_zero_point[INT8, scalar]
  %/Relu_output_0_scale[FLOAT, scalar]
  %output_zero_point[INT8, scalar]
  %output_scale[FLOAT, scalar]
  %w1_zero_point[INT8, scalar]
  %w1_scale[FLOAT, scalar]
  %w1_quantized[INT8, 128x10]
  %b0_quantized[INT32, 128]
  %b0_quantized_scale[FLOAT, 1]
  %b0_quantized_zero_point[INT32, scalar]
  %b1_quantized[INT32, 10]
  %b1_quantized_scale[FLOAT, 1]
  %b1_quantized_zero_point[INT32, scalar]
) {
  %b0 = DequantizeLinear(%b0_quantized, %b0_quantized_scale, %b0_quantized_zero_point)
  %b1 = DequantizeLinear(%b1_quantized, %b1_quantized_scale, %b1_quantized_zero_point)
  %input_QuantizeLinear_Output = QuantizeLinear(%input, %input_scale, %input_zero_point)
  

### Get the GBA template folder

To test your exported model, a ready-to-use GBA project folder is provided.  
This folder contains everything needed to build and run your model on the emulator.

(and also on the real hardare !)

In [None]:
ExportBaseGba.copy_gba_folder(".")

'./gba'

### Generating random inputs

To test our exported model, we generate random inputs and run inference using both ONNX Runtime and the exported GBA model.

Since the exported model uses full int8 and the ONNX model uses QDQ pairs, their input and output formats differ.  
To compare results accurately, we need to retrieve the input and output scaling factors and use them to convert values between formats.

In [121]:
quantized_model = onnx.load(quantized_onnx_path)
output_scale = OnnxUtils.get_last_qdq_scaling_factor(quantized_model.graph)[0]
input_scale = OnnxUtils.get_first_qdq_scaling_factor(quantized_model.graph)[0]

input_random = np.random.uniform(-1, 1, (1, 360)).astype(np.float32)

Next, we run inference with ONNX Runtime to get the model outputs.  
This lets us compare the results from the quantized ONNX model with those from the exported GBA model.

In [122]:

ort_session = ort.InferenceSession(quantized_onnx_path)
input_type = ort_session.get_inputs()[0].type
if 'int8' in input_type:
    model_input = input_random
else:
    model_input = input_random.astype(np.float32)
ort_outputs = ort_session.run(None, {ort_session.get_inputs()[0].name: model_input})
onnx_output = ort_outputs[0]
print("ONNX Runtime output:", onnx_output)

ONNX Runtime output: [[132.09496    -4.9847155 154.52618   -19.938862  239.26634   184.43448
    2.4923577 107.17138   -67.293655  -99.694305 ]]


## Fuse and delete nodes

To make the model compatible with our export module, we need to fuse QDQ pairs and Gemm operations.

This process creates a custom QGemm node, simplifying the graph and making it easier to generate the corresponding C code for the GBA.

In [123]:
fusion_pass = GemmQuantDequantFusionPass()
delete_pass = DeleteQuantizePass()

fusion_pass.run(quantized_model.graph)
delete_pass.run(quantized_model.graph)
onnx.save(quantized_model, fused_path)

fused_model = onnx.load(fused_path)
print(onnx.helper.printable_graph(fused_model.graph))

Deleting QuantizeLinear -> DequantizeLinear pair: input_QuantizeLinear -> input_DequantizeLinear
Deleting QuantizeLinear -> DequantizeLinear pair: /Relu_output_0_QuantizeLinear -> /Relu_output_0_DequantizeLinear
Remapped input: input_DequantizeLinear_Output -> input in node /MatMul/MatMulAddFusion_fused
Remapped input: /Relu_output_0_DequantizeLinear_Output -> /Relu_output_0 in node /MatMul_1/MatMulAddFusion_fused
graph main_graph (
  %input[FLOAT, 1x360]
) initializers (
  %input_zero_point[INT8, scalar]
  %input_scale[FLOAT, scalar]
  %/Add_output_0_zero_point[INT8, scalar]
  %/Add_output_0_scale[FLOAT, scalar]
  %w0_zero_point[INT8, scalar]
  %w0_scale[FLOAT, scalar]
  %w0_quantized[INT8, 360x128]
  %/Relu_output_0_zero_point[INT8, scalar]
  %/Relu_output_0_scale[FLOAT, scalar]
  %output_zero_point[INT8, scalar]
  %output_scale[FLOAT, scalar]
  %w1_zero_point[INT8, scalar]
  %w1_scale[FLOAT, scalar]
  %w1_quantized[INT8, 128x10]
  %b0_quantized[INT32, 128]
  %b0_quantized_scale[FLOA

## Generate the export

To generate the export, simply call the export function.  
This will create the necessary header files containing all model parameters (weights and biases), and a `forward.c` file that implements the model’s inference logic for the GBA.

In [124]:
exporter = ONNXExporter(fused_path)
exporter.export(output_dir="gba")

Exported QGemm layer parameters for _MATMUL_MATMULADDFUSION_FUSED
Exported layer _MATMUL_MATMULADDFUSION_FUSED parameters to gba/include
Exported layer _RELU parameters to gba/include
Exported QGemm layer parameters for _MATMUL_1_MATMULADDFUSION_FUSED
Exported layer _MATMUL_1_MATMULADDFUSION_FUSED parameters to gba/include
Forward function exported to: gba/source/forward.c
Forward function header exported to: gba/include/forward.h


### Compile it !

In [125]:
os.system("make -C gba")

make: Entering directory '/home/wboussella/Documents/rl_new_pokemon_ai/rl_new_pokemon_ai/examples/gba'
forward.c
template.c
linking cartridge
built ... gba.gba
ROM fixed!
make: Leaving directory '/home/wboussella/Documents/rl_new_pokemon_ai/rl_new_pokemon_ai/examples/gba'


In file included from /home/wboussella/Documents/rl_new_pokemon_ai/rl_new_pokemon_ai/examples/gba/include/forward.h:5,
                 from /home/wboussella/Documents/rl_new_pokemon_ai/rl_new_pokemon_ai/examples/gba/source/forward.c:5:
   28 | IN_IWRAM static int8_t weight_buffer[WEIGHT_BUFFER_BYTES] __attribute__((aligned(4)));
      |                        ^~~~~~~~~~~~~
In file included from /home/wboussella/Documents/rl_new_pokemon_ai/rl_new_pokemon_ai/examples/gba/source/template.c:9:
/home/wboussella/Documents/rl_new_pokemon_ai/rl_new_pokemon_ai/examples/gba/include/tests.h: In function 'all_tests':
  428 |     u32 start_time, end_time, result;
      |         ^~~~~~~~~~
/home/wboussella/Documents/rl_new_pokemon_ai/rl_new_pokemon_ai/examples/gba/source/template.c: In function 'main':
   41 |         forward(input, output);
      |                 ^~~~~
In file included from /home/wboussella/Documents/rl_new_pokemon_ai/rl_new_pokemon_ai/examples/gba/source/template.c:13:
/home/wb

0

## Communicate with the rust emulator

To run the emulator, we need to set up stop addresses, write the input data, and read the output.  
For a detailed explanation of how this works, check out our tutorial:

https://github.com/wissammm/PkmnRLArena/wiki/How-stopHandleTurn-works

```c
volatile u16 stopWriteData IN_EWRAM;
volatile u16 stopReadData IN_EWRAM;
volatile int8_t input[1024] IN_EWRAM;
volatile int8_t output[10] IN_EWRAM;

int main(void)
{
    ...
    stopWriteData = 1;
    forward(input, output);
    stopReadData = 1;
    ...
}
```

In [126]:
def setup_gba_environment(rom_path, map_path):
    """Setup GBA environment and return necessary objects."""
    gba = rustboyadvance_py.RustGba()
    gba.load(PATHS["BIOS"], rom_path)
    parser = MapAnalyzer(map_path)
    addr_write = int(parser.get_address("stopWriteData"), 16)
    addr_read = int(parser.get_address("stopReadData"), 16)
    gba.add_stop_addr(addr_write, 1, True, "stopWriteData", 3)
    gba.add_stop_addr(addr_read, 1, True, "stopReadData", 4)

    output_addr = int(parser.get_address("output"), 16)
    input_addr = int(parser.get_address("input"), 16)

    return gba, parser, addr_write, addr_read, output_addr, input_addr


def run_gba_inference(
    gba, addr_write, input_addr, output_addr, input_data, output_size
):
    """Run inference on GBA and return results."""
    # Wait for initial stop
    id = gba.run_to_next_stop(20000)
    while id != 3:
        id = gba.run_to_next_stop(20000)

    # Write input data
    gba.write_i8_list(input_addr, input_data.reshape(-1).tolist())
    gba.write_u16(addr_write, 0)

    # Wait for computation to complete
    id = gba.run_to_next_stop(20000)
    while id != 4:
        id = gba.run_to_next_stop(20000)

    # Read output
    output_read = gba.read_i8_list(output_addr, output_size)
    return np.array(output_read, dtype=np.int8).reshape(-1)


In [None]:

gba, parser, addr_write, addr_read, output_addr, input_addr = (
    setup_gba_environment("gba/gba.elf", "gba/build/gba.map")
)
# Scale and convert input to int8
input_gba = np.round(input_random / input_scale).astype(np.int8)
gba_output = run_gba_inference(gba, addr_write, input_addr, output_addr, 
                            input_gba, 10)


Adding stop address: addr=33554976, value=1, is_active=true, name=stopWriteData, id=3
Adding stop address: addr=33554978, value=1, is_active=true, name=stopReadData, id=4


[1;38;5;208mWARN[0m [rustboyadvance_utils::elf] [1;38;5;208mELF: skipping program header ProgramHeader { p_type: "PT_LOAD", p_flags: 0x6, p_offset: 0xb8, p_vaddr: 0x30000b8, p_paddr: 0x30000b8, p_filesz: 0x0, p_memsz: 0x464, p_align: 4096 }[0m
INFO [rustboyadvance_utils::elf] ELF: loading segment phdr: ProgramHeader { p_type: "PT_LOAD", p_flags: 0x5, p_offset: 0x1000, p_vaddr: 0x8000000, p_paddr: 0x8000000, p_filesz: 0x22b50, p_memsz: 0x22b50, p_align: 4096 } range 0x1000..0x23b50 vec range 0x8000000..0x8022b50
INFO [rustboyadvance_utils::elf] ELF: loading segment phdr: ProgramHeader { p_type: "PT_LOAD", p_flags: 0x5, p_offset: 0x24000, p_vaddr: 0x3000000, p_paddr: 0x8022b50, p_filesz: 0xb8, p_memsz: 0xb8, p_align: 4096 } range 0x24000..0x240b8 vec range 0x8022b50..0x8022c08
INFO [rustboyadvance_utils::elf] ELF: loading segment phdr: ProgramHeader { p_type: "PT_LOAD", p_flags: 0x6, p_offset: 0x2451c, p_vaddr: 0x300051c, p_paddr: 0x8022c08, p_filesz: 0x176c, p_memsz: 0x176c, p_align

## Compare Results

Finally, we compare the outputs from ONNX Runtime and the GBA model. 

Since both outputs are in different formats, we dequantize them using the scaling factors to bring them to the same range.  
This allows us to check if the results match (within a reasonable tolerance).

In [None]:
onnx_float = onnx_output.astype(np.float32).reshape(-1)
gba_float = (gba_output.astype(np.float32)) * output_scale

print(f"ONNX output (float):{onnx_float}")
print(f"GBA output (float):    {gba_float}")
float_match = np.allclose(onnx_float, gba_float, rtol=1e-1, atol=6e4)
if float_match:
    print("Dequantized outputs match with (a lot of) tolerance!")
else:
    print("Outputs do not match :O")


ONNX output (float):[132.09496    -4.9847155 154.52618   -19.938862  239.26634   184.43448
   2.4923577 107.17138   -67.293655  -99.694305 ]
GBA output (float):    [127.110245  -7.477073 127.110245 -19.938862 236.77399  179.44975
   0.       104.67902  -64.8013   -94.709595]
Dequantized outputs match with (a lot of) tolerance!
