diff --git a/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py b/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py index 0feec63316..dcfc70373f 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py @@ -99,8 +99,17 @@ class _TorchTensorRTConstantFolder(ConstantFolder): # type: ignore[misc] def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) - # TODO: Update this function when quantization is added def is_impure(self, node: torch.fx.node.Node) -> bool: - if node.target in (torch.ops.tensorrt.quantize_op.default,): + # Set of known quantization ops to be excluded from constant folding. + # Currently, we exclude all quantization ops coming from modelopt library. + quantization_ops = {} + try: + # modelopt import ensures torch.ops.tensorrt.quantize_op.default is registered + import modelopt.torch.quantization as mtq + assert torch.ops.tensorrt.quantize_op.default + quantization_ops.add(torch.ops.tensorrt.quantize_op.default) + except Exception as e: + pass + if quantization_ops and node.target in quantization_ops: return True return False