Skip to content

Commit 8e5da36

Browse files
committed
add nativert benchmark
1 parent 882d50c commit 8e5da36

File tree

1 file changed

+61
-2
lines changed

1 file changed

+61
-2
lines changed

benchmarks/dynamo/common.py

Lines changed: 61 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import signal
2222
import subprocess
2323
import sys
24+
import tempfile
2425
import time
2526
import weakref
2627
from contextlib import contextmanager
@@ -41,6 +42,7 @@
4142
import torch.distributed
4243
import torch.multiprocessing as mp
4344
from torch._C import _has_cuda as HAS_CUDA, _has_xpu as HAS_XPU
45+
from torch._C._nativert import PyModelRunner
4446
from torch._dynamo.profiler import fx_insert_profiling, Profiler
4547
from torch._dynamo.testing import (
4648
dummy_fx_compile,
@@ -1100,6 +1102,8 @@ def maybe_mark_profile(*args, **kwargs):
11001102
frozen_model_iter_fn = export_aot_inductor(
11011103
model, example_inputs, args.inductor_compile_mode
11021104
)
1105+
elif args.export_nativert:
1106+
frozen_model_iter_fn = export_nativert(model, example_inputs)
11031107
else:
11041108
frozen_model_iter_fn = torch._dynamo.run(model_iter_fn)
11051109

@@ -1446,6 +1450,38 @@ def get_excess_memory(cls, model) -> float:
14461450
return cls.cache.get(weakref.ref(model), (None, 0.0))[1]
14471451

14481452

1453+
class NativeRTCache:
1454+
cache: dict[weakref.ref, Any] = {}
1455+
1456+
@classmethod
1457+
def load(cls, model, example_inputs):
1458+
from torch.export.dynamic_shapes import _combine_args, _tree_map_with_path
1459+
1460+
key = weakref.ref(model)
1461+
if key not in cls.cache:
1462+
example_args, example_kwargs = _normalize_bench_inputs(example_inputs)
1463+
example_outputs = model(*example_args, **example_kwargs)
1464+
_register_dataclass_output_as_pytree(example_outputs)
1465+
1466+
combined_args = _combine_args(model, example_args, example_kwargs)
1467+
dynamic_shapes = _tree_map_with_path(
1468+
_produce_dynamic_shapes_for_export, combined_args
1469+
)
1470+
1471+
ep = torch.export.export(
1472+
model, example_args, example_kwargs, dynamic_shapes=dynamic_shapes
1473+
)
1474+
ep = ep.run_decompositions()
1475+
with tempfile.NamedTemporaryFile(delete=False) as f:
1476+
torch.export.pt2_archive._package.package_pt2(
1477+
f, exported_programs={"forward": ep}
1478+
)
1479+
filename = f.name
1480+
cls.cache[key] = PyModelRunner(filename, "forward")
1481+
1482+
return cls.cache[key]
1483+
1484+
14491485
def export(model, example_inputs):
14501486
from torch.export.dynamic_shapes import _combine_args, _tree_map_with_path
14511487

@@ -1472,6 +1508,16 @@ def opt_export(_, example_inputs):
14721508
return opt_export
14731509

14741510

1511+
def export_nativert(model, example_inputs):
1512+
optimized = NativeRTCache.load(model, example_inputs)
1513+
1514+
def opt_nativert(_, example_inputs, collect_outputs=False):
1515+
example_args, example_kwargs = _normalize_bench_inputs(example_inputs)
1516+
return optimized.run(*example_args, **example_kwargs)
1517+
1518+
return opt_nativert
1519+
1520+
14751521
def export_aot_inductor(model, example_inputs, mode):
14761522
optimized = AOTInductorModelCache.load(model, example_inputs, mode)
14771523

@@ -2228,7 +2274,11 @@ def record_status(accuracy_status, dynamo_start_stats):
22282274
try:
22292275
model_copy = self.deepcopy_and_maybe_parallelize(model)
22302276
self.init_optimizer(name, current_device, model_copy.parameters())
2231-
if self.args.export or self.args.export_aot_inductor:
2277+
if (
2278+
self.args.export
2279+
or self.args.export_aot_inductor
2280+
or self.args.export_nativert
2281+
):
22322282
# apply export on module directly
22332283
# no need for n iterations
22342284
# the logic should be the same to self.model_iter_fn (forward_pass)
@@ -2624,7 +2674,7 @@ def warmup(fn, model, example_inputs, mode, niters=5):
26242674
niters=1,
26252675
)
26262676

2627-
if self.args.export_aot_inductor:
2677+
if self.args.export_aot_inductor or self.args.export_nativert:
26282678
optimized_model_iter_fn = optimize_ctx
26292679
else:
26302680
optimized_model_iter_fn = optimize_ctx(self.model_iter_fn)
@@ -3377,6 +3427,11 @@ def get_example_inputs(self):
33773427
action="store_true",
33783428
help="Measure pass rate with Export+AOTInductor",
33793429
)
3430+
group.add_argument(
3431+
"--export-nativert",
3432+
action="store_true",
3433+
help="Measure pass rate with Export+NativeRT",
3434+
)
33803435
group.add_argument(
33813436
"--xla", action="store_true", help="Compare TorchXLA to eager PyTorch"
33823437
)
@@ -3818,6 +3873,10 @@ def run(runner, args, original_dir=None):
38183873
optimize_ctx = export
38193874
experiment = speedup_experiment
38203875
output_filename = "export.csv"
3876+
elif args.export_nativert:
3877+
optimize_ctx = export_nativert
3878+
experiment = speedup_experiment
3879+
output_filename = "export_nativert.csv"
38213880
elif args.xla:
38223881
(dev,) = args.devices
38233882
os.environ["PJRT_DEVICE"] = {"cuda": "GPU", "cpu": "CPU"}[dev]

0 commit comments

Comments
 (0)