Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 37 additions & 9 deletions tensordict/nn/cudagraphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@
from torch.utils._pytree import SUPPORTED_NODES, tree_map

try:
from torch.utils._pytree import tree_leaves
from torch.utils._pytree import tree_flatten, tree_leaves, tree_unflatten
except ImportError:
from torch.utils._pytree import tree_flatten
from torch.utils._pytree import tree_flatten, tree_unflatten

def tree_leaves(pytree):
"""Torch 2.0 compatible version of tree_leaves."""
Expand Down Expand Up @@ -293,11 +293,13 @@ def check_tensor_id(name, t0, t1):

def _call(*args: torch.Tensor, **kwargs: torch.Tensor):
if self.counter >= self._warmup:
tree_map(
lambda x, y: x.copy_(y, non_blocking=True),
(self._args, self._kwargs),
(args, kwargs),
)
srcs, dests = [], []
for arg_src, arg_dest in zip(
tree_leaves((args, kwargs)), self._flat_tree
):
self._maybe_copy_onto_(arg_src, arg_dest, srcs, dests)
if dests:
torch._foreach_copy_(dests, srcs)
torch.cuda.synchronize()
self.graph.replay()
if self._return_unchanged == "clone":
Expand All @@ -322,8 +324,13 @@ def _call(*args: torch.Tensor, **kwargs: torch.Tensor):
self.counter += self._has_cuda
return out
else:
args, kwargs = self._args, self._kwargs = tree_map(
self._check_device_and_clone, (args, kwargs)
self._flat_tree, self._tree_spec = tree_flatten((args, kwargs))

self._flat_tree = tuple(
self._check_device_and_clone(arg) for arg in self._flat_tree
)
args, kwargs = self._args, self._kwargs = tree_unflatten(
self._flat_tree, self._tree_spec
)

torch.cuda.synchronize()
Expand Down Expand Up @@ -360,6 +367,27 @@ def _call(*args: torch.Tensor, **kwargs: torch.Tensor):
_call_func = functools.wraps(self.module)(_call)
self._call_func = _call_func

@staticmethod
def _maybe_copy_onto_(src, dest, srcs, dests):
if isinstance(src, torch.Tensor):
srcs.append(src)
dests.append(dest)
return
if is_tensor_collection(src):
dest.copy_(src)
return
isdiff = False
try:
isdiff = src != dest
except Exception as err:
raise RuntimeError(
"Couldn't assess input value. Make sure your function only takes tensor inputs or that "
"the input value can be easily checked and is constant. For a better efficiency, avoid "
"passing non-tensor inputs to your function."
) from err
if isdiff:
raise ValueError("Varying inputs must be torch.Tensor subclasses.")

@classmethod
def _check_device_and_clone(cls, x):
if isinstance(x, torch.Tensor) or is_tensor_collection(x):
Expand Down
13 changes: 13 additions & 0 deletions test/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -1056,6 +1056,19 @@ def test_td_input_non_tdmodule(self, compiled):
if i == 5:
assert not func._is_tensordict_module

def test_td_input_non_tdmodule_nontensor(self, compiled):
func = lambda x, y: x + y
func = self._make_cudagraph(func, compiled)
for i in range(10):
assert func(torch.zeros(()), 1.0) == 1.0
if i == 5:
assert not func._is_tensordict_module
if torch.cuda.is_available():
with pytest.raises(
ValueError, match="Varying inputs must be torch.Tensor subclasses."
):
func(torch.zeros(()), 2.0)


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
Expand Down