Skip to content

Commit 9733d6e

Browse files
author
Vincent Moens
committed
[Feature] COMPOSITE_LP_AGGREGATE env variable
ghstack-source-id: 16b07d0 Pull Request resolved: #1190
1 parent 790bef6 commit 9733d6e

File tree

1 file changed

+11
-3
lines changed

1 file changed

+11
-3
lines changed

tensordict/nn/utils.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -450,7 +450,13 @@ def _generate_next_value_(name, start, count, last_values):
450450
return name.lower()
451451

452452

453-
_composite_lp_aggregate = _ContextManager()
453+
_composite_lp_aggregate = _ContextManager(
454+
default=(
455+
strtobool(os.getenv("COMPOSITE_LP_AGGREGATE"))
456+
if os.getenv("COMPOSITE_LP_AGGREGATE") is not None
457+
else None
458+
)
459+
)
454460

455461

456462
def composite_lp_aggregate(nowarn: bool = False) -> bool | None:
@@ -467,9 +473,9 @@ def composite_lp_aggregate(nowarn: bool = False) -> bool | None:
467473
if not nowarn:
468474
warnings.warn(
469475
"Composite log-prob aggregation wasn't defined explicitly and ``composite_lp_aggregate()`` will "
470-
"currently return ``True``. However, from v0.9, this behaviour will change and ``composite_lp_aggregate`` will "
476+
"currently return ``True``. However, from v0.9, this behavior will change and ``composite_lp_aggregate`` will "
471477
"return ``False``. Please change your code accordingly by specifying the aggregation strategy via "
472-
"`tensordict.nn.set_composite_lp_aggregate`.",
478+
"`tensordict.nn.set_composite_lp_aggregate` or via the `COMPOSITE_LP_AGGREGATE` environment variable.",
473479
category=DeprecationWarning,
474480
)
475481
return True
@@ -483,6 +489,8 @@ class set_composite_lp_aggregate(_DecoratorContextManager):
483489
will be summed into a single tensor with the shape of the root tensordict. This behaviour is being deprecated in favor of
484490
non-aggregated log-probs, which offer more flexibility and a somewhat more natural API (tensordict samples, tensordict log-probs, tensordict entropies).
485491
492+
The value of composite_lp_aggregate can also be controlled through the `COMPOSITE_LP_AGGREGATE` environment variable.
493+
486494
Example:
487495
>>> _ = torch.manual_seed(0)
488496
>>> from tensordict import TensorDict

0 commit comments

Comments
 (0)