diff --git a/.circleci/unittest/linux/scripts/install.sh b/.circleci/unittest/linux/scripts/install.sh index ad126f23b0a..a7c9bb93976 100755 --- a/.circleci/unittest/linux/scripts/install.sh +++ b/.circleci/unittest/linux/scripts/install.sh @@ -40,7 +40,7 @@ python -c "import functorch" pip install git+https://github.com/pytorch/torchsnapshot # install tensordict -pip install git+https://github.com/pytorch-labs/tensordict +pip install git+https://github.com/pytorch-labs/tensordict.git printf "* Installing torchrl\n" python setup.py develop diff --git a/.circleci/unittest/linux_examples/scripts/install.sh b/.circleci/unittest/linux_examples/scripts/install.sh index ad126f23b0a..a7c9bb93976 100755 --- a/.circleci/unittest/linux_examples/scripts/install.sh +++ b/.circleci/unittest/linux_examples/scripts/install.sh @@ -40,7 +40,7 @@ python -c "import functorch" pip install git+https://github.com/pytorch/torchsnapshot # install tensordict -pip install git+https://github.com/pytorch-labs/tensordict +pip install git+https://github.com/pytorch-labs/tensordict.git printf "* Installing torchrl\n" python setup.py develop diff --git a/.circleci/unittest/linux_libs/scripts_brax/install.sh b/.circleci/unittest/linux_libs/scripts_brax/install.sh index 767070f2b25..91671e8d985 100755 --- a/.circleci/unittest/linux_libs/scripts_brax/install.sh +++ b/.circleci/unittest/linux_libs/scripts_brax/install.sh @@ -36,7 +36,7 @@ else fi # install tensordict -pip install git+https://github.com/pytorch-labs/tensordict +pip install git+https://github.com/pytorch-labs/tensordict.git # smoke test python -c "import functorch;import tensordict" diff --git a/.circleci/unittest/linux_libs/scripts_gym/install.sh b/.circleci/unittest/linux_libs/scripts_gym/install.sh index 0cdee0320c1..7044df97232 100755 --- a/.circleci/unittest/linux_libs/scripts_gym/install.sh +++ b/.circleci/unittest/linux_libs/scripts_gym/install.sh @@ -42,7 +42,7 @@ else fi # install tensordict -pip install git+https://github.com/pytorch-labs/tensordict +pip install git+https://github.com/pytorch-labs/tensordict.git # smoke test python -c "import tensordict" diff --git a/.circleci/unittest/linux_libs/scripts_habitat/install.sh b/.circleci/unittest/linux_libs/scripts_habitat/install.sh index e5833cd1356..8fb340c567c 100755 --- a/.circleci/unittest/linux_libs/scripts_habitat/install.sh +++ b/.circleci/unittest/linux_libs/scripts_habitat/install.sh @@ -38,7 +38,7 @@ else fi # install tensordict -pip install git+https://github.com/pytorch-labs/tensordict +pip install git+https://github.com/pytorch-labs/tensordict.git # smoke test python -c "import functorch;import tensordict" diff --git a/.circleci/unittest/linux_libs/scripts_jumanji/install.sh b/.circleci/unittest/linux_libs/scripts_jumanji/install.sh index 767070f2b25..91671e8d985 100755 --- a/.circleci/unittest/linux_libs/scripts_jumanji/install.sh +++ b/.circleci/unittest/linux_libs/scripts_jumanji/install.sh @@ -36,7 +36,7 @@ else fi # install tensordict -pip install git+https://github.com/pytorch-labs/tensordict +pip install git+https://github.com/pytorch-labs/tensordict.git # smoke test python -c "import functorch;import tensordict" diff --git a/.circleci/unittest/linux_olddeps/scripts_gym_0_13/install.sh b/.circleci/unittest/linux_olddeps/scripts_gym_0_13/install.sh index 0cdee0320c1..7044df97232 100755 --- a/.circleci/unittest/linux_olddeps/scripts_gym_0_13/install.sh +++ b/.circleci/unittest/linux_olddeps/scripts_gym_0_13/install.sh @@ -42,7 +42,7 @@ else fi # install tensordict -pip install git+https://github.com/pytorch-labs/tensordict +pip install git+https://github.com/pytorch-labs/tensordict.git # smoke test python -c "import tensordict" diff --git a/.circleci/unittest/linux_optdeps/scripts/install.sh b/.circleci/unittest/linux_optdeps/scripts/install.sh index 84951e95f24..142ed709325 100755 --- a/.circleci/unittest/linux_optdeps/scripts/install.sh +++ b/.circleci/unittest/linux_optdeps/scripts/install.sh @@ -36,7 +36,7 @@ else fi # install tensordict -pip install git+https://github.com/pytorch-labs/tensordict +pip install git+https://github.com/pytorch-labs/tensordict.git # smoke test python -c "import functorch" diff --git a/.circleci/unittest/linux_stable/scripts/install.sh b/.circleci/unittest/linux_stable/scripts/install.sh index e8688cc825d..cb396476522 100755 --- a/.circleci/unittest/linux_stable/scripts/install.sh +++ b/.circleci/unittest/linux_stable/scripts/install.sh @@ -34,7 +34,7 @@ else fi # install tensordict -pip install git+https://github.com/pytorch-labs/tensordict +pip install git+https://github.com/pytorch-labs/tensordict.git # smoke test python -c "import torch;import functorch" diff --git a/docs/source/reference/envs.rst b/docs/source/reference/envs.rst index 1b8bfdeb240..be856e59298 100644 --- a/docs/source/reference/envs.rst +++ b/docs/source/reference/envs.rst @@ -105,7 +105,7 @@ It is also possible to reset some but not all of the environments: fields={ done: Tensor(torch.Size([4, 1]), dtype=torch.bool), pixels: Tensor(torch.Size([4, 500, 500, 3]), dtype=torch.uint8), - reset_workers: Tensor(torch.Size([4, 1]), dtype=torch.bool)}, + reset_workers: Tensor(torch.Size([4]), dtype=torch.bool)}, batch_size=torch.Size([4]), device=None, is_shared=True) diff --git a/test/test_collector.py b/test/test_collector.py index 769ef221ae6..b12e974097a 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -306,8 +306,8 @@ def make_env(): ) for _data in collector: continue - steps = _data["step_count"][..., 1:, :] - done = _data["done"][..., :-1, :] + steps = _data["step_count"][..., 1:] + done = _data["done"][..., :-1, :].squeeze(-1) # we don't want just one done assert done.sum() > 3 # check that after a done, the next step count is always 1 @@ -370,6 +370,62 @@ def make_env(seed): del collector +@pytest.mark.parametrize("frames_per_batch", [200, 10]) +@pytest.mark.parametrize("num_env", [1, 3]) +@pytest.mark.parametrize("env_name", ["vec"]) +def test_split_trajs(num_env, env_name, frames_per_batch, seed=5): + if num_env == 1: + + def env_fn(seed): + env = MockSerialEnv(device="cpu") + env.set_seed(seed) + return env + + else: + + def env_fn(seed): + def make_env(seed): + env = MockSerialEnv(device="cpu") + env.set_seed(seed) + return env + + env = SerialEnv( + num_workers=num_env, + create_env_fn=make_env, + create_env_kwargs=[{"seed": i} for i in range(seed, seed + num_env)], + allow_step_when_done=True, + ) + env.set_seed(seed) + return env + + policy = make_policy(env_name) + + collector = SyncDataCollector( + create_env_fn=env_fn, + create_env_kwargs={"seed": seed}, + policy=policy, + frames_per_batch=frames_per_batch * num_env, + max_frames_per_traj=2000, + total_frames=20000, + device="cpu", + pin_memory=False, + reset_when_done=True, + split_trajs=True, + ) + for _, d in enumerate(collector): # noqa + break + + assert d.ndimension() == 2 + assert d["mask"].shape == d.shape + assert d["step_count"].shape == d.shape + assert d["traj_ids"].shape == d.shape + for traj in d.unbind(0): + assert traj["traj_ids"].unique().numel() == 1 + assert (traj["step_count"][1:] - traj["step_count"][:-1] == 1).all() + + del collector + + # TODO: design a test that ensures that collectors are interrupted even if __del__ is not called # @pytest.mark.parametrize("should_shutdown", [True, False]) # def test_shutdown_collector(should_shutdown, num_env=3, env_name="vec", seed=40): diff --git a/test/test_cost.py b/test/test_cost.py index fbba7ac4c5a..08c705cbfcd 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -26,7 +26,6 @@ # from torchrl.data.postprocs.utils import expand_as_right from tensordict.tensordict import assert_allclose_td, TensorDict -from tensordict.utils import expand_as_right from torch import autograd, nn from torchrl.data import ( CompositeSpec, @@ -253,20 +252,22 @@ def _create_seq_mock_data_dqn( if action_spec_type == "categorical": action_value = torch.max(action_value, -1, keepdim=True)[0] action = torch.argmax(action, -1, keepdim=True) + # action_value = action_value.unsqueeze(-1) reward = torch.randn(batch, T, 1, device=device) done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) - mask = ~torch.zeros(batch, T, 1, dtype=torch.bool, device=device) + mask = ~torch.zeros(batch, T, dtype=torch.bool, device=device) td = TensorDict( batch_size=(batch, T), source={ - "observation": obs * mask.to(obs.dtype), - "next": {"observation": next_obs * mask.to(obs.dtype)}, + "observation": obs.masked_fill_(~mask.unsqueeze(-1), 0.0), + "next": { + "observation": next_obs.masked_fill_(~mask.unsqueeze(-1), 0.0) + }, "done": done, "mask": mask, - "reward": reward * mask.to(obs.dtype), - "action": action * mask.to(obs.dtype), - "action_value": action_value - * expand_as_right(mask.to(obs.dtype).squeeze(-1), action_value), + "reward": reward.masked_fill_(~mask.unsqueeze(-1), 0.0), + "action": action.masked_fill_(~mask.unsqueeze(-1), 0.0), + "action_value": action_value.masked_fill_(~mask.unsqueeze(-1), 0.0), }, ) return td @@ -488,16 +489,18 @@ def _create_seq_mock_data_ddpg( action = torch.randn(batch, T, action_dim, device=device).clamp(-1, 1) reward = torch.randn(batch, T, 1, device=device) done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) - mask = ~torch.zeros(batch, T, 1, dtype=torch.bool, device=device) + mask = ~torch.zeros(batch, T, dtype=torch.bool, device=device) td = TensorDict( batch_size=(batch, T), source={ - "observation": obs * mask.to(obs.dtype), - "next": {"observation": next_obs * mask.to(obs.dtype)}, + "observation": obs.masked_fill_(~mask.unsqueeze(-1), 0.0), + "next": { + "observation": next_obs.masked_fill_(~mask.unsqueeze(-1), 0.0) + }, "done": done, "mask": mask, - "reward": reward * mask.to(obs.dtype), - "action": action * mask.to(obs.dtype), + "reward": reward.masked_fill_(~mask.unsqueeze(-1), 0.0), + "action": action.masked_fill_(~mask.unsqueeze(-1), 0.0), }, device=device, ) @@ -726,16 +729,18 @@ def _create_seq_mock_data_sac( action = torch.randn(batch, T, action_dim, device=device).clamp(-1, 1) reward = torch.randn(batch, T, 1, device=device) done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) - mask = ~torch.zeros(batch, T, 1, dtype=torch.bool, device=device) + mask = torch.ones(batch, T, dtype=torch.bool, device=device) td = TensorDict( batch_size=(batch, T), source={ - "observation": obs * mask.to(obs.dtype), - "next": {"observation": next_obs * mask.to(obs.dtype)}, + "observation": obs.masked_fill_(~mask.unsqueeze(-1), 0.0), + "next": { + "observation": next_obs.masked_fill_(~mask.unsqueeze(-1), 0.0) + }, "done": done, "mask": mask, - "reward": reward * mask.to(obs.dtype), - "action": action * mask.to(obs.dtype), + "reward": reward.masked_fill_(~mask.unsqueeze(-1), 0.0), + "action": action.masked_fill_(~mask.unsqueeze(-1), 0.0), }, device=device, ) @@ -1129,16 +1134,18 @@ def _create_seq_mock_data_redq( action = torch.randn(batch, T, action_dim, device=device).clamp(-1, 1) reward = torch.randn(batch, T, 1, device=device) done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) - mask = ~torch.zeros(batch, T, 1, dtype=torch.bool, device=device) + mask = ~torch.zeros(batch, T, dtype=torch.bool, device=device) td = TensorDict( batch_size=(batch, T), source={ - "observation": obs * mask.to(obs.dtype), - "next": {"observation": next_obs * mask.to(obs.dtype)}, + "observation": obs.masked_fill_(~mask.unsqueeze(-1), 0.0), + "next": { + "observation": next_obs.masked_fill_(~mask.unsqueeze(-1), 0.0) + }, "done": done, "mask": mask, - "reward": reward * mask.to(obs.dtype), - "action": action * mask.to(obs.dtype), + "reward": reward.masked_fill_(~mask.unsqueeze(-1), 0.0), + "action": action.masked_fill_(~mask.unsqueeze(-1), 0.0), }, device=device, ) @@ -1543,7 +1550,7 @@ def _create_mock_data_ppo( "done": done, "reward": reward, "action": action, - "sample_log_prob": torch.randn_like(action[..., :1]) / 10, + "sample_log_prob": torch.randn_like(action[..., 1]) / 10, }, device=device, ) @@ -1564,23 +1571,25 @@ def _create_seq_mock_data_ppo( action = torch.randn(batch, T, action_dim, device=device).clamp(-1, 1) reward = torch.randn(batch, T, 1, device=device) done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) - mask = ~torch.zeros(batch, T, 1, dtype=torch.bool, device=device) + mask = torch.ones(batch, T, dtype=torch.bool, device=device) params_mean = torch.randn_like(action) / 10 params_scale = torch.rand_like(action) / 10 td = TensorDict( batch_size=(batch, T), source={ - "observation": obs * mask.to(obs.dtype), - "next": {"observation": next_obs * mask.to(obs.dtype)}, + "observation": obs.masked_fill_(~mask.unsqueeze(-1), 0.0), + "next": { + "observation": next_obs.masked_fill_(~mask.unsqueeze(-1), 0.0) + }, "done": done, "mask": mask, - "reward": reward * mask.to(obs.dtype), - "action": action * mask.to(obs.dtype), - "sample_log_prob": torch.randn_like(action[..., :1]) - / 10 - * mask.to(obs.dtype), - "loc": params_mean * mask.to(obs.dtype), - "scale": params_scale * mask.to(obs.dtype), + "reward": reward.masked_fill_(~mask.unsqueeze(-1), 0.0), + "action": action.masked_fill_(~mask.unsqueeze(-1), 0.0), + "sample_log_prob": (torch.randn_like(action[..., 1]) / 10).masked_fill_( + ~mask, 0.0 + ), + "loc": params_mean.masked_fill_(~mask.unsqueeze(-1), 0.0), + "scale": params_scale.masked_fill_(~mask.unsqueeze(-1), 0.0), }, device=device, ) @@ -1835,23 +1844,26 @@ def _create_seq_mock_data_a2c( action = torch.randn(batch, T, action_dim, device=device).clamp(-1, 1) reward = torch.randn(batch, T, 1, device=device) done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) - mask = ~torch.zeros(batch, T, 1, dtype=torch.bool, device=device) + mask = ~torch.zeros(batch, T, dtype=torch.bool, device=device) params_mean = torch.randn_like(action) / 10 params_scale = torch.rand_like(action) / 10 td = TensorDict( batch_size=(batch, T), source={ - "observation": obs * mask.to(obs.dtype), - "next": {"observation": next_obs * mask.to(obs.dtype)}, + "observation": obs.masked_fill_(~mask.unsqueeze(-1), 0.0), + "next": { + "observation": next_obs.masked_fill_(~mask.unsqueeze(-1), 0.0) + }, "done": done, "mask": mask, - "reward": reward * mask.to(obs.dtype), - "action": action * mask.to(obs.dtype), - "sample_log_prob": torch.randn_like(action[..., :1]) - / 10 - * mask.to(obs.dtype), - "loc": params_mean * mask.to(obs.dtype), - "scale": params_scale * mask.to(obs.dtype), + "reward": reward.masked_fill_(~mask.unsqueeze(-1), 0.0), + "action": action.masked_fill_(~mask.unsqueeze(-1), 0.0), + "sample_log_prob": torch.randn_like(action[..., 1]).masked_fill_( + ~mask, 0.0 + ) + / 10, + "loc": params_mean.masked_fill_(~mask.unsqueeze(-1), 0.0), + "scale": params_scale.masked_fill_(~mask.unsqueeze(-1), 0.0), }, device=device, ) diff --git a/test/test_env.py b/test/test_env.py index c4379ec203d..a1e6a725e34 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -422,6 +422,9 @@ def test_multitask(self): env2 = DMControlEnv("humanoid", "walk") env2_obs_keys = list(env2.observation_spec.keys()) + assert len(env1_obs_keys) + assert len(env2_obs_keys) + def env1_maker(): return TransformedEnv( DMControlEnv("humanoid", "stand"), @@ -449,6 +452,7 @@ def env2_maker(): ) env = ParallelEnv(2, [env1_maker, env2_maker]) + # env = SerialEnv(2, [env1_maker, env2_maker]) assert not env._single_task td = env.rollout(10, return_contiguous=False) @@ -497,7 +501,7 @@ def test_parallel_env( td1 = env_parallel.step(td) td_reset = TensorDict( - source={"reset_workers": torch.zeros(N, 1, dtype=torch.bool).bernoulli_()}, + source={"reset_workers": torch.zeros(N, dtype=torch.bool).bernoulli_()}, batch_size=[ N, ], @@ -581,7 +585,7 @@ def test_parallel_env_with_policy( td1 = env_parallel.step(td) td_reset = TensorDict( - source={"reset_workers": torch.zeros(N, 1, dtype=torch.bool).bernoulli_()}, + source={"reset_workers": torch.zeros(N, dtype=torch.bool).bernoulli_()}, batch_size=[ N, ], diff --git a/test/test_libs.py b/test/test_libs.py index c16774de4c2..de3e8c665c4 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -342,7 +342,13 @@ def test_habitat(self, envname): @pytest.mark.skipif(not _has_jumanji, reason="jumanji not installed") -@pytest.mark.parametrize("envname", ["Snake-6x6-v0", "TSP50-v0"]) +@pytest.mark.parametrize( + "envname", + [ + "TSP50-v0", + "Snake-6x6-v0", + ], +) class TestJumanji: def test_jumanji_seeding(self, envname): final_seed = [] diff --git a/test/test_postprocs.py b/test/test_postprocs.py index d50b74bb08f..d684793670d 100644 --- a/test/test_postprocs.py +++ b/test/test_postprocs.py @@ -97,8 +97,8 @@ def create_fake_trajs( num_workers=32, traj_len=200, ): - traj_ids = torch.arange(num_workers).unsqueeze(-1) - steps_count = torch.zeros(num_workers).unsqueeze(-1) + traj_ids = torch.arange(num_workers) + steps_count = torch.zeros(num_workers) workers = torch.arange(num_workers) out = [] @@ -108,10 +108,10 @@ def create_fake_trajs( td = TensorDict( source={ "traj_ids": traj_ids, - "a": traj_ids.clone(), + "a": traj_ids.clone().unsqueeze(-1), "steps_count": steps_count, "workers": workers, - "done": done, + "done": done.unsqueeze(-1), }, batch_size=[num_workers], ) @@ -125,15 +125,7 @@ def create_fake_trajs( return out @pytest.mark.parametrize("num_workers", range(3, 34, 3)) - @pytest.mark.parametrize( - "traj_len", - [ - 10, - 17, - 50, - 97, - ], - ) + @pytest.mark.parametrize("traj_len", [10, 17, 50, 97]) def test_splits(self, num_workers, traj_len): trajs = TestSplits.create_fake_trajs(num_workers, traj_len) diff --git a/test/test_tensor_spec.py b/test/test_tensor_spec.py index 0b64c6df52d..7a8a455615d 100644 --- a/test/test_tensor_spec.py +++ b/test/test_tensor_spec.py @@ -53,7 +53,7 @@ def test_discrete(cls): r = ts.rand() ts.to_numpy(r) ts.encode(torch.tensor([5])) - ts.encode(torch.tensor([5]).numpy()) + ts.encode(torch.tensor(5).numpy()) ts.encode(9) with pytest.raises(AssertionError): ts.encode(torch.tensor([11])) # out of bounds @@ -887,9 +887,8 @@ def test_categorical_action_spec_rand(self): sample = action_spec.rand((10000,)) - sample_list = sample[:, 0] + sample_list = sample sample_list = [sum(sample_list == i).item() for i in range(10)] - print(sample_list) assert chisquare(sample_list).pvalue > 0.1 sample = action_spec.to_numpy(sample) diff --git a/test/test_trainer.py b/test/test_trainer.py index bd0c8a8ea59..87919df592d 100644 --- a/test/test_trainer.py +++ b/test/test_trainer.py @@ -773,7 +773,7 @@ def test_masking(): ) td_out = trainer._process_batch_hook(td) assert td_out.shape[0] == td.get("mask").sum() - assert (td["tensor"][td["mask"].squeeze(-1)] == td_out["tensor"]).all() + assert (td["tensor"][td["mask"]] == td_out["tensor"]).all() class TestSubSampler: diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 48b80e30a58..f5881e05f3c 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -432,7 +432,8 @@ def __init__( self._tensordict = env.reset() self._tensordict.set( - "step_count", torch.zeros(*self.env.batch_size, 1, dtype=torch.int) + "step_count", + torch.zeros(self.env.batch_size, dtype=torch.int, device=env.device), ) if ( @@ -468,14 +469,21 @@ def __init__( ) # in addition to outputs of the policy, we add traj_ids and step_count to # _tensordict_out which will be collected during rollout - if len(self.env.batch_size): - traj_ids = torch.zeros(*self._tensordict_out.batch_size, 1) - else: - traj_ids = torch.zeros(*self._tensordict_out.batch_size, 1, 1) - - self._tensordict_out.set("traj_ids", traj_ids) self._tensordict_out.set( - "step_count", torch.zeros(*self._tensordict_out.batch_size, 1) + "traj_ids", + torch.zeros( + *self._tensordict_out.batch_size, + dtype=torch.int64, + device=self.env_device, + ), + ) + self._tensordict_out.set( + "step_count", + torch.zeros( + *self._tensordict_out.batch_size, + dtype=torch.int64, + device=self.env_device, + ), ) self.return_in_place = return_in_place @@ -589,7 +597,7 @@ def _reset_if_necessary(self) -> None: if not self.reset_when_done: done = torch.zeros_like(done) steps = self._tensordict.get("step_count") - done_or_terminated = done | (steps == self.max_frames_per_traj) + done_or_terminated = done.squeeze(-1) | (steps == self.max_frames_per_traj) if self._has_been_done is None: self._has_been_done = done_or_terminated else: @@ -604,7 +612,7 @@ def _reset_if_necessary(self) -> None: traj_ids = self._tensordict.get("traj_ids").clone() steps = steps.clone() if len(self.env.batch_size): - self._tensordict.masked_fill_(done_or_terminated.squeeze(-1), 0) + self._tensordict.masked_fill_(done_or_terminated, 0) self._tensordict.set("reset_workers", done_or_terminated) else: self._tensordict.zero_() @@ -620,8 +628,8 @@ def _reset_if_necessary(self) -> None: 1, done_or_terminated.sum() + 1, device=traj_ids.device ) steps[done_or_terminated] = 0 - self._tensordict.set("traj_ids", traj_ids) # no ops if they already match - self._tensordict.set("step_count", steps) + self._tensordict.set_("traj_ids", traj_ids) # no ops if they already match + self._tensordict.set_("step_count", steps) @torch.no_grad() def rollout(self) -> TensorDictBase: @@ -636,7 +644,7 @@ def rollout(self) -> TensorDictBase: self._tensordict.fill_("step_count", 0) n = self.env.batch_size[0] if len(self.env.batch_size) else 1 - self._tensordict.set("traj_ids", torch.arange(n).unsqueeze(-1)) + self._tensordict.set("traj_ids", torch.arange(n).view(self.env.batch_size[:1])) tensordict_out = [] with set_exploration_mode(self.exploration_mode): @@ -673,7 +681,6 @@ def reset(self, index=None, **kwargs) -> None: raise RuntimeError("resetting unique env with index is not permitted.") reset_workers = torch.zeros( *self.env.batch_size, - 1, dtype=torch.bool, device=self.env.device, ) diff --git a/torchrl/collectors/utils.py b/torchrl/collectors/utils.py index 9edbb81c0e5..2f3df6c3a5e 100644 --- a/torchrl/collectors/utils.py +++ b/torchrl/collectors/utils.py @@ -41,7 +41,7 @@ def split_trajectories(rollout_tensordict: TensorDictBase) -> TensorDictBase: splits = traj_ids.view(-1) splits = [(splits == i).sum().item() for i in splits.unique_consecutive()] # if all splits are identical then we can skip this function - if len(set(splits)) == 1 and splits[0] == traj_ids.shape[1]: + if len(set(splits)) == 1 and splits[0] == traj_ids.shape[-1]: rollout_tensordict.set( "mask", torch.ones( @@ -63,16 +63,19 @@ def split_trajectories(rollout_tensordict: TensorDictBase) -> TensorDictBase: dones = out_splits["done"] valid_ids = list(range(len(dones))) out_splits = {key: [_out[i] for i in valid_ids] for key, _out in out_splits.items()} - mask = [torch.ones_like(_out, dtype=torch.bool) for _out in out_splits["done"]] + mask = [ + torch.ones_like(_out[..., 0], dtype=torch.bool) for _out in out_splits["done"] + ] out_splits["mask"] = mask out_dict = { key: torch.nn.utils.rnn.pad_sequence(_o, batch_first=True) for key, _o in out_splits.items() } + out_dict["mask"] = out_dict["mask"] td = TensorDict( source=out_dict, device=rollout_tensordict.device, - batch_size=out_dict["mask"].shape[:-1], + batch_size=out_dict["mask"].shape, ) td = td.unflatten_keys(sep) if (out_dict["done"].sum(1) > 1).any(): diff --git a/torchrl/data/postprocs/postprocs.py b/torchrl/data/postprocs/postprocs.py index 6d6ed79d4b0..6463fc02d72 100644 --- a/torchrl/data/postprocs/postprocs.py +++ b/torchrl/data/postprocs/postprocs.py @@ -170,7 +170,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: done = tensordict.get("done") if "mask" in tensordict.keys(): - mask = tensordict.get("mask") + mask = tensordict.get("mask").view_as(done) else: mask = done.clone().flip(1).cumsum(1).flip(1).to(torch.bool) reward = tensordict.get("reward") diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 71a20e522ef..a46af2881e4 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -222,10 +222,24 @@ def encode(self, val: Union[np.ndarray, torch.Tensor]) -> torch.Tensor: ): val = val.copy() val = torch.tensor(val, dtype=self.dtype, device=self.device) + if val.shape[-len(self.shape) :] != self.shape: + # option 1: add a singleton dim at the end + if self.shape == torch.Size([1]): + val = val.unsqueeze(-1) + else: + raise RuntimeError( + f"Shape mismatch: the value has shape {val.shape} which " + f"is incompatible with the spec shape {self.shape}." + ) if not _NO_CHECK_SPEC_ENCODE: self.assert_is_in(val) return val + def __setattr__(self, key, value): + if key == "shape": + value = torch.Size(value) + super().__setattr__(key, value) + def to_numpy(self, val: torch.Tensor, safe: bool = True) -> np.ndarray: """Returns the np.ndarray correspondent of an input tensor. @@ -433,7 +447,9 @@ def rand(self, shape=None) -> torch.Tensor: return out else: interval = self.space.maximum - self.space.minimum - r = torch.rand(*shape, *interval.shape, device=interval.device) + r = torch.rand( + torch.Size([*shape, *interval.shape]), device=interval.device + ) r = interval * r r = self.space.minimum + r r = r.to(self.dtype).to(self.device) @@ -519,7 +535,7 @@ def rand(self, shape=None) -> torch.Tensor: if shape is None: shape = torch.Size([]) return torch.nn.functional.gumbel_softmax( - torch.rand(*shape, self.space.n, device=self.device), + torch.rand(torch.Size([*shape, self.space.n]), device=self.device), hard=True, dim=-1, ).to(torch.long) @@ -652,7 +668,7 @@ def rand(self, shape=None) -> torch.Tensor: if shape is None: shape = torch.Size([]) interval = self.space.maximum - self.space.minimum - r = torch.rand(*shape, *interval.shape, device=interval.device) + r = torch.rand(torch.Size([*shape, *interval.shape]), device=interval.device) r = r * interval r = self.space.minimum + r r = r.to(self.dtype) @@ -1021,7 +1037,7 @@ def __init__( dtype: Optional[Union[str, torch.dtype]] = torch.long, ): if shape is None: - shape = torch.Size((1,)) + shape = torch.Size([]) dtype, device = _default_dtype_and_device(dtype, device) space = DiscreteBox(n) super().__init__(shape, space, device, dtype, domain="discrete") @@ -1056,7 +1072,7 @@ def __eq__(self, other): ) def to_numpy(self, val: TensorDict, safe: bool = True) -> dict: - return super().to_numpy(val, safe).squeeze(-1) + return super().to_numpy(val, safe) class CompositeSpec(TensorSpec): diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index fef7249fa38..b93ccdb48f3 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -214,11 +214,11 @@ def __init__( if "is_closed" not in self.__dir__(): self.is_closed = True if "_input_spec" not in self.__dir__(): - self._input_spec = None + self.__dict__["_input_spec"] = None if "_reward_spec" not in self.__dir__(): - self._reward_spec = None + self.__dict__["_reward_spec"] = None if "_observation_spec" not in self.__dir__(): - self._observation_spec = None + self.__dict__["_observation_spec"] = None if batch_size is not None: # we want an error to be raised if we pass batch_size but # it's already been set @@ -240,6 +240,14 @@ def __new__(cls, *args, _inplace_update=False, _batch_locked=True, **kwargs): cls._device = None return super().__new__(cls) + def __setattr__(self, key, value): + if key in ("_input_spec", "_observation_spec", "_action_spec", "_reward_spec"): + raise AttributeError( + "To set an environment spec, please use `env.observation_spec = obs_spec` (without the leading" + " underscore)." + ) + return super().__setattr__(key, value) + @property def batch_locked(self) -> bool: """Whether the environnement can be used with a batch size different from the one it was initialized with or not. @@ -268,9 +276,9 @@ def action_spec(self) -> TensorSpec: @action_spec.setter def action_spec(self, value: TensorSpec) -> None: if self._input_spec is None: - self._input_spec = CompositeSpec(action=value) + self.input_spec = CompositeSpec(action=value) else: - self._input_spec["action"] = value + self.input_spec["action"] = value @property def input_spec(self) -> TensorSpec: @@ -278,7 +286,9 @@ def input_spec(self) -> TensorSpec: @input_spec.setter def input_spec(self, value: TensorSpec) -> None: - self._input_spec = value + if not isinstance(value, CompositeSpec): + raise TypeError("The type of an input_spec must be Composite.") + self.__dict__["_input_spec"] = value @property def reward_spec(self) -> TensorSpec: @@ -286,7 +296,18 @@ def reward_spec(self) -> TensorSpec: @reward_spec.setter def reward_spec(self, value: TensorSpec) -> None: - self._reward_spec = value + if not hasattr(value, "shape"): + raise TypeError( + f"reward_spec of type {type(value)} do not have a shape " f"attribute." + ) + if len(value.shape) == 0: + raise RuntimeError( + "the reward_spec shape cannot be empty (this error" + " usually comes from trying to set a reward_spec" + " with a null number of dimensions. Try using a multidimensional" + " spec instead, for instance with a singleton dimension at the tail)." + ) + self.__dict__["_reward_spec"] = value @property def observation_spec(self) -> TensorSpec: @@ -294,7 +315,9 @@ def observation_spec(self) -> TensorSpec: @observation_spec.setter def observation_spec(self, value: TensorSpec) -> None: - self._observation_spec = value + if not isinstance(value, CompositeSpec): + raise TypeError("The type of an observation_spec must be Composite.") + self.__dict__["_observation_spec"] = value def step(self, tensordict: TensorDictBase) -> TensorDictBase: """Makes a step in the environment. @@ -320,7 +343,31 @@ def step(self, tensordict: TensorDictBase) -> TensorDictBase: obs_keys = set(self.observation_spec.keys()) tensordict_out_select = tensordict_out.select(*obs_keys) tensordict_out = tensordict_out.exclude(*obs_keys) - tensordict_out["next"] = tensordict_out_select + tensordict_out.set("next", tensordict_out_select) + + reward = tensordict_out.get("reward") + # unsqueeze rewards if needed + expected_reward_shape = torch.Size( + [*tensordict_out.batch_size, *self.reward_spec.shape] + ) + n = len(expected_reward_shape) + if len(reward.shape) >= n and reward.shape[-n:] != expected_reward_shape: + reward = reward.view(*reward.shape[:n], *expected_reward_shape) + tensordict_out.set("reward", reward) + elif len(reward.shape) < n: + reward = reward.view(expected_reward_shape) + tensordict_out.set("reward", reward) + + done = tensordict_out.get("done") + # unsqueeze done if needed + expected_done_shape = torch.Size([*tensordict_out.batch_size, 1]) + n = len(expected_done_shape) + if len(done.shape) >= n and done.shape[-n:] != expected_done_shape: + done = done.view(*done.shape[:n], *expected_done_shape) + tensordict_out.set("done", done) + elif len(done.shape) < n: + done = done.view(expected_done_shape) + tensordict_out.set("done", done) if tensordict_out is tensordict: raise RuntimeError( @@ -382,6 +429,15 @@ def reset( """ tensordict_reset = self._reset(tensordict, **kwargs) + + done = tensordict_reset.get("done", None) + if done is not None: + # unsqueeze done if needed + expected_done_shape = torch.Size([*tensordict_reset.batch_size, 1]) + if done.shape != expected_done_shape: + done = done.view(expected_done_shape) + tensordict_reset.set("done", done) + if tensordict_reset.device != self.device: tensordict_reset = tensordict_reset.to(self.device) if tensordict_reset is tensordict: diff --git a/torchrl/envs/gym_like.py b/torchrl/envs/gym_like.py index 9021a4f590b..bf914708080 100644 --- a/torchrl/envs/gym_like.py +++ b/torchrl/envs/gym_like.py @@ -204,7 +204,11 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: reward = self.read_reward(reward, _reward) - # TODO: check how to deal with np arrays + if isinstance(done, bool) or ( + isinstance(done, np.ndarray) and not len(done) + ): + done = torch.tensor([done], device=self.device) + done, do_break = self.read_done(done) if do_break: break diff --git a/torchrl/envs/libs/brax.py b/torchrl/envs/libs/brax.py index 47dc69ba2a5..7ab3253ba8c 100644 --- a/torchrl/envs/libs/brax.py +++ b/torchrl/envs/libs/brax.py @@ -117,22 +117,25 @@ def _make_state_spec(self, env: "brax.envs.env.Env"): return state_spec def _make_specs(self, env: "brax.envs.env.Env") -> None: # noqa: F821 - self._input_spec = CompositeSpec( + self.input_spec = CompositeSpec( action=NdBoundedTensorSpec( minimum=-1, maximum=1, shape=(env.action_size,), device=self.device ) ) - self._reward_spec = NdUnboundedContinuousTensorSpec( - shape=(), device=self.device + self.reward_spec = NdUnboundedContinuousTensorSpec( + shape=[ + 1, + ], + device=self.device, ) - self._observation_spec = CompositeSpec( + self.observation_spec = CompositeSpec( observation=NdUnboundedContinuousTensorSpec( shape=(env.observation_size,), device=self.device ) ) # extract state spec from instance - self._state_spec = self._make_state_spec(env) - self._input_spec["state"] = self._state_spec + self.state_spec = self._make_state_spec(env) + self.input_spec["state"] = self.state_spec def _make_state_example(self): key = jax.random.PRNGKey(0) diff --git a/torchrl/envs/libs/dm_control.py b/torchrl/envs/libs/dm_control.py index a3a08158c21..fbe73e828d3 100644 --- a/torchrl/envs/libs/dm_control.py +++ b/torchrl/envs/libs/dm_control.py @@ -57,23 +57,29 @@ def _dmcontrol_to_torchrl_spec_transform( elif isinstance(spec, dm_env.specs.BoundedArray): if dtype is None: dtype = numpy_to_torch_dtype_dict[spec.dtype] + shape = spec.shape + if not len(shape): + shape = torch.Size([1]) return NdBoundedTensorSpec( - shape=spec.shape, + shape=shape, minimum=spec.minimum, maximum=spec.maximum, dtype=dtype, device=device, ) elif isinstance(spec, dm_env.specs.Array): + shape = spec.shape + if not len(shape): + shape = torch.Size([1]) if dtype is None: dtype = numpy_to_torch_dtype_dict[spec.dtype] if dtype in (torch.float, torch.double, torch.half): return NdUnboundedContinuousTensorSpec( - shape=spec.shape, dtype=dtype, device=device + shape=shape, dtype=dtype, device=device ) else: return NdUnboundedDiscreteTensorSpec( - shape=spec.shape, dtype=dtype, device=device + shape=shape, dtype=dtype, device=device ) else: @@ -213,7 +219,7 @@ def _output_transform( @property def input_spec(self) -> TensorSpec: if self._input_spec is None: - self._input_spec = CompositeSpec( + self.__dict__["_input_spec"] = CompositeSpec( action=_dmcontrol_to_torchrl_spec_transform( self._env.action_spec(), device=self.device ) @@ -222,31 +228,49 @@ def input_spec(self) -> TensorSpec: @input_spec.setter def input_spec(self, value: TensorSpec) -> None: - self._input_spec = value + if not isinstance(value, CompositeSpec): + raise TypeError("The type of an input_spec must be Composite.") + self.__dict__["_input_spec"] = value @property def observation_spec(self) -> TensorSpec: if self._observation_spec is None: - self._observation_spec = _dmcontrol_to_torchrl_spec_transform( + self.__dict__["_observation_spec"] = _dmcontrol_to_torchrl_spec_transform( self._env.observation_spec(), device=self.device ) return self._observation_spec @observation_spec.setter def observation_spec(self, value: TensorSpec) -> None: - self._observation_spec = value + if not isinstance(value, CompositeSpec): + raise TypeError("The type of an observation_spec must be Composite.") + self.__dict__["_observation_spec"] = value @property def reward_spec(self) -> TensorSpec: if self._reward_spec is None: - self._reward_spec = _dmcontrol_to_torchrl_spec_transform( + reward_spec = _dmcontrol_to_torchrl_spec_transform( self._env.reward_spec(), device=self.device ) + if len(reward_spec.shape) == 0: + reward_spec.shape = torch.Size([1]) + self.__dict__["_reward_spec"] = reward_spec return self._reward_spec @reward_spec.setter def reward_spec(self, value: TensorSpec) -> None: - self._reward_spec = value + if not hasattr(value, "shape"): + raise TypeError( + f"reward_spec of type {type(value)} do not have a shape " f"attribute." + ) + if len(value.shape) == 0: + raise RuntimeError( + "the reward_spec shape cannot be empty (this error" + " usually comes from trying to set a reward_spec" + " with a null number of dimensions. Try using a multidimensional" + " spec instead, for instance with a singleton dimension at the tail)." + ) + self.__dict__["_reward_spec"] = value def __repr__(self) -> str: return ( @@ -344,12 +368,5 @@ def _check_kwargs(self, kwargs: Dict): else: raise TypeError("dm_control requires env_name to be specified") - # def _set_seed(self, _seed: int) -> int: - # self._env = self._build_env( - # _seed=_seed, **self._constructor_kwargs - # ) - # self.reset() - # return _seed - def __repr__(self) -> str: return f"{self.__class__.__name__}(env={self.env_name}, task={self.task_name}, batch_size={self.batch_size})" diff --git a/torchrl/envs/libs/gym.py b/torchrl/envs/libs/gym.py index 309ed3b066a..abecc288c66 100644 --- a/torchrl/envs/libs/gym.py +++ b/torchrl/envs/libs/gym.py @@ -15,9 +15,9 @@ DiscreteTensorSpec, MultOneHotDiscreteTensorSpec, NdBoundedTensorSpec, + NdUnboundedContinuousTensorSpec, OneHotDiscreteTensorSpec, TensorSpec, - UnboundedContinuousTensorSpec, ) from ..._utils import implement_for @@ -70,12 +70,15 @@ def _gym_to_torchrl_spec_transform( elif isinstance(spec, gym.spaces.multi_discrete.MultiDiscrete): return MultOneHotDiscreteTensorSpec(spec.nvec, device=device) elif isinstance(spec, gym.spaces.Box): + shape = spec.shape + if not len(shape): + shape = torch.Size([1]) if dtype is None: dtype = numpy_to_torch_dtype_dict[spec.dtype] return NdBoundedTensorSpec( torch.tensor(spec.low, device=device, dtype=dtype), torch.tensor(spec.high, device=device, dtype=dtype), - torch.Size(spec.shape), + shape, dtype=dtype, device=device, ) @@ -257,17 +260,19 @@ def _make_specs(self, env: "gym.Env") -> None: device=self.device, categorical_action_encoding=self._categorical_action_encoding, ) - self.observation_spec = _gym_to_torchrl_spec_transform( + observation_spec = _gym_to_torchrl_spec_transform( env.observation_space, device=self.device, categorical_action_encoding=self._categorical_action_encoding, ) - if not isinstance(self.observation_spec, CompositeSpec): + if not isinstance(observation_spec, CompositeSpec): if self.from_pixels: - self.observation_spec = CompositeSpec(pixels=self.observation_spec) + observation_spec = CompositeSpec(pixels=observation_spec) else: - self.observation_spec = CompositeSpec(observation=self.observation_spec) - self.reward_spec = UnboundedContinuousTensorSpec( + observation_spec = CompositeSpec(observation=observation_spec) + self.observation_spec = observation_spec + self.reward_spec = NdUnboundedContinuousTensorSpec( + shape=[1], device=self.device, ) diff --git a/torchrl/envs/libs/jumanji.py b/torchrl/envs/libs/jumanji.py index 2e59e452b19..67dceafbaec 100644 --- a/torchrl/envs/libs/jumanji.py +++ b/torchrl/envs/libs/jumanji.py @@ -59,25 +59,27 @@ def _jumanji_to_torchrl_spec_transform( dtype = numpy_to_torch_dtype_dict[spec.dtype] return action_space_cls(spec.num_values, dtype=dtype, device=device) elif isinstance(spec, jumanji.specs.BoundedArray): + shape = spec.shape if dtype is None: dtype = numpy_to_torch_dtype_dict[spec.dtype] return NdBoundedTensorSpec( - shape=spec.shape, + shape=shape, minimum=np.asarray(spec.minimum), maximum=np.asarray(spec.maximum), dtype=dtype, device=device, ) elif isinstance(spec, jumanji.specs.Array): + shape = spec.shape if dtype is None: dtype = numpy_to_torch_dtype_dict[spec.dtype] if dtype in (torch.float, torch.double, torch.half): return NdUnboundedContinuousTensorSpec( - shape=spec.shape, dtype=dtype, device=device + shape=shape, dtype=dtype, device=device ) else: return NdUnboundedDiscreteTensorSpec( - shape=spec.shape, dtype=dtype, device=device + shape=shape, dtype=dtype, device=device ) elif isinstance(spec, jumanji.specs.Spec) and hasattr(spec, "__dict__"): new_spec = {} @@ -188,18 +190,23 @@ def _make_observation_spec(self, env) -> TensorSpec: raise TypeError(f"Unsupported spec type {type(spec)}") def _make_reward_spec(self, env) -> TensorSpec: - return _jumanji_to_torchrl_spec_transform(env.reward_spec(), device=self.device) + reward_spec = _jumanji_to_torchrl_spec_transform( + env.reward_spec(), device=self.device + ) + if not len(reward_spec.shape): + reward_spec.shape = torch.Size([1]) + return reward_spec def _make_specs(self, env: "jumanji.env.Environment") -> None: # noqa: F821 # extract spec from jumanji definition - self._input_spec = self._make_input_spec(env) - self._observation_spec = self._make_observation_spec(env) - self._reward_spec = self._make_reward_spec(env) + self.input_spec = self._make_input_spec(env) + self.observation_spec = self._make_observation_spec(env) + self.reward_spec = self._make_reward_spec(env) # extract state spec from instance - self._state_spec = self._make_state_spec(env) - self._input_spec["state"] = self._state_spec + self.state_spec = self._make_state_spec(env) + self.input_spec["state"] = self.state_spec # build state example for data conversion self._state_example = self._make_state_example(env) @@ -221,7 +228,7 @@ def _set_seed(self, seed): def read_state(self, state): state_dict = _object_to_tensordict(state, self.device, self.batch_size) - return self._state_spec.encode(state_dict) + return self.state_spec.encode(state_dict) def read_obs(self, obs): if isinstance(obs, (list, jnp.ndarray, np.ndarray)): diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 4164cc44cfd..8ef11c6f691 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -333,8 +333,9 @@ def __init__( self._last_obs = None self.cache_specs = cache_specs - self._reward_spec = None - self._observation_spec = None + self.__dict__["_reward_spec"] = None + self.__dict__["_input_spec"] = None + self.__dict__["_observation_spec"] = None self.batch_size = self.base_env.batch_size def _set_env(self, env: EnvBase, device) -> None: @@ -395,7 +396,7 @@ def observation_spec(self) -> TensorSpec: deepcopy(self.base_env.observation_spec) ) if self.cache_specs: - self._observation_spec = observation_spec + self.__dict__["_observation_spec"] = observation_spec else: observation_spec = self._observation_spec return observation_spec @@ -413,7 +414,7 @@ def input_spec(self) -> TensorSpec: deepcopy(self.base_env.input_spec) ) if self.cache_specs: - self._input_spec = input_spec + self.__dict__["_input_spec"] = input_spec else: input_spec = self._input_spec return input_spec @@ -426,7 +427,7 @@ def reward_spec(self) -> TensorSpec: deepcopy(self.base_env.reward_spec) ) if self.cache_specs: - self._reward_spec = reward_spec + self.__dict__["_reward_spec"] = reward_spec else: reward_spec = self._reward_spec return reward_spec @@ -490,9 +491,9 @@ def close(self): self.is_closed = True def empty_cache(self): - self._observation_spec = None - self._input_spec = None - self._reward_spec = None + self.__dict__["_observation_spec"] = None + self.__dict__["_input_spec"] = None + self.__dict__["_reward_spec"] = None def append_transform(self, transform: Transform) -> None: self._erase_metadata() @@ -550,18 +551,18 @@ def __repr__(self) -> str: def _erase_metadata(self): if self.cache_specs: - self._input_spec = None - self._observation_spec = None - self._reward_spec = None + self.__dict__["_input_spec"] = None + self.__dict__["_observation_spec"] = None + self.__dict__["_reward_spec"] = None def to(self, device: DEVICE_TYPING) -> TransformedEnv: self.base_env.to(device) self.transform.to(device) if self.cache_specs: - self._input_spec = None - self._observation_spec = None - self._reward_spec = None + self.__dict__["_input_spec"] = None + self.__dict__["_observation_spec"] = None + self.__dict__["_reward_spec"] = None return self def __setattr__(self, key, value): diff --git a/torchrl/envs/vec_env.py b/torchrl/envs/vec_env.py index a601b417172..06f9821eb8a 100644 --- a/torchrl/envs/vec_env.py +++ b/torchrl/envs/vec_env.py @@ -195,8 +195,9 @@ def __init__( "memmap and shared memory are mutually exclusive features." ) self._batch_size = None - self._observation_spec = None - self._reward_spec = None + self.__dict__["_observation_spec"] = None + self.__dict__["_input_spec"] = None + self.__dict__["_reward_spec"] = None self._device = None self._dummy_env_str = None self._seeds = None @@ -276,9 +277,9 @@ def _set_properties(self): meta_data = deepcopy(self.meta_data) if self._single_task: self._batch_size = meta_data.batch_size - self._observation_spec = meta_data.specs["observation_spec"] - self._reward_spec = meta_data.specs["reward_spec"] - self._input_spec = meta_data.specs["input_spec"] + self.observation_spec = meta_data.specs["observation_spec"] + self.reward_spec = meta_data.specs["reward_spec"] + self.input_spec = meta_data.specs["input_spec"] self._dummy_env_str = meta_data.env_str self._device = meta_data.device self._env_tensordict = meta_data.tensordict @@ -287,15 +288,15 @@ def _set_properties(self): self._batch_size = torch.Size([self.num_workers, *meta_data[0].batch_size]) self._device = meta_data[0].device # TODO: check that all action_spec and reward spec match (issue #351) - self._reward_spec = meta_data[0].specs["reward_spec"] + self.reward_spec = meta_data[0].specs["reward_spec"] _observation_spec = {} for md in meta_data: _observation_spec.update(dict(**md.specs["observation_spec"])) - self._observation_spec = CompositeSpec(**_observation_spec) + self.observation_spec = CompositeSpec(**_observation_spec) _input_spec = {} for md in meta_data: _input_spec.update(dict(**md.specs["input_spec"])) - self._input_spec = CompositeSpec(**_input_spec) + self.input_spec = CompositeSpec(**_input_spec) self._dummy_env_str = str(meta_data[0]) self._env_tensordict = torch.stack( [meta_data.tensordict for meta_data in meta_data], 0 @@ -334,7 +335,9 @@ def observation_spec(self) -> TensorSpec: @observation_spec.setter def observation_spec(self, value: TensorSpec) -> None: - self._observation_spec = value + if not isinstance(value, CompositeSpec) and value is not None: + raise TypeError("The type of an observation_spec must be Composite.") + self.__dict__["_observation_spec"] = value @property def input_spec(self) -> TensorSpec: @@ -344,7 +347,9 @@ def input_spec(self) -> TensorSpec: @input_spec.setter def input_spec(self, value: TensorSpec) -> None: - self._input_spec = value + if not isinstance(value, CompositeSpec) and value is not None: + raise TypeError("The type of an input_spec must be Composite.") + self.__dict__["_input_spec"] = value @property def reward_spec(self) -> TensorSpec: @@ -354,7 +359,18 @@ def reward_spec(self) -> TensorSpec: @reward_spec.setter def reward_spec(self, value: TensorSpec) -> None: - self._reward_spec = value + if not hasattr(value, "shape") and value is not None: + raise TypeError( + f"reward_spec of type {type(value)} do not have a shape " f"attribute." + ) + if value is not None and len(value.shape) == 0: + raise RuntimeError( + "the reward_spec shape cannot be empty (this error" + " usually comes from trying to set a reward_spec" + " with a null number of dimensions. Try using a multidimensional" + " spec instead, for instance with a singleton dimension at the tail)." + ) + self.__dict__["_reward_spec"] = value def _create_td(self) -> None: """Creates self.shared_tensordict_parent, a TensorDict used to store the most recent observations.""" @@ -602,7 +618,7 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: self._assert_tensordict_shape(tensordict) reset_workers = tensordict.get("reset_workers") else: - reset_workers = torch.ones(self.num_workers, 1, dtype=torch.bool) + reset_workers = torch.ones(self.num_workers, dtype=torch.bool) keys = set() for i, _env in enumerate(self._envs): @@ -818,7 +834,7 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: self._assert_tensordict_shape(tensordict) reset_workers = tensordict.get("reset_workers") else: - reset_workers = torch.ones(self.num_workers, 1, dtype=torch.bool) + reset_workers = torch.ones(self.num_workers, dtype=torch.bool) for i, channel in enumerate(self.parent_channels): if not reset_workers[i]: diff --git a/torchrl/modules/models/recipes/impala.py b/torchrl/modules/models/recipes/impala.py index 6dedfb42e77..67d35635637 100644 --- a/torchrl/modules/models/recipes/impala.py +++ b/torchrl/modules/models/recipes/impala.py @@ -177,7 +177,7 @@ def forward(self, tensordict: TensorDictBase): # noqa: D102 x = tensordict.get(self.observation_key) done = tensordict.get("done").squeeze(-1) reward = tensordict.get("reward").squeeze(-1) - mask = tensordict.get("mask").squeeze(-1) + mask = tensordict.get("mask") core_state = ( tensordict.get("core_state") if "core_state" in tensordict.keys() else None ) diff --git a/torchrl/objectives/deprecated.py b/torchrl/objectives/deprecated.py index b80cf2854ff..69d09fc3fe0 100644 --- a/torchrl/objectives/deprecated.py +++ b/torchrl/objectives/deprecated.py @@ -218,6 +218,12 @@ def _qvalue_loss(self, tensordict: TensorDictBase) -> Tensor: next_td, selected_q_params, ) + state_action_value = next_td.get("state_action_value") + if ( + state_action_value.shape[-len(sample_log_prob.shape) :] + != sample_log_prob.shape + ): + sample_log_prob = sample_log_prob.unsqueeze(-1) state_value = ( next_td.get("state_action_value") - self.alpha * sample_log_prob ) diff --git a/torchrl/objectives/dreamer.py b/torchrl/objectives/dreamer.py index 8d839078154..c42c8e29387 100644 --- a/torchrl/objectives/dreamer.py +++ b/torchrl/objectives/dreamer.py @@ -73,7 +73,7 @@ def forward(self, tensordict: TensorDict) -> torch.Tensor: tensordict.get(("next", "prior_std")), tensordict.get(("next", "posterior_mean")), tensordict.get(("next", "posterior_std")), - ) + ).unsqueeze(-1) reco_loss = distance_loss( tensordict.get(("next", "pixels")), tensordict.get(("next", "reco_pixels")), @@ -81,7 +81,7 @@ def forward(self, tensordict: TensorDict) -> torch.Tensor: ) if not self.global_average: reco_loss = reco_loss.sum((-3, -2, -1)) - reco_loss = reco_loss.mean() + reco_loss = reco_loss.mean().unsqueeze(-1) reward_loss = distance_loss( tensordict.get("true_reward"), @@ -90,7 +90,8 @@ def forward(self, tensordict: TensorDict) -> torch.Tensor: ) if not self.global_average: reward_loss = reward_loss.squeeze(-1) - reward_loss = reward_loss.mean() + reward_loss = reward_loss.mean().unsqueeze(-1) + # import ipdb; ipdb.set_trace() return ( TensorDict( { diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index 26b73e70612..0ab4924e271 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -105,13 +105,12 @@ def _log_weight( dist = self.actor.get_dist(tensordict_clone, params=self.actor_params) log_prob = dist.log_prob(action) - log_prob = log_prob.unsqueeze(-1) prev_log_prob = tensordict.get("sample_log_prob") if prev_log_prob.requires_grad: raise RuntimeError("tensordict prev_log_prob requires grad.") - log_weight = log_prob - prev_log_prob + log_weight = (log_prob - prev_log_prob).unsqueeze(-1) return log_weight, dist def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor: diff --git a/torchrl/objectives/sac.py b/torchrl/objectives/sac.py index 994bb1cab7c..c5d31e19f5a 100644 --- a/torchrl/objectives/sac.py +++ b/torchrl/objectives/sac.py @@ -177,23 +177,29 @@ def device(self) -> torch.device: ) def forward(self, tensordict: TensorDictBase) -> TensorDictBase: + shape = None if tensordict.ndimension() > 1: - tensordict = tensordict.view(-1) + shape = tensordict.shape + tensordict_reshape = tensordict.reshape(-1) + else: + tensordict_reshape = tensordict device = self.device - td_device = tensordict.to(device) + td_device = tensordict_reshape.to(device) loss_actor = self._loss_actor(td_device) loss_qvalue, priority = self._loss_qvalue(td_device) loss_value = self._loss_value(td_device) loss_alpha = self._loss_alpha(td_device) - tensordict.set(self.priority_key, priority) + tensordict_reshape.set(self.priority_key, priority) if (loss_actor.shape != loss_qvalue.shape) or ( loss_actor.shape != loss_value.shape ): raise RuntimeError( f"Losses shape mismatch: {loss_actor.shape}, {loss_qvalue.shape} and {loss_value.shape}" ) + if shape: + tensordict.update(tensordict_reshape.view(shape)) return TensorDict( { "loss_actor": loss_actor.mean(), diff --git a/torchrl/trainers/helpers/models.py b/torchrl/trainers/helpers/models.py index 3294c2b87fa..947e3e9a613 100644 --- a/torchrl/trainers/helpers/models.py +++ b/torchrl/trainers/helpers/models.py @@ -169,7 +169,6 @@ def make_dqn_actor( # automatically infer in key (in_key,) = itertools.islice(env_specs["observation_spec"], 1) - out_features = action_spec.shape[0] actor_class = QValueActor actor_kwargs = {} @@ -178,6 +177,8 @@ def make_dqn_actor( # to the number of possible choices and also set categorical behavioural for actors. actor_kwargs.update({"action_space": "categorical"}) out_features = env_specs["action_spec"].space.n + else: + out_features = action_spec.shape[0] if cfg.distributional: if not atoms: diff --git a/torchrl/trainers/trainers.py b/torchrl/trainers/trainers.py index 07ad793668e..da312aaf9f5 100644 --- a/torchrl/trainers/trainers.py +++ b/torchrl/trainers/trainers.py @@ -646,7 +646,7 @@ def __init__( def extend(self, batch: TensorDictBase) -> TensorDictBase: if self.flatten_tensordicts: if "mask" in batch.keys(): - batch = batch[batch.get("mask").squeeze(-1)] + batch = batch[batch.get("mask")] else: batch = batch.reshape(-1) else: @@ -810,9 +810,7 @@ def __init__(self, logname="r_training", log_pbar: bool = False): def __call__(self, batch: TensorDictBase) -> Dict: if "mask" in batch.keys(): return { - self.logname: batch.get("reward")[batch.get("mask").squeeze(-1)] - .mean() - .item(), + self.logname: batch.get("reward")[batch.get("mask")].mean().item(), "log_pbar": self.log_pbar, } return { @@ -857,7 +855,7 @@ def __init__( def update_reward_stats(self, batch: TensorDictBase) -> None: reward = batch.get("reward") if "mask" in batch.keys(): - reward = reward[batch.get("mask").squeeze(-1)] + reward = reward[batch.get("mask")] if self._update_has_been_called and not self._normalize_has_been_called: # We'd like to check that rewards are normalized. Problem is that the trainer can collect data without calling steps... # raise RuntimeError( @@ -935,7 +933,7 @@ def mask_batch(batch: TensorDictBase) -> TensorDictBase: """ if "mask" in batch.keys(): mask = batch.get("mask") - return batch[mask.squeeze(-1)] + return batch[mask] return batch @@ -997,7 +995,7 @@ def __call__(self, batch: TensorDictBase) -> TensorDictBase: if "mask" in batch.keys(): # if a valid mask is present, it's important to sample only # valid steps - traj_len = batch.get("mask").sum(1).squeeze() + traj_len = batch.get("mask").sum(-1) sub_traj_len = max( self.min_sub_traj_len, min(sub_traj_len, traj_len.min().int().item()), diff --git a/tutorials/sphinx-tutorials/coding_ddpg.py b/tutorials/sphinx-tutorials/coding_ddpg.py index 78fe2973492..91965a9f091 100644 --- a/tutorials/sphinx-tutorials/coding_ddpg.py +++ b/tutorials/sphinx-tutorials/coding_ddpg.py @@ -643,7 +643,7 @@ def make_replay_buffer(make_replay_buffer=3): if "mask" in tensordict.keys(): # if multi-step, a mask is present to help filter padded values current_frames = tensordict["mask"].sum() - tensordict = tensordict[tensordict.get("mask").squeeze(-1)] + tensordict = tensordict[tensordict.get("mask")] else: tensordict = tensordict.view(-1) current_frames = tensordict.numel() diff --git a/tutorials/sphinx-tutorials/coding_dqn.py b/tutorials/sphinx-tutorials/coding_dqn.py index 78003cbd9b2..c5b728f440c 100644 --- a/tutorials/sphinx-tutorials/coding_dqn.py +++ b/tutorials/sphinx-tutorials/coding_dqn.py @@ -365,7 +365,7 @@ def make_model(): pbar = tqdm.tqdm(total=total_frames) for j, data in enumerate(data_collector): # trajectories are padded to be stored in the same tensordict: since we do not care about consecutive step, we'll just mask the tensordict and get the flattened representation instead. - mask = data["mask"].squeeze(-1) + mask = data["mask"] current_frames = mask.sum().cpu().item() pbar.update(current_frames) @@ -602,7 +602,7 @@ def make_model(): pbar = tqdm.tqdm(total=total_frames) for j, data in enumerate(data_collector): - mask = data["mask"].squeeze(-1) + mask = data["mask"] data = pad(data, [0, 0, 0, max_size - data.shape[1]]) current_frames = mask.sum().cpu().item() pbar.update(current_frames)