2121import signal
2222import subprocess
2323import sys
24+ import tempfile
2425import time
2526import weakref
2627from contextlib import contextmanager
4142import torch .distributed
4243import torch .multiprocessing as mp
4344from torch ._C import _has_cuda as HAS_CUDA , _has_xpu as HAS_XPU
45+ from torch ._C ._nativert import PyModelRunner
4446from torch ._dynamo .profiler import fx_insert_profiling , Profiler
4547from torch ._dynamo .testing import (
4648 dummy_fx_compile ,
@@ -1100,6 +1102,8 @@ def maybe_mark_profile(*args, **kwargs):
11001102 frozen_model_iter_fn = export_aot_inductor (
11011103 model , example_inputs , args .inductor_compile_mode
11021104 )
1105+ elif args .export_nativert :
1106+ frozen_model_iter_fn = export_nativert (model , example_inputs )
11031107 else :
11041108 frozen_model_iter_fn = torch ._dynamo .run (model_iter_fn )
11051109
@@ -1446,6 +1450,38 @@ def get_excess_memory(cls, model) -> float:
14461450 return cls .cache .get (weakref .ref (model ), (None , 0.0 ))[1 ]
14471451
14481452
1453+ class NativeRTCache :
1454+ cache : dict [weakref .ref , Any ] = {}
1455+
1456+ @classmethod
1457+ def load (cls , model , example_inputs ):
1458+ from torch .export .dynamic_shapes import _combine_args , _tree_map_with_path
1459+
1460+ key = weakref .ref (model )
1461+ if key not in cls .cache :
1462+ example_args , example_kwargs = _normalize_bench_inputs (example_inputs )
1463+ example_outputs = model (* example_args , ** example_kwargs )
1464+ _register_dataclass_output_as_pytree (example_outputs )
1465+
1466+ combined_args = _combine_args (model , example_args , example_kwargs )
1467+ dynamic_shapes = _tree_map_with_path (
1468+ _produce_dynamic_shapes_for_export , combined_args
1469+ )
1470+
1471+ ep = torch .export .export (
1472+ model , example_args , example_kwargs , dynamic_shapes = dynamic_shapes
1473+ )
1474+ ep = ep .run_decompositions ()
1475+ with tempfile .NamedTemporaryFile (delete = False ) as f :
1476+ torch .export .pt2_archive ._package .package_pt2 (
1477+ f , exported_programs = {"forward" : ep }
1478+ )
1479+ filename = f .name
1480+ cls .cache [key ] = PyModelRunner (filename , "forward" )
1481+
1482+ return cls .cache [key ]
1483+
1484+
14491485def export (model , example_inputs ):
14501486 from torch .export .dynamic_shapes import _combine_args , _tree_map_with_path
14511487
@@ -1472,6 +1508,16 @@ def opt_export(_, example_inputs):
14721508 return opt_export
14731509
14741510
1511+ def export_nativert (model , example_inputs ):
1512+ optimized = NativeRTCache .load (model , example_inputs )
1513+
1514+ def opt_nativert (_ , example_inputs , collect_outputs = False ):
1515+ example_args , example_kwargs = _normalize_bench_inputs (example_inputs )
1516+ return optimized .run (* example_args , ** example_kwargs )
1517+
1518+ return opt_nativert
1519+
1520+
14751521def export_aot_inductor (model , example_inputs , mode ):
14761522 optimized = AOTInductorModelCache .load (model , example_inputs , mode )
14771523
@@ -2228,7 +2274,11 @@ def record_status(accuracy_status, dynamo_start_stats):
22282274 try :
22292275 model_copy = self .deepcopy_and_maybe_parallelize (model )
22302276 self .init_optimizer (name , current_device , model_copy .parameters ())
2231- if self .args .export or self .args .export_aot_inductor :
2277+ if (
2278+ self .args .export
2279+ or self .args .export_aot_inductor
2280+ or self .args .export_nativert
2281+ ):
22322282 # apply export on module directly
22332283 # no need for n iterations
22342284 # the logic should be the same to self.model_iter_fn (forward_pass)
@@ -2624,7 +2674,7 @@ def warmup(fn, model, example_inputs, mode, niters=5):
26242674 niters = 1 ,
26252675 )
26262676
2627- if self .args .export_aot_inductor :
2677+ if self .args .export_aot_inductor or self . args . export_nativert :
26282678 optimized_model_iter_fn = optimize_ctx
26292679 else :
26302680 optimized_model_iter_fn = optimize_ctx (self .model_iter_fn )
@@ -3377,6 +3427,11 @@ def get_example_inputs(self):
33773427 action = "store_true" ,
33783428 help = "Measure pass rate with Export+AOTInductor" ,
33793429 )
3430+ group .add_argument (
3431+ "--export-nativert" ,
3432+ action = "store_true" ,
3433+ help = "Measure pass rate with Export+NativeRT" ,
3434+ )
33803435 group .add_argument (
33813436 "--xla" , action = "store_true" , help = "Compare TorchXLA to eager PyTorch"
33823437 )
@@ -3818,6 +3873,10 @@ def run(runner, args, original_dir=None):
38183873 optimize_ctx = export
38193874 experiment = speedup_experiment
38203875 output_filename = "export.csv"
3876+ elif args .export_nativert :
3877+ optimize_ctx = export_nativert
3878+ experiment = speedup_experiment
3879+ output_filename = "export_nativert.csv"
38213880 elif args .xla :
38223881 (dev ,) = args .devices
38233882 os .environ ["PJRT_DEVICE" ] = {"cuda" : "GPU" , "cpu" : "CPU" }[dev ]
0 commit comments