In [1]:

import gym
import numpy as np
from torchsummary import summary
from tianshou.utils.net.common import BDQNet

In [2]:
class DiscreteToContinuous(gym.ActionWrapper):
    def __init__(self, env, action_per_branch):
        super().__init__(env)
        self.action_per_branch = action_per_branch
        low = self.action_space.low
        high = self.action_space.high
        self.mesh = []
        for l, h in zip(low, high):
            self.mesh.append(np.linspace(l, h, action_per_branch))
        
    def action(self, act):
        # modify act
        act = np.array([self.mesh[i][a] for i, a in enumerate(act)])
        return act

In [9]:
env = DiscreteToContinuous(gym.make('BipedalWalker-v3'), action_per_branch=2)

In [10]:
state_shape = env.observation_space.shape[0]
action_shape = env.action_space.shape[0]
action_per_branch = env.action_per_branch
common_hidden_sizes = [512, 256]
value_hidden_sizes = [128]
action_hidden_sizes = [128]

In [11]:
model = BDQNet(state_shape,
               action_shape,
               action_per_branch,
               common_hidden_sizes,
               value_hidden_sizes,
               action_hidden_sizes)

In [18]:
import torch

x = torch.randn(state_shape).unsqueeze(0)
y = model(x)
y[0].shape

torch.Size([1, 4, 2])

In [23]:
y[0].max(dim=-1)[1]

tensor([[0, 0, 1, 0]])

In [25]:
y[0].shape

torch.Size([1, 4, 2])

In [26]:
len(y[0])

1

In [31]:
a = np.random.randint(0, 2, 10)
a.reshape(-1, 1, 1)

array([[[0]],

       [[1]],

       [[1]],

       [[1]],

       [[0]],

       [[0]],

       [[1]],

       [[1]],

       [[1]],

       [[1]]])