Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 24 additions & 3 deletions torchrl/objectives/value/advantages.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ class TDEstimate(nn.Module):
Defaults to "advantage".
value_target_key (str or tuple of str, optional): the key of the advantage entry.
Defaults to "value_target".
value_key (str or tuple of str, optional): the value key to read from the input tensordict.
Defaults to "state_value".

"""

Expand All @@ -57,6 +59,7 @@ def __init__(
differentiable: bool = False,
advantage_key: Union[str, Tuple] = "advantage",
value_target_key: Union[str, Tuple] = "value_target",
value_key: Union[str, Tuple] = "state_value",
):
super().__init__()
try:
Expand All @@ -68,7 +71,11 @@ def __init__(

self.average_rewards = average_rewards
self.differentiable = differentiable
self.value_key = value_network.out_keys[0]
self.value_key = value_key
if value_key not in value_network.out_keys:
raise KeyError(
f"value key '{value_key}' not found in value network out_keys."
)

self.advantage_key = advantage_key
self.value_target_key = value_target_key
Expand Down Expand Up @@ -209,6 +216,8 @@ class TDLambdaEstimate(nn.Module):
Defaults to "advantage".
value_target_key (str or tuple of str, optional): the key of the advantage entry.
Defaults to "value_target".
value_key (str or tuple of str, optional): the value key to read from the input tensordict.
Defaults to "state_value".

"""

Expand All @@ -222,6 +231,7 @@ def __init__(
vectorized: bool = True,
advantage_key: Union[str, Tuple] = "advantage",
value_target_key: Union[str, Tuple] = "value_target",
value_key: Union[str, Tuple] = "state_value",
):
super().__init__()
try:
Expand All @@ -235,7 +245,11 @@ def __init__(

self.average_rewards = average_rewards
self.differentiable = differentiable
self.value_key = self.value_network.out_keys[0]
self.value_key = value_key
if value_key not in value_network.out_keys:
raise KeyError(
f"value key '{value_key}' not found in value network out_keys."
)

self.advantage_key = advantage_key
self.value_target_key = value_target_key
Expand Down Expand Up @@ -389,6 +403,8 @@ class GAE(nn.Module):
Defaults to "advantage".
value_target_key (str or tuple of str, optional): the key of the advantage entry.
Defaults to "value_target".
value_key (str or tuple of str, optional): the value key to read from the input tensordict.
Defaults to "state_value".

GAE will return an :obj:`"advantage"` entry containing the advange value. It will also
return a :obj:`"value_target"` entry with the return value that is to be used
Expand All @@ -408,6 +424,7 @@ def __init__(
differentiable: bool = False,
advantage_key: Union[str, Tuple] = "advantage",
value_target_key: Union[str, Tuple] = "value_target",
value_key: Union[str, Tuple] = "state_value",
):
super().__init__()
try:
Expand All @@ -417,7 +434,11 @@ def __init__(
self.register_buffer("gamma", torch.tensor(gamma, device=device))
self.register_buffer("lmbda", torch.tensor(lmbda, device=device))
self.value_network = value_network
self.value_key = self.value_network.out_keys[0]
self.value_key = value_key
if value_key not in value_network.out_keys:
raise KeyError(
f"value key '{value_key}' not found in value network out_keys."
)

self.average_gae = average_gae
self.differentiable = differentiable
Expand Down