From d2d12b317d7485a95090ad614e16d9eaf526d010 Mon Sep 17 00:00:00 2001 From: Yiming Zhou Date: Tue, 11 Nov 2025 13:17:31 -0800 Subject: [PATCH] add tests and scripts for numerics check --- .../scripts/check_numerics.py | 126 ++++++++ .../compiler_toolkit/tests/numerics_utils.py | 270 ++++++++++++++++++ .../compiler_toolkit/tests/test_numerics.py | 71 +++++ 3 files changed, 467 insertions(+) create mode 100644 torchtitan/experiments/compiler_toolkit/scripts/check_numerics.py create mode 100644 torchtitan/experiments/compiler_toolkit/tests/numerics_utils.py create mode 100644 torchtitan/experiments/compiler_toolkit/tests/test_numerics.py diff --git a/torchtitan/experiments/compiler_toolkit/scripts/check_numerics.py b/torchtitan/experiments/compiler_toolkit/scripts/check_numerics.py new file mode 100644 index 0000000000..06c1717957 --- /dev/null +++ b/torchtitan/experiments/compiler_toolkit/scripts/check_numerics.py @@ -0,0 +1,126 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import sys +from pathlib import Path + +# Add parent directory to path to import numerics_utils +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from tests.numerics_utils import run_numerics_test + + +def main(): + parser = argparse.ArgumentParser( + description="Run two training jobs and compare their tensorboard metrics" + ) + parser.add_argument( + "--ngpu", + type=int, + required=True, + help="Number of GPUs to use", + ) + parser.add_argument( + "--config-file", + type=str, + required=True, + help="Path to config file", + ) + parser.add_argument( + "--dp-shard-degree", + type=int, + default=1, + help="Data parallel shard degree", + ) + parser.add_argument( + "--tp-degree", + type=int, + default=1, + help="Tensor parallel degree", + ) + parser.add_argument( + "--cp-degree", + type=int, + default=1, + help="Context parallel degree", + ) + parser.add_argument( + "--ep-degree", + type=int, + default=1, + help="Expert parallel degree", + ) + parser.add_argument( + "--ac-mode", + type=str, + default="selective", + choices=["selective", "none", "full"], + help="Activation checkpoint mode", + ) + parser.add_argument( + "--steps", + type=int, + default=50, + help="Number of training steps", + ) + parser.add_argument( + "--seed", + type=int, + default=42, + help="Random seed for deterministic training", + ) + parser.add_argument( + "--eager-tb-folder", + type=str, + default="tb/eager_run", + help="Tensorboard folder for eager run", + ) + parser.add_argument( + "--compiled-tb-folder", + type=str, + default="tb/compiled_run", + help="Tensorboard folder for compiled run", + ) + parser.add_argument( + "--metrics", + nargs="+", + default=["loss_metrics/global_avg_loss", "grad_norm"], + help="Metrics to compare", + ) + parser.add_argument( + "--passes", + type=str, + default=None, + help=( + "Comma-separated list of compiler passes to apply " + "(e.g., 'autobucketing_reordering' or 'autobucketing_reordering,regional_inductor')" + ), + ) + + args = parser.parse_args() + + success = run_numerics_test( + ngpu=args.ngpu, + config_file=args.config_file, + dp_shard_degree=args.dp_shard_degree, + tp_degree=args.tp_degree, + cp_degree=args.cp_degree, + ep_degree=args.ep_degree, + ac_mode=args.ac_mode, + steps=args.steps, + seed=args.seed, + eager_tb_folder=args.eager_tb_folder, + compiled_tb_folder=args.compiled_tb_folder, + metrics=args.metrics, + passes=args.passes, + ) + + return 0 if success else 1 + + +if __name__ == "__main__": + exit(main()) diff --git a/torchtitan/experiments/compiler_toolkit/tests/numerics_utils.py b/torchtitan/experiments/compiler_toolkit/tests/numerics_utils.py new file mode 100644 index 0000000000..0d7741b1a2 --- /dev/null +++ b/torchtitan/experiments/compiler_toolkit/tests/numerics_utils.py @@ -0,0 +1,270 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Shared utilities for numerics testing.""" + +import glob +import os +import subprocess + +import torch +from tensorboard.backend.event_processing.event_accumulator import EventAccumulator + + +def load_metrics(event_path, metric_names): + """Load metrics from tensorboard event files.""" + event_acc = EventAccumulator(event_path) + event_acc.Reload() + + metrics = {} + for metric_name in metric_names: + try: + scalars = event_acc.Scalars(metric_name) + metrics[metric_name] = {scalar.step: scalar.value for scalar in scalars} + except KeyError: + print(f"Warning: Metric {metric_name!r} not found in event file") + metrics[metric_name] = {} + + return metrics + + +def compare_metrics(metrics1, metrics2, label1="Eager", label2="Compiled"): + """Compare two sets of metrics and verify bitwise equivalence using torch.equal().""" + + all_metrics = set(metrics1.keys()) | set(metrics2.keys()) + all_match = True + + for metric_name in sorted(all_metrics): + + steps1 = set(metrics1[metric_name].keys()) + steps2 = set(metrics2[metric_name].keys()) + + if steps1 != steps2: + print(" ERROR: Step mismatch!") + print(f" {label1} steps: {sorted(steps1)}") + print(f" {label2} steps: {sorted(steps2)}") + all_match = False + continue + + # Convert values to tensors for each step and compare + values1 = [metrics1[metric_name][step] for step in sorted(steps1)] + values2 = [metrics2[metric_name][step] for step in sorted(steps2)] + + tensor1 = torch.tensor(values1) + tensor2 = torch.tensor(values2) + + if torch.equal(tensor1, tensor2): + print( + f" ✓ PASS: All {len(steps1)} steps match exactly (bitwise equivalent)" + ) + else: + # Find and report mismatches + mismatches = [] + for idx, step in enumerate(sorted(steps1)): + val1 = values1[idx] + val2 = values2[idx] + if val1 != val2: + mismatches.append((step, val1, val2, abs(val1 - val2))) + + print( + f" ERROR: Found {len(mismatches)} mismatches out of {len(steps1)} steps" + ) + + return all_match + + +def find_latest_event_dir(base_path): + """Find the latest timestamped directory in the base path.""" + if not os.path.exists(base_path): + raise ValueError(f"Path does not exist: {base_path}") + + subdirs = [d for d in glob.glob(os.path.join(base_path, "*")) if os.path.isdir(d)] + if not subdirs: + return base_path + + latest = max(subdirs, key=os.path.getmtime) + return latest + + +def run_training( + ngpu, + config_file, + model_name, + dp_shard_degree, + tp_degree, + cp_degree, + ep_degree, + ac_mode, + steps, + seed, + deterministic, + tb_folder, + passes=None, +): + """Run a training job with the specified configuration.""" + print(f"\nStarting training: {model_name}") + + env = os.environ.copy() + env["NGPU"] = str(ngpu) + env["CONFIG_FILE"] = config_file + + cmd = [ + "./run_train.sh", + "--model.name", + model_name, + "--parallelism.data_parallel_shard_degree", + str(dp_shard_degree), + "--parallelism.tensor_parallel_degree", + str(tp_degree), + ] + + if cp_degree > 1: + cmd.extend(["--parallelism.context_parallel_degree", str(cp_degree)]) + if ep_degree > 1: + cmd.extend(["--parallelism.expert_parallel_degree", str(ep_degree)]) + + cmd.extend( + [ + "--activation_checkpoint.mode", + ac_mode, + "--training.steps", + str(steps), + "--debug.seed", + str(seed), + "--debug.deterministic", + "--metrics.enable_tensorboard", + "--metrics.save_tb_folder", + tb_folder, + ] + ) + + if passes: + cmd.extend( + [ + "--job.custom_config_module", + "torchtitan.experiments.compiler_toolkit.job_config", + "--compile.passes", + passes, + ] + ) + + print(f"Environment: NGPU={env['NGPU']}, CONFIG_FILE={env['CONFIG_FILE']}") + print(f"Running command: {' '.join(cmd)}") + + try: + result = subprocess.run( + cmd, + env=env, + check=True, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + ) + print(f"✓ Training completed: {model_name}") + return True + except subprocess.CalledProcessError as e: + print(f"✗ Training failed: {model_name}") + print(f"Error output:\n{e.stdout}") + return False + + +def determine_model_names(config_file): + """Determine model names based on config file.""" + if "deepseek" in config_file: + model_name = "deepseek_v3" + elif "llama3" in config_file: + model_name = "llama3" + else: + raise ValueError( + f"Unable to determine model names from config file: {config_file}" + ) + + eager_model = f"simple_fsdp.{model_name}" + compiled_model = f"compiler_toolkit.{model_name}" + + return eager_model, compiled_model + + +def run_numerics_test( + ngpu, + config_file, + dp_shard_degree, + tp_degree, + cp_degree, + ep_degree, + ac_mode, + steps, + seed, + eager_tb_folder, + compiled_tb_folder, + metrics, + passes=None, +): + """ + Run numerics test by training both eager and compiled models and comparing metrics. + + Returns: + bool: True if all metrics match, False otherwise. + """ + # Determine model names + eager_model, compiled_model = determine_model_names(config_file) + + # Run eager training + eager_success = run_training( + ngpu=ngpu, + config_file=config_file, + model_name=eager_model, + dp_shard_degree=dp_shard_degree, + tp_degree=tp_degree, + cp_degree=cp_degree, + ep_degree=ep_degree, + ac_mode=ac_mode, + steps=steps, + seed=seed, + deterministic=True, + tb_folder=eager_tb_folder, + ) + + if not eager_success: + print("✗ Eager training failed") + return False + + # Run compiled training + compiled_success = run_training( + ngpu=ngpu, + config_file=config_file, + model_name=compiled_model, + dp_shard_degree=dp_shard_degree, + tp_degree=tp_degree, + cp_degree=cp_degree, + ep_degree=ep_degree, + ac_mode=ac_mode, + steps=steps, + seed=seed, + deterministic=True, + tb_folder=compiled_tb_folder, + passes=passes, + ) + + if not compiled_success: + print("✗ Compiled training failed") + return False + + # Compare metrics + eager_path = find_latest_event_dir(f"./outputs/{eager_tb_folder}") + compiled_path = find_latest_event_dir(f"./outputs/{compiled_tb_folder}") + + eager_metrics = load_metrics(eager_path, metrics) + compiled_metrics = load_metrics(compiled_path, metrics) + + all_match = compare_metrics(eager_metrics, compiled_metrics) + + if all_match: + print("✓ SUCCESS: All metrics are bitwise equivalent") + else: + print("✗ FAILURE: Metrics differ between runs") + + return all_match diff --git a/torchtitan/experiments/compiler_toolkit/tests/test_numerics.py b/torchtitan/experiments/compiler_toolkit/tests/test_numerics.py new file mode 100644 index 0000000000..3bf5650e55 --- /dev/null +++ b/torchtitan/experiments/compiler_toolkit/tests/test_numerics.py @@ -0,0 +1,71 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import unittest + +from .numerics_utils import run_numerics_test + + +class TestNumerics(unittest.TestCase): + """Test numerics equivalence between simple_fsdp and compiler_toolkit implementations.""" + + def test_llama3_fsdp_tp(self): + """Test Llama3 with FSDP + TP configuration.""" + result = run_numerics_test( + ngpu=4, + config_file="./torchtitan/models/llama3/train_configs/debug_model.toml", + dp_shard_degree=2, + tp_degree=2, + cp_degree=1, + ep_degree=1, + ac_mode="selective", + steps=10, + seed=42, + eager_tb_folder="tb/test_llama3_fsdp_tp_eager", + compiled_tb_folder="tb/test_llama3_fsdp_tp_compiled", + metrics=["loss_metrics/global_avg_loss", "grad_norm"], + ) + self.assertTrue(result, "Llama3 FSDP+TP numerics test failed") + + def test_llama3_fsdp_tp_autobucketing(self): + result = run_numerics_test( + ngpu=4, + config_file="./torchtitan/models/llama3/train_configs/debug_model.toml", + dp_shard_degree=2, + tp_degree=2, + cp_degree=1, + ep_degree=1, + ac_mode="selective", + steps=10, + seed=42, + eager_tb_folder="tb/test_llama3_fsdp_tp_eager", + compiled_tb_folder="tb/test_llama3_fsdp_tp_compiled", + metrics=["loss_metrics/global_avg_loss", "grad_norm"], + passes="autobucketing_reordering", + ) + + def test_deepseek_v3_fsdp_tp_ep(self): + """Test DeepSeek V3 with FSDP + TP + EP configuration.""" + result = run_numerics_test( + ngpu=4, + config_file="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml", + dp_shard_degree=2, + tp_degree=2, + cp_degree=1, + ep_degree=4, + ac_mode="none", + steps=10, + seed=42, + eager_tb_folder="tb/test_deepseek_v3_fsdp_tp_ep_eager", + compiled_tb_folder="tb/test_deepseek_v3_fsdp_tp_ep_compiled", + metrics=["loss_metrics/global_avg_loss", "grad_norm"], + ) + self.assertTrue(result, "DeepSeek V3 FSDP+TP+EP numerics test failed") + + +if __name__ == "__main__": + unittest.main()