In [1]:
import copy
import os
import torch

from rsl_rl.runners import OnPolicyRunner
from rsl_rl.modules import ActorCriticRecurrent

In [2]:
class _OnnxPolicyExporter(torch.nn.Module):
    """Exporter of actor-critic into ONNX file."""

    def __init__(self, actor_critic, normalizer=None, verbose=False):
        super().__init__()
        self.verbose = verbose
        self.actor = copy.deepcopy(actor_critic.actor)
        self.is_recurrent = actor_critic.is_recurrent
        if self.is_recurrent:
            self.rnn = copy.deepcopy(actor_critic.memory_a.rnn)
            self.rnn.cpu()
            self.forward = self.forward_lstm
        # copy normalizer if exists
        if normalizer:
            self.normalizer = copy.deepcopy(normalizer)
        else:
            self.normalizer = torch.nn.Identity()

    def forward_lstm(self, x_in, h_in, c_in):
        # x_in = self.normalizer(x_in)
        # x, (h, c) = self.rnn(x_in.unsqueeze(0), (h_in, c_in))
        # x = x.squeeze(0)
        # return self.actor(x), h, c
        out, (h_out, c_out) = self.rnn(x_in.unsqueeze(0), (h_in, c_in))
        # self.hidden_state[:] = h_out
        # self.cell_state[:] = c_out
        return self.actor(out.squeeze(0)), h_out, c_out

    def forward(self, x):
        return self.actor(self.normalizer(x))

    def export(self, path, filename):
        self.to("cpu")
        if self.is_recurrent:
            obs = torch.zeros(1, self.rnn.input_size)
            h_in = torch.zeros(self.rnn.num_layers, 1, self.rnn.hidden_size)
            c_in = torch.zeros(self.rnn.num_layers, 1, self.rnn.hidden_size)
            actions, h_out, c_out = self(obs, h_in, c_in)
            torch.onnx.export(
                self,
                (obs, h_in, c_in),
                os.path.join(path, filename),
                export_params=True,
                opset_version=11,
                verbose=self.verbose,
                input_names=["obs", "h_in", "c_in"],
                output_names=["actions", "h_out", "c_out"],
                dynamic_axes={},
            )
        else:
            obs = torch.zeros(1, self.actor[0].in_features)
            torch.onnx.export(
                self,
                obs,
                os.path.join(path, filename),
                export_params=True,
                opset_version=11,
                verbose=self.verbose,
                input_names=["obs"],
                output_names=["actions"],
                dynamic_axes={},
            )

In [3]:
def export_policy_as_onnx(
    actor_critic: object, path: str, normalizer: object | None = None, filename="policy.onnx", verbose=False
):
    """Export policy into a Torch ONNX file.

    Args:
        actor_critic: The actor-critic torch module.
        normalizer: The empirical normalizer module. If None, Identity is used.
        path: The path to the saving directory.
        filename: The name of exported ONNX file. Defaults to "policy.onnx".
        verbose: Whether to print the model summary. Defaults to False.
    """
    if not os.path.exists(path):
        os.makedirs(path, exist_ok=True)
    policy_exporter = _OnnxPolicyExporter(actor_critic, normalizer, verbose)
    policy_exporter.export(path, filename)


In [4]:
# model = torch.jit.load("./g1/motion.pt")
actor_critic = ActorCriticRecurrent(num_actor_obs=47, num_critic_obs=50, num_actions=12, 
                                    actor_hidden_dims=[32],
                                    critic_hidden_dims=[32],
                                    rnn_hidden_size=64)
ckpt = torch.load("./g1/model_1650.pt", map_location="cuda:0")
actor_critic.load_state_dict(ckpt["model_state_dict"])

Actor MLP: Sequential(
  (0): Linear(in_features=64, out_features=32, bias=True)
  (1): ELU(alpha=1.0)
  (2): Linear(in_features=32, out_features=12, bias=True)
)
Critic MLP: Sequential(
  (0): Linear(in_features=64, out_features=32, bias=True)
  (1): ELU(alpha=1.0)
  (2): Linear(in_features=32, out_features=1, bias=True)
)
Actor RNN: Memory(
  (rnn): LSTM(47, 64)
)
Critic RNN: Memory(
  (rnn): LSTM(50, 64)
)


  ckpt = torch.load("./g1/model_1650.pt", map_location="cuda:0")


<All keys matched successfully>

In [None]:
policy_exporter = _OnnxPolicyExporter(actor_critic, None, False)
policy_exporter.export(".", "unitree_policy.onnx")



In [None]:
# runner = OnPolicyRunner(env, train_cfg_dict, None, device=args.rl_device)
# actor_critic["model_state_dict"]

NameError: name 'env' is not defined