diff --git a/docs/source/reference/envs.rst b/docs/source/reference/envs.rst index be856e59298..114d9f168fd 100644 --- a/docs/source/reference/envs.rst +++ b/docs/source/reference/envs.rst @@ -216,6 +216,7 @@ in the environment. The keys to be included in this inverse transform are passed VecNorm gSDENoise TensorDictPrimer + RewardSum R3MTransform VIPTransform VIPRewardTransform diff --git a/test/test_transforms.py b/test/test_transforms.py index 090d0837585..8d3e40f7550 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -42,6 +42,7 @@ Resize, RewardClipping, RewardScaling, + RewardSum, SerialEnv, StepCounter, ToTensorImage, @@ -803,6 +804,64 @@ def test_grayscale(self, keys, device): for key in keys: assert observation_spec[key].shape == torch.Size([1, 16, 16]) + @pytest.mark.parametrize( + "keys", + [["done", "reward"]], + ) + @pytest.mark.parametrize("device", get_available_devices()) + def test_sum_reward(self, keys, device): + torch.manual_seed(0) + batch = 4 + rs = RewardSum() + td = TensorDict( + { + "done": torch.zeros((batch, 1), dtype=torch.bool), + "reward": torch.rand((batch, 1)), + }, + device=device, + batch_size=[batch], + ) + + # apply one time, episode_reward should be equal to reward again + td = rs(td) + assert "episode_reward" in td.keys() + assert (td.get("episode_reward") == td.get("reward")).all() + + # apply a second time, episode_reward should twice the reward + td = rs(td) + assert (td.get("episode_reward") == 2 * td.get("reward")).all() + + # reset environments + td.set("reset_workers", torch.ones((batch, 1), dtype=torch.bool, device=device)) + rs.reset(td) + + # apply a third time, episode_reward should be equal to reward again + td = rs(td) + assert (td.get("episode_reward") == td.get("reward")).all() + + # test transform_observation_spec + base_env = ContinuousActionVecMockEnv( + reward_spec=UnboundedContinuousTensorSpec(shape=(3, 16, 16)), + ) + transfomed_env = TransformedEnv(base_env, RewardSum()) + transformed_observation_spec1 = transfomed_env.specs["observation_spec"] + assert isinstance(transformed_observation_spec1, CompositeSpec) + assert "episode_reward" in transformed_observation_spec1.keys() + assert "observation" in transformed_observation_spec1.keys() + + base_env = ContinuousActionVecMockEnv( + reward_spec=UnboundedContinuousTensorSpec(), + observation_spec=CompositeSpec( + observation=UnboundedContinuousTensorSpec(), + some_extra_observation=UnboundedContinuousTensorSpec(), + ), + ) + transfomed_env = TransformedEnv(base_env, RewardSum()) + transformed_observation_spec2 = transfomed_env.specs["observation_spec"] + assert isinstance(transformed_observation_spec2, CompositeSpec) + assert "some_extra_observation" in transformed_observation_spec2.keys() + assert "episode_reward" in transformed_observation_spec2.keys() + @pytest.mark.parametrize("batch", [[], [1], [3, 2]]) @pytest.mark.parametrize( "keys", diff --git a/torchrl/envs/__init__.py b/torchrl/envs/__init__.py index 1bc51074197..f8b6630b703 100644 --- a/torchrl/envs/__init__.py +++ b/torchrl/envs/__init__.py @@ -26,6 +26,7 @@ Resize, RewardClipping, RewardScaling, + RewardSum, StepCounter, TensorDictPrimer, ToTensorImage, diff --git a/torchrl/envs/transforms/__init__.py b/torchrl/envs/transforms/__init__.py index 61ee5c20d7d..b10a9efdc56 100644 --- a/torchrl/envs/transforms/__init__.py +++ b/torchrl/envs/transforms/__init__.py @@ -23,6 +23,7 @@ Resize, RewardClipping, RewardScaling, + RewardSum, SqueezeTransform, StepCounter, TensorDictPrimer, diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 11587053295..3dc630b7f31 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -2479,6 +2479,124 @@ def __repr__(self) -> str: ) +class RewardSum(Transform): + """Tracks episode cumulative rewards. + + This transform accepts a list of tensordict reward keys (i.e. ´in_keys´) and tracks their cumulative + value along each episode. When called, the transform creates a new tensordict key for each in_key named + ´episode_{in_key}´ where the cumulative values are written. All ´in_keys´ should be part of the env + reward and be present in the env reward_spec. + + If no in_keys are specified, this transform assumes ´reward´ to be the input key. However, multiple rewards + (e.g. reward1 and reward2) can also be specified. If ´in_keys´ are not present in the provided tensordict, + this transform hos no effect. + """ + + inplace = True + + def __init__( + self, + in_keys: Optional[Sequence[str]] = None, + out_keys: Optional[Sequence[str]] = None, + ): + """Initialises the transform. Filters out non-reward input keys and defines output keys.""" + if in_keys is None: + in_keys = ["reward"] + out_keys = [f"episode_{in_key}" for in_key in in_keys] + + super().__init__(in_keys=in_keys, out_keys=out_keys) + + def reset(self, tensordict: TensorDictBase) -> TensorDictBase: + """Resets episode rewards.""" + # Non-batched environments + if len(tensordict.batch_size) < 1 or tensordict.batch_size[0] == 1: + for out_key in self.out_keys: + if out_key in tensordict.keys(): + tensordict[out_key] = 0.0 + + # Batched environments + else: + reset_workers = tensordict.get( + "reset_workers", + torch.ones( + *tensordict.batch_size, + 1, + dtype=torch.bool, + device=tensordict.device, + ), + ) + for out_key in self.out_keys: + if out_key in tensordict.keys(): + tensordict[out_key][reset_workers] = 0.0 + + return tensordict + + def _call(self, tensordict: TensorDictBase) -> TensorDictBase: + """Updates the episode rewards with the step rewards.""" + # Sanity checks + self._check_inplace() + for in_key in self.in_keys: + if in_key not in tensordict.keys(): + return tensordict + + # Update episode rewards + for in_key, out_key in zip(self.in_keys, self.out_keys): + reward = tensordict.get(in_key) + if out_key not in tensordict.keys(): + tensordict.set( + out_key, + torch.zeros( + *tensordict.shape, 1, dtype=reward.dtype, device=reward.device + ), + ) + tensordict[out_key] += reward + + return tensordict + + def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: + """Transforms the observation spec, adding the new keys generated by RewardSum.""" + # Retrieve parent reward spec + reward_spec = self.parent.specs["reward_spec"] + + episode_specs = {} + if isinstance(reward_spec, CompositeSpec): + + # If reward_spec is a CompositeSpec, all in_keys should be keys of reward_spec + if not all([k in reward_spec.keys() for k in self.in_keys]): + raise KeyError("Not all in_keys are present in ´reward_spec´") + + # Define episode specs for all out_keys + for out_key in self.out_keys: + episode_spec = UnboundedContinuousTensorSpec( + shape=reward_spec.shape, + device=reward_spec.device, + dtype=reward_spec.dtype, + ) + episode_specs.update({out_key: episode_spec}) + + else: + + # If reward_spec is not a CompositeSpec, the only in_key should be ´reward´ + if not set(self.in_keys) == {"reward"}: + raise KeyError( + "reward_spec is not a CompositeSpec class, in_keys should only include ´reward´" + ) + + # Define episode spec + episode_spec = UnboundedContinuousTensorSpec( + device=reward_spec.device, + dtype=reward_spec.dtype, + shape=reward_spec.shape, + ) + episode_specs.update({"episode_reward": episode_spec}) + + # Update observation_spec with episode_specs + if not isinstance(observation_spec, CompositeSpec): + observation_spec = CompositeSpec(observation=observation_spec) + observation_spec.update(episode_specs) + return observation_spec + + class StepCounter(Transform): """Counts the steps from a reset and sets the done state to True after a certain number of steps.