In [3]:
from stable_baselines3 import PPO
import torch
import numpy as np

In [4]:
class OnnxableActionPolicy(torch.nn.Module):
    def __init__(self, extractor, action_net, value_net):
        super(OnnxableActionPolicy, self).__init__()
        self.extractor = extractor
        self.action_net = action_net
        self.value_net = value_net
        normalize_linear0 = torch.nn.Linear(1, 1)
        normalize_linear0.weight.data = torch.Tensor([[1]])
        normalize_linear0.bias.data=torch.Tensor([0.05])
        normalize_linear1 = torch.nn.Linear(1, 4)
        # 100* ((max(0,x) - max(0,-x)) - max(0,x-1) + max(0,-x-1))
        normalize_linear1.weight.data = torch.Tensor([[1],[-1],[1],[-1]])
        normalize_linear1.bias.data=torch.Tensor([0,0,-1,-1])
        #print(normalize_linear1.weight)
        #print(normalize_linear1.bias)
        A = 100
        normalize_linear2 = torch.nn.Linear(3,1)
        normalize_linear2.weight.data = torch.Tensor([[A+1e-3,-A+3e-3,-A-1e-3,A-3e-3]])
        normalize_linear2.bias.data=torch.Tensor([0])
        self.normalizer = torch.nn.Sequential(
            normalize_linear0,
            normalize_linear1,
            torch.nn.ReLU(),
            normalize_linear2)

    def forward(self, observation):
        # NOTE: You may have to process (normalize) observation in the correct
        #       way before using this. See `common.preprocessing.preprocess_obs`
        action_hidden, value_hidden = self.extractor(observation)
        action = self.action_net(action_hidden)
        return self.normalizer(action) #, self.value_net(value_hidden)

In [5]:
# Example: model = PPO("MlpPolicy", "Pendulum-v0")
model = PPO.load("model_backup/acc-2000000-64-64-64-64-100000-200000-0.9")
model.policy.to("cpu")
onnxable_model = OnnxableActionPolicy(model.policy.mlp_extractor, model.policy.action_net, model.policy.value_net)

In [6]:
onnxable_model.normalizer(torch.Tensor([2]))

tensor([100.0010], grad_fn=<AddBackward0>)

In [7]:
onnxable_model.normalizer(torch.Tensor([-2]))

tensor([-99.9970], grad_fn=<AddBackward0>)

In [8]:
dummy_input = torch.randn(1, 2)
torch.onnx.export(onnxable_model, dummy_input, "acc-2000000-64-64-64-64-100000-200000-0.9.onnx", opset_version=9)

In [9]:
##### Load and test with onnx

import onnx
import onnxruntime as ort
import numpy as np

In [10]:
onnx_model = onnx.load("acc-2000000-64-64-64-64-100000-200000-0.9.onnx")
onnx.checker.check_model(onnx_model)

In [11]:
onnx_model = onnx.load("acc-2000000-64-64-64-64-100000-200000-0.9.onnx")
onnx.checker.check_model(onnx_model)

observation = np.zeros((1, 2)).astype(np.float32)
ort_sess = ort.InferenceSession("acc-2000000-64-64-64-64-100000-200000-0.9.onnx")

In [12]:
print(ort_sess.run(None, {'input.1': [[0.08,-4]]}))

[array([[100.001]], dtype=float32)]


In [13]:
import gym
import acc

In [14]:
env = gym.make("acc-variant-v1")



In [15]:
env.seed(2022)
torch.manual_seed(2022)

<torch._C.Generator at 0x7f90d40769b0>

In [18]:
for i in range(10000):
    obs = env.reset()
    starting_state = obs
    #obs = [0.08,-4]
    env.unwrapped.state = obs
    for i in range(0,500):
        action = ort_sess.run(None, {'input.1': [[obs[0],obs[1]]]})[0][0]/100.0
        action = np.clip(action,-1.,1.)
        obs, rewards, dones, info = env.step(action)
        env.render()
        if dones:
            print(obs[0])
            print(".",end="")
            if obs[0]<=0:
                print("\nEncountered unsafe behaviour!")
                print("Starting state: ",starting_state)
            break

28.776630715521794
.28.775768132958248
.28.72983095656857
.

KeyboardInterrupt: 

In [19]:
env.close()

# Rename output nodes to not purely numeric names!

In [20]:
onnx_model = onnx.load('acc-2000000-64-64-64-64-100000-200000-0.9.onnx')

In [21]:
onnx_model.graph.output[0].name = "out1"

In [22]:
onnx_model.graph.node[len(onnx_model.graph.node)-1].output[0]="out1"

In [23]:
onnx_model.graph.input[0].name

'input.1'

In [24]:
onnx.save(onnx_model, 'acc-2000000-64-64-64-64-100000-200000-0.9.onnx')