From 3c224162c970047ea2142597c3dd13e38de6ada4 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 11 Oct 2024 11:01:10 +0100 Subject: [PATCH 1/7] Update [ghstack-poisoned] --- tensordict/nn/cudagraphs.py | 39 ++++++++++++++++++++++++++++--------- test/test_compile.py | 13 +++++++++++++ 2 files changed, 43 insertions(+), 9 deletions(-) diff --git a/tensordict/nn/cudagraphs.py b/tensordict/nn/cudagraphs.py index d6eefe1eb..aed6ba7e3 100644 --- a/tensordict/nn/cudagraphs.py +++ b/tensordict/nn/cudagraphs.py @@ -28,9 +28,9 @@ from torch.utils._pytree import SUPPORTED_NODES, tree_map try: - from torch.utils._pytree import tree_leaves + from torch.utils._pytree import tree_flatten, tree_leaves, tree_unflatten except ImportError: - from torch.utils._pytree import tree_flatten + from torch.utils._pytree import tree_flatten, tree_unflatten def tree_leaves(pytree): """Torch 2.0 compatible version of tree_leaves.""" @@ -293,11 +293,12 @@ def check_tensor_id(name, t0, t1): def _call(*args: torch.Tensor, **kwargs: torch.Tensor): if self.counter >= self._warmup: - tree_map( - lambda x, y: x.copy_(y, non_blocking=True), - (self._args, self._kwargs), - (args, kwargs), - ) + srcs, dests = [], [] + for arg_src, arg_dest in zip( + tree_leaves((args, kwargs)), self._flat_tree + ): + self._maybe_copy_onto_(arg_src, arg_dest, srcs, dests) + torch._foreach_copy_(dests, srcs) torch.cuda.synchronize() self.graph.replay() if self._return_unchanged == "clone": @@ -322,8 +323,13 @@ def _call(*args: torch.Tensor, **kwargs: torch.Tensor): self.counter += self._has_cuda return out else: - args, kwargs = self._args, self._kwargs = tree_map( - self._check_device_and_clone, (args, kwargs) + self._flat_tree, self._tree_spec = tree_flatten((args, kwargs)) + + self._flat_tree = tuple( + self._check_device_and_clone(arg) for arg in self._flat_tree + ) + args, kwargs = self._args, self._kwargs = tree_unflatten( + self._flat_tree, self._tree_spec ) torch.cuda.synchronize() @@ -360,6 +366,21 @@ def _call(*args: torch.Tensor, **kwargs: torch.Tensor): _call_func = functools.wraps(self.module)(_call) self._call_func = _call_func + @staticmethod + def _maybe_copy_onto_(src, dest, srcs, dests): + if isinstance(src, (torch.Tensor, TensorDictBase)): + srcs.append(src) + dests.append(dest) + try: + if src != dest: + raise ValueError("Varying inputs must be torch.Tensor subclasses.") + except Exception: + raise RuntimeError( + "Couldn't assess input value. Make sure your function only takes tensor inputs or that " + "the input value can be easily checked and is constant. For a better efficiency, avoid " + "passing non-tensor inputs to your function." + ) + @classmethod def _check_device_and_clone(cls, x): if isinstance(x, torch.Tensor) or is_tensor_collection(x): diff --git a/test/test_compile.py b/test/test_compile.py index 755a928e2..ff4e79f38 100644 --- a/test/test_compile.py +++ b/test/test_compile.py @@ -1056,6 +1056,19 @@ def test_td_input_non_tdmodule(self, compiled): if i == 5: assert not func._is_tensordict_module + def test_td_input_non_tdmodule_nontensor(self, compiled): + func = lambda x, y: x + y + func = self._make_cudagraph(func, compiled) + for i in range(10): + assert func(torch.zeros(()), 1.0) == 1.0 + if i == 5: + assert not func._is_tensordict_module + if torch.cuda.is_available(): + with pytest.raises( + ValueError, match="Varying inputs must be torch.Tensor subclasses." + ): + func(torch.zeros(()), 2.0) + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() From 610e9087cf0ed62df6d3c7bada5f11797f589a3a Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 11 Oct 2024 11:02:14 +0100 Subject: [PATCH 2/7] Update [ghstack-poisoned] --- tensordict/nn/cudagraphs.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tensordict/nn/cudagraphs.py b/tensordict/nn/cudagraphs.py index aed6ba7e3..a2d37b51d 100644 --- a/tensordict/nn/cudagraphs.py +++ b/tensordict/nn/cudagraphs.py @@ -371,6 +371,7 @@ def _maybe_copy_onto_(src, dest, srcs, dests): if isinstance(src, (torch.Tensor, TensorDictBase)): srcs.append(src) dests.append(dest) + return try: if src != dest: raise ValueError("Varying inputs must be torch.Tensor subclasses.") From fddc70f73a4a45f253810c1dbd63190d279d4fd0 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 11 Oct 2024 11:02:51 +0100 Subject: [PATCH 3/7] Update [ghstack-poisoned] --- tensordict/nn/cudagraphs.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tensordict/nn/cudagraphs.py b/tensordict/nn/cudagraphs.py index a2d37b51d..6bb4a7837 100644 --- a/tensordict/nn/cudagraphs.py +++ b/tensordict/nn/cudagraphs.py @@ -298,7 +298,8 @@ def _call(*args: torch.Tensor, **kwargs: torch.Tensor): tree_leaves((args, kwargs)), self._flat_tree ): self._maybe_copy_onto_(arg_src, arg_dest, srcs, dests) - torch._foreach_copy_(dests, srcs) + if dests: + torch._foreach_copy_(dests, srcs) torch.cuda.synchronize() self.graph.replay() if self._return_unchanged == "clone": From e00c7ef0299d2c27534035c4b1f9731dbb82bbab Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 11 Oct 2024 11:04:57 +0100 Subject: [PATCH 4/7] Update [ghstack-poisoned] --- tensordict/nn/cudagraphs.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tensordict/nn/cudagraphs.py b/tensordict/nn/cudagraphs.py index 6bb4a7837..1cb54b3cb 100644 --- a/tensordict/nn/cudagraphs.py +++ b/tensordict/nn/cudagraphs.py @@ -369,10 +369,12 @@ def _call(*args: torch.Tensor, **kwargs: torch.Tensor): @staticmethod def _maybe_copy_onto_(src, dest, srcs, dests): - if isinstance(src, (torch.Tensor, TensorDictBase)): + if isinstance(src, torch.Tensor): srcs.append(src) dests.append(dest) return + if is_tensor_collection(src): + dest.copy_(src) try: if src != dest: raise ValueError("Varying inputs must be torch.Tensor subclasses.") From 7f0c28f657237649d72f5061f86820110bbec8e6 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 11 Oct 2024 11:05:30 +0100 Subject: [PATCH 5/7] Update [ghstack-poisoned] --- tensordict/nn/cudagraphs.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tensordict/nn/cudagraphs.py b/tensordict/nn/cudagraphs.py index 1cb54b3cb..82351f882 100644 --- a/tensordict/nn/cudagraphs.py +++ b/tensordict/nn/cudagraphs.py @@ -375,6 +375,7 @@ def _maybe_copy_onto_(src, dest, srcs, dests): return if is_tensor_collection(src): dest.copy_(src) + return try: if src != dest: raise ValueError("Varying inputs must be torch.Tensor subclasses.") From b567315d004eb2cec81ff2f8a3b4a13e79d01dc1 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 11 Oct 2024 11:06:53 +0100 Subject: [PATCH 6/7] Update [ghstack-poisoned] --- tensordict/nn/cudagraphs.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tensordict/nn/cudagraphs.py b/tensordict/nn/cudagraphs.py index 82351f882..ab795fd37 100644 --- a/tensordict/nn/cudagraphs.py +++ b/tensordict/nn/cudagraphs.py @@ -379,12 +379,13 @@ def _maybe_copy_onto_(src, dest, srcs, dests): try: if src != dest: raise ValueError("Varying inputs must be torch.Tensor subclasses.") - except Exception: + return + except Exception as err: raise RuntimeError( "Couldn't assess input value. Make sure your function only takes tensor inputs or that " "the input value can be easily checked and is constant. For a better efficiency, avoid " "passing non-tensor inputs to your function." - ) + ) from err @classmethod def _check_device_and_clone(cls, x): From efce95bb5537e69736959e408413ceaa4ca28caa Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 11 Oct 2024 11:08:21 +0100 Subject: [PATCH 7/7] Update [ghstack-poisoned] --- tensordict/nn/cudagraphs.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tensordict/nn/cudagraphs.py b/tensordict/nn/cudagraphs.py index ab795fd37..e99236b48 100644 --- a/tensordict/nn/cudagraphs.py +++ b/tensordict/nn/cudagraphs.py @@ -376,16 +376,17 @@ def _maybe_copy_onto_(src, dest, srcs, dests): if is_tensor_collection(src): dest.copy_(src) return + isdiff = False try: - if src != dest: - raise ValueError("Varying inputs must be torch.Tensor subclasses.") - return + isdiff = src != dest except Exception as err: raise RuntimeError( "Couldn't assess input value. Make sure your function only takes tensor inputs or that " "the input value can be easily checked and is constant. For a better efficiency, avoid " "passing non-tensor inputs to your function." ) from err + if isdiff: + raise ValueError("Varying inputs must be torch.Tensor subclasses.") @classmethod def _check_device_and_clone(cls, x):