diff --git a/tensordict/nn/sequence.py b/tensordict/nn/sequence.py index 41a7d7cb1..15269a287 100644 --- a/tensordict/nn/sequence.py +++ b/tensordict/nn/sequence.py @@ -242,7 +242,7 @@ def select_subsequence( "No modules left after selection. Make sure that in_keys and out_keys are coherent." ) - return TensorDictSequential(*modules) + return self.__class__(*modules) def _run_module( self,