Skip to content

Commit a3d3164

Browse files
authored
[BugFix] Reinstantiate custom value key for multioutput value networks (#754)
1 parent f6df86c commit a3d3164

File tree

1 file changed

+24
-3
lines changed

1 file changed

+24
-3
lines changed

torchrl/objectives/value/advantages.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ class TDEstimate(nn.Module):
4646
Defaults to "advantage".
4747
value_target_key (str or tuple of str, optional): the key of the advantage entry.
4848
Defaults to "value_target".
49+
value_key (str or tuple of str, optional): the value key to read from the input tensordict.
50+
Defaults to "state_value".
4951
5052
"""
5153

@@ -57,6 +59,7 @@ def __init__(
5759
differentiable: bool = False,
5860
advantage_key: Union[str, Tuple] = "advantage",
5961
value_target_key: Union[str, Tuple] = "value_target",
62+
value_key: Union[str, Tuple] = "state_value",
6063
):
6164
super().__init__()
6265
try:
@@ -68,7 +71,11 @@ def __init__(
6871

6972
self.average_rewards = average_rewards
7073
self.differentiable = differentiable
71-
self.value_key = value_network.out_keys[0]
74+
self.value_key = value_key
75+
if value_key not in value_network.out_keys:
76+
raise KeyError(
77+
f"value key '{value_key}' not found in value network out_keys."
78+
)
7279

7380
self.advantage_key = advantage_key
7481
self.value_target_key = value_target_key
@@ -209,6 +216,8 @@ class TDLambdaEstimate(nn.Module):
209216
Defaults to "advantage".
210217
value_target_key (str or tuple of str, optional): the key of the advantage entry.
211218
Defaults to "value_target".
219+
value_key (str or tuple of str, optional): the value key to read from the input tensordict.
220+
Defaults to "state_value".
212221
213222
"""
214223

@@ -222,6 +231,7 @@ def __init__(
222231
vectorized: bool = True,
223232
advantage_key: Union[str, Tuple] = "advantage",
224233
value_target_key: Union[str, Tuple] = "value_target",
234+
value_key: Union[str, Tuple] = "state_value",
225235
):
226236
super().__init__()
227237
try:
@@ -235,7 +245,11 @@ def __init__(
235245

236246
self.average_rewards = average_rewards
237247
self.differentiable = differentiable
238-
self.value_key = self.value_network.out_keys[0]
248+
self.value_key = value_key
249+
if value_key not in value_network.out_keys:
250+
raise KeyError(
251+
f"value key '{value_key}' not found in value network out_keys."
252+
)
239253

240254
self.advantage_key = advantage_key
241255
self.value_target_key = value_target_key
@@ -389,6 +403,8 @@ class GAE(nn.Module):
389403
Defaults to "advantage".
390404
value_target_key (str or tuple of str, optional): the key of the advantage entry.
391405
Defaults to "value_target".
406+
value_key (str or tuple of str, optional): the value key to read from the input tensordict.
407+
Defaults to "state_value".
392408
393409
GAE will return an :obj:`"advantage"` entry containing the advange value. It will also
394410
return a :obj:`"value_target"` entry with the return value that is to be used
@@ -408,6 +424,7 @@ def __init__(
408424
differentiable: bool = False,
409425
advantage_key: Union[str, Tuple] = "advantage",
410426
value_target_key: Union[str, Tuple] = "value_target",
427+
value_key: Union[str, Tuple] = "state_value",
411428
):
412429
super().__init__()
413430
try:
@@ -417,7 +434,11 @@ def __init__(
417434
self.register_buffer("gamma", torch.tensor(gamma, device=device))
418435
self.register_buffer("lmbda", torch.tensor(lmbda, device=device))
419436
self.value_network = value_network
420-
self.value_key = self.value_network.out_keys[0]
437+
self.value_key = value_key
438+
if value_key not in value_network.out_keys:
439+
raise KeyError(
440+
f"value key '{value_key}' not found in value network out_keys."
441+
)
421442

422443
self.average_gae = average_gae
423444
self.differentiable = differentiable

0 commit comments

Comments
 (0)