From 081e35cd0916afd50f1db51c34b1df964fed798c Mon Sep 17 00:00:00 2001 From: David Berard Date: Thu, 29 Sep 2022 16:19:43 -0700 Subject: [PATCH 1/5] Add flag for dynamo+ddp optimizations Add a flag that can be used to turn dynamo+ddp optimizations on. This will be used to compare how dynamo+ddp performs with and without the additional graph break strategy for improving dynamo+ddp compute/communication overlap. [ghstack-poisoned] --- torchbenchmark/util/extra_args.py | 14 +++++ userbenchmark/ddp_experiments/__init__.py | 69 ++++++++++++---------- userbenchmark/ddp_experiments/parse_ddp.py | 28 +++++---- 3 files changed, 68 insertions(+), 43 deletions(-) diff --git a/torchbenchmark/util/extra_args.py b/torchbenchmark/util/extra_args.py index a5002fbc87..53aa9f9c7a 100644 --- a/torchbenchmark/util/extra_args.py +++ b/torchbenchmark/util/extra_args.py @@ -111,6 +111,7 @@ def parse_opt_args(model: 'torchbenchmark.util.model.BenchmarkModel', opt_args: parser.add_argument("--torch_trt", action='store_true', help="enable torch_tensorrt") parser.add_argument("--flops", choices=["fvcore", "dcgm"], help="Return the flops result") parser.add_argument("--use_cosine_similarity", action='store_true', help="use cosine similarity for correctness check") + parser.add_argument("--optimize_dynamo_ddp", action='store_true', help="enable extra optimizations for DDP + dynamo") args, extra_args = parser.parse_known_args(opt_args) if model.jit: args.backend = "torchscript" @@ -146,3 +147,16 @@ def apply_opt_args(model: 'torchbenchmark.util.model.BenchmarkModel', args: argp module, exmaple_inputs = model.get_module() precision = 'fp16' if not model.dargs.precision == "fp32" else 'fp32' model.set_module(enable_torchtrt(precision=precision, model=module, example_inputs=exmaple_inputs)) + + if args.optimize_dynamo_ddp: + import torchdynamo + @contextlib.contextmanager + def optimize_ddp_ctx(val: bool): + old_value = torchdynamo.config.optimize_ddp + try: + torchdynamo.config.optimize_ddp = val + yield + finally: + torchdynamo.config.optimize_ddp = old_value + + model.add_context(lambda: optimize_ddp_ctx(True)) diff --git a/userbenchmark/ddp_experiments/__init__.py b/userbenchmark/ddp_experiments/__init__.py index 0e452de88b..7b162e9ed9 100644 --- a/userbenchmark/ddp_experiments/__init__.py +++ b/userbenchmark/ddp_experiments/__init__.py @@ -1,6 +1,7 @@ import argparse import importlib import os +import copy import csv import io import submitit @@ -131,7 +132,7 @@ class TrainerWrapper(object): def __init__(self, args, model_args): self.args = args self.args.output_dir = args.job_dir - + # extra args just passed to the Trainer class ctor self.model_args=model_args @@ -184,7 +185,7 @@ def main(): # Note that the folder will depend on the job_id, to easily track experiments executor = submitit.AutoExecutor(folder=args.job_dir, slurm_max_num_timeout=3000) - + executor.update_parameters( gpus_per_node=args.ngpus, # one task per GPU @@ -199,7 +200,7 @@ def main(): executor.update_parameters(name="distbench", slurm_array_parallelism=1, timeout_min=1000) - + # args.dist_url = get_init_file(args).as_uri() # args.output_dir = args.job_dir # job = executor.submit(TrainerWrapper(args)) @@ -247,33 +248,41 @@ def get_backend_name(model_args): for nodes in node_list: for model_name in models: for model_args in model_args_configs: - batch_size = model_batch_size[model_name] - args.model = model_name - args.batch_size = batch_size - args.nodes = nodes - args.dist_url = get_init_file(args).as_uri() - args.output_dir = args.job_dir - executor.update_parameters( - gpus_per_node=args.ngpus, - # one task per GPU - tasks_per_node=args.ngpus, - cpus_per_task=10, - nodes=args.nodes, - timeout_min=args.timeout, - # Below are cluster dependent parameters - slurm_partition=args.partition, - slurm_signal_delay_s=120, - ) - job = executor.submit(TrainerWrapper(args, model_args)) - - # print ID of the Slurm job - backend_name = get_backend_name(model_args) - print(f"{model_name}_{backend_name}_{nodes}: {job.job_id}") - output_csv( - args.index_file, - ("model", "backend", "nodes", "job_id"), - (model_name, backend_name, nodes, job.job_id), - ) + for has_breaks in [True, False]: + # copy the model args so we can add more arguments without modifying + # the original model_args list. + copied_model_args = copy.copy(model_args) + breakname = "withbreaks" if has_breaks else "nobreaks" + if has_breaks: + copied_model_args.append("--optimize_dynamo_ddp") + batch_size = model_batch_size[model_name] + args.model = model_name + args.batch_size = batch_size + args.nodes = nodes + args.dist_url = get_init_file(args).as_uri() + args.output_dir = args.job_dir + executor.update_parameters( + gpus_per_node=args.ngpus, + # one task per GPU + tasks_per_node=args.ngpus, + cpus_per_task=10, + nodes=args.nodes, + timeout_min=args.timeout, + # Below are cluster dependent parameters + slurm_partition=args.partition, + slurm_signal_delay_s=120, + slurm_exclude=args.exclude, + ) + job = executor.submit(TrainerWrapper(args, copied_model_args)) + + # print ID of the Slurm job + backend_name = get_backend_name(model_args) + print(f"{model_name}_{backend_name}_{nodes}_{breakname}: {job.job_id}") + output_csv( + args.index_file, + ("model", "backend", "nodes", "has_breaks", "job_id"), + (model_name, backend_name, nodes, has_breaks, job.job_id), + ) # waits for completion and returns output print(job.results()) diff --git a/userbenchmark/ddp_experiments/parse_ddp.py b/userbenchmark/ddp_experiments/parse_ddp.py index 154ba23686..71de996053 100644 --- a/userbenchmark/ddp_experiments/parse_ddp.py +++ b/userbenchmark/ddp_experiments/parse_ddp.py @@ -25,10 +25,10 @@ def get_job_result(args, job_id, worker_rank=0): elif desc == "success": # print(f"Success: {payload}") return True, payload - + # print(f"Unknown result: {dat}") return False, dat - + return False, None def parse_data(args): @@ -36,17 +36,18 @@ def parse_data(args): Schema: model_data["model"]["backend"][#nodes] = latency_median """ - model_data = defaultdict(lambda: defaultdict(lambda: defaultdict(float))) + model_data = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(float)))) with open(args.csv) as f: runs = csv.DictReader(f) for row in runs: model = row["model"] backend = row["backend"] nodes = row["nodes"] + has_breaks = row["has_breaks"] job_id = row["job_id"] result_code, result_data = get_job_result(args, job_id) latency = f"{result_data['latency_median']:.3f}" if result_code else str(result_data)[:10] - model_data[model][backend][nodes] = latency + model_data[model][backend][nodes][has_breaks] = latency return model_data def model_name(model): @@ -62,14 +63,15 @@ def print_model_table(args, model, model_data): for node in model_data[backend]: node_counts[node] = node # hack orderedset rows = [] - for backend in model_data: - row = [backend, ] - for node in node_counts: - if node in model_data[backend]: - row.append(model_data[backend][node]) - else: - row.append("-") - rows.append(row) + for has_breaks in [False, True]: + for backend in model_data: + row = [f"{backend} {'w/' if has_breaks else 'wo/'}breaks", ] + for node in node_counts: + if node in model_data[backend]: + row.append(model_data[backend][node][str(has_breaks)]) + else: + row.append("-") + rows.append(row) hdr = ("backend", ) + tuple(f"{node}_latency" for node in node_counts) print(f"{model_name(model)}:") @@ -89,4 +91,4 @@ def main(): print_results(args, data) if __name__ == "__main__": - main() \ No newline at end of file + main() From 2294a3574e1df1c138a496d200c2ff44bea5d5d9 Mon Sep 17 00:00:00 2001 From: David Berard Date: Thu, 29 Sep 2022 16:39:18 -0700 Subject: [PATCH 2/5] Update on "Add flag for dynamo+ddp optimizations" Add a flag that can be used to turn dynamo+ddp optimizations on. This will be used to compare how dynamo+ddp performs with and without the additional graph break strategy for improving dynamo+ddp compute/communication overlap. [ghstack-poisoned] --- torchbenchmark/util/extra_args.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchbenchmark/util/extra_args.py b/torchbenchmark/util/extra_args.py index 53aa9f9c7a..8064174df2 100644 --- a/torchbenchmark/util/extra_args.py +++ b/torchbenchmark/util/extra_args.py @@ -1,4 +1,5 @@ import argparse +import contextlib from typing import List, Optional, Tuple from torchbenchmark.util.backends import list_backends, BACKENDS From 37783a2ca9a2d770b265719c13242aeb32af1d2a Mon Sep 17 00:00:00 2001 From: David Berard Date: Fri, 30 Sep 2022 10:29:29 -0700 Subject: [PATCH 3/5] Update on "Add flag for dynamo+ddp optimizations" Add a flag that can be used to turn dynamo+ddp optimizations on. This will be used to compare how dynamo+ddp performs with and without the additional graph break strategy for improving dynamo+ddp compute/communication overlap. [ghstack-poisoned] --- userbenchmark/ddp_experiments/__init__.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/userbenchmark/ddp_experiments/__init__.py b/userbenchmark/ddp_experiments/__init__.py index 7b162e9ed9..281b737ac1 100644 --- a/userbenchmark/ddp_experiments/__init__.py +++ b/userbenchmark/ddp_experiments/__init__.py @@ -249,6 +249,9 @@ def get_backend_name(model_args): for model_name in models: for model_args in model_args_configs: for has_breaks in [True, False]: + backend_name = get_backend_name(model_args) + if backend_name == "eager" and has_breaks: + continue # copy the model args so we can add more arguments without modifying # the original model_args list. copied_model_args = copy.copy(model_args) @@ -276,7 +279,6 @@ def get_backend_name(model_args): job = executor.submit(TrainerWrapper(args, copied_model_args)) # print ID of the Slurm job - backend_name = get_backend_name(model_args) print(f"{model_name}_{backend_name}_{nodes}_{breakname}: {job.job_id}") output_csv( args.index_file, From 62b1b8d865340742446b16d9c941f0d1dd5dbdd8 Mon Sep 17 00:00:00 2001 From: David Berard Date: Mon, 3 Oct 2022 16:07:23 -0700 Subject: [PATCH 4/5] Update on "Add flag for dynamo+ddp optimizations" Add a flag that can be used to turn dynamo+ddp optimizations on. This will be used to compare how dynamo+ddp performs with and without the additional graph break strategy for improving dynamo+ddp compute/communication overlap. Differential Revision: [D39976005](https://our.internmc.facebook.com/intern/diff/D39976005) [ghstack-poisoned] --- torchbenchmark/util/backends/torchdynamo.py | 46 ++++++++++++++++----- torchbenchmark/util/extra_args.py | 13 ------ 2 files changed, 35 insertions(+), 24 deletions(-) diff --git a/torchbenchmark/util/backends/torchdynamo.py b/torchbenchmark/util/backends/torchdynamo.py index fde479cee5..710e709d7b 100644 --- a/torchbenchmark/util/backends/torchdynamo.py +++ b/torchbenchmark/util/backends/torchdynamo.py @@ -2,6 +2,7 @@ Support TorchDynamo(https://github.com/facebookresearch/torchdynamo) backends """ import argparse +import contextlib from typing import List import torchdynamo @@ -14,19 +15,42 @@ def parse_torchdynamo_args(model: 'torchbenchmark.util.model.BenchmarkModel', dy parser.add_argument( "--extra-py-args", type=str, help="Extra Python args to evaluate." ) + parser.add_argument( + "--optimize_dynamo_ddp", + action='store_true', + help="enable extra optimizations for DDP + dynamo" + ) args, extra_args = parser.parse_known_args(dynamo_args) return args, extra_args def apply_torchdynamo_args(model: 'torchbenchmark.util.model.BenchmarkModel', args: argparse.Namespace, precision: str): - if args.torchdynamo == "fx2trt" and precision == "fp16": - dynamo_optimizer = torchdynamo.optimize(torchdynamo.optimizations.backends.fx2trt_compiler_fp16) - else: - dynamo_optimizer = torchdynamo.optimize(args.torchdynamo) - # evaluate extra python code passed by the user - if args.extra_py_args: - exec(args.extra_py_args) - if model.test == "train": - model.train = dynamo_optimizer(model.train) - else: - model.eval = dynamo_optimizer(model.eval) + optimize_ddp_context = contextlib.nullcontext + + if args.optimize_dynamo_ddp: + import torchdynamo + @contextlib.contextmanager + def optimize_ddp_ctx(val: bool): + old_value = torchdynamo.config.optimize_ddp + try: + torchdynamo.config.optimize_ddp = val + yield + finally: + torchdynamo.config.optimize_ddp = old_value + optimize_ddp_context = lambda: optimize_ddp_ctx(True) + + with optimize_ddp_context(): + if args.torchdynamo == "fx2trt" and precision == "fp16": + dynamo_optimizer = torchdynamo.optimize(torchdynamo.optimizations.backends.fx2trt_compiler_fp16) + else: + dynamo_optimizer = torchdynamo.optimize(args.torchdynamo) + # evaluate extra python code passed by the user + if args.extra_py_args: + exec(args.extra_py_args) + if model.test == "train": + model.train = dynamo_optimizer(model.train) + else: + model.eval = dynamo_optimizer(model.eval) + + model.add_context(optimize_ddp_context) + torchdynamo.reset() diff --git a/torchbenchmark/util/extra_args.py b/torchbenchmark/util/extra_args.py index 8064174df2..998a566c25 100644 --- a/torchbenchmark/util/extra_args.py +++ b/torchbenchmark/util/extra_args.py @@ -112,7 +112,6 @@ def parse_opt_args(model: 'torchbenchmark.util.model.BenchmarkModel', opt_args: parser.add_argument("--torch_trt", action='store_true', help="enable torch_tensorrt") parser.add_argument("--flops", choices=["fvcore", "dcgm"], help="Return the flops result") parser.add_argument("--use_cosine_similarity", action='store_true', help="use cosine similarity for correctness check") - parser.add_argument("--optimize_dynamo_ddp", action='store_true', help="enable extra optimizations for DDP + dynamo") args, extra_args = parser.parse_known_args(opt_args) if model.jit: args.backend = "torchscript" @@ -148,16 +147,4 @@ def apply_opt_args(model: 'torchbenchmark.util.model.BenchmarkModel', args: argp module, exmaple_inputs = model.get_module() precision = 'fp16' if not model.dargs.precision == "fp32" else 'fp32' model.set_module(enable_torchtrt(precision=precision, model=module, example_inputs=exmaple_inputs)) - - if args.optimize_dynamo_ddp: - import torchdynamo - @contextlib.contextmanager - def optimize_ddp_ctx(val: bool): - old_value = torchdynamo.config.optimize_ddp - try: - torchdynamo.config.optimize_ddp = val - yield - finally: - torchdynamo.config.optimize_ddp = old_value - model.add_context(lambda: optimize_ddp_ctx(True)) From 6e839db2258fcfb5f6182191fdafb684affc906d Mon Sep 17 00:00:00 2001 From: David Berard Date: Mon, 3 Oct 2022 16:27:07 -0700 Subject: [PATCH 5/5] Update on "Add flag for dynamo+ddp optimizations" Add a flag that can be used to turn dynamo+ddp optimizations on. This will be used to compare how dynamo+ddp performs with and without the additional graph break strategy for improving dynamo+ddp compute/communication overlap. Differential Revision: [D39976005](https://our.internmc.facebook.com/intern/diff/D39976005) [ghstack-poisoned] --- torchbenchmark/util/backends/torchdynamo.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchbenchmark/util/backends/torchdynamo.py b/torchbenchmark/util/backends/torchdynamo.py index 710e709d7b..1b705d07b7 100644 --- a/torchbenchmark/util/backends/torchdynamo.py +++ b/torchbenchmark/util/backends/torchdynamo.py @@ -27,7 +27,6 @@ def apply_torchdynamo_args(model: 'torchbenchmark.util.model.BenchmarkModel', ar optimize_ddp_context = contextlib.nullcontext if args.optimize_dynamo_ddp: - import torchdynamo @contextlib.contextmanager def optimize_ddp_ctx(val: bool): old_value = torchdynamo.config.optimize_ddp