Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion run_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,6 @@ TORCHFT_LIGHTHOUSE=${TORCHFT_LIGHTHOUSE:-"http://localhost:29510"}

PYTORCH_ALLOC_CONF="expandable_segments:True" \
TORCHFT_LIGHTHOUSE=${TORCHFT_LIGHTHOUSE} \
torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \
torchrun --virtual-local-rank --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \
--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
-m ${TRAIN_FILE} --job.config_file ${CONFIG_FILE} "$@"
5 changes: 5 additions & 0 deletions torchtitan/config/job_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,7 +660,12 @@ class Compile:
default_factory=lambda: ["model", "loss"]
)
"""Which components to compile"""

backend: str = "inductor"
"""Which backend to compile with"""

enable_precompilation: bool = False
"""Enable AOT precompilation to save compiled function to disk"""


@dataclass
Expand Down
15 changes: 14 additions & 1 deletion torchtitan/experiments/simple_fsdp/simple_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import sys
from collections.abc import Sequence
from contextlib import contextmanager
from dataclasses import dataclass
Expand Down Expand Up @@ -159,6 +160,8 @@ def _distribute_dtensor(
)


# global counter to ensure unique class names for precompilation
_wrap_class_counter = 0
def _register_parametrization(
module: nn.Module, param_names: list[str], parametrization: nn.Module
) -> None:
Expand All @@ -169,18 +172,28 @@ def _register_parametrization(
TODO: In checkpoint saving/loading, avoid parametrization calls when calling
get_model_state_dict func in torchtitan's torchtitan/components/checkpoint.py.
"""
global _wrap_class_counter
_wrap_class_counter += 1
param_name_to_property = {
param_name: property(
lambda self, pn=param_name: parametrization(self._parameters[pn])
)
for param_name in param_names
}
module_cls = type(
f"SimpleFSDP{module.__class__.__name__}",
f"SimpleFSDP{module.__class__.__name__}_{_wrap_class_counter}",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(module.__class__,),
param_name_to_property,
)
module.__class__ = module_cls
module_cls.__module__ = module.__class__.__module__
# Expose the dynamically created subclass as a real, importable symbol in its module.
# This is necessary for precompilation to work.
setattr(
sys.modules[module_cls.__module__],
module_cls.__name__,
module_cls,
)


def fsdp_policy():
Expand Down
60 changes: 58 additions & 2 deletions torchtitan/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ def __init__(self, job_config: JobConfig):

self.job_config = job_config

if job_config.compile.enable_precompilation:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

qq. Is this for simplefsdp-only or also works for fsdp2+block-level compile?

maybe you want to add this config to apply_compile here for fsdp2:

def apply_compile(model: nn.Module, compile_config: CompileConfig):
"""
Apply torch.compile to each TransformerBlock, which makes compilation efficient due to
repeated structure. Alternatively one can compile the whole model (after applying DP).
"""
for layer_id, transformer_block in model.layers.named_children():
transformer_block = torch.compile(
transformer_block, backend=compile_config.backend, fullgraph=True
)
model.layers.register_module(layer_id, transformer_block)
logger.info("Compiling each TransformerBlock with torch.compile")
; and here for simplefsdp: https://github.com/pytorch/torchtitan/blob/main/torchtitan/experiments/simple_fsdp/llama3/parallelize.py#L151-L152?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently only simplefsdp but this should work with fsdp2+block-level compile with some additional work.

torch._dynamo.config.enable_aot_compile = True

logger.info(f"Starting job: {job_config.job.description}")

if job_config.experimental.custom_import:
Expand Down Expand Up @@ -380,6 +383,53 @@ def init_distributed(self) -> ParallelDims:
world_size=world_size,
)

def _get_precompiled_function_path(self) -> str:
"""
Generate a unique path for the precompiled function based on model configuration.

Returns:
Path to the precompiled function file.
"""
rank = int(os.environ["RANK"])
model_name = self.job_config.model.name
model_flavor = self.job_config.model.flavor

# Create a unique filename based on model configuration and rank
filename = f"compiled_fn_{model_name}_{model_flavor}_rank_{rank}.pt"
return os.path.join("/tmp", filename)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't a realistic file path for training on FB infra, as the tmp is cleared if you restart training

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. For FB infra, we would either package the artifact into the conda or fbpkg build, or place it in oilfs and keep a reference to it. For Torchtitan, using /tmp seemed acceptable, though I can make the location configurable through an environment variable. Did you have a different approach in mind?


def _load_or_compile_model(
self,
model: torch.nn.Module,
inputs: torch.Tensor,
extra_inputs: dict[str, torch.Tensor],
extra_kwargs: dict[str, Any],
) -> torch.Tensor:
"""
Load a precompiled model function or compile and save it if not available.

Args:
model: The model to compile.
inputs: Main input tensor.
extra_inputs: Additional input tensors.
extra_kwargs: Additional keyword arguments.

Returns:
Model output predictions.
"""
compiled_fn_path = self._get_precompiled_function_path()

if not os.path.exists(compiled_fn_path):
logger.info(f"Compiling model and saving to {compiled_fn_path}")
model.forward \
.aot_compile(((inputs,), {**extra_inputs, **extra_kwargs})) \
.save_compiled_function(compiled_fn_path)

with open(compiled_fn_path, "rb") as f:
return torch.compiler.load_compiled_function(f)(
model._orig_mod, inputs, **extra_inputs, **extra_kwargs
)

def batch_generator(
self, data_iterable: Iterable[tuple[dict[str, torch.Tensor], torch.Tensor]]
) -> Iterable[tuple[dict[str, torch.Tensor], torch.Tensor]]:
Expand Down Expand Up @@ -524,8 +574,14 @@ def forward_backward_step(
with self.train_context(optional_context_parallel_ctx):
assert len(model_parts) == 1
with self.maybe_enable_amp:
pred = model_parts[0](inputs, **extra_inputs, **extra_kwargs)
loss = self.loss_fn(pred, labels)
if self.job_config.compile.enable_precompilation:
pred = self._load_or_compile_model(
model_parts[0], inputs, extra_inputs, extra_kwargs
)
loss = self.loss_fn(pred, labels)
else:
pred = model_parts[0](inputs, **extra_inputs, **extra_kwargs)
loss = self.loss_fn(pred, labels)
# need to free pred before bwd to avoid peaking memory
del pred
loss.backward()
Expand Down