From 70d9e6096871296b850b7c3b1d12cb3aeb1a4cb2 Mon Sep 17 00:00:00 2001 From: Tom Begley Date: Wed, 21 Dec 2022 11:37:31 +0000 Subject: [PATCH 1/4] Fix test_cost --- torchrl/envs/common.py | 12 ++++++++++-- torchrl/envs/transforms/transforms.py | 6 +++--- torchrl/objectives/deprecated.py | 3 +++ torchrl/objectives/dreamer.py | 7 ++++--- 4 files changed, 20 insertions(+), 8 deletions(-) diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 6eeadf5c733..a37a4752743 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -351,14 +351,22 @@ def step(self, tensordict: TensorDictBase) -> TensorDictBase: reward = tensordict_out.get("reward") # unsqueeze rewards if needed expected_reward_shape = torch.Size([*tensordict_out.batch_size, *self.reward_spec.shape]) - if reward.shape != expected_reward_shape: + n = len(expected_reward_shape) + if len(reward.shape) >= n and reward.shape[-n:] != expected_reward_shape: + reward = reward.view(*reward.shape[:n], *expected_reward_shape) + tensordict_out.set("reward", reward) + elif len(reward.shape) < n: reward = reward.view(expected_reward_shape) tensordict_out.set("reward", reward) done = tensordict_out.get("done") # unsqueeze done if needed expected_done_shape = torch.Size([*tensordict_out.batch_size, 1]) - if done.shape != expected_done_shape: + n = len(expected_done_shape) + if len(done.shape) >= n and done.shape[-n:] != expected_done_shape: + done = done.view(*done.shape[:n], *expected_done_shape) + tensordict_out.set("done", done) + elif len(done.shape) < n: done = done.view(expected_done_shape) tensordict_out.set("done", done) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 569a3ac550d..c9612abde2b 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -561,9 +561,9 @@ def __repr__(self) -> str: def _erase_metadata(self): if self.cache_specs: - self._input_spec = None - self._observation_spec = None - self._reward_spec = None + self.__dict__["_input_spec"] = None + self.__dict__["_observation_spec"] = None + self.__dict__["_reward_spec"] = None def to(self, device: DEVICE_TYPING) -> TransformedEnv: self.base_env.to(device) diff --git a/torchrl/objectives/deprecated.py b/torchrl/objectives/deprecated.py index b80cf2854ff..38a284079cb 100644 --- a/torchrl/objectives/deprecated.py +++ b/torchrl/objectives/deprecated.py @@ -218,6 +218,9 @@ def _qvalue_loss(self, tensordict: TensorDictBase) -> Tensor: next_td, selected_q_params, ) + state_action_value = next_td.get("state_action_value") + if state_action_value.shape[-len(sample_log_prob.shape):] != sample_log_prob.shape: + sample_log_prob = sample_log_prob.unsqueeze(-1) state_value = ( next_td.get("state_action_value") - self.alpha * sample_log_prob ) diff --git a/torchrl/objectives/dreamer.py b/torchrl/objectives/dreamer.py index 8d839078154..c42c8e29387 100644 --- a/torchrl/objectives/dreamer.py +++ b/torchrl/objectives/dreamer.py @@ -73,7 +73,7 @@ def forward(self, tensordict: TensorDict) -> torch.Tensor: tensordict.get(("next", "prior_std")), tensordict.get(("next", "posterior_mean")), tensordict.get(("next", "posterior_std")), - ) + ).unsqueeze(-1) reco_loss = distance_loss( tensordict.get(("next", "pixels")), tensordict.get(("next", "reco_pixels")), @@ -81,7 +81,7 @@ def forward(self, tensordict: TensorDict) -> torch.Tensor: ) if not self.global_average: reco_loss = reco_loss.sum((-3, -2, -1)) - reco_loss = reco_loss.mean() + reco_loss = reco_loss.mean().unsqueeze(-1) reward_loss = distance_loss( tensordict.get("true_reward"), @@ -90,7 +90,8 @@ def forward(self, tensordict: TensorDict) -> torch.Tensor: ) if not self.global_average: reward_loss = reward_loss.squeeze(-1) - reward_loss = reward_loss.mean() + reward_loss = reward_loss.mean().unsqueeze(-1) + # import ipdb; ipdb.set_trace() return ( TensorDict( { From ce4f4d23db3775a82f567d52c6610a19d4b8a72f Mon Sep 17 00:00:00 2001 From: Tom Begley Date: Wed, 21 Dec 2022 11:38:53 +0000 Subject: [PATCH 2/4] Lint and format --- torchrl/data/tensor_specs.py | 8 +++++--- torchrl/envs/common.py | 4 +++- torchrl/envs/gym_like.py | 4 +++- torchrl/envs/libs/gym.py | 2 +- torchrl/objectives/deprecated.py | 5 ++++- 5 files changed, 16 insertions(+), 7 deletions(-) diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index debe7e43a1c..debcb7f1623 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -222,13 +222,15 @@ def encode(self, val: Union[np.ndarray, torch.Tensor]) -> torch.Tensor: ): val = val.copy() val = torch.tensor(val, dtype=self.dtype, device=self.device) - if val.shape[-len(self.shape):] != self.shape: + if val.shape[-len(self.shape) :] != self.shape: # option 1: add a singleton dim at the end if self.shape == torch.Size([1]): val = val.unsqueeze(-1) else: - raise RuntimeError(f"Shape mismatch: the value has shape {val.shape} which " - f"is incompatible with the spec shape {self.shape}.") + raise RuntimeError( + f"Shape mismatch: the value has shape {val.shape} which " + f"is incompatible with the spec shape {self.shape}." + ) if not _NO_CHECK_SPEC_ENCODE: self.assert_is_in(val) return val diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index a37a4752743..920b7e3ad7b 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -350,7 +350,9 @@ def step(self, tensordict: TensorDictBase) -> TensorDictBase: reward = tensordict_out.get("reward") # unsqueeze rewards if needed - expected_reward_shape = torch.Size([*tensordict_out.batch_size, *self.reward_spec.shape]) + expected_reward_shape = torch.Size( + [*tensordict_out.batch_size, *self.reward_spec.shape] + ) n = len(expected_reward_shape) if len(reward.shape) >= n and reward.shape[-n:] != expected_reward_shape: reward = reward.view(*reward.shape[:n], *expected_reward_shape) diff --git a/torchrl/envs/gym_like.py b/torchrl/envs/gym_like.py index 5b117be37c6..f3fd531be44 100644 --- a/torchrl/envs/gym_like.py +++ b/torchrl/envs/gym_like.py @@ -204,7 +204,9 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: reward = self.read_reward(reward, _reward) - if isinstance(done, bool) or (isinstance(done, np.ndarray) and not len(done)): + if isinstance(done, bool) or ( + isinstance(done, np.ndarray) and not len(done) + ): done = torch.tensor([done], device=self.device) done, do_break = self.read_done(done) diff --git a/torchrl/envs/libs/gym.py b/torchrl/envs/libs/gym.py index 2ef53b188a6..64e9c756e1d 100644 --- a/torchrl/envs/libs/gym.py +++ b/torchrl/envs/libs/gym.py @@ -15,9 +15,9 @@ DiscreteTensorSpec, MultOneHotDiscreteTensorSpec, NdBoundedTensorSpec, + NdUnboundedContinuousTensorSpec, OneHotDiscreteTensorSpec, TensorSpec, - UnboundedContinuousTensorSpec, NdUnboundedContinuousTensorSpec, ) from ..._utils import implement_for diff --git a/torchrl/objectives/deprecated.py b/torchrl/objectives/deprecated.py index 38a284079cb..69d09fc3fe0 100644 --- a/torchrl/objectives/deprecated.py +++ b/torchrl/objectives/deprecated.py @@ -219,7 +219,10 @@ def _qvalue_loss(self, tensordict: TensorDictBase) -> Tensor: selected_q_params, ) state_action_value = next_td.get("state_action_value") - if state_action_value.shape[-len(sample_log_prob.shape):] != sample_log_prob.shape: + if ( + state_action_value.shape[-len(sample_log_prob.shape) :] + != sample_log_prob.shape + ): sample_log_prob = sample_log_prob.unsqueeze(-1) state_value = ( next_td.get("state_action_value") - self.alpha * sample_log_prob From 84b5f37acebc49d2a9fad2414fdb581c78e62d3f Mon Sep 17 00:00:00 2001 From: Tom Begley Date: Wed, 21 Dec 2022 11:55:51 +0000 Subject: [PATCH 3/4] Fix test_postprocs --- test/test_postprocs.py | 14 +++----------- torchrl/collectors/utils.py | 2 +- 2 files changed, 4 insertions(+), 12 deletions(-) diff --git a/test/test_postprocs.py b/test/test_postprocs.py index d50b74bb08f..38cbf806f68 100644 --- a/test/test_postprocs.py +++ b/test/test_postprocs.py @@ -97,8 +97,8 @@ def create_fake_trajs( num_workers=32, traj_len=200, ): - traj_ids = torch.arange(num_workers).unsqueeze(-1) - steps_count = torch.zeros(num_workers).unsqueeze(-1) + traj_ids = torch.arange(num_workers) + steps_count = torch.zeros(num_workers) workers = torch.arange(num_workers) out = [] @@ -125,15 +125,7 @@ def create_fake_trajs( return out @pytest.mark.parametrize("num_workers", range(3, 34, 3)) - @pytest.mark.parametrize( - "traj_len", - [ - 10, - 17, - 50, - 97, - ], - ) + @pytest.mark.parametrize("traj_len", [10, 17, 50, 97]) def test_splits(self, num_workers, traj_len): trajs = TestSplits.create_fake_trajs(num_workers, traj_len) diff --git a/torchrl/collectors/utils.py b/torchrl/collectors/utils.py index 9edbb81c0e5..8a7c360b37b 100644 --- a/torchrl/collectors/utils.py +++ b/torchrl/collectors/utils.py @@ -72,7 +72,7 @@ def split_trajectories(rollout_tensordict: TensorDictBase) -> TensorDictBase: td = TensorDict( source=out_dict, device=rollout_tensordict.device, - batch_size=out_dict["mask"].shape[:-1], + batch_size=out_dict["mask"].shape, ) td = td.unflatten_keys(sep) if (out_dict["done"].sum(1) > 1).any(): From d61f7a53fec5717fe12e35dc414470c933c2d58e Mon Sep 17 00:00:00 2001 From: Tom Begley Date: Wed, 21 Dec 2022 16:53:40 +0000 Subject: [PATCH 4/4] Fix test_collectors --- torchrl/collectors/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchrl/collectors/utils.py b/torchrl/collectors/utils.py index 8a7c360b37b..9ea8c67e831 100644 --- a/torchrl/collectors/utils.py +++ b/torchrl/collectors/utils.py @@ -69,10 +69,11 @@ def split_trajectories(rollout_tensordict: TensorDictBase) -> TensorDictBase: key: torch.nn.utils.rnn.pad_sequence(_o, batch_first=True) for key, _o in out_splits.items() } + out_dict["mask"] = out_dict["mask"].squeeze(-1) td = TensorDict( source=out_dict, device=rollout_tensordict.device, - batch_size=out_dict["mask"].shape, + batch_size=out_dict["mask"].squeeze(-1).shape, ) td = td.unflatten_keys(sep) if (out_dict["done"].sum(1) > 1).any():