From d8e1e76ab2063e50ebdeb69eb4f09d269acd262e Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 16 Sep 2024 16:44:29 -0700 Subject: [PATCH] Update [ghstack-poisoned] --- tensordict/nn/cudagraphs.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/tensordict/nn/cudagraphs.py b/tensordict/nn/cudagraphs.py index 55308ab3b..6185e89dc 100644 --- a/tensordict/nn/cudagraphs.py +++ b/tensordict/nn/cudagraphs.py @@ -24,7 +24,16 @@ from tensordict.utils import strtobool from torch import Tensor -from torch.utils._pytree import SUPPORTED_NODES, tree_leaves, tree_map +from torch.utils._pytree import SUPPORTED_NODES, tree_map + +try: + from torch.utils._pytree import tree_leaves +except ImportError: + from torch.utils._pytree import tree_flatten + + def tree_leaves(pytree): + """Torch 2.0 compatible version of tree_leaves.""" + return tree_flatten(pytree)[0] class CudaGraphModule: