Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
ac1f620
sumreward transform and tests
albertbou92 Dec 16, 2022
a10e138
test fix
albertbou92 Dec 16, 2022
d87b909
pre-commit ok
albertbou92 Dec 16, 2022
7fb7ae4
tests ok
albertbou92 Dec 16, 2022
839d9c1
pre-commit ok
albertbou92 Dec 16, 2022
de2bb10
dont use done flag
albertbou92 Dec 20, 2022
53e49e0
dont use done flag
albertbou92 Dec 20, 2022
c0ed80c
format
albertbou92 Dec 20, 2022
a25cecd
dont use done flag
albertbou92 Dec 20, 2022
98f0e4c
format
albertbou92 Dec 20, 2022
ad51d80
fix tests
albertbou92 Dec 20, 2022
daa028b
minor fix
albertbou92 Dec 21, 2022
9702041
suggested changes
albertbou92 Dec 22, 2022
d89acce
format
albertbou92 Dec 22, 2022
756138d
format
albertbou92 Dec 22, 2022
36ded1b
format
albertbou92 Dec 22, 2022
981b73d
fix
albertbou92 Dec 22, 2022
13bb979
fix
albertbou92 Dec 22, 2022
808bb0b
fix
albertbou92 Dec 22, 2022
e6a67d6
fix
albertbou92 Dec 22, 2022
d259f3e
transform obs spec
albertbou92 Dec 23, 2022
de25e32
format
albertbou92 Dec 23, 2022
b49add8
transform obs spec
albertbou92 Dec 23, 2022
d307485
tests
albertbou92 Dec 23, 2022
265fc2b
minor change
albertbou92 Dec 23, 2022
f10dcf7
docs
albertbou92 Dec 23, 2022
ad8e9bc
docs
albertbou92 Dec 23, 2022
b192006
review suggested fixes
albertbou92 Dec 23, 2022
8be9600
format
albertbou92 Dec 23, 2022
185ba73
Merge branch 'main' into sumreward_transform
albertbou92 Jan 2, 2023
6aebe0e
fix
albertbou92 Jan 2, 2023
2d4150a
fix
albertbou92 Jan 2, 2023
4e98a0c
fix
albertbou92 Jan 2, 2023
ab0ad7b
format
albertbou92 Jan 2, 2023
0902247
fix
albertbou92 Jan 2, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/reference/envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ in the environment. The keys to be included in this inverse transform are passed
VecNorm
gSDENoise
TensorDictPrimer
RewardSum
R3MTransform
VIPTransform
VIPRewardTransform
Expand Down
59 changes: 59 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
Resize,
RewardClipping,
RewardScaling,
RewardSum,
SerialEnv,
StepCounter,
ToTensorImage,
Expand Down Expand Up @@ -803,6 +804,64 @@ def test_grayscale(self, keys, device):
for key in keys:
assert observation_spec[key].shape == torch.Size([1, 16, 16])

@pytest.mark.parametrize(
"keys",
[["done", "reward"]],
)
@pytest.mark.parametrize("device", get_available_devices())
def test_sum_reward(self, keys, device):
torch.manual_seed(0)
batch = 4
rs = RewardSum()
td = TensorDict(
{
"done": torch.zeros((batch, 1), dtype=torch.bool),
"reward": torch.rand((batch, 1)),
},
device=device,
batch_size=[batch],
)

# apply one time, episode_reward should be equal to reward again
td = rs(td)
assert "episode_reward" in td.keys()
assert (td.get("episode_reward") == td.get("reward")).all()

# apply a second time, episode_reward should twice the reward
td = rs(td)
assert (td.get("episode_reward") == 2 * td.get("reward")).all()

# reset environments
td.set("reset_workers", torch.ones((batch, 1), dtype=torch.bool, device=device))
rs.reset(td)

# apply a third time, episode_reward should be equal to reward again
td = rs(td)
assert (td.get("episode_reward") == td.get("reward")).all()

# test transform_observation_spec
base_env = ContinuousActionVecMockEnv(
reward_spec=UnboundedContinuousTensorSpec(shape=(3, 16, 16)),
)
transfomed_env = TransformedEnv(base_env, RewardSum())
transformed_observation_spec1 = transfomed_env.specs["observation_spec"]
assert isinstance(transformed_observation_spec1, CompositeSpec)
assert "episode_reward" in transformed_observation_spec1.keys()
assert "observation" in transformed_observation_spec1.keys()

base_env = ContinuousActionVecMockEnv(
reward_spec=UnboundedContinuousTensorSpec(),
observation_spec=CompositeSpec(
observation=UnboundedContinuousTensorSpec(),
some_extra_observation=UnboundedContinuousTensorSpec(),
),
)
transfomed_env = TransformedEnv(base_env, RewardSum())
transformed_observation_spec2 = transfomed_env.specs["observation_spec"]
assert isinstance(transformed_observation_spec2, CompositeSpec)
assert "some_extra_observation" in transformed_observation_spec2.keys()
assert "episode_reward" in transformed_observation_spec2.keys()

@pytest.mark.parametrize("batch", [[], [1], [3, 2]])
@pytest.mark.parametrize(
"keys",
Expand Down
1 change: 1 addition & 0 deletions torchrl/envs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
Resize,
RewardClipping,
RewardScaling,
RewardSum,
StepCounter,
TensorDictPrimer,
ToTensorImage,
Expand Down
1 change: 1 addition & 0 deletions torchrl/envs/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
Resize,
RewardClipping,
RewardScaling,
RewardSum,
SqueezeTransform,
StepCounter,
TensorDictPrimer,
Expand Down
118 changes: 118 additions & 0 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2479,6 +2479,124 @@ def __repr__(self) -> str:
)


class RewardSum(Transform):
"""Tracks episode cumulative rewards.

This transform accepts a list of tensordict reward keys (i.e. ´in_keys´) and tracks their cumulative
value along each episode. When called, the transform creates a new tensordict key for each in_key named
´episode_{in_key}´ where the cumulative values are written. All ´in_keys´ should be part of the env
reward and be present in the env reward_spec.

If no in_keys are specified, this transform assumes ´reward´ to be the input key. However, multiple rewards
(e.g. reward1 and reward2) can also be specified. If ´in_keys´ are not present in the provided tensordict,
this transform hos no effect.
"""

inplace = True

def __init__(
self,
in_keys: Optional[Sequence[str]] = None,
out_keys: Optional[Sequence[str]] = None,
):
"""Initialises the transform. Filters out non-reward input keys and defines output keys."""
if in_keys is None:
in_keys = ["reward"]
out_keys = [f"episode_{in_key}" for in_key in in_keys]

super().__init__(in_keys=in_keys, out_keys=out_keys)

def reset(self, tensordict: TensorDictBase) -> TensorDictBase:
"""Resets episode rewards."""
# Non-batched environments
if len(tensordict.batch_size) < 1 or tensordict.batch_size[0] == 1:
for out_key in self.out_keys:
if out_key in tensordict.keys():
tensordict[out_key] = 0.0

# Batched environments
else:
reset_workers = tensordict.get(
"reset_workers",
torch.ones(
*tensordict.batch_size,
1,
dtype=torch.bool,
device=tensordict.device,
),
)
for out_key in self.out_keys:
if out_key in tensordict.keys():
tensordict[out_key][reset_workers] = 0.0

return tensordict

def _call(self, tensordict: TensorDictBase) -> TensorDictBase:
"""Updates the episode rewards with the step rewards."""
# Sanity checks
self._check_inplace()
for in_key in self.in_keys:
if in_key not in tensordict.keys():
return tensordict

# Update episode rewards
for in_key, out_key in zip(self.in_keys, self.out_keys):
reward = tensordict.get(in_key)
if out_key not in tensordict.keys():
tensordict.set(
out_key,
torch.zeros(
*tensordict.shape, 1, dtype=reward.dtype, device=reward.device
),
)
tensordict[out_key] += reward

return tensordict

def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec:
"""Transforms the observation spec, adding the new keys generated by RewardSum."""
# Retrieve parent reward spec
reward_spec = self.parent.specs["reward_spec"]

episode_specs = {}
if isinstance(reward_spec, CompositeSpec):

# If reward_spec is a CompositeSpec, all in_keys should be keys of reward_spec
if not all([k in reward_spec.keys() for k in self.in_keys]):
raise KeyError("Not all in_keys are present in ´reward_spec´")

# Define episode specs for all out_keys
for out_key in self.out_keys:
episode_spec = UnboundedContinuousTensorSpec(
shape=reward_spec.shape,
device=reward_spec.device,
dtype=reward_spec.dtype,
)
episode_specs.update({out_key: episode_spec})

else:

# If reward_spec is not a CompositeSpec, the only in_key should be ´reward´
if not set(self.in_keys) == {"reward"}:
raise KeyError(
"reward_spec is not a CompositeSpec class, in_keys should only include ´reward´"
)

# Define episode spec
episode_spec = UnboundedContinuousTensorSpec(
device=reward_spec.device,
dtype=reward_spec.dtype,
shape=reward_spec.shape,
)
episode_specs.update({"episode_reward": episode_spec})

# Update observation_spec with episode_specs
if not isinstance(observation_spec, CompositeSpec):
observation_spec = CompositeSpec(observation=observation_spec)
observation_spec.update(episode_specs)
return observation_spec


class StepCounter(Transform):
"""Counts the steps from a reset and sets the done state to True after a certain number of steps.

Expand Down