diff --git a/torchbenchmark/util/fx2trt.py b/torchbenchmark/util/fx2trt.py index b1e624c511..1eb44c5e77 100644 --- a/torchbenchmark/util/fx2trt.py +++ b/torchbenchmark/util/fx2trt.py @@ -1,6 +1,6 @@ import torch -from torch.fx.experimental.fx2trt import LowerSetting -from torch.fx.experimental.fx2trt.lower import Lowerer +from fx2trt_oss.fx import LowerSetting +from fx2trt_oss.fx.lower import Lowerer """ The purpose of this example is to demostrate the onverall flow of lowering a PyTorch model @@ -31,7 +31,7 @@ def lower_to_trt( explicit_batch_dimension: Use explicit batch dimension in TensorRT if set True, otherwise use implicit batch dimension. fp16_mode: fp16 config given to TRTModule. enable_fuse: Enable pass fusion during lowering if set to true. l=Lowering will try to find pattern defined - in torch.fx.experimental.fx2trt.passes from original module, and replace with optimized pass before apply lowering. + in fx2trt_oss.fx.passes from original module, and replace with optimized pass before apply lowering. verbose_log: Enable verbose log for TensorRT if set True. timing_cache_prefix: Timing cache file name for timing cache used by fx2trt. save_timing_cache: Update timing cache with current timing cache data if set to True.