diff --git a/torchtitan/experiments/compiler_toolkit/README.md b/torchtitan/experiments/compiler_toolkit/README.md index 7d00e1f48b..620911ce60 100644 --- a/torchtitan/experiments/compiler_toolkit/README.md +++ b/torchtitan/experiments/compiler_toolkit/README.md @@ -55,3 +55,9 @@ NGPU=8 TRAIN_FILE=torchtitan.experiments.compiler_toolkit.train CONFIG_FILE=./to ```shell NGPU=8 TRAIN_FILE=torchtitan.experiments.compiler_toolkit.train CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config --compile.passes transformer_block_bucketing,regional_inductor ``` + +**SimpleFSDP + TP + FlexAttention + transformer-block-bucketing + regional-inductor + cudagraph** + +```shell +NCCL_GRAPH_REGISTER=0 NGPU=8 TRAIN_FILE=torchtitan.experiments.compiler_toolkit.train CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config --compile.passes transformer_block_bucketing,regional_inductor,cudagraph +``` diff --git a/torchtitan/experiments/compiler_toolkit/common_utils.py b/torchtitan/experiments/compiler_toolkit/common_utils.py index 965e027bdb..997af9a2c4 100644 --- a/torchtitan/experiments/compiler_toolkit/common_utils.py +++ b/torchtitan/experiments/compiler_toolkit/common_utils.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. from contextlib import contextmanager +from typing import Callable import torch from torch.distributed.tensor import DTensor, Replicate @@ -53,3 +54,11 @@ def register_blockmask_pytree_node(): flatten_with_keys_fn=BlockMask._flatten_with_keys, serialized_type_name="torch.nn.attention.flex_attention.BlockMask", ) + + +def end_with_pass(passes: list[Callable], names: list[str]) -> bool: + return ( + len(passes) > 0 + and (last_pass_name := getattr(passes[-1], "__name__", None)) + and (last_pass_name in names) + ) diff --git a/torchtitan/experiments/compiler_toolkit/cudagraph.py b/torchtitan/experiments/compiler_toolkit/cudagraph.py new file mode 100644 index 0000000000..cd6e4cfc22 --- /dev/null +++ b/torchtitan/experiments/compiler_toolkit/cudagraph.py @@ -0,0 +1,169 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +CUDAGraph pass for the compiler toolkit. + +This module provides a cudagraph pass that can be applied to graph modules +during compilation. +""" + +import warnings +from typing import Any, Callable, Optional, Sequence + +import torch +from torch._inductor.cudagraph_trees import _use_cuda_memory_pool_manager +from torch.utils._ordered_set import OrderedSet + + +def init_global_graph_pool() -> tuple[ + torch.cuda.CUDAGraph, torch.cuda._POOL_HANDLE, torch.cuda.Stream +]: + dummy_graph = torch.cuda.CUDAGraph() + + # create a global cudagraph memory pool to allow memory reuse across cudagraphs. + graph_pool = torch.cuda.graph_pool_handle() + + # create a global cuda stream for graph capture. we need to use a single stream + # for all allocations to the memory pool, otherwise the allocations to separate streams + # will not be used. + graph_capture_stream = torch.cuda.Stream() + + # use a dummy graph to keep the global graph pool alive + with ( + # suppress an empty cudagraph warning, since we intentionally create + # an empty cudagraph here + warnings.catch_warnings(record=True), + torch.cuda.graph( + dummy_graph, + pool=graph_pool, + stream=graph_capture_stream, + capture_error_mode="thread_local", + ), + ): + pass + + return dummy_graph, graph_pool, graph_capture_stream + + +( + _global_dummy_graph, + _global_graph_pool, + _global_graph_capture_stream, +) = init_global_graph_pool() + + +class CUDAGraphWrapper: + def __init__( + self, + runnable: Callable, + example_inputs: Sequence[Any], + static_input_indices: Optional[tuple[int]] = None, + should_check_address: bool = False, + ): + self.runnable = runnable + self.graph_pool = _global_graph_pool + self.stream = _global_graph_capture_stream + self.static_input_indices = OrderedSet( + static_input_indices if static_input_indices is not None else [] + ) + self.input_indices_to_copy = [ + i + for i, inp in enumerate(example_inputs) + if isinstance(inp, torch.Tensor) and i not in self.static_input_indices + ] + self.cudagraph: Optional[torch.cuda.CUDAGraph] = None + self.has_warmup = False + + self.args = None + self.output = None + + # (debug only) whether check static input tensor addresses during runtime + self.should_check_address = should_check_address + + def copy_non_static_inputs(self, *args): + for i in self.input_indices_to_copy: + self.args[i].copy_(args[i]) + + def check_input_types(self, inputs) -> None: + for inp in inputs: + assert isinstance(inp, (torch.Tensor, int, torch._C.Generator)), ( + "args must be tensor, integer (for dynamic shapes), " + "or Generator (for random number generator), " + f"but found {type(inp)}" + ) + + def check_static_inputs_address(self) -> None: + for i in self.static_input_indices: + actual = args[i].data_ptr() + expected = self.input_addresses[i] + assert expected == actual, ( + "Expected the same static tensor address but found " + f"{expected} != {actual}" + ) + + def __call__(self, *args): + if not self.has_warmup: + self.has_warmup = True + device = torch.cuda.current_device() + + # warmup in cudagraph memory pool to avoid fragmentation + # across eager memory pool and cudagraph memory pool. + with _use_cuda_memory_pool_manager(device, self.graph_pool, self.stream): + out = self.runnable(*args) + return out + + if self.cudagraph is None: + self.check_input_types(args) + self.args = args + self.input_addresses = [ + x.data_ptr() if isinstance(x, torch.Tensor) else None for x in args + ] + + self.cudagraph = torch.cuda.CUDAGraph() + + with torch.cuda.graph( + self.cudagraph, pool=self.graph_pool, stream=self.stream + ): + # `output` is managed by pytorch's cudagraph pool + self.output = self.runnable(*args) + + if self.should_check_address: + self.check_static_inputs_address() + + self.copy_non_static_inputs(*args) + self.cudagraph.replay() + return self.output + + +def get_static_input_indices(gm: torch.fx.GraphModule, is_forward: bool) -> list[int]: + """ + Get indices of gm inputs that are static input tensors whose tensor addresses do not + change across runs. Example of static input tensors include weights, buffers, and + outputs of previous cudagraph wrapped functions. + """ + from torch._inductor.utils import count_tangents + + static_input_indices = [] + if ( + is_forward + and (tracing_context := torch._guards.TracingContext.try_get()) + and hasattr(tracing_context, "fw_metadata") + ): + # for forward, we rely on graph capture (i.e., dynamo or export) to provide + # the correct static input indices stored in tracing context. Typical examples + # include weights and buffers. + static_input_indices = tracing_context.fw_metadata.static_input_indices + + elif not is_forward: + # for backward, we identify saved tensors as static inputs, since saved tensors + # are outputs of cudagraph-wrapped forward run. In PT2-generated backward gm, + # saved tensors are always the leading args. So we can get the number of saved + # tensors and generate static input indices. + fixed = count_tangents(gm) + static_input_indices = list(range(fixed)) + + return static_input_indices diff --git a/torchtitan/experiments/compiler_toolkit/graph_utils.py b/torchtitan/experiments/compiler_toolkit/graph_utils.py index 51ac8ba983..e097579cc0 100644 --- a/torchtitan/experiments/compiler_toolkit/graph_utils.py +++ b/torchtitan/experiments/compiler_toolkit/graph_utils.py @@ -20,6 +20,7 @@ from torch.distributed.tensor import DTensor from torchtitan.config import JobConfig from torchtitan.distributed import ParallelDims +from torchtitan.experiments.compiler_toolkit.common_utils import end_with_pass from torchtitan.tools.logging import logger @@ -217,6 +218,7 @@ def compiler( example_inputs, passes: List[Callable] = None, dump_folder: str | None = None, + is_forward: bool = True, ): """ Compile a graph module by applying a sequence of compiler passes. @@ -239,6 +241,17 @@ def compiler( ) _dump_gm(dump_folder, gm, f"{name}_before_compiler") + if end_with_pass(passes, ["cudagraph_pass"]): + # cudagraph pass is always the last pass if it is applied + cg_pass = passes[-1] + + # to identify static input indices, cudagraph passes behaves differently for + # forward and backward pass. so we explicitly pass the info. + _cg_pass = functools.partial(cg_pass, is_forward=is_forward) + + # keep the function name for debug log + passes[-1] = functools.wraps(cg_pass)(_cg_pass) + for pass_fn in passes: pass_name = ( pass_fn.func.__name__ @@ -271,17 +284,42 @@ def make_compiler_with_passes( def fw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None: return compiler( - "fwd_gm", gm, example_inputs, passes=passes, dump_folder=dump_folder + "fwd_gm", + gm, + example_inputs, + passes=passes, + dump_folder=dump_folder, + is_forward=True, ) def bw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None: return compiler( - "bwd_gm", gm, example_inputs, passes=passes, dump_folder=dump_folder + "bwd_gm", + gm, + example_inputs, + passes=passes, + dump_folder=dump_folder, + is_forward=False, ) return fw_compiler, bw_compiler +def validate_pass_names(pass_names: list[str]) -> None: + if "cudagraph" in pass_names: + assert ( + pass_names[-1] == "cudagraph" + ), "cudagraph has to be the last pass to apply" + + if ( + "autobucketing_reordering" in pass_names + and "transformer_block_bucketing" in pass_names + ): + raise ValueError( + "Cannot apply autobucketing_reordering and transformer_block_bucketing at the same time!" + ) + + def get_compiler_passes_from_config(model: torch.nn.Module, job_config: JobConfig): """ Extract and validate compiler passes from job config. @@ -298,13 +336,7 @@ def get_compiler_passes_from_config(model: torch.nn.Module, job_config: JobConfi ) pass_names = getattr(job_config.compile, "passes", []) - if ( - "autobucketing_reordering" in pass_names - and "transformer_block_bucketing" in pass_names - ): - raise ValueError( - "Cannot apply autobucketing_reordering and transformer_block_bucketing at the same time!" - ) + validate_pass_names(pass_names) compiler_passes = [] for pass_name in pass_names: diff --git a/torchtitan/experiments/compiler_toolkit/passes.py b/torchtitan/experiments/compiler_toolkit/passes.py index 64276a91bc..5657eb2b2b 100644 --- a/torchtitan/experiments/compiler_toolkit/passes.py +++ b/torchtitan/experiments/compiler_toolkit/passes.py @@ -11,10 +11,16 @@ during compilation. Passes can be selected and configured via job config. """ +from typing import Any, Sequence + import torch from torch._inductor.fx_passes.overlap_manual_scheduling import manual_overlap_bucketing from torch._inductor.fx_passes.overlap_scheduling import schedule_overlap_bucketing from torch.fx.passes.regional_inductor import regional_inductor +from torchtitan.experiments.compiler_toolkit.cudagraph import ( + CUDAGraphWrapper, + get_static_input_indices, +) from torchtitan.experiments.simple_fsdp.reshard_after_forward import ( annotate_fsdp_all_gather, ) @@ -56,6 +62,23 @@ def regional_inductor_pass( return regional_inductor(gm, example_inputs) +def cudagraph_pass( + gm: torch.fx.GraphModule, example_inputs: Sequence[Any], is_forward: bool +) -> torch.fx.GraphModule: + """ + Apply cudagraph. + + This pass wraps the forward function with cudagraph during compilation and does + not record cudagraph until runtime. + - For the first run, it will warm up operators such as nccl. + - For the second run, it will record cudagraph and replay cudagraph. + - For the following runs, it will replay cudagraph. + """ + static_input_indices = get_static_input_indices(gm, is_forward) + gm.forward = CUDAGraphWrapper(gm.forward, example_inputs, static_input_indices) + return gm + + def validate_flex_attn_annotation_pass( gm: torch.fx.GraphModule, ) -> torch.fx.GraphModule: @@ -88,4 +111,5 @@ def fsdp_reshard_after_fwd_pass( "autobucketing_reordering": autobucketing_reordering_pass, "transformer_block_bucketing": transformer_block_bucketing_reordering_pass, "regional_inductor": regional_inductor_pass, + "cudagraph": cudagraph_pass, } diff --git a/torchtitan/experiments/compiler_toolkit/tests/integration_tests.py b/torchtitan/experiments/compiler_toolkit/tests/integration_tests.py index b0155a9f2a..f01a1c4380 100644 --- a/torchtitan/experiments/compiler_toolkit/tests/integration_tests.py +++ b/torchtitan/experiments/compiler_toolkit/tests/integration_tests.py @@ -58,6 +58,20 @@ def build_compiler_toolkit_test_list() -> list[OverrideDefinitions]: "llama3_fsdp_tp_manualbucketing", ngpu=4, ), + OverrideDefinitions( + [ + [ + "--model.name compiler_toolkit.llama3", + "--parallelism.data_parallel_shard_degree 2", + "--parallelism.tensor_parallel_degree 2", + "--job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config", + "--compile.passes cudagraph", + ], + ], + "llama3 FSDP+TP+cudagraph", + "llama3_fsdp_tp_cudagraph", + ngpu=4, + ), OverrideDefinitions( [ [ @@ -86,6 +100,21 @@ def build_compiler_toolkit_test_list() -> list[OverrideDefinitions]: "llama3_fsdp_tp_flexattn_autobucketing_regional_inductor", ngpu=4, ), + OverrideDefinitions( + [ + [ + "--model.name compiler_toolkit.llama3", + "--parallelism.data_parallel_shard_degree 2", + "--parallelism.tensor_parallel_degree 2", + "--model.flavor debugmodel_flex_attn", + "--job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config", + "--compile.passes autobucketing_reordering,regional_inductor,cudagraph", + ], + ], + "llama3 FSDP+TP+FlexAttn autobucketing regional_inductor+cudagraph", + "llama3_fsdp_tp_flexattn_autobucketing_regional_inductor_cudagraph", + ngpu=4, + ), OverrideDefinitions( [ [ diff --git a/torchtitan/experiments/compiler_toolkit/train.py b/torchtitan/experiments/compiler_toolkit/train.py index 26e3245b2b..7b0d58aa5a 100644 --- a/torchtitan/experiments/compiler_toolkit/train.py +++ b/torchtitan/experiments/compiler_toolkit/train.py @@ -4,11 +4,24 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import gc + from torchtitan.train import main, Trainer class CompilerToolkitTrainer(Trainer): - pass + def close(self) -> None: + super().close() + + # Note [explicit cudagraph close] + # cudagraph holds reference to nccl which prevents destroy nccl + # group. so we need to explicitly delete cudagraph which is held + # in joint_graph_module. An explicit gc.collect() is necessary + # to clean up reference cycles. + for part in self.model_parts: + if hasattr(part, "joint_graph_module"): + part.joint_graph_module = None + gc.collect() if __name__ == "__main__":