Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions torchtitan/experiments/compiler_toolkit/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,9 @@ NGPU=8 TRAIN_FILE=torchtitan.experiments.compiler_toolkit.train CONFIG_FILE=./to
```shell
NGPU=8 TRAIN_FILE=torchtitan.experiments.compiler_toolkit.train CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config --compile.passes transformer_block_bucketing,regional_inductor
```

**SimpleFSDP + TP + FlexAttention + transformer-block-bucketing + regional-inductor + cudagraph**

```shell
NCCL_GRAPH_REGISTER=0 NGPU=8 TRAIN_FILE=torchtitan.experiments.compiler_toolkit.train CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config --compile.passes transformer_block_bucketing,regional_inductor,cudagraph
```
9 changes: 9 additions & 0 deletions torchtitan/experiments/compiler_toolkit/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.

from contextlib import contextmanager
from typing import Callable

import torch
from torch.distributed.tensor import DTensor, Replicate
Expand Down Expand Up @@ -53,3 +54,11 @@ def register_blockmask_pytree_node():
flatten_with_keys_fn=BlockMask._flatten_with_keys,
serialized_type_name="torch.nn.attention.flex_attention.BlockMask",
)


def end_with_pass(passes: list[Callable], names: list[str]) -> bool:
return (
len(passes) > 0
and (last_pass_name := getattr(passes[-1], "__name__", None))
and (last_pass_name in names)
)
169 changes: 169 additions & 0 deletions torchtitan/experiments/compiler_toolkit/cudagraph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""
CUDAGraph pass for the compiler toolkit.

This module provides a cudagraph pass that can be applied to graph modules
during compilation.
"""

import warnings
from typing import Any, Callable, Optional, Sequence

import torch
from torch._inductor.cudagraph_trees import _use_cuda_memory_pool_manager
from torch.utils._ordered_set import OrderedSet


def init_global_graph_pool() -> tuple[
torch.cuda.CUDAGraph, torch.cuda._POOL_HANDLE, torch.cuda.Stream
]:
dummy_graph = torch.cuda.CUDAGraph()

# create a global cudagraph memory pool to allow memory reuse across cudagraphs.
graph_pool = torch.cuda.graph_pool_handle()

# create a global cuda stream for graph capture. we need to use a single stream
# for all allocations to the memory pool, otherwise the allocations to separate streams
# will not be used.
graph_capture_stream = torch.cuda.Stream()

# use a dummy graph to keep the global graph pool alive
with (
# suppress an empty cudagraph warning, since we intentionally create
# an empty cudagraph here
warnings.catch_warnings(record=True),
torch.cuda.graph(
dummy_graph,
pool=graph_pool,
stream=graph_capture_stream,
capture_error_mode="thread_local",
),
):
pass

return dummy_graph, graph_pool, graph_capture_stream


(
_global_dummy_graph,
_global_graph_pool,
_global_graph_capture_stream,
) = init_global_graph_pool()
Comment on lines +52 to +56

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this work when backward is on a separate stream ? or not an issue?

Copy link
Contributor Author

@BoyuanFeng BoyuanFeng Nov 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, this is not an issue currently. since fwd and bwd are on the same cuda stream by default.

cudagraph trees has used the same graph capture stream for both fwd and bwd.
https://github.com/pytorch/pytorch/blob/7a928397cda89b71c24b0efe9db6df7fb04a46cb/torch/_inductor/cudagraph_trees.py#L1945



class CUDAGraphWrapper:
def __init__(
self,
runnable: Callable,
example_inputs: Sequence[Any],
static_input_indices: Optional[tuple[int]] = None,
should_check_address: bool = False,
):
self.runnable = runnable
self.graph_pool = _global_graph_pool
self.stream = _global_graph_capture_stream
self.static_input_indices = OrderedSet(
static_input_indices if static_input_indices is not None else []
)
self.input_indices_to_copy = [
i
for i, inp in enumerate(example_inputs)
if isinstance(inp, torch.Tensor) and i not in self.static_input_indices
]
self.cudagraph: Optional[torch.cuda.CUDAGraph] = None
self.has_warmup = False

self.args = None
self.output = None

# (debug only) whether check static input tensor addresses during runtime
self.should_check_address = should_check_address

def copy_non_static_inputs(self, *args):
for i in self.input_indices_to_copy:
self.args[i].copy_(args[i])
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could replace this for loop with foreach copy. However, I empirically observed there is only 1 tensor to copy for fwd and 1 tensor to copy for bwd. So no need to add code complexity here.


def check_input_types(self, inputs) -> None:
for inp in inputs:
assert isinstance(inp, (torch.Tensor, int, torch._C.Generator)), (
"args must be tensor, integer (for dynamic shapes), "
"or Generator (for random number generator), "
f"but found {type(inp)}"
)

def check_static_inputs_address(self) -> None:
for i in self.static_input_indices:
actual = args[i].data_ptr()
expected = self.input_addresses[i]
assert expected == actual, (
"Expected the same static tensor address but found "
f"{expected} != {actual}"
)

def __call__(self, *args):
if not self.has_warmup:
self.has_warmup = True
device = torch.cuda.current_device()

# warmup in cudagraph memory pool to avoid fragmentation
# across eager memory pool and cudagraph memory pool.
with _use_cuda_memory_pool_manager(device, self.graph_pool, self.stream):
out = self.runnable(*args)
return out

if self.cudagraph is None:
self.check_input_types(args)
self.args = args
self.input_addresses = [
x.data_ptr() if isinstance(x, torch.Tensor) else None for x in args
]

self.cudagraph = torch.cuda.CUDAGraph()

with torch.cuda.graph(
self.cudagraph, pool=self.graph_pool, stream=self.stream
):
# `output` is managed by pytorch's cudagraph pool
self.output = self.runnable(*args)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could potentially use weakref for output tensor to reduce memory. Will do in a followup pr.


if self.should_check_address:
self.check_static_inputs_address()

self.copy_non_static_inputs(*args)
self.cudagraph.replay()
return self.output

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The persistent input and output is not good for memory, as you've commented.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes will add in the next pr.



def get_static_input_indices(gm: torch.fx.GraphModule, is_forward: bool) -> list[int]:
"""
Get indices of gm inputs that are static input tensors whose tensor addresses do not
change across runs. Example of static input tensors include weights, buffers, and
outputs of previous cudagraph wrapped functions.
"""
from torch._inductor.utils import count_tangents

static_input_indices = []
if (
is_forward
and (tracing_context := torch._guards.TracingContext.try_get())
and hasattr(tracing_context, "fw_metadata")
):
# for forward, we rely on graph capture (i.e., dynamo or export) to provide
# the correct static input indices stored in tracing context. Typical examples
# include weights and buffers.
static_input_indices = tracing_context.fw_metadata.static_input_indices

elif not is_forward:
# for backward, we identify saved tensors as static inputs, since saved tensors
# are outputs of cudagraph-wrapped forward run. In PT2-generated backward gm,
# saved tensors are always the leading args. So we can get the number of saved
# tensors and generate static input indices.
fixed = count_tangents(gm)
static_input_indices = list(range(fixed))

return static_input_indices
50 changes: 41 additions & 9 deletions torchtitan/experiments/compiler_toolkit/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from torch.distributed.tensor import DTensor
from torchtitan.config import JobConfig
from torchtitan.distributed import ParallelDims
from torchtitan.experiments.compiler_toolkit.common_utils import end_with_pass
from torchtitan.tools.logging import logger


Expand Down Expand Up @@ -217,6 +218,7 @@ def compiler(
example_inputs,
passes: List[Callable] = None,
dump_folder: str | None = None,
is_forward: bool = True,
):
"""
Compile a graph module by applying a sequence of compiler passes.
Expand All @@ -239,6 +241,17 @@ def compiler(
)
_dump_gm(dump_folder, gm, f"{name}_before_compiler")

if end_with_pass(passes, ["cudagraph_pass"]):
# cudagraph pass is always the last pass if it is applied
cg_pass = passes[-1]

# to identify static input indices, cudagraph passes behaves differently for
# forward and backward pass. so we explicitly pass the info.
_cg_pass = functools.partial(cg_pass, is_forward=is_forward)

# keep the function name for debug log
passes[-1] = functools.wraps(cg_pass)(_cg_pass)

for pass_fn in passes:
pass_name = (
pass_fn.func.__name__
Expand Down Expand Up @@ -271,17 +284,42 @@ def make_compiler_with_passes(

def fw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None:
return compiler(
"fwd_gm", gm, example_inputs, passes=passes, dump_folder=dump_folder
"fwd_gm",
gm,
example_inputs,
passes=passes,
dump_folder=dump_folder,
is_forward=True,
)

def bw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None:
return compiler(
"bwd_gm", gm, example_inputs, passes=passes, dump_folder=dump_folder
"bwd_gm",
gm,
example_inputs,
passes=passes,
dump_folder=dump_folder,
is_forward=False,
)

return fw_compiler, bw_compiler


def validate_pass_names(pass_names: list[str]) -> None:
if "cudagraph" in pass_names:
assert (
pass_names[-1] == "cudagraph"
), "cudagraph has to be the last pass to apply"

if (
"autobucketing_reordering" in pass_names
and "transformer_block_bucketing" in pass_names
):
raise ValueError(
"Cannot apply autobucketing_reordering and transformer_block_bucketing at the same time!"
)


def get_compiler_passes_from_config(model: torch.nn.Module, job_config: JobConfig):
"""
Extract and validate compiler passes from job config.
Expand All @@ -298,13 +336,7 @@ def get_compiler_passes_from_config(model: torch.nn.Module, job_config: JobConfi
)

pass_names = getattr(job_config.compile, "passes", [])
if (
"autobucketing_reordering" in pass_names
and "transformer_block_bucketing" in pass_names
):
raise ValueError(
"Cannot apply autobucketing_reordering and transformer_block_bucketing at the same time!"
)
validate_pass_names(pass_names)
compiler_passes = []

for pass_name in pass_names:
Expand Down
24 changes: 24 additions & 0 deletions torchtitan/experiments/compiler_toolkit/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,16 @@
during compilation. Passes can be selected and configured via job config.
"""

from typing import Any, Sequence

import torch
from torch._inductor.fx_passes.overlap_manual_scheduling import manual_overlap_bucketing
from torch._inductor.fx_passes.overlap_scheduling import schedule_overlap_bucketing
from torch.fx.passes.regional_inductor import regional_inductor
from torchtitan.experiments.compiler_toolkit.cudagraph import (
CUDAGraphWrapper,
get_static_input_indices,
)
from torchtitan.experiments.simple_fsdp.reshard_after_forward import (
annotate_fsdp_all_gather,
)
Expand Down Expand Up @@ -56,6 +62,23 @@ def regional_inductor_pass(
return regional_inductor(gm, example_inputs)


def cudagraph_pass(
gm: torch.fx.GraphModule, example_inputs: Sequence[Any], is_forward: bool
) -> torch.fx.GraphModule:
"""
Apply cudagraph.

This pass wraps the forward function with cudagraph during compilation and does
not record cudagraph until runtime.
- For the first run, it will warm up operators such as nccl.
- For the second run, it will record cudagraph and replay cudagraph.
- For the following runs, it will replay cudagraph.
"""
static_input_indices = get_static_input_indices(gm, is_forward)
gm.forward = CUDAGraphWrapper(gm.forward, example_inputs, static_input_indices)
return gm


def validate_flex_attn_annotation_pass(
gm: torch.fx.GraphModule,
) -> torch.fx.GraphModule:
Expand Down Expand Up @@ -88,4 +111,5 @@ def fsdp_reshard_after_fwd_pass(
"autobucketing_reordering": autobucketing_reordering_pass,
"transformer_block_bucketing": transformer_block_bucketing_reordering_pass,
"regional_inductor": regional_inductor_pass,
"cudagraph": cudagraph_pass,
}
29 changes: 29 additions & 0 deletions torchtitan/experiments/compiler_toolkit/tests/integration_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,20 @@ def build_compiler_toolkit_test_list() -> list[OverrideDefinitions]:
"llama3_fsdp_tp_manualbucketing",
ngpu=4,
),
OverrideDefinitions(
[
[
"--model.name compiler_toolkit.llama3",
"--parallelism.data_parallel_shard_degree 2",
"--parallelism.tensor_parallel_degree 2",
"--job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config",
"--compile.passes cudagraph",
],
],
"llama3 FSDP+TP+cudagraph",
"llama3_fsdp_tp_cudagraph",
ngpu=4,
),
OverrideDefinitions(
[
[
Expand Down Expand Up @@ -86,6 +100,21 @@ def build_compiler_toolkit_test_list() -> list[OverrideDefinitions]:
"llama3_fsdp_tp_flexattn_autobucketing_regional_inductor",
ngpu=4,
),
OverrideDefinitions(
[
[
"--model.name compiler_toolkit.llama3",
"--parallelism.data_parallel_shard_degree 2",
"--parallelism.tensor_parallel_degree 2",
"--model.flavor debugmodel_flex_attn",
"--job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config",
"--compile.passes autobucketing_reordering,regional_inductor,cudagraph",
],
],
"llama3 FSDP+TP+FlexAttn autobucketing regional_inductor+cudagraph",
"llama3_fsdp_tp_flexattn_autobucketing_regional_inductor_cudagraph",
ngpu=4,
),
OverrideDefinitions(
[
[
Expand Down
Loading