-
Notifications
You must be signed in to change notification settings - Fork 619
[precompile] add ability to precompile torchtitan models #2092
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -74,6 +74,9 @@ def __init__(self, job_config: JobConfig): | |||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| self.job_config = job_config | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| if job_config.compile.enable_precompilation: | ||||||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. qq. Is this for simplefsdp-only or also works for fsdp2+block-level compile? maybe you want to add this config to torchtitan/torchtitan/models/llama3/infra/parallelize.py Lines 236 to 247 in cbdb311
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Currently only simplefsdp but this should work with fsdp2+block-level compile with some additional work. |
||||||||||||||||||||||||||
| torch._dynamo.config.enable_aot_compile = True | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| logger.info(f"Starting job: {job_config.job.description}") | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| if job_config.experimental.custom_import: | ||||||||||||||||||||||||||
|
|
@@ -380,6 +383,53 @@ def init_distributed(self) -> ParallelDims: | |||||||||||||||||||||||||
| world_size=world_size, | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| def _get_precompiled_function_path(self) -> str: | ||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||
| Generate a unique path for the precompiled function based on model configuration. | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| Returns: | ||||||||||||||||||||||||||
| Path to the precompiled function file. | ||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||
| rank = int(os.environ["RANK"]) | ||||||||||||||||||||||||||
| model_name = self.job_config.model.name | ||||||||||||||||||||||||||
| model_flavor = self.job_config.model.flavor | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # Create a unique filename based on model configuration and rank | ||||||||||||||||||||||||||
| filename = f"compiled_fn_{model_name}_{model_flavor}_rank_{rank}.pt" | ||||||||||||||||||||||||||
| return os.path.join("/tmp", filename) | ||||||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This isn't a realistic file path for training on FB infra, as the tmp is cleared if you restart training
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agreed. For FB infra, we would either package the artifact into the conda or fbpkg build, or place it in oilfs and keep a reference to it. For Torchtitan, using /tmp seemed acceptable, though I can make the location configurable through an environment variable. Did you have a different approach in mind? |
||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| def _load_or_compile_model( | ||||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||||
| model: torch.nn.Module, | ||||||||||||||||||||||||||
| inputs: torch.Tensor, | ||||||||||||||||||||||||||
| extra_inputs: dict[str, torch.Tensor], | ||||||||||||||||||||||||||
| extra_kwargs: dict[str, Any], | ||||||||||||||||||||||||||
| ) -> torch.Tensor: | ||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||
| Load a precompiled model function or compile and save it if not available. | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| Args: | ||||||||||||||||||||||||||
| model: The model to compile. | ||||||||||||||||||||||||||
| inputs: Main input tensor. | ||||||||||||||||||||||||||
| extra_inputs: Additional input tensors. | ||||||||||||||||||||||||||
| extra_kwargs: Additional keyword arguments. | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| Returns: | ||||||||||||||||||||||||||
| Model output predictions. | ||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||
| compiled_fn_path = self._get_precompiled_function_path() | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| if not os.path.exists(compiled_fn_path): | ||||||||||||||||||||||||||
| logger.info(f"Compiling model and saving to {compiled_fn_path}") | ||||||||||||||||||||||||||
| model.forward \ | ||||||||||||||||||||||||||
| .aot_compile(((inputs,), {**extra_inputs, **extra_kwargs})) \ | ||||||||||||||||||||||||||
| .save_compiled_function(compiled_fn_path) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| with open(compiled_fn_path, "rb") as f: | ||||||||||||||||||||||||||
| return torch.compiler.load_compiled_function(f)( | ||||||||||||||||||||||||||
| model._orig_mod, inputs, **extra_inputs, **extra_kwargs | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| def batch_generator( | ||||||||||||||||||||||||||
| self, data_iterable: Iterable[tuple[dict[str, torch.Tensor], torch.Tensor]] | ||||||||||||||||||||||||||
| ) -> Iterable[tuple[dict[str, torch.Tensor], torch.Tensor]]: | ||||||||||||||||||||||||||
|
|
@@ -524,8 +574,14 @@ def forward_backward_step( | |||||||||||||||||||||||||
| with self.train_context(optional_context_parallel_ctx): | ||||||||||||||||||||||||||
| assert len(model_parts) == 1 | ||||||||||||||||||||||||||
| with self.maybe_enable_amp: | ||||||||||||||||||||||||||
| pred = model_parts[0](inputs, **extra_inputs, **extra_kwargs) | ||||||||||||||||||||||||||
| loss = self.loss_fn(pred, labels) | ||||||||||||||||||||||||||
| if self.job_config.compile.enable_precompilation: | ||||||||||||||||||||||||||
| pred = self._load_or_compile_model( | ||||||||||||||||||||||||||
| model_parts[0], inputs, extra_inputs, extra_kwargs | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
| loss = self.loss_fn(pred, labels) | ||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||
| pred = model_parts[0](inputs, **extra_inputs, **extra_kwargs) | ||||||||||||||||||||||||||
| loss = self.loss_fn(pred, labels) | ||||||||||||||||||||||||||
| # need to free pred before bwd to avoid peaking memory | ||||||||||||||||||||||||||
| del pred | ||||||||||||||||||||||||||
| loss.backward() | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please see also https://docs.google.com/document/d/1FqUXYCaoHTQy40anvKVSAv9Ci7yIWCUdZxbWMwHN6is/edit?tab=t.0#heading=h.aq5mvgrni90o