diff --git a/.github/workflows/integration_test_8gpu_full_dtensor.yaml b/.github/workflows/integration_test_8gpu_full_dtensor.yaml new file mode 100644 index 000000000..adf6b20c5 --- /dev/null +++ b/.github/workflows/integration_test_8gpu_full_dtensor.yaml @@ -0,0 +1,53 @@ +name: Full DTensor 8 GPU Integration Tests + +on: + push: + branches: [ main ] + paths: + - 'torchtitan/experiments/full_dtensor/**' + - '.github/workflows/integration_test_8gpu_full_dtensor.yaml' + pull_request: + paths: + - 'torchtitan/experiments/full_dtensor/**' + - '.github/workflows/integration_test_8gpu_full_dtensor.yaml' + schedule: + # Runs every 12 hours + - cron: '0 */12 * * *' + +concurrency: + group: unit-test${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_number || github.ref }} + cancel-in-progress: true + +defaults: + run: + shell: bash -l -eo pipefail {0} + +jobs: + build-test: + uses: pytorch/test-infra/.github/workflows/linux_job.yml@main + with: + runner: linux.g5.48xlarge.nvidia.gpu + gpu-arch-type: cuda + gpu-arch-version: "12.6" + # This image is faster to clone than the default, but it lacks CC needed by triton + # (1m25s vs 2m37s). + docker-image: torchtitan-ubuntu-20.04-clang12 + repository: pytorch/torchtitan + upload-artifact: outputs + script: | + set -eux + + # The generic Linux job chooses to use base env, not the one setup by the image + CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]") + conda activate "${CONDA_ENV}" + + # Log CUDA driver version for debugging. + DRIVER_VERSION=$(nvidia-smi --query-gpu=driver_version --format=csv,noheader | head -n 1 || true) + echo "CUDA driver version: ${DRIVER_VERSION}" + + pip config --user set global.progress_bar off + + python -m pip install --force-reinstall --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126 + + mkdir artifacts-to-be-uploaded + TRAIN_FILE=torchtitan.experiments.full_dtensor.train python -m torchtitan.experiments.full_dtensor.tests.integration_tests artifacts-to-be-uploaded --ngpu 8 diff --git a/torchtitan/experiments/__init__.py b/torchtitan/experiments/__init__.py index f6f813bfa..49ca1b9f9 100644 --- a/torchtitan/experiments/__init__.py +++ b/torchtitan/experiments/__init__.py @@ -9,6 +9,7 @@ "gpt_oss", "simple_fsdp.llama3", "simple_fsdp.deepseek_v3", + "full_dtensor.llama3", "vlm", "compiler_toolkit.deepseek_v3", "compiler_toolkit.llama3", diff --git a/torchtitan/experiments/full_dtensor/llama3/__init__.py b/torchtitan/experiments/full_dtensor/llama3/__init__.py new file mode 100644 index 000000000..cdfa0b90b --- /dev/null +++ b/torchtitan/experiments/full_dtensor/llama3/__init__.py @@ -0,0 +1,31 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from torchtitan.components.loss import build_cross_entropy_loss +from torchtitan.components.lr_scheduler import build_lr_schedulers +from torchtitan.components.optimizer import build_optimizers +from torchtitan.components.tokenizer import build_hf_tokenizer +from torchtitan.distributed.pipeline_parallel import pipeline_llm +from torchtitan.experiments.simple_fsdp.llama3.model import SimpleFSDPTransformer +from torchtitan.hf_datasets.text_datasets import build_text_dataloader +from torchtitan.models.llama3 import llama3_args +from torchtitan.protocols.train_spec import TrainSpec + +from .parallelize import parallelize_llama + + +def get_train_spec() -> TrainSpec: + return TrainSpec( + model_cls=SimpleFSDPTransformer, + model_args=llama3_args, + parallelize_fn=parallelize_llama, + pipelining_fn=pipeline_llm, + build_optimizers_fn=build_optimizers, + build_lr_schedulers_fn=build_lr_schedulers, + build_dataloader_fn=build_text_dataloader, + build_tokenizer_fn=build_hf_tokenizer, + build_loss_fn=build_cross_entropy_loss, + ) diff --git a/torchtitan/experiments/full_dtensor/llama3/parallelize.py b/torchtitan/experiments/full_dtensor/llama3/parallelize.py new file mode 100644 index 000000000..0cc952537 --- /dev/null +++ b/torchtitan/experiments/full_dtensor/llama3/parallelize.py @@ -0,0 +1,264 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import functools +from typing import Any, Callable + +import torch +import torch.nn as nn +from torch.distributed.tensor import DeviceMesh, DTensor, Replicate, Shard + +from torchtitan.config import JobConfig, TORCH_DTYPE_MAP +from torchtitan.distributed import ParallelDims + +from torchtitan.experiments.compiler_toolkit.graph_utils import ( + CompiledModule, + joint_graph_builder, + make_compiler_with_passes, +) +from torchtitan.experiments.simple_fsdp.reshard_after_forward import ( + annotate_fsdp_all_gather, +) +from torchtitan.experiments.simple_fsdp.simple_fsdp import ( + data_parallel, + MixedPrecisionPolicy, +) +from torchtitan.tools.logging import logger + + +def _get_dp_mesh(parallel_dims: ParallelDims) -> DeviceMesh: + if parallel_dims.dp_replicate_enabled: + if parallel_dims.dp_shard_enabled or parallel_dims.cp_enabled: + dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp") + else: + dp_mesh_dim_names = ("dp_replicate",) + else: + dp_mesh_dim_names = ("dp_shard_cp",) + + return parallel_dims.world_mesh[tuple(dp_mesh_dim_names)] + + +def _get_spmd_mesh(parallel_dims: ParallelDims) -> DeviceMesh: + return _get_dp_mesh(parallel_dims) + + +def apply_dp( + model: nn.Module, + parallel_dims: ParallelDims, + job_config: JobConfig, +) -> nn.Module: + if parallel_dims.dp_replicate_enabled: + if parallel_dims.dp_shard_enabled or parallel_dims.cp_enabled: + dp_mode = "hybrid_shard" + else: + dp_mode = "replicate" + else: + dp_mode = "fully_shard" + + mp_policy = MixedPrecisionPolicy( + param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], + reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], + ) + + model = data_parallel( + model, + _get_dp_mesh(parallel_dims), + mode=dp_mode, + mp_policy=mp_policy, + full_dtensor=True, + ) + logger.info( + "Applied Data Parallel (simple_fsdp) (dp mode=%s) to the model", dp_mode + ) + return model + + +def parallelize_llama( + model: nn.Module, + parallel_dims: ParallelDims, + job_config: JobConfig, +) -> nn.Module: + if parallel_dims.cp_enabled: + # TODO: SDPA + CP enablement: + # Dependency: https://github.com/pytorch/pytorch/pull/167381 (sharding rule fix) + # Goal: Enable Shard(2) -> Replicate() transition on the CP mesh placement. + # + # Implementation options for handling the required allgather: + # 1. Transform into ring attention (requires converting current implementation + # to an async TP-like operation) + # 2. Retain explicit allgather (approach used in Llama 4) + + # TODO: FlexAttention + CP enablement: + # Need to resolve DTensor + FlexAttention compatibility issues. + + raise NotImplementedError("CP is not implemented yet.") + + if parallel_dims.tp_enabled: + # TODO: TP parallelization strategy - Key architectural decision: + # + # Option 1: Parallelize parameters directly (current design approach) + # - Apply TP dimension immediately at this point + # - Requires _StridedShard for implementation + # + # Option 2: Record the placement and apply full placements later + # - Record TP dimension placement now, apply full placement later with DP dimension + # - No need to use _StridedShard, we can just use Shard() + # + # It's mostly likely that we will go with option 2 as we are going to use + # parameterization to handle the full parameters transformation, which + # makes option 2 more natural. + raise NotImplementedError("TP is not implemented yet.") + + if job_config.activation_checkpoint.mode != "none": + # TODO: Graph based AC. + raise NotImplementedError("AC is not implemented yet.") + + # TODO: CP integration challenge: + # + # Problem: + # When CP is enabled, the mesh structure becomes ["dp_replicate", "dp_shard", "cp"] + # to maintain sequence sharding in DTensor. However, naively applying data_parallel + # may trigger two separate allgather operations because DTensor.redistribute cannot + # recognize that the two mesh dimensions can be flattened into a single allgather. + # + # Potential solution using SimpleFSDP: + # 1. Transform mesh: ["dp_replicate", "dp_shard", "cp"] -> ["dp_replicate", "dp_shard_cp"] + # via to_local() and from_local() + # 2. Redistribute placement on ["dp_shard_cp"] dimension + # 3. Transform mesh back: ["dp_replicate", "dp_shard_cp"] -> ["dp_replicate", "dp_shard", "cp"] + # via to_local() and from_local() + # + # Note: This solution leaves the dp_shard process group wasted ( + # we can initialize it with fake backend). + # + # Note: We may be able to implement this solution with parameterization directly. + + # Keep cp_enabled here to remind us cp_enabled=True requires data_parallel + if ( + parallel_dims.dp_replicate_enabled + or parallel_dims.dp_shard_enabled + or parallel_dims.cp_enabled + ): + model = apply_dp(model, parallel_dims, job_config) + + # Apply compilation after SPMD parallelization is complete. This differs from + # eager mode parallelization where compilation occurs earlier. + return apply_compile(model, parallel_dims, job_config) + + +def parallelize_buffers( + model: nn.Module, + parallel_dims: ParallelDims, +) -> nn.Module: + # Buffer-to-mesh mapping in multi-SPMD scenarios: + # + # When buffers are used with different SPMD meshes (e.g., dense vs sparse meshes), we + # will need an explicit mapping to associate each buffer with its corresponding mesh. + # This indicates that the current implementation is not general enough to support + # nD meshes. + # + # The solution is that we need to reparameterize the buffers together with the + # parameters within a module. + spmd_mesh = _get_spmd_mesh(parallel_dims) + placements = (Replicate() for _ in range(spmd_mesh.ndim)) + for m in model.modules(): + buffers = { + name: DTensor.from_local(b, spmd_mesh, placements) + for name, b in m.named_buffers(recurse=False) + } + for name, b in buffers.items(): + setattr(m, name, b) + + return model + + +def build_parallelize_inputs_fn( + parallel_dims: ParallelDims, +) -> Callable[[torch.Tensor, torch.Tensor], tuple[DTensor, DTensor]]: + spmd_mesh = _get_spmd_mesh(parallel_dims) + + # TODO: We need to make this more general to support nD mesh. But we can do this + # after the DeviceMesh revamp PR is landed. + spmd_placements = [] + if parallel_dims.dp_shard_enabled or parallel_dims.cp_enabled: + spmd_placements.append(Shard(0)) + if parallel_dims.dp_replicate_enabled: + spmd_placements.append(Shard(0)) + + def parallelize_inputs( + inputs: torch.Tensor, labels: torch.Tensor + ) -> tuple[DTensor, DTensor]: + inputs = DTensor.from_local(inputs, spmd_mesh, spmd_placements) + labels = DTensor.from_local(labels, spmd_mesh, spmd_placements) + return inputs, labels + + return parallelize_inputs + + +def joint_custom_pass_builder( + parallel_dims: ParallelDims, job_config: JobConfig +) -> Callable: + match job_config.parallelism.fsdp_reshard_after_forward: + case "always": + fsdp_reshard_after_forward = True + case "never": + fsdp_reshard_after_forward = False + case "default": + # For PP, by default do not reshard after forward to avoid per-microbatch + # all-gathers, which can be expensive and non-overlapped + fsdp_reshard_after_forward = not parallel_dims.pp_enabled + case _: + raise ValueError( + "Invalid fsdp_reshard_after_forward_policy: " + f"{job_config.parallelism.fsdp_reshard_after_forward}." + ) + + def joint_ac_pass(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: + gm = annotate_fsdp_all_gather(gm, fsdp_reshard_after_forward) + gm.recompile() + return gm + + def joint_custom_pass(joint_with_descriptors) -> None: + # TODO: Is this safe? Or should we use update_joint_with_descriptors from auto_parallel? + joint_with_descriptors.graph_module = joint_ac_pass( + joint_with_descriptors.graph_module + ) + + +def apply_compile( + model: nn.Module, parallel_dims: ParallelDims, job_config: JobConfig +) -> nn.Module: + # TODO: This API just implements compiler toolkit. + # We should also add torch.compile() support + + if not (job_config.compile.enable and "model" in job_config.compile.passes): + return model + + compiler_passes = [] + # Create compilers with specified passes (defaults to no passes) + fw_compiler, bw_compiler = make_compiler_with_passes( + compiler_passes, dump_folder=job_config.job.dump_folder + ) + + # Create custom joint_graph_builder with llama-specific compilers and validation + llama_joint_graph_builder = functools.partial( + joint_graph_builder, + fw_compiler=fw_compiler, + bw_compiler=bw_compiler, + joint_custom_pass=joint_custom_pass_builder(parallel_dims, job_config), + dump_folder=job_config.job.dump_folder, + ) + + # Full DTensor trainer will convert the inputs to DTensor, so we don't + # need CompiledModule to do it. + def dummy_parallelize_inputs( + mesh: DeviceMesh, args: tuple[Any, ...], kwargs: dict[str, Any] + ) -> tuple[tuple[Any, ...], dict[str, Any]]: + return args, kwargs + + return CompiledModule( + model, parallel_dims, llama_joint_graph_builder, dummy_parallelize_inputs + ) diff --git a/torchtitan/experiments/full_dtensor/tests/__init__.py b/torchtitan/experiments/full_dtensor/tests/__init__.py new file mode 100644 index 000000000..2e41cd717 --- /dev/null +++ b/torchtitan/experiments/full_dtensor/tests/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/torchtitan/experiments/full_dtensor/tests/integration_tests.py b/torchtitan/experiments/full_dtensor/tests/integration_tests.py new file mode 100644 index 000000000..46ef7e186 --- /dev/null +++ b/torchtitan/experiments/full_dtensor/tests/integration_tests.py @@ -0,0 +1,68 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import os + +from tests.integration_tests import OverrideDefinitions +from tests.integration_tests.run_tests import run_tests + + +def build_full_dtensor_test_list() -> list[OverrideDefinitions]: + """ + returns a list of OverrideDefinitions that is used to generate + variations of integration tests based on the same root config file. + """ + integration_tests_flavors = [ + # llama3 FSDP test + OverrideDefinitions( + [ + [ + "--model.name full_dtensor.llama3", + "--parallelism.data_parallel_shard_degree 8", + "--activation_checkpoint.mode none", + ], + ], + "llama3 full dtensor fsdp", + "llama3_full_dtensor_fsdp", + ngpu=8, + ), + ] + return integration_tests_flavors + + +_TEST_SUITES_FUNCTION = { + "full_dtensor": build_full_dtensor_test_list, +} + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("output_dir") + parser.add_argument( + "--config_path", + default="./tests/integration_tests/base_config.toml", + help="Base config path for integration tests. This is the config that will be used as a base for all tests.", + ) + parser.add_argument( + "--test_name", + default="all", + help="test to run, acceptable values: `test_name` in `build_test_list` (default: all)", + ) + parser.add_argument("--ngpu", default=8, type=int) + args = parser.parse_args() + + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir) + if os.listdir(args.output_dir): + raise RuntimeError("Please provide an empty output directory.") + + test_list = _TEST_SUITES_FUNCTION["full_dtensor"]() + run_tests(args, test_list) + + +if __name__ == "__main__": + main() diff --git a/torchtitan/experiments/full_dtensor/train.py b/torchtitan/experiments/full_dtensor/train.py new file mode 100644 index 000000000..8ee4635a3 --- /dev/null +++ b/torchtitan/experiments/full_dtensor/train.py @@ -0,0 +1,54 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import Any + +import torch + +from torchtitan.train import main, Trainer + +from .llama3.parallelize import build_parallelize_inputs_fn, parallelize_buffers + + +class FullDTensorTrainer(Trainer): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + # parallelize_buffers must be called here instead of in parallelize_fn because + # buffers are re-initialized after parallelize_fn executes. The current buffer + # initialization creates regular tensors rather than DTensors. + # NOTE: This function will likely be removed when we rewrite parameterization. + for m in self.model_parts: + parallelize_buffers(m, self.parallel_dims) + + self.parallelize_inputs = build_parallelize_inputs_fn(self.parallel_dims) + + def post_dataloading_process( + self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, dict[str, torch.Tensor], dict[str, Any]]: + inputs = input_dict["input"] + + extra_inputs = {k: v for k, v in input_dict.items() if k != "input"} + # Arguments like attention_masks must be in a separate dict since extra_inputs + # are not forwarded to other stages in PP, but extra_kwargs are. + extra_kwargs: dict[str, Any] = {} + + # We could let the model perform parallelize_inputs, but calling it here in the + # trainer preserves the potential of implementing dataloading pipelining, + # # which offloads logic (e.g., CP load # balancing) to CPU and overlaps it with + # the previous forward(). We also need to consider how PP shards inputs along + # the batch dimension. For now, keep # this function callsite in the trainer. + inputs, labels = self.parallelize_inputs(inputs, labels) + + assert isinstance(inputs, torch.distributed.tensor.DTensor), type(inputs) + assert isinstance(labels, torch.distributed.tensor.DTensor), type(labels) + + return inputs, labels, extra_inputs, extra_kwargs + + +if __name__ == "__main__": + main(FullDTensorTrainer)