diff --git a/tensordict/nn/cudagraphs.py b/tensordict/nn/cudagraphs.py index 55308ab3b..6185e89dc 100644 --- a/tensordict/nn/cudagraphs.py +++ b/tensordict/nn/cudagraphs.py @@ -24,7 +24,16 @@ from tensordict.utils import strtobool from torch import Tensor -from torch.utils._pytree import SUPPORTED_NODES, tree_leaves, tree_map +from torch.utils._pytree import SUPPORTED_NODES, tree_map + +try: + from torch.utils._pytree import tree_leaves +except ImportError: + from torch.utils._pytree import tree_flatten + + def tree_leaves(pytree): + """Torch 2.0 compatible version of tree_leaves.""" + return tree_flatten(pytree)[0] class CudaGraphModule: