diff --git a/torchbenchmark/util/backends/fx2trt.py b/torchbenchmark/util/backends/fx2trt.py index b8165c9ea4..9ea4c09e07 100644 --- a/torchbenchmark/util/backends/fx2trt.py +++ b/torchbenchmark/util/backends/fx2trt.py @@ -62,11 +62,12 @@ def lower_to_trt( """ from fx2trt_oss.fx import LowerSetting from fx2trt_oss.fx.lower import Lowerer + from fx2trt_oss.fx.utils import LowerPrecision lower_setting = LowerSetting( max_batch_size=max_batch_size, max_workspace_size=max_workspace_size, explicit_batch_dimension=explicit_batch_dimension, - fp16_mode=fp16_mode, + lower_precision=LowerPrecision.FP16, enable_fuse=enable_fuse, verbose_log=verbose_log, timing_cache_prefix=timing_cache_prefix,