Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions tests/py/dynamo/conversion/harness.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,17 @@
post_lowering,
pre_export_lowering,
)
from torch_tensorrt.dynamo.lowering.passes import remove_num_users_is_0_nodes
from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule
from torch_tensorrt.dynamo.utils import ATOL, RTOL, get_model_device, get_torch_inputs

_LOGGER: logging.Logger = logging.getLogger(__name__)

# this is the post lowering pass list for the converter test
post_lowering_pass_list_for_converter_test = [
remove_num_users_is_0_nodes,
]


# this method is only used in our converter test to infer the module output dtypes via dummy inference
# which is due to fx.symbolic_trace does not have the meta['val'] info in the node
Expand Down Expand Up @@ -435,6 +441,8 @@ def run_test(
settings=compilation_settings,
)

for pass_func in post_lowering_pass_list_for_converter_test:
mod = pass_func(mod, compilation_settings)
num_inputs = len(inputs)
trt_inputs = inputs
dtype_to_change = []
Expand Down
Loading