Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/reference/envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ It is also possible to reset some but not all of the environments:
fields={
done: Tensor(torch.Size([4, 1]), dtype=torch.bool),
pixels: Tensor(torch.Size([4, 500, 500, 3]), dtype=torch.uint8),
reset_workers: Tensor(torch.Size([4, 1]), dtype=torch.bool)},
reset_workers: Tensor(torch.Size([4]), dtype=torch.bool)},
batch_size=torch.Size([4]),
device=None,
is_shared=True)
Expand Down
2 changes: 0 additions & 2 deletions test/mocking_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,6 @@ def _step(

if not self.categorical_action_encoding:
assert (a.sum(-1) == 1).all()
assert not self.is_done, "trying to execute step in done env"

obs = self._get_in_obs(tensordict.get(self._out_key)) + a / self.maxstep
tensordict = tensordict.select() # empty tensordict
Expand Down Expand Up @@ -423,7 +422,6 @@ def _step(
self.step_count += 1
tensordict = tensordict.to(self.device)
a = tensordict.get("action")
assert not self.is_done, "trying to execute step in done env"

obs = self._obs_step(self._get_in_obs(tensordict.get(self._out_key)), a)
tensordict = tensordict.select() # empty tensordict
Expand Down
60 changes: 58 additions & 2 deletions test/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,8 +306,8 @@ def make_env():
)
for _data in collector:
continue
steps = _data["step_count"][..., 1:, :]
done = _data["done"][..., :-1, :]
steps = _data["step_count"][..., 1:]
done = _data["done"][..., :-1, :].squeeze(-1)
# we don't want just one done
assert done.sum() > 3
# check that after a done, the next step count is always 1
Expand Down Expand Up @@ -370,6 +370,62 @@ def make_env(seed):
del collector


@pytest.mark.parametrize("frames_per_batch", [200, 10])
@pytest.mark.parametrize("num_env", [1, 3])
@pytest.mark.parametrize("env_name", ["vec"])
def test_split_trajs(num_env, env_name, frames_per_batch, seed=5):
if num_env == 1:

def env_fn(seed):
env = MockSerialEnv(device="cpu")
env.set_seed(seed)
return env

else:

def env_fn(seed):
def make_env(seed):
env = MockSerialEnv(device="cpu")
env.set_seed(seed)
return env

env = SerialEnv(
num_workers=num_env,
create_env_fn=make_env,
create_env_kwargs=[{"seed": i} for i in range(seed, seed + num_env)],
allow_step_when_done=True,
)
env.set_seed(seed)
return env

policy = make_policy(env_name)

collector = SyncDataCollector(
create_env_fn=env_fn,
create_env_kwargs={"seed": seed},
policy=policy,
frames_per_batch=frames_per_batch * num_env,
max_frames_per_traj=2000,
total_frames=20000,
device="cpu",
pin_memory=False,
reset_when_done=True,
split_trajs=True,
)
for _, d in enumerate(collector): # noqa
break

assert d.ndimension() == 2
assert d["mask"].shape == d.shape
assert d["step_count"].shape == d.shape
assert d["traj_ids"].shape == d.shape
for traj in d.unbind(0):
assert traj["traj_ids"].unique().numel() == 1
assert (traj["step_count"][1:] - traj["step_count"][:-1] == 1).all()

del collector


# TODO: design a test that ensures that collectors are interrupted even if __del__ is not called
# @pytest.mark.parametrize("should_shutdown", [True, False])
# def test_shutdown_collector(should_shutdown, num_env=3, env_name="vec", seed=40):
Expand Down
100 changes: 56 additions & 44 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@

# from torchrl.data.postprocs.utils import expand_as_right
from tensordict.tensordict import assert_allclose_td, TensorDict
from tensordict.utils import expand_as_right
from torch import autograd, nn
from torchrl.data import (
CompositeSpec,
Expand Down Expand Up @@ -253,20 +252,22 @@ def _create_seq_mock_data_dqn(
if action_spec_type == "categorical":
action_value = torch.max(action_value, -1, keepdim=True)[0]
action = torch.argmax(action, -1, keepdim=True)
# action_value = action_value.unsqueeze(-1)
reward = torch.randn(batch, T, 1, device=device)
done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device)
mask = ~torch.zeros(batch, T, 1, dtype=torch.bool, device=device)
mask = ~torch.zeros(batch, T, dtype=torch.bool, device=device)
td = TensorDict(
batch_size=(batch, T),
source={
"observation": obs * mask.to(obs.dtype),
"next": {"observation": next_obs * mask.to(obs.dtype)},
"observation": obs.masked_fill_(~mask.unsqueeze(-1), 0.0),
"next": {
"observation": next_obs.masked_fill_(~mask.unsqueeze(-1), 0.0)
},
"done": done,
"mask": mask,
"reward": reward * mask.to(obs.dtype),
"action": action * mask.to(obs.dtype),
"action_value": action_value
* expand_as_right(mask.to(obs.dtype).squeeze(-1), action_value),
"reward": reward.masked_fill_(~mask.unsqueeze(-1), 0.0),
"action": action.masked_fill_(~mask.unsqueeze(-1), 0.0),
"action_value": action_value.masked_fill_(~mask.unsqueeze(-1), 0.0),
},
)
return td
Expand Down Expand Up @@ -488,16 +489,18 @@ def _create_seq_mock_data_ddpg(
action = torch.randn(batch, T, action_dim, device=device).clamp(-1, 1)
reward = torch.randn(batch, T, 1, device=device)
done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device)
mask = ~torch.zeros(batch, T, 1, dtype=torch.bool, device=device)
mask = ~torch.zeros(batch, T, dtype=torch.bool, device=device)
td = TensorDict(
batch_size=(batch, T),
source={
"observation": obs * mask.to(obs.dtype),
"next": {"observation": next_obs * mask.to(obs.dtype)},
"observation": obs.masked_fill_(~mask.unsqueeze(-1), 0.0),
"next": {
"observation": next_obs.masked_fill_(~mask.unsqueeze(-1), 0.0)
},
"done": done,
"mask": mask,
"reward": reward * mask.to(obs.dtype),
"action": action * mask.to(obs.dtype),
"reward": reward.masked_fill_(~mask.unsqueeze(-1), 0.0),
"action": action.masked_fill_(~mask.unsqueeze(-1), 0.0),
},
device=device,
)
Expand Down Expand Up @@ -726,16 +729,18 @@ def _create_seq_mock_data_sac(
action = torch.randn(batch, T, action_dim, device=device).clamp(-1, 1)
reward = torch.randn(batch, T, 1, device=device)
done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device)
mask = ~torch.zeros(batch, T, 1, dtype=torch.bool, device=device)
mask = torch.ones(batch, T, dtype=torch.bool, device=device)
td = TensorDict(
batch_size=(batch, T),
source={
"observation": obs * mask.to(obs.dtype),
"next": {"observation": next_obs * mask.to(obs.dtype)},
"observation": obs.masked_fill_(~mask.unsqueeze(-1), 0.0),
"next": {
"observation": next_obs.masked_fill_(~mask.unsqueeze(-1), 0.0)
},
"done": done,
"mask": mask,
"reward": reward * mask.to(obs.dtype),
"action": action * mask.to(obs.dtype),
"reward": reward.masked_fill_(~mask.unsqueeze(-1), 0.0),
"action": action.masked_fill_(~mask.unsqueeze(-1), 0.0),
},
device=device,
)
Expand Down Expand Up @@ -1129,16 +1134,18 @@ def _create_seq_mock_data_redq(
action = torch.randn(batch, T, action_dim, device=device).clamp(-1, 1)
reward = torch.randn(batch, T, 1, device=device)
done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device)
mask = ~torch.zeros(batch, T, 1, dtype=torch.bool, device=device)
mask = ~torch.zeros(batch, T, dtype=torch.bool, device=device)
td = TensorDict(
batch_size=(batch, T),
source={
"observation": obs * mask.to(obs.dtype),
"next": {"observation": next_obs * mask.to(obs.dtype)},
"observation": obs.masked_fill_(~mask.unsqueeze(-1), 0.0),
"next": {
"observation": next_obs.masked_fill_(~mask.unsqueeze(-1), 0.0)
},
"done": done,
"mask": mask,
"reward": reward * mask.to(obs.dtype),
"action": action * mask.to(obs.dtype),
"reward": reward.masked_fill_(~mask.unsqueeze(-1), 0.0),
"action": action.masked_fill_(~mask.unsqueeze(-1), 0.0),
},
device=device,
)
Expand Down Expand Up @@ -1543,7 +1550,7 @@ def _create_mock_data_ppo(
"done": done,
"reward": reward,
"action": action,
"sample_log_prob": torch.randn_like(action[..., :1]) / 10,
"sample_log_prob": torch.randn_like(action[..., 1]) / 10,
},
device=device,
)
Expand All @@ -1564,23 +1571,25 @@ def _create_seq_mock_data_ppo(
action = torch.randn(batch, T, action_dim, device=device).clamp(-1, 1)
reward = torch.randn(batch, T, 1, device=device)
done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device)
mask = ~torch.zeros(batch, T, 1, dtype=torch.bool, device=device)
mask = torch.ones(batch, T, dtype=torch.bool, device=device)
params_mean = torch.randn_like(action) / 10
params_scale = torch.rand_like(action) / 10
td = TensorDict(
batch_size=(batch, T),
source={
"observation": obs * mask.to(obs.dtype),
"next": {"observation": next_obs * mask.to(obs.dtype)},
"observation": obs.masked_fill_(~mask.unsqueeze(-1), 0.0),
"next": {
"observation": next_obs.masked_fill_(~mask.unsqueeze(-1), 0.0)
},
"done": done,
"mask": mask,
"reward": reward * mask.to(obs.dtype),
"action": action * mask.to(obs.dtype),
"sample_log_prob": torch.randn_like(action[..., :1])
/ 10
* mask.to(obs.dtype),
"loc": params_mean * mask.to(obs.dtype),
"scale": params_scale * mask.to(obs.dtype),
"reward": reward.masked_fill_(~mask.unsqueeze(-1), 0.0),
"action": action.masked_fill_(~mask.unsqueeze(-1), 0.0),
"sample_log_prob": (torch.randn_like(action[..., 1]) / 10).masked_fill_(
~mask, 0.0
),
"loc": params_mean.masked_fill_(~mask.unsqueeze(-1), 0.0),
"scale": params_scale.masked_fill_(~mask.unsqueeze(-1), 0.0),
},
device=device,
)
Expand Down Expand Up @@ -1835,23 +1844,26 @@ def _create_seq_mock_data_a2c(
action = torch.randn(batch, T, action_dim, device=device).clamp(-1, 1)
reward = torch.randn(batch, T, 1, device=device)
done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device)
mask = ~torch.zeros(batch, T, 1, dtype=torch.bool, device=device)
mask = ~torch.zeros(batch, T, dtype=torch.bool, device=device)
params_mean = torch.randn_like(action) / 10
params_scale = torch.rand_like(action) / 10
td = TensorDict(
batch_size=(batch, T),
source={
"observation": obs * mask.to(obs.dtype),
"next": {"observation": next_obs * mask.to(obs.dtype)},
"observation": obs.masked_fill_(~mask.unsqueeze(-1), 0.0),
"next": {
"observation": next_obs.masked_fill_(~mask.unsqueeze(-1), 0.0)
},
"done": done,
"mask": mask,
"reward": reward * mask.to(obs.dtype),
"action": action * mask.to(obs.dtype),
"sample_log_prob": torch.randn_like(action[..., :1])
/ 10
* mask.to(obs.dtype),
"loc": params_mean * mask.to(obs.dtype),
"scale": params_scale * mask.to(obs.dtype),
"reward": reward.masked_fill_(~mask.unsqueeze(-1), 0.0),
"action": action.masked_fill_(~mask.unsqueeze(-1), 0.0),
"sample_log_prob": torch.randn_like(action[..., 1]).masked_fill_(
~mask, 0.0
)
/ 10,
"loc": params_mean.masked_fill_(~mask.unsqueeze(-1), 0.0),
"scale": params_scale.masked_fill_(~mask.unsqueeze(-1), 0.0),
},
device=device,
)
Expand Down
4 changes: 2 additions & 2 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,7 @@ def test_parallel_env(
td1 = env_parallel.step(td)

td_reset = TensorDict(
source={"reset_workers": torch.zeros(N, 1, dtype=torch.bool).bernoulli_()},
source={"reset_workers": torch.zeros(N, dtype=torch.bool).bernoulli_()},
batch_size=[
N,
],
Expand Down Expand Up @@ -585,7 +585,7 @@ def test_parallel_env_with_policy(
td1 = env_parallel.step(td)

td_reset = TensorDict(
source={"reset_workers": torch.zeros(N, 1, dtype=torch.bool).bernoulli_()},
source={"reset_workers": torch.zeros(N, dtype=torch.bool).bernoulli_()},
batch_size=[
N,
],
Expand Down
8 changes: 7 additions & 1 deletion test/test_libs.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,13 @@ def test_habitat(self, envname):


@pytest.mark.skipif(not _has_jumanji, reason="jumanji not installed")
@pytest.mark.parametrize("envname", ["Snake-6x6-v0", "TSP50-v0"])
@pytest.mark.parametrize(
"envname",
[
"TSP50-v0",
"Snake-6x6-v0",
],
)
class TestJumanji:
def test_jumanji_seeding(self, envname):
final_seed = []
Expand Down
18 changes: 5 additions & 13 deletions test/test_postprocs.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,8 @@ def create_fake_trajs(
num_workers=32,
traj_len=200,
):
traj_ids = torch.arange(num_workers).unsqueeze(-1)
steps_count = torch.zeros(num_workers).unsqueeze(-1)
traj_ids = torch.arange(num_workers)
steps_count = torch.zeros(num_workers)
workers = torch.arange(num_workers)

out = []
Expand All @@ -108,10 +108,10 @@ def create_fake_trajs(
td = TensorDict(
source={
"traj_ids": traj_ids,
"a": traj_ids.clone(),
"a": traj_ids.clone().unsqueeze(-1),
"steps_count": steps_count,
"workers": workers,
"done": done,
"done": done.unsqueeze(-1),
},
batch_size=[num_workers],
)
Expand All @@ -125,15 +125,7 @@ def create_fake_trajs(
return out

@pytest.mark.parametrize("num_workers", range(3, 34, 3))
@pytest.mark.parametrize(
"traj_len",
[
10,
17,
50,
97,
],
)
@pytest.mark.parametrize("traj_len", [10, 17, 50, 97])
def test_splits(self, num_workers, traj_len):

trajs = TestSplits.create_fake_trajs(num_workers, traj_len)
Expand Down
4 changes: 3 additions & 1 deletion test/test_shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@ def test_shared(self, indexing_method):
td = tensordict.clone().share_memory_()
if indexing_method == 0:
subtd = TensorDict(
source={key: item[0] for key, item in td.items()}, batch_size=[]
source={key: item[0] for key, item in td.items()},
batch_size=[],
_is_shared=True,
)
elif indexing_method == 1:
subtd = td.get_sub_tensordict(0)
Expand Down
5 changes: 2 additions & 3 deletions test/test_tensor_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def test_discrete(cls):
r = ts.rand()
ts.to_numpy(r)
ts.encode(torch.tensor([5]))
ts.encode(torch.tensor([5]).numpy())
ts.encode(torch.tensor(5).numpy())
ts.encode(9)
with pytest.raises(AssertionError):
ts.encode(torch.tensor([11])) # out of bounds
Expand Down Expand Up @@ -887,9 +887,8 @@ def test_categorical_action_spec_rand(self):

sample = action_spec.rand((10000,))

sample_list = sample[:, 0]
sample_list = sample
sample_list = [sum(sample_list == i).item() for i in range(10)]
print(sample_list)
assert chisquare(sample_list).pvalue > 0.1

sample = action_spec.to_numpy(sample)
Expand Down
Loading