Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/torchft.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,12 @@ The `--training.global_batch_size` parameter refers to global batch size that wi

#### Replica Group 0
```bash
CUDA_VISIBLE_DEVICES=0,1,2,3 NGPU=4 ./run_train.sh --parallelism.data_parallel_shard_degree=4 --fault_tolerance.enable --fault_tolerance.group_size=2 --fault_tolerance.replica_id=0 --fault_tolerance.semi_sync_method="diloco" --experimental.custom_args_module=torchtitan.components.ft.config
CONFIG_FILE=./torchtitan/models/llama3_ft/train_configs/debug_model.toml CUDA_VISIBLE_DEVICES=0,1,2,3 NGPU=4 ./run_train.sh --parallelism.data_parallel_shard_degree=4 --fault_tolerance.enable --fault_tolerance.group_size=2 --fault_tolerance.replica_id=0
```

#### Replica Group 1
```bash
CUDA_VISIBLE_DEVICES=4,5,6,7 NGPU=4 ./run_train.sh --parallelism.data_parallel_shard_degree=4 --fault_tolerance.enable --fault_tolerance.group_size=2 --fault_tolerance.replica_id=1 --fault_tolerance.semi_sync_method="diloco" --experimental.custom_args_module=torchtitan.components.ft.config
CONFIG_FILE=./torchtitan/models/llama3_ft/train_configs/debug_model.toml CUDA_VISIBLE_DEVICES=4,5,6,7 NGPU=4 ./run_train.sh --parallelism.data_parallel_shard_degree=4 --fault_tolerance.enable --fault_tolerance.group_size=2 --fault_tolerance.replica_id=1
```

## Fault Tolerance Configuration Options
Expand Down
14 changes: 14 additions & 0 deletions torchtitan/config/job_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,20 @@ class Profiling:
profile_freq: int = 10
"""How often to collect profile traces, in iterations"""

profiler_active: int = 1
"""
The steps profiler is active for.

This is used to configure torch.profile.schedule.
"""

profiler_warmup: int = 3
"""
The number of warmup steps before the active step in each profiling cycle.

This is used to configure torch.profile.schedule.
"""

enable_memory_snapshot: bool = False
"""Whether to dump memory snapshot"""

Expand Down
95 changes: 95 additions & 0 deletions torchtitan/models/llama3_ft/train_configs/debug_model.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
[job]
dump_folder = "./outputs"
description = "Llama 3 debug training"
print_args = false

[profiling]
enable_profiling = true
save_traces_folder = "profile_trace"
profile_freq = 10
profiler_active = 10
profiler_warmup = 0
enable_memory_snapshot = false
save_memory_snapshot_folder = "memory_snapshot"

[metrics]
log_freq = 1
disable_color_printing = false
enable_tensorboard = false
save_tb_folder = "tb"
enable_wandb = false

[model]
name = "llama3"
flavor = "debugmodel"
# test folder with tokenizer.json, for debug purpose only
hf_assets_path = "./tests/assets/tokenizer"
# converters = ["float8"]

[optimizer]
name = "AdamW"
lr = 8e-4
eps = 1e-8

[lr_scheduler]
warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps
decay_ratio = 0.8 # lr scheduler decay ratio, 80% of the train steps
decay_type = "linear"
min_lr_factor = 0.0

[training]
local_batch_size = 8
seq_len = 2048
max_norm = 1.0 # grad norm clipping
steps = 100
dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M)

[parallelism]
data_parallel_replicate_degree = 1
data_parallel_shard_degree = -1
fsdp_reshard_after_forward = "default" # default / never / always
tensor_parallel_degree = 1
enable_async_tensor_parallel = false
pipeline_parallel_degree = 1
context_parallel_degree = 1

[checkpoint]
enable = false
folder = "checkpoint"
interval = 10
last_save_model_only = false
export_dtype = "float32"
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]

[activation_checkpoint]
mode = "selective" # ["none", "selective", "full"]
selective_ac_option = "2" # 'int' = ac every positive int layer or 'op', ac based on ops policy

[compile]
enable = false
components = ["model", "loss"]

[quantize.linear.float8]
enable_fsdp_float8_all_gather = false
precompute_float8_dynamic_scale_for_fsdp = false
filter_fqns = ["output"]

[validation]
enable = false
dataset = "c4_validation"
freq = 5
steps = 10

[comm]
train_timeout_seconds = 15

[fault_tolerance]
enable = true
sync_steps = 10
num_fragments = 2
semi_sync_method = "diloco"
process_group = "nccl"
process_group_timeout_ms = 10000

[experimental]
custom_args_module = "torchtitan.components.ft.config"
10 changes: 5 additions & 5 deletions torchtitan/tools/profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,6 @@
from torchtitan.config import Profiling as ProfilingConfig
from torchtitan.tools.logging import logger

# the number of warmup steps before the active step in each profiling cycle
WARMUP = 3

# how much memory allocation/free ops to record in memory snapshots
MEMORY_SNAPSHOT_MAX_ENTRIES = 100000

Expand All @@ -34,7 +31,11 @@ def maybe_enable_profiling(

if enable_profiling:
trace_dir = os.path.join(base_folder, profiling_config.save_traces_folder)
profile_freq = profiling_config.profile_freq
profile_freq, warmup, active = (
profiling_config.profile_freq,
profiling_config.profiler_warmup,
profiling_config.profiler_active,
)

rank = torch.distributed.get_rank()

Expand All @@ -58,7 +59,6 @@ def trace_handler(prof):
if not os.path.exists(trace_dir):
os.makedirs(trace_dir, exist_ok=True)

warmup, active = WARMUP, 1
wait = profile_freq - (active + warmup)
assert (
wait >= 0
Expand Down