diff --git a/torchrl/objectives/value/advantages.py b/torchrl/objectives/value/advantages.py index d8d7c128857..bd73683ef0a 100644 --- a/torchrl/objectives/value/advantages.py +++ b/torchrl/objectives/value/advantages.py @@ -46,6 +46,8 @@ class TDEstimate(nn.Module): Defaults to "advantage". value_target_key (str or tuple of str, optional): the key of the advantage entry. Defaults to "value_target". + value_key (str or tuple of str, optional): the value key to read from the input tensordict. + Defaults to "state_value". """ @@ -57,6 +59,7 @@ def __init__( differentiable: bool = False, advantage_key: Union[str, Tuple] = "advantage", value_target_key: Union[str, Tuple] = "value_target", + value_key: Union[str, Tuple] = "state_value", ): super().__init__() try: @@ -68,7 +71,11 @@ def __init__( self.average_rewards = average_rewards self.differentiable = differentiable - self.value_key = value_network.out_keys[0] + self.value_key = value_key + if value_key not in value_network.out_keys: + raise KeyError( + f"value key '{value_key}' not found in value network out_keys." + ) self.advantage_key = advantage_key self.value_target_key = value_target_key @@ -209,6 +216,8 @@ class TDLambdaEstimate(nn.Module): Defaults to "advantage". value_target_key (str or tuple of str, optional): the key of the advantage entry. Defaults to "value_target". + value_key (str or tuple of str, optional): the value key to read from the input tensordict. + Defaults to "state_value". """ @@ -222,6 +231,7 @@ def __init__( vectorized: bool = True, advantage_key: Union[str, Tuple] = "advantage", value_target_key: Union[str, Tuple] = "value_target", + value_key: Union[str, Tuple] = "state_value", ): super().__init__() try: @@ -235,7 +245,11 @@ def __init__( self.average_rewards = average_rewards self.differentiable = differentiable - self.value_key = self.value_network.out_keys[0] + self.value_key = value_key + if value_key not in value_network.out_keys: + raise KeyError( + f"value key '{value_key}' not found in value network out_keys." + ) self.advantage_key = advantage_key self.value_target_key = value_target_key @@ -389,6 +403,8 @@ class GAE(nn.Module): Defaults to "advantage". value_target_key (str or tuple of str, optional): the key of the advantage entry. Defaults to "value_target". + value_key (str or tuple of str, optional): the value key to read from the input tensordict. + Defaults to "state_value". GAE will return an :obj:`"advantage"` entry containing the advange value. It will also return a :obj:`"value_target"` entry with the return value that is to be used @@ -408,6 +424,7 @@ def __init__( differentiable: bool = False, advantage_key: Union[str, Tuple] = "advantage", value_target_key: Union[str, Tuple] = "value_target", + value_key: Union[str, Tuple] = "state_value", ): super().__init__() try: @@ -417,7 +434,11 @@ def __init__( self.register_buffer("gamma", torch.tensor(gamma, device=device)) self.register_buffer("lmbda", torch.tensor(lmbda, device=device)) self.value_network = value_network - self.value_key = self.value_network.out_keys[0] + self.value_key = value_key + if value_key not in value_network.out_keys: + raise KeyError( + f"value key '{value_key}' not found in value network out_keys." + ) self.average_gae = average_gae self.differentiable = differentiable