# Generate a C export for the GBA
If you have trained model in onnx, you can generate a c export of inference for the gba 

In [1]:
import sys
sys.path.append("..")

import numpy as np
import torch
import torch.nn as nn
import onnx
import onnxruntime as ort
from onnx import shape_inference
from onnx import numpy_helper

import torch.nn.functional as F

from pkmn_rl_arena.quantize.quantize import FullQuantizer
from pkmn_rl_arena.export.onnx_exporter import ONNXExporter
from pkmn_rl_arena.paths import PATHS

from pkmn_rl_arena.export.passes.delete_pass import (
    DeleteFirstInputQDQPass,
    DeleteQuantizePass,
    DeleteFirstLastQuantizeDequantizePass,
)
from pkmn_rl_arena.export.passes.fusion_pass import (
    GemmQuantDequantFusionPass,
)

from pkmn_rl_arena.export.base import ExportBaseGba

from pkmn_rl_arena.export.exporters.parameters import ExportParameters
import rustboyadvance_py


In [2]:

onnx_path = "pokemon_battle_model.onnx"
quantized_onnx_path = "pokemon_battle_model_quantized.onnx"
fused_path = "pokemon_battle_model_quantized_fused.onnx"

num_samples = 10

quantizer = FullQuantizer(onnx_path, quantized_onnx_path)
calib_reader = FullQuantizer.create_fake_calibration_data(
    onnx_path, num_samples=num_samples
)
quantizer.quantize(calib_reader)

# Infer shapes
quantized_model = onnx.load(quantized_onnx_path)
inferred_model = shape_inference.infer_shapes(quantized_model)
onnx.save(inferred_model, quantized_onnx_path)

ExportBaseGba.copy_gba_folder(".")

'./gba'

In [3]:
def get_last_qdq_scaling_factor(graph):
    """
    Get the actual scale and zero_point value of the last DequantizeLinear node before output.
    Returns (scale_value, zero_point_value)
    """
    for node in reversed(graph.node):
        if node.op_type == "DequantizeLinear":
            # Find scale and zero_point names
            scale_name = node.input[1]
            zero_point_name = node.input[2]
            scale_value = None
            zero_point_value = None
            # Search initializers for actual values
            for init in graph.initializer:
                if init.name == scale_name:
                    scale_value = float(numpy_helper.to_array(init))
                if init.name == zero_point_name:
                    zero_point_value = int(numpy_helper.to_array(init))
            if scale_value is not None and zero_point_value is not None:
                return scale_value, zero_point_value
    return None, None


def get_first_qdq_scaling_factor(graph):
    """
    Get the actual scale and zero_point value of the first QuantizeLinear node after input.
    Returns (scale_value, zero_point_value)
    """
    for node in graph.node:
        if node.op_type == "QuantizeLinear":
            # Find scale and zero_point names
            scale_name = node.input[1]
            zero_point_name = node.input[2]
            scale_value = None
            zero_point_value = None
            # Search initializers for actual values
            for init in graph.initializer:
                if init.name == scale_name:
                    scale_value = float(numpy_helper.to_array(init))
                if init.name == zero_point_name:
                    zero_point_value = int(numpy_helper.to_array(init))
            if scale_value is not None and zero_point_value is not None:
                return scale_value, zero_point_value
    return None, None


In [4]:
def apply_fusion_passes(
    model_path,
    output_path=None,
    use_gemm_fusion=True,
    use_delete_pass=True,
    use_delete_first_last_pass=False,
    use_delete_first_pass=True,
):
    """Apply fusion and optimization passes to an ONNX model."""
    if output_path is None:
        output_path = model_path

    onnx_model = onnx.load(model_path)

    if use_gemm_fusion:
        fusion_pass = GemmQuantDequantFusionPass()
        fusion_pass.run(onnx_model.graph)

    if use_delete_pass:
        delete_pass = DeleteQuantizePass()
        delete_pass.run(onnx_model.graph)
        
    if use_delete_first_last_pass:
        delete_first_pass = DeleteFirstLastQuantizeDequantizePass()
        delete_first_pass.run(onnx_model.graph)

    if use_delete_first_pass:
        delete_first_pass = DeleteFirstInputQDQPass()
        delete_first_pass.run(onnx_model.graph)

    onnx.save(onnx_model, output_path)
    return output_path

In [None]:
quantized_model = onnx.load(quantized_onnx_path)
output_scale = get_last_qdq_scaling_factor(quantized_model.graph)[0]
if output_scale is None:
    raise ValueError("Could not find output DQ node scale and zero point")
input_random = np.random.uniform(-1, 1, (1, 360)).astype(np.float32)
ort_session = ort.InferenceSession(quantized_onnx_path)
input_type = ort_session.get_inputs()[0].type
if 'int8' in input_type:
    model_input = input_random
else:
    model_input = input_random.astype(np.float32)
ort_outputs = ort_session.run(None, {ort_session.get_inputs()[0].name: model_input})
onnx_output = ort_outputs[0]
apply_fusion_passes(quantized_onnx_path, fused_path, 
                    use_gemm_fusion=True, use_delete_pass=True, 
                    use_delete_first_pass=False, use_delete_first_last_pass=False)
exporter = ONNXExporter(fused_path)
exporter.export(output_dir="gba")
launch_makefile()
gba, parser, addr_write, addr_read, output_addr, input_addr = setup_gba_environment(
    self.rom_path, self.map_path)
input_scale = get_first_qdq_scaling_factor(quantized_model.graph)[0]
input_gba = np.round(input_random / input_scale).astype(np.int8)
gba_output = run_gba_inference(gba, addr_write, input_addr, output_addr, 
                            input_gba, 5)
print(f"ONNX output (int8): {onnx_output}")
print(f"GBA output (int8): {gba_output}")
onnx_float = onnx_output.astype(np.float32)
gba_float = (gba_output.astype(np.float32)) * output_scale
print(f"ONNX output (float): {onnx_float}")
print(f"GBA output (float): {gba_float}")
float_match = np.allclose(onnx_float, gba_float, rtol=1e-1, atol=6e4)
if float_match:
    print("Dequantized outputs match within tolerance!")
else:
    print("Dequantized outputs don't match")
    diff = np.abs(onnx_float - gba_float)
    max_diff = np.max(diff)
    avg_diff = np.mean(diff)
    print(f"Max difference: {max_diff}")
    print(f"Average difference: {avg_diff}")
    print(f"Differences: {diff}")
self.assertTrue(float_match, "Outputs don't match even after dequantization")

Deleting QuantizeLinear -> DequantizeLinear pair: input_QuantizeLinear -> input_DequantizeLinear
Deleting QuantizeLinear -> DequantizeLinear pair: /net/net.1/Relu_output_0_QuantizeLinear -> /net/net.1/Relu_output_0_DequantizeLinear
Remapped input: input_DequantizeLinear_Output -> input in node /net/net.0/Gemm_fused
Remapped input: /net/net.1/Relu_output_0_DequantizeLinear_Output -> /net/net.1/Relu_output_0 in node /net/net.2/Gemm_fused


NameError: name 'os' is not defined