In [1]:
import numpy as np
import torch

from torchrl.envs import ParallelEnv
from torchrl.collectors import SyncDataCollector

from miscellaneous.tiny_sim_wrapper import TinySimWrapper
from miscellaneous.torchrl_ac import ActorCritic
from miscellaneous.attention_ac import ActorCriticWithAttention

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device(0) if torch.cuda.is_available() else torch.device("cpu")

In [3]:
#env = ParallelEnv(
#    1,
#    lambda: TinySimWrapper(model_path="../models/tinyphysics.onnx", data_directory_path="../data", device=device),
#    device=device
#)
env = TinySimWrapper(model_path="../models/tinyphysics.onnx", data_directory_path="../data", device=device)

In [4]:
in_features = env.observation_spec["current_state"].shape[-1] + 1 # +1 for time
num_actions = env.action_spec.shape[-1]
low = env.action_spec_unbatched.space.low
high = env.action_spec_unbatched.space.high

#ac = ActorCritic(in_features, num_actions, low, high, 256, in_keys=["current_state", "time"]).to(device)
ac = ActorCriticWithAttention(num_actions, low, high).to(device)

In [5]:
td = env.reset()

In [6]:
td

TensorDict(
    fields={
        current_state: Tensor(shape=torch.Size([5]), device=cuda:0, dtype=torch.float32, is_shared=True),
        done: Tensor(shape=torch.Size([1]), device=cuda:0, dtype=torch.bool, is_shared=True),
        future_plans: Tensor(shape=torch.Size([49, 4]), device=cuda:0, dtype=torch.float32, is_shared=True),
        past_states: Tensor(shape=torch.Size([49, 5]), device=cuda:0, dtype=torch.float32, is_shared=True),
        terminated: Tensor(shape=torch.Size([1]), device=cuda:0, dtype=torch.bool, is_shared=True),
        time: Tensor(shape=torch.Size([1]), device=cuda:0, dtype=torch.float32, is_shared=True)},
    batch_size=torch.Size([]),
    device=cuda:0,
    is_shared=True)

In [7]:
action = ac(td)

torch.Size([64])


In [8]:
action

TensorDict(
    fields={
        action: Tensor(shape=torch.Size([1]), device=cuda:0, dtype=torch.float32, is_shared=True),
        current_state: Tensor(shape=torch.Size([5]), device=cuda:0, dtype=torch.float32, is_shared=True),
        done: Tensor(shape=torch.Size([1]), device=cuda:0, dtype=torch.bool, is_shared=True),
        future_plans: Tensor(shape=torch.Size([49, 4]), device=cuda:0, dtype=torch.float32, is_shared=True),
        hidden: Tensor(shape=torch.Size([64]), device=cuda:0, dtype=torch.float32, is_shared=True),
        loc: Tensor(shape=torch.Size([1]), device=cuda:0, dtype=torch.float32, is_shared=True),
        past_states: Tensor(shape=torch.Size([49, 5]), device=cuda:0, dtype=torch.float32, is_shared=True),
        sample_log_prob: Tensor(shape=torch.Size([]), device=cuda:0, dtype=torch.float32, is_shared=True),
        scale: Tensor(shape=torch.Size([1]), device=cuda:0, dtype=torch.float32, is_shared=True),
        state_value: Tensor(shape=torch.Size([1]), device=

In [9]:
new_td = env._step(action)

In [10]:
new_td["reward"]

tensor([-12.0095], device='cuda:0')