# TorchRL Exploration

In [1]:
from torchrl.envs import BraxWrapper
import brax.envs as brax_envs
from Rodent_Env_Brax import Rodent
import torch

In [2]:
brax_envs.register_environment("rodent", Rodent)

In [3]:
device = torch.device("cpu")

In [4]:
env = BraxWrapper(brax_envs.get_environment("rodent"), device=device)

In [27]:
# env.set_seed(0)
# td = env.reset()
# print(td)

In [28]:
# td["action"] = env.action_spec.rand() # random move
# td = env.step(td) # step the env
# print(td)

In [29]:
# td = env.rand_step(td)

In [30]:
# %timeit env.rand_step(td)

## Checking observation connection

In [12]:
torch.ones(env.observation_spec['observation'].shape).shape[-1]

1260

## Checking parralel environment

In [32]:
from torchrl.modules import (
    ActorValueOperator,
    ConvNet,
    MLP,
    OneHotCategorical,
    ProbabilisticActor,
    TanhNormal,
    ValueOperator,
    )

from torchrl.envs import (BraxWrapper,
                          ParallelEnv,
                          EnvCreator,
                          TransformedEnv,
                          VecNorm,
                          RewardSum,
                          ExplorationType,
                          )

import brax.envs as brax_envs

from tensordict.nn import TensorDictModule

from torchrl.data import CompositeSpec

from torchrl.data.tensor_specs import DiscreteBox

In [13]:
def make_env(env_name="rodent", frame_skip=4, is_test=False):
    brax_envs.register_environment(env_name, Rodent)

    env = BraxWrapper(brax_envs.get_environment(env_name), 
                      iterations=6,
                      ls_iterations=3)
    env.set_seed(0)
    env = TransformedEnv(env)
    return env

def make_parallel_env(env_name, num_envs, device, is_test=False):
    env = ParallelEnv(
        num_envs,
        EnvCreator(lambda: make_env(env_name)),
        serial_for_single=True,
        device=device,
    )
    env = TransformedEnv(env)
    env.append_transform(VecNorm(in_keys=["observation"]))
    env.append_transform(RewardSum())
    return env

proof_environment = make_parallel_env('rodent', 1, device="cpu")

In [16]:
proof_environment.observation_spec["observation"].shape

torch.Size([1, 1260])

In [19]:
torch.ones(proof_environment.observation_spec["observation"].shape)

tensor([[1., 1., 1.,  ..., 1., 1., 1.]])

## Checking cnn shapes

In [20]:
common_cnn = ConvNet(
        activation_class=torch.nn.ReLU,
        num_cells=[32, 64, 64],
        kernel_sizes=[8, 4, 3],
        strides=[4, 2, 1],
    )



In [22]:
# common_cnn(torch.ones(proof_environment.observation_spec["observation"].shape))

## Checking mlp shapes

In [26]:
input_shape = proof_environment.observation_spec["observation"].shape
in_keys = ["observation"]

common_mlp = MLP(
        in_features=input_shape[-1], #common_cnn_output.shape[-1],
        activation_class=torch.nn.ReLU,
        activate_last_layer=True,
        out_features=512,
        num_cells=[],
    )

common_mlp_output = common_mlp(torch.ones(input_shape))#(common_cnn_output)

common_module = TensorDictModule(
        module=torch.nn.Sequential(common_mlp),#(common_cnn, common_mlp),
        in_keys=in_keys,
        out_keys=["common_features"],
    )

## Checking policy & value

In [33]:
if isinstance(proof_environment.action_spec.space, DiscreteBox):
        num_outputs = proof_environment.action_spec.space.n
        distribution_class = OneHotCategorical
        distribution_kwargs = {}
else:  # is ContinuousBox
        num_outputs = proof_environment.action_spec.shape
        distribution_class = TanhNormal
        distribution_kwargs = {
            "min": proof_environment.action_spec.space.low,
            "max": proof_environment.action_spec.space.high,
}

policy_net = MLP(
        in_features=common_mlp_output.shape[-1],
        out_features=num_outputs,
        activation_class=torch.nn.ReLU,
        num_cells=[],
)
policy_module = TensorDictModule(
        module=policy_net,
        in_keys=["common_features"],
        out_keys=["logits"],
)

# Add probabilistic sampling of the actions
policy_module = ProbabilisticActor(
        policy_module,
        in_keys=["logits"],
        spec=CompositeSpec(action=proof_environment.action_spec),
        distribution_class=distribution_class,
        distribution_kwargs=distribution_kwargs,
        return_log_prob=True,
        default_interaction_type=ExplorationType.RANDOM,
)

# Define another head for the value
value_net = MLP(
        activation_class=torch.nn.ReLU,
        in_features=common_mlp_output.shape[-1],
        out_features=1,
        num_cells=[],
)
value_module = ValueOperator(
        value_net,
        in_keys=["common_features"],
)

## Checking `make_ppo_models` function

In [34]:
proof_environment.rollout(max_steps=100, break_when_any_done=False)

TensorDict(
    fields={
        action: Tensor(shape=torch.Size([1, 100, 30]), device=cpu, dtype=torch.float32, is_shared=False),
        done: Tensor(shape=torch.Size([1, 100, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        episode_reward: Tensor(shape=torch.Size([1, 100, 1]), device=cpu, dtype=torch.float32, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([1, 100, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                episode_reward: Tensor(shape=torch.Size([1, 100, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                observation: Tensor(shape=torch.Size([1, 100, 1260]), device=cpu, dtype=torch.float32, is_shared=False),
                reward: Tensor(shape=torch.Size([1, 100, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                state: TensorDict(
                    fields={
                        done: Tensor(shape=torch.Size([1, 100, 1]), device=cpu, dt

In [35]:
actor_critic = ActorValueOperator(
        common_operator=common_module,
        policy_operator=policy_module,
        value_operator=value_module,
    )

In [36]:
with torch.no_grad():
    td = proof_environment.rollout(max_steps=100, break_when_any_done=False)
    td = actor_critic(td)
    del td

TypeError: distribution keywords and tensordict keys indicated by ProbabilisticTensorDictModule.dist_keys must match.Got this error message: 
    TanhNormal.__init__() got an unexpected keyword argument 'logits'
with dist_keys=['logits']

## Compare to brax

In [9]:
b_env = Rodent()

In [10]:
import jax
import brax
key = jax.random.key(0)
state = b_env.reset(key)

In [11]:
jit_step = jax.jit(b_env.step)

In [12]:
action = jax.random.uniform(key, shape=state.pipeline_state.ctrl.shape)
state = jit_step(state, action)

In [13]:
action = jax.random.uniform(key, shape=state.pipeline_state.ctrl.shape)
%timeit jit_step(state, action)

2.68 ms ± 434 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


# Testing code

In [14]:
import jax
env = Rodent()

In [15]:
key = jax.random.PRNGKey(0)
state = env.reset(key)

In [16]:
state.__dict__.keys()

dict_keys(['pipeline_state', 'obs', 'reward', 'done', 'metrics', 'info'])

In [17]:
for k in state.__dict__.keys():
    print(type(getattr(state, k)))
    

<class 'brax.mjx.base.State'>
<class 'jaxlib.xla_extension.ArrayImpl'>
<class 'jaxlib.xla_extension.ArrayImpl'>
<class 'jaxlib.xla_extension.ArrayImpl'>
<class 'dict'>
<class 'dict'>


In [18]:
data = state.pipeline_state
print(type(data.contact))

<class 'brax.base.Contact'>


In [19]:
contact = data.contact

In [20]:
for k in contact.__dict__.keys():
    print(type(getattr(contact, k)))

<class 'jaxlib.xla_extension.ArrayImpl'>
<class 'jaxlib.xla_extension.ArrayImpl'>
<class 'jaxlib.xla_extension.ArrayImpl'>
<class 'jaxlib.xla_extension.ArrayImpl'>
<class 'jaxlib.xla_extension.ArrayImpl'>
<class 'jaxlib.xla_extension.ArrayImpl'>
<class 'jaxlib.xla_extension.ArrayImpl'>
<class 'jaxlib.xla_extension.ArrayImpl'>
<class 'jaxlib.xla_extension.ArrayImpl'>
<class 'jaxlib.xla_extension.ArrayImpl'>
<class 'tuple'>
<class 'jaxlib.xla_extension.ArrayImpl'>


In [21]:
contact.link_idx

(Array([-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
        -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],      dtype=int32),
 Array([10, 10, 11, 11, 11, 11, 11, 11, 14, 14, 15, 15, 15, 15, 15, 15, 24,
        24, 35, 35, 59, 59, 59, 59, 59, 59, 64, 64, 64, 64, 64, 64, 58, 63],      dtype=int32))

: 