diff --git a/tests/py/dynamo/conversion/harness.py b/tests/py/dynamo/conversion/harness.py index 93ffc8b451..646f034a79 100644 --- a/tests/py/dynamo/conversion/harness.py +++ b/tests/py/dynamo/conversion/harness.py @@ -26,11 +26,17 @@ post_lowering, pre_export_lowering, ) +from torch_tensorrt.dynamo.lowering.passes import remove_num_users_is_0_nodes from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule from torch_tensorrt.dynamo.utils import ATOL, RTOL, get_model_device, get_torch_inputs _LOGGER: logging.Logger = logging.getLogger(__name__) +# this is the post lowering pass list for the converter test +post_lowering_pass_list_for_converter_test = [ + remove_num_users_is_0_nodes, +] + # this method is only used in our converter test to infer the module output dtypes via dummy inference # which is due to fx.symbolic_trace does not have the meta['val'] info in the node @@ -435,6 +441,8 @@ def run_test( settings=compilation_settings, ) + for pass_func in post_lowering_pass_list_for_converter_test: + mod = pass_func(mod, compilation_settings) num_inputs = len(inputs) trt_inputs = inputs dtype_to_change = []