Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 29 additions & 7 deletions tensordict/nn/cudagraphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import contextlib
import functools
import os
import warnings
Expand Down Expand Up @@ -168,6 +169,12 @@ def __init__(
if not isinstance(warmup, int) or warmup < 1:
raise ValueError("warmup must be an integer greater than 0.")
self._warmup = warmup
if torch.cuda.is_available():
self._warmup_stream = torch.cuda.Stream()
self._warmup_stream_cm = torch.cuda.stream(self._warmup_stream)
else:
self._warmup_stream = None
self._warmup_stream_cm = contextlib.nullcontext()

if hasattr(module, "in_keys"):
self.in_keys = module.in_keys
Expand Down Expand Up @@ -202,12 +209,17 @@ def _call(
**kwargs: Any,
) -> Any:
if self.counter < self._warmup:
if tensordict_out is not None:
kwargs["tensordict_out"] = tensordict_out
out = self.module(tensordict, *args, **kwargs)
if self._out_matches_in is None:
self._out_matches_in = out is tensordict
if self._warmup_stream is not None:
self._warmup_stream.wait_stream(torch.cuda.current_stream())
with self._warmup_stream_cm:
if tensordict_out is not None:
kwargs["tensordict_out"] = tensordict_out
out = self.module(tensordict, *args, **kwargs)
if self._out_matches_in is None:
self._out_matches_in = out is tensordict
self.counter += self._has_cuda
if self._warmup_stream is not None:
torch.cuda.current_stream().wait_stream(self._warmup_stream)
return out
elif self.counter == self._warmup:
if tensordict.device is None:
Expand All @@ -220,6 +232,7 @@ def _call(
tree_map(self._check_non_tensor, (args, kwargs))

self.graph = torch.cuda.CUDAGraph()
torch.cuda.synchronize()
self._tensordict = tensordict.copy()
with torch.cuda.graph(self.graph):
if tensordict_out is not None:
Expand Down Expand Up @@ -260,7 +273,9 @@ def check_tensor_id(name, t0, t1):
return out.clone() if self._out is not None else None
else:
self._tensordict.update_(tensordict)
torch.cuda.synchronize()
self.graph.replay()
torch.cuda.synchronize()
if self._out_matches_in:
return tensordict.update(
self._out, keys_to_update=self._selected_keys
Expand All @@ -273,11 +288,15 @@ def check_tensor_id(name, t0, t1):

def _call(*args: torch.Tensor, **kwargs: torch.Tensor):
if self.counter < self._warmup:
out = self.module(*args, **kwargs)
if self._warmup_stream is not None:
self._warmup_stream.wait_stream(torch.cuda.current_stream())
with self._warmup_stream_cm:
out = self.module(*args, **kwargs)
if self._warmup_stream is not None:
torch.cuda.current_stream().wait_stream(self._warmup_stream)
self.counter += self._has_cuda
return out
elif self.counter == self._warmup:
self.graph = torch.cuda.CUDAGraph()

def check_device_and_clone(x):
if isinstance(x, torch.Tensor) or is_tensor_collection(x):
Expand All @@ -298,6 +317,7 @@ def check_device_and_clone(x):
self._args, self._kwargs = tree_map(
check_device_and_clone, (args, kwargs)
)
self.graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.graph):
out = self.module(*self._args, **self._kwargs)
self.graph.replay()
Expand Down Expand Up @@ -334,7 +354,9 @@ def check_device_and_clone(x):
(self._args, self._kwargs),
(args, kwargs),
)
torch.cuda.synchronize()
self.graph.replay()
torch.cuda.synchronize()
if self._return_unchanged == "clone":
return self._out.clone()
elif self._return_unchanged:
Expand Down