@@ -46,6 +46,8 @@ class TDEstimate(nn.Module):
4646 Defaults to "advantage".
4747 value_target_key (str or tuple of str, optional): the key of the advantage entry.
4848 Defaults to "value_target".
49+ value_key (str or tuple of str, optional): the value key to read from the input tensordict.
50+ Defaults to "state_value".
4951
5052 """
5153
@@ -57,6 +59,7 @@ def __init__(
5759 differentiable : bool = False ,
5860 advantage_key : Union [str , Tuple ] = "advantage" ,
5961 value_target_key : Union [str , Tuple ] = "value_target" ,
62+ value_key : Union [str , Tuple ] = "state_value" ,
6063 ):
6164 super ().__init__ ()
6265 try :
@@ -68,7 +71,11 @@ def __init__(
6871
6972 self .average_rewards = average_rewards
7073 self .differentiable = differentiable
71- self .value_key = value_network .out_keys [0 ]
74+ self .value_key = value_key
75+ if value_key not in value_network .out_keys :
76+ raise KeyError (
77+ f"value key '{ value_key } ' not found in value network out_keys."
78+ )
7279
7380 self .advantage_key = advantage_key
7481 self .value_target_key = value_target_key
@@ -209,6 +216,8 @@ class TDLambdaEstimate(nn.Module):
209216 Defaults to "advantage".
210217 value_target_key (str or tuple of str, optional): the key of the advantage entry.
211218 Defaults to "value_target".
219+ value_key (str or tuple of str, optional): the value key to read from the input tensordict.
220+ Defaults to "state_value".
212221
213222 """
214223
@@ -222,6 +231,7 @@ def __init__(
222231 vectorized : bool = True ,
223232 advantage_key : Union [str , Tuple ] = "advantage" ,
224233 value_target_key : Union [str , Tuple ] = "value_target" ,
234+ value_key : Union [str , Tuple ] = "state_value" ,
225235 ):
226236 super ().__init__ ()
227237 try :
@@ -235,7 +245,11 @@ def __init__(
235245
236246 self .average_rewards = average_rewards
237247 self .differentiable = differentiable
238- self .value_key = self .value_network .out_keys [0 ]
248+ self .value_key = value_key
249+ if value_key not in value_network .out_keys :
250+ raise KeyError (
251+ f"value key '{ value_key } ' not found in value network out_keys."
252+ )
239253
240254 self .advantage_key = advantage_key
241255 self .value_target_key = value_target_key
@@ -389,6 +403,8 @@ class GAE(nn.Module):
389403 Defaults to "advantage".
390404 value_target_key (str or tuple of str, optional): the key of the advantage entry.
391405 Defaults to "value_target".
406+ value_key (str or tuple of str, optional): the value key to read from the input tensordict.
407+ Defaults to "state_value".
392408
393409 GAE will return an :obj:`"advantage"` entry containing the advange value. It will also
394410 return a :obj:`"value_target"` entry with the return value that is to be used
@@ -408,6 +424,7 @@ def __init__(
408424 differentiable : bool = False ,
409425 advantage_key : Union [str , Tuple ] = "advantage" ,
410426 value_target_key : Union [str , Tuple ] = "value_target" ,
427+ value_key : Union [str , Tuple ] = "state_value" ,
411428 ):
412429 super ().__init__ ()
413430 try :
@@ -417,7 +434,11 @@ def __init__(
417434 self .register_buffer ("gamma" , torch .tensor (gamma , device = device ))
418435 self .register_buffer ("lmbda" , torch .tensor (lmbda , device = device ))
419436 self .value_network = value_network
420- self .value_key = self .value_network .out_keys [0 ]
437+ self .value_key = value_key
438+ if value_key not in value_network .out_keys :
439+ raise KeyError (
440+ f"value key '{ value_key } ' not found in value network out_keys."
441+ )
421442
422443 self .average_gae = average_gae
423444 self .differentiable = differentiable
0 commit comments