# TorchRL Exploration

In [1]:
from torchrl.envs import BraxWrapper, ParallelEnv
import brax.envs as envs
from Rodent_Env_Brax import Rodent
import torch

In [2]:
envs.register_environment("rodent", Rodent)

In [3]:
device = torch.device("cpu")

In [4]:
env = BraxWrapper(envs.get_environment("rodent"), device=device)

In [5]:
env.set_seed(0)
td = env.reset()
print(td)

TensorDict(
    fields={
        done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        observation: Tensor(shape=torch.Size([1260]), device=cpu, dtype=torch.float32, is_shared=False),
        state: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
                metrics: TensorDict(
                    fields={
                        distance_from_origin: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
                        forward_reward: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
                        reward_alive: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
                        reward_linvel: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
                        reward_quadctrl: Tensor(shape=torch.Size([]), device=c

In [11]:
td["action"] = env.action_spec.rand() # random move
td = env.step(td) # step the env
print(td)

TensorDict(
    fields={
        action: Tensor(shape=torch.Size([30]), device=cpu, dtype=torch.float32, is_shared=False),
        done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
                observation: Tensor(shape=torch.Size([1260]), device=cpu, dtype=torch.float32, is_shared=False),
                reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
                state: TensorDict(
                    fields={
                        done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
                        metrics: TensorDict(
                            fields={
                                distance_from_origin: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
                             

In [12]:
td = env.rand_step(td)

In [8]:
%timeit env.rand_step(td)

30.5 ms ± 476 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


## Checking observation connection

In [12]:
torch.ones(env.observation_spec['observation'].shape).shape[-1]

1260

## Checking parralel environment

In [13]:
from torchrl.envs import (BraxWrapper,
                          ParallelEnv,
                          EnvCreator,
                          TransformedEnv,
                          VecNorm,
                          RewardSum)
import brax.envs as brax_envs

def make_env(env_name="rodent", frame_skip=4, is_test=False):
    brax_envs.register_environment(env_name, Rodent)

    env = BraxWrapper(brax_envs.get_environment(env_name), 
                      iterations=6,
                      ls_iterations=3)
    env.set_seed(0)
    env = TransformedEnv(env)
    return env

def make_parallel_env(env_name, num_envs, device, is_test=False):
    env = ParallelEnv(
        num_envs,
        EnvCreator(lambda: make_env(env_name)),
        serial_for_single=True,
        device=device,
    )
    env = TransformedEnv(env)
    env.append_transform(VecNorm(in_keys=["observation"]))
    env.append_transform(RewardSum())
    return env


proof_environment = make_parallel_env('rodent', 1, device="cpu")

In [16]:
proof_environment.observation_spec["observation"].shape

torch.Size([1, 1260])

In [19]:
torch.ones(proof_environment.observation_spec["observation"].shape)

tensor([[1., 1., 1.,  ..., 1., 1., 1.]])

In [20]:
from torchrl.modules import (
    ActorValueOperator,
    ConvNet,
    MLP,
    OneHotCategorical,
    ProbabilisticActor,
    TanhNormal,
    ValueOperator,
)

common_cnn = ConvNet(
        activation_class=torch.nn.ReLU,
        num_cells=[32, 64, 64],
        kernel_sizes=[8, 4, 3],
        strides=[4, 2, 1],
    )



In [22]:
# common_cnn(torch.ones(proof_environment.observation_spec["observation"].shape))

In [24]:
input_shape = proof_environment.observation_spec["observation"].shape

common_mlp = MLP(
        in_features=input_shape[-1], #common_cnn_output.shape[-1],
        activation_class=torch.nn.ReLU,
        activate_last_layer=True,
        out_features=512,
        num_cells=[],
    )

common_mlp_output = common_mlp(torch.ones(input_shape))#(common_cnn_output)

## Compare to brax

In [9]:
b_env = Rodent()

In [10]:
import jax
import brax
key = jax.random.key(0)
state = b_env.reset(key)

In [11]:
jit_step = jax.jit(b_env.step)

In [12]:
action = jax.random.uniform(key, shape=state.pipeline_state.ctrl.shape)
state = jit_step(state, action)

In [13]:
action = jax.random.uniform(key, shape=state.pipeline_state.ctrl.shape)
%timeit jit_step(state, action)

2.68 ms ± 434 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


# Testing code

In [14]:
import jax
env = Rodent()

In [15]:
key = jax.random.PRNGKey(0)
state = env.reset(key)

In [16]:
state.__dict__.keys()

dict_keys(['pipeline_state', 'obs', 'reward', 'done', 'metrics', 'info'])

In [17]:
for k in state.__dict__.keys():
    print(type(getattr(state, k)))
    

<class 'brax.mjx.base.State'>
<class 'jaxlib.xla_extension.ArrayImpl'>
<class 'jaxlib.xla_extension.ArrayImpl'>
<class 'jaxlib.xla_extension.ArrayImpl'>
<class 'dict'>
<class 'dict'>


In [18]:
data = state.pipeline_state
print(type(data.contact))

<class 'brax.base.Contact'>


In [19]:
contact = data.contact

In [20]:
for k in contact.__dict__.keys():
    print(type(getattr(contact, k)))

<class 'jaxlib.xla_extension.ArrayImpl'>
<class 'jaxlib.xla_extension.ArrayImpl'>
<class 'jaxlib.xla_extension.ArrayImpl'>
<class 'jaxlib.xla_extension.ArrayImpl'>
<class 'jaxlib.xla_extension.ArrayImpl'>
<class 'jaxlib.xla_extension.ArrayImpl'>
<class 'jaxlib.xla_extension.ArrayImpl'>
<class 'jaxlib.xla_extension.ArrayImpl'>
<class 'jaxlib.xla_extension.ArrayImpl'>
<class 'jaxlib.xla_extension.ArrayImpl'>
<class 'tuple'>
<class 'jaxlib.xla_extension.ArrayImpl'>


In [21]:
contact.link_idx

(Array([-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
        -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],      dtype=int32),
 Array([10, 10, 11, 11, 11, 11, 11, 11, 14, 14, 15, 15, 15, 15, 15, 15, 24,
        24, 35, 35, 59, 59, 59, 59, 59, 59, 64, 64, 64, 64, 64, 64, 58, 63],      dtype=int32))

: 