In [1]:
from stable_baselines3 import PPO
import torch
import numpy as np

In [2]:
class OnnxableActionPolicy(torch.nn.Module):
    def __init__(self, extractor, action_net, value_net):
        super(OnnxableActionPolicy, self).__init__()
        self.extractor = extractor
        self.action_net = action_net
        self.value_net = value_net
        normalize_linear0 = torch.nn.Linear(1, 1)
        normalize_linear0.weight.data = torch.Tensor([[1]])
        normalize_linear0.bias.data=torch.Tensor([0.05])
        normalize_linear1 = torch.nn.Linear(1, 4)
        # 100* ((max(0,x) - max(0,-x)) - max(0,x-1) + max(0,-x-1))
        normalize_linear1.weight.data = torch.Tensor([[1],[-1],[1],[-1]])
        normalize_linear1.bias.data=torch.Tensor([0,0,-1,-1])
        #print(normalize_linear1.weight)
        #print(normalize_linear1.bias)
        A = 100
        normalize_linear2 = torch.nn.Linear(3,1)
        normalize_linear2.weight.data = torch.Tensor([[A+1e-3,-A+3e-3,-A-1e-3,A-3e-3]])
        normalize_linear2.bias.data=torch.Tensor([0])
        self.normalizer = torch.nn.Sequential(
            normalize_linear0,
            normalize_linear1,
            torch.nn.ReLU(),
            normalize_linear2)

    def forward(self, observation):
        # NOTE: You may have to process (normalize) observation in the correct
        #       way before using this. See `common.preprocessing.preprocess_obs`
        action_hidden, value_hidden = self.extractor(observation)
        action = self.action_net(action_hidden)
        return self.normalizer(action) #, self.value_net(value_hidden)

In [3]:
# Example: model = PPO("MlpPolicy", "Pendulum-v0")
model = PPO.load("model_backup/acc-2000000-64-64-64-64-100000-0.1")
model.policy.to("cpu")
onnxable_model = OnnxableActionPolicy(model.policy.mlp_extractor, model.policy.action_net, model.policy.value_net)

In [4]:
onnxable_model.normalizer(torch.Tensor([2]))

tensor([100.0010], grad_fn=<AddBackward0>)

In [5]:
onnxable_model.normalizer(torch.Tensor([-2]))

tensor([-99.9970], grad_fn=<AddBackward0>)

In [6]:
dummy_input = torch.randn(1, 2)
torch.onnx.export(onnxable_model, dummy_input, "ppo_acc_bigger_retrain200000-100000-0.1.onnx", opset_version=9)

In [7]:
##### Load and test with onnx

import onnx
import onnxruntime as ort
import numpy as np

In [8]:
onnx_model = onnx.load("ppo_acc_bigger_retrain200000-100000-0.1.onnx")
onnx.checker.check_model(onnx_model)

In [9]:
onnx_model = onnx.load("ppo_acc_bigger_retrain200000-100000-0.1.onnx")
onnx.checker.check_model(onnx_model)

observation = np.zeros((1, 2)).astype(np.float32)
ort_sess = ort.InferenceSession("ppo_acc_bigger_retrain200000-100000-0.1.onnx")

In [10]:
print(ort_sess.run(None, {'input.1': [[0.08,-4]]}))

[array([[59.13693]], dtype=float32)]


In [31]:
import gym
import acc

In [32]:
env = gym.make("acc-variant-v2")

In [13]:
env.seed(2022)
torch.manual_seed(2022)

<torch._C.Generator at 0x7fac1c9ed9b0>

In [38]:
for i in range(10000):
    obs = env.reset()
    starting_state = obs
    #obs = [50,-99.0]
    env.unwrapped.state = obs
    for i in range(0,500):
        action = ort_sess.run(None, {'input.1': [[obs[0],obs[1]]]})[0][0]
        action = np.clip(action,-100.,100.)
        obs, rewards, dones, info = env.step(action)
        env.render()
        if dones:
            print(".",end="")
            if obs[0]<=0:
                print("\nEncountered unsafe behaviour!")
                print("Starting state: ",starting_state)
            break

KeyboardInterrupt: 

In [16]:
env.close()

NameError: name 'env' is not defined

# Rename output nodes to not purely numeric names!

In [11]:
onnx_model = onnx.load('ppo_acc_bigger_retrain200000-100000-0.1.onnx')

In [12]:
onnx_model.graph.output[0].name = "out1"

In [13]:
onnx_model.graph.node[len(onnx_model.graph.node)-1].output[0]="out1"

In [14]:
onnx_model.graph.input[0].name

'input.1'

In [15]:
onnx.save(onnx_model, 'ppo_acc_bigger_retrain200000-100000-0.1.onnx')