From df1bc6ab18296632bab61a7481767254d3fc511f Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Fri, 31 Jan 2025 10:40:33 -0800 Subject: [PATCH 01/17] Update [ghstack-poisoned] --- torchtitan/config_manager.py | 16 ++++++ torchtitan/model_spec.py | 27 +++++++++ torchtitan/models/__init__.py | 87 ++++++++++++++++++++++++++--- torchtitan/models/llama/__init__.py | 17 +++++- torchtitan/models/llama/model.py | 5 +- train.py | 21 +++---- 6 files changed, 150 insertions(+), 23 deletions(-) create mode 100644 torchtitan/model_spec.py diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index d59e34bc66..f5c01f78ae 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import argparse +import importlib import sys from collections import defaultdict from typing import Tuple, Union @@ -375,6 +376,12 @@ def __init__(self): The default value is 'allgather'. """, ) + self.parser.add_argument( + "--experimental.model_module_path", + type=str, + default="", + help="", + ) self.parser.add_argument( "--training.mixed_precision_param", type=str, @@ -638,6 +645,15 @@ def parse_args(self, args_list: list = sys.argv[1:]): exp["pipeline_parallel_split_points"] ) + if ( + "experimental" in args_dict + and "model_module_path" in args_dict["experimental"] + and args_dict["experimental"]["model_module_path"] + ): + from torchtitan.models import add_model_spec_path + + add_model_spec_path(args_dict["experimental"]["model_module_path"]) + # override args dict with cmd_args cmd_args_dict = self._args_to_two_level_dict(cmd_args) for section, section_args in cmd_args_dict.items(): diff --git a/torchtitan/model_spec.py b/torchtitan/model_spec.py new file mode 100644 index 0000000000..74cc69c210 --- /dev/null +++ b/torchtitan/model_spec.py @@ -0,0 +1,27 @@ +from dataclasses import dataclass +from typing import Callable, Dict, List, Protocol, Tuple, Type + +import torch.nn as nn +from torch.distributed.pipelining.schedules import _PipelineSchedule + +@dataclass +class BaseModelArgs: + _enforced: str = "This field is used to enforce all fields have defaults." + + +class ModelProtocol(Protocol): + def from_model_args(self, args: BaseModelArgs) -> nn.Module: + ... + + +@dataclass +class ModelSpec: + name: str + cls: Type[nn.Module] + config: Dict[str, BaseModelArgs] + # As for now, this is a string. So it will have to be built-in to the + # TorchTitan library. In the future, we can make this a defined class + # that can be extended like ModelSpec. + tokenizer: str + parallelize_fn: Callable[[nn.Module], None] + pipelining_fn: Callable[[nn.Module], Tuple[_PipelineSchedule, List[nn.Module]]] diff --git a/torchtitan/models/__init__.py b/torchtitan/models/__init__.py index c666b06553..ee9bf801b4 100644 --- a/torchtitan/models/__init__.py +++ b/torchtitan/models/__init__.py @@ -4,14 +4,85 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from torchtitan.models.llama import llama3_configs, Transformer +import importlib -models_config = { - "llama3": llama3_configs, -} +import os +import pkgutil +from typing import Dict, Set -model_name_to_cls = {"llama3": Transformer} +import torchtitan.models as models +from torchtitan.model_spec import ModelSpec -model_name_to_tokenizer = { - "llama3": "tiktoken", -} + +_model_specs_path: Set[str] = set() + + +def _load_module(path: str): + path = os.path.expanduser(path) + + # 1. Check if path is an existing file or directory path. + if os.path.exists(path): + if os.path.isdir(path): + init_file = os.path.join(path, "__init__.py") + if os.path.isfile(init_file): + return _load_module_from_init(path) + + raise ImportError( + f"Directory '{path}' is not a Python package because it does not " + "contain an __init__.py file." + ) + else: + raise ImportError(f"Path '{path}' is not a directory.") + + # 2. If not a valid path, assume it's a dotted module name. + return importlib.import_module(path) + + +def _load_module_from_init(path: str): + module_name = os.path.basename(os.path.normpath(path)) + init_file = os.path.join(path, "__init__.py") + + spec = importlib.util.spec_from_file_location(module_name, init_file) + if spec is None: + raise ImportError(f"Could not create spec from '{init_file}'") + + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + +for _, name, _ in pkgutil.iter_modules(models.__path__): + full_module_name = f"{models.__name__}.{name}" + _model_specs_path.add(full_module_name) + # model_module = importlib.import_module(full_module_name) + # load_spec_from_module(model_module) + + +def add_model_spec_path(path: str): + global _model_specs_path + _model_specs_path.add(path) + + +def build_model_specs() -> Dict[str, ModelSpec]: + """ + Load all model specs from the `models` package. + """ + global _model_specs_path + model_specs = {} + for path in _model_specs_path: + module = _load_module(path) + model_spec = getattr(module, "model_spec", None) + if model_spec is not None: + model_specs[model_spec.name] = model_spec + # We would like to just use `model_spec` but current torchtitan parallelize + # functions depend on ModelArgs and can cause circular imports. + # As a result, we have to use `build_model_spec` as a workaround. + build_model_spec = getattr(module, "build_model_spec", None) + if build_model_spec: + model_spec = build_model_spec() + model_specs[model_spec.name] = model_spec + + return model_specs + + +__all__ = [add_model_spec_path, build_model_specs] diff --git a/torchtitan/models/llama/__init__.py b/torchtitan/models/llama/__init__.py index 3bb430d2cb..e61538b987 100644 --- a/torchtitan/models/llama/__init__.py +++ b/torchtitan/models/llama/__init__.py @@ -6,9 +6,9 @@ # # Copyright (c) Meta Platforms, Inc. All Rights Reserved. +from torchtitan.model_spec import ModelSpec from torchtitan.models.llama.model import ModelArgs, Transformer -__all__ = ["Transformer"] llama3_configs = { "debugmodel": ModelArgs(dim=256, n_layers=8, n_heads=16, rope_theta=500000), @@ -40,3 +40,18 @@ rope_theta=500000, ), } + + +def build_model_spec() -> ModelSpec: + # Avoid circular import + from torchtitan.parallelisms.parallelize_llama import parallelize_llama + from torchtitan.parallelisms.pipeline_llama import pipeline_llama + + return ModelSpec( + name="llama3", + cls=Transformer, + config=llama3_configs, + tokenizer="tiktoken", + parallelize_fn=parallelize_llama, + pipelining_fn=pipeline_llama, + ) diff --git a/torchtitan/models/llama/model.py b/torchtitan/models/llama/model.py index 641ef6de95..d60447e46d 100644 --- a/torchtitan/models/llama/model.py +++ b/torchtitan/models/llama/model.py @@ -13,11 +13,12 @@ import torch import torch.nn.functional as F from torch import nn +from torchtitan.model_spec import BaseModelArgs, ModelProtocol from torchtitan.models.norms import build_norm @dataclass -class ModelArgs: +class ModelArgs(BaseModelArgs): dim: int = 4096 n_layers: int = 32 n_heads: int = 32 @@ -258,7 +259,7 @@ def init_weights(self, init_std: float): nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std) -class TransformerBlock(nn.Module): +class TransformerBlock(nn.Module, ModelProtocol): """ TransformerBlock Module diff --git a/train.py b/train.py index bac2277228..7f7cf2e183 100644 --- a/train.py +++ b/train.py @@ -19,13 +19,9 @@ from torchtitan.float8 import Float8Handler from torchtitan.logging import init_logger, logger from torchtitan.metrics import build_device_memory_monitor, build_metric_logger -from torchtitan.models import model_name_to_cls, model_name_to_tokenizer, models_config +from torchtitan.models import build_model_specs from torchtitan.optimizer import build_lr_schedulers, build_optimizers -from torchtitan.parallelisms import ( - models_parallelize_fns, - models_pipelining_fns, - ParallelDims, -) +from torchtitan.parallelisms import ParallelDims from torchtitan.profiling import maybe_enable_memory_snapshot, maybe_enable_profiling from torchtitan.utils import device_module, device_type @@ -80,9 +76,10 @@ def main(job_config: JobConfig): world_mesh, device, job_config.training.seed, job_config.training.deterministic ) model_name = job_config.model.name + model_spec = build_model_specs()[model_name] # build tokenizer - tokenizer_type = model_name_to_tokenizer[model_name] + tokenizer_type = model_spec.tokenizer tokenizer = build_tokenizer(tokenizer_type, job_config.model.tokenizer_path) # build dataloader data_loader = build_hf_data_loader( @@ -96,8 +93,8 @@ def main(job_config: JobConfig): ) # build model (using meta init) - model_cls = model_name_to_cls[model_name] - model_config = models_config[model_name][job_config.model.flavor] + model_cls = model_spec.cls + model_config = model_spec.config[job_config.model.flavor] # set the model configs from training inputs: # 1. norm type to decide which norm layer to use # 2. vocab size from tokenizer @@ -151,7 +148,7 @@ def loss_fn(pred, labels): # apply parallelisms and initialization if parallel_dims.pp_enabled: # apply PT-D Pipeline Parallel - pp_schedule, model_parts = models_pipelining_fns[model_name]( + pp_schedule, model_parts = model_spec.pipelining_fn( model, pp_mesh, parallel_dims, job_config, device, model_config, loss_fn ) # when PP is enabled, `model` obj is no longer used after this point, model_parts is used instead @@ -162,14 +159,14 @@ def loss_fn(pred, labels): # optimizer, and checkpointing for m in model_parts: # apply SPMD-style PT-D techniques - models_parallelize_fns[model_name](m, world_mesh, parallel_dims, job_config) + model_spec.parallelize_fn(m, world_mesh, parallel_dims, job_config) m.to_empty(device=init_device) with torch.no_grad(): m.init_weights(buffer_device=buffer_device) m.train() else: # apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel - models_parallelize_fns[model_name](model, world_mesh, parallel_dims, job_config) + model_spec.parallelize_fn(model, world_mesh, parallel_dims, job_config) model.to_empty(device=init_device) with torch.no_grad(): model.init_weights(buffer_device=buffer_device) From dfc1649a9d79244ceb83a7923c05717f42a6231b Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Fri, 31 Jan 2025 10:46:46 -0800 Subject: [PATCH 02/17] Update [ghstack-poisoned] --- torchtitan/config_manager.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index f5c01f78ae..9741e42a39 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -5,7 +5,6 @@ # LICENSE file in the root directory of this source tree. import argparse -import importlib import sys from collections import defaultdict from typing import Tuple, Union @@ -377,10 +376,18 @@ def __init__(self): """, ) self.parser.add_argument( - "--experimental.model_module_path", + "--experimental.custom_model_path", type=str, default="", - help="", + help=""" + The --custom_model_path option allows to specify a custom path to a model module + + that is not natively implemented within TorchTitan. + + Acceptable values are the file system path to the module (e.g., my_models/model_x) + + dotted import module (e.g., some_package.model_x). + """ ) self.parser.add_argument( "--training.mixed_precision_param", From 720f12a51bfa96727a8b118cf6228a529cc7daa4 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Fri, 31 Jan 2025 10:49:46 -0800 Subject: [PATCH 03/17] Update [ghstack-poisoned] --- torchtitan/config_manager.py | 2 +- torchtitan/model_spec.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 9741e42a39..7b8684e8e5 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -387,7 +387,7 @@ def __init__(self): Acceptable values are the file system path to the module (e.g., my_models/model_x) dotted import module (e.g., some_package.model_x). - """ + """, ) self.parser.add_argument( "--training.mixed_precision_param", diff --git a/torchtitan/model_spec.py b/torchtitan/model_spec.py index 74cc69c210..08efc7beb7 100644 --- a/torchtitan/model_spec.py +++ b/torchtitan/model_spec.py @@ -4,14 +4,14 @@ import torch.nn as nn from torch.distributed.pipelining.schedules import _PipelineSchedule + @dataclass class BaseModelArgs: _enforced: str = "This field is used to enforce all fields have defaults." class ModelProtocol(Protocol): - def from_model_args(self, args: BaseModelArgs) -> nn.Module: - ... + def from_model_args(self, args: BaseModelArgs) -> nn.Module: ... @dataclass From 225bfcc371b64b56794d43fa801f7c71924452a3 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Fri, 31 Jan 2025 10:58:16 -0800 Subject: [PATCH 04/17] Update [ghstack-poisoned] --- torchtitan/model_spec.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchtitan/model_spec.py b/torchtitan/model_spec.py index 08efc7beb7..5ac9ea8ba2 100644 --- a/torchtitan/model_spec.py +++ b/torchtitan/model_spec.py @@ -11,7 +11,8 @@ class BaseModelArgs: class ModelProtocol(Protocol): - def from_model_args(self, args: BaseModelArgs) -> nn.Module: ... + def from_model_args(self, args: BaseModelArgs) -> nn.Module: + ... @dataclass From 650152e7b92173c2f43dc323cc530613ad6b8b64 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Fri, 31 Jan 2025 11:00:17 -0800 Subject: [PATCH 05/17] Update [ghstack-poisoned] --- torchtitan/model_spec.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/torchtitan/model_spec.py b/torchtitan/model_spec.py index 5ac9ea8ba2..28c050b4a1 100644 --- a/torchtitan/model_spec.py +++ b/torchtitan/model_spec.py @@ -1,3 +1,12 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# Copyright (c) Meta Platforms, Inc. All Rights Reserved. + + from dataclasses import dataclass from typing import Callable, Dict, List, Protocol, Tuple, Type From 6a51325f359b6a0a88583aff8ec9f4ff67d66e6e Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Thu, 6 Feb 2025 15:27:32 -0800 Subject: [PATCH 06/17] Update [ghstack-poisoned] --- torchtitan/models/llama/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchtitan/models/llama/model.py b/torchtitan/models/llama/model.py index d60447e46d..67c5426051 100644 --- a/torchtitan/models/llama/model.py +++ b/torchtitan/models/llama/model.py @@ -259,7 +259,7 @@ def init_weights(self, init_std: float): nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std) -class TransformerBlock(nn.Module, ModelProtocol): +class TransformerBlock(nn.Module): """ TransformerBlock Module @@ -332,7 +332,7 @@ def init_weights(self): self.feed_forward.init_weights(self.weight_init_std) -class Transformer(nn.Module): +class Transformer(nn.Module, ModelProtocol): """ Transformer Module From 2e569d7d15811747d48d2e981b64eedc3702ddea Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Thu, 6 Feb 2025 15:41:35 -0800 Subject: [PATCH 07/17] Update [ghstack-poisoned] --- torchtitan/config_manager.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index e4c58d03f4..0cda86b440 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -405,11 +405,8 @@ def __init__(self): default="", help=""" The --custom_model_path option allows to specify a custom path to a model module - that is not natively implemented within TorchTitan. - Acceptable values are the file system path to the module (e.g., my_models/model_x) - dotted import module (e.g., some_package.model_x). """, ) From bab9bf5b85623ab0636e2fed1be12b5517199511 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Thu, 6 Feb 2025 16:11:12 -0800 Subject: [PATCH 08/17] Update [ghstack-poisoned] --- torchtitan/__init__.py | 12 ++++++++++++ torchtitan/optimizer.py | 1 - torchtitan/utils.py | 19 +++++++++---------- 3 files changed, 21 insertions(+), 11 deletions(-) create mode 100644 torchtitan/__init__.py diff --git a/torchtitan/__init__.py b/torchtitan/__init__.py new file mode 100644 index 0000000000..d39e084f14 --- /dev/null +++ b/torchtitan/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# Copyright (c) Meta Platforms, Inc. All Rights Reserved. + +# Import the built-in models here so that the corresponding register_model_spec() +# will be called. +import torchtitan.models + diff --git a/torchtitan/optimizer.py b/torchtitan/optimizer.py index 631cbea5da..9c8d749cdb 100644 --- a/torchtitan/optimizer.py +++ b/torchtitan/optimizer.py @@ -5,7 +5,6 @@ # LICENSE file in the root directory of this source tree. import functools -from abc import ABC from typing import Any, Callable, Dict, Iterable, List import torch diff --git a/torchtitan/utils.py b/torchtitan/utils.py index 8e43618eb3..8ff0cd2df1 100644 --- a/torchtitan/utils.py +++ b/torchtitan/utils.py @@ -439,17 +439,16 @@ def import_module_from_path(path: str): # 1. Check if path is an existing file or directory path. if os.path.exists(path): - if os.path.isdir(path): - init_file = os.path.join(path, "__init__.py") - if os.path.isfile(init_file): - return _import_module_from_init(path) - - raise ImportError( - f"Directory '{path}' is not a Python package because it does not " - "contain an __init__.py file." - ) - else: + if not os.path.isdir(path): raise ImportError(f"Path '{path}' is not a directory.") + init_file = os.path.join(path, "__init__.py") + if os.path.isfile(init_file): + return _import_module_from_init(path) + + raise ImportError( + f"Directory '{path}' is not a Python package because it does not " + "contain an __init__.py file." + ) # 2. If not a valid path, assume it's a dotted module name. return importlib.import_module(path) From 210707ad67385811b9490a17cb20d9f3f43ba872 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Tue, 11 Feb 2025 10:57:29 -0800 Subject: [PATCH 09/17] Update [ghstack-poisoned] --- tests/unit_tests/test_model_spec.py | 124 ++++++++++++++++++++++++++++ torchtitan/model_spec.py | 27 ++++-- 2 files changed, 145 insertions(+), 6 deletions(-) create mode 100644 tests/unit_tests/test_model_spec.py diff --git a/tests/unit_tests/test_model_spec.py b/tests/unit_tests/test_model_spec.py new file mode 100644 index 0000000000..3f83464dfd --- /dev/null +++ b/tests/unit_tests/test_model_spec.py @@ -0,0 +1,124 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from functools import partial + +import pytest +import torch +import torch.nn as nn +from torchtitan.config_manager import JobConfig +from torchtitan.model_spec import ( + apply_to_model_specs, + BaseModelArgs, + get_model_spec, + ModelProtocol, + ModelSpec, + register_model_spec, +) +from torchtitan.models.llama import parallelize_llama, pipeline_llama +from torchtitan.optimizer import ( + build_lr_schedulers, + build_optimizers, + OptimizersContainer, +) + + +class FakeModel(ModelProtocol): + @staticmethod + def from_model_args(args: BaseModelArgs) -> nn.Module: + return nn.Linear(8, 8) + + +def fake_build_optimizers( + model_parts: list[nn.Module], job_config: JobConfig +) -> OptimizersContainer: + optimizer_kwargs = { + "lr": 0.1, + "betas": (0.9, 0.95), + "weight_decay": 0.1, + "fused": True, + "foreach": False, + } + return OptimizersContainer( + model_parts=model_parts, + optimizer_kwargs=optimizer_kwargs, + name="Adam", + ) + + +class TestModelSpec: + def test_register_model_spec(self): + fake_config = {"fake": None} + spec = ModelSpec( + name="fake", + cls=FakeModel, + config=fake_config, + tokenizer="tiktoken", + parallelize_fn=parallelize_llama, + pipelining_fn=pipeline_llama, + build_optimizers_fn=build_optimizers, + build_lr_schedulers_fn=build_lr_schedulers, + ) + register_model_spec(spec) + new_spec = get_model_spec("fake") + assert new_spec == spec + + with pytest.raises(ValueError): + new_spec = get_model_spec("fake2") + + def test_optim_hook(self): + fake_config = {"fake": None} + spec = ModelSpec( + name="fake2", + cls=FakeModel, + config=fake_config, + tokenizer="tiktoken", + parallelize_fn=parallelize_llama, + pipelining_fn=pipeline_llama, + build_optimizers_fn=fake_build_optimizers, + build_lr_schedulers_fn=build_lr_schedulers, + ) + register_model_spec(spec) + new_spec = get_model_spec("fake2") + + # Demonstrate how to register a optimizer hook for all model specs + hook_called = False + + def my_hook( + optimizer: torch.optim.Optimizer, + args, + kwargs, + model_parts: list[nn.Module], + ) -> None: + nonlocal hook_called + hook_called = True + + def register_optimizer_hook_to_spec(spec: ModelSpec) -> ModelSpec: + # Create a closure to capture the original spec.build_optimizers_fn + original_build_optimizers_fn = spec.build_optimizers_fn + + def my_build_optimizer_fn( + model_parts: list[nn.Module], job_config: JobConfig + ) -> OptimizersContainer: + optimizers = original_build_optimizers_fn(model_parts, job_config) + optimizers.register_step_post_hook( + partial(my_hook, model_parts=model_parts) + ) + return optimizers + + spec.build_optimizers_fn = my_build_optimizer_fn + + apply_to_model_specs(register_optimizer_hook_to_spec) + + model = new_spec.cls.from_model_args(BaseModelArgs()) + model_parts = [model] + optimizers = new_spec.build_optimizers_fn(model_parts, JobConfig()) + assert optimizers.optimizers[0].__class__.__name__ == "Adam" + batch = torch.randn(8, 8) + model(batch).sum().backward() + assert not hook_called + optimizers.step() + assert hook_called diff --git a/torchtitan/model_spec.py b/torchtitan/model_spec.py index c8ee56e5d6..7e1ba3e896 100644 --- a/torchtitan/model_spec.py +++ b/torchtitan/model_spec.py @@ -8,10 +8,11 @@ from dataclasses import dataclass -from typing import Callable, Dict, List, Protocol, Tuple, Type +from typing import Callable, Dict, List, Protocol, Tuple, Type, TypeAlias import torch.nn as nn from torch.distributed.pipelining.schedules import _PipelineSchedule + from torchtitan.config_manager import JobConfig from torchtitan.optimizer import LRSchedulersContainer, OptimizersContainer @@ -35,7 +36,16 @@ class ModelProtocol(Protocol): """ @staticmethod - def from_model_args(self, args: BaseModelArgs) -> nn.Module: ... + def from_model_args(args: BaseModelArgs) -> nn.Module: ... + + +OptimizersBuilder: TypeAlias = Callable[ + [List[nn.Module], JobConfig], OptimizersContainer +] +OptimizerBuilderWrapper: TypeAlias = Callable[ + [List[nn.Module], JobConfig, OptimizersContainer], OptimizersContainer +] +LRSchedulersBuilder: TypeAlias = Callable[[OptimizersContainer], LRSchedulersContainer] @dataclass @@ -51,10 +61,8 @@ class ModelSpec: tokenizer: str parallelize_fn: Callable[[nn.Module], None] pipelining_fn: Callable[[nn.Module], Tuple[_PipelineSchedule, List[nn.Module]]] - build_optimizers_fn: Callable[[List[nn.Module], JobConfig], OptimizersContainer] - build_lr_schedulers_fn: Callable[ - [List[nn.Module], JobConfig], LRSchedulersContainer - ] + build_optimizers_fn: OptimizersBuilder + build_lr_schedulers_fn: LRSchedulersBuilder # TODO: Add a FQN convert fn to allow users to load checkpoints from # HuggingFace or other sources that have different FQN conventions. @@ -67,6 +75,7 @@ def register_model_spec(model_spec: ModelSpec) -> None: global _model_specs if model_spec.name in _model_specs: raise ValueError(f"Model {model_spec.name} is already registered.") + _model_specs[model_spec.name] = model_spec @@ -75,3 +84,9 @@ def get_model_spec(name: str) -> ModelSpec: if name not in _model_specs: raise ValueError(f"Model {name} is not registered.") return _model_specs[name] + + +def apply_to_model_specs(func: Callable[[ModelSpec], ModelSpec]) -> None: + global _model_specs + for name, model_spec in _model_specs.items(): + _model_specs[name] = func(model_spec) From 6fb1d743979d08f780bdb8a490f138058293e4d7 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Tue, 11 Feb 2025 11:32:02 -0800 Subject: [PATCH 10/17] Update [ghstack-poisoned] --- ...{test_model_spec.py => test_train_spec.py} | 32 +++++++++---------- torchtitan/checkpoint.py | 1 + torchtitan/models/llama/__init__.py | 26 +++++++++------ torchtitan/models/llama/model.py | 24 +++++++------- torchtitan/models/llama/pipeline_llama.py | 16 ++++------ torchtitan/optimizer.py | 13 ++++---- torchtitan/parallelisms/__init__.py | 12 +------ torchtitan/parallelisms/parallel_dims.py | 1 + .../{pipelining_utils.py => pipeline.py} | 0 torchtitan/{model_spec.py => train_spec.py} | 30 ++++++++--------- torchtitan/utils.py | 1 + 11 files changed, 78 insertions(+), 78 deletions(-) rename tests/unit_tests/{test_model_spec.py => test_train_spec.py} (85%) rename torchtitan/parallelisms/{pipelining_utils.py => pipeline.py} (100%) rename torchtitan/{model_spec.py => train_spec.py} (80%) diff --git a/tests/unit_tests/test_model_spec.py b/tests/unit_tests/test_train_spec.py similarity index 85% rename from tests/unit_tests/test_model_spec.py rename to tests/unit_tests/test_train_spec.py index 3f83464dfd..9be25dfbff 100644 --- a/tests/unit_tests/test_model_spec.py +++ b/tests/unit_tests/test_train_spec.py @@ -10,13 +10,13 @@ import torch import torch.nn as nn from torchtitan.config_manager import JobConfig -from torchtitan.model_spec import ( - apply_to_model_specs, +from torchtitan.train_spec import ( + apply_to_train_specs, BaseModelArgs, - get_model_spec, + get_train_spec, ModelProtocol, - ModelSpec, - register_model_spec, + TrainSpec, + register_train_spec, ) from torchtitan.models.llama import parallelize_llama, pipeline_llama from torchtitan.optimizer import ( @@ -49,10 +49,10 @@ def fake_build_optimizers( ) -class TestModelSpec: - def test_register_model_spec(self): +class TestTrainSpec: + def test_register_train_spec(self): fake_config = {"fake": None} - spec = ModelSpec( + spec = TrainSpec( name="fake", cls=FakeModel, config=fake_config, @@ -62,16 +62,16 @@ def test_register_model_spec(self): build_optimizers_fn=build_optimizers, build_lr_schedulers_fn=build_lr_schedulers, ) - register_model_spec(spec) - new_spec = get_model_spec("fake") + register_train_spec(spec) + new_spec = get_train_spec("fake") assert new_spec == spec with pytest.raises(ValueError): - new_spec = get_model_spec("fake2") + new_spec = get_train_spec("fake2") def test_optim_hook(self): fake_config = {"fake": None} - spec = ModelSpec( + spec = TrainSpec( name="fake2", cls=FakeModel, config=fake_config, @@ -81,8 +81,8 @@ def test_optim_hook(self): build_optimizers_fn=fake_build_optimizers, build_lr_schedulers_fn=build_lr_schedulers, ) - register_model_spec(spec) - new_spec = get_model_spec("fake2") + register_train_spec(spec) + new_spec = get_train_spec("fake2") # Demonstrate how to register a optimizer hook for all model specs hook_called = False @@ -96,7 +96,7 @@ def my_hook( nonlocal hook_called hook_called = True - def register_optimizer_hook_to_spec(spec: ModelSpec) -> ModelSpec: + def register_optimizer_hook_to_spec(spec: TrainSpec) -> TrainSpec: # Create a closure to capture the original spec.build_optimizers_fn original_build_optimizers_fn = spec.build_optimizers_fn @@ -111,7 +111,7 @@ def my_build_optimizer_fn( spec.build_optimizers_fn = my_build_optimizer_fn - apply_to_model_specs(register_optimizer_hook_to_spec) + apply_to_train_specs(register_optimizer_hook_to_spec) model = new_spec.cls.from_model_args(BaseModelArgs()) model_parts = [model] diff --git a/torchtitan/checkpoint.py b/torchtitan/checkpoint.py index 860d82d6aa..c751115788 100644 --- a/torchtitan/checkpoint.py +++ b/torchtitan/checkpoint.py @@ -26,6 +26,7 @@ ) from torch.distributed.checkpoint.stateful import Stateful from torch.utils.data import DataLoader + from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP from torchtitan.logging import init_logger, logger from torchtitan.optimizer import LRSchedulersContainer, OptimizersContainer diff --git a/torchtitan/models/llama/__init__.py b/torchtitan/models/llama/__init__.py index cfd1a14272..c4cee7100d 100644 --- a/torchtitan/models/llama/__init__.py +++ b/torchtitan/models/llama/__init__.py @@ -6,19 +6,27 @@ # # Copyright (c) Meta Platforms, Inc. All Rights Reserved. -from torchtitan.model_spec import ModelSpec, register_model_spec -from torchtitan.models.llama.model import ModelArgs, Transformer +from torchtitan.train_spec import TrainSpec, register_train_spec +from torchtitan.models.llama.model import Transformer, TransformerModelArgs from torchtitan.optimizer import build_lr_schedulers, build_optimizers from .parallelize_llama import parallelize_llama from .pipeline_llama import pipeline_llama -__all__ = ["parallelize_llama", "pipeline_llama", "ModelArgs", "Transformer"] +__all__ = [ + "parallelize_llama", + "pipeline_llama", + "TransformerModelArgs", + "Transformer", + "llama3_configs", +] llama3_configs = { - "debugmodel": ModelArgs(dim=256, n_layers=8, n_heads=16, rope_theta=500000), - "8B": ModelArgs( + "debugmodel": TransformerModelArgs( + dim=256, n_layers=8, n_heads=16, rope_theta=500000 + ), + "8B": TransformerModelArgs( dim=4096, n_layers=32, n_heads=32, @@ -27,7 +35,7 @@ multiple_of=1024, rope_theta=500000, ), - "70B": ModelArgs( + "70B": TransformerModelArgs( dim=8192, n_layers=80, n_heads=64, @@ -36,7 +44,7 @@ multiple_of=4096, rope_theta=500000, ), - "405B": ModelArgs( + "405B": TransformerModelArgs( dim=16384, n_layers=126, n_heads=128, @@ -48,8 +56,8 @@ } -register_model_spec( - ModelSpec( +register_train_spec( + TrainSpec( name="llama3", cls=Transformer, config=llama3_configs, diff --git a/torchtitan/models/llama/model.py b/torchtitan/models/llama/model.py index 67c5426051..6519b4446c 100644 --- a/torchtitan/models/llama/model.py +++ b/torchtitan/models/llama/model.py @@ -13,12 +13,12 @@ import torch import torch.nn.functional as F from torch import nn -from torchtitan.model_spec import BaseModelArgs, ModelProtocol +from torchtitan.train_spec import BaseModelArgs, ModelProtocol from torchtitan.models.norms import build_norm @dataclass -class ModelArgs(BaseModelArgs): +class TransformerModelArgs(BaseModelArgs): dim: int = 4096 n_layers: int = 32 n_heads: int = 32 @@ -131,7 +131,7 @@ class Attention(nn.Module): Multi-head attention module. Args: - model_args (ModelArgs): Model configuration arguments. + model_args (TransformerModelArgs): Model configuration arguments. Attributes: n_kv_heads (int): Number of key and value heads. @@ -145,7 +145,7 @@ class Attention(nn.Module): """ - def __init__(self, model_args: ModelArgs): + def __init__(self, model_args: TransformerModelArgs): super().__init__() self.n_heads = model_args.n_heads self.n_kv_heads = ( @@ -265,7 +265,7 @@ class TransformerBlock(nn.Module): Args: layer_id (int): Identifier for the layer. - model_args (ModelArgs): Model configuration arguments. + model_args (TransformerModelArgs): Model configuration arguments. Attributes: n_heads (int): Number of attention heads. @@ -279,7 +279,7 @@ class TransformerBlock(nn.Module): """ - def __init__(self, layer_id: int, model_args: ModelArgs): + def __init__(self, layer_id: int, model_args: TransformerModelArgs): super().__init__() self.n_heads = model_args.n_heads self.dim = model_args.dim @@ -337,10 +337,10 @@ class Transformer(nn.Module, ModelProtocol): Transformer Module Args: - model_args (ModelArgs): Model configuration arguments. + model_args (TransformerModelArgs): Model configuration arguments. Attributes: - model_args (ModelArgs): Model configuration arguments. + model_args (TransformerModelArgs): Model configuration arguments. vocab_size (int): Vocabulary size. n_layers (int): Number of layers in the model. tok_embeddings (ParallelEmbedding): Token embeddings. @@ -351,7 +351,7 @@ class Transformer(nn.Module, ModelProtocol): """ - def __init__(self, model_args: ModelArgs): + def __init__(self, model_args: TransformerModelArgs): super().__init__() self.model_args = model_args self.vocab_size = model_args.vocab_size @@ -447,12 +447,12 @@ def forward(self, tokens: torch.Tensor): return output @classmethod - def from_model_args(cls, model_args: ModelArgs) -> "Transformer": + def from_model_args(cls, model_args: TransformerModelArgs) -> "Transformer": """ - Initialize a Transformer model from a ModelArgs object. + Initialize a Transformer model from a TransformerModelArgs object. Args: - model_args (ModelArgs): Model configuration arguments. + model_args (TransformerModelArgs): Model configuration arguments. Returns: Transformer: Transformer model. diff --git a/torchtitan/models/llama/pipeline_llama.py b/torchtitan/models/llama/pipeline_llama.py index 8bb080ff23..856fa5d39f 100644 --- a/torchtitan/models/llama/pipeline_llama.py +++ b/torchtitan/models/llama/pipeline_llama.py @@ -7,7 +7,7 @@ # This file applies the PT-D pipeline parallelism to the Llama model. import copy -from typing import Callable, Union +from typing import Callable, Union, Optional import torch import torch.nn as nn @@ -18,14 +18,12 @@ from torchtitan.config_manager import JobConfig from torchtitan.logging import logger -from torchtitan.parallelisms import ( - build_pipeline_schedule, - generate_split_points, - ParallelDims, - stage_ids_this_rank, +from torchtitan.parallelisms.pipeline import ( + build_pipeline_schedule, generate_split_points, stage_ids_this_rank, ) +from torchtitan.parallelisms import ParallelDims -from .model import ModelArgs +from .model import TransformerModelArgs DeviceType = Union[int, str, torch.device] @@ -37,7 +35,7 @@ def pipeline_llama( parallel_dims: ParallelDims, job_config: JobConfig, device: DeviceType, - model_config: ModelArgs, + model_config: TransformerModelArgs, loss_fn: Callable[..., torch.Tensor], ) -> tuple[_PipelineSchedule, list[nn.Module]]: stages, models = pipeline_llama_manual_split( @@ -55,7 +53,7 @@ def pipeline_llama_manual_split( parallel_dims: ParallelDims, job_config: JobConfig, device: DeviceType, - model_config: ModelArgs, + model_config: TransformerModelArgs, ) -> tuple[list[PipelineStage], list[nn.Module]]: """ This API extracts one torch.nn.Module objects for the part of the model configured to run inside this stage. diff --git a/torchtitan/optimizer.py b/torchtitan/optimizer.py index 9c8d749cdb..e351fd1321 100644 --- a/torchtitan/optimizer.py +++ b/torchtitan/optimizer.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import copy import functools from typing import Any, Callable, Dict, Iterable, List @@ -17,6 +18,7 @@ from torch.distributed.checkpoint.stateful import Stateful from torch.optim import Optimizer from torch.optim.lr_scheduler import LambdaLR, LRScheduler + from torchtitan.config_manager import JobConfig @@ -71,8 +73,6 @@ class OptimizersContainer(Optimizer): def __init__( self, model_parts: List[nn.Module], optimizer_kwargs: Dict[str, Any], name: str ) -> None: - # We need to call super().__init__() to initialize some necessary optimizer - # functionality such as hooks. all_params = [] self.optimizers: List[Optimizer] = [] self.model_parts = model_parts @@ -124,6 +124,8 @@ def _validate_length(self, expected_length: int) -> None: def _post_init( self, all_params: list[nn.Parameter], optimizer_kwargs: dict[str, Any] ) -> None: + # We need to call Optimizer.__init__() to initialize some necessary optimizer + # functionality such as hooks. Optimizer.__init__(self, all_params, optimizer_kwargs) @@ -188,7 +190,7 @@ def build_optimizers( **Note** Users who want to customize the optimizer behavior can create their own ``OptimizersContainer`` subclass and ``build_optimizers``. Passing the - customized ``build_optimizers`` to ``ModelSpec`` will create the customized + customized ``build_optimizers`` to ``TrainSpec`` will create the customized ``OptimizersContainer``. Args: @@ -273,9 +275,8 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: # that is immutable. As long as ``training.steps`` and ``training.warmup_steps`` # in ``job_config`` remain unchanged when resuming from a checkpoint, this # approach is safe. We call ``copy()`` here to ensure extra safety. - # TODO: Should we deepcopy the state_dict? for scheduler in self.schedulers: - scheduler.load_state_dict(state_dict.copy()) + scheduler.load_state_dict(copy.deepcopy(state_dict)) def build_lr_schedulers( @@ -289,7 +290,7 @@ def build_lr_schedulers( **Note** Users who want to customize the lr scheduler behavior can create their own ``LRSchedulersContainer`` subclass and ``build_lr_scheduler``. Passing the - customized ``build_lr_schedulers`` to ``ModelSpec`` will create the customized + customized ``build_lr_schedulers`` to ``TrainSpec`` will create the customized ``LRSchedulersContainer``. diff --git a/torchtitan/parallelisms/__init__.py b/torchtitan/parallelisms/__init__.py index 58d5434817..1a187282e1 100644 --- a/torchtitan/parallelisms/__init__.py +++ b/torchtitan/parallelisms/__init__.py @@ -6,16 +6,6 @@ from torchtitan.parallelisms.parallel_dims import ParallelDims -from torchtitan.parallelisms.pipelining_utils import ( - build_pipeline_schedule, - generate_split_points, - stage_ids_this_rank, -) -__all__ = [ - "ParallelDims", - "build_pipeline_schedule", - "generate_split_points", - "stage_ids_this_rank", -] +__all__ = ["ParallelDims"] diff --git a/torchtitan/parallelisms/parallel_dims.py b/torchtitan/parallelisms/parallel_dims.py index d039fc33f3..f5e6a0e4c2 100644 --- a/torchtitan/parallelisms/parallel_dims.py +++ b/torchtitan/parallelisms/parallel_dims.py @@ -8,6 +8,7 @@ from functools import cached_property from torch.distributed.device_mesh import init_device_mesh + from torchtitan.logging import logger diff --git a/torchtitan/parallelisms/pipelining_utils.py b/torchtitan/parallelisms/pipeline.py similarity index 100% rename from torchtitan/parallelisms/pipelining_utils.py rename to torchtitan/parallelisms/pipeline.py diff --git a/torchtitan/model_spec.py b/torchtitan/train_spec.py similarity index 80% rename from torchtitan/model_spec.py rename to torchtitan/train_spec.py index 7e1ba3e896..22bfda9a72 100644 --- a/torchtitan/model_spec.py +++ b/torchtitan/train_spec.py @@ -49,7 +49,7 @@ def from_model_args(args: BaseModelArgs) -> nn.Module: ... @dataclass -class ModelSpec: +class TrainSpec: name: str cls: Type[nn.Module] config: Dict[str, BaseModelArgs] @@ -68,25 +68,25 @@ class ModelSpec: # HuggingFace or other sources that have different FQN conventions. -_model_specs = {} +_train_specs = {} -def register_model_spec(model_spec: ModelSpec) -> None: - global _model_specs - if model_spec.name in _model_specs: - raise ValueError(f"Model {model_spec.name} is already registered.") +def register_train_spec(train_spec: TrainSpec) -> None: + global _train_specs + if train_spec.name in _train_specs: + raise ValueError(f"Model {train_spec.name} is already registered.") - _model_specs[model_spec.name] = model_spec + _train_specs[train_spec.name] = train_spec -def get_model_spec(name: str) -> ModelSpec: - global _model_specs - if name not in _model_specs: +def get_train_spec(name: str) -> TrainSpec: + global _train_specs + if name not in _train_specs: raise ValueError(f"Model {name} is not registered.") - return _model_specs[name] + return _train_specs[name] -def apply_to_model_specs(func: Callable[[ModelSpec], ModelSpec]) -> None: - global _model_specs - for name, model_spec in _model_specs.items(): - _model_specs[name] = func(model_spec) +def apply_to_train_specs(func: Callable[[TrainSpec], TrainSpec]) -> None: + global _train_specs + for name, train_spec in _train_specs.items(): + _train_specs[name] = func(train_spec) diff --git a/torchtitan/utils.py b/torchtitan/utils.py index 8ff0cd2df1..122a406f68 100644 --- a/torchtitan/utils.py +++ b/torchtitan/utils.py @@ -22,6 +22,7 @@ from torch._utils import _get_available_device_type, _get_device_module from torch.distributed.device_mesh import DeviceMesh from torch.distributed.tensor import DTensor + from torchtitan.logging import logger From 02c87b2f31d50fed588388b5565e9829e891d5e2 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Tue, 11 Feb 2025 11:34:17 -0800 Subject: [PATCH 11/17] Update [ghstack-poisoned] --- torchtitan/__init__.py | 3 +-- torchtitan/train_spec.py | 3 ++- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/torchtitan/__init__.py b/torchtitan/__init__.py index d39e084f14..be0d95f3c5 100644 --- a/torchtitan/__init__.py +++ b/torchtitan/__init__.py @@ -8,5 +8,4 @@ # Import the built-in models here so that the corresponding register_model_spec() # will be called. -import torchtitan.models - +import torchtitan.models # noqa: F401 diff --git a/torchtitan/train_spec.py b/torchtitan/train_spec.py index 22bfda9a72..d72ea3773d 100644 --- a/torchtitan/train_spec.py +++ b/torchtitan/train_spec.py @@ -36,7 +36,8 @@ class ModelProtocol(Protocol): """ @staticmethod - def from_model_args(args: BaseModelArgs) -> nn.Module: ... + def from_model_args(args: BaseModelArgs) -> nn.Module: + ... OptimizersBuilder: TypeAlias = Callable[ From 4234a26a2be7a6a9d975271818101b722f6438ee Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Tue, 11 Feb 2025 11:48:46 -0800 Subject: [PATCH 12/17] Update [ghstack-poisoned] --- tests/unit_tests/test_train_spec.py | 2 -- torchtitan/models/__init__.py | 5 ++++- torchtitan/models/llama/__init__.py | 3 +-- torchtitan/train_spec.py | 8 ++------ train.py | 25 +++++++++++++------------ 5 files changed, 20 insertions(+), 23 deletions(-) diff --git a/tests/unit_tests/test_train_spec.py b/tests/unit_tests/test_train_spec.py index 9be25dfbff..447f94d964 100644 --- a/tests/unit_tests/test_train_spec.py +++ b/tests/unit_tests/test_train_spec.py @@ -56,7 +56,6 @@ def test_register_train_spec(self): name="fake", cls=FakeModel, config=fake_config, - tokenizer="tiktoken", parallelize_fn=parallelize_llama, pipelining_fn=pipeline_llama, build_optimizers_fn=build_optimizers, @@ -75,7 +74,6 @@ def test_optim_hook(self): name="fake2", cls=FakeModel, config=fake_config, - tokenizer="tiktoken", parallelize_fn=parallelize_llama, pipelining_fn=pipeline_llama, build_optimizers_fn=fake_build_optimizers, diff --git a/torchtitan/models/__init__.py b/torchtitan/models/__init__.py index acbd15fbcc..16d940d229 100644 --- a/torchtitan/models/__init__.py +++ b/torchtitan/models/__init__.py @@ -7,4 +7,7 @@ # Import the built-in models here so that the corresponding register_model_spec() # will be called. -import torchtitan.models.llama # noqa +import torchtitan.models.llama # noqa: F401 + + +model_name_to_tokenizer = {"llama3": "tiktoken"} diff --git a/torchtitan/models/llama/__init__.py b/torchtitan/models/llama/__init__.py index c4cee7100d..5cdedb0839 100644 --- a/torchtitan/models/llama/__init__.py +++ b/torchtitan/models/llama/__init__.py @@ -6,9 +6,9 @@ # # Copyright (c) Meta Platforms, Inc. All Rights Reserved. -from torchtitan.train_spec import TrainSpec, register_train_spec from torchtitan.models.llama.model import Transformer, TransformerModelArgs from torchtitan.optimizer import build_lr_schedulers, build_optimizers +from torchtitan.train_spec import register_train_spec, TrainSpec from .parallelize_llama import parallelize_llama from .pipeline_llama import pipeline_llama @@ -61,7 +61,6 @@ name="llama3", cls=Transformer, config=llama3_configs, - tokenizer="tiktoken", parallelize_fn=parallelize_llama, pipelining_fn=pipeline_llama, build_optimizers_fn=build_optimizers, diff --git a/torchtitan/train_spec.py b/torchtitan/train_spec.py index d72ea3773d..f199d2e033 100644 --- a/torchtitan/train_spec.py +++ b/torchtitan/train_spec.py @@ -54,17 +54,13 @@ class TrainSpec: name: str cls: Type[nn.Module] config: Dict[str, BaseModelArgs] - # TODO: Add a ``build_dataloader_fn`` - # As for now, this is a string. So it will have to be built-in to the - # TorchTitan library. A better way would be to have a dataloader class - # and a ``build_dataloader`` function that take job_config to consume - # the different dataloader and tokenizer configs. - tokenizer: str parallelize_fn: Callable[[nn.Module], None] pipelining_fn: Callable[[nn.Module], Tuple[_PipelineSchedule, List[nn.Module]]] build_optimizers_fn: OptimizersBuilder build_lr_schedulers_fn: LRSchedulersBuilder + # TODO: Add a ``build_dataloader_fn`` + # TODO: Add a FQN convert fn to allow users to load checkpoints from # HuggingFace or other sources that have different FQN conventions. diff --git a/train.py b/train.py index da63273281..a278a14358 100644 --- a/train.py +++ b/train.py @@ -19,9 +19,10 @@ from torchtitan.float8 import Float8Handler from torchtitan.logging import init_logger, logger from torchtitan.metrics import build_device_memory_monitor, build_metric_logger -from torchtitan.model_spec import get_model_spec +from torchtitan.models import model_name_to_tokenizer from torchtitan.parallelisms import ParallelDims from torchtitan.profiling import maybe_enable_memory_snapshot, maybe_enable_profiling +from torchtitan.train_spec import get_train_spec from torchtitan.utils import device_module, device_type, import_module_from_path @@ -77,10 +78,10 @@ def main(job_config: JobConfig): utils.set_determinism( world_mesh, device, job_config.training.seed, job_config.training.deterministic ) - model_spec = get_model_spec(job_config.model.name) + train_spec = get_train_spec(job_config.model.name) # build tokenizer - tokenizer_type = model_spec.tokenizer + tokenizer_type = model_name_to_tokenizer[train_spec.name] tokenizer = build_tokenizer(tokenizer_type, job_config.model.tokenizer_path) # build dataloader data_loader = build_hf_data_loader( @@ -94,8 +95,8 @@ def main(job_config: JobConfig): ) # build model (using meta init) - model_cls = model_spec.cls - model_config = model_spec.config[job_config.model.flavor] + model_cls = train_spec.cls + model_config = train_spec.config[job_config.model.flavor] # set the model configs from training inputs: # 1. norm type to decide which norm layer to use # 2. vocab size from tokenizer @@ -105,7 +106,7 @@ def main(job_config: JobConfig): model_config.max_seq_len = job_config.training.seq_len logger.info( - f"Building {model_spec.name} {job_config.model.flavor} with {model_config}" + f"Building {train_spec.name} {job_config.model.flavor} with {model_config}" ) with torch.device("meta"): model = model_cls.from_model_args(model_config) @@ -123,7 +124,7 @@ def main(job_config: JobConfig): job_config.training.seq_len, ) logger.info( - f"{color.blue}Model {model_spec.name} {job_config.model.flavor} " + f"{color.blue}Model {train_spec.name} {job_config.model.flavor} " f"{color.red}size: {model_param_count:,} total parameters{color.reset}" ) @@ -151,7 +152,7 @@ def loss_fn(pred, labels): # apply parallelisms and initialization if parallel_dims.pp_enabled: # apply PT-D Pipeline Parallel - pp_schedule, model_parts = model_spec.pipelining_fn( + pp_schedule, model_parts = train_spec.pipelining_fn( model, pp_mesh, parallel_dims, job_config, device, model_config, loss_fn ) # when PP is enabled, `model` obj is no longer used after this point, model_parts is used instead @@ -162,14 +163,14 @@ def loss_fn(pred, labels): # optimizer, and checkpointing for m in model_parts: # apply SPMD-style PT-D techniques - model_spec.parallelize_fn(m, world_mesh, parallel_dims, job_config) + train_spec.parallelize_fn(m, world_mesh, parallel_dims, job_config) m.to_empty(device=init_device) with torch.no_grad(): m.init_weights(buffer_device=buffer_device) m.train() else: # apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel - model_spec.parallelize_fn(model, world_mesh, parallel_dims, job_config) + train_spec.parallelize_fn(model, world_mesh, parallel_dims, job_config) model.to_empty(device=init_device) with torch.no_grad(): model.init_weights(buffer_device=buffer_device) @@ -185,8 +186,8 @@ def loss_fn(pred, labels): ) # build optimizer after applying parallelisms to the model - optimizers = model_spec.build_optimizers_fn(model_parts, job_config) - lr_schedulers = model_spec.build_lr_schedulers_fn(optimizers, job_config) + optimizers = train_spec.build_optimizers_fn(model_parts, job_config) + lr_schedulers = train_spec.build_lr_schedulers_fn(optimizers, job_config) train_state = TrainState() From a5491da51217f4ff7ebe85781ab79822471177ea Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Tue, 11 Feb 2025 12:37:44 -0800 Subject: [PATCH 13/17] Update [ghstack-poisoned] --- scripts/estimate/estimation.py | 15 +++++++++------ tests/unit_tests/test_train_spec.py | 14 +++++++------- torchtitan/models/llama/model.py | 3 ++- torchtitan/models/llama/pipeline_llama.py | 8 +++++--- torchtitan/parallelisms/pipeline.py | 1 + 5 files changed, 24 insertions(+), 17 deletions(-) diff --git a/scripts/estimate/estimation.py b/scripts/estimate/estimation.py index 81ea10b8d4..cd7cdef5c1 100644 --- a/scripts/estimate/estimation.py +++ b/scripts/estimate/estimation.py @@ -19,9 +19,10 @@ from torchtitan.datasets import build_tokenizer from torchtitan.float8 import Float8Handler from torchtitan.logging import init_logger, logger -from torchtitan.models import model_name_to_cls, model_name_to_tokenizer, models_config +from torchtitan.models import model_name_to_tokenizer from torchtitan.optimizer import build_lr_schedulers, build_optimizers -from torchtitan.parallelisms import models_parallelize_fns, ParallelDims +from torchtitan.parallelisms import ParallelDims +from torchtitan.train_spec import get_train_spec def estimate_memory(job_config: JobConfig): @@ -74,6 +75,8 @@ def estimate_memory(job_config: JobConfig): "fake", rank=int(os.environ["LOCAL_RANK"]), world_size=world_size, store=store ) + train_spec = get_train_spec(job_config.model.name) + # build meshes world_mesh = parallel_dims.build_mesh(device_type="cuda") @@ -95,8 +98,8 @@ def loss_fn(pred, labels): ) # build model (using meta init) - model_cls = model_name_to_cls[model_name] - model_config = models_config[model_name][job_config.model.flavor] + model_cls = train_spec.cls + model_config = train_spec.config # set the model configs from training inputs: # 1. norm type to decide which norm layer to use # 2. vocab size from tokenizer @@ -112,7 +115,7 @@ def loss_fn(pred, labels): ): logger.info( - f"Building {model_name} {job_config.model.flavor} with {model_config}" + f"Building {train_spec.name} {job_config.model.flavor} with {model_config}" ) with torch.device("meta"): model = model_cls.from_model_args(model_config) @@ -123,7 +126,7 @@ def loss_fn(pred, labels): float8_handler.convert_to_float8_training(model) # apply PT-D DP/TP parallelisms and activation checkpointing - models_parallelize_fns[model_name](model, world_mesh, parallel_dims, job_config) + train_spec.parallelize_fn(model, world_mesh, parallel_dims, job_config) model.to_empty(device="cuda") if not active_fake_mode(): diff --git a/tests/unit_tests/test_train_spec.py b/tests/unit_tests/test_train_spec.py index 447f94d964..4c01d74bde 100644 --- a/tests/unit_tests/test_train_spec.py +++ b/tests/unit_tests/test_train_spec.py @@ -10,19 +10,19 @@ import torch import torch.nn as nn from torchtitan.config_manager import JobConfig +from torchtitan.models.llama import parallelize_llama, pipeline_llama +from torchtitan.optimizer import ( + build_lr_schedulers, + build_optimizers, + OptimizersContainer, +) from torchtitan.train_spec import ( apply_to_train_specs, BaseModelArgs, get_train_spec, ModelProtocol, - TrainSpec, register_train_spec, -) -from torchtitan.models.llama import parallelize_llama, pipeline_llama -from torchtitan.optimizer import ( - build_lr_schedulers, - build_optimizers, - OptimizersContainer, + TrainSpec, ) diff --git a/torchtitan/models/llama/model.py b/torchtitan/models/llama/model.py index 6519b4446c..0a96445119 100644 --- a/torchtitan/models/llama/model.py +++ b/torchtitan/models/llama/model.py @@ -13,8 +13,9 @@ import torch import torch.nn.functional as F from torch import nn -from torchtitan.train_spec import BaseModelArgs, ModelProtocol + from torchtitan.models.norms import build_norm +from torchtitan.train_spec import BaseModelArgs, ModelProtocol @dataclass diff --git a/torchtitan/models/llama/pipeline_llama.py b/torchtitan/models/llama/pipeline_llama.py index 856fa5d39f..1ede183f91 100644 --- a/torchtitan/models/llama/pipeline_llama.py +++ b/torchtitan/models/llama/pipeline_llama.py @@ -7,7 +7,7 @@ # This file applies the PT-D pipeline parallelism to the Llama model. import copy -from typing import Callable, Union, Optional +from typing import Callable, Optional, Union import torch import torch.nn as nn @@ -18,10 +18,12 @@ from torchtitan.config_manager import JobConfig from torchtitan.logging import logger +from torchtitan.parallelisms import ParallelDims from torchtitan.parallelisms.pipeline import ( - build_pipeline_schedule, generate_split_points, stage_ids_this_rank, + build_pipeline_schedule, + generate_split_points, + stage_ids_this_rank, ) -from torchtitan.parallelisms import ParallelDims from .model import TransformerModelArgs diff --git a/torchtitan/parallelisms/pipeline.py b/torchtitan/parallelisms/pipeline.py index 90502b9a5e..aa47189c7b 100644 --- a/torchtitan/parallelisms/pipeline.py +++ b/torchtitan/parallelisms/pipeline.py @@ -14,6 +14,7 @@ PipelineScheduleSingle, ) from torch.distributed.pipelining.stage import PipelineStage + from torchtitan.config_manager import JobConfig from torchtitan.logging import logger From b5cd485b2de7eb9bc92f2bd0e86c2c437b2c1968 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Tue, 11 Feb 2025 12:58:04 -0800 Subject: [PATCH 14/17] Update [ghstack-poisoned] --- scripts/estimate/estimation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/estimate/estimation.py b/scripts/estimate/estimation.py index cd7cdef5c1..8ec614d338 100644 --- a/scripts/estimate/estimation.py +++ b/scripts/estimate/estimation.py @@ -99,7 +99,7 @@ def loss_fn(pred, labels): # build model (using meta init) model_cls = train_spec.cls - model_config = train_spec.config + model_config = train_spec.config[job_config.model.flavor] # set the model configs from training inputs: # 1. norm type to decide which norm layer to use # 2. vocab size from tokenizer From 078d4adeac2c43c0462e2f6161f79f00bb5c3833 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Tue, 11 Feb 2025 13:59:15 -0800 Subject: [PATCH 15/17] Update [ghstack-poisoned] --- scripts/generate/test_generate.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/scripts/generate/test_generate.py b/scripts/generate/test_generate.py index 210bd67389..23da678305 100644 --- a/scripts/generate/test_generate.py +++ b/scripts/generate/test_generate.py @@ -26,13 +26,14 @@ ) from torchtitan import utils - from torchtitan.config_manager import JobConfig from torchtitan.datasets import build_tokenizer from torchtitan.logging import init_logger, logger from torchtitan.metrics import build_device_memory_monitor -from torchtitan.models import model_name_to_cls, model_name_to_tokenizer, models_config +from torchtitan.models import model_name_to_tokenizer from torchtitan.parallelisms import ParallelDims + +from torchtitan.train_spec import get_train_spec from torchtitan.utils import device_module, device_type # support running w/o installing as package @@ -102,21 +103,21 @@ def test_generate( device_module.set_device(device) device_memory_monitor = build_device_memory_monitor() - model_name = config.model.name + train_spec = get_train_spec(job_config.model.name) logger.info(f"World Size: {world_size}, Local Rank: {local_rank} on {device}") # Tokenizer setup tokenizer = build_tokenizer( - model_name_to_tokenizer[model_name], config.model.tokenizer_path + model_name_to_tokenizer[train_spec.name], config.model.tokenizer_path ) - model_config = models_config[model_name][config.model.flavor] + model_config = train_spec.config[config.model.flavor] model_config.norm_type = config.model.norm_type model_config.max_seq_len = config.training.seq_len model_config.vocab_size = tokenizer.n_words - model_cls = model_name_to_cls[model_name] + model_cls = train_spec.cls init_device = "meta" if world_size > 1 else device with torch.device(init_device): logger.info(f"Init model on init_device: {init_device}") From caf5b97bb9191c3a1880d77bab7237acf8c0908d Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Tue, 11 Feb 2025 14:23:30 -0800 Subject: [PATCH 16/17] Update [ghstack-poisoned] --- scripts/generate/test_generate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/generate/test_generate.py b/scripts/generate/test_generate.py index 23da678305..2d016e0e88 100644 --- a/scripts/generate/test_generate.py +++ b/scripts/generate/test_generate.py @@ -103,7 +103,7 @@ def test_generate( device_module.set_device(device) device_memory_monitor = build_device_memory_monitor() - train_spec = get_train_spec(job_config.model.name) + train_spec = get_train_spec(config.model.name) logger.info(f"World Size: {world_size}, Local Rank: {local_rank} on {device}") From 2f4d1ce250f9ff294a0175a0a2f33c5c3a8a40c9 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Tue, 11 Feb 2025 17:29:23 -0800 Subject: [PATCH 17/17] Update [ghstack-poisoned] --- torchtitan/models/llama/pipeline_llama.py | 6 +++++- torchtitan/train_spec.py | 3 ++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/torchtitan/models/llama/pipeline_llama.py b/torchtitan/models/llama/pipeline_llama.py index bdd6281374..6a3622bacd 100644 --- a/torchtitan/models/llama/pipeline_llama.py +++ b/torchtitan/models/llama/pipeline_llama.py @@ -13,7 +13,11 @@ import torch.nn as nn from torch.distributed import DeviceMesh from torch.distributed.pipelining import PipelineStage -from torch.distributed.pipelining.schedules import _PipelineSchedule, get_schedule_class, ScheduleZBVZeroBubble +from torch.distributed.pipelining.schedules import ( + _PipelineSchedule, + get_schedule_class, + ScheduleZBVZeroBubble, +) from torchtitan.config_manager import JobConfig from torchtitan.logging import logger diff --git a/torchtitan/train_spec.py b/torchtitan/train_spec.py index 222ff97f92..f76b1a8d17 100644 --- a/torchtitan/train_spec.py +++ b/torchtitan/train_spec.py @@ -36,7 +36,8 @@ class ModelProtocol(Protocol): """ @staticmethod - def from_model_args(args: BaseModelArgs) -> nn.Module: ... + def from_model_args(args: BaseModelArgs) -> nn.Module: + ... OptimizersBuilder: TypeAlias = Callable[