In [None]:
import torch as th
import numpy as np
import onnx
import onnxruntime as ort

from sb3_contrib import MaskablePPO

from lib.onnxable import OnnxableMaskableACPolicy
from lib.briscola.game import BriscolaGame
from lib.briscola_env.embedding import game_embedding

In [None]:
# Load select model
model_name = "models/briscola_4p_1M_20250501-181557"
model_source_file = f"{model_name}.zip"
model_dest_file = f"{model_name}-probdist.onnx"

model = MaskablePPO.load(model_source_file, device="cpu")
onnx_policy = OnnxableMaskableACPolicy(model.policy)

observation_size = model.observation_space.shape
dummy_input = th.randn(1, *observation_size)
th.onnx.export(
    onnx_policy,
    dummy_input,
    model_dest_file,
    opset_version=17,
    input_names=["input"],
	output_names=["dist", "values"]
)

In [3]:
game = BriscolaGame(players=4)

In [4]:
# Load and test model
onnx_model = onnx.load(model_dest_file)
onnx.checker.check_model(onnx_model)

observation = np.array([game_embedding(game, 0)], dtype=np.float32)

np.reshape(observation, (1, 86))
print(observation)
ort_sess = ort.InferenceSession(model_dest_file)

print("Inputs:")
for i in ort_sess.get_inputs():
    print("\t", i)


print("Outputs")
for out in ort_sess.get_outputs():
    print("\t", out)

# Check that the predictions are the same
actions, values = ort_sess.run(None, {"input": observation})
print(actions, values)
with th.no_grad():
    print(model.policy(th.as_tensor(observation), deterministic=True))

[[ 0.  0.  0.  1.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  1.  0.
   0.  1.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
   0.  0.  0.  0.  0.  0.  0.  0. 23.  2.  1.  1.  1.  1.  1.  1.  1.  1.
   1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  0.  1.  1.
   1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.]]
Inputs:
	 NodeArg(name='input', type='tensor(float)', shape=[1, 86])
Outputs
	 NodeArg(name='actions', type='tensor(float)', shape=[1, 40])
	 NodeArg(name='values', type='tensor(float)', shape=[1, 1])
[[1.03327498e-01 1.14174532e-02 1.54052451e-01 1.30279353e-04
  3.21751507e-03 4.33949870e-04 2.25688703e-03 4.66828700e-03
  9.19541810e-03 2.50689834e-02 7.70970881e-02 6.42255647e-04
  1.05952196e-01 1.47218263e-04 5.64788410e-04 3.09772906e-04
  2.34098849e-03 3.41040851e-03 5.85791562e-03 2.80531533e-02
  1.30125031e-01 3.34504788e-04 7.50642866e-02 3.25879053e-04
  6.17327925e-04 9.32262265e-05 8.56808096e-04 2.22337572e-03
  5.02975