In [10]:
import torch as th

class OnnxablePolicy(th.nn.Module):
    def __init__(self, extractor, action_net, value_net):
        super().__init__()
        self.extractor = extractor
        self.action_net = action_net
        self.value_net = value_net

    def forward(self, observation):
        # NOTE: You may have to process (normalize) observation in the correct
        #       way before using this. See `common.preprocessing.preprocess_obs`
#         action_hidden, value_hidden = self.extractor(observation)
        action_hidden = self.extractor(observation)
        value_hidden = self.extractor(observation)
        return self.action_net(action_hidden), self.value_net(value_hidden)



In [11]:
import time
import torch
from stable_baselines3 import PPO

PolicyModel = PPO
model_path = "models/ppo_8x8_cnn/model_10w_1"
model = PolicyModel.load(model_path, device="cpu")
# policy_model = ppo_model.policy
onnxable_model = OnnxablePolicy(
    model.policy.features_extractor, model.policy.action_net, model.policy.value_net
)

observation_size = model.observation_space.shape
dummy_input = th.randn(1, *observation_size)

torchscript_file = model_path + '.torchscript'
# traced = torch.jit.script(policy_model)
# policy_model = policy_model.cpu()
traced = torch.jit.trace(onnxable_model, dummy_input)
traced.save(torchscript_file)

In [12]:
th_model=torch.jit.load(torchscript_file)

In [13]:
128*64

8192

In [8]:
model.policy

ActorCriticCnnPolicy(
  (features_extractor): CustomCNN(
    (cnn): Sequential(
      (0): Conv2d(4, 64, kernel_size=(3, 3), stride=(1, 1), padding=same)
      (1): ReLU()
      (2): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=same)
      (3): ReLU()
      (4): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=same)
      (5): ReLU()
      (6): Flatten(start_dim=1, end_dim=-1)
    )
    (linear): Sequential(
      (0): Linear(in_features=8192, out_features=256, bias=True)
      (1): ReLU()
    )
  )
  (pi_features_extractor): CustomCNN(
    (cnn): Sequential(
      (0): Conv2d(4, 64, kernel_size=(3, 3), stride=(1, 1), padding=same)
      (1): ReLU()
      (2): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=same)
      (3): ReLU()
      (4): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=same)
      (5): ReLU()
      (6): Flatten(start_dim=1, end_dim=-1)
    )
    (linear): Sequential(
      (0): Linear(in_features=8192, out_features=256

In [4]:
model.policy.mlp_extractor

MlpExtractor(
  (policy_net): Sequential(
    (0): Linear(in_features=256, out_features=256, bias=True)
    (1): Tanh()
    (2): Linear(in_features=256, out_features=256, bias=True)
    (3): Tanh()
  )
  (value_net): Sequential(
    (0): Linear(in_features=256, out_features=256, bias=True)
    (1): Tanh()
    (2): Linear(in_features=256, out_features=256, bias=True)
    (3): Tanh()
  )
)

In [6]:
model.policy.action_net

Linear(in_features=256, out_features=64, bias=True)

In [7]:
model.policy.value_net

Linear(in_features=256, out_features=1, bias=True)