Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/apps/flux_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def forward_loop(mod):

settings = {
"strict": False,
"allow_complex_guards_as_runtime_asserts": True,
"prefer_deferred_runtime_asserts_over_guards": True,
"enabled_precisions": enabled_precisions,
"truncate_double": True,
"min_block_size": 1,
Expand Down
2 changes: 1 addition & 1 deletion examples/dynamo/torch_export_flux_dev.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@
kwargs=dummy_inputs,
dynamic_shapes=dynamic_shapes,
strict=False,
allow_complex_guards_as_runtime_asserts=True,
prefer_deferred_runtime_asserts_over_guards=True,
)

# %%
Expand Down
2 changes: 1 addition & 1 deletion examples/dynamo/weight_streaming_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def export_llm(model, inputs, min_seq_len=1, max_seq_len=16):
kwargs={"position_ids": position_ids},
dynamic_shapes=({1: seq_len}, {1: seq_len}),
strict=False,
allow_complex_guards_as_runtime_asserts=True,
prefer_deferred_runtime_asserts_over_guards=True,
)

return ep
Expand Down
10 changes: 5 additions & 5 deletions py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def __init__(
use_python_runtime: bool = _defaults.USE_PYTHON_RUNTIME,
immutable_weights: bool = False,
strict: bool = True,
allow_complex_guards_as_runtime_asserts: bool = False,
prefer_deferred_runtime_asserts_over_guards: bool = False,
weight_streaming_budget: Optional[int] = None,
enabled_precisions: Optional[Set[Union[torch.dtype, dtype]]] = None,
**kwargs: Any,
Expand Down Expand Up @@ -134,8 +134,8 @@ def __init__(
self.kwarg_inputs: dict[str, Any] = {}
self.additional_settings = kwargs
self.strict = strict
self.allow_complex_guards_as_runtime_asserts = (
allow_complex_guards_as_runtime_asserts
self.prefer_deferred_runtime_asserts_over_guards = (
prefer_deferred_runtime_asserts_over_guards
)
self.use_python_runtime = use_python_runtime
self.trt_device = to_torch_tensorrt_device(device)
Expand Down Expand Up @@ -312,14 +312,14 @@ def refit_gm(self) -> None:
def get_exported_program(self) -> torch.export.ExportedProgram:

def export_fn() -> torch.export.ExportedProgram:
if self.allow_complex_guards_as_runtime_asserts:
if self.prefer_deferred_runtime_asserts_over_guards:
return _export(
self.original_model,
self.arg_inputs,
kwargs=self.kwarg_inputs,
dynamic_shapes=self._get_total_dynamic_shapes(),
strict=self.strict,
allow_complex_guards_as_runtime_asserts=self.allow_complex_guards_as_runtime_asserts,
prefer_deferred_runtime_asserts_over_guards=self.prefer_deferred_runtime_asserts_over_guards,
)
else:
return torch.export.export(
Expand Down
2 changes: 1 addition & 1 deletion tests/py/dynamo/models/test_engine_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -856,7 +856,7 @@ def export_llm(model, inputs, min_seq_len=1, max_seq_len=16):
(inputs,),
dynamic_shapes=({1: seq_len},),
strict=False,
allow_complex_guards_as_runtime_asserts=True,
prefer_deferred_runtime_asserts_over_guards=True,
)

return ep
Expand Down
4 changes: 2 additions & 2 deletions tools/llm/test_llama_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def test_llama_attention(args):
args=(hidden_states, position_embeddings, None),
dynamic_shapes=dynamic_shapes,
strict=False,
allow_complex_guards_as_runtime_asserts=True,
prefer_deferred_runtime_asserts_over_guards=True,
)

with torch_tensorrt.logging.debug() if args.debug else nullcontext():
Expand Down Expand Up @@ -463,7 +463,7 @@ def test_llama_model(args):
kwargs=kwarg_inputs,
dynamic_shapes=dynamic_shapes,
strict=False,
allow_complex_guards_as_runtime_asserts=True,
prefer_deferred_runtime_asserts_over_guards=True,
)

with torch_tensorrt.logging.debug() if args.debug else nullcontext():
Expand Down
2 changes: 1 addition & 1 deletion tools/llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def export_llm(model, inputs, min_seq_len=1, max_seq_len=16):
kwargs={"position_ids": position_ids},
dynamic_shapes=({1: seq_len}, {1: seq_len}),
strict=False,
allow_complex_guards_as_runtime_asserts=True,
prefer_deferred_runtime_asserts_over_guards=True,
)

return ep
Expand Down
2 changes: 1 addition & 1 deletion tools/perf/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def export_llm(model, inputs, min_seq_len=1, max_seq_len=16):
(inputs,),
dynamic_shapes=({1: seq_len},),
strict=False,
allow_complex_guards_as_runtime_asserts=True,
prefer_deferred_runtime_asserts_over_guards=True,
)

return ep
Expand Down
Loading