1111from typing import Any , Callable , Iterable , List , OrderedDict , overload
1212
1313from tensordict ._nestedkey import NestedKey
14+ from tensordict ._td import TensorDict
1415
1516from tensordict .nn .common import (
1617 dispatch ,
@@ -61,14 +62,17 @@ class TensorDictSequential(TensorDictModule):
6162 Regular ``dict`` inputs will be converted to ``OrderedDict`` if necessary.
6263
6364 Keyword Args:
64- partial_tolerant (bool, optional): if True, the input tensordict can miss some of the input keys.
65+ partial_tolerant (bool, optional): if True, the input tensordict can miss some of the input keys.
6566 If so, the only module that will be executed are those who can be executed given the keys that
6667 are present.
6768 Also, if the input tensordict is a lazy stack of tensordicts AND if partial_tolerant is :obj:`True` AND if the
6869 stack does not have the required keys, then TensorDictSequential will scan through the sub-tensordicts
6970 looking for those that have the required keys, if any. Defaults to False.
70- selected_out_keys (iterable of NestedKeys, optional): the list of out-keys to select. If not provided, all
71+ selected_out_keys (iterable of NestedKeys, optional): the list of out-keys to select. If not provided, all
7172 ``out_keys`` will be written.
73+ inplace (bool, optional): if `True`, the input tensordict is modified in-place. If `False`, a new empty
74+ :class:`~tensordict.TensorDict` instance is created. If `"empty"`, `input.empty()` is used instead (ie, the
75+ output preserves type, device and batch-size). Defaults to `None` (relies on sub-modules).
7276
7377 .. note::
7478 A :class:`TensorDictSequential` instance may have a long list of output keys, and one may wish to remove
@@ -185,6 +189,7 @@ def __init__(
185189 * ,
186190 partial_tolerant : bool = False ,
187191 selected_out_keys : List [NestedKey ] | None = None ,
192+ inplace : bool | None = None ,
188193 ) -> None : ...
189194
190195 @overload
@@ -194,13 +199,15 @@ def __init__(
194199 * ,
195200 partial_tolerant : bool = False ,
196201 selected_out_keys : List [NestedKey ] | None = None ,
202+ inplace : bool | None = None ,
197203 ) -> None : ...
198204
199205 def __init__ (
200206 self ,
201207 * modules : Callable [[TensorDictBase ], TensorDictBase ],
202208 partial_tolerant : bool = False ,
203209 selected_out_keys : List [NestedKey ] | None = None ,
210+ inplace : bool | None = None ,
204211 ) -> None :
205212
206213 if len (modules ) == 1 and isinstance (modules [0 ], collections .OrderedDict ):
@@ -236,6 +243,7 @@ def __init__(
236243 module = nn .ModuleList (list (modules )), in_keys = in_keys , out_keys = out_keys
237244 )
238245
246+ self .inplace = inplace
239247 self .partial_tolerant = partial_tolerant
240248 if selected_out_keys :
241249 self ._select_before_return = True
@@ -452,6 +460,43 @@ def select_subsequence(
452460 in_keys=['b'],
453461 out_keys=['d', 'e'])
454462
463+ The `inplace` argument allows for a fine-grained control over the output type, allowing for instance to write
464+ the result of the computational graph in the input object without tracking the intermediate tensors.
465+
466+ Example:
467+ >>> import torch
468+ >>> from tensordict import TensorClass
469+ >>> from tensordict.nn import TensorDictModule as Mod, TensorDictSequential as Seq
470+ >>>
471+ >>> class MyClass(TensorClass):
472+ ... input: torch.Tensor
473+ ... output: torch.Tensor | None = None
474+ >>>
475+ >>> obj = MyClass(torch.randn(2, 3), batch_size=(2,))
476+ >>>
477+ >>> model = Seq(
478+ ... Mod(
479+ ... lambda x: (x + 1, x - 1),
480+ ... in_keys=["input"],
481+ ... out_keys=[("intermediate", "0"), ("intermediate", "1")],
482+ ... inplace=False
483+ ... ),
484+ ... Mod(
485+ ... lambda y0, y1: y0 * y1,
486+ ... in_keys=[("intermediate", "0"), ("intermediate", "1")],
487+ ... out_keys=["output"],
488+ ... inplace=False
489+ ... ),
490+ ... inplace=True, )
491+ >>> print(model(obj))
492+ MyClass(
493+ input=Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
494+ output=Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
495+ output=None,
496+ batch_size=torch.Size([2]),
497+ device=None,
498+ is_shared=False)
499+
455500 """
456501 if in_keys is None :
457502 in_keys = deepcopy (self .in_keys )
@@ -558,6 +603,14 @@ def forward(
558603 tensordict_exec = tensordict .copy ()
559604 else :
560605 tensordict_exec = tensordict
606+ if tensordict_out is None :
607+ if self .inplace is True :
608+ tensordict_out = tensordict
609+ elif self .inplace is False :
610+ tensordict_out = TensorDict ()
611+ elif self .inplace == "empty" :
612+ tensordict_out = tensordict .empty ()
613+
561614 if not len (kwargs ):
562615 for module in self ._module_iter ():
563616 try :
0 commit comments