In [1]:
from stable_baselines3 import PPO
import torch
import numpy as np

In [2]:
class OnnxableActionPolicy(torch.nn.Module):
    def __init__(self, extractor, action_net, value_net):
        super(OnnxableActionPolicy, self).__init__()
        self.extractor = extractor
        self.action_net = action_net
        self.value_net = value_net
        normalize_linear1 = torch.nn.Linear(1, 4)
        # 100* ((max(0,x) - max(0,-x)) - max(0,x-1) + max(0,-x-1))
        normalize_linear1.weight.data = torch.Tensor([[1],[-1],[1],[-1]])
        normalize_linear1.bias.data=torch.Tensor([0,0,-1,-1])
        #print(normalize_linear1.weight)
        #print(normalize_linear1.bias)
        A = 100-1e-6
        normalize_linear2 = torch.nn.Linear(3,1)
        normalize_linear2.weight.data = torch.Tensor([[A,-A,-A,A]])
        normalize_linear2.bias.data=torch.Tensor([0])
        self.normalizer = torch.nn.Sequential(
            normalize_linear1,
            torch.nn.ReLU(),
            normalize_linear2)

    def forward(self, observation):
        # NOTE: You may have to process (normalize) observation in the correct
        #       way before using this. See `common.preprocessing.preprocess_obs`
        action_hidden, value_hidden = self.extractor(observation)
        action = self.action_net(action_hidden)
        return self.normalizer(action) #, self.value_net(value_hidden)

In [4]:
# Example: model = PPO("MlpPolicy", "Pendulum-v0")
model = PPO.load("ppo_acc_bigger_200000_steps.zip")
model.policy.to("cpu")
onnxable_model = OnnxableActionPolicy(model.policy.mlp_extractor, model.policy.action_net, model.policy.value_net)

In [5]:
onnxable_model.normalizer(torch.Tensor([2]))

tensor([100.], grad_fn=<AddBackward0>)

In [6]:
dummy_input = torch.randn(1, 2)
torch.onnx.export(onnxable_model, dummy_input, "ppo_acc_bigger_200000_steps.onnx", opset_version=9)

In [8]:
##### Load and test with onnx

import onnx
import onnxruntime as ort
import numpy as np

In [10]:
onnx_model = onnx.load("ppo_acc_bigger_2000000_steps.onnx")
onnx.checker.check_model(onnx_model)

In [11]:
onnx_model = onnx.load("ppo_acc_bigger_2000000_steps.onnx")
onnx.checker.check_model(onnx_model)

observation = np.zeros((1, 2)).astype(np.float32)
ort_sess = ort.InferenceSession("ppo_acc_bigger_2000000_steps.onnx")

In [12]:
print(ort_sess.run(None, {'input.1': [[50,-99.0]]}))

[array([[100.000015]], dtype=float32)]


In [13]:
import gym

In [14]:
env = gym.make("acc-variant-v2")



In [15]:
obs = env.reset()
obs = [50,-99.0]
env.unwrapped.state = obs
for i in range(0,300):
    action = ort_sess.run(None, {'input.1': [[obs[0],obs[1]]]})[0][0]
    #action = np.clip(action,-100.,100.)
    obs, rewards, dones, info = env.step(action)
    env.render()
    if dones:
        print("DONE")
        print(obs)
        env.reset()

DONE
[100.4 141. ]
DONE
[118.63347988 192.79825841]
DONE
[ -0.21748271 -26.94745295]


AssertionError: [-100.00001] (of type <class 'numpy.ndarray'>) invalid

In [16]:
del env

# Rename output nodes to not purely numeric names!

In [17]:
onnx_model = onnx.load('ppo_acc_bigger_2000000_steps.onnx')

In [18]:
onnx_model.graph.output[0].name = "out1"

In [19]:
onnx_model.graph.node[len(onnx_model.graph.node)-1].output[0]="out1"

In [20]:
onnx_model.graph.input[0].name

'input.1'

In [21]:
onnx.save(onnx_model, 'ppo_acc_bigger_2000000_steps.onnx')