In [1]:
from stable_baselines3 import PPO
import torch

In [2]:
class OnnxableActionPolicy(torch.nn.Module):
    def __init__(self, extractor, action_net, value_net):
        super(OnnxableActionPolicy, self).__init__()
        self.extractor = extractor
        self.action_net = action_net
        self.value_net = value_net

    def forward(self, observation):
        # NOTE: You may have to process (normalize) observation in the correct
        #       way before using this. See `common.preprocessing.preprocess_obs`
        action_hidden, value_hidden = self.extractor(observation)
        return self.action_net(action_hidden) #, self.value_net(value_hidden)

In [3]:
# Example: model = PPO("MlpPolicy", "Pendulum-v0")
model = PPO.load("all_ppo_acc_3000000_steps.zip")
model.policy.to("cpu")
onnxable_model = OnnxableActionPolicy(model.policy.mlp_extractor, model.policy.action_net, model.policy.value_net)

In [4]:
dummy_input = torch.randn(1, 2)
torch.onnx.export(onnxable_model, dummy_input, "all_ppo_acc_3000000_steps.onnx", opset_version=9)

In [4]:
##### Load and test with onnx

import onnx
import onnxruntime as ort
import numpy as np

In [5]:
onnx_model = onnx.load("all_ppo_acc_3000000_steps.onnx")
onnx.checker.check_model(onnx_model)

observation = np.zeros((1, 2)).astype(np.float32)
ort_sess = ort.InferenceSession("all_ppo_acc_3000000_steps.onnx")

In [6]:
print(ort_sess.run(None, {'input.1': [[51.24999277597358,6.0]]}))

[array([[-6.025987]], dtype=float32)]


In [1]:
import gym

In [2]:
env = gym.make("acc-variant-v0")



In [7]:
obs = env.reset()
for i in range(0,3000):
    action = ort_sess.run(None, {'input.1': [[obs[0],obs[1]]]})[0][0]
    obs, rewards, dones, info = env.step(action)
    env.render()
    if dones:
        env.reset()

In [12]:
del gym

# Rename output nodes to not purely numeric names!

In [18]:
onnx_model = onnx.load('all_ppo_acc_3000000_steps.onnx')

In [19]:
onnx_model.graph.output[0].name = "out1"

In [20]:
onnx_model.graph.node[4].output[0]="out1"

In [24]:
onnx_model.graph.input[0].name

'input.1'

In [21]:
onnx.save(onnx_model, 'all_ppo_acc_3000000_steps-renamed.onnx')