In [1]:
import torch as th
import numpy as np
import onnx
import onnxruntime as ort

from sb3_contrib import MaskablePPO

from lib.briscola_env.embedding import EMBEDDING_SHAPE
from lib.onnxable import OnnxableSB3Policy

In [2]:
# Load select model
model_name = "models/briscola_4p_1M_20250501-181557"
model_source_file = f"{model_name}.zip"
model_dest_file = f"{model_name}.onnx"

model = MaskablePPO.load(model_source_file, device="cpu")

onnx_policy = OnnxableSB3Policy(model.policy)

observation_size = model.observation_space.shape
dummy_input = th.randn(1, *observation_size)
th.onnx.export(
    onnx_policy,
    dummy_input,
    model_dest_file,
    opset_version=17,
    input_names=["input"],
)

In [3]:
# Load and test model
onnx_model = onnx.load(model_dest_file)
onnx.checker.check_model(onnx_model)

observation = np.zeros((1, *observation_size)).astype(np.float32)
ort_sess = ort.InferenceSession(model_dest_file)

# Check that the predictions are the same
actions, values, log_prob = ort_sess.run(None, {"input": observation})
print(actions, values, log_prob)
with th.no_grad():
    print(model.policy(th.as_tensor(observation), deterministic=True))

[2] [[-12.487151]] [-3.3847802]
(tensor([2]), tensor([[-12.4871]]), tensor([-3.3848]))
