diff --git a/docs/float8.md b/docs/float8.md index ccfd7fbd5d..f076ed4d0a 100644 --- a/docs/float8.md +++ b/docs/float8.md @@ -7,9 +7,9 @@ USE_CPP=0 python -m pip install git+https://github.com/pytorch/ao.git Launch training job with the following command (or alternatively set configs in toml files) ``` -CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --float8.enable_float8_linear --float8.enable_fsdp_float8_all_gather --float8.precompute_float8_dynamic_scale_for_fsdp +CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --model.converters="float8" --float8.enable_fsdp_float8_all_gather --float8.precompute_float8_dynamic_scale_for_fsdp ``` -* `--float8.enable_float8_linear`: swap `nn.Linear` with `Float8Linear` to perform float8 matmul. +* `--model.converters="float8"`: swap `nn.Linear` with `Float8Linear` to perform float8 matmul. * `--float8.enable_fsdp_float8_all_gather`: cast `Float8Linear.weight` from high precision to float8 before FSDP all-gather so we can communicate in float8 to save bandwidth. * `--float8.precompute_float8_dynamic_scale_for_fsdp` (optional): communicate AMAX/scales efficiently in a single all-reduce for all parameters instead of doing many small all-reduce for each parameter. diff --git a/scripts/estimate/estimation.py b/scripts/estimate/estimation.py index af445d2bf6..ea0c7a59c6 100644 --- a/scripts/estimate/estimation.py +++ b/scripts/estimate/estimation.py @@ -9,15 +9,16 @@ import os import torch + +import torchtitan.float8 # noqa from torch._guards import active_fake_mode from torch._subclasses.fake_tensor import FakeTensorMode from torch.distributed._tools.fsdp2_mem_tracker import FSDPMemTracker from torch.testing._internal.distributed.fake_pg import FakeStore - from torchtitan import utils from torchtitan.config_manager import JobConfig -from torchtitan.float8 import Float8Handler from torchtitan.logging import init_logger, logger +from torchtitan.model_converter import build_model_converters from torchtitan.optimizer import build_lr_schedulers, build_optimizers from torchtitan.parallelisms import ParallelDims from torchtitan.train_spec import get_train_spec @@ -117,10 +118,9 @@ def loss_fn(pred, labels): with torch.device("meta"): model = model_cls.from_model_args(model_config) - # a no-op hander if float8 is not enabled - float8_handler = Float8Handler(job_config, parallel_dims) - # swap to Float8Linear based on float8 configs - float8_handler.convert_to_float8_training(model) + # Build the collection of model converters. No-op if `model.converters` empty + model_converters = build_model_converters(job_config, parallel_dims) + model_converters.convert(model) # apply PT-D DP/TP parallelisms and activation checkpointing train_spec.parallelize_fn(model, world_mesh, parallel_dims, job_config) @@ -170,9 +170,10 @@ def loss_fn(pred, labels): # optimizer step optimizers.step() lr_schedulers.step() - # calculate float8 dynamic amax/scale for all-parameter for FSDP2 + # Post-optimizer model converters hook. + # e.g. calculate float8 dynamic amax/scale for all-parameter for FSDP2 # it issues a single all-reduce for all parameters at once for better performance - float8_handler.precompute_float8_dynamic_scale_for_fsdp(model) + model_converters.post_optimizer_hook(model) optimizers.zero_grad() print(f"Peak Memory at iter: {iter_idx}") fsdp_memtracker.display_snapshot("peak", units="MiB", tabulate=True) diff --git a/tests/unit_tests/test_job_config.py b/tests/unit_tests/test_job_config.py index ae9bb56359..e29d4712b8 100644 --- a/tests/unit_tests/test_job_config.py +++ b/tests/unit_tests/test_job_config.py @@ -187,6 +187,15 @@ def test_parse_exclude_from_loading(self): config.checkpoint.exclude_from_loading == cmdline_splits ), config.checkpoint.exclude_from_loading + def test_job_config_model_converters_split(self): + config = JobConfig() + config.parse_args([]) + assert config.model.converters == [] + + config = JobConfig() + config.parse_args(["--model.converters", "float8,mxfp"]) + assert config.model.converters == ["float8", "mxfp"] + def test_print_help(self): config = JobConfig() parser = config.parser diff --git a/tests/unit_tests/test_model_converter.py b/tests/unit_tests/test_model_converter.py new file mode 100644 index 0000000000..ea8e9af310 --- /dev/null +++ b/tests/unit_tests/test_model_converter.py @@ -0,0 +1,44 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from torchtitan.config_manager import JobConfig +from torchtitan.float8 import Float8Converter +from torchtitan.model_converter import build_model_converters, ModelConvertersContainer +from torchtitan.parallelisms import ParallelDims + + +def build_parallel_dims(job_config, world_size): + parallel_dims = ParallelDims( + dp_shard=job_config.training.data_parallel_shard_degree, + dp_replicate=job_config.training.data_parallel_replicate_degree, + cp=job_config.experimental.context_parallel_degree, + tp=job_config.training.tensor_parallel_degree, + pp=job_config.experimental.pipeline_parallel_degree, + world_size=world_size, + enable_loss_parallel=not job_config.training.disable_loss_parallel, + ) + return parallel_dims + + +def test_build_model_converters_empty_list(): + config = JobConfig() + config.parse_args([]) + parallel_dims = build_parallel_dims(config, 1) + + model_converters = build_model_converters(config, parallel_dims) + assert isinstance(model_converters, ModelConvertersContainer) + assert model_converters.converters == [] + + +def test_build_model_converters_float8_converter(): + config = JobConfig() + config.parse_args(["--model.converters", "float8"]) + parallel_dims = build_parallel_dims(config, 1) + + model_converters = build_model_converters(config, parallel_dims) + assert isinstance(model_converters, ModelConvertersContainer) + assert len(model_converters.converters) == 1 + assert isinstance(model_converters.converters[0], Float8Converter) diff --git a/torchtitan/__init__.py b/torchtitan/__init__.py index be0d95f3c5..26515fe9d3 100644 --- a/torchtitan/__init__.py +++ b/torchtitan/__init__.py @@ -6,6 +6,9 @@ # # Copyright (c) Meta Platforms, Inc. All Rights Reserved. +# Import to register Float8Converter. +import torchtitan.float8 # noqa: F401 + # Import the built-in models here so that the corresponding register_model_spec() # will be called. import torchtitan.models # noqa: F401 diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index db9e290030..126e04161f 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -26,9 +26,22 @@ def string_list(raw_arg): + """Comma-separated string list argument.""" return [s.strip() for s in raw_arg.split(",") if s.strip()] +def check_string_list_argument(args_dict: dict[str, any], fullargname: str): + section, name = fullargname.split(".") + # Split string list which are still raw strings. + if ( + section in args_dict + and name in args_dict[section] + and isinstance(args_dict[section][name], str) + ): + sec = args_dict[section] + sec[name] = string_list(sec[name]) + + class JobConfig: """ A helper class to manage the train configuration. @@ -183,6 +196,19 @@ def __init__(self): default="./torchtitan/datasets/tokenizer/tokenizer.model", help="Tokenizer path", ) + self.parser.add_argument( + "--model.converters", + type=string_list, + nargs="+", + default=[], + help=""" + Comma separated list of converters to apply to the model. + + For instance, the `float8` converter swaps `torch.nn.Linear` + with `Float8Linear`. This feature requires you to install 'torchao' + which can be found here: https://github.com/pytorch/ao + """, + ) # optimizer configs self.parser.add_argument( @@ -575,15 +601,6 @@ def __init__(self): ) # float8 configs - self.parser.add_argument( - "--float8.enable_float8_linear", - action="store_true", - help=""" - If true, swaps `torch.nn.Linear` with `Float8Linear`. - This feature requires you to install 'torchao' which can be found - here: https://github.com/pytorch/ao - """, - ) self.parser.add_argument( "--float8.enable_fsdp_float8_all_gather", action="store_true", @@ -652,25 +669,11 @@ def parse_args(self, args_list: list = sys.argv[1:]): logger.exception(f"Error details: {str(e)}") raise e + # Checking string-list arguments are properly split into a list # if split-points came from 'args' (from cmd line) it would have already been parsed into a list by that parser - if ( - "experimental" in args_dict - and "pipeline_parallel_split_points" in args_dict["experimental"] - and isinstance( - args_dict["experimental"]["pipeline_parallel_split_points"], str - ) - ): - exp = args_dict["experimental"] - exp["pipeline_parallel_split_points"] = string_list( - exp["pipeline_parallel_split_points"] - ) - if ( - "checkpoint" in args_dict - and "exclude_from_loading" in args_dict["checkpoint"] - and isinstance(args_dict["checkpoint"]["exclude_from_loading"], str) - ): - ckpt = args_dict["checkpoint"] - ckpt["exclude_from_loading"] = string_list(ckpt["exclude_from_loading"]) + string_list_argnames = self._get_string_list_argument_names() + for n in string_list_argnames: + check_string_list_argument(args_dict, n) # override args dict with cmd_args cmd_args_dict = self._args_to_two_level_dict(cmd_args) @@ -698,6 +701,13 @@ def _validate_config(self) -> None: assert self.model.flavor assert self.model.tokenizer_path + def _get_string_list_argument_names(self) -> list[str]: + """Get the parser argument names of type `string_list`.""" + string_list_args = [ + v.dest for v in self.parser._actions if v.type is string_list + ] + return string_list_args + def parse_args_from_command_line( self, args_list ) -> Tuple[argparse.Namespace, argparse.Namespace]: @@ -705,6 +715,7 @@ def parse_args_from_command_line( Parse command line arguments and return the parsed args and the command line only args """ args = self.parser.parse_args(args_list) + string_list_argnames = set(self._get_string_list_argument_names()) # aux parser to parse the command line only args, with no defaults from main parser aux_parser = argparse.ArgumentParser(argument_default=argparse.SUPPRESS) @@ -713,14 +724,11 @@ def parse_args_from_command_line( aux_parser.add_argument( "--" + arg, action="store_true" if val else "store_false" ) - elif arg == "experimental.pipeline_parallel_split_points": + elif arg in string_list_argnames: # without this special case, type inference breaks here, # since the inferred type is just 'list' and it ends up flattening # e.g. from ["layers.0", "layers.1"] into ["l", "a", "y", "e", "r", "s", ".0", ...] aux_parser.add_argument("--" + arg, type=string_list) - elif arg == "checkpoint.exclude_from_loading": - # similar to the case above - aux_parser.add_argument("--" + arg, type=string_list) else: aux_parser.add_argument("--" + arg, type=type(val)) diff --git a/torchtitan/float8.py b/torchtitan/float8.py index 849ac378fe..d97606fa22 100644 --- a/torchtitan/float8.py +++ b/torchtitan/float8.py @@ -20,6 +20,7 @@ from torchtitan.config_manager import JobConfig from torchtitan.logging import logger +from torchtitan.model_converter import ModelConverter, register_model_converter from torchtitan.parallelisms import ParallelDims @@ -28,13 +29,11 @@ def _is_sm89_or_later(): return torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9) -class Float8Handler: +class Float8Converter(ModelConverter): def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): self.enabled = False float8_config = job_config.float8 - if not float8_config.enable_float8_linear: - return if not _is_sm89_or_later(): logger.warning( "Failed to swap to Float8Linear because float8 is only supported on SM89 or later", @@ -66,6 +65,12 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): logger.info("Float8 training active") + def convert(self, model: nn.Module): + return self.convert_to_float8_training(model) + + def post_optimizer_hook(self, model: Union[nn.Module, List[nn.Module]]): + return self.precompute_float8_dynamic_scale_for_fsdp(model) + def convert_to_float8_training(self, model: nn.Module): """ This function converts the linear layers of `model` to `Float8Linear`. @@ -102,3 +107,6 @@ def precompute_float8_dynamic_scale_for_fsdp( models = [model] if isinstance(model, nn.Module) else model for m in models: precompute_float8_dynamic_scale_for_fsdp(m) + + +register_model_converter(Float8Converter, "float8") diff --git a/torchtitan/metrics.py b/torchtitan/metrics.py index a42c887d0f..bcd03f2448 100644 --- a/torchtitan/metrics.py +++ b/torchtitan/metrics.py @@ -11,6 +11,7 @@ import torch from torch.utils.tensorboard import SummaryWriter + from torchtitan.config_manager import JobConfig from torchtitan.logging import logger from torchtitan.parallelisms import ParallelDims diff --git a/torchtitan/model_converter.py b/torchtitan/model_converter.py new file mode 100644 index 0000000000..2719238d9e --- /dev/null +++ b/torchtitan/model_converter.py @@ -0,0 +1,80 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from typing import Dict, List, Protocol, Union + +import torch.nn as nn + +from torchtitan.config_manager import JobConfig +from torchtitan.parallelisms import ParallelDims + + +class ModelConverter(Protocol): + """General model converter interface. + + A model converter is applying a modification to PyTorch model. + Typical use cases are: + - Quantization: using QAT, FP8, ... specialized linear layers; + - Fused optimized layers (e.g. flash-attention, norms, ...) + """ + + def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): + ... + + def convert(self, model: nn.Module): + """Inplace convertion of the model.""" + ... + + def post_optimizer_hook(self, model: Union[nn.Module, List[nn.Module]]): + """Post-optimizer (optional) hook (e.g. compute weights statistics).""" + ... + + +_registry_model_converter_cls: Dict[str, type[ModelConverter]] = {} +"""Registry of model converter classes. +""" + + +def register_model_converter(converter_cls: type[ModelConverter], name: str): + """Register a model converter class. + + A registered model converter can be applied on any model + using the `model.converters` config parameter. + """ + assert ( + name not in _registry_model_converter_cls + ), f"A model converter '{name}' is already registered." + _registry_model_converter_cls[name] = converter_cls + + +class ModelConvertersContainer(ModelConverter): + """Model converters sequential container. + + The class build the sequence of model converters defined in `model.converters` + job config, and apply them to the model sequentially. + """ + + def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): + converter_classes = [ + _registry_model_converter_cls[name] for name in job_config.model.converters + ] + self.converters = [ + mh_cls(job_config, parallel_dims) for mh_cls in converter_classes + ] + + def convert(self, model: nn.Module): + for mh in self.converters: + mh.convert(model) + + def post_optimizer_hook(self, model: Union[nn.Module, List[nn.Module]]): + for mh in self.converters: + mh.post_optimizer_hook(model) + + +def build_model_converters( + job_config: JobConfig, parallel_dims: ParallelDims +) -> ModelConvertersContainer: + """Build the collection of model converters to apply to the model.""" + return ModelConvertersContainer(job_config, parallel_dims) diff --git a/torchtitan/models/llama/parallelize_llama.py b/torchtitan/models/llama/parallelize_llama.py index 27c89feb0d..e5c03e8f33 100644 --- a/torchtitan/models/llama/parallelize_llama.py +++ b/torchtitan/models/llama/parallelize_llama.py @@ -56,11 +56,12 @@ def parallelize_llama( and not job_config.training.compile ): raise RuntimeError("Async TP requires --training.compile") + enable_float8_linear = "float8" in job_config.model.converters apply_tp( model, world_mesh["tp"], loss_parallel=parallel_dims.loss_parallel_enabled, - enable_float8=job_config.float8.enable_float8_linear, + enable_float8=enable_float8_linear, enable_async_tp=job_config.experimental.enable_async_tensor_parallel, ) diff --git a/torchtitan/profiling.py b/torchtitan/profiling.py index f85ae2b12e..8a9146821c 100644 --- a/torchtitan/profiling.py +++ b/torchtitan/profiling.py @@ -10,6 +10,7 @@ import time import torch + from torchtitan.config_manager import JobConfig from torchtitan.logging import logger diff --git a/train.py b/train.py index eeb3705f96..3097fd9bc0 100644 --- a/train.py +++ b/train.py @@ -15,15 +15,14 @@ from torchtitan import utils from torchtitan.checkpoint import CheckpointManager, TrainState from torchtitan.config_manager import JobConfig -from torchtitan.float8 import Float8Handler from torchtitan.logging import init_logger, logger from torchtitan.metrics import build_device_memory_monitor, build_metric_logger +from torchtitan.model_converter import build_model_converters from torchtitan.parallelisms import ParallelDims from torchtitan.profiling import maybe_enable_memory_snapshot, maybe_enable_profiling from torchtitan.train_spec import get_train_spec from torchtitan.utils import device_module, device_type, import_module_from_path - # Enable debug tracing on failure: https://pytorch.org/docs/stable/elastic/errors.html @record def main(job_config: JobConfig): @@ -107,10 +106,9 @@ def main(job_config: JobConfig): with torch.device("meta"): model = model_cls.from_model_args(model_config) - # a no-op hander if float8 is not enabled - float8_handler = Float8Handler(job_config, parallel_dims) - # swap to Float8Linear based on float8 configs - float8_handler.convert_to_float8_training(model) + # Build the collection of model converters. No-op if `model.converters` empty + model_converters = build_model_converters(job_config, parallel_dims) + model_converters.convert(model) # log model size model_param_count = utils.get_num_params(model) @@ -324,9 +322,10 @@ def loss_fn(pred, labels): optimizers.step() lr_schedulers.step() - # calculate float8 dynamic amax/scale for all-parameter for FSDP2 + # Post-optimizer model converters hook. + # e.g. calculate float8 dynamic amax/scale for all-parameter for FSDP2 # it issues a single all-reduce for all parameters at once for better performance - float8_handler.precompute_float8_dynamic_scale_for_fsdp(model_parts) + model_converters.post_optimizer_hook(model_parts) # log metrics if ( diff --git a/train_configs/debug_model.toml b/train_configs/debug_model.toml index cbbbd12c8e..8f4a40dd6c 100644 --- a/train_configs/debug_model.toml +++ b/train_configs/debug_model.toml @@ -26,6 +26,7 @@ flavor = "debugmodel" norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm # test tokenizer.model, for debug purpose only tokenizer_path = "./tests/assets/test_tiktoken.model" +# converters = "float8" [optimizer] name = "AdamW" @@ -63,4 +64,5 @@ mode = 'selective' # ['none', 'selective', 'full'] selective_ac_option = '2' # 'int' = ac every positive int layer or 'op', ac based on ops policy [float8] -enable_float8_linear = false +enable_fsdp_float8_all_gather = false +precompute_float8_dynamic_scale_for_fsdp = false diff --git a/train_configs/llama3_405b.toml b/train_configs/llama3_405b.toml index 26405603db..9028f32223 100644 --- a/train_configs/llama3_405b.toml +++ b/train_configs/llama3_405b.toml @@ -20,6 +20,7 @@ name = "llama3" flavor = "405B" norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm tokenizer_path = "./torchtitan/datasets/tokenizer/original/tokenizer.model" +converters = "float8" [optimizer] name = "AdamW" @@ -55,6 +56,5 @@ async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] mode = 'full' # ['none', 'selective', 'full'] [float8] -enable_float8_linear = true enable_fsdp_float8_all_gather = true precompute_float8_dynamic_scale_for_fsdp = true diff --git a/train_configs/llama3_70b.toml b/train_configs/llama3_70b.toml index e73e4b9457..f54a57fb67 100644 --- a/train_configs/llama3_70b.toml +++ b/train_configs/llama3_70b.toml @@ -20,6 +20,7 @@ name = "llama3" flavor = "70B" norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm tokenizer_path = "./torchtitan/datasets/tokenizer/original/tokenizer.model" +# converters = "float8" [optimizer] name = "AdamW" @@ -54,4 +55,5 @@ async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] mode = 'full' [float8] -enable_float8_linear = false +enable_fsdp_float8_all_gather = false +precompute_float8_dynamic_scale_for_fsdp = false diff --git a/train_configs/llama3_8b.toml b/train_configs/llama3_8b.toml index c616403629..e78b474ddf 100644 --- a/train_configs/llama3_8b.toml +++ b/train_configs/llama3_8b.toml @@ -20,6 +20,7 @@ name = "llama3" flavor = "8B" norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm tokenizer_path = "./torchtitan/datasets/tokenizer/original/tokenizer.model" +# converters = "float8" [optimizer] name = "AdamW" @@ -55,4 +56,5 @@ mode = 'selective' # ['none', 'selective', 'full'] selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy [float8] -enable_float8_linear = false +enable_fsdp_float8_all_gather = false +precompute_float8_dynamic_scale_for_fsdp = false