Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 0 additions & 89 deletions benchmarks/dynamo/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,84 +348,6 @@ def randomize_input(inputs):
)


def cold_start_experiment(args, model_iter_fn, model, example_inputs, optimize_ctx):
compile_iters = 2
total_iters = compile_iters + 2
timings = np.zeros((total_iters, 2), np.float64)
# if we randomize the input, we should also check the result is correct
should_check_result = should_randomize_input = args.randomize_input
is_correct = True

optimized_model_iter_fn = optimize_ctx(model_iter_fn)
for rep in range(total_iters):
inputs = (
randomize_input(copy.deepcopy(example_inputs))
if should_randomize_input
else example_inputs
)

# interleave the runs to handle frequency scaling and load changes
timings[rep, 0], expected_output = timed(
model, model_iter_fn, inputs, return_result=True
)
timings[rep, 1], actual_output = timed(
model, optimized_model_iter_fn, inputs, return_result=True
)
if should_check_result:
is_correct = is_correct and same(expected_output, actual_output)
pvalue = ttest_ind(timings[:, 0], timings[:, 1]).pvalue
worst = np.max(timings, axis=0)

def breakeven(dynamo_times, eager_times):
"""
Solve for the number of iterations it takes dynamo to 'catch up' with eager,
taking into account the time it spent compiling. Assumes all compilation
happens up front and the model is static thereafter, which is definitely not
true in general but might be across torchbench.

dc1, dc2 = dynamo compilation iterations (with Prof Exec)
d, e = dynamo, eager warmed up iteration
B = num iters to break even
dc1 + dc2 + (B-2)d = B*e
B = (dc1 + dc2 - 2d) / (e - d)
"""
dc1, dc2, d = dynamo_times[0], dynamo_times[1], np.median(dynamo_times[2:])
e = np.median(eager_times)
if d < e:
return (dc1 + dc2 + 2 * d) / (e - d)
else:
# if optimized dynamo is not faster than eager we'll compute
# a nonsense negative number
return 0

speedup = worst[0] / worst[1]
eager_times, dynamo_times = timings[:, 0], timings[:, 1]
output_csv(
output_filename,
("dev", "name", "batch_size", "cold-start speedup", "breakeven iters"),
[
current_device,
current_name,
current_batch_size,
float(speedup),
breakeven(dynamo_times, eager_times),
],
)

def format_speedup(
speedup, pvalue, breakeven_iters, is_correct=True, pvalue_threshold=0.1
):
if not is_correct:
return "ERROR"
if pvalue > pvalue_threshold:
return f"{speedup:.3f}x breakeven={breakeven_iters:.2f} iters SAME"
return f"{speedup:.3f}x breakeven={breakeven_iters:.2f} iters p={pvalue:.2f}"

return format_speedup(
speedup, pvalue, breakeven(dynamo_times, eager_times), is_correct=is_correct
)


def speedup_experiment(args, model_iter_fn, model, example_inputs, **kwargs):
"""
Measure speedups over eager.
Expand Down Expand Up @@ -1538,9 +1460,6 @@ def parse_args():
action="store_true",
help="speedup using the ltc backend without reusing compiled graph",
)
group.add_argument(
"--cold-start", action="store_true", help=help(cold_start_experiment)
)
group.add_argument(
"--overhead", action="store_true", help=help(overhead_experiment)
)
Expand Down Expand Up @@ -1785,14 +1704,6 @@ def main(runner, original_dir=None):
optimize_ctx = torch._dynamo.optimize(dummy_fx_compile, nopython=args.nopython)
experiment = speedup_experiment
output_filename = "overheads.csv"
elif args.cold_start:
optimize_ctx = torch._dynamo.optimize("aot_nvfuser", nopython=args.nopython)
experiment = cold_start_experiment
assert args.nvfuser, "TODO - Add another aot string for mem fusion with NNC"
backend_str = "nvfuser" if args.nvfuser else "nnc"
output_filename = f"cold_start_{backend_str}.csv"
# TODO(whc) should we move this to a more general part of the script?
torch.backends.cuda.matmul.allow_tf32 = True
elif args.inductor or args.inductor_dynamic:
from torch._inductor import config as inductor_config

Expand Down