From 546847a0c632d38f3b39598e58bc3887eca698d7 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Wed, 12 Nov 2025 13:12:09 -0800 Subject: [PATCH 1/9] Update (base update) [ghstack-poisoned] --- .gitignore | 3 ++ scripts/dry_run.py | 3 ++ .../deepseek_v3/parallelize.py | 5 +- .../compiler_toolkit/graph_utils.py | 50 ++++++++++++++----- .../compiler_toolkit/llama3/parallelize.py | 5 +- torchtitan/train.py | 3 +- 6 files changed, 53 insertions(+), 16 deletions(-) diff --git a/.gitignore b/.gitignore index 45a8f5752a..415631ff9c 100644 --- a/.gitignore +++ b/.gitignore @@ -42,3 +42,6 @@ Sessionx.vim # env files .env + +# Vibe coding +.claude diff --git a/scripts/dry_run.py b/scripts/dry_run.py index 2552ca0d78..fa8e1b4c17 100644 --- a/scripts/dry_run.py +++ b/scripts/dry_run.py @@ -151,6 +151,9 @@ def __init__(self, job_config: JobConfig): logger.info("Configuration is ready for training execution.") logger.info("=" * 80) + def train(self): + return + if __name__ == "__main__": main(DryRunTrainer) diff --git a/torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py b/torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py index bc6859af61..20ad17f301 100644 --- a/torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py +++ b/torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py @@ -80,7 +80,9 @@ def parallelize_deepseekv3( compiler_passes = get_compiler_passes_from_config(job_config) # Create compilers with specified passes (defaults to no passes) - fw_compiler, bw_compiler = make_compiler_with_passes(compiler_passes) + fw_compiler, bw_compiler = make_compiler_with_passes( + compiler_passes, dump_folder=job_config.job.dump_folder + ) # Create custom joint_graph_builder with deepseekv3-specific compilers deepseekv3_joint_graph_builder = functools.partial( @@ -88,6 +90,7 @@ def parallelize_deepseekv3( fw_compiler=fw_compiler, bw_compiler=bw_compiler, joint_custom_pass=validate_flex_attention_annotation, + dump_folder=job_config.job.dump_folder, ) # TODO: CompiledModule should take sample input as well, so that we can diff --git a/torchtitan/experiments/compiler_toolkit/graph_utils.py b/torchtitan/experiments/compiler_toolkit/graph_utils.py index aee089cad9..db998aa170 100644 --- a/torchtitan/experiments/compiler_toolkit/graph_utils.py +++ b/torchtitan/experiments/compiler_toolkit/graph_utils.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import contextlib +from pathlib import Path from typing import Callable, List, Optional import torch @@ -21,8 +22,18 @@ from torchtitan.tools.logging import logger +def _dump_gm(dump_folder: str | None, gm: torch.fx.GraphModule, name: str) -> None: + # TODO: make the dump rank configurable + if not dump_folder or torch.distributed.get_rank() != 0: + return + + output_path = Path(dump_folder) / "compiler" / f"{name}.txt" + output_path.parent.mkdir(parents=True, exist_ok=True) + output_path.write_text(gm.print_readable(print_output=False)) + + def export_joint( - model, args, kwargs=None + model, args, kwargs=None, dump_folder: str | None = None ) -> tuple[JointWithDescriptors, TracingContext]: if kwargs is None: kwargs = {} @@ -35,8 +46,10 @@ def export_joint( torch.fx.traceback.preserve_node_meta(), ): gm = dynamo_graph_capture_for_export(model)(*args, **kwargs) - logger.info("Dynamo gm:") - logger.info(gm.print_readable(print_output=False)) + logger.debug("Dynamo gm:") + logger.debug(gm.print_readable(print_output=False)) + _dump_gm(dump_folder, gm, "dynamo_gm") + tracing_context = gm.meta["tracing_context"] with tracing(tracing_context): @@ -68,6 +81,7 @@ def joint_graph_builder( fw_compiler: Optional[Callable] = None, bw_compiler: Optional[Callable] = None, joint_custom_pass: Optional[Callable] = None, + dump_folder: str | None = None, ): """ Build a joint forward-backward graph for the model with optional custom compilers. @@ -79,16 +93,17 @@ def joint_graph_builder( fw_compiler: Optional custom forward compiler function bw_compiler: Optional custom backward compiler function joint_custom_pass: Optional custom pass to run on the joint graph + dump_folder: Optional folder to dump the graph to """ assert isinstance(model_args, tuple) - for arg in model_args: - assert isinstance(arg, DTensor) + for idx, arg in enumerate(model_args): + assert isinstance(arg, DTensor), f"Argument {idx} is of type {type(arg)}" # get joint graph ( joint_with_descriptors, tracing_context, - ) = export_joint(model, model_args, model_kwargs) + ) = export_joint(model, model_args, model_kwargs, dump_folder=dump_folder) # Optional validation if joint_custom_pass is not None: @@ -179,6 +194,7 @@ def compiler( gm: torch.fx.GraphModule, example_inputs, passes: List[Callable] = None, + dump_folder: str | None = None, ): """ Compile a graph module by applying a sequence of compiler passes. @@ -194,19 +210,23 @@ def compiler( if passes is None: passes = DEFAULT_COMPILER_PASSES - logger.info(f"{name} before compiler:") - logger.info(gm.print_readable(print_output=False)) + logger.debug(f"{name} before compiler:") + logger.debug(gm.print_readable(print_output=False)) + _dump_gm(dump_folder, gm, f"{name}_before_compiler") for pass_fn in passes: logger.info(f"Applying pass: {pass_fn.__name__}") gm = pass_fn(gm, example_inputs) - logger.info(f"{name} after compiler:") - logger.info(gm.print_readable(print_output=False)) + logger.debug(f"{name} after compiler:") + logger.debug(gm.print_readable(print_output=False)) + _dump_gm(dump_folder, gm, f"{name}_after_compiler") return gm -def make_compiler_with_passes(passes: List[Callable] = None): +def make_compiler_with_passes( + passes: List[Callable] = None, dump_folder: str | None = None +): """ Create forward and backward compilers with specified passes. @@ -218,10 +238,14 @@ def make_compiler_with_passes(passes: List[Callable] = None): """ def fw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None: - return compiler("fwd_gm", gm, example_inputs, passes=passes) + return compiler( + "fwd_gm", gm, example_inputs, passes=passes, dump_folder=dump_folder + ) def bw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None: - return compiler("bwd_gm", gm, example_inputs, passes=passes) + return compiler( + "bwd_gm", gm, example_inputs, passes=passes, dump_folder=dump_folder + ) return fw_compiler, bw_compiler diff --git a/torchtitan/experiments/compiler_toolkit/llama3/parallelize.py b/torchtitan/experiments/compiler_toolkit/llama3/parallelize.py index e3dca203e9..0ffbe61b89 100644 --- a/torchtitan/experiments/compiler_toolkit/llama3/parallelize.py +++ b/torchtitan/experiments/compiler_toolkit/llama3/parallelize.py @@ -64,7 +64,9 @@ def parallelize_llama( model = simple_fsdp_parallelize_llama(model, parallel_dims, job_config) # Get compiler passes from config - compiler_passes = get_compiler_passes_from_config(job_config) + compiler_passes = get_compiler_passes_from_config( + job_config, dump_folder=job_config.job.dump_folder + ) # Create compilers with specified passes (defaults to no passes) fw_compiler, bw_compiler = make_compiler_with_passes(compiler_passes) @@ -75,6 +77,7 @@ def parallelize_llama( fw_compiler=fw_compiler, bw_compiler=bw_compiler, joint_custom_pass=validate_flex_attention_annotation, + dump_folder=job_config.job.dump_folder, ) # TODO: CompiledModule should take sample input as well, so that we can diff --git a/torchtitan/train.py b/torchtitan/train.py index 18a876c4bb..5cfab998b2 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -735,7 +735,8 @@ def main(trainer_class: type[Trainer]) -> None: raise else: trainer.close() - torch.distributed.destroy_process_group() + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() logger.info("Process group destroyed") From dd979ef1a6af3b62c490059c835e851bfef7c663 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Wed, 12 Nov 2025 13:12:09 -0800 Subject: [PATCH 2/9] Update [ghstack-poisoned] --- .../integration_test_8gpu_features.yaml | 5 + scripts/loss_compare.py | 785 ++++++++++++++++++ 2 files changed, 790 insertions(+) create mode 100644 scripts/loss_compare.py diff --git a/.github/workflows/integration_test_8gpu_features.yaml b/.github/workflows/integration_test_8gpu_features.yaml index c6e8ed30d5..5978eaf4a1 100644 --- a/.github/workflows/integration_test_8gpu_features.yaml +++ b/.github/workflows/integration_test_8gpu_features.yaml @@ -76,5 +76,10 @@ jobs: export TEST_WITH_ROCM=$([[ "${{ matrix.gpu-arch-type }}" == "rocm" ]] && echo 1 || echo 0) python -m tests.integration_tests.run_tests --test_suite features $RUNNER_TEMP/artifacts-to-be-uploaded --ngpu 8 + # Verify the accuracy. + export baseline_cmd='CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh' + export baseline_cmd='CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --parallelism.data_parallel_replicate_degree=2' + python3 scripts/loss_compare.py . . --baseline-cmd=${baseline_cmd} --test-cmd=${test_cmd} --no-seed-checkpoint --steps=10 + rm -rf $RUNNER_TEMP/artifacts-to-be-uploaded/*/checkpoint rm -rf artifacts-to-be-uploaded/*/checkpoint diff --git a/scripts/loss_compare.py b/scripts/loss_compare.py new file mode 100644 index 0000000000..12a2d70514 --- /dev/null +++ b/scripts/loss_compare.py @@ -0,0 +1,785 @@ +#!/usr/bin/env python3 + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +This script compares training losses between different git commits and/or different training configurations. +--debug.deterministic is always enabled and seed checkpoint is also enabled by default for reproducible +comparisons. You can disable seed checkpoint with --no-seed-checkpoint if you don't need it to speed up comparisons. +If --output-folder is specified, all outputs are organized in that folder with detailed analysis and statistical summaries. + +The --assert-equal flag can be used for CI testing to verify that losses are identical between runs. +If losses differ, the script will exit with a non-zero status code. + +Example usages: +1. Compare losses between two different git commits with default command: + loss_compare.py main my_branch + +2. Compare losses between two commits with custom command and save results: + loss_compare.py main my_branch \ + --baseline-cmd="CONFIG_FILE='./custom.toml' ./run_train.sh --parallelism.tensor_parallel_degree=2" \ + --output-folder=my_comparison + +3. Compare commits with the same command but skip seed checkpoint for faster execution: + loss_compare.py main my_branch --no-seed-checkpoint + +4. Compare the same commit with different training configurations: + loss_compare.py . . \ + --baseline-cmd="CONFIG_FILE='./llama3_8b.toml' ./run_train.sh --parallelism.dp=1" \ + --test-cmd="CONFIG_FILE='./llama3_8b.toml' ./run_train.sh --parallelism.dp=2" + +5. Assert that losses are equal (for CI testing): + loss_compare.py main my_branch --assert-equal +""" + +import argparse +import os +import re +import shutil +import subprocess +import sys +import unittest +from typing import Any + +# ============================================================================= +# GLOBAL CONFIGURATION +# ============================================================================= + +LOG_PREFIX = "[LOSS_COMPARE]" + +# Default configuration values +DEFAULT_RUN_CMD = "CONFIG_FILE='./torchtitan/models/llama3/train_configs/llama3_8b.toml' ./run_train.sh" +DEFAULT_STEPS = 100 + +# Fixed options that are always appended +FIXED_OPTIONS = "--debug.deterministic --debug.seed=42" + + +# ============================================================================= +# UTILITY FUNCTIONS +# ============================================================================= + + +def log_print(message: str = "") -> None: + """Print message with LOG_PREFIX.""" + if message: + print(f"{LOG_PREFIX} {message}") + else: + print(f"{LOG_PREFIX}") + + +def get_log_path(scenario: str, output_folder: str | None) -> str: + """Get log file path for a scenario.""" + if output_folder: + return f"{output_folder}/{scenario}_training.log" + return f"/tmp/{scenario}_training.log" + + +def get_loss_file_path(scenario: str, output_folder: str) -> str: + """Get loss file path for a scenario.""" + return f"{output_folder}/{scenario}_losses.txt" + + +def get_clean_log_path(scenario: str, output_folder: str) -> str: + """Get cleaned log file path for a scenario.""" + return f"{output_folder}/{scenario}_training_clean.log" + + +def extract_config_file(cmd: str) -> str | None: + """Extract CONFIG_FILE value from command string.""" + # Match CONFIG_FILE=value with optional quotes + patterns = [ + r"CONFIG_FILE='([^']+)'", # Single quotes + r'CONFIG_FILE="([^"]+)"', # Double quotes + r"CONFIG_FILE=(\S+)", # No quotes + ] + + for pattern in patterns: + match = re.search(pattern, cmd) + if match: + return match.group(1) + + return None + + +def strip_ansi_codes(input_file: str, output_file: str) -> None: + """Strip ANSI escape codes from log files.""" + ansi_escape = re.compile(r"\x1b\[[0-9;]*m") + with open(input_file, "r") as f_in: + with open(output_file, "w") as f_out: + for line in f_in: + f_out.write(ansi_escape.sub("", line)) + + +def run_with_realtime_output(cmd: str, logfile: str, env: dict[str, Any]) -> None: + """Run command with real-time output to both console and log file.""" + log_print(f"Executing: {cmd}") + + # Set PYTHONUNBUFFERED for better output handling + env["PYTHONUNBUFFERED"] = "1" + + # Run command and tee output to both stdout and log file + with open(logfile, "w") as log_f: + process = subprocess.Popen( + cmd, + shell=True, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + env=env, + text=True, + bufsize=1, + ) + + for line in process.stdout: + print(line, end="") + log_f.write(line) + log_f.flush() + + process.wait() + + if process.returncode != 0: + raise subprocess.CalledProcessError(process.returncode, cmd) + + +def log_and_save(message: str, stats_file: str) -> None: + """Output message to both stdout and stats file.""" + print(message) + with open(stats_file, "a") as f: + f.write(message + "\n") + + +# ============================================================================= +# VALIDATION FUNCTIONS +# ============================================================================= + + +def validate_arguments( + baseline_commit: str, test_commit: str, baseline_cmd: str, test_cmd: str, steps: int +) -> None: + """Validate command line arguments.""" + # Validate commit arguments - if one is ".", both must be "." + if (baseline_commit == "." and test_commit != ".") or ( + baseline_commit != "." and test_commit == "." + ): + log_print("Error: If one commit is '.', both commits must be '.'") + log_print(f" Got baseline: '{baseline_commit}', test: '{test_commit}'") + log_print( + " Use '.' for both commits to compare different configurations on current working directory" + ) + sys.exit(1) + + # Validate CONFIG_FILE is specified in both commands + baseline_config = extract_config_file(baseline_cmd) + if not baseline_config: + log_print("Error: CONFIG_FILE not found in baseline command") + log_print(f" Baseline command: {baseline_cmd}") + log_print( + " Please specify CONFIG_FILE in the format: CONFIG_FILE='path/to/config.toml' ./run_train.sh" + ) + sys.exit(1) + + test_config = extract_config_file(test_cmd) + if not test_config: + log_print("Error: CONFIG_FILE not found in test command") + log_print(f" Test command: {test_cmd}") + log_print( + " Please specify CONFIG_FILE in the format: CONFIG_FILE='path/to/config.toml' ./run_train.sh" + ) + sys.exit(1) + + # Validate that commits and commands are not both identical + if baseline_commit == test_commit and baseline_cmd == test_cmd: + log_print("Error: Both commits and commands are identical") + log_print(" Cannot compare identical configurations") + log_print( + " Please provide either different commits or different commands" + ) + sys.exit(1) + + # Validate steps is a positive integer + if steps <= 0: + log_print(f"Error: --steps must be a positive integer, got: {steps}") + sys.exit(1) + + +# ============================================================================= +# SETUP FUNCTIONS +# ============================================================================= + + +def setup_output_directory(output_folder: str | None) -> str | None: + """Setup output directory and return stats file path. Returns None if no output folder specified.""" + if not output_folder: + return None + + # Check if output folder already exists + if os.path.exists(output_folder): + log_print(f"Error: Output folder '{output_folder}' already exists") + log_print(f"Please delete it first: rm -rf {output_folder}") + sys.exit(1) + + # Create the output folder + log_print(f"Creating output folder: {output_folder}") + os.makedirs(output_folder) + + # Set statistics file path + stats_file = os.path.join(output_folder, "comparison_statistics.txt") + return stats_file + + +def build_training_command( + base_cmd: str, steps: int, enable_seed_checkpoint: bool +) -> str: + """Build the final training command with all options.""" + cmd = f"{base_cmd} {FIXED_OPTIONS} --training.steps={steps}" + if enable_seed_checkpoint: + cmd += " --checkpoint.enable --checkpoint.export_dtype=bfloat16 --checkpoint.load_only" + return cmd + + +def print_configuration( + baseline_commit: str, + test_commit: str, + baseline_cmd: str, + test_cmd: str, + steps: int, + enable_seed_checkpoint: bool, +) -> None: + """Print configuration summary.""" + log_print( + f"Starting loss comparison between baseline commit: {baseline_commit} and test commit: {test_commit}" + ) + log_print(f"Training steps: {steps}") + log_print(f"Seed checkpoint enabled: {enable_seed_checkpoint}") + log_print() + + # Build and display final commands + baseline_final_cmd = build_training_command( + baseline_cmd, steps, enable_seed_checkpoint + ) + test_final_cmd = build_training_command(test_cmd, steps, enable_seed_checkpoint) + + log_print("Baseline command:") + log_print(f" {baseline_final_cmd}") + log_print() + log_print("Test command:") + log_print(f" {test_final_cmd}") + log_print() + + +# ============================================================================= +# GIT OPERATIONS +# ============================================================================= + + +def checkout_commit(commit: str, commit_name: str) -> None: + """Checkout git commit.""" + if commit != ".": + log_print(f"Checking out {commit_name} commit: {commit}") + subprocess.run(["git", "checkout", commit], check=True) + else: + log_print(f"Using current working directory for {commit_name} (commit: '.')") + + +# ============================================================================= +# TRAINING OPERATIONS +# ============================================================================= + + +def create_seed_checkpoint( + enable_seed_checkpoint: bool, baseline_cmd: str, output_folder: str | None +) -> None: + """Create seed checkpoint.""" + if enable_seed_checkpoint: + log_file = get_log_path("seed_checkpoint", output_folder) + log_print(f"Creating seed checkpoint and logging output to {log_file}") + + # Extract CONFIG_FILE from baseline command + config_file = extract_config_file(baseline_cmd) + if not config_file: + log_print("Warning: Could not extract CONFIG_FILE from baseline command") + log_print(f" Baseline command: {baseline_cmd}") + sys.exit(1) + + # Build seed checkpoint command + seed_cmd = ( + f"CONFIG_FILE={config_file} ./run_train.sh " + f"--checkpoint.create_seed_checkpoint --checkpoint.enable {FIXED_OPTIONS}" + ) + + env = os.environ.copy() + env["NGPU"] = "1" + + run_with_realtime_output(seed_cmd, log_file, env) + + # Backup the seed checkpoint + if output_folder: + shutil.copytree("outputs", f"{output_folder}/seed_checkpoint_outputs") + + +def restore_seed_checkpoint( + enable_seed_checkpoint: bool, output_folder: str | None +) -> None: + """Restore seed checkpoint.""" + if enable_seed_checkpoint and output_folder: + if os.path.exists("outputs"): + shutil.rmtree("outputs") + shutil.copytree(f"{output_folder}/seed_checkpoint_outputs", "outputs") + + +def run_training( + scenario: str, + cmd: str, + steps: int, + enable_seed_checkpoint: bool, + output_folder: str | None, +) -> str: + """Run training for a specific scenario. Returns the log file path.""" + log_file = get_log_path(scenario, output_folder) + log_print( + f"Running training with {scenario} commit and logging output to {log_file}" + ) + + # Build the final command + full_cmd = build_training_command(cmd, steps, enable_seed_checkpoint) + + env = os.environ.copy() + + run_with_realtime_output(full_cmd, log_file, env) + + # Backup the outputs + if output_folder: + shutil.move("outputs", f"{output_folder}/{scenario}_outputs") + else: + # Clean up outputs if not saving + if os.path.exists("outputs"): + shutil.rmtree("outputs") + + return log_file + + +# ============================================================================= +# LOG PROCESSING AND ANALYSIS +# ============================================================================= + + +def extract_losses_from_log(log_file: str) -> dict[int, float]: + """Extract step and loss pairs from a log file.""" + losses = {} + step_loss_pattern = re.compile(r"step:\s*(\d+)\s*loss:\s*(\d+\.\d+)") + ansi_escape = re.compile(r"\x1b\[[0-9;]*m") + + with open(log_file, "r") as f: + for line in f: + # Strip ANSI codes before matching + clean_line = ansi_escape.sub("", line) + match = step_loss_pattern.search(clean_line) + if match: + step, loss = match.groups() + losses[int(step)] = float(loss) + + return losses + + +def read_losses_from_file(loss_file: str) -> dict[int, float]: + """Read losses from a processed loss file.""" + losses = {} + with open(loss_file, "r") as f: + for line in f: + step, loss = line.strip().split() + losses[int(step)] = float(loss) + return losses + + +def extract_loss_data(output_folder: str | None) -> None: + """Extract loss data from logs.""" + if not output_folder: + return + + log_print("Cleaning ANSI escape codes from log files...") + + # Strip ANSI escape codes from log files before processing + scenarios = ["baseline", "test"] + for scenario in scenarios: + strip_ansi_codes( + get_log_path(scenario, output_folder), + get_clean_log_path(scenario, output_folder), + ) + + # Extract step and loss from cleaned logs + step_loss_pattern = re.compile(r"step:\s*(\d+)\s*loss:\s*(\d+\.\d+)") + + for scenario in scenarios: + with open(get_clean_log_path(scenario, output_folder), "r") as f_in: + with open(get_loss_file_path(scenario, output_folder), "w") as f_out: + for line in f_in: + match = step_loss_pattern.search(line) + if match: + step, loss = match.groups() + f_out.write(f"{step} {loss}\n") + + +def generate_step_comparison(output_folder: str, stats_file: str) -> None: + """Generate step-by-step comparison.""" + log_and_save("", stats_file) + log_and_save(f"{LOG_PREFIX} Step-by-step loss comparison:", stats_file) + log_and_save( + f"{LOG_PREFIX} Step Baseline Loss Test Loss Difference", stats_file + ) + log_and_save( + f"{LOG_PREFIX} ---- ------------- --------- ----------", stats_file + ) + + # Read baseline and test losses + baseline_losses = read_losses_from_file( + get_loss_file_path("baseline", output_folder) + ) + test_losses = read_losses_from_file(get_loss_file_path("test", output_folder)) + + # Generate comparison for common steps + for step in sorted(set(baseline_losses.keys()) & set(test_losses.keys())): + baseline_loss = baseline_losses[step] + test_loss = test_losses[step] + diff = test_loss - baseline_loss + + formatted_line = f"{LOG_PREFIX} {step:<6} {baseline_loss:<13} {test_loss:<14} {diff:.6f}" + log_and_save(formatted_line, stats_file) + + +def generate_summary_statistics(output_folder: str, stats_file: str) -> None: + """Generate summary statistics.""" + log_and_save(f"{LOG_PREFIX}", stats_file) + log_and_save(f"{LOG_PREFIX} Summary statistics:", stats_file) + + # Calculate average losses + def calculate_average(losses: dict[int, float]) -> float | None: + """Calculate average loss from losses dict.""" + if not losses: + return None + return sum(losses.values()) / len(losses) + + baseline_losses = read_losses_from_file( + get_loss_file_path("baseline", output_folder) + ) + test_losses = read_losses_from_file(get_loss_file_path("test", output_folder)) + + baseline_avg = calculate_average(baseline_losses) + test_avg = calculate_average(test_losses) + + baseline_avg_str = f"{baseline_avg}" if baseline_avg is not None else "N/A" + test_avg_str = f"{test_avg}" if test_avg is not None else "N/A" + + log_and_save(f"{LOG_PREFIX} Average baseline loss: {baseline_avg_str}", stats_file) + log_and_save(f"{LOG_PREFIX} Average test loss: {test_avg_str}", stats_file) + + # Calculate overall difference if both averages are available + if baseline_avg is not None and test_avg is not None: + avg_diff = test_avg - baseline_avg + log_and_save(f"{LOG_PREFIX} Average difference: {avg_diff:.6f}", stats_file) + + +def perform_loss_analysis(output_folder: str | None, stats_file: str | None) -> None: + """Perform loss comparison analysis.""" + if not output_folder or not stats_file: + log_print("Skipping loss analysis (no output folder specified)") + return + + # Initialize stats file and add header + log_and_save(f"{LOG_PREFIX} ==========================================", stats_file) + log_and_save(f"{LOG_PREFIX} LOSS COMPARISON ANALYSIS", stats_file) + log_and_save(f"{LOG_PREFIX} ==========================================", stats_file) + + # Extract loss data from training logs + extract_loss_data(output_folder) + + # Check if loss files were created successfully + scenarios = ["baseline", "test"] + for scenario in scenarios: + loss_path = get_loss_file_path(scenario, output_folder) + if not os.path.exists(loss_path) or os.path.getsize(loss_path) == 0: + log_and_save( + f"{LOG_PREFIX} Warning: Could not extract loss data from training logs.", + stats_file, + ) + log_and_save( + f"{LOG_PREFIX} Please check that the training completed successfully.", + stats_file, + ) + return + + # Generate comparison outputs + generate_step_comparison(output_folder, stats_file) + generate_summary_statistics(output_folder, stats_file) + + +def assert_losses_equal(baseline_log: str, test_log: str) -> None: + """Assert that losses are equal between baseline and test using unittest.""" + log_print("Asserting losses are equal...") + log_print(f"Baseline log: {baseline_log}") + log_print(f"Test log: {test_log}") + + # Extract losses from both logs + baseline_losses = extract_losses_from_log(baseline_log) + test_losses = extract_losses_from_log(test_log) + + log_print(f"Extracted {len(baseline_losses)} steps from baseline log") + log_print(f"Extracted {len(test_losses)} steps from test log") + + if not baseline_losses: + log_print("Error: No losses found in baseline log") + sys.exit(1) + + if not test_losses: + log_print("Error: No losses found in test log") + sys.exit(1) + + # Create a test case + class LossEqualityTest(unittest.TestCase): + def test_losses_equal(self): + # Check that both have the same steps + baseline_steps = set(baseline_losses.keys()) + test_steps = set(test_losses.keys()) + + self.assertEqual( + baseline_steps, + test_steps, + f"Steps mismatch: baseline has {len(baseline_steps)} steps, test has {len(test_steps)} steps", + ) + + # Check that losses are equal for each step + for step in sorted(baseline_steps): + baseline_loss = baseline_losses[step] + test_loss = test_losses[step] + self.assertEqual( + baseline_loss, + test_loss, + f"Loss mismatch at step {step}: baseline={baseline_loss}, test={test_loss}", + ) + + # Run the test + suite = unittest.TestLoader().loadTestsFromTestCase(LossEqualityTest) + runner = unittest.TextTestRunner(verbosity=2) + result = runner.run(suite) + + if not result.wasSuccessful(): + log_print("Loss assertion failed!") + sys.exit(1) + else: + log_print("All losses are equal. Assertion passed!") + + +def cleanup_temp_files(output_folder: str | None) -> None: + """Cleanup temporary files.""" + if not output_folder: + return + + scenarios = ["baseline", "test"] + for scenario in scenarios: + for temp_file in [ + get_loss_file_path(scenario, output_folder), + get_clean_log_path(scenario, output_folder), + ]: + if os.path.exists(temp_file): + os.remove(temp_file) + + +# ============================================================================= +# OUTPUT FUNCTIONS +# ============================================================================= + + +def print_completion_summary( + output_folder: str | None, enable_seed_checkpoint: bool +) -> None: + """Print completion summary.""" + log_print() + if output_folder: + log_print(f"Loss comparison complete. Results saved in {output_folder}/:") + log_print(" - baseline_outputs/") + log_print(" - test_outputs/") + if enable_seed_checkpoint: + log_print(" - seed_checkpoint_outputs/") + log_print() + log_print(f"Training logs saved in {output_folder}/:") + if enable_seed_checkpoint: + log_print(" - seed_checkpoint.log") + log_print(" - baseline_training.log") + log_print(" - test_training.log") + log_print() + log_print(f"All outputs organized in: {output_folder}/") + else: + log_print( + "Loss comparison complete. No results saved (no output folder specified)." + ) + + +# ============================================================================= +# MAIN EXECUTION +# ============================================================================= + + +def parse_arguments() -> argparse.Namespace: + """Parse command line arguments.""" + parser = argparse.ArgumentParser( + description="Compare training losses between different git commits and/or different training configurations.", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + %(prog)s abc123 def456 + %(prog)s abc123 def456 --steps=200 + %(prog)s abc123 def456 --baseline-cmd="CONFIG_FILE='./custom.toml' ./run_train.sh" --steps=50 + %(prog)s abc123 def456 --no-seed-checkpoint + %(prog)s . . --baseline-cmd="CONFIG_FILE='./llama3_8b.toml' ./run_train.sh --parallelism.dp=1" \\ + --test-cmd="CONFIG_FILE='./llama3_8b.toml' ./run_train.sh --parallelism.dp=2" --steps=30 + """, + ) + + parser.add_argument("baseline_commit", help="Git commit hash for baseline") + parser.add_argument("test_commit", help="Git commit hash for test") + parser.add_argument( + "--baseline-cmd", + default="", + help=f"Full command for baseline run (default: {DEFAULT_RUN_CMD})", + ) + parser.add_argument( + "--test-cmd", + default="", + help="Full command for test run (default: uses baseline-cmd)", + ) + parser.add_argument( + "--steps", + type=int, + default=DEFAULT_STEPS, + help=f"Number of training steps (default: {DEFAULT_STEPS})", + ) + parser.add_argument( + "--no-seed-checkpoint", + action="store_true", + help="Disable seed checkpoint creation and checkpoint functionality", + ) + parser.add_argument( + "--output-folder", + default="", + help="Output folder for results (optional, if not specified, results will not be saved)", + ) + parser.add_argument( + "--assert-equal", + action="store_true", + help="Assert that all losses are equal (for CI testing). Script exits with error if losses differ.", + ) + + args = parser.parse_args() + + # Set default commands if not provided + if not args.baseline_cmd: + args.baseline_cmd = DEFAULT_RUN_CMD + + if not args.test_cmd: + args.test_cmd = args.baseline_cmd + log_print("Note: Using baseline command for both baseline and test runs") + + # Convert empty output_folder to None + if not args.output_folder: + args.output_folder = None + + return args + + +def run_scenario( + scenario: str, + commit: str, + cmd: str, + steps: int, + enable_seed_checkpoint: bool, + output_folder: str | None, + is_baseline: bool = False, +) -> str: + """Run training for a specific scenario (baseline or test). + + Args: + scenario: Name of the scenario ("baseline" or "test") + commit: Git commit to checkout + cmd: Command to run + steps: Number of training steps + enable_seed_checkpoint: Whether to use seed checkpoint + output_folder: Output folder for results + is_baseline: Whether this is the baseline run (handles seed checkpoint creation) + + Returns: + Path to the log file + """ + checkout_commit(commit, scenario) + + if is_baseline: + create_seed_checkpoint(enable_seed_checkpoint, cmd, output_folder) + + log_file = run_training(scenario, cmd, steps, enable_seed_checkpoint, output_folder) + + if is_baseline: + restore_seed_checkpoint(enable_seed_checkpoint, output_folder) + + return log_file + + +def main() -> None: + """Main function that orchestrates the entire comparison process.""" + # Parse and validate arguments + args = parse_arguments() + validate_arguments( + args.baseline_commit, + args.test_commit, + args.baseline_cmd, + args.test_cmd, + args.steps, + ) + + # Setup environment + stats_file = setup_output_directory(args.output_folder) + enable_seed_checkpoint = not args.no_seed_checkpoint + print_configuration( + args.baseline_commit, + args.test_commit, + args.baseline_cmd, + args.test_cmd, + args.steps, + enable_seed_checkpoint, + ) + + # Run baseline and test training + baseline_log = run_scenario( + "baseline", + args.baseline_commit, + args.baseline_cmd, + args.steps, + enable_seed_checkpoint, + args.output_folder, + is_baseline=True, + ) + + test_log = run_scenario( + "test", + args.test_commit, + args.test_cmd, + args.steps, + enable_seed_checkpoint, + args.output_folder, + is_baseline=False, + ) + log_print() + + # Assert losses are equal if requested + if args.assert_equal: + assert_losses_equal(baseline_log, test_log) + + # Analysis and reporting + perform_loss_analysis(args.output_folder, stats_file) + cleanup_temp_files(args.output_folder) + print_completion_summary(args.output_folder, enable_seed_checkpoint) + + +if __name__ == "__main__": + main() From 0e78d18692adf2299e272493b3aec5834f427de3 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Wed, 12 Nov 2025 13:30:20 -0800 Subject: [PATCH 3/9] Update (base update) [ghstack-poisoned] --- .../experiments/compiler_toolkit/llama3/parallelize.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torchtitan/experiments/compiler_toolkit/llama3/parallelize.py b/torchtitan/experiments/compiler_toolkit/llama3/parallelize.py index 0ffbe61b89..62def3ef00 100644 --- a/torchtitan/experiments/compiler_toolkit/llama3/parallelize.py +++ b/torchtitan/experiments/compiler_toolkit/llama3/parallelize.py @@ -64,12 +64,12 @@ def parallelize_llama( model = simple_fsdp_parallelize_llama(model, parallel_dims, job_config) # Get compiler passes from config - compiler_passes = get_compiler_passes_from_config( - job_config, dump_folder=job_config.job.dump_folder - ) + compiler_passes = get_compiler_passes_from_config(job_config) # Create compilers with specified passes (defaults to no passes) - fw_compiler, bw_compiler = make_compiler_with_passes(compiler_passes) + fw_compiler, bw_compiler = make_compiler_with_passes( + compiler_passes, dump_folder=job_config.job.dump_folder + ) # Create custom joint_graph_builder with llama-specific compilers and validation llama_joint_graph_builder = functools.partial( From fa43b2893467961b98df1a90cf82c5da8678d739 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Wed, 12 Nov 2025 15:21:59 -0800 Subject: [PATCH 4/9] Update (base update) [ghstack-poisoned] --- .../experiments/compiler_toolkit/graph_utils.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/torchtitan/experiments/compiler_toolkit/graph_utils.py b/torchtitan/experiments/compiler_toolkit/graph_utils.py index db998aa170..fe246c1af0 100644 --- a/torchtitan/experiments/compiler_toolkit/graph_utils.py +++ b/torchtitan/experiments/compiler_toolkit/graph_utils.py @@ -6,7 +6,7 @@ import contextlib from pathlib import Path -from typing import Callable, List, Optional +from typing import Any, Callable, List, Optional import torch from torch._dynamo.functional_export import dynamo_graph_capture_for_export @@ -168,6 +168,18 @@ def __delattr__(self, name: str) -> None: else: super().__delattr__(name) + def state_dict(self, *args, **kwargs) -> Any: + return self.inner.state_dict(*args, **kwargs) + + def load_state_dict(self, *args, **kwargs) -> Any: + return self.inner.load_state_dict(*args, **kwargs) + + def name_parameters(self, *args, **kwargs) -> Any: + return self.inner.named_parameters(*args, **kwargs) + + def parameters(self, *args, **kwargs) -> Any: + return self.inner.parameters(*args, **kwargs) + def forward(self, *args, **kwargs): assert "forward" not in self._overrides, "forward cannot be overridden" From e393a13ff6e9afbb397318c2293b8609c177a3ab Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Mon, 17 Nov 2025 13:06:15 -0800 Subject: [PATCH 5/9] Update [ghstack-poisoned] --- .github/workflows/integration_test_8gpu_features.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/integration_test_8gpu_features.yaml b/.github/workflows/integration_test_8gpu_features.yaml index d8c8159175..d4c2cf0f01 100644 --- a/.github/workflows/integration_test_8gpu_features.yaml +++ b/.github/workflows/integration_test_8gpu_features.yaml @@ -79,7 +79,7 @@ jobs: # Verify the accuracy. export baseline_options="--parallelism.data_parallel_replicate_degree=1 --job.dump_folder=${RUNNER_TEMP}/artifacts/accuracy_comparison_outputs" export test_options="--parallelism.data_parallel_replicate_degree=4 --job.dump_folder=${RUNNER_TEMP}/artifacts/accuracy_comparison_outputs" - python3 scripts/loss_compare.py . . --baseline-options=${baseline_options} --test-options=${test_options} --steps=10 + python3 scripts/loss_compare.py . . --baseline-options="${baseline_options}" --test-options="${test_options}" --steps=10 # Cleanup the checkpoints so that we don't waste network bandwidth and time. rm -rf $RUNNER_TEMP/artifacts-to-be-uploaded/*/checkpoint From a53269fab99b0ae07fc125aac5a609928d6b5c47 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Mon, 17 Nov 2025 13:43:56 -0800 Subject: [PATCH 6/9] Update [ghstack-poisoned] --- .../integration_test_8gpu_features.yaml | 6 ++-- scripts/loss_compare.py | 30 ++++++++++++++++--- 2 files changed, 29 insertions(+), 7 deletions(-) diff --git a/.github/workflows/integration_test_8gpu_features.yaml b/.github/workflows/integration_test_8gpu_features.yaml index d4c2cf0f01..35c40201a0 100644 --- a/.github/workflows/integration_test_8gpu_features.yaml +++ b/.github/workflows/integration_test_8gpu_features.yaml @@ -77,9 +77,9 @@ jobs: python -m tests.integration_tests.run_tests --test_suite features $RUNNER_TEMP/artifacts-to-be-uploaded --ngpu 8 # Verify the accuracy. - export baseline_options="--parallelism.data_parallel_replicate_degree=1 --job.dump_folder=${RUNNER_TEMP}/artifacts/accuracy_comparison_outputs" - export test_options="--parallelism.data_parallel_replicate_degree=4 --job.dump_folder=${RUNNER_TEMP}/artifacts/accuracy_comparison_outputs" - python3 scripts/loss_compare.py . . --baseline-options="${baseline_options}" --test-options="${test_options}" --steps=10 + export baseline_options="--parallelism.data_parallel_replicate_degree=1" + export test_options="--parallelism.data_parallel_replicate_degree=4" + python3 scripts/loss_compare.py . . --baseline-options="${baseline_options}" --test-options="${test_options}" --steps=10 --job-dump-folder="${RUNNER_TEMP}/artifacts-to-be-uploaded/accuracy_comparison_outputs" # Cleanup the checkpoints so that we don't waste network bandwidth and time. rm -rf $RUNNER_TEMP/artifacts-to-be-uploaded/*/checkpoint diff --git a/scripts/loss_compare.py b/scripts/loss_compare.py index 2493c4b24a..1360f34936 100644 --- a/scripts/loss_compare.py +++ b/scripts/loss_compare.py @@ -95,9 +95,12 @@ def get_clean_log_path(scenario: str, output_folder: str) -> str: return f"{output_folder}/{scenario}_training_clean.log" -def build_base_command(config_file: str, train_file: str, options: str) -> str: +def build_base_command( + config_file: str, train_file: str, options: str, job_dump_folder: str +) -> str: """Build the base command from config file, train file, and options.""" cmd = f"TRAIN_FILE='{train_file}' CONFIG_FILE='{config_file}' ./run_train.sh" + cmd += f" --job.dump_folder={job_dump_folder}" if options: cmd += f" {options}" return cmd @@ -232,9 +235,10 @@ def build_training_command( options: str, steps: int, enable_seed_checkpoint: bool, + job_dump_folder: str, ) -> str: """Build the final training command with all options.""" - base_cmd = build_base_command(config_file, train_file, options) + base_cmd = build_base_command(config_file, train_file, options, job_dump_folder) cmd = f"{base_cmd} {FIXED_OPTIONS} --training.steps={steps}" if enable_seed_checkpoint: cmd += ( @@ -255,6 +259,7 @@ def print_configuration( test_options: str, steps: int, enable_seed_checkpoint: bool, + job_dump_folder: str, ) -> None: """Print configuration summary.""" log_print( @@ -272,6 +277,7 @@ def print_configuration( baseline_options, steps, enable_seed_checkpoint, + job_dump_folder, ) test_final_cmd = build_training_command( test_config, @@ -279,6 +285,7 @@ def print_configuration( test_options, steps, enable_seed_checkpoint, + job_dump_folder, ) log_print("Baseline command:") @@ -313,6 +320,7 @@ def create_seed_checkpoint( config_file: str, train_file: str, output_folder: str | None, + job_dump_folder: str, ) -> None: """Create seed checkpoint.""" if enable_seed_checkpoint: @@ -322,7 +330,8 @@ def create_seed_checkpoint( # Build seed checkpoint command seed_cmd = ( f"TRAIN_FILE='{train_file}' CONFIG_FILE='{config_file}' " - f"./run_train.sh --checkpoint.create_seed_checkpoint " + f"./run_train.sh --job.dump_folder={job_dump_folder} " + f"--checkpoint.create_seed_checkpoint " f"--checkpoint.enable {FIXED_OPTIONS}" ) @@ -340,6 +349,7 @@ def run_training( steps: int, enable_seed_checkpoint: bool, output_folder: str | None, + job_dump_folder: str, ) -> str: """Run training for a specific scenario. Returns the log file path.""" log_file = get_log_path(scenario, output_folder) @@ -349,7 +359,7 @@ def run_training( # Build the final command full_cmd = build_training_command( - config_file, train_file, options, steps, enable_seed_checkpoint + config_file, train_file, options, steps, enable_seed_checkpoint, job_dump_folder ) env = os.environ.copy() @@ -715,6 +725,11 @@ def parse_arguments() -> argparse.Namespace: "Script exits with error if losses differ." ), ) + parser.add_argument( + "--job-dump-folder", + default="outputs", + help="Job dump folder path (default: outputs)", + ) args = parser.parse_args() @@ -741,6 +756,7 @@ def run_scenario( steps: int, enable_seed_checkpoint: bool, output_folder: str | None, + job_dump_folder: str, ) -> str: """Run training for a specific scenario (baseline or test). @@ -753,6 +769,7 @@ def run_scenario( steps: Number of training steps enable_seed_checkpoint: Whether to use seed checkpoint output_folder: Output folder for results + job_dump_folder: Job dump folder path Returns: Path to the log file @@ -767,6 +784,7 @@ def run_scenario( steps, enable_seed_checkpoint, output_folder, + job_dump_folder, ) return log_file @@ -802,6 +820,7 @@ def main() -> None: args.test_options, args.steps, enable_seed_checkpoint, + args.job_dump_folder, ) create_seed_checkpoint( @@ -809,6 +828,7 @@ def main() -> None: args.baseline_config, args.baseline_train_file, args.output_folder, + args.job_dump_folder, ) # Run baseline and test training baseline_log = run_scenario( @@ -820,6 +840,7 @@ def main() -> None: args.steps, enable_seed_checkpoint, args.output_folder, + args.job_dump_folder, ) test_log = run_scenario( @@ -831,6 +852,7 @@ def main() -> None: args.steps, enable_seed_checkpoint, args.output_folder, + args.job_dump_folder, ) log_print() From 88f37f65da86eec80bc0247908bd25dff913d781 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Tue, 18 Nov 2025 16:44:07 -0800 Subject: [PATCH 7/9] Update [ghstack-poisoned] --- .../integration_test_8gpu_features.yaml | 5 +++-- scripts/loss_compare.py | 19 +++++++++++++++++++ 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/.github/workflows/integration_test_8gpu_features.yaml b/.github/workflows/integration_test_8gpu_features.yaml index 35c40201a0..362272d6cc 100644 --- a/.github/workflows/integration_test_8gpu_features.yaml +++ b/.github/workflows/integration_test_8gpu_features.yaml @@ -77,9 +77,10 @@ jobs: python -m tests.integration_tests.run_tests --test_suite features $RUNNER_TEMP/artifacts-to-be-uploaded --ngpu 8 # Verify the accuracy. + echo "Checking FSDP4 v.s. HSDP2FSDP2TP2 accuracy parity" export baseline_options="--parallelism.data_parallel_replicate_degree=1" - export test_options="--parallelism.data_parallel_replicate_degree=4" - python3 scripts/loss_compare.py . . --baseline-options="${baseline_options}" --test-options="${test_options}" --steps=10 --job-dump-folder="${RUNNER_TEMP}/artifacts-to-be-uploaded/accuracy_comparison_outputs" + export test_options="--parallelism.data_parallel_replicate_degree=2 --parallelism.tensor_parallel_degree=2" + python3 scripts/loss_compare.py . . --baseline-options="${baseline_options}" --test-options="${test_options}" --steps=10 --job-dump-folder="${RUNNER_TEMP}/artifacts-to-be-uploaded/accuracy_comparison_outputs" --assert-equal --baseline-ngpus=4 --test-ngpus=8 --steps=1 # Cleanup the checkpoints so that we don't waste network bandwidth and time. rm -rf $RUNNER_TEMP/artifacts-to-be-uploaded/*/checkpoint diff --git a/scripts/loss_compare.py b/scripts/loss_compare.py index 1360f34936..31573c3bce 100644 --- a/scripts/loss_compare.py +++ b/scripts/loss_compare.py @@ -350,6 +350,7 @@ def run_training( enable_seed_checkpoint: bool, output_folder: str | None, job_dump_folder: str, + ngpus: int, ) -> str: """Run training for a specific scenario. Returns the log file path.""" log_file = get_log_path(scenario, output_folder) @@ -363,6 +364,7 @@ def run_training( ) env = os.environ.copy() + env["NGPU"] = str(ngpus) run_with_realtime_output(full_cmd, log_file, env) @@ -730,6 +732,18 @@ def parse_arguments() -> argparse.Namespace: default="outputs", help="Job dump folder path (default: outputs)", ) + parser.add_argument( + "--baseline-ngpus", + type=int, + default=8, + help="Number of GPUs for baseline run (default: 8)", + ) + parser.add_argument( + "--test-ngpus", + type=int, + default=8, + help="Number of GPUs for test run (default: 8)", + ) args = parser.parse_args() @@ -757,6 +771,7 @@ def run_scenario( enable_seed_checkpoint: bool, output_folder: str | None, job_dump_folder: str, + ngpus: int, ) -> str: """Run training for a specific scenario (baseline or test). @@ -770,6 +785,7 @@ def run_scenario( enable_seed_checkpoint: Whether to use seed checkpoint output_folder: Output folder for results job_dump_folder: Job dump folder path + ngpus: Number of GPUs to use Returns: Path to the log file @@ -785,6 +801,7 @@ def run_scenario( enable_seed_checkpoint, output_folder, job_dump_folder, + ngpus, ) return log_file @@ -841,6 +858,7 @@ def main() -> None: enable_seed_checkpoint, args.output_folder, args.job_dump_folder, + args.baseline_ngpus, ) test_log = run_scenario( @@ -853,6 +871,7 @@ def main() -> None: enable_seed_checkpoint, args.output_folder, args.job_dump_folder, + args.test_ngpus, ) log_print() From 67708b93e4e53ba50283b8f401b4359098a52230 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Tue, 18 Nov 2025 17:37:31 -0800 Subject: [PATCH 8/9] Update [ghstack-poisoned] --- .github/workflows/integration_test_8gpu_features.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/integration_test_8gpu_features.yaml b/.github/workflows/integration_test_8gpu_features.yaml index 362272d6cc..47db741098 100644 --- a/.github/workflows/integration_test_8gpu_features.yaml +++ b/.github/workflows/integration_test_8gpu_features.yaml @@ -80,7 +80,7 @@ jobs: echo "Checking FSDP4 v.s. HSDP2FSDP2TP2 accuracy parity" export baseline_options="--parallelism.data_parallel_replicate_degree=1" export test_options="--parallelism.data_parallel_replicate_degree=2 --parallelism.tensor_parallel_degree=2" - python3 scripts/loss_compare.py . . --baseline-options="${baseline_options}" --test-options="${test_options}" --steps=10 --job-dump-folder="${RUNNER_TEMP}/artifacts-to-be-uploaded/accuracy_comparison_outputs" --assert-equal --baseline-ngpus=4 --test-ngpus=8 --steps=1 + python3 scripts/loss_compare.py . . --baseline-options="${baseline_options}" --test-options="${test_options}" --job-dump-folder="${RUNNER_TEMP}/artifacts-to-be-uploaded/accuracy_comparison_outputs" --assert-equal --baseline-ngpus=4 --test-ngpus=8 --steps=2 # Cleanup the checkpoints so that we don't waste network bandwidth and time. rm -rf $RUNNER_TEMP/artifacts-to-be-uploaded/*/checkpoint From 1425d8a746f2af8bc9d4bc2a96fdc6cbdc0f2841 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Tue, 18 Nov 2025 18:30:50 -0800 Subject: [PATCH 9/9] Update [ghstack-poisoned] --- .github/workflows/integration_test_8gpu_features.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/integration_test_8gpu_features.yaml b/.github/workflows/integration_test_8gpu_features.yaml index 47db741098..b95fa5be05 100644 --- a/.github/workflows/integration_test_8gpu_features.yaml +++ b/.github/workflows/integration_test_8gpu_features.yaml @@ -80,7 +80,7 @@ jobs: echo "Checking FSDP4 v.s. HSDP2FSDP2TP2 accuracy parity" export baseline_options="--parallelism.data_parallel_replicate_degree=1" export test_options="--parallelism.data_parallel_replicate_degree=2 --parallelism.tensor_parallel_degree=2" - python3 scripts/loss_compare.py . . --baseline-options="${baseline_options}" --test-options="${test_options}" --job-dump-folder="${RUNNER_TEMP}/artifacts-to-be-uploaded/accuracy_comparison_outputs" --assert-equal --baseline-ngpus=4 --test-ngpus=8 --steps=2 + python3 scripts/loss_compare.py . . --baseline-options="${baseline_options}" --test-options="${test_options}" --job-dump-folder="${RUNNER_TEMP}/artifacts-to-be-uploaded/accuracy_comparison_outputs" --assert-equal --baseline-ngpus=4 --test-ngpus=8 --steps=1 # Cleanup the checkpoints so that we don't waste network bandwidth and time. rm -rf $RUNNER_TEMP/artifacts-to-be-uploaded/*/checkpoint