diff --git a/docs/torchft.md b/docs/torchft.md index 68733ec1ec..b39eddb6db 100644 --- a/docs/torchft.md +++ b/docs/torchft.md @@ -68,12 +68,12 @@ The `--training.global_batch_size` parameter refers to global batch size that wi #### Replica Group 0 ```bash -CUDA_VISIBLE_DEVICES=0,1,2,3 NGPU=4 ./run_train.sh --parallelism.data_parallel_shard_degree=4 --fault_tolerance.enable --fault_tolerance.group_size=2 --fault_tolerance.replica_id=0 --fault_tolerance.semi_sync_method="diloco" --experimental.custom_args_module=torchtitan.components.ft.config +CONFIG_FILE=./torchtitan/models/llama3_ft/train_configs/debug_model.toml CUDA_VISIBLE_DEVICES=0,1,2,3 NGPU=4 ./run_train.sh --parallelism.data_parallel_shard_degree=4 --fault_tolerance.enable --fault_tolerance.group_size=2 --fault_tolerance.replica_id=0 ``` #### Replica Group 1 ```bash -CUDA_VISIBLE_DEVICES=4,5,6,7 NGPU=4 ./run_train.sh --parallelism.data_parallel_shard_degree=4 --fault_tolerance.enable --fault_tolerance.group_size=2 --fault_tolerance.replica_id=1 --fault_tolerance.semi_sync_method="diloco" --experimental.custom_args_module=torchtitan.components.ft.config +CONFIG_FILE=./torchtitan/models/llama3_ft/train_configs/debug_model.toml CUDA_VISIBLE_DEVICES=4,5,6,7 NGPU=4 ./run_train.sh --parallelism.data_parallel_shard_degree=4 --fault_tolerance.enable --fault_tolerance.group_size=2 --fault_tolerance.replica_id=1 ``` ## Fault Tolerance Configuration Options diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index eb477941ca..138672c739 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -34,6 +34,20 @@ class Profiling: profile_freq: int = 10 """How often to collect profile traces, in iterations""" + profiler_active: int = 1 + """ + The steps profiler is active for. + + This is used to configure torch.profile.schedule. + """ + + profiler_warmup: int = 3 + """ + The number of warmup steps before the active step in each profiling cycle. + + This is used to configure torch.profile.schedule. + """ + enable_memory_snapshot: bool = False """Whether to dump memory snapshot""" diff --git a/torchtitan/models/llama3_ft/train_configs/debug_model.toml b/torchtitan/models/llama3_ft/train_configs/debug_model.toml new file mode 100644 index 0000000000..000af2ac06 --- /dev/null +++ b/torchtitan/models/llama3_ft/train_configs/debug_model.toml @@ -0,0 +1,95 @@ +[job] +dump_folder = "./outputs" +description = "Llama 3 debug training" +print_args = false + +[profiling] +enable_profiling = true +save_traces_folder = "profile_trace" +profile_freq = 10 +profiler_active = 10 +profiler_warmup = 0 +enable_memory_snapshot = false +save_memory_snapshot_folder = "memory_snapshot" + +[metrics] +log_freq = 1 +disable_color_printing = false +enable_tensorboard = false +save_tb_folder = "tb" +enable_wandb = false + +[model] +name = "llama3" +flavor = "debugmodel" +# test folder with tokenizer.json, for debug purpose only +hf_assets_path = "./tests/assets/tokenizer" +# converters = ["float8"] + +[optimizer] +name = "AdamW" +lr = 8e-4 +eps = 1e-8 + +[lr_scheduler] +warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps +decay_ratio = 0.8 # lr scheduler decay ratio, 80% of the train steps +decay_type = "linear" +min_lr_factor = 0.0 + +[training] +local_batch_size = 8 +seq_len = 2048 +max_norm = 1.0 # grad norm clipping +steps = 100 +dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M) + +[parallelism] +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = -1 +fsdp_reshard_after_forward = "default" # default / never / always +tensor_parallel_degree = 1 +enable_async_tensor_parallel = false +pipeline_parallel_degree = 1 +context_parallel_degree = 1 + +[checkpoint] +enable = false +folder = "checkpoint" +interval = 10 +last_save_model_only = false +export_dtype = "float32" +async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] + +[activation_checkpoint] +mode = "selective" # ["none", "selective", "full"] +selective_ac_option = "2" # 'int' = ac every positive int layer or 'op', ac based on ops policy + +[compile] +enable = false +components = ["model", "loss"] + +[quantize.linear.float8] +enable_fsdp_float8_all_gather = false +precompute_float8_dynamic_scale_for_fsdp = false +filter_fqns = ["output"] + +[validation] +enable = false +dataset = "c4_validation" +freq = 5 +steps = 10 + +[comm] +train_timeout_seconds = 15 + +[fault_tolerance] +enable = true +sync_steps = 10 +num_fragments = 2 +semi_sync_method = "diloco" +process_group = "nccl" +process_group_timeout_ms = 10000 + +[experimental] +custom_args_module = "torchtitan.components.ft.config" diff --git a/torchtitan/tools/profiling.py b/torchtitan/tools/profiling.py index 0e851d335a..f398dba9b5 100644 --- a/torchtitan/tools/profiling.py +++ b/torchtitan/tools/profiling.py @@ -14,9 +14,6 @@ from torchtitan.config import Profiling as ProfilingConfig from torchtitan.tools.logging import logger -# the number of warmup steps before the active step in each profiling cycle -WARMUP = 3 - # how much memory allocation/free ops to record in memory snapshots MEMORY_SNAPSHOT_MAX_ENTRIES = 100000 @@ -34,7 +31,11 @@ def maybe_enable_profiling( if enable_profiling: trace_dir = os.path.join(base_folder, profiling_config.save_traces_folder) - profile_freq = profiling_config.profile_freq + profile_freq, warmup, active = ( + profiling_config.profile_freq, + profiling_config.profiler_warmup, + profiling_config.profiler_active, + ) rank = torch.distributed.get_rank() @@ -58,7 +59,6 @@ def trace_handler(prof): if not os.path.exists(trace_dir): os.makedirs(trace_dir, exist_ok=True) - warmup, active = WARMUP, 1 wait = profile_freq - (active + warmup) assert ( wait >= 0