diff --git a/run_train.sh b/run_train.sh index 8aaf55de28..8981401ad7 100755 --- a/run_train.sh +++ b/run_train.sh @@ -19,6 +19,6 @@ TORCHFT_LIGHTHOUSE=${TORCHFT_LIGHTHOUSE:-"http://localhost:29510"} PYTORCH_ALLOC_CONF="expandable_segments:True" \ TORCHFT_LIGHTHOUSE=${TORCHFT_LIGHTHOUSE} \ -torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \ +torchrun --virtual-local-rank --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \ --local-ranks-filter ${LOG_RANK} --role rank --tee 3 \ -m ${TRAIN_FILE} --job.config_file ${CONFIG_FILE} "$@" diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index 95588d2c3b..4e296b4d84 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -660,7 +660,12 @@ class Compile: default_factory=lambda: ["model", "loss"] ) """Which components to compile""" + backend: str = "inductor" + """Which backend to compile with""" + + enable_precompilation: bool = False + """Enable AOT precompilation to save compiled function to disk""" @dataclass diff --git a/torchtitan/experiments/simple_fsdp/simple_fsdp.py b/torchtitan/experiments/simple_fsdp/simple_fsdp.py index 737b6d3ec2..e0a295262b 100644 --- a/torchtitan/experiments/simple_fsdp/simple_fsdp.py +++ b/torchtitan/experiments/simple_fsdp/simple_fsdp.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import sys from collections.abc import Sequence from contextlib import contextmanager from dataclasses import dataclass @@ -159,6 +160,8 @@ def _distribute_dtensor( ) +# global counter to ensure unique class names for precompilation +_wrap_class_counter = 0 def _register_parametrization( module: nn.Module, param_names: list[str], parametrization: nn.Module ) -> None: @@ -169,6 +172,8 @@ def _register_parametrization( TODO: In checkpoint saving/loading, avoid parametrization calls when calling get_model_state_dict func in torchtitan's torchtitan/components/checkpoint.py. """ + global _wrap_class_counter + _wrap_class_counter += 1 param_name_to_property = { param_name: property( lambda self, pn=param_name: parametrization(self._parameters[pn]) @@ -176,11 +181,19 @@ def _register_parametrization( for param_name in param_names } module_cls = type( - f"SimpleFSDP{module.__class__.__name__}", + f"SimpleFSDP{module.__class__.__name__}_{_wrap_class_counter}", (module.__class__,), param_name_to_property, ) module.__class__ = module_cls + module_cls.__module__ = module.__class__.__module__ + # Expose the dynamically created subclass as a real, importable symbol in its module. + # This is necessary for precompilation to work. + setattr( + sys.modules[module_cls.__module__], + module_cls.__name__, + module_cls, + ) def fsdp_policy(): diff --git a/torchtitan/train.py b/torchtitan/train.py index 0070806e94..d903600ffc 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -74,6 +74,9 @@ def __init__(self, job_config: JobConfig): self.job_config = job_config + if job_config.compile.enable_precompilation: + torch._dynamo.config.enable_aot_compile = True + logger.info(f"Starting job: {job_config.job.description}") if job_config.experimental.custom_import: @@ -380,6 +383,53 @@ def init_distributed(self) -> ParallelDims: world_size=world_size, ) + def _get_precompiled_function_path(self) -> str: + """ + Generate a unique path for the precompiled function based on model configuration. + + Returns: + Path to the precompiled function file. + """ + rank = int(os.environ["RANK"]) + model_name = self.job_config.model.name + model_flavor = self.job_config.model.flavor + + # Create a unique filename based on model configuration and rank + filename = f"compiled_fn_{model_name}_{model_flavor}_rank_{rank}.pt" + return os.path.join("/tmp", filename) + + def _load_or_compile_model( + self, + model: torch.nn.Module, + inputs: torch.Tensor, + extra_inputs: dict[str, torch.Tensor], + extra_kwargs: dict[str, Any], + ) -> torch.Tensor: + """ + Load a precompiled model function or compile and save it if not available. + + Args: + model: The model to compile. + inputs: Main input tensor. + extra_inputs: Additional input tensors. + extra_kwargs: Additional keyword arguments. + + Returns: + Model output predictions. + """ + compiled_fn_path = self._get_precompiled_function_path() + + if not os.path.exists(compiled_fn_path): + logger.info(f"Compiling model and saving to {compiled_fn_path}") + model.forward \ + .aot_compile(((inputs,), {**extra_inputs, **extra_kwargs})) \ + .save_compiled_function(compiled_fn_path) + + with open(compiled_fn_path, "rb") as f: + return torch.compiler.load_compiled_function(f)( + model._orig_mod, inputs, **extra_inputs, **extra_kwargs + ) + def batch_generator( self, data_iterable: Iterable[tuple[dict[str, torch.Tensor], torch.Tensor]] ) -> Iterable[tuple[dict[str, torch.Tensor], torch.Tensor]]: @@ -524,8 +574,14 @@ def forward_backward_step( with self.train_context(optional_context_parallel_ctx): assert len(model_parts) == 1 with self.maybe_enable_amp: - pred = model_parts[0](inputs, **extra_inputs, **extra_kwargs) - loss = self.loss_fn(pred, labels) + if self.job_config.compile.enable_precompilation: + pred = self._load_or_compile_model( + model_parts[0], inputs, extra_inputs, extra_kwargs + ) + loss = self.loss_fn(pred, labels) + else: + pred = model_parts[0](inputs, **extra_inputs, **extra_kwargs) + loss = self.loss_fn(pred, labels) # need to free pred before bwd to avoid peaking memory del pred loss.backward()