Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions torchbenchmark/util/backends/torchdynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Support TorchDynamo(https://github.com/facebookresearch/torchdynamo) backends
"""
import argparse
import contextlib
from typing import List
import torchdynamo

Expand All @@ -14,6 +15,11 @@ def parse_torchdynamo_args(model: 'torchbenchmark.util.model.BenchmarkModel', dy
parser.add_argument(
"--tritonmm", type=str, help="torchinductor.config.triton.mm configuration"
)
parser.add_argument(
"--optimize_dynamo_ddp",
action='store_true',
help="enable extra optimizations for DDP + dynamo"
)
args, extra_args = parser.parse_known_args(dynamo_args)
return args, extra_args

Expand All @@ -32,4 +38,16 @@ def apply_torchdynamo_args(model: 'torchbenchmark.util.model.BenchmarkModel', ar
model.train = dynamo_optimizer(model.train)
else:
model.eval = dynamo_optimizer(model.eval)

if args.optimize_dynamo_ddp:
@contextlib.contextmanager
def optimize_ddp_ctx(val: bool):
old_value = torchdynamo.config.optimize_ddp
try:
torchdynamo.config.optimize_ddp = val
yield
finally:
torchdynamo.config.optimize_ddp = old_value
model.add_context(lambda: optimize_ddp_ctx(True))

torchdynamo.reset()
65 changes: 38 additions & 27 deletions userbenchmark/ddp_experiments/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import argparse
import importlib
import os
import copy
import csv
import io
import submitit
Expand Down Expand Up @@ -254,33 +255,43 @@ def get_backend_name(model_args):
for nodes in node_list:
for model_name in models:
for model_args in model_args_configs:
batch_size = model_batch_size[model_name]
args.model = model_name
args.batch_size = batch_size
args.nodes = nodes
args.dist_url = get_init_file(args).as_uri()
args.output_dir = args.job_dir
executor.update_parameters(
gpus_per_node=args.ngpus,
# one task per GPU
tasks_per_node=args.ngpus,
cpus_per_task=10,
nodes=args.nodes,
timeout_min=args.timeout,
# Below are cluster dependent parameters
slurm_partition=args.partition,
slurm_signal_delay_s=120,
)
job = executor.submit(TrainerWrapper(args, model_args))

# print ID of the Slurm job
backend_name = get_backend_name(model_args)
print(f"{model_name}_{backend_name}_{nodes}: {job.job_id}")
output_csv(
args.index_file,
("model", "backend", "nodes", "job_id"),
(model_name, backend_name, nodes, job.job_id),
)
for has_breaks in [True, False]:
backend_name = get_backend_name(model_args)
if backend_name == "eager" and has_breaks:
continue
# copy the model args so we can add more arguments without modifying
# the original model_args list.
copied_model_args = copy.copy(model_args)
breakname = "withbreaks" if has_breaks else "nobreaks"
if has_breaks:
copied_model_args.append("--optimize_dynamo_ddp")
batch_size = model_batch_size[model_name]
args.model = model_name
args.batch_size = batch_size
args.nodes = nodes
args.dist_url = get_init_file(args).as_uri()
args.output_dir = args.job_dir
executor.update_parameters(
gpus_per_node=args.ngpus,
# one task per GPU
tasks_per_node=args.ngpus,
cpus_per_task=10,
nodes=args.nodes,
timeout_min=args.timeout,
# Below are cluster dependent parameters
slurm_partition=args.partition,
slurm_signal_delay_s=120,
slurm_exclude=args.exclude,
)
job = executor.submit(TrainerWrapper(args, copied_model_args))

# print ID of the Slurm job
print(f"{model_name}_{backend_name}_{nodes}_{breakname}: {job.job_id}")
output_csv(
args.index_file,
("model", "backend", "nodes", "has_breaks", "job_id"),
(model_name, backend_name, nodes, has_breaks, job.job_id),
)

# waits for completion and returns output
print(job.results())
Expand Down
28 changes: 15 additions & 13 deletions userbenchmark/ddp_experiments/parse_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,28 +25,29 @@ def get_job_result(args, job_id, worker_rank=0):
elif desc == "success":
# print(f"Success: {payload}")
return True, payload

# print(f"Unknown result: {dat}")
return False, dat

return False, None

def parse_data(args):
"""
Schema:
model_data["model"]["backend"][#nodes] = latency_median
"""
model_data = defaultdict(lambda: defaultdict(lambda: defaultdict(float)))
model_data = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(float))))
with open(args.csv) as f:
runs = csv.DictReader(f)
for row in runs:
model = row["model"]
backend = row["backend"]
nodes = row["nodes"]
has_breaks = row["has_breaks"]
job_id = row["job_id"]
result_code, result_data = get_job_result(args, job_id)
latency = f"{result_data['latency_median']:.3f}" if result_code else str(result_data)[:10]
model_data[model][backend][nodes] = latency
model_data[model][backend][nodes][has_breaks] = latency
return model_data

def model_name(model):
Expand All @@ -62,14 +63,15 @@ def print_model_table(args, model, model_data):
for node in model_data[backend]:
node_counts[node] = node # hack orderedset
rows = []
for backend in model_data:
row = [backend, ]
for node in node_counts:
if node in model_data[backend]:
row.append(model_data[backend][node])
else:
row.append("-")
rows.append(row)
for has_breaks in [False, True]:
for backend in model_data:
row = [f"{backend} {'w/' if has_breaks else 'wo/'}breaks", ]
for node in node_counts:
if node in model_data[backend]:
row.append(model_data[backend][node][str(has_breaks)])
else:
row.append("-")
rows.append(row)

hdr = ("backend", ) + tuple(f"{node}_latency" for node in node_counts)
print(f"{model_name(model)}:")
Expand All @@ -89,4 +91,4 @@ def main():
print_results(args, data)

if __name__ == "__main__":
main()
main()