-
Notifications
You must be signed in to change notification settings - Fork 610
CUDAGraph support for SimpleFSDP and TP #2050
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
658e73d
dd218f8
8ee5fce
40823fa
6a297db
73812f9
5c3da3f
b2b2b4f
b433910
5559ae4
a3ed72c
2163f31
3ff3ada
e49a2f2
4845b96
fac30da
2b5cfbc
5992263
b0feed3
5bc3c2a
3835a14
26414c0
2d037e4
267d0ae
752f307
c8e7384
0516fa7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,169 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| """ | ||
| CUDAGraph pass for the compiler toolkit. | ||
|
|
||
| This module provides a cudagraph pass that can be applied to graph modules | ||
| during compilation. | ||
| """ | ||
|
|
||
| import warnings | ||
| from typing import Any, Callable, Optional, Sequence | ||
|
|
||
| import torch | ||
| from torch._inductor.cudagraph_trees import _use_cuda_memory_pool_manager | ||
| from torch.utils._ordered_set import OrderedSet | ||
|
|
||
|
|
||
| def init_global_graph_pool() -> tuple[ | ||
| torch.cuda.CUDAGraph, torch.cuda._POOL_HANDLE, torch.cuda.Stream | ||
| ]: | ||
| dummy_graph = torch.cuda.CUDAGraph() | ||
|
|
||
| # create a global cudagraph memory pool to allow memory reuse across cudagraphs. | ||
| graph_pool = torch.cuda.graph_pool_handle() | ||
|
|
||
| # create a global cuda stream for graph capture. we need to use a single stream | ||
| # for all allocations to the memory pool, otherwise the allocations to separate streams | ||
| # will not be used. | ||
| graph_capture_stream = torch.cuda.Stream() | ||
|
|
||
| # use a dummy graph to keep the global graph pool alive | ||
| with ( | ||
| # suppress an empty cudagraph warning, since we intentionally create | ||
| # an empty cudagraph here | ||
| warnings.catch_warnings(record=True), | ||
| torch.cuda.graph( | ||
| dummy_graph, | ||
| pool=graph_pool, | ||
| stream=graph_capture_stream, | ||
| capture_error_mode="thread_local", | ||
| ), | ||
| ): | ||
| pass | ||
|
|
||
| return dummy_graph, graph_pool, graph_capture_stream | ||
|
|
||
|
|
||
| ( | ||
| _global_dummy_graph, | ||
| _global_graph_pool, | ||
| _global_graph_capture_stream, | ||
| ) = init_global_graph_pool() | ||
|
|
||
|
|
||
| class CUDAGraphWrapper: | ||
| def __init__( | ||
| self, | ||
| runnable: Callable, | ||
| example_inputs: Sequence[Any], | ||
| static_input_indices: Optional[tuple[int]] = None, | ||
| should_check_address: bool = False, | ||
| ): | ||
| self.runnable = runnable | ||
| self.graph_pool = _global_graph_pool | ||
| self.stream = _global_graph_capture_stream | ||
| self.static_input_indices = OrderedSet( | ||
| static_input_indices if static_input_indices is not None else [] | ||
| ) | ||
| self.input_indices_to_copy = [ | ||
| i | ||
| for i, inp in enumerate(example_inputs) | ||
| if isinstance(inp, torch.Tensor) and i not in self.static_input_indices | ||
| ] | ||
| self.cudagraph: Optional[torch.cuda.CUDAGraph] = None | ||
| self.has_warmup = False | ||
|
|
||
| self.args = None | ||
| self.output = None | ||
|
|
||
| # (debug only) whether check static input tensor addresses during runtime | ||
| self.should_check_address = should_check_address | ||
|
|
||
| def copy_non_static_inputs(self, *args): | ||
| for i in self.input_indices_to_copy: | ||
| self.args[i].copy_(args[i]) | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we could replace this for loop with foreach copy. However, I empirically observed there is only 1 tensor to copy for fwd and 1 tensor to copy for bwd. So no need to add code complexity here. |
||
|
|
||
| def check_input_types(self, inputs) -> None: | ||
| for inp in inputs: | ||
| assert isinstance(inp, (torch.Tensor, int, torch._C.Generator)), ( | ||
| "args must be tensor, integer (for dynamic shapes), " | ||
| "or Generator (for random number generator), " | ||
| f"but found {type(inp)}" | ||
| ) | ||
|
|
||
| def check_static_inputs_address(self) -> None: | ||
| for i in self.static_input_indices: | ||
| actual = args[i].data_ptr() | ||
| expected = self.input_addresses[i] | ||
| assert expected == actual, ( | ||
| "Expected the same static tensor address but found " | ||
| f"{expected} != {actual}" | ||
| ) | ||
|
|
||
| def __call__(self, *args): | ||
| if not self.has_warmup: | ||
| self.has_warmup = True | ||
| device = torch.cuda.current_device() | ||
|
|
||
| # warmup in cudagraph memory pool to avoid fragmentation | ||
| # across eager memory pool and cudagraph memory pool. | ||
| with _use_cuda_memory_pool_manager(device, self.graph_pool, self.stream): | ||
| out = self.runnable(*args) | ||
| return out | ||
|
|
||
| if self.cudagraph is None: | ||
| self.check_input_types(args) | ||
| self.args = args | ||
| self.input_addresses = [ | ||
| x.data_ptr() if isinstance(x, torch.Tensor) else None for x in args | ||
| ] | ||
|
|
||
| self.cudagraph = torch.cuda.CUDAGraph() | ||
|
|
||
| with torch.cuda.graph( | ||
| self.cudagraph, pool=self.graph_pool, stream=self.stream | ||
| ): | ||
| # `output` is managed by pytorch's cudagraph pool | ||
| self.output = self.runnable(*args) | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we could potentially use weakref for output tensor to reduce memory. Will do in a followup pr. |
||
|
|
||
| if self.should_check_address: | ||
| self.check_static_inputs_address() | ||
|
|
||
| self.copy_non_static_inputs(*args) | ||
| self.cudagraph.replay() | ||
| return self.output | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The persistent input and output is not good for memory, as you've commented.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes will add in the next pr. |
||
|
|
||
|
|
||
| def get_static_input_indices(gm: torch.fx.GraphModule, is_forward: bool) -> list[int]: | ||
| """ | ||
| Get indices of gm inputs that are static input tensors whose tensor addresses do not | ||
| change across runs. Example of static input tensors include weights, buffers, and | ||
| outputs of previous cudagraph wrapped functions. | ||
| """ | ||
| from torch._inductor.utils import count_tangents | ||
|
|
||
| static_input_indices = [] | ||
| if ( | ||
| is_forward | ||
| and (tracing_context := torch._guards.TracingContext.try_get()) | ||
| and hasattr(tracing_context, "fw_metadata") | ||
| ): | ||
| # for forward, we rely on graph capture (i.e., dynamo or export) to provide | ||
| # the correct static input indices stored in tracing context. Typical examples | ||
| # include weights and buffers. | ||
| static_input_indices = tracing_context.fw_metadata.static_input_indices | ||
|
|
||
| elif not is_forward: | ||
| # for backward, we identify saved tensors as static inputs, since saved tensors | ||
| # are outputs of cudagraph-wrapped forward run. In PT2-generated backward gm, | ||
| # saved tensors are always the leading args. So we can get the number of saved | ||
| # tensors and generate static input indices. | ||
| fixed = count_tangents(gm) | ||
| static_input_indices = list(range(fixed)) | ||
|
|
||
| return static_input_indices | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does this work when backward is on a separate stream ? or not an issue?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC, this is not an issue currently. since fwd and bwd are on the same cuda stream by default.
cudagraph trees has used the same graph capture stream for both fwd and bwd.
https://github.com/pytorch/pytorch/blob/7a928397cda89b71c24b0efe9db6df7fb04a46cb/torch/_inductor/cudagraph_trees.py#L1945