From 77e49fd82e84277c471a30dae75b588a8fa687f5 Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 13 Dec 2022 11:04:51 +0000 Subject: [PATCH 1/3] init --- torchrl/objectives/value/advantages.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchrl/objectives/value/advantages.py b/torchrl/objectives/value/advantages.py index 4f558feecef..677beaea396 100644 --- a/torchrl/objectives/value/advantages.py +++ b/torchrl/objectives/value/advantages.py @@ -336,6 +336,7 @@ def forward( ) tensordict.set("advantage", adv.detach()) + tensordict.set("value_target", value_target) if self.gradient_mode: tensordict.set("value_error", value_target - value) From 3854c810118b14b9b7c53da59ea92030a803103a Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 13 Dec 2022 13:08:40 +0000 Subject: [PATCH 2/3] docstr --- torchrl/objectives/value/advantages.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/torchrl/objectives/value/advantages.py b/torchrl/objectives/value/advantages.py index 677beaea396..a4ecd994a54 100644 --- a/torchrl/objectives/value/advantages.py +++ b/torchrl/objectives/value/advantages.py @@ -252,6 +252,13 @@ class GAE(nn.Module): gradient_mode (bool): if True, gradients are propagated throught the computation of the value function. Default is `False`. + GAE will return an :obj:`"advantage"` entry containing the advange value. It will also + return a :obj:`"value_target"` entry with the return value that is to be used + to train the value network. Finally, if :obj:`gradient_mode` is :obj:`True`, + an additional and differentiable :obj:`"value_error"` entry will be returned, + which simple represents the difference between the return and the value network + output (i.e. an additional distance loss should be applied to that signed value). + """ def __init__( From d3ebfd202813cea9d0a7a63f7dffd42512206d51 Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 13 Dec 2022 16:16:37 +0000 Subject: [PATCH 3/3] nightly fix --- torchrl/data/tensor_specs.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index ec347be62ed..71a20e522ef 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -449,6 +449,11 @@ def _project(self, val: torch.Tensor) -> torch.Tensor: maximum = maximum.expand_as(val) val[val < minimum] = minimum[val < minimum] val[val > maximum] = maximum[val > maximum] + except RuntimeError: + minimum = minimum.expand_as(val) + maximum = maximum.expand_as(val) + val[val < minimum] = minimum[val < minimum] + val[val > maximum] = maximum[val > maximum] return val def is_in(self, val: torch.Tensor) -> bool: