# Generate a C export for the GBA
If you have trained model in onnx, you can generate a c export of inference for the gba 

In [22]:
import sys
sys.path.append("..")

import os

import numpy as np
import torch
import torch.nn as nn
import onnx
import onnxruntime as ort
from onnx import shape_inference
from onnx import numpy_helper

import torch.nn.functional as F

from pkmn_rl_arena.quantize.quantize import FullQuantizer
from pkmn_rl_arena.export.onnx_exporter import ONNXExporter
from pkmn_rl_arena.paths import PATHS

from pkmn_rl_arena.export.passes.delete_pass import (
    DeleteFirstInputQDQPass,
    DeleteQuantizePass,
    DeleteFirstLastQuantizeDequantizePass,
)
from pkmn_rl_arena.export.passes.fusion_pass import (
    GemmQuantDequantFusionPass,
)

from pkmn_rl_arena.export.base import ExportBaseGba
from pkmn_rl_arena.export.onnx_utils import OnnxUtils
from pkmn_rl_arena.data.parser import MapAnalyzer

from pkmn_rl_arena.export.exporters.parameters import ExportParameters
import rustboyadvance_py



### Set Path
We need three paths : 
- The regular onnx, the one that we trained before 
- The quantized one, used to run an inference with onnx_runtime and compare oiur results 
- The fused one, from the quantized model, we create a custom model wich fuse Gemm and QDQ, to be interpretable by our export module

In [23]:
onnx_path = "pokemon_battle_model.onnx"
quantized_onnx_path = "pokemon_battle_model_quantized.onnx"
fused_path = "pokemon_battle_model_quantized_fused.onnx"

You can print the actual graph with onnx to see how much we manipulate it ;) (But it's better with netron, a tool to vizualize onnx graph)


In [24]:
onnx_raw = onnx.load(onnx_path)
print(onnx.helper.printable_graph(onnx_raw.graph))

graph main_graph (
  %input[FLOAT, 1x360]
) initializers (
  %net.0.weight[FLOAT, 128x360]
  %net.0.bias[FLOAT, 128]
  %net.2.weight[FLOAT, 10x128]
  %net.2.bias[FLOAT, 10]
) {
  %/net/net.0/Gemm_output_0 = Gemm[alpha = 1, beta = 1, transB = 1](%input, %net.0.weight, %net.0.bias)
  %/net/net.1/Relu_output_0 = Relu(%/net/net.0/Gemm_output_0)
  %output = Gemm[alpha = 1, beta = 1, transB = 1](%/net/net.1/Relu_output_0, %net.2.weight, %net.2.bias)
  return %output
}


### Quantize the model 
To quantize our model, we need to calibrate QDQ pairs, to do so, calibration data is needed

> Note : 
> It's better to calibrate data, with real data. In the next version of our project, a tool to generate calibration data will be available

In [25]:

num_samples = 10

quantizer = FullQuantizer(onnx_path, quantized_onnx_path)
calib_reader = FullQuantizer.create_fake_calibration_data(
    onnx_path, num_samples=num_samples
)
quantizer.quantize(calib_reader)

# Infer shapes
quantized_model = onnx.load(quantized_onnx_path)
inferred_model = shape_inference.infer_shapes(quantized_model)
onnx.save(inferred_model, quantized_onnx_path)

print(onnx.helper.printable_graph(inferred_model.graph))

graph main_graph (
  %input[FLOAT, 1x360]
) initializers (
  %input_zero_point[INT8, scalar]
  %input_scale[FLOAT, scalar]
  %/net/net.0/Gemm_output_0_zero_point[INT8, scalar]
  %/net/net.0/Gemm_output_0_scale[FLOAT, scalar]
  %net.0.weight_zero_point[INT8, scalar]
  %net.0.weight_scale[FLOAT, scalar]
  %net.0.weight_quantized[INT8, 128x360]
  %/net/net.1/Relu_output_0_zero_point[INT8, scalar]
  %/net/net.1/Relu_output_0_scale[FLOAT, scalar]
  %output_zero_point[INT8, scalar]
  %output_scale[FLOAT, scalar]
  %net.2.weight_zero_point[INT8, scalar]
  %net.2.weight_scale[FLOAT, scalar]
  %net.2.weight_quantized[INT8, 10x128]
  %net.0.bias_quantized[INT32, 128]
  %net.0.bias_quantized_scale[FLOAT, 1]
  %net.0.bias_quantized_zero_point[INT32, scalar]
  %net.2.bias_quantized[INT32, 10]
  %net.2.bias_quantized_scale[FLOAT, 1]
  %net.2.bias_quantized_zero_point[INT32, scalar]
) {
  %input_QuantizeLinear_Output = QuantizeLinear(%input, %input_scale, %input_zero_point)
  %net.0.bias = Dequantize

### Get the GBA template folder
To tests our export, a GBA folder is available

In [26]:

ExportBaseGba.copy_gba_folder(".")

'./gba'

### Generating random inputs
To test our exported model, we generate random inputs and infer it both with onnx_runtime and our model.

But our exported model and our onnx are different, one use QDQ pair and the other is full int8, so we need to retrive our input and output scaling factors to put results on the right range

In [27]:
quantized_model = onnx.load(quantized_onnx_path)
output_scale = OnnxUtils.get_last_qdq_scaling_factor(quantized_model.graph)[0]
input_scale = OnnxUtils.get_first_qdq_scaling_factor(quantized_model.graph)[0]

input_random = np.random.uniform(-1, 1, (1, 360)).astype(np.float32)

Then we run the inference with onnx_runtime and get the outputs.

In [28]:

ort_session = ort.InferenceSession(quantized_onnx_path)
input_type = ort_session.get_inputs()[0].type
if 'int8' in input_type:
    model_input = input_random
else:
    model_input = input_random.astype(np.float32)
ort_outputs = ort_session.run(None, {ort_session.get_inputs()[0].name: model_input})
onnx_output = ort_outputs[0]
print("ONNX Runtime output:", onnx_output)

ONNX Runtime output: [[0.57719505 0.44138443 0.3395265  0.30557385 0.23766854 0.16976325
  0.44138443 0.10185795 0.44138443 0.40743178]]


### Fuse and delete nodes
To have model intrepretable by our export module, we need to fuse QDQ pairs and Gemm OP, it creates a QGemmCustom. It makes the graph more understandable, and we generate Gemm and QDQ as a whole.

In [29]:
fusion_pass = GemmQuantDequantFusionPass()
delete_pass = DeleteQuantizePass()

fusion_pass.run(quantized_model.graph)
delete_pass.run(quantized_model.graph)
onnx.save(quantized_model, fused_path)

fused_model = onnx.load(fused_path)
print(onnx.helper.printable_graph(fused_model.graph))

Deleting QuantizeLinear -> DequantizeLinear pair: input_QuantizeLinear -> input_DequantizeLinear
Deleting QuantizeLinear -> DequantizeLinear pair: /net/net.1/Relu_output_0_QuantizeLinear -> /net/net.1/Relu_output_0_DequantizeLinear
Remapped input: input_DequantizeLinear_Output -> input in node /net/net.0/Gemm_fused
Remapped input: /net/net.1/Relu_output_0_DequantizeLinear_Output -> /net/net.1/Relu_output_0 in node /net/net.2/Gemm_fused
graph main_graph (
  %input[FLOAT, 1x360]
) initializers (
  %input_zero_point[INT8, scalar]
  %input_scale[FLOAT, scalar]
  %/net/net.0/Gemm_output_0_zero_point[INT8, scalar]
  %/net/net.0/Gemm_output_0_scale[FLOAT, scalar]
  %net.0.weight_zero_point[INT8, scalar]
  %net.0.weight_scale[FLOAT, scalar]
  %net.0.weight_quantized[INT8, 128x360]
  %/net/net.1/Relu_output_0_zero_point[INT8, scalar]
  %/net/net.1/Relu_output_0_scale[FLOAT, scalar]
  %output_zero_point[INT8, scalar]
  %output_scale[FLOAT, scalar]
  %net.2.weight_zero_point[INT8, scalar]
  %net.

### Generate the export
To generate the export we just need to call a function. We generate includes, which contains headers with all of the parameters (weights and bias). And a forward.c which do ... 

In [30]:
exporter = ONNXExporter(fused_path)
exporter.export(output_dir="gba")

Exported QGemm layer parameters for _NET_NET_0_GEMM_FUSED
Exported layer _NET_NET_0_GEMM_FUSED parameters to gba/include
Exported layer _NET_NET_1_RELU parameters to gba/include
Exported QGemm layer parameters for _NET_NET_2_GEMM_FUSED
Exported layer _NET_NET_2_GEMM_FUSED parameters to gba/include
Forward function exported to: gba/source/forward.c
Forward function header exported to: gba/include/forward.h


### Make the project

In [31]:
os.system("make -C gba")

make: Entering directory '/home/wboussella/Documents/rl_new_pokemon_ai/rl_new_pokemon_ai/examples/gba'
forward.c
template.c
linking cartridge
built ... gba.gba
ROM fixed!
make: Leaving directory '/home/wboussella/Documents/rl_new_pokemon_ai/rl_new_pokemon_ai/examples/gba'


In file included from /home/wboussella/Documents/rl_new_pokemon_ai/rl_new_pokemon_ai/examples/gba/include/forward.h:5,
                 from /home/wboussella/Documents/rl_new_pokemon_ai/rl_new_pokemon_ai/examples/gba/source/forward.c:5:
   28 | IN_IWRAM static int8_t weight_buffer[WEIGHT_BUFFER_BYTES] __attribute__((aligned(4)));
      |                        ^~~~~~~~~~~~~
In file included from /home/wboussella/Documents/rl_new_pokemon_ai/rl_new_pokemon_ai/examples/gba/source/template.c:9:
/home/wboussella/Documents/rl_new_pokemon_ai/rl_new_pokemon_ai/examples/gba/include/tests.h: In function 'all_tests':
  428 |     u32 start_time, end_time, result;
      |         ^~~~~~~~~~
/home/wboussella/Documents/rl_new_pokemon_ai/rl_new_pokemon_ai/examples/gba/source/template.c: In function 'main':
   41 |         forward(input, output);
      |                 ^~~~~
In file included from /home/wboussella/Documents/rl_new_pokemon_ai/rl_new_pokemon_ai/examples/gba/source/template.c:13:
/home/wb

0

### Communicate with the rust emulator
To run the emulator, we need to add stop addrs, write inside the input and read the output. To understand how it works, please refer to our well writed tutorial on how interragir with the emulator.

https://github.com/wissammm/PkmnRLArena/wiki/How-stopHandleTurn-works
```c
volatile u16 stopWriteData IN_EWRAM;
volatile u16 stopReadData IN_EWRAM;
volatile int8_t input[1024] IN_EWRAM;
volatile int8_t output[10] IN_EWRAM;

int main(void)
{
    ...
    stopWriteData = 1;
    forward(input, output);
    stopReadData = 1;
    ...
}
```


In [32]:
def setup_gba_environment(rom_path, map_path):
    """Setup GBA environment and return necessary objects."""
    gba = rustboyadvance_py.RustGba()
    gba.load(PATHS["BIOS"], rom_path)
    parser = MapAnalyzer(map_path)
    addr_write = int(parser.get_address("stopWriteData"), 16)
    addr_read = int(parser.get_address("stopReadData"), 16)
    gba.add_stop_addr(addr_write, 1, True, "stopWriteData", 3)
    gba.add_stop_addr(addr_read, 1, True, "stopReadData", 4)

    output_addr = int(parser.get_address("output"), 16)
    input_addr = int(parser.get_address("input"), 16)

    return gba, parser, addr_write, addr_read, output_addr, input_addr


def run_gba_inference(
    gba, addr_write, input_addr, output_addr, input_data, output_size
):
    """Run inference on GBA and return results."""
    # Wait for initial stop
    id = gba.run_to_next_stop(20000)
    while id != 3:
        id = gba.run_to_next_stop(20000)

    # Write input data
    gba.write_i8_list(input_addr, input_data.reshape(-1).tolist())
    gba.write_u16(addr_write, 0)

    # Wait for computation to complete
    id = gba.run_to_next_stop(20000)
    while id != 4:
        id = gba.run_to_next_stop(20000)

    # Read output
    output_read = gba.read_i8_list(output_addr, output_size)
    return np.array(output_read, dtype=np.int8).reshape(-1)


In [None]:

gba, parser, addr_write, addr_read, output_addr, input_addr = (
    setup_gba_environment("gba/gba.elf", "gba/build/gba.map")
)
input_gba = np.round(input_random / input_scale).astype(np.int8)
gba_output = run_gba_inference(gba, addr_write, input_addr, output_addr, 
                            input_gba, 10)


[1;38;5;208mWARN[0m [rustboyadvance_utils::elf] [1;38;5;208mELF: skipping program header ProgramHeader { p_type: "PT_LOAD", p_flags: 0x6, p_offset: 0xb8, p_vaddr: 0x30000b8, p_paddr: 0x30000b8, p_filesz: 0x0, p_memsz: 0x464, p_align: 4096 }[0m
INFO [rustboyadvance_utils::elf] ELF: loading segment phdr: ProgramHeader { p_type: "PT_LOAD", p_flags: 0x5, p_offset: 0x1000, p_vaddr: 0x8000000, p_paddr: 0x8000000, p_filesz: 0x22b50, p_memsz: 0x22b50, p_align: 4096 } range 0x1000..0x23b50 vec range 0x8000000..0x8022b50
INFO [rustboyadvance_utils::elf] ELF: loading segment phdr: ProgramHeader { p_type: "PT_LOAD", p_flags: 0x5, p_offset: 0x24000, p_vaddr: 0x3000000, p_paddr: 0x8022b50, p_filesz: 0xb8, p_memsz: 0xb8, p_align: 4096 } range 0x24000..0x240b8 vec range 0x8022b50..0x8022c08
INFO [rustboyadvance_utils::elf] ELF: loading segment phdr: ProgramHeader { p_type: "PT_LOAD", p_flags: 0x6, p_offset: 0x2451c, p_vaddr: 0x300051c, p_paddr: 0x8022c08, p_filesz: 0x176c, p_memsz: 0x176c, p_align

Adding stop address: addr=33554976, value=1, is_active=true, name=stopWriteData, id=3
Adding stop address: addr=33554978, value=1, is_active=true, name=stopReadData, id=4
ONNX output (int8): [[0.57719505 0.44138443 0.3395265  0.30557385 0.23766854 0.16976325
  0.44138443 0.10185795 0.44138443 0.40743178]]
GBA output (int8): [ 6 23 -1 11 17 11 11  5  2 19]
ONNX output (float): [[0.57719505 0.44138443 0.3395265  0.30557385 0.23766854 0.16976325
  0.44138443 0.10185795 0.44138443 0.40743178]]
GBA output (float): [ 0.20371589  0.78091097 -0.03395265  0.37347916  0.57719505  0.37347916
  0.37347916  0.16976325  0.0679053   0.64510036]
Dequantized outputs match within tolerance!


INFO [rustboyadvance_core::sound] bias - setting sample frequency to 32768hz


NameError: name 'self' is not defined

In [None]:
print(f"ONNX output (int8): {onnx_output}")
print(f"GBA output (int8): {gba_output}")
onnx_float = onnx_output.astype(np.float32)
gba_float = (gba_output.astype(np.float32)) * output_scale
print(f"ONNX output (float): {onnx_float}")
print(f"GBA output (float): {gba_float}")
float_match = np.allclose(onnx_float, gba_float, rtol=1e-1, atol=6e4)
if float_match:
    print("Dequantized outputs match within tolerance!")
else:
    print("Dequantized outputs don't match")
    diff = np.abs(onnx_float - gba_float)
    max_diff = np.max(diff)
    avg_diff = np.mean(diff)
    print(f"Max difference: {max_diff}")
    print(f"Average difference: {avg_diff}")
    print(f"Differences: {diff}")