From ac1a6cde3fa34526abbf4132666a7e997f86ee71 Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Thu, 4 Jul 2024 20:26:41 -0700 Subject: [PATCH 1/4] Run performance test non-alternately Summary: By default, performance tests (speedup experiments) will run the baseline and test backend alternately. However, this does not work for the torchao backend, which will change the model in-place, therefore the baseline run will also run with torchao backend since the model has already been quantized. Add a new experiment "latency_experiment" to run performance tests non-alternately (first run baseline for a few iterations, then run the test backend). other changes: need to add torch.compiler.cudagraph_mark_step_begin() to avoid the slowdown from # Unable to hit fast path of CUDAGraphs because of pending, uninvoked backwards also updated the torchao APIs to the current versions Test Plan: python run_benchmark.py torchao --only AlbertForMaskedLM --quantization noquant --performance --inference --bfloat16 --inductor-compile-mode max-autotune python run_benchmark.py torchao --only BartForCausalLM --quantization noquant --performance --inference --bfloat16 --inductor-compile-mode max-autotune python run_benchmark.py torchao --only timm_efficientnet --quantization noquant --performance --inference --bfloat16 --inductor-compile-mode max-autotune (should all be ~1.0 0.997x 1.006x 0.994x Reviewers: Subscribers: Tasks: Tags: --- userbenchmark/dynamo/dynamobench/common.py | 28 ++++++-------- .../dynamo/dynamobench/torchao_backend.py | 38 +++++++++---------- 2 files changed, 30 insertions(+), 36 deletions(-) diff --git a/userbenchmark/dynamo/dynamobench/common.py b/userbenchmark/dynamo/dynamobench/common.py index 7378959eb1..19aa3b8b72 100644 --- a/userbenchmark/dynamo/dynamobench/common.py +++ b/userbenchmark/dynamo/dynamobench/common.py @@ -699,7 +699,6 @@ def latency_experiment(args, model_iter_fn, model, example_inputs, mark, **kwarg should_randomize_input = args.randomize_input import contextlib - from torch._inductor.utils import maybe_profile @contextlib.contextmanager @@ -2964,10 +2963,8 @@ def run_performance_test_non_alternate( self, name, model, example_inputs, optimize_ctx, experiment, tag=None ): "Run performance test in non-alternately." - assert ( - experiment.func is latency_experiment - ), "Must run with latency_experiment." - + assert experiment.func is latency_experiment, \ + f"Must run with latency_experiment." def warmup(fn, model, example_inputs, mode, niters=10): peak_mem = 0 start_stats = get_dynamo_stats() @@ -3018,7 +3015,6 @@ def warmup(fn, model, example_inputs, mode, niters=10): if tag is not None: experiment_kwargs["tag"] = tag results = [] - with maybe_snapshot_memory( self.args.snapshot_memory, f"eager_{self.args.only}" ): @@ -3030,9 +3026,7 @@ def warmup(fn, model, example_inputs, mode, niters=10): self.model_iter_fn, model, example_inputs, "eager", niters=1 ) - baseline_timings = experiment( - model, example_inputs, mark="expected", **experiment_kwargs - ) + baseline_timings = experiment(model, example_inputs, mark="expected", **experiment_kwargs) if self.args.export_aot_inductor: t_0 = time.perf_counter() @@ -3110,13 +3104,9 @@ def warmup(fn, model, example_inputs, mode, niters=10): experiment = functools.partial( experiment, optimized_model_iter_fn.context.onnx_model ) - backend_timings = experiment( - model, example_inputs, mark="expected", **experiment_kwargs - ) + backend_timings = experiment(model, example_inputs, mark="expected", **experiment_kwargs) timings = np.stack((baseline_timings, backend_timings), axis=1) - result_summary = latency_experiment_summary( - self.args, model, timings, **experiment_kwargs - ) + result_summary = latency_experiment_summary(self.args, model, timings, **experiment_kwargs) if not hasattr(model, name): model.name = name results.append(result_summary) @@ -4430,7 +4420,13 @@ def run(runner, args, original_dir=None): fullgraph=args.nopython, mode=args.inductor_compile_mode, ) - runner.model_iter_fn = baseline_ctx(runner.model_iter_fn) + model_iter_fn = baseline_ctx(runner.model_iter_fn) + # needed to avoid error that causes inconsistent timing due to: + # Unable to hit fast path of CUDAGraphs because of pending, uninvoked backwards + def model_iter_fn_and_mark_step(*args, **kwargs): + torch.compiler.cudagraph_mark_step_begin() + model_iter_fn(*args, **kwargs) + runner.model_iter_fn = model_iter_fn_and_mark_step optimize_ctx = torchao_optimize_ctx(args.quantization) else: optimize_ctx = torch._dynamo.optimize(args.backend, nopython=args.nopython) diff --git a/userbenchmark/dynamo/dynamobench/torchao_backend.py b/userbenchmark/dynamo/dynamobench/torchao_backend.py index f02672928b..332c40bc33 100644 --- a/userbenchmark/dynamo/dynamobench/torchao_backend.py +++ b/userbenchmark/dynamo/dynamobench/torchao_backend.py @@ -4,48 +4,46 @@ def setup_baseline(): - torch._dynamo.epilogue_fusion = False + from torchao.quantization.utils import recommended_inductor_config_setter + recommended_inductor_config_setter() torch._dynamo.config.automatic_dynamic_shapes = False - torch._dynamo.config.force_parameter_static_shapes = False torch._dynamo.config.cache_size_limit = 10000 - torch._inductor.config.force_fuse_int_mm_with_mul = True - torch._inductor.config.use_mixed_mm = True def torchao_optimize_ctx(quantization: str): import torchao from torchao.quantization.quant_api import ( - change_linear_weights_to_int4_woqtensors, - change_linear_weights_to_int8_dqtensors, - change_linear_weights_to_int8_woqtensors, + autoquant, + quantize_, + int8_dynamic_activation_int8_weight, + int4_weight_only, + int8_weight_only, ) + from torchao.utils import unwrap_tensor_subclass def inner(model_iter_fn: Callable): def _torchao_apply(module: torch.nn.Module, example_inputs: Any): if getattr(module, "_quantized", None) is None: if quantization == "int8dynamic": - change_linear_weights_to_int8_dqtensors(module) + quantize_(module, int8_dynamic_activation_int8_weight(), set_inductor_config=False) elif quantization == "int8weightonly": - change_linear_weights_to_int8_woqtensors(module) + quantize_(module, int8_weight_only(), set_inductor_config=False) elif quantization == "int4weightonly": - change_linear_weights_to_int4_woqtensors(module) - elif quantization == "autoquant": - torchao.autoquant(module, error_on_unseen=False) + quantize_(module, int4_weight_only(), set_inductor_config=False) + if quantization == "autoquant": + torchao.autoquant(module, error_on_unseen=False, set_inductor_config=False) if isinstance(example_inputs, dict): module(**example_inputs) else: module(*example_inputs) from torchao.quantization.autoquant import AUTOQUANT_CACHE - assert ( - len(AUTOQUANT_CACHE) > 0 - ), f"Err: found no autoquantizable layers in model {type(module)}, stopping autoquantization" - elif quantization == "noquant": - pass + if len(AUTOQUANT_CACHE) == 0: + raise Exception("NotAutoquantizable" + f"Found no autoquantizable layers in model {type(module)}, stopping autoquantized run" + ) else: - raise AssertionError( - f"Unsupposed quantization mode {quantization}." - ) + unwrap_tensor_subclass(module) setattr(module, "_quantized", True) # noqa: B010 model_iter_fn(module, example_inputs) From 50db22899638387b80ae04fd8cb23924dc877c87 Mon Sep 17 00:00:00 2001 From: HDCharles Date: Tue, 30 Jul 2024 19:17:06 -0700 Subject: [PATCH 2/4] Linting Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- userbenchmark/dynamo/dynamobench/common.py | 19 ++++++++++++++----- .../dynamo/dynamobench/torchao_backend.py | 16 +++++++++++----- 2 files changed, 25 insertions(+), 10 deletions(-) diff --git a/userbenchmark/dynamo/dynamobench/common.py b/userbenchmark/dynamo/dynamobench/common.py index 19aa3b8b72..af9097c984 100644 --- a/userbenchmark/dynamo/dynamobench/common.py +++ b/userbenchmark/dynamo/dynamobench/common.py @@ -2963,8 +2963,10 @@ def run_performance_test_non_alternate( self, name, model, example_inputs, optimize_ctx, experiment, tag=None ): "Run performance test in non-alternately." - assert experiment.func is latency_experiment, \ - f"Must run with latency_experiment." + assert ( + experiment.func is latency_experiment + ), "Must run with latency_experiment." + def warmup(fn, model, example_inputs, mode, niters=10): peak_mem = 0 start_stats = get_dynamo_stats() @@ -3026,7 +3028,9 @@ def warmup(fn, model, example_inputs, mode, niters=10): self.model_iter_fn, model, example_inputs, "eager", niters=1 ) - baseline_timings = experiment(model, example_inputs, mark="expected", **experiment_kwargs) + baseline_timings = experiment( + model, example_inputs, mark="expected", **experiment_kwargs + ) if self.args.export_aot_inductor: t_0 = time.perf_counter() @@ -3104,9 +3108,13 @@ def warmup(fn, model, example_inputs, mode, niters=10): experiment = functools.partial( experiment, optimized_model_iter_fn.context.onnx_model ) - backend_timings = experiment(model, example_inputs, mark="expected", **experiment_kwargs) + backend_timings = experiment( + model, example_inputs, mark="expected", **experiment_kwargs + ) timings = np.stack((baseline_timings, backend_timings), axis=1) - result_summary = latency_experiment_summary(self.args, model, timings, **experiment_kwargs) + result_summary = latency_experiment_summary( + self.args, model, timings, **experiment_kwargs + ) if not hasattr(model, name): model.name = name results.append(result_summary) @@ -4421,6 +4429,7 @@ def run(runner, args, original_dir=None): mode=args.inductor_compile_mode, ) model_iter_fn = baseline_ctx(runner.model_iter_fn) + # needed to avoid error that causes inconsistent timing due to: # Unable to hit fast path of CUDAGraphs because of pending, uninvoked backwards def model_iter_fn_and_mark_step(*args, **kwargs): diff --git a/userbenchmark/dynamo/dynamobench/torchao_backend.py b/userbenchmark/dynamo/dynamobench/torchao_backend.py index 332c40bc33..8e9739e770 100644 --- a/userbenchmark/dynamo/dynamobench/torchao_backend.py +++ b/userbenchmark/dynamo/dynamobench/torchao_backend.py @@ -5,6 +5,7 @@ def setup_baseline(): from torchao.quantization.utils import recommended_inductor_config_setter + recommended_inductor_config_setter() torch._dynamo.config.automatic_dynamic_shapes = False torch._dynamo.config.cache_size_limit = 10000 @@ -14,10 +15,10 @@ def torchao_optimize_ctx(quantization: str): import torchao from torchao.quantization.quant_api import ( autoquant, - quantize_, - int8_dynamic_activation_int8_weight, int4_weight_only, + int8_dynamic_activation_int8_weight, int8_weight_only, + quantize_, ) from torchao.utils import unwrap_tensor_subclass @@ -25,13 +26,17 @@ def inner(model_iter_fn: Callable): def _torchao_apply(module: torch.nn.Module, example_inputs: Any): if getattr(module, "_quantized", None) is None: if quantization == "int8dynamic": - quantize_(module, int8_dynamic_activation_int8_weight(), set_inductor_config=False) + quantize_( + module, + int8_dynamic_activation_int8_weight(), + set_inductor_config=False, + ) elif quantization == "int8weightonly": quantize_(module, int8_weight_only(), set_inductor_config=False) elif quantization == "int4weightonly": quantize_(module, int4_weight_only(), set_inductor_config=False) if quantization == "autoquant": - torchao.autoquant(module, error_on_unseen=False, set_inductor_config=False) + autoquant(module, error_on_unseen=False, set_inductor_config=False) if isinstance(example_inputs, dict): module(**example_inputs) else: @@ -39,7 +44,8 @@ def _torchao_apply(module: torch.nn.Module, example_inputs: Any): from torchao.quantization.autoquant import AUTOQUANT_CACHE if len(AUTOQUANT_CACHE) == 0: - raise Exception("NotAutoquantizable" + raise Exception( + "NotAutoquantizable" f"Found no autoquantizable layers in model {type(module)}, stopping autoquantized run" ) else: From 4ee24638870c5cb309f9ffec44c5a245d2312a9b Mon Sep 17 00:00:00 2001 From: HDCharles Date: Fri, 2 Aug 2024 13:18:25 -0700 Subject: [PATCH 3/4] lint fixes Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- userbenchmark/dynamo/dynamobench/common.py | 3 +++ userbenchmark/dynamo/dynamobench/torchao_backend.py | 3 +-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/userbenchmark/dynamo/dynamobench/common.py b/userbenchmark/dynamo/dynamobench/common.py index af9097c984..c8698707aa 100644 --- a/userbenchmark/dynamo/dynamobench/common.py +++ b/userbenchmark/dynamo/dynamobench/common.py @@ -699,6 +699,7 @@ def latency_experiment(args, model_iter_fn, model, example_inputs, mark, **kwarg should_randomize_input = args.randomize_input import contextlib + from torch._inductor.utils import maybe_profile @contextlib.contextmanager @@ -3017,6 +3018,7 @@ def warmup(fn, model, example_inputs, mode, niters=10): if tag is not None: experiment_kwargs["tag"] = tag results = [] + with maybe_snapshot_memory( self.args.snapshot_memory, f"eager_{self.args.only}" ): @@ -4435,6 +4437,7 @@ def run(runner, args, original_dir=None): def model_iter_fn_and_mark_step(*args, **kwargs): torch.compiler.cudagraph_mark_step_begin() model_iter_fn(*args, **kwargs) + runner.model_iter_fn = model_iter_fn_and_mark_step optimize_ctx = torchao_optimize_ctx(args.quantization) else: diff --git a/userbenchmark/dynamo/dynamobench/torchao_backend.py b/userbenchmark/dynamo/dynamobench/torchao_backend.py index 8e9739e770..c8adb8bb3a 100644 --- a/userbenchmark/dynamo/dynamobench/torchao_backend.py +++ b/userbenchmark/dynamo/dynamobench/torchao_backend.py @@ -12,7 +12,6 @@ def setup_baseline(): def torchao_optimize_ctx(quantization: str): - import torchao from torchao.quantization.quant_api import ( autoquant, int4_weight_only, @@ -44,7 +43,7 @@ def _torchao_apply(module: torch.nn.Module, example_inputs: Any): from torchao.quantization.autoquant import AUTOQUANT_CACHE if len(AUTOQUANT_CACHE) == 0: - raise Exception( + raise Exception( # noqa: TRY002` "NotAutoquantizable" f"Found no autoquantizable layers in model {type(module)}, stopping autoquantized run" ) From 4ca5a04d31f6883f937c417d8fd7df2e15f8e37c Mon Sep 17 00:00:00 2001 From: HDCharles Date: Tue, 6 Aug 2024 17:46:16 -0700 Subject: [PATCH 4/4] fix lint Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- userbenchmark/dynamo/dynamobench/torchao_backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/userbenchmark/dynamo/dynamobench/torchao_backend.py b/userbenchmark/dynamo/dynamobench/torchao_backend.py index c8adb8bb3a..3854853784 100644 --- a/userbenchmark/dynamo/dynamobench/torchao_backend.py +++ b/userbenchmark/dynamo/dynamobench/torchao_backend.py @@ -43,7 +43,7 @@ def _torchao_apply(module: torch.nn.Module, example_inputs: Any): from torchao.quantization.autoquant import AUTOQUANT_CACHE if len(AUTOQUANT_CACHE) == 0: - raise Exception( # noqa: TRY002` + raise Exception( # noqa: TRY002` "NotAutoquantizable" f"Found no autoquantizable layers in model {type(module)}, stopping autoquantized run" )