diff --git a/examples/apps/flux_demo.py b/examples/apps/flux_demo.py index c061bb5d81..2a4e1f9d5f 100644 --- a/examples/apps/flux_demo.py +++ b/examples/apps/flux_demo.py @@ -121,7 +121,7 @@ def forward_loop(mod): settings = { "strict": False, - "allow_complex_guards_as_runtime_asserts": True, + "prefer_deferred_runtime_asserts_over_guards": True, "enabled_precisions": enabled_precisions, "truncate_double": True, "min_block_size": 1, diff --git a/examples/dynamo/torch_export_flux_dev.py b/examples/dynamo/torch_export_flux_dev.py index 8f471668f1..e46ad9ba46 100644 --- a/examples/dynamo/torch_export_flux_dev.py +++ b/examples/dynamo/torch_export_flux_dev.py @@ -92,7 +92,7 @@ kwargs=dummy_inputs, dynamic_shapes=dynamic_shapes, strict=False, - allow_complex_guards_as_runtime_asserts=True, + prefer_deferred_runtime_asserts_over_guards=True, ) # %% diff --git a/examples/dynamo/weight_streaming_example.py b/examples/dynamo/weight_streaming_example.py index 601292ba95..c477ba6df8 100644 --- a/examples/dynamo/weight_streaming_example.py +++ b/examples/dynamo/weight_streaming_example.py @@ -65,7 +65,7 @@ def export_llm(model, inputs, min_seq_len=1, max_seq_len=16): kwargs={"position_ids": position_ids}, dynamic_shapes=({1: seq_len}, {1: seq_len}), strict=False, - allow_complex_guards_as_runtime_asserts=True, + prefer_deferred_runtime_asserts_over_guards=True, ) return ep diff --git a/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py index 258449ad7b..1cffec77c2 100644 --- a/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py @@ -68,7 +68,7 @@ def __init__( use_python_runtime: bool = _defaults.USE_PYTHON_RUNTIME, immutable_weights: bool = False, strict: bool = True, - allow_complex_guards_as_runtime_asserts: bool = False, + prefer_deferred_runtime_asserts_over_guards: bool = False, weight_streaming_budget: Optional[int] = None, enabled_precisions: Optional[Set[Union[torch.dtype, dtype]]] = None, **kwargs: Any, @@ -134,8 +134,8 @@ def __init__( self.kwarg_inputs: dict[str, Any] = {} self.additional_settings = kwargs self.strict = strict - self.allow_complex_guards_as_runtime_asserts = ( - allow_complex_guards_as_runtime_asserts + self.prefer_deferred_runtime_asserts_over_guards = ( + prefer_deferred_runtime_asserts_over_guards ) self.use_python_runtime = use_python_runtime self.trt_device = to_torch_tensorrt_device(device) @@ -312,14 +312,14 @@ def refit_gm(self) -> None: def get_exported_program(self) -> torch.export.ExportedProgram: def export_fn() -> torch.export.ExportedProgram: - if self.allow_complex_guards_as_runtime_asserts: + if self.prefer_deferred_runtime_asserts_over_guards: return _export( self.original_model, self.arg_inputs, kwargs=self.kwarg_inputs, dynamic_shapes=self._get_total_dynamic_shapes(), strict=self.strict, - allow_complex_guards_as_runtime_asserts=self.allow_complex_guards_as_runtime_asserts, + prefer_deferred_runtime_asserts_over_guards=self.prefer_deferred_runtime_asserts_over_guards, ) else: return torch.export.export( diff --git a/tests/py/dynamo/models/test_engine_cache.py b/tests/py/dynamo/models/test_engine_cache.py index 61b5d74679..5e310900aa 100644 --- a/tests/py/dynamo/models/test_engine_cache.py +++ b/tests/py/dynamo/models/test_engine_cache.py @@ -856,7 +856,7 @@ def export_llm(model, inputs, min_seq_len=1, max_seq_len=16): (inputs,), dynamic_shapes=({1: seq_len},), strict=False, - allow_complex_guards_as_runtime_asserts=True, + prefer_deferred_runtime_asserts_over_guards=True, ) return ep diff --git a/tools/llm/test_llama_components.py b/tools/llm/test_llama_components.py index ef7e59cd72..9adb51d324 100644 --- a/tools/llm/test_llama_components.py +++ b/tools/llm/test_llama_components.py @@ -79,7 +79,7 @@ def test_llama_attention(args): args=(hidden_states, position_embeddings, None), dynamic_shapes=dynamic_shapes, strict=False, - allow_complex_guards_as_runtime_asserts=True, + prefer_deferred_runtime_asserts_over_guards=True, ) with torch_tensorrt.logging.debug() if args.debug else nullcontext(): @@ -463,7 +463,7 @@ def test_llama_model(args): kwargs=kwarg_inputs, dynamic_shapes=dynamic_shapes, strict=False, - allow_complex_guards_as_runtime_asserts=True, + prefer_deferred_runtime_asserts_over_guards=True, ) with torch_tensorrt.logging.debug() if args.debug else nullcontext(): diff --git a/tools/llm/utils.py b/tools/llm/utils.py index 2c3434b0ed..842d2a597a 100644 --- a/tools/llm/utils.py +++ b/tools/llm/utils.py @@ -41,7 +41,7 @@ def export_llm(model, inputs, min_seq_len=1, max_seq_len=16): kwargs={"position_ids": position_ids}, dynamic_shapes=({1: seq_len}, {1: seq_len}), strict=False, - allow_complex_guards_as_runtime_asserts=True, + prefer_deferred_runtime_asserts_over_guards=True, ) return ep diff --git a/tools/perf/utils.py b/tools/perf/utils.py index b0bed6ff0e..13d7deac43 100644 --- a/tools/perf/utils.py +++ b/tools/perf/utils.py @@ -228,7 +228,7 @@ def export_llm(model, inputs, min_seq_len=1, max_seq_len=16): (inputs,), dynamic_shapes=({1: seq_len},), strict=False, - allow_complex_guards_as_runtime_asserts=True, + prefer_deferred_runtime_asserts_over_guards=True, ) return ep