From 4f0ba72f6a4b8c948c8d1f16b3488cb764ae2cae Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 20 Dec 2022 13:24:47 +0000 Subject: [PATCH 1/2] init --- torchrl/objectives/value/advantages.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/torchrl/objectives/value/advantages.py b/torchrl/objectives/value/advantages.py index d8d7c128857..ca8f30b5a75 100644 --- a/torchrl/objectives/value/advantages.py +++ b/torchrl/objectives/value/advantages.py @@ -46,6 +46,8 @@ class TDEstimate(nn.Module): Defaults to "advantage". value_target_key (str or tuple of str, optional): the key of the advantage entry. Defaults to "value_target". + value_key (str or tuple of str, optional): the value key to read from the input tensordict. + Defaults to "state_value". """ @@ -57,6 +59,7 @@ def __init__( differentiable: bool = False, advantage_key: Union[str, Tuple] = "advantage", value_target_key: Union[str, Tuple] = "value_target", + value_key: Union[str, Tuple] = "state_value", ): super().__init__() try: @@ -68,7 +71,9 @@ def __init__( self.average_rewards = average_rewards self.differentiable = differentiable - self.value_key = value_network.out_keys[0] + self.value_key = value_key + if value_key not in value_network.out_keys: + raise KeyError(f"value key '{value_key}' not found in value network out_keys.") self.advantage_key = advantage_key self.value_target_key = value_target_key @@ -209,6 +214,8 @@ class TDLambdaEstimate(nn.Module): Defaults to "advantage". value_target_key (str or tuple of str, optional): the key of the advantage entry. Defaults to "value_target". + value_key (str or tuple of str, optional): the value key to read from the input tensordict. + Defaults to "state_value". """ @@ -222,6 +229,7 @@ def __init__( vectorized: bool = True, advantage_key: Union[str, Tuple] = "advantage", value_target_key: Union[str, Tuple] = "value_target", + value_key: Union[str, Tuple] = "state_value", ): super().__init__() try: @@ -235,7 +243,9 @@ def __init__( self.average_rewards = average_rewards self.differentiable = differentiable - self.value_key = self.value_network.out_keys[0] + self.value_key = value_key + if value_key not in value_network.out_keys: + raise KeyError(f"value key '{value_key}' not found in value network out_keys.") self.advantage_key = advantage_key self.value_target_key = value_target_key @@ -389,6 +399,8 @@ class GAE(nn.Module): Defaults to "advantage". value_target_key (str or tuple of str, optional): the key of the advantage entry. Defaults to "value_target". + value_key (str or tuple of str, optional): the value key to read from the input tensordict. + Defaults to "state_value". GAE will return an :obj:`"advantage"` entry containing the advange value. It will also return a :obj:`"value_target"` entry with the return value that is to be used @@ -408,6 +420,7 @@ def __init__( differentiable: bool = False, advantage_key: Union[str, Tuple] = "advantage", value_target_key: Union[str, Tuple] = "value_target", + value_key: Union[str, Tuple] = "state_value", ): super().__init__() try: @@ -417,7 +430,9 @@ def __init__( self.register_buffer("gamma", torch.tensor(gamma, device=device)) self.register_buffer("lmbda", torch.tensor(lmbda, device=device)) self.value_network = value_network - self.value_key = self.value_network.out_keys[0] + self.value_key = value_key + if value_key not in value_network.out_keys: + raise KeyError(f"value key '{value_key}' not found in value network out_keys.") self.average_gae = average_gae self.differentiable = differentiable From 56b0913bd951911b4221a9a799d26082677b1a0e Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 20 Dec 2022 13:28:48 +0000 Subject: [PATCH 2/2] lint --- torchrl/objectives/value/advantages.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/torchrl/objectives/value/advantages.py b/torchrl/objectives/value/advantages.py index ca8f30b5a75..bd73683ef0a 100644 --- a/torchrl/objectives/value/advantages.py +++ b/torchrl/objectives/value/advantages.py @@ -73,7 +73,9 @@ def __init__( self.differentiable = differentiable self.value_key = value_key if value_key not in value_network.out_keys: - raise KeyError(f"value key '{value_key}' not found in value network out_keys.") + raise KeyError( + f"value key '{value_key}' not found in value network out_keys." + ) self.advantage_key = advantage_key self.value_target_key = value_target_key @@ -245,7 +247,9 @@ def __init__( self.differentiable = differentiable self.value_key = value_key if value_key not in value_network.out_keys: - raise KeyError(f"value key '{value_key}' not found in value network out_keys.") + raise KeyError( + f"value key '{value_key}' not found in value network out_keys." + ) self.advantage_key = advantage_key self.value_target_key = value_target_key @@ -432,7 +436,9 @@ def __init__( self.value_network = value_network self.value_key = value_key if value_key not in value_network.out_keys: - raise KeyError(f"value key '{value_key}' not found in value network out_keys.") + raise KeyError( + f"value key '{value_key}' not found in value network out_keys." + ) self.average_gae = average_gae self.differentiable = differentiable