# 1. Prepare Env

In [8]:
from gym.envs.registration import register

register(
    id='SampleEnv-v0',
    entry_point='sample_env:SampleEnv',
)

  logger.warn("Overriding environment {}".format(id))


# 2. Prepare sample model

In [9]:
from stable_baselines3 import PPO
import torch

In [10]:
# create onnx model parser format
class OnnxablePolicy(torch.nn.Module):
    def __init__(self, extractor, action_net, value_net):
        super().__init__()
        self.extractor = extractor
        self.action_net = action_net
        self.value_net = value_net

    def forward(self, observation):
        # NOTE: You may have to process (normalize) observation in the correct
        #       way before using this. See `common.preprocessing.preprocess_obs`
        action_hidden, value_hidden = self.extractor(observation)
        return self.action_net(action_hidden), self.value_net(value_hidden)

In [11]:
# initialize sample RL model
model = PPO("MlpPolicy", "SampleEnv-v0", device="cpu")
print('observation_space: ', model.observation_space)
print('action_space: ', model.action_space)

observation_space:  Box([0. 0. 0. 0. 0. 0. 0. 0. 0. 0.], [100. 100. 100. 100. 100. 100. 100. 100. 100. 100.], (10,), float32)
action_space:  MultiBinary(5)


  logger.warn(


In [12]:
print('observation_space sample: ', model.observation_space.sample())
print('action_space sample: ', model.action_space.sample())

observation_space sample:  [76.74602  65.90086  56.5687   25.135822 29.373598  9.519502 24.78885
 97.69087  57.811977 53.804802]
action_space sample:  [0 0 0 0 1]


In [14]:
# parse RL model to ONNX format
onnxable_model = OnnxablePolicy(
    model.policy.mlp_extractor, model.policy.action_net, model.policy.value_net
)

In [59]:
# define model input schema
observation_size = model.observation_space.shape
dummy_input = torch.randint(0, 100, observation_size, dtype=torch.float32)
print('observation_size: ', observation_size)
print('dummy_input: ', dummy_input)

observation_size:  (10,)
dummy_input:  tensor([12., 59., 83., 83., 26., 22., 36.,  2., 90., 36.])


In [82]:
# create ONNX model
torch.onnx.export(
    onnxable_model,
    dummy_input,
    "sample_ppo_model.onnx",
    opset_version=12,
    input_names=["input"],
)

verbose: False, log level: Level.ERROR



# 3. Run ONNX model

In [83]:
import onnx

# load ONNX model
onnx_path = "sample_ppo_model.onnx"
onnx_model = onnx.load(onnx_path)
onnx.checker.check_model(onnx_model)  # if no exception, model is valid

In [87]:
import onnxruntime
import numpy as np

# sample model input
observation_size = (10,)  # model.observation_space.shape
observation = np.random.randint(0, 100, size=observation_size).astype(np.float32)
print('sample input: ', observation)

ort_sess = onnxruntime.InferenceSession(onnx_path)
actions, _ = ort_sess.run(None, {"input": observation})
print('predicted actions: ', actions)

sample input:  [92.  1. 26. 56. 21. 35. 24. 67. 38. 99.]
predicted actions:  [-0.00517289 -0.00588899  0.00750596  0.01193175  0.01188159]
