Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FX] remove op_lowering_disallow_list and format revert #1261

Merged
merged 6 commits into from
Aug 12, 2022

Conversation

frank-wei
Copy link
Contributor

Description

  1. format revert to fb internal to avoid fb code linter issues
  2. remove op_lowering_disallow_list

Fixes # (issue)

Type of change

Please delete options that are not relevant and/or add your own.

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • This change requires a documentation update

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes
  • I have added the relevant labels to my PR in so that relevant reviewers are notified

@frank-wei frank-wei changed the title Fb sync wwei6 [FX] remove op_lowering_disallow_list and format revert Aug 12, 2022
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- py/torch_tensorrt/fx/input_tensor_spec.py	2022-08-12 18:48:48.790461 +0000
+++ py/torch_tensorrt/fx/input_tensor_spec.py	2022-08-12 18:52:24.650475 +0000
@@ -6,14 +6,11 @@
from .utils import get_dynamic_dims


def generate_input_specs(inputs, lower_setting, additional_inputs=None):
    # dynamic_batch is TRT only flag.
-    if (
-        not lower_setting.explicit_batch_dimension
-        or lower_setting.dynamic_batch is False
-    ):
+    if not lower_setting.explicit_batch_dimension or lower_setting.dynamic_batch is False:
        return InputTensorSpec.from_tensors(inputs)

    # If we don't have additional inputs, we assume the first dimension
    # is the dynamic batch dimension. Otherwise, we use the additional
    # inputs to determine the batch dimension.
@@ -33,20 +30,16 @@
        for i, j in zip(inputs, additional_inputs):
            found_batch_dim = False

            for idx, values in enumerate(zip(i.shape, j.shape)):
                if values[0] != values[1]:
-                    assert (
-                        found_batch_dim is False
-                    ), f"We've already found a batch dim, {i.shape}, {j.shape}."
+                    assert found_batch_dim is False, f"We've already found a batch dim, {i.shape}, {j.shape}."
                    batch_dims.append(idx)
                    found_batch_dim = True

            if not found_batch_dim:
-                raise RuntimeError(
-                    f"Failed to find batch dimension because shapes are the same, {i.shape}"
-                )
+                raise RuntimeError(f"Failed to find batch dimension because shapes are the same, {i.shape}")

        return InputTensorSpec.from_tensors_with_dynamic_batch_size(
            inputs,
            (
                0,
@@ -158,13 +151,11 @@
                batch_dim
            ), f"The {i}th tensor (shape: {tensor.shape}) doesn't have the correct batch size: {batch_size}."
            shape = list(tensor.shape)
            shape[batch_dim] = -1
            shape_ranges: List[ShapeRange] = [tuple(tuple(shape[0:batch_dim] + [bs] + shape[batch_dim + 1 :]) for bs in batch_size_range)] * opt_profile_replica  # type: ignore[list-item]
-            input_specs.append(
-                cls(tuple(shape), tensor.dtype, tensor.device, shape_ranges)
-            )
+            input_specs.append(cls(tuple(shape), tensor.dtype, tensor.device, shape_ranges))

        return input_specs

    def to_random_tensor(self):
        shape = tuple(self.shape)
--- py/torch_tensorrt/fx/lower.py	2022-08-12 18:48:48.790461 +0000
+++ py/torch_tensorrt/fx/lower.py	2022-08-12 18:52:24.739763 +0000
@@ -77,13 +77,11 @@
    lower_setting: LowerSetting
    timing_cache_manager: TimingCacheManager

    @classmethod
    def create(cls, lower_setting):
-        timing_cache_manager = TimingCacheManager(
-            lower_setting.timing_cache_prefix, lower_setting.save_timing_cache
-        )
+        timing_cache_manager = TimingCacheManager(lower_setting.timing_cache_prefix, lower_setting.save_timing_cache)
        return LowerTrtInterpreter(lower_setting, timing_cache_manager)

    def __call__(self, mod, input, split_name) -> TRTInterpreterResult:
        assert self.lower_setting.input_specs, "Can't find input specs for lowering!"
        logger.info(f"{split_name=} {self.lower_setting.input_specs=}")
@@ -103,13 +101,11 @@
        interpreter = TRTInterpreter(
            mod,
            input_specs=self.lower_setting.input_specs,
            explicit_batch_dimension=self.lower_setting.explicit_batch_dimension,
            explicit_precision=self.lower_setting.explicit_precision,
-            logger_level=trt.Logger.VERBOSE
-            if self.lower_setting.verbose_log
-            else trt.Logger.WARNING,
+            logger_level=trt.Logger.VERBOSE if self.lower_setting.verbose_log else trt.Logger.WARNING,
        )

        interp_result: TRTInterpreterResult = interpreter.run(
            max_batch_size=self.lower_setting.max_batch_size,
            max_workspace_size=self.lower_setting.max_workspace_size,
@@ -129,13 +125,11 @@
            self.timing_cache_manager.update_timing_cache(split_name, timing_cache)

        return interp_result


-def default_split_function(
-    model: fx.GraphModule, inputs: Input, lower_setting: LowerSetting
-) -> SplitResult:
+def default_split_function(model: fx.GraphModule, inputs: Input, lower_setting: LowerSetting) -> SplitResult:
    splitter_setting = TRTSplitterSetting()
    splitter_setting.use_implicit_batch_dim = not lower_setting.explicit_batch_dimension
    splitter_setting.min_acc_module_size = lower_setting.min_acc_module_size
    splitter = TRTSplitter(model, inputs, settings=splitter_setting)
    splitter.node_support_preview()
@@ -147,13 +141,11 @@


def default_lower_pass(
    create_trt_interpreter: Callable[[LowerSetting], LowerTrtInterpreter],
) -> PassFunc:
-    def lower_pass(
-        mod: nn.Module, input: Input, lower_setting: LowerSetting, module_name: str
-    ) -> nn.Module:
+    def lower_pass(mod: nn.Module, input: Input, lower_setting: LowerSetting, module_name: str) -> nn.Module:
        """
        Create a module transformation pass which lowers an `fx.GraphModule` into a
        `TRTModule`
        """
        interpreter = create_trt_interpreter(lower_setting)
@@ -223,21 +215,13 @@
        inputs: Input,
        additional_inputs: Optional[Input] = None,
    ) -> nn.Module:
        module.eval()

-        if (
-            self.lower_pass_manager_builder.lower_setting.lower_precision
-            == LowerPrecision.FP16
-        ):
+        if self.lower_pass_manager_builder.lower_setting.lower_precision == LowerPrecision.FP16:
            module.half()
-            inputs = tuple(
-                x.half() if x is not None and x.dtype == torch.float32 else x
-                for x in inputs
-            )
-        pm = self.lower_pass_manager_builder.build_trt_lower_pipeline(
-            inputs, additional_inputs
-        )
+            inputs = tuple(x.half() if x is not None and x.dtype == torch.float32 else x for x in inputs)
+        pm = self.lower_pass_manager_builder.build_trt_lower_pipeline(inputs, additional_inputs)

        lower_result = pm(module)

        return lower_result
--- py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py	2022-08-12 18:48:48.790461 +0000
+++ py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py	2022-08-12 18:52:24.961256 +0000
@@ -35,23 +35,17 @@
# >>> with FUSE_PASSES_POST_OBSERVER.add(print_module_and_input):
# >>>     # print_module_and_input will be called right after the fuse passes
# >>>     lower(module, sample_input)

# Observer for the model after the fuse passes.
-FUSE_PASSES_POST_OBSERVER: Observer[Callable[[nn.Module, Input], None]] = Observer(
-    "FUSE_PASSES_POST_OBSERVER"
-)
+FUSE_PASSES_POST_OBSERVER: Observer[Callable[[nn.Module, Input], None]] = Observer("FUSE_PASSES_POST_OBSERVER")

# Observer for the TRT split submodules before lowering
-LOWER_SPLIT_PRE_OBSERVER: Observer[Callable[[str, nn.Module, Input], None]] = Observer(
-    "LOWER_SPLIT_PRE_OBSERVER"
-)
+LOWER_SPLIT_PRE_OBSERVER: Observer[Callable[[str, nn.Module, Input], None]] = Observer("LOWER_SPLIT_PRE_OBSERVER")

# Observer for the TRT split submodules after lowering
-LOWER_SPLIT_POST_OBSERVER: Observer[Callable[[str, nn.Module, Input], None]] = Observer(
-    "LOWER_SPLIT_POST_OBSERVER"
-)
+LOWER_SPLIT_POST_OBSERVER: Observer[Callable[[str, nn.Module, Input], None]] = Observer("LOWER_SPLIT_POST_OBSERVER")
# ----------------------------------------------------------------------


def wrapper(fn: Callable, input) -> Callable:
    @wraps(fn)
@@ -103,22 +97,16 @@
            passes.append(wrapper(p, self._input))
        for p in self.lower_setting.lower_basic_fuse_pass.passes:
            passes.append(wrapper(p, self._input))

        passes.append(inplace_wrapper(common_subexpression_elimination))
-        passes.append(
-            inplace_wrapper(lambda m: FUSE_PASSES_POST_OBSERVER.observe(m, self._input))
-        )
+        passes.append(inplace_wrapper(lambda m: FUSE_PASSES_POST_OBSERVER.observe(m, self._input)))

        return PassManager.build_from_passlist(passes)

    def _split_pass(self) -> PassManager:
-        passes = [
-            partial(
-                self._split_func, inputs=self._input, lower_setting=self.lower_setting
-            )
-        ]
+        passes = [partial(self._split_func, inputs=self._input, lower_setting=self.lower_setting)]
        passes.append(
            inplace_wrapper(
                lambda split_result: remove_duplicate_output_args(
                    split_result.split_module, split_result.submodule_inputs.keys()
                )
@@ -152,21 +140,15 @@
                    lowering_start_time = datetime.datetime.now()

                    self.lower_setting.input_specs = generate_input_specs(
                        submod_inputs,
                        self.lower_setting,
-                        additional_submodule_inputs[submod_name]
-                        if additional_submodule_inputs
-                        else None,
+                        additional_submodule_inputs[submod_name] if additional_submodule_inputs else None,
                    )
-                    lowered_module = self._lower_func(
-                        submod, submod_inputs, self.lower_setting, submod_name
-                    )
+                    lowered_module = self._lower_func(submod, submod_inputs, self.lower_setting, submod_name)
                    setattr(split_result.split_module, submod_name, lowered_module)
-                    LOWER_SPLIT_POST_OBSERVER.observe(
-                        submod_name, lowered_module, submod_inputs
-                    )
+                    LOWER_SPLIT_POST_OBSERVER.observe(submod_name, lowered_module, submod_inputs)
                    _LOGGER.info(
                        f"Lowering submodule {submod_name} elapsed time {datetime.datetime.now() - lowering_start_time}"
                    )

            return split_result.split_module
@@ -184,28 +166,22 @@
                # Only acc submodules will be lowered.
                if not submod_name.startswith(split_result.non_acc_submodule_prefix):
                    _LOGGER.info(f"Now lowering submodule {submod_name}")
                    lowering_start_time = datetime.datetime.now()

-                    lowered_module = self._lower_func(
-                        submod, submod_inputs, self.lower_setting, submod_name
-                    )
+                    lowered_module = self._lower_func(submod, submod_inputs, self.lower_setting, submod_name)
                    setattr(split_result.split_module, submod_name, lowered_module)
-                    LOWER_SPLIT_POST_OBSERVER.observe(
-                        submod_name, lowered_module, submod_inputs
-                    )
+                    LOWER_SPLIT_POST_OBSERVER.observe(submod_name, lowered_module, submod_inputs)
                    _LOGGER.info(
                        f"Lowering submodule {submod_name} elapsed time {datetime.datetime.now() - lowering_start_time}"
                    )

            return split_result.split_module

        return PassManager.build_from_passlist([lower_func])

-    def build_trt_lower_pipeline(
-        self, input: Input, additional_input: Optional[Input] = None
-    ) -> PassManager:
+    def build_trt_lower_pipeline(self, input: Input, additional_input: Optional[Input] = None) -> PassManager:
        self._input = input
        self._additional_input = additional_input
        passes = []

        passes.append(self._const_fold_pass())
@@ -214,13 +190,11 @@
        passes.append(self._trt_lower_pass())

        pm = PassManager.build_from_passlist(passes)
        return pm

-    def build_default_lower_pipeline(
-        self, input: Input, additional_input: Optional[Input] = None
-    ) -> PassManager:
+    def build_default_lower_pipeline(self, input: Input, additional_input: Optional[Input] = None) -> PassManager:
        self._input = input
        self._additional_input = additional_input
        passes = []

        passes.append(self._const_fold_pass())
--- py/torch_tensorrt/fx/test/converters/acc_op/test_avgpool.py	2022-08-12 18:48:48.790461 +0000
+++ py/torch_tensorrt/fx/test/converters/acc_op/test_avgpool.py	2022-08-12 18:52:25.154686 +0000
@@ -27,13 +27,11 @@
        count_include_pad=True,
    ):
        class TestModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
-                self.avg_pool = torch.nn.AvgPool1d(
-                    kernel_size, stride, padding, ceil_mode, count_include_pad
-                )
+                self.avg_pool = torch.nn.AvgPool1d(kernel_size, stride, padding, ceil_mode, count_include_pad)

            def forward(self, x):
                return self.avg_pool(x)

        inputs = [torch.randn(1, 3, 224)]
@@ -60,13 +58,11 @@
        count_include_pad=True,
    ):
        class TestModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
-                self.avg_pool = torch.nn.AvgPool1d(
-                    kernel_size, stride, padding, ceil_mode, count_include_pad
-                )
+                self.avg_pool = torch.nn.AvgPool1d(kernel_size, stride, padding, ceil_mode, count_include_pad)

            def forward(self, x):
                return self.avg_pool(x)

        input_specs = [
@@ -75,13 +71,11 @@
                dtype=torch.float32,
                shape_ranges=[((1, 3, 3), (3, 3, 3), (3, 3, 3))],
            ),
        ]

-        self.run_test_with_dynamic_shape(
-            TestModule(), input_specs, expected_ops={acc_ops.avg_pool1d}
-        )
+        self.run_test_with_dynamic_shape(TestModule(), input_specs, expected_ops={acc_ops.avg_pool1d})

    def test_avg_pool2d_with_dynamic_shape_four_dimensions(
        self,
        test_name="default",
        kernel_size=1,
@@ -112,13 +106,11 @@
                dtype=torch.float32,
                shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (5, 5, 5, 5))],
            ),
        ]

-        self.run_test_with_dynamic_shape(
-            TestModule(), input_specs, expected_ops={acc_ops.avg_pool2d}
-        )
+        self.run_test_with_dynamic_shape(TestModule(), input_specs, expected_ops={acc_ops.avg_pool2d})

    @parameterized.expand(
        [
            ("default", 1),
            ("kernal_size", 3),
@@ -254,12 +246,10 @@
                dtype=torch.float32,
                shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (5, 5, 5, 5))],
            ),
        ]

-        self.run_test_with_dynamic_shape(
-            TestModule(), input_specs, expected_ops={acc_ops.avg_pool2d}
-        )
+        self.run_test_with_dynamic_shape(TestModule(), input_specs, expected_ops={acc_ops.avg_pool2d})


if __name__ == "__main__":
    run_tests()
--- py/torch_tensorrt/fx/test/converters/acc_op/test_batchnorm.py	2022-08-12 18:48:48.790461 +0000
+++ py/torch_tensorrt/fx/test/converters/acc_op/test_batchnorm.py	2022-08-12 18:52:25.193116 +0000
@@ -32,13 +32,11 @@
                dtype=torch.float32,
                shape_ranges=[((2, 3, 5), (6, 3, 5), (10, 3, 5))],
            ),
        ]

-        self.run_test_with_dynamic_shape(
-            TestModule(), input_specs, expected_ops={acc_ops.batch_norm}
-        )
+        self.run_test_with_dynamic_shape(TestModule(), input_specs, expected_ops={acc_ops.batch_norm})

    def test_batchnorm_with_dynamic_shape(self):
        class TestModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
@@ -53,13 +51,11 @@
                dtype=torch.float32,
                shape_ranges=[((1, 3, 1, 1), (1, 3, 5, 5), (2, 3, 10, 10))],
            ),
        ]

-        self.run_test_with_dynamic_shape(
-            TestModule(), input_specs, expected_ops={acc_ops.batch_norm}
-        )
+        self.run_test_with_dynamic_shape(TestModule(), input_specs, expected_ops={acc_ops.batch_norm})

    # Testing with shape=(-1, -1, -1, -1) results in AssertionError: Channel dim can't be dynamic for batch norm.


if __name__ == "__main__":
--- py/torch_tensorrt/fx/test/converters/acc_op/test_clamp.py	2022-08-12 18:48:48.790461 +0000
+++ py/torch_tensorrt/fx/test/converters/acc_op/test_clamp.py	2022-08-12 18:52:25.320784 +0000
@@ -51,12 +51,10 @@
                dtype=torch.float32,
                shape_ranges=[((1, 1, 3, 3), (3, 3, 3, 3), (5, 5, 3, 3))],
            ),
        ]

-        self.run_test_with_dynamic_shape(
-            TestModule(), input_specs, expected_ops={acc_ops.clamp}
-        )
+        self.run_test_with_dynamic_shape(TestModule(), input_specs, expected_ops={acc_ops.clamp})


if __name__ == "__main__":
    run_tests()
--- py/torch_tensorrt/fx/test/converters/acc_op/test_convolution.py	2022-08-12 18:48:48.790461 +0000
+++ py/torch_tensorrt/fx/test/converters/acc_op/test_convolution.py	2022-08-12 18:52:25.418225 +0000
@@ -27,13 +27,11 @@
        bias=True,
    ):
        class TestModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
-                self.conv = torch.nn.Conv1d(
-                    3, 6, kernel_size, stride, padding, dilation, groups, bias
-                )
+                self.conv = torch.nn.Conv1d(3, 6, kernel_size, stride, padding, dilation, groups, bias)

            def forward(self, x):
                return self.conv(x)

        inputs = [torch.randn(1, 3, 32)]
@@ -60,13 +58,11 @@
        bias=True,
    ):
        class TestModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
-                self.conv = torch.nn.Conv1d(
-                    3, 6, kernel_size, stride, padding, dilation, groups, bias
-                )
+                self.conv = torch.nn.Conv1d(3, 6, kernel_size, stride, padding, dilation, groups, bias)

            def forward(self, x):
                return self.conv(x)

        input_specs = [
@@ -75,13 +71,11 @@
                dtype=torch.float32,
                shape_ranges=[((1, 3, 3), (3, 3, 3), (5, 3, 3))],
            ),
        ]

-        self.run_test_with_dynamic_shape(
-            TestModule(), input_specs, expected_ops={acc_ops.conv1d}
-        )
+        self.run_test_with_dynamic_shape(TestModule(), input_specs, expected_ops={acc_ops.conv1d})

    @parameterized.expand(
        [
            ("default", 1),
            param("no_bias", 1, bias=False),
@@ -102,13 +96,11 @@
        bias=True,
    ):
        class TestModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
-                self.conv = torch.nn.Conv2d(
-                    3, 6, kernel_size, stride, padding, dilation, groups, bias
-                )
+                self.conv = torch.nn.Conv2d(3, 6, kernel_size, stride, padding, dilation, groups, bias)

            def forward(self, x):
                return self.conv(x)

        inputs = [torch.randn(1, 3, 32, 32)]
@@ -131,13 +123,11 @@
                shape=(-1, 3, -1, -1),
                dtype=torch.float32,
                shape_ranges=[((1, 3, 1, 1), (1, 3, 4, 4), (32, 3, 128, 128))],
            ),
        ]
-        self.run_test_with_dynamic_shape(
-            TestModule(), input_specs, expected_ops={acc_ops.conv2d}
-        )
+        self.run_test_with_dynamic_shape(TestModule(), input_specs, expected_ops={acc_ops.conv2d})

    @parameterized.expand(
        [
            ("default", 1),
            param("no_bias", 1, bias=False),
@@ -158,13 +148,11 @@
        bias=True,
    ):
        class TestModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
-                self.conv = torch.nn.Conv3d(
-                    3, 6, kernel_size, stride, padding, dilation, groups, bias
-                )
+                self.conv = torch.nn.Conv3d(3, 6, kernel_size, stride, padding, dilation, groups, bias)

            def forward(self, x):
                return self.conv(x)

        inputs = [torch.randn(1, 3, 32, 32, 32)]
@@ -187,12 +175,10 @@
                shape=(-1, 3, -1, -1, -1),
                dtype=torch.float32,
                shape_ranges=[((1, 3, 1, 1, 1), (1, 3, 4, 4, 4), (8, 3, 32, 32, 32))],
            ),
        ]
-        self.run_test_with_dynamic_shape(
-            TestModule(), input_specs, expected_ops={acc_ops.conv3d}
-        )
+        self.run_test_with_dynamic_shape(TestModule(), input_specs, expected_ops={acc_ops.conv3d})


if __name__ == "__main__":
    run_tests()
--- py/torch_tensorrt/fx/test/converters/acc_op/test_gelu.py	2022-08-12 18:48:48.790461 +0000
+++ py/torch_tensorrt/fx/test/converters/acc_op/test_gelu.py	2022-08-12 18:52:25.578605 +0000
@@ -5,13 +5,11 @@
import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase, InputTensorSpec


-@unittest.skip(
-    reason="Could not find CustomGeluPluginDynamic. Enable it once we upgrade TRT to 8.4"
-)
+@unittest.skip(reason="Could not find CustomGeluPluginDynamic. Enable it once we upgrade TRT to 8.4")
class TestGELU(AccTestCase):
    def test_gelu(self):
        class TestModule(nn.Module):
            def forward(self, x):
                return nn.functional.gelu(x)
@@ -34,13 +32,11 @@
                shape=(-1, -1, -1),
                dtype=torch.float32,
                shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))],
            ),
        ]
-        self.run_test_with_dynamic_shape(
-            TestModule(), input_specs, expected_ops={acc_ops.gelu}
-        )
+        self.run_test_with_dynamic_shape(TestModule(), input_specs, expected_ops={acc_ops.gelu})

    def test_gelu_with_dynamic_shape_four_dimensions(self):
        class TestModule(nn.Module):
            def forward(self, x):
                return nn.functional.gelu(x)
@@ -51,12 +47,10 @@
                dtype=torch.float32,
                shape_ranges=[((1, 1, 1, 1), (1, 2, 3, 3), (3, 3, 3, 3))],
            ),
        ]

-        self.run_test_with_dynamic_shape(
-            TestModule(), input_specs, expected_ops={acc_ops.gelu}
-        )
+        self.run_test_with_dynamic_shape(TestModule(), input_specs, expected_ops={acc_ops.gelu})


if __name__ == "__main__":
    run_tests()
--- py/torch_tensorrt/fx/test/converters/acc_op/test_interpolate.py	2022-08-12 18:48:48.790461 +0000
+++ py/torch_tensorrt/fx/test/converters/acc_op/test_interpolate.py	2022-08-12 18:52:25.793707 +0000
@@ -131,12 +131,10 @@
                dtype=torch.float32,
                shape_ranges=[((1, 1, 1, 1), (1, 2, 3, 3), (3, 3, 3, 3))],
            ),
        ]

-        self.run_test_with_dynamic_shape(
-            Interpolate(), input_specs, expected_ops={acc_ops.interpolate}
-        )
+        self.run_test_with_dynamic_shape(Interpolate(), input_specs, expected_ops={acc_ops.interpolate})


if __name__ == "__main__":
    run_tests()
--- py/torch_tensorrt/fx/test/converters/acc_op/test_matmul.py	2022-08-12 18:48:48.790461 +0000
+++ py/torch_tensorrt/fx/test/converters/acc_op/test_matmul.py	2022-08-12 18:52:26.083066 +0000
@@ -71,15 +71,11 @@
        class MatMul(nn.Module):
            def forward(self, input, other):
                return torch.matmul(input, other)

        inputs = [torch.randn(*input_shape), torch.randn(*other_shape)]
-        test_implicit_batch_dim = (
-            input_shape[0] == other_shape[0]
-            and len(input_shape) > 2
-            and len(other_shape) > 2
-        )
+        test_implicit_batch_dim = input_shape[0] == other_shape[0] and len(input_shape) > 2 and len(other_shape) > 2
        self.run_test(
            MatMul(),
            inputs,
            expected_ops={acc_ops.matmul},
            test_implicit_batch_dim=test_implicit_batch_dim,
@@ -106,12 +102,10 @@
                dtype=torch.float32,
                shape_ranges=[((1, 1, 3, 3), (9, 4, 3, 3), (9, 4, 3, 3))],
            ),
        ]

-        self.run_test_with_dynamic_shape(
-            Matmul(), input_specs, expected_ops={acc_ops.matmul}
-        )
+        self.run_test_with_dynamic_shape(Matmul(), input_specs, expected_ops={acc_ops.matmul})


if __name__ == "__main__":
    run_tests()
--- py/torch_tensorrt/fx/test/converters/acc_op/test_max.py	2022-08-12 18:48:48.790461 +0000
+++ py/torch_tensorrt/fx/test/converters/acc_op/test_max.py	2022-08-12 18:52:26.198919 +0000
@@ -102,13 +102,11 @@
                dtype=torch.float32,
                shape_ranges=[((1, 1, 5, 5), (2, 3, 5, 5), (2, 3, 5, 5))],
            ),
        ]

-        self.run_test_with_dynamic_shape(
-            MaxDimReduce(), input_specs, expected_ops={acc_ops.max_dim_reduce}
-        )
+        self.run_test_with_dynamic_shape(MaxDimReduce(), input_specs, expected_ops={acc_ops.max_dim_reduce})

    def test_max_full_reduce(
        self,
    ):
        class MaxFullReduce(torch.nn.Module):
@@ -124,13 +122,11 @@
                dtype=torch.float32,
                shape_ranges=[((1, 1, 5, 5), (2, 3, 5, 5), (2, 3, 5, 5))],
            ),
        ]

-        self.run_test_with_dynamic_shape(
-            MaxFullReduce(), input_specs, expected_ops={acc_ops.max_full_reduce}
-        )
+        self.run_test_with_dynamic_shape(MaxFullReduce(), input_specs, expected_ops={acc_ops.max_full_reduce})

    def test_max_method(self):
        class MaxMethod(torch.nn.Module):
            def __init__(self):
                super().__init__()
@@ -149,12 +145,10 @@
                dtype=torch.float32,
                shape_ranges=[((1, 1, 5, 5), (2, 3, 5, 5), (2, 3, 5, 5))],
            ),
        ]

-        self.run_test_with_dynamic_shape(
-            MaxMethod(), input_specs, expected_ops={acc_ops.maximum}
-        )
+        self.run_test_with_dynamic_shape(MaxMethod(), input_specs, expected_ops={acc_ops.maximum})


if __name__ == "__main__":
    run_tests()
--- py/torch_tensorrt/fx/test/converters/acc_op/test_min.py	2022-08-12 18:48:48.790461 +0000
+++ py/torch_tensorrt/fx/test/converters/acc_op/test_min.py	2022-08-12 18:52:26.358565 +0000
@@ -101,13 +101,11 @@
                dtype=torch.float32,
                shape_ranges=[((1, 1, 5, 5), (2, 3, 5, 5), (2, 3, 5, 5))],
            ),
        ]

-        self.run_test_with_dynamic_shape(
-            MinDimReduce(), input_specs, expected_ops={acc_ops.min_dim_reduce}
-        )
+        self.run_test_with_dynamic_shape(MinDimReduce(), input_specs, expected_ops={acc_ops.min_dim_reduce})

    def test_min_full_reduce(
        self,
    ):
        class MinFullReduce(torch.nn.Module):
@@ -123,13 +121,11 @@
                dtype=torch.float32,
                shape_ranges=[((1, 1, 5, 5), (2, 3, 5, 5), (2, 3, 5, 5))],
            ),
        ]

-        self.run_test_with_dynamic_shape(
-            MinFullReduce(), input_specs, expected_ops={acc_ops.min_full_reduce}
-        )
+        self.run_test_with_dynamic_shape(MinFullReduce(), input_specs, expected_ops={acc_ops.min_full_reduce})

    def test_min_method(self):
        class MinMethod(torch.nn.Module):
            def __init__(self):
                super().__init__()
@@ -148,12 +144,10 @@
                dtype=torch.float32,
                shape_ranges=[((1, 1, 5, 5), (2, 3, 5, 5), (2, 3, 5, 5))],
            ),
        ]

-        self.run_test_with_dynamic_shape(
-            MinMethod(), input_specs, expected_ops={acc_ops.minimum}
-        )
+        self.run_test_with_dynamic_shape(MinMethod(), input_specs, expected_ops={acc_ops.minimum})


if __name__ == "__main__":
    run_tests()
--- py/torch_tensorrt/fx/test/converters/acc_op/test_narrow.py	2022-08-12 18:48:48.790461 +0000
+++ py/torch_tensorrt/fx/test/converters/acc_op/test_narrow.py	2022-08-12 18:52:26.408067 +0000
@@ -23,13 +23,11 @@
                dtype=torch.float32,
                shape_ranges=[((1, 3, 5, 5), (2, 3, 5, 5), (2, 3, 5, 5))],
            ),
        ]

-        self.run_test_with_dynamic_shape(
-            Narrow(), input_specs, expected_ops={acc_ops.slice_tensor}
-        )
+        self.run_test_with_dynamic_shape(Narrow(), input_specs, expected_ops={acc_ops.slice_tensor})


class TestNarrowConverter(AccTestCase):
    @parameterized.expand(
        [
--- py/torch_tensorrt/fx/converters/acc_ops_converters.py	2022-08-12 18:48:48.790461 +0000
+++ py/torch_tensorrt/fx/converters/acc_ops_converters.py	2022-08-12 18:52:26.424854 +0000
@@ -34,14 +34,11 @@
    kwargs: Dict[str, Argument],
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    input_val = kwargs["input"]
    if not isinstance(input_val, TRTTensor):
-        raise RuntimeError(
-            f"Conv received input {input_val} that is not part "
-            "of the TensorRT region!"
-        )
+        raise RuntimeError(f"Conv received input {input_val} that is not part " "of the TensorRT region!")

    # Process 1d input with unsqueeze -> conv2d -> squeeze to calculated conv1d
    unsqueeze_layer = network.add_shuffle(input=input_val)
    unsqueeze_layer.reshape_dims = tuple([*input_val.shape, 1])
    set_layer_name(unsqueeze_layer, target, name + "_unsqueeze")
@@ -52,13 +49,11 @@

    # for now we'll assume bias is constant Tensor or None,
    # and bias being ITensor is not supported in TensorRT api
    # right now
    if kwargs["bias"] is not None and not isinstance(kwargs["bias"], torch.Tensor):
-        raise RuntimeError(
-            f"linear {name} has bias of type {type(kwargs['bias'])}, Expect Optional[Tenosr]"
-        )
+        raise RuntimeError(f"linear {name} has bias of type {type(kwargs['bias'])}, Expect Optional[Tenosr]")
    bias = to_numpy(kwargs["bias"])  # type: ignore[arg-type]
    if bias is not None:
        bias = bias[None]
    weight = kwargs["weight"]

@@ -82,13 +77,11 @@
        )

        layer.set_input(1, weight)
    else:
        if not isinstance(kwargs["weight"], torch.Tensor):
-            raise RuntimeError(
-                f"linear {name} has weight of type {type(kwargs['weight'])}, Expect Optional[Tenosr]"
-            )
+            raise RuntimeError(f"linear {name} has weight of type {type(kwargs['weight'])}, Expect Optional[Tenosr]")
        weight = to_numpy(weight)
        weight = np.expand_dims(weight, -1)
        layer = network.add_convolution_nd(
            input=input_val,
            num_output_maps=weight.shape[0],
@@ -126,25 +119,20 @@
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    input_val = kwargs["input"]

    if not isinstance(input_val, TRTTensor):
-        raise RuntimeError(
-            f"Conv received input {input_val} that is not part "
-            "of the TensorRT region!"
-        )
+        raise RuntimeError(f"Conv received input {input_val} that is not part " "of the TensorRT region!")

    if has_dynamic_shape(input_val.shape):
        assert input_val.shape[1] != -1, "Channel dim can't be dynamic for convolution."

    # for now we'll assume bias is constant Tensor or None,
    # and bias being ITensor is not supported in TensorRT api
    # right now
    if kwargs["bias"] is not None and not isinstance(kwargs["bias"], torch.Tensor):
-        raise RuntimeError(
-            f"linear {name} has bias of type {type(kwargs['bias'])}, Expect Optional[Tenosr]"
-        )
+        raise RuntimeError(f"linear {name} has bias of type {type(kwargs['bias'])}, Expect Optional[Tenosr]")
    bias = to_numpy(kwargs["bias"])  # type: ignore[arg-type]

    if network.has_explicit_precision:
        weight = get_trt_tensor(network, kwargs["weight"], f"{name}_weight")
        weight_shape = tuple(kwargs["weight"].shape)  # type: ignore[union-attr]
@@ -160,13 +148,11 @@
        )

        layer.set_input(1, weight)
    else:
        if not isinstance(kwargs["weight"], torch.Tensor):
-            raise RuntimeError(
-                f"linear {name} has weight of type {type(kwargs['weight'])}, Expect Optional[Tenosr]"
-            )
+            raise RuntimeError(f"linear {name} has weight of type {type(kwargs['weight'])}, Expect Optional[Tenosr]")
        weight = to_numpy(kwargs["weight"])
        layer = network.add_convolution_nd(
            input=input_val,
            num_output_maps=weight.shape[0],
            kernel_shape=weight.shape[2:],
@@ -194,27 +180,20 @@
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    input_val = kwargs["input"]

    if not isinstance(input_val, TRTTensor):
-        raise RuntimeError(
-            f"Transpose conv received input {input_val} that is not part "
-            "of the TensorRT region!"
-        )
+        raise RuntimeError(f"Transpose conv received input {input_val} that is not part " "of the TensorRT region!")

    if has_dynamic_shape(input_val.shape):
-        assert (
-            input_val.shape[1] != -1
-        ), "Channel dim can't be dynamic for transpose convolution."
+        assert input_val.shape[1] != -1, "Channel dim can't be dynamic for transpose convolution."

    # for now we'll assume bias is constant Tensor or None,
    # and bias being ITensor is not supported in TensorRT api
    # right now
    if kwargs["bias"] is not None and not isinstance(kwargs["bias"], torch.Tensor):
-        raise RuntimeError(
-            f"ConvTranspose {name} has bias of type {type(kwargs['bias'])}, Expect Optional[Tensor]"
-        )
+        raise RuntimeError(f"ConvTranspose {name} has bias of type {type(kwargs['bias'])}, Expect Optional[Tensor]")
    bias = to_numpy(kwargs["bias"])  # type: ignore[arg-type]

    if network.has_explicit_precision:
        weight = get_trt_tensor(network, kwargs["weight"], f"{name}_weight")
        weight_shape = tuple(kwargs["weight"].shape)  # type: ignore[union-attr]
@@ -232,13 +211,11 @@
        )

        layer.set_input(1, weight)
    else:
        if not isinstance(kwargs["weight"], torch.Tensor):
-            raise RuntimeError(
-                f"conv {name} has weight of type {type(kwargs['weight'])}, Expect Optional[Tensor]"
-            )
+            raise RuntimeError(f"conv {name} has weight of type {type(kwargs['weight'])}, Expect Optional[Tensor]")
        weight = to_numpy(kwargs["weight"])
        # nn.ConvTranspose2d/3d weight size is (in_channels, out_channels/groups, kernel_0, kernel_1, [kernel_2])
        layer = network.add_deconvolution_nd(
            input=input_val,
            num_output_maps=weight.shape[1] * kwargs["groups"],
@@ -270,29 +247,20 @@
    mode = kwargs["mode"]
    value = kwargs["value"] if kwargs["value"] is not None else 0
    rank = len(input_val.shape)  # type: ignore[union-attr]

    if not isinstance(input_val, TRTTensor):
-        raise RuntimeError(
-            f"pad received input {input_val} that is not part "
-            "of the TensorRT region!"
-        )
+        raise RuntimeError(f"pad received input {input_val} that is not part " "of the TensorRT region!")

    if mode != "constant":
-        raise RuntimeError(
-            f"Currently we only support constant mode for pad, got {mode}."
-        )
+        raise RuntimeError(f"Currently we only support constant mode for pad, got {mode}.")

    if len(pad) / 2 > rank:
-        raise RuntimeError(
-            f"Trying to pad last {len(pad) / 2} dimension but the input only has {rank} dimension."
-        )
+        raise RuntimeError(f"Trying to pad last {len(pad) / 2} dimension but the input only has {rank} dimension.")

    if value != 0:
-        raise RuntimeError(
-            f"Currently we only support padding value of 0, got {value}."
-        )
+        raise RuntimeError(f"Currently we only support padding value of 0, got {value}.")

    if len(pad) > 4:
        raise RuntimeError("Currently we only support padding last two dimensions.")

    pre_padding = tuple(pad[len(pad) - i - 2] for i in range(0, len(pad), 2))
@@ -320,38 +288,28 @@
    mode = kwargs["mode"]
    value = kwargs["value"] if kwargs["value"] is not None else 0
    rank = len(input_val.shape)  # type: ignore[union-attr]

    if not isinstance(input_val, TRTTensor):
-        raise RuntimeError(
-            f"pad received input {input_val} that is not part "
-            "of the TensorRT region!"
-        )
+        raise RuntimeError(f"pad received input {input_val} that is not part " "of the TensorRT region!")

    if mode != "constant":
-        raise RuntimeError(
-            f"Currently we only support constant mode for pad, got {mode}."
-        )
+        raise RuntimeError(f"Currently we only support constant mode for pad, got {mode}.")

    if len(pad) / 2 > rank:
-        raise RuntimeError(
-            f"Trying to pad last {len(pad) / 2} dimension but the input only has {rank} dimension."
-        )
+        raise RuntimeError(f"Trying to pad last {len(pad) / 2} dimension but the input only has {rank} dimension.")

    # cast value to TRTensor
    dt = torch_dtype_from_trt(input_val.dtype)
    value = 0 if value == None else value
-    value_const = get_trt_tensor(
-        network, torch.tensor([value], dtype=dt), f"{name}_value"
-    )
+    value_const = get_trt_tensor(network, torch.tensor([value], dtype=dt), f"{name}_value")

    input_shape = input_val.shape
    pre_start = tuple(i - 1 for i in input_shape)
    prefix_len = len(input_shape) - len(pad) // 2
    pre_shape = tuple(
-        input_shape[i] + (pad[-(i - prefix_len) * 2 - 2] if i >= prefix_len else 0)
-        for i in range(0, len(input_shape))
+        input_shape[i] + (pad[-(i - prefix_len) * 2 - 2] if i >= prefix_len else 0) for i in range(0, len(input_shape))
    )
    pre_stride = [-1] * len(input_shape)

    layer = network.add_slice(
        input_val,
@@ -374,12 +332,11 @@
    transpose_output = layer.get_output(0)

    shape = transpose_output.shape
    post_start = tuple([0] * len(shape))
    post_shape = tuple(
-        shape[i] + (pad[-(i - prefix_len) * 2 - 1] if i >= prefix_len else 0)
-        for i in range(0, len(shape))
+        shape[i] + (pad[-(i - prefix_len) * 2 - 1] if i >= prefix_len else 0) for i in range(0, len(shape))
    )
    post_stride = tuple([1] * len(shape))

    layer = network.add_slice(transpose_output, post_start, post_shape, post_stride)
    layer.set_input(4, value_const)
@@ -397,22 +354,15 @@
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    input_val = kwargs["input"]

    if not isinstance(input_val, TRTTensor):
-        raise RuntimeError(
-            f"flatten received input {input_val} that is not part "
-            "of the TensorRT region!"
-        )
+        raise RuntimeError(f"flatten received input {input_val} that is not part " "of the TensorRT region!")

    num_dims = len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0)
-    start_dim = get_positive_dim(
-        cast(int, kwargs["start_dim"] if "start_dim" in kwargs else 0), num_dims
-    )
-    end_dim = get_positive_dim(
-        cast(int, kwargs["end_dim"] if "end_dim" in kwargs else -1), num_dims
-    )
+    start_dim = get_positive_dim(cast(int, kwargs["start_dim"] if "start_dim" in kwargs else 0), num_dims)
+    end_dim = get_positive_dim(cast(int, kwargs["end_dim"] if "end_dim" in kwargs else -1), num_dims)

    if network.has_implicit_batch_dimension:
        assert start_dim != 0, "Can't flatten batch dimension when it's implicit."
        start_dim -= 1
        end_dim -= 1
@@ -511,24 +461,18 @@
    kwargs: Dict[str, Argument],
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    input_t = kwargs["input"]
    if type(input_t) == torch.nn.Parameter or type(input_t) == torch.Tensor:
-        if (
-            not has_dynamic_shape(input_t.shape)
-            and network.has_implicit_batch_dimension
-        ):
+        if not has_dynamic_shape(input_t.shape) and network.has_implicit_batch_dimension:
            return torch.Size((IMPLICIT_BATCH_DIM,) + tuple(input_t.shape))
        return input_t.shape

    # input_val = get_trt_tensor(network, input_t, f"{name}_input_t")
    input_val = input_t
    if not isinstance(input_val, TRTTensor):
-        raise RuntimeError(
-            f"size received input {input_val} that is not part "
-            "of the TensorRT region!"
-        )
+        raise RuntimeError(f"size received input {input_val} that is not part " "of the TensorRT region!")

    if not has_dynamic_shape(input_val.shape):
        if network.has_implicit_batch_dimension:
            return torch.Size((IMPLICIT_BATCH_DIM,) + tuple(input_val.shape))
        return torch.Size(input_val.shape)
@@ -547,14 +491,11 @@
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    input_val = kwargs["input"]

    if not isinstance(input_val, TRTTensor):
-        raise RuntimeError(
-            f"size received input {input_val} that is not part "
-            "of the TensorRT region!"
-        )
+        raise RuntimeError(f"size received input {input_val} that is not part " "of the TensorRT region!")

    if has_dynamic_shape(input_val.shape):
        raise RuntimeError(f"numel does not support dynamic shapes.")

    numel = np.prod(input_val.shape)
@@ -572,29 +513,20 @@
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    input_val = kwargs["input"]

    if not isinstance(input_val, TRTTensor):
-        raise RuntimeError(
-            f"BatchNorm2d received input {input_val} that is not part "
-            "of the TensorRT region!"
-        )
+        raise RuntimeError(f"BatchNorm2d received input {input_val} that is not part " "of the TensorRT region!")

    if has_dynamic_shape(input_val.shape):
        assert input_val.shape[1] != -1, "Channel dim can't be dynamic for batch norm."

-    scale = cast(
-        torch.Tensor, to_numpy(cast(torch.Tensor, kwargs["weight"]))
-    ) / np.sqrt(
-        cast(torch.Tensor, to_numpy(cast(torch.Tensor, kwargs["running_var"])))
-        + cast(float, kwargs["eps"])
+    scale = cast(torch.Tensor, to_numpy(cast(torch.Tensor, kwargs["weight"]))) / np.sqrt(
+        cast(torch.Tensor, to_numpy(cast(torch.Tensor, kwargs["running_var"]))) + cast(float, kwargs["eps"])
    )

-    bias = (
-        to_numpy(cast(torch.Tensor, kwargs["bias"]))
-        - to_numpy(cast(torch.Tensor, kwargs["running_mean"])) * scale
-    )
+    bias = to_numpy(cast(torch.Tensor, kwargs["bias"])) - to_numpy(cast(torch.Tensor, kwargs["running_mean"])) * scale
    power = np.ones_like(scale)

    # For BatchNorm1d, reshape 1d to 2d
    output_shape = input_val.shape
    if not network.has_implicit_batch_dimension and len(input_val.shape) < 4:
@@ -628,44 +560,33 @@
@tensorrt_converter(acc_ops.layer_norm)
def acc_ops_layer_norm(network, target, args, kwargs, name):
    input_val = kwargs["input"]

    if not isinstance(input_val, trt.tensorrt.ITensor):
-        raise RuntimeError(
-            f"LayerNorm received input {input_val} that is not part "
-            "of the TensorRT region!"
-        )
+        raise RuntimeError(f"LayerNorm received input {input_val} that is not part " "of the TensorRT region!")

    gamma = kwargs["weight"].detach().cpu().float().numpy()
    gamma_field = trt.PluginField("gamma", gamma, trt.PluginFieldType.FLOAT32)
    beta = kwargs["bias"].detach().cpu().float().numpy()
    beta_field = trt.PluginField("beta", beta, trt.PluginFieldType.FLOAT32)
-    eps_field = trt.PluginField(
-        "eps", np.array([kwargs["eps"]], dtype=np.float32), trt.PluginFieldType.FLOAT32
-    )
+    eps_field = trt.PluginField("eps", np.array([kwargs["eps"]], dtype=np.float32), trt.PluginFieldType.FLOAT32)
    try:
        normalized_shape = np.array(kwargs["normalized_shape"], dtype=np.int32)
    except TypeError:
        _LOGGER.error("Unable to convert normalized_shape to a field, fall back to []")
        normalized_shape = np.array([], dtype=np.int32)

-    normalized_shape_filed = trt.PluginField(
-        "normalized_shape", normalized_shape, trt.PluginFieldType.INT32
-    )
-    field_collection = trt.PluginFieldCollection(
-        [gamma_field, beta_field, eps_field, normalized_shape_filed]
-    )
+    normalized_shape_filed = trt.PluginField("normalized_shape", normalized_shape, trt.PluginFieldType.INT32)
+    field_collection = trt.PluginFieldCollection([gamma_field, beta_field, eps_field, normalized_shape_filed])

    try:
        if network.has_implicit_batch_dimension:
            plugin = get_trt_plugin("layer_norm", field_collection, "1", "fx2trt")
        else:
            plugin = get_trt_plugin("LayerNormDynamic", field_collection, "1", "fx2trt")
    except AssertionError:
-        _LOGGER.error(
-            "Unable to find layer norm plugin, fall back to TensorRT implementation."
-        )
+        _LOGGER.error("Unable to find layer norm plugin, fall back to TensorRT implementation.")
        return layer_norm(network, target, args, kwargs, name)
    layer = network.add_plugin_v2([input_val], plugin)
    layer.name = name
    return layer.get_output(0)

@@ -678,14 +599,11 @@
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    input_val = kwargs["input"]

    if not isinstance(input_val, TRTTensor):
-        raise RuntimeError(
-            f"LayerNorm received input {input_val} that is not part "
-            "of the TensorRT region!"
-        )
+        raise RuntimeError(f"LayerNorm received input {input_val} that is not part " "of the TensorRT region!")

    shape = kwargs["weight"].shape  # type: ignore[union-attr]
    broadcasted_shape = (1,) * (len(input_val.shape) - len(shape)) + shape
    gamma = to_numpy(kwargs["weight"].reshape(*shape))  # type: ignore[union-attr]
    beta = to_numpy(kwargs["bias"].reshape(*shape))  # type: ignore[union-attr]
@@ -694,13 +612,11 @@
    axes = 0
    for d in range(len(shape)):
        axes |= 1 << (len(input_val.shape) - d - 1)

    # E[x]
-    mean_expected_layer = network.add_reduce(
-        input_val, trt.ReduceOperation.AVG, axes, keep_dims=True
-    )
+    mean_expected_layer = network.add_reduce(input_val, trt.ReduceOperation.AVG, axes, keep_dims=True)
    set_layer_name(mean_expected_layer, target, f"{name}_mean_expected")

    # X-E[x]
    sub_trt = add_binary_elementwise_layer(
        network,
@@ -722,13 +638,11 @@
        pow_tensor.get_output(0),
        trt.ElementWiseOperation.POW,
        target,
        f"{name}_pow_var",
    )
-    mean_trt_layer = network.add_reduce(
-        pow_var, trt.ReduceOperation.AVG, axes, keep_dims=True
-    )
+    mean_trt_layer = network.add_reduce(pow_var, trt.ReduceOperation.AVG, axes, keep_dims=True)
    set_layer_name(mean_trt_layer, target, f"{name}_mean")
    # Variance + eps
    eps_tensor = network.add_constant(
        (1,) * len(input_val.shape),
        trt.Weights(np.ascontiguousarray([eps], dtype=np.float32)),
@@ -741,13 +655,11 @@
        trt.ElementWiseOperation.SUM,
        target,
        f"{name}_add",
    )
    # SQRT((Var + eps))
-    sqrt_trt = add_unary_layer(
-        network, add_trt, trt.UnaryOperation.SQRT, target, f"{name}_sqrt"
-    )
+    sqrt_trt = add_unary_layer(network, add_trt, trt.UnaryOperation.SQRT, target, f"{name}_sqrt")
    # (x - E[x]) / sqrt((var + eps))
    div_trt = add_binary_elementwise_layer(
        network,
        sub_trt,
        sqrt_trt,
@@ -791,14 +703,11 @@
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    input_val = kwargs["input"]
    input_ranks = len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0)  # type: ignore[union-attr]

    if not isinstance(input_val, TRTTensor):
-        raise RuntimeError(
-            f"softmax received input {input_val} that is not part "
-            "of the TensorRT region!"
-        )
+        raise RuntimeError(f"softmax received input {input_val} that is not part " "of the TensorRT region!")

    # Used to get dim when dim is None. Copied from PyTorch softmax implementation.
    def get_softmax_dim(ndim: int) -> int:
        if ndim == 0 or ndim == 1 or ndim == 3:
            ret = 0
@@ -832,13 +741,11 @@
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    input_t = kwargs["input"]
    input_val = get_trt_tensor(network, input_t, f"{name}_input")

    dims = tuple(cast(Sequence[int], kwargs["dims"]))
-    n_input_dims = len(input_val.shape) + (
-        1 if network.has_implicit_batch_dimension else 0
-    )
+    n_input_dims = len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0)

    if len(dims) > n_input_dims:
        assert not network.has_implicit_batch_dimension
        layer = network.add_shuffle(input_val)
        layer.name = f"{name}_reshape"
@@ -849,20 +756,16 @@
            input_shape_layer.name = f"{name}_input_shape"
            preceding_ones = network.add_constant(
                (num_preceding_ones,),
                np.ascontiguousarray([1] * num_preceding_ones, np.int32),
            ).get_output(0)
-            reshape_layer = network.add_concatenation(
-                [preceding_ones, input_shape_layer.get_output(0)]
-            )
+            reshape_layer = network.add_concatenation([preceding_ones, input_shape_layer.get_output(0)])
            reshape_layer.axis = 0
            reshape_layer.name = f"{name}_reshape_dims"
            layer.set_input(1, reshape_layer.get_output(0))
        else:
-            layer.reshape_dims = (1,) * (len(dims) - n_input_dims) + tuple(
-                input_val.shape
-            )
+            layer.reshape_dims = (1,) * (len(dims) - n_input_dims) + tuple(input_val.shape)
        input_val = layer.get_output(0)
    else:
        dims = (1,) * (n_input_dims - len(dims)) + dims

    if network.has_implicit_batch_dimension:
@@ -898,17 +801,15 @@
    layer = network.add_slice(input_val, starts, shapes, strides)
    layer.mode = trt.SliceMode.WRAP
    set_layer_name(layer, target, name)

    if has_dynamic_shape(input_val.shape):  # type: ignore[union-attr]
-        starts_tensor = network.add_constant(
-            (len(dims),), np.ascontiguousarray([0] * len(dims), np.int32)
-        ).get_output(0)
+        starts_tensor = network.add_constant((len(dims),), np.ascontiguousarray([0] * len(dims), np.int32)).get_output(
+            0
+        )
        if all(isinstance(d, int) for d in dims):
-            dims_tensor = network.add_constant(
-                (len(dims),), np.ascontiguousarray(dims, np.int32)
-            ).get_output(0)
+            dims_tensor = network.add_constant((len(dims),), np.ascontiguousarray(dims, np.int32)).get_output(0)
        else:
            assert all(isinstance(d, TRTTensor) for d in dims)
            concat_dims_layer = network.add_concatenation(inputs=dims)
            concat_dims_layer.axis = 0
            concat_dims_layer.name = f"{name}_tile_dim"
@@ -969,13 +870,11 @@
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    input_val = kwargs["input"]
    negative_slope = kwargs["negative_slope"]
    operation_type = trt.ActivationType.LEAKY_RELU
-    return add_activation_layer(
-        network, input_val, operation_type, target, name, negative_slope
-    )
+    return add_activation_layer(network, input_val, operation_type, target, name, negative_slope)


@tensorrt_converter(acc_ops.elu)
def acc_ops_elu(
    network: TRTNetwork,
@@ -1243,51 +1142,40 @@
    target: Target,
    args: Tuple[Argument, ...],
    kwargs: Dict[str, Argument],
    name: str,
) -> TRTTensor:
-    return add_reduce_layer(
-        network, target, args, kwargs, trt.ReduceOperation.SUM, name
-    )
+    return add_reduce_layer(network, target, args, kwargs, trt.ReduceOperation.SUM, name)


@tensorrt_converter(acc_ops.prod)
def acc_ops_prod(
    network: TRTNetwork,
    target: Target,
    args: Tuple[Argument, ...],
    kwargs: Dict[str, Argument],
    name: str,
) -> TRTTensor:
-    return add_reduce_layer(
-        network, target, args, kwargs, trt.ReduceOperation.PROD, name
-    )
+    return add_reduce_layer(network, target, args, kwargs, trt.ReduceOperation.PROD, name)


@tensorrt_converter(acc_ops.mean)
def acc_ops_mean(
    network: TRTNetwork,
    target: Target,
    args: Tuple[Argument, ...],
    kwargs: Dict[str, Argument],
    name: str,
) -> TRTTensor:
-    return add_reduce_layer(
-        network, target, args, kwargs, trt.ReduceOperation.AVG, name
-    )
+    return add_reduce_layer(network, target, args, kwargs, trt.ReduceOperation.AVG, name)


def add_acc_ops_full_reduce(network, target, args, kwargs, name, reduce_op):
    input_val = kwargs["input"]
    if not isinstance(input_val, TRTTensor):
-        raise RuntimeError(
-            f"max received input {input_val} that is not part "
-            "of the TensorRT region!"
-        )
-    assert (
-        not network.has_implicit_batch_dimension
-    ), "Do not support max over all the elements for implicit batch."
+        raise RuntimeError(f"max received input {input_val} that is not part " "of the TensorRT region!")
+    assert not network.has_implicit_batch_dimension, "Do not support max over all the elements for implicit batch."

    dim = range(len(input_val.shape))

    layer = network.add_reduce(
        input_val,
@@ -1307,25 +1195,21 @@
        new_kwargs["largest"] = True
    elif reduce_op == trt.ReduceOperation.MIN:
        new_kwargs["largest"] = False
    new_kwargs["sorted"] = False

-    topk_out0, topk_out1 = acc_ops_topk(
-        network, target, args, new_kwargs, name + "_topk"
-    )
+    topk_out0, topk_out1 = acc_ops_topk(network, target, args, new_kwargs, name + "_topk")

    topk_out0.name = f"{name}_topk0"
    topk_out1.name = f"{name}_topk1"

    if "keepdim" in new_kwargs and new_kwargs["keepdim"]:
        return topk_out0, topk_out1

    dim = new_kwargs["dim"]
    if network.has_implicit_batch_dimension:
-        assert (
-            dim != 0
-        ), "can't reduce on dim == 0 when network has implicit batch dimension"
+        assert dim != 0, "can't reduce on dim == 0 when network has implicit batch dimension"
        # we remove the first dim in the shape tuple when it is implicit
        dim -= 1
    input_val = topk_out0
    shape = input_val.shape

@@ -1355,52 +1239,44 @@
    target: Target,
    args: Tuple[Argument, ...],
    kwargs: Dict[str, Argument],
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
-    return add_acc_ops_full_reduce(
-        network, target, args, kwargs, name, trt.ReduceOperation.MAX
-    )
+    return add_acc_ops_full_reduce(network, target, args, kwargs, name, trt.ReduceOperation.MAX)


@tensorrt_converter(acc_ops.min_full_reduce, no_implicit_batch_dim=True)
def acc_ops_min_full_reduce(
    network: TRTNetwork,
    target: Target,
    args: Tuple[Argument, ...],
    kwargs: Dict[str, Argument],
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
-    return add_acc_ops_full_reduce(
-        network, target, args, kwargs, name, trt.ReduceOperation.MIN
-    )
+    return add_acc_ops_full_reduce(network, target, args, kwargs, name, trt.ReduceOperation.MIN)


@tensorrt_converter(acc_ops.max_dim_reduce)
def acc_ops_max_dim_reduce(
    network: TRTNetwork,
    target: Target,
    args: Tuple[Argument, ...],
    kwargs: Dict[str, Argument],
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
-    return add_acc_ops_dim_reduce(
-        network, target, args, kwargs, name, trt.ReduceOperation.MAX
-    )
+    return add_acc_ops_dim_reduce(network, target, args, kwargs, name, trt.ReduceOperation.MAX)


@tensorrt_converter(acc_ops.min_dim_reduce)
def acc_ops_min_dim_reduce(
    network: TRTNetwork,
    target: Target,
    args: Tuple[Argument, ...],
    kwargs: Dict[str, Argument],
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
-    return add_acc_ops_dim_reduce(
-        network, target, args, kwargs, name, trt.ReduceOperation.MIN
-    )
+    return add_acc_ops_dim_reduce(network, target, args, kwargs, name, trt.ReduceOperation.MIN)


@tensorrt_converter(acc_ops.maximum)
def acc_ops_maximum(
    network: TRTNetwork,
@@ -1503,32 +1379,24 @@
    args: Tuple[Argument, ...],
    kwargs: Dict[str, Argument],
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    if network.has_implicit_batch_dimension:
-        raise RuntimeError(
-            "The `logical_and` function should be called with explicit batch dimension."
-        )
+        raise RuntimeError("The `logical_and` function should be called with explicit batch dimension.")

    input_t = kwargs["input"]
    other_t = kwargs["other"]
    # we only support both inputs are bool type
    if target == acc_ops.bitwise_and:

        def check_is_bool(input_t):
            if isinstance(input_t, TRTTensor):
-                assert (
-                    input_t.dtype == trt.bool
-                ), "We currently do not support input is non-bool"
+                assert input_t.dtype == trt.bool, "We currently do not support input is non-bool"
            elif isinstance(input_t, torch.Tensor):
-                assert (
-                    input_t.dtype == torch.bool
-                ), "We currently do not support input is non-bool"
+                assert input_t.dtype == torch.bool, "We currently do not support input is non-bool"
            else:
-                assert isinstance(
-                    input_t.bool
-                ), "We currently do not support input is non-bool"
+                assert isinstance(input_t.bool), "We currently do not support input is non-bool"

        check_is_bool(input_t)
        check_is_bool(other_t)

    input_t = get_trt_tensor(network, input_t, f"{name}_input_t")
@@ -1536,13 +1404,11 @@

    if input_t.dtype != trt.bool:
        input_t = type_cast(network, target, f"{name}_input", input_t, trt.bool)
    if other_t.dtype != trt.bool:
        other_t = type_cast(network, target, f"{name}_other", other_t, trt.bool)
-    return add_binary_elementwise_layer(
-        network, input_t, other_t, trt.ElementWiseOperation.AND, target, name
-    )
+    return add_binary_elementwise_layer(network, input_t, other_t, trt.ElementWiseOperation.AND, target, name)


@tensorrt_converter(acc_ops.ne, no_implicit_batch_dim=True)
def acc_ops_ne(
    network: TRTNetwork,
@@ -1550,24 +1416,20 @@
    args: Tuple[Argument, ...],
    kwargs: Dict[str, Argument],
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    if network.has_implicit_batch_dimension:
-        raise RuntimeError(
-            "The `ne` function should be called with explicit batch dimension."
-        )
+        raise RuntimeError("The `ne` function should be called with explicit batch dimension.")

    input_t = kwargs["input"]
    other_t = kwargs["other"]

    input_t = get_trt_tensor(network, input_t, f"{name}_input_t")
    other_t = get_trt_tensor(network, other_t, f"{name}_other_t")

    input_t, other_t = dtype_uniform(network, target, name, input_t, other_t)
-    eq_t = add_binary_elementwise_layer(
-        network, input_t, other_t, trt.ElementWiseOperation.EQUAL, target, name
-    )
+    eq_t = add_binary_elementwise_layer(network, input_t, other_t, trt.ElementWiseOperation.EQUAL, target, name)

    return add_unary_layer(network, eq_t, trt.UnaryOperation.NOT, target, name)


@tensorrt_converter(acc_ops.eq, no_implicit_batch_dim=True)
@@ -1577,24 +1439,20 @@
    args: Tuple[Argument, ...],
    kwargs: Dict[str, Argument],
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    if network.has_implicit_batch_dimension:
-        raise RuntimeError(
-            "The `eq` function should be called with explicit batch dimension."
-        )
+        raise RuntimeError("The `eq` function should be called with explicit batch dimension.")

    input_t = kwargs["input"]
    other_t = kwargs["other"]

    input_t = get_trt_tensor(network, input_t, f"{name}_input_t")
    other_t = get_trt_tensor(network, other_t, f"{name}_other_t")

    input_t, other_t = dtype_uniform(network, target, name, input_t, other_t)
-    return add_binary_elementwise_layer(
-        network, input_t, other_t, trt.ElementWiseOperation.EQUAL, target, name
-    )
+    return add_binary_elementwise_layer(network, input_t, other_t, trt.ElementWiseOperation.EQUAL, target, name)


@tensorrt_converter(acc_ops.gt, no_implicit_batch_dim=True)
def acc_ops_gt(
    network: TRTNetwork,
@@ -1602,24 +1460,20 @@
    args: Tuple[Argument, ...],
    kwargs: Dict[str, Argument],
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    if network.has_implicit_batch_dimension:
-        raise RuntimeError(
-            "The `gt` function should be called with explicit batch dimension."
-        )
+        raise RuntimeError("The `gt` function should be called with explicit batch dimension.")

    input_t = kwargs["input"]
    other_t = kwargs["other"]

    input_t = get_trt_tensor(network, input_t, f"{name}_input_t")
    other_t = get_trt_tensor(network, other_t, f"{name}_other_t")

    input_t, other_t = dtype_uniform(network, target, name, input_t, other_t)
-    return add_binary_elementwise_layer(
-        network, input_t, other_t, trt.ElementWiseOperation.GREATER, target, name
-    )
+    return add_binary_elementwise_layer(network, input_t, other_t, trt.ElementWiseOperation.GREATER, target, name)


@tensorrt_converter(acc_ops.lt, no_implicit_batch_dim=True)
def acc_ops_lt(
    network: TRTNetwork,
@@ -1627,24 +1481,20 @@
    args: Tuple[Argument, ...],
    kwargs: Dict[str, Argument],
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    if network.has_implicit_batch_dimension:
-        raise RuntimeError(
-            "The `le` function should be called with explicit batch dimension."
-        )
+        raise RuntimeError("The `le` function should be called with explicit batch dimension.")

    input_t = kwargs["input"]
    other_t = kwargs["other"]

    input_t = get_trt_tensor(network, input_t, f"{name}_input_t")
    other_t = get_trt_tensor(network, other_t, f"{name}_other_t")

    input_t, other_t = dtype_uniform(network, target, name, input_t, other_t)
-    return add_binary_elementwise_layer(
-        network, input_t, other_t, trt.ElementWiseOperation.LESS, target, name
-    )
+    return add_binary_elementwise_layer(network, input_t, other_t, trt.ElementWiseOperation.LESS, target, name)


@tensorrt_converter(acc_ops.logical_or, no_implicit_batch_dim=True)
def acc_ops_logical_or(
    network: TRTNetwork,
@@ -1652,13 +1502,11 @@
    args: Tuple[Argument, ...],
    kwargs: Dict[str, Argument],
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    if network.has_implicit_batch_dimension:
-        raise RuntimeError(
-            "The `logical_or` function should be called with explicit batch dimension."
-        )
+        raise RuntimeError("The `logical_or` function should be called with explicit batch dimension.")

    input_t = kwargs["input"]
    other_t = kwargs["other"]
    if isinstance(other_t, (torch.Tensor, bool)):
        if isinstance(other_t, bool):
@@ -1675,13 +1523,11 @@
        layer_o = network.add_identity(other_t)
        layer_o.set_output_type(0, trt.bool)
        set_layer_name(layer_o, target, f"{name}_other_dtype_change")
        other_t = layer_o.get_output(0)

-    return add_binary_elementwise_layer(
-        network, input_t, other_t, trt.ElementWiseOperation.OR, target, name
-    )
+    return add_binary_elementwise_layer(network, input_t, other_t, trt.ElementWiseOperation.OR, target, name)


@tensorrt_converter(acc_ops.logical_xor, no_implicit_batch_dim=True)
def acc_ops_logical_xor(
    network: TRTNetwork,
@@ -1689,13 +1535,11 @@
    args: Tuple[Argument, ...],
    kwargs: Dict[str, Argument],
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    if network.has_implicit_batch_dimension:
-        raise RuntimeError(
-            "The `logical_xor` function should be called with explicit batch dimension."
-        )
+        raise RuntimeError("The `logical_xor` function should be called with explicit batch dimension.")

    input_t = kwargs["input"]
    other_t = kwargs["other"]
    if isinstance(other_t, (torch.Tensor, bool)):
        if isinstance(other_t, bool):
@@ -1712,13 +1556,11 @@
        layer_o = network.add_identity(other_t)
        layer_o.set_output_type(0, trt.bool)
        set_layer_name(layer_o, target, f"{name}_other_dtype_change")
        other_t = layer_o.get_output(0)

-    return add_binary_elementwise_layer(
-        network, input_t, other_t, trt.ElementWiseOperation.XOR, target, name
-    )
+    return add_binary_elementwise_layer(network, input_t, other_t, trt.ElementWiseOperation.XOR, target, name)


# T113156424 Have some accuracy problems in hf_T5.
# [TRT] [W] Weights [name=isinf_1_inf_t]: Converted FP32 value in weights (either FP32 infinity or FP32 value outside FP16 range) to corresponding FP16 infinity. If this is not the desired behavior, please modify the weights or retrain with regularization to reduce the magnitude of the weights.
# @tensorrt_converter(acc_ops.isinf)
@@ -1764,26 +1606,19 @@
    kwargs: Dict[str, Argument],
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    input_t = kwargs["input"]
    if not isinstance(input_t, TRTTensor):
-        raise RuntimeError(
-            f"isinf received input {input_t} that is not part "
-            "of the TensorRT region!"
-        )
+        raise RuntimeError(f"isinf received input {input_t} that is not part " "of the TensorRT region!")

    if input_t.dtype in (trt.float32, trt.float16, trt.int32):
-        comp_t = torch.zeros(tuple([*input_t.shape])).to(
-            torch_dtype_from_trt(input_t.dtype)
-        )
+        comp_t = torch.zeros(tuple([*input_t.shape])).to(torch_dtype_from_trt(input_t.dtype))
        comp_t = get_trt_tensor(network, comp_t, f"{name}_comp_t")
        kwargs_new = {"input": input_t, "other": comp_t}
        eq_output = acc_ops_eq(network, target, None, kwargs_new, name + "_eq")
        kwargs_new = {"input": eq_output}
-        not_output = acc_ops_logical_not(
-            network, target, None, kwargs_new, name + "_not"
-        )
+        not_output = acc_ops_logical_not(network, target, None, kwargs_new, name + "_not")
    else:
        not_output = input_t
    # cast bool result to int
    int_output = type_cast(network, target, f"{name}_cast_int", not_output, trt.int32)
    # sum
@@ -1809,13 +1644,11 @@
    args: Tuple[Argument, ...],
    kwargs: Dict[str, Argument],
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    # NOTE: TRT doesnt currently implement fmod so we need multiple operations to perform it
-    trunc_div_value = trunc_div(
-        kwargs["input"], kwargs["other"], network, target, name + "_trunc_div"
-    )
+    trunc_div_value = trunc_div(kwargs["input"], kwargs["other"], network, target, name + "_trunc_div")
    prod_value = add_binary_elementwise_layer(
        network,
        trunc_div_value,
        kwargs["other"],
        trt.ElementWiseOperation.PROD,
@@ -1907,14 +1740,11 @@
    kwargs: Dict[str, Argument],
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    input_trt = kwargs["input"]
    if not isinstance(input_trt, TRTTensor):
-        raise RuntimeError(
-            f"Max_pool1d received input {input_trt} that is not part "
-            "of the TensorRT region!"
-        )
+        raise RuntimeError(f"Max_pool1d received input {input_trt} that is not part " "of the TensorRT region!")

    # adds unsqueeze layer -> max pool 2d -> squeeze layer to emulate max pool 1d.
    unsqueeze_layer = network.add_shuffle(input=input_trt)
    unsqueeze_layer.reshape_dims = tuple([*input_trt.shape, 1])
    set_layer_name(unsqueeze_layer, target, name + "_unsqueeze")
@@ -1929,25 +1759,16 @@
    ceil_mode = kwargs["ceil_mode"]

    if len(stride) == 0 or stride[0] == None:
        stride = kernel_size

-    if any(
-        [
-            not isinstance(param, int)
-            for param in [kernel_size[0], stride[0], padding[0], dilation[0]]
-        ]
-    ):
-        raise RuntimeError(
-            f"Parameters kernel_size, stride, padding, and dilation should be of type int."
-        )
+    if any([not isinstance(param, int) for param in [kernel_size[0], stride[0], padding[0], dilation[0]]]):
+        raise RuntimeError(f"Parameters kernel_size, stride, padding, and dilation should be of type int.")
    if dilation[0] != 1:
        raise RuntimeError(f"Only support dilation=1 for maxpool, but got {dilation}")

-    max_pooling_layer = network.add_pooling(
-        input=input_trt, type=trt.PoolingType.MAX, window_size=(kernel_size[0], 1)
-    )
+    max_pooling_layer = network.add_pooling(input=input_trt, type=trt.PoolingType.MAX, window_size=(kernel_size[0], 1))
    max_pooling_layer.stride_nd = stride + (1,)
    max_pooling_layer.padding_nd = padding + (0,)
    set_layer_name(max_pooling_layer, target, name)

    if ceil_mode:
@@ -1969,14 +1790,11 @@
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    input_val = kwargs["input"]

    if not isinstance(input_val, TRTTensor):
-        raise RuntimeError(
-            f"MaxPool2d received input {input_val} that is not part "
-            "of the TensorRT region!"
-        )
+        raise RuntimeError(f"MaxPool2d received input {input_val} that is not part " "of the TensorRT region!")
    extend_len = 2 if target == acc_ops.max_pool2d else 3
    kernel_size = extend_attr_to_tuple(kwargs["kernel_size"], extend_len)
    stride = extend_attr_to_tuple(kwargs["stride"], extend_len)
    padding = extend_attr_to_tuple(kwargs["padding"], extend_len)
    dilation = extend_attr_to_tuple(kwargs["dilation"], extend_len)
@@ -1985,17 +1803,13 @@
    if len(stride) == 0 or stride[0] == None:
        stride = kernel_size

    ones = (1,) * extend_len
    if dilation != ones:
-        raise RuntimeError(
-            f"Only support dilation=(1, 1) for maxpool, but got {dilation}"
-        )
-
-    layer = network.add_pooling_nd(
-        input=input_val, type=trt.PoolingType.MAX, window_size=kernel_size
-    )
+        raise RuntimeError(f"Only support dilation=(1, 1) for maxpool, but got {dilation}")
+
+    layer = network.add_pooling_nd(input=input_val, type=trt.PoolingType.MAX, window_size=kernel_size)
    layer.stride_nd = stride
    layer.padding_nd = padding
    set_layer_name(layer, target, name)

    if ceil_mode:
@@ -2013,23 +1827,18 @@
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    input_val = kwargs["input"]

    if not isinstance(input_val, TRTTensor):
-        raise RuntimeError(
-            f"squeeze received input {input_val} that is not part "
-            "of the TensorRT region!"
-        )
+        raise RuntimeError(f"squeeze received input {input_val} that is not part " "of the TensorRT region!")

    dim = cast(Optional[int], kwargs["dim"] if "dim" in kwargs else None)
    # Squeeze with dim=None would only work in explicit batch dim mode without any dynamic
    # dim, which is a very rare case. For now we just claim not supporting dim=None.
    assert dim is not None, "We don't support dim=None right now for squeeze."

-    dim = get_positive_dim(
-        dim, len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0)
-    )
+    dim = get_positive_dim(dim, len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0))
    if network.has_implicit_batch_dimension:
        assert dim != 0, "We don't support squeeze batch dim when it's implicit."
        dim -= 1

    assert input_val.shape[dim] != -1, "We don't support squeeze dynamic dim."
@@ -2176,35 +1985,26 @@
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    input_t = kwargs["input"]
    input_val = get_trt_tensor(network, input_t, f"{name}_input_t")
    if not isinstance(input_val, TRTTensor):
-        raise RuntimeError(
-            f"unsqueeze received input {input_val} that is not part "
-            "of the TensorRT region!"
-        )
+        raise RuntimeError(f"unsqueeze received input {input_val} that is not part " "of the TensorRT region!")

    dim = cast(int, kwargs["dim"])
    input_shape = input_val.shape
-    input_shape_size = (
-        len(input_val.shape) + 1
-        if network.has_implicit_batch_dimension
-        else len(input_val.shape)
-    )
+    input_shape_size = len(input_val.shape) + 1 if network.has_implicit_batch_dimension else len(input_val.shape)
    dim = get_positive_dim(dim, input_shape_size + 1)

    if network.has_implicit_batch_dimension:
        assert dim != 0
        dim -= 1

    assert (
        len(get_dynamic_dims(input_val.shape)) <= 1
    ), "Currently we don't support unsqueeze with more than one dynamic dims."
    layer = network.add_shuffle(input_val)
-    layer.reshape_dims = (
-        tuple(input_val.shape)[:dim] + (1,) + tuple(input_val.shape)[dim:]
-    )
+    layer.reshape_dims = tuple(input_val.shape)[:dim] + (1,) + tuple(input_val.shape)[dim:]
    set_layer_name(layer, target, name)
    return layer.get_output(0)


@tensorrt_converter(acc_ops.topk)
@@ -2216,14 +2016,11 @@
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    input_val = kwargs["input"]

    if not isinstance(input_val, TRTTensor):
-        raise RuntimeError(
-            f"topk received input {input_val} that is not part "
-            "of the TensorRT region!"
-        )
+        raise RuntimeError(f"topk received input {input_val} that is not part " "of the TensorRT region!")

    if kwargs["sorted"] and kwargs["k"] != 1:
        raise RuntimeError("Currently we don't support sorted=True in topk.")

    if not network.has_implicit_batch_dimension and len(input_val.shape) <= 1:
@@ -2253,40 +2050,28 @@
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    input_val = kwargs["input"]

    if not isinstance(input_val, TRTTensor):
-        raise RuntimeError(
-            f"AdaptiveAvgPool2d received input {input_val} that is not part "
-            "of the TensorRT region!"
-        )
+        raise RuntimeError(f"AdaptiveAvgPool2d received input {input_val} that is not part " "of the TensorRT region!")

    extend_len = 2 if target == acc_ops.adaptive_avg_pool2d else 3
    assert all(
        input_val.shape[-(i + 1)] != -1 for i in range(extend_len)
    ), "AdaptiveAvgPool2d and AdaptiveAvgPool3d currently doesn't support dynamic shapes for last two dims."

-    output_size = cast(
-        Sequence[int], extend_attr_to_tuple(kwargs["output_size"], extend_len)
-    )
+    output_size = cast(Sequence[int], extend_attr_to_tuple(kwargs["output_size"], extend_len))
    for input_dim, output_dim in zip(input_val.shape[-extend_len:], output_size):
        if input_dim % output_dim != 0:
            raise RuntimeError(
                "For AdaptiveAvgPool, input dim has to be integer multiple of output dim."
                f"Got input dim {input_dim}, output dim {output_dim}"
            )

-    stride = tuple(
-        input_val.shape[-extend_len + i] // output_size[i] for i in range(extend_len)
-    )
-    kernel_size = tuple(
-        input_val.shape[-extend_len + i] - (output_size[i] - 1) * stride[i]
-        for i in range(extend_len)
-    )
-    layer = network.add_pooling_nd(
-        input=input_val, type=trt.PoolingType.AVERAGE, window_size=kernel_size
-    )
+    stride = tuple(input_val.shape[-extend_len + i] // output_size[i] for i in range(extend_len))
+    kernel_size = tuple(input_val.shape[-extend_len + i] - (output_size[i] - 1) * stride[i] for i in range(extend_len))
+    layer = network.add_pooling_nd(input=input_val, type=trt.PoolingType.AVERAGE, window_size=kernel_size)
    layer.stride_nd = stride
    set_layer_name(layer, target, name)

    return layer.get_output(0)

@@ -2300,14 +2085,11 @@
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    input_val = kwargs["input"]

    if not isinstance(input_val, TRTTensor):
-        raise RuntimeError(
-            f"AvgPool1d received input {input_val} that is not part "
-            "of the TensorRT region!"
-        )
+        raise RuntimeError(f"AvgPool1d received input {input_val} that is not part " "of the TensorRT region!")

    kernel_size = extend_attr_to_tuple(kwargs["kernel_size"], 1)
    stride = extend_attr_to_tuple(kwargs["stride"], 1)
    padding = extend_attr_to_tuple(kwargs["padding"], 1)
    ceil_mode = kwargs["ceil_mode"]
@@ -2319,13 +2101,11 @@
    shuffle_layer = network.add_shuffle(input_val)
    shuffle_layer.reshape_dims = tuple(input_val.shape) + (1,)
    set_layer_name(shuffle_layer, target, name + "_shuffle1")
    shuffle_out = shuffle_layer.get_output(0)

-    layer = network.add_pooling_nd(
-        input=shuffle_out, type=trt.PoolingType.AVERAGE, window_size=(kernel_size[0], 1)
-    )
+    layer = network.add_pooling_nd(input=shuffle_out, type=trt.PoolingType.AVERAGE, window_size=(kernel_size[0], 1))

    layer.stride_nd = stride + (1,)
    layer.padding_nd = padding + (0,)
    layer.average_count_excludes_padding = False if count_include_pad else True
    set_layer_name(layer, target, name)
@@ -2349,14 +2129,11 @@
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    input_val = kwargs["input"]

    if not isinstance(input_val, TRTTensor):
-        raise RuntimeError(
-            f"AvgPool2d received input {input_val} that is not part "
-            "of the TensorRT region!"
-        )
+        raise RuntimeError(f"AvgPool2d received input {input_val} that is not part " "of the TensorRT region!")

    kernel_size = extend_attr_to_tuple(kwargs["kernel_size"], 2)
    stride = extend_attr_to_tuple(kwargs["stride"], 2)
    padding = extend_attr_to_tuple(kwargs["padding"], 2)
    ceil_mode = kwargs["ceil_mode"]
@@ -2367,13 +2144,11 @@
        stride = kernel_size

    if divisor_override:
        raise RuntimeError("TensorRT does not support divisor_override.")

-    layer = network.add_pooling(
-        input=input_val, type=trt.PoolingType.AVERAGE, window_size=kernel_size
-    )
+    layer = network.add_pooling(input=input_val, type=trt.PoolingType.AVERAGE, window_size=kernel_size)
    layer.stride = stride
    layer.padding = padding
    layer.average_count_excludes_padding = False if count_include_pad else True
    set_layer_name(layer, target, name)

@@ -2433,23 +2208,18 @@
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    input_val = kwargs["input"]

    if not isinstance(input_val, TRTTensor):
-        raise RuntimeError(
-            f"slice_tensor received input {input_val} that is not part "
-            "of the TensorRT region!"
-        )
+        raise RuntimeError(f"slice_tensor received input {input_val} that is not part " "of the TensorRT region!")

    ranks = len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0)
    dim = get_positive_dim(cast(int, kwargs["dim"]), ranks)
    dynamic_shape = has_dynamic_shape(input_val.shape)
    if network.has_implicit_batch_dimension:
        if dim == 0:
-            raise RuntimeError(
-                f"We do not support slice_tensor at batch dim when it's implicit, got {dim}!"
-            )
+            raise RuntimeError(f"We do not support slice_tensor at batch dim when it's implicit, got {dim}!")
        dim = dim - 1
    else:
        if dynamic_shape:
            # Check whether slice target dim is dynamic shape dim
            assert input_val.shape[dim] != -1, "Can't chunk on dynamic shape dimension!"
@@ -2463,13 +2233,11 @@
    stride[dim] = step_int
    output_shape = list(input_val.shape)
    output_shape[dim] = (stop_int - start_int) // step_int

    if dynamic_shape > 0:
-        output_shape = get_shape_with_dynamic_shape(
-            network, output_shape, input_val, target, name
-        )
+        output_shape = get_shape_with_dynamic_shape(network, output_shape, input_val, target, name)
    layer = network.add_slice(
        input_val,
        start=start,
        shape=[] if dynamic_shape else output_shape,
        stride=stride,
@@ -2502,13 +2270,11 @@
    shape = [input_val.shape[i] if shape[i] == -1 else shape[i] for i in range(ranks)]

    inshape = tuple(input_val.shape)
    shape = tuple(shape)
    start = tuple([0] * ranks)
-    stride = tuple(
-        [int(i == o) for i, o in zip(inshape, shape)]
-    )  # stride == 1 if dimensions match, 0 otherwise
+    stride = tuple([int(i == o) for i, o in zip(inshape, shape)])  # stride == 1 if dimensions match, 0 otherwise
    layer = network.add_slice(input_val, start=start, shape=shape, stride=stride)
    set_layer_name(layer, target, name)
    return layer.get_output(0)


@@ -2615,13 +2381,11 @@
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    input_t = kwargs["input"]
    mask_t = kwargs["mask"]
    value_t = kwargs["value"]
    if network.has_implicit_batch_dimension:
-        raise RuntimeError(
-            "We don't support masked_fill with implicit batch dimension due to select layer!"
-        )
+        raise RuntimeError("We don't support masked_fill with implicit batch dimension due to select layer!")

    shape = list(input_t.shape)
    mask_shape = list(mask_t.shape)

    assert type(value_t) in (
@@ -2674,14 +2438,11 @@
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    input_val = kwargs["input"]

    if not isinstance(input_val, TRTTensor):
-        raise RuntimeError(
-            f"split received input {input_val} that is not part "
-            "of the TensorRT region!"
-        )
+        raise RuntimeError(f"split received input {input_val} that is not part " "of the TensorRT region!")

    dim = cast(int, kwargs["dim"])
    dynamic_shape = has_dynamic_shape(input_val.shape)
    if network.has_implicit_batch_dimension:
        assert dim != 0, "Can't split on batch dim when it's implicit!"
@@ -2695,28 +2456,22 @@
    start = [0] * len(input_val.shape)
    stride = [1] * len(start)
    offset = 0
    num_splits = (input_val.shape[dim] + split_size - 1) // split_size
    if num_splits < 1:
-        raise RuntimeError(
-            f"Invalid split: {input_val.shape[dim]} with split_size={split_size}"
-        )
+        raise RuntimeError(f"Invalid split: {input_val.shape[dim]} with split_size={split_size}")

    max_offset = input_val.shape[dim]
    # add slice layers
    output = []
    for i in range(num_splits):
        shape = list(input_val.shape)
        shape[dim] = min(split_size, cast(int, max_offset - offset))
        start[dim] = offset
        if dynamic_shape:
-            shape = get_shape_with_dynamic_shape(
-                network, shape, input_val, target, f"{name}_shape_{i}"
-            )
-        layer = network.add_slice(
-            input_val, start=start, shape=[] if dynamic_shape else shape, stride=stride
-        )
+            shape = get_shape_with_dynamic_shape(network, shape, input_val, target, f"{name}_shape_{i}")
+        layer = network.add_slice(input_val, start=start, shape=[] if dynamic_shape else shape, stride=stride)
        if dynamic_shape:
            layer.set_input(2, shape)
        offset += split_size
        set_layer_name(layer, target, f"{name}_{i}")
        output.append(layer.get_output(0))
@@ -2732,19 +2487,15 @@
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    input_val = kwargs["input"]

    if not isinstance(input_val, TRTTensor):
-        raise RuntimeError(
-            f"Linear received input {input_val} that is not part "
-            "of the TensorRT region!"
-        )
+        raise RuntimeError(f"Linear received input {input_val} that is not part " "of the TensorRT region!")

    dynamic_dims = get_dynamic_dims(input_val.shape)
    assert len(dynamic_dims) < 2 and input_val.shape[-1] != -1, (
-        "Currently we only support one dynmaic "
-        "dim for linear and it can't be the last dim."
+        "Currently we only support one dynmaic " "dim for linear and it can't be the last dim."
    )

    if isinstance(kwargs["weight"], torch.Tensor):
        weight = get_trt_tensor(network, kwargs["weight"].t(), f"{name}_weight")
        weight_op = trt.MatrixOperation.NONE
@@ -2760,13 +2511,11 @@
        preset_diff -= 1
        input_op = trt.MatrixOperation.VECTOR
    else:
        input_op = trt.MatrixOperation.NONE

-    input_val, weight = broadcast(
-        network, input_val, weight, f"{name}_input", f"{name}_weight", preset_diff
-    )
+    input_val, weight = broadcast(network, input_val, weight, f"{name}_input", f"{name}_weight", preset_diff)
    matmul_layer = network.add_matrix_multiply(input_val, input_op, weight, weight_op)
    set_layer_name(matmul_layer, target, f"{name}_matmul")
    res = matmul_layer.get_output(0)

    if kwargs["bias"] is not None:
@@ -2782,16 +2531,11 @@
    return res


def add_clamp(network, input, val, op):
    acc_ops_clamp_shape = (1,) * len(input.shape)  # broadcast all dimensions
-    acc_ops_clamp_tensor = (
-        val
-        * torch.ones(acc_ops_clamp_shape, dtype=torch_dtype_from_trt(input.dtype))
-        .cpu()
-        .numpy()
-    )
+    acc_ops_clamp_tensor = val * torch.ones(acc_ops_clamp_shape, dtype=torch_dtype_from_trt(input.dtype)).cpu().numpy()
    acc_ops_clamp_trt = network.add_constant(acc_ops_clamp_shape, acc_ops_clamp_tensor)
    layer = network.add_elementwise(input, acc_ops_clamp_trt.get_output(0), op)

    return layer

@@ -2807,25 +2551,18 @@
    input_val = kwargs["input"]
    min_val = kwargs["min"]
    max_val = kwargs["max"]

    if not isinstance(input_val, TRTTensor):
-        raise RuntimeError(
-            f"Clamp received input {input_val} that is not part "
-            "of the TensorRT region!"
-        )
+        raise RuntimeError(f"Clamp received input {input_val} that is not part " "of the TensorRT region!")

    if min_val is not None:
-        clamp_min_layer = add_clamp(
-            network, input_val, min_val, trt.ElementWiseOperation.MAX
-        )
+        clamp_min_layer = add_clamp(network, input_val, min_val, trt.ElementWiseOperation.MAX)
        set_layer_name(clamp_min_layer, target, f"{name}_clamp_min")
        input_val = clamp_min_layer.get_output(0)
    if max_val is not None:
-        clamp_max_layer = add_clamp(
-            network, input_val, max_val, trt.ElementWiseOperation.MIN
-        )
+        clamp_max_layer = add_clamp(network, input_val, max_val, trt.ElementWiseOperation.MIN)
        set_layer_name(clamp_max_layer, target, f"{name}_clamp_max")
        input_val = clamp_max_layer.get_output(0)

    return input_val

@@ -2883,30 +2620,22 @@

    def slice_to_trt_params(py_slice, dim_size):
        """
        Convert python slice to TensorRT slice layer parameters.
        """
-        start = (
-            get_positive_dim(py_slice.start, dim_size) if py_slice.start != None else 0
-        )
+        start = get_positive_dim(py_slice.start, dim_size) if py_slice.start != None else 0
        stride = py_slice.step if py_slice.step != None else 1
-        stop = (
-            get_positive_dim(py_slice.stop, dim_size)
-            if py_slice.stop != None
-            else dim_size
-        )
+        stop = get_positive_dim(py_slice.stop, dim_size) if py_slice.stop != None else dim_size
        size = math.ceil((stop - start) * 1.0 / stride)
        return start, size, stride

    if network.has_implicit_batch_dimension:
        # Raise an error if it's trying to subscript batch dimension unless it's
        # slice(None, None, None).
        batch_subscript = slices[0]
        if batch_subscript not in [slice(None, None, None), slice(0, None, None)]:
-            raise RuntimeError(
-                f"{name}: Can't subscript batch dimension when it's implicit. Got {slices}"
-            )
+            raise RuntimeError(f"{name}: Can't subscript batch dimension when it's implicit. Got {slices}")

        # Remove batch_dim subscript
        slices = slices[1:]

    # Replace ellipsis with expanded slices.
@@ -2995,13 +2724,11 @@
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    tensors = kwargs["tensors"]
    dim = kwargs["dim"]

    if any(not isinstance(t, TRTTensor) for t in tensors):  # type: ignore[union-attr]
-        raise RuntimeError(
-            f"cat received inputs {tensors} that is not part " "of the TensorRT region!"
-        )
+        raise RuntimeError(f"cat received inputs {tensors} that is not part " "of the TensorRT region!")
    layer = network.add_concatenation(inputs=tensors)
    if dim < 0:
        if network.has_implicit_batch_dimension:
            dim = len(tensors[0].shape) + 1 + dim
        else:
@@ -3023,13 +2750,11 @@
    input_val = get_trt_tensor(network, kwargs["input"], f"{name}_input")
    other_val = get_trt_tensor(network, kwargs["other"], f"{name}_other")

    for i in [input_val, other_val]:
        if not isinstance(i, TRTTensor):
-            raise RuntimeError(
-                f"matmul received input {i} that is not part of the TensorRT region!"
-            )
+            raise RuntimeError(f"matmul received input {i} that is not part of the TensorRT region!")

    input_matrix_op = other_matrix_op = trt.MatrixOperation.NONE
    preset_diff = 0

    if len(input_val.shape) == 1:
@@ -3038,16 +2763,12 @@

    if len(other_val.shape) == 1:
        preset_diff += 1
        other_matrix_op = trt.MatrixOperation.VECTOR

-    input_val, other_val = broadcast(
-        network, input_val, other_val, f"{name}_input", f"{name}_other", preset_diff
-    )
-    layer = network.add_matrix_multiply(
-        input_val, input_matrix_op, other_val, other_matrix_op
-    )
+    input_val, other_val = broadcast(network, input_val, other_val, f"{name}_input", f"{name}_other", preset_diff)
+    layer = network.add_matrix_multiply(input_val, input_matrix_op, other_val, other_matrix_op)
    set_layer_name(layer, target, name)
    return layer.get_output(0)


@tensorrt_converter(acc_ops.hardsigmoid)
@@ -3059,14 +2780,11 @@
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    input_val = kwargs["input"]

    if not isinstance(input_val, TRTTensor):
-        raise RuntimeError(
-            f"Hard sigmoid received input {input_val} that is not part "
-            "of the TensorRT region!"
-        )
+        raise RuntimeError(f"Hard sigmoid received input {input_val} that is not part " "of the TensorRT region!")

    return add_activation_layer(
        network,
        input_val,
        trt.ActivationType.HARD_SIGMOID,
@@ -3086,18 +2804,13 @@
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    input_val = kwargs["input"]

    if not isinstance(input_val, TRTTensor):
-        raise RuntimeError(
-            f"Sigmoid received input {input_val} that is not part "
-            "of the TensorRT region!"
-        )
-
-    return add_activation_layer(
-        network, input_val, trt.ActivationType.SIGMOID, target, name
-    )
+        raise RuntimeError(f"Sigmoid received input {input_val} that is not part " "of the TensorRT region!")
+
+    return add_activation_layer(network, input_val, trt.ActivationType.SIGMOID, target, name)


@tensorrt_converter(acc_ops.permute)
def acc_ops_permute(
    network: TRTNetwork,
@@ -3113,14 +2826,11 @@
    else:
        index = kwargs["permutation"]
    permutation = [get_positive_dim(i, ranks) for i in cast(Sequence[int], index)]

    if not isinstance(input_val, TRTTensor):
-        raise RuntimeError(
-            f"permute received input {input_val} that is not part "
-            "of the TensorRT region!"
-        )
+        raise RuntimeError(f"permute received input {input_val} that is not part " "of the TensorRT region!")

    if network.has_implicit_batch_dimension:
        assert permutation[0] == 0, "Can't permute batch dimension when it's implicit."
        permutation = [i - 1 for i in permutation[1:]]

@@ -3139,14 +2849,11 @@
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    input_val = get_trt_tensor(network, kwargs["input"], f"{name}_input")

    if not isinstance(input_val, TRTTensor):
-        raise RuntimeError(
-            f"{name} received input {input_val} that is not part "
-            "of the TensorRT region!"
-        )
+        raise RuntimeError(f"{name} received input {input_val} that is not part " "of the TensorRT region!")

    qparams = kwargs["acc_out_ty"].qparams  # type: ignore[misc]
    q_scale = qparams["scale"]
    q_zero_point = qparams["zero_point"]
    dtype = kwargs["acc_out_ty"].dtype  # type: ignore[misc]
@@ -3157,13 +2864,11 @@
        )

    if q_zero_point != 0:
        raise RuntimeError(f"Only support zero_point == 0, get {q_zero_point}")

-    scale_layer = network.add_constant(
-        (1,), trt.Weights(np.ascontiguousarray([float(q_scale)], dtype=np.float32))
-    )
+    scale_layer = network.add_constant((1,), trt.Weights(np.ascontiguousarray([float(q_scale)], dtype=np.float32)))
    scale_layer.name = input_val.name + ".per_tensor_quant.scale"
    scale = scale_layer.get_output(0)
    # assert trt.__version__ > "8.0", "Explicit quantize op is only supported in "
    # "TensorRT 8.0 or above, current TensorRT version:" + trt.__version__
    layer = network.add_quantize(input=input_val, scale=scale)
@@ -3181,14 +2886,11 @@
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    input_val = get_trt_tensor(network, kwargs["input"], f"{name}_input")

    if not isinstance(input_val, TRTTensor):
-        raise RuntimeError(
-            f"{name} received input {input_val} that is not part "
-            "of the TensorRT region!"
-        )
+        raise RuntimeError(f"{name} received input {input_val} that is not part " "of the TensorRT region!")

    qparams = kwargs["acc_out_ty"].qparams  # type: ignore[misc]
    q_per_channel_scales = qparams["scale"]
    q_per_channel_zero_points = qparams["zero_point"]
    q_per_channel_axis = qparams["axis"]
@@ -3201,17 +2903,13 @@

    # Make sure zero_points are all 0 because only symmetric quantization
    # is supported in TensorRT
    if not torch.equal(
        q_per_channel_zero_points,
-        torch.zeros(
-            q_per_channel_zero_points.shape, dtype=q_per_channel_zero_points.dtype
-        ),
+        torch.zeros(q_per_channel_zero_points.shape, dtype=q_per_channel_zero_points.dtype),
    ):
-        raise RuntimeError(
-            f"Only support zero_point == 0, get {q_per_channel_zero_points}"
-        )
+        raise RuntimeError(f"Only support zero_point == 0, get {q_per_channel_zero_points}")

    if not torch.all(torch.ge(q_per_channel_scales, 0)):
        raise RuntimeError(f"All scale values must be >= 0, get {q_per_channel_scales}")

    scale_layer = network.add_constant(
@@ -3238,14 +2936,11 @@
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    input_val = kwargs["input"]
    input_val_tensor_meta = kwargs["_itensor_to_tensor_meta"][input_val]  # type: ignore[index]

    if not isinstance(input_val, TRTTensor):
-        raise RuntimeError(
-            f"{name} received input {input_val} that is not part "
-            "of the TensorRT region!"
-        )
+        raise RuntimeError(f"{name} received input {input_val} that is not part " "of the TensorRT region!")

    qparams = input_val_tensor_meta.qparams  # type: ignore[misc]
    qscheme = qparams["qscheme"]
    if qscheme == torch.per_tensor_affine:
        q_scale = qparams["scale"]
@@ -3256,30 +2951,25 @@
            raise RuntimeError(f"Only support zero_point == 0, get {q_zero_point}")
    elif qscheme == torch.per_channel_affine:
        q_scale = qparams["scale"]
        q_zero_point = qparams["zero_point"]
        q_axis = qparams["axis"]
-        assert isinstance(
-            q_scale, immutable_list
-        ), "expected q_scale to be immutable_list got {}".format(type(q_scale))
+        assert isinstance(q_scale, immutable_list), "expected q_scale to be immutable_list got {}".format(type(q_scale))
        scale_shape = (len(q_scale),)
        if any(x != 0 for x in q_zero_point):
            raise RuntimeError(f"Only support zero_point == 0, get {q_zero_point}")
    else:
        raise RuntimeError("Unsupported qscheme in dequantize: {qscheme}")

    dtype = input_val_tensor_meta.dtype  # type: ignore[misc]

    if dtype not in (torch.quint8, torch.qint8, torch.qint32):
        raise RuntimeError(
-            "Only support (torch.quint8, torch.qint8, torch.qint32) "
-            f"quantized type in dequantize, get {dtype}."
+            "Only support (torch.quint8, torch.qint8, torch.qint32) " f"quantized type in dequantize, get {dtype}."
        )

-    scale_layer = network.add_constant(
-        scale_shape, trt.Weights(np.ascontiguousarray(q_scale, dtype=np.float32))
-    )
+    scale_layer = network.add_constant(scale_shape, trt.Weights(np.ascontiguousarray(q_scale, dtype=np.float32)))
    scale_layer.name = input_val.name + ".dequant.scale"
    scale = scale_layer.get_output(0)
    # assert trt.__version__ > "8.0", "Explicit dequantize op is only supported in "
    # "TensorRT 8.0 or above, current TensorRT version:" + trt.__version__
    layer = network.add_dequantize(input=input_val, scale=scale)
@@ -3296,24 +2986,17 @@
    kwargs: Dict[str, Argument],
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    input_val = kwargs["input"]
    if not isinstance(input_val, TRTTensor):
-        raise RuntimeError(
-            f"GELU received input {input_val} that is not part "
-            "of the TensorRT region!"
-        )
+        raise RuntimeError(f"GELU received input {input_val} that is not part " "of the TensorRT region!")
    if network.has_implicit_batch_dimension:
-        raise RuntimeError(
-            "GeLU converter currently doesn't support implicit batch dimension"
-        )
+        raise RuntimeError("GeLU converter currently doesn't support implicit batch dimension")

    plugin_name = "CustomGeluPluginDynamic"
    # type_id 0 for float32, 1 for  float16
-    type_id = trt.PluginField(
-        "type_id", np.array(0, dtype=np.int32), trt.PluginFieldType.INT32
-    )
+    type_id = trt.PluginField("type_id", np.array(0, dtype=np.int32), trt.PluginFieldType.INT32)
    field_collection = TRTPluginFieldCollection([type_id])
    plugin_version = "1"

    plugin = get_trt_plugin(plugin_name, field_collection, plugin_version)

@@ -3334,14 +3017,11 @@
    chunks = cast(int, kwargs["chunks"])
    dim = cast(int, kwargs["dim"])
    input_dim_size = len(input_val.shape)  # type: ignore[union-attr]

    if not isinstance(input_val, TRTTensor):
-        raise RuntimeError(
-            f"chunk received input {input_val} that is not part "
-            "of the TensorRT region!"
-        )
+        raise RuntimeError(f"chunk received input {input_val} that is not part " "of the TensorRT region!")

    dynamic_shape = has_dynamic_shape(input_val.shape)
    if network.has_implicit_batch_dimension:
        input_dim_size += 1
        dim = get_positive_dim(dim, input_dim_size)
@@ -3371,17 +3051,13 @@
    output = []
    for i in range(chunks):
        shape = list(input_val.shape)
        shape[dim] = min(split_size, max_offset - offset)
        if dynamic_shape:
-            shape = get_shape_with_dynamic_shape(
-                network, shape, input_val, target, f"{name}_{i}"
-            )
+            shape = get_shape_with_dynamic_shape(network, shape, input_val, target, f"{name}_{i}")
        start[dim] = offset
-        layer = network.add_slice(
-            input_val, start=start, shape=[] if dynamic_shape else shape, stride=stride
-        )
+        layer = network.add_slice(input_val, start=start, shape=[] if dynamic_shape else shape, stride=stride)
        if dynamic_shape:
            layer.set_input(2, shape)
        offset += split_size
        set_layer_name(layer, target, f"{name}_{i}")
        output.append(layer.get_output(0))
@@ -3400,18 +3076,13 @@
    dim = cast(int, kwargs["dim"])
    input_shape = input_val.shape  # type: ignore[union-attr]
    input_dim_size = len(input_val.shape)  # type: ignore[union-attr]

    if not isinstance(input_val, TRTTensor):
-        raise RuntimeError(
-            f"cumsum received input {input_val} that is not part "
-            "of the TensorRT region!"
-        )
+        raise RuntimeError(f"cumsum received input {input_val} that is not part " "of the TensorRT region!")
    if network.has_implicit_batch_dimension:
-        raise RuntimeError(
-            "cumsum converter currently doesn't support implicit batch dimension"
-        )
+        raise RuntimeError("cumsum converter currently doesn't support implicit batch dimension")
    dim = get_positive_dim(dim, input_dim_size)
    loop = network.add_loop()
    trip_limit = None
    if input_shape[dim] > 0:
        axis = torch.tensor(input_shape[dim], dtype=torch.int32)
@@ -3427,13 +3098,11 @@
    loop.add_trip_limit(trip_limit, trt.TripLimit(0))
    iterator = loop.add_iterator(input_val, dim, False)
    data = iterator.get_output(0)
    new_dims = tuple(data.shape)
    zero_tensor = torch.zeros(new_dims, dtype=trt_dtype_to_torch_dtype(input_val.dtype))
-    zero_tensor = network.add_constant(
-        zero_tensor.shape, to_numpy(zero_tensor)
-    ).get_output(0)
+    zero_tensor = network.add_constant(zero_tensor.shape, to_numpy(zero_tensor)).get_output(0)

    running_sum = loop.add_recurrence(zero_tensor)
    set_layer_name(running_sum, target, f"{name}_running_sum_1")
    running_sum_tensor = running_sum.get_output(0)

@@ -3476,14 +3145,11 @@
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    input_val = kwargs["input"]

    if not isinstance(input_val, TRTTensor):
-        raise RuntimeError(
-            f"hardtanh received input {input_val} that is not part "
-            "of the TensorRT region!"
-        )
+        raise RuntimeError(f"hardtanh received input {input_val} that is not part " "of the TensorRT region!")

    return add_activation_layer(
        network,
        input_val,
        trt.ActivationType.CLIP,
@@ -3507,26 +3173,19 @@
    scale_factor = kwargs["scale_factor"]
    mode = kwargs["mode"]
    align_corners = kwargs["align_corners"]

    if not isinstance(input_val, TRTTensor):
-        raise RuntimeError(
-            f"interpolate received input {input_val} that is not part "
-            "of the TensorRT region!"
-        )
+        raise RuntimeError(f"interpolate received input {input_val} that is not part " "of the TensorRT region!")

    dim = input_val.shape
    ranks = len(input_val.shape)
    if network.has_implicit_batch_dimension:
-        assert (
-            ranks >= 2 and ranks <= 4
-        ), "Interpolate expects inputs are 3D,4D,5D in shape"
+        assert ranks >= 2 and ranks <= 4, "Interpolate expects inputs are 3D,4D,5D in shape"
        ranks = ranks - 1
    else:
-        assert (
-            ranks >= 3 and ranks <= 5
-        ), "Interpolate expects inputs are 3D,4D,5D in shape"
+        assert ranks >= 3 and ranks <= 5, "Interpolate expects inputs are 3D,4D,5D in shape"
        ranks = ranks - 2

    layer = network.add_resize(input_val)
    if network.has_implicit_batch_dimension:
        if size != None:
@@ -3555,13 +3214,11 @@
        layer.resize_mode = trt.ResizeMode.LINEAR
    else:
        layer.resize_mode = trt.ResizeMode.NEAREST

    if align_corners != None:
-        layer.coordinate_transformation = (
-            trt.ResizeCoordinateTransformation.ALIGN_CORNERS
-        )
+        layer.coordinate_transformation = trt.ResizeCoordinateTransformation.ALIGN_CORNERS

    set_layer_name(layer, target, name)
    return layer.get_output(0)


@@ -3579,13 +3236,11 @@
    if dtype_val is None:
        dtype_val = input_val.dtype
        dtype_val = torch_dtype_from_trt(dtype_val)

    device_val = kwargs.get("device")
-    assert (
-        device_val == "cuda" or device_val == None
-    ), f"device is not `cuda` but {device_val}"
+    assert device_val == "cuda" or device_val == None, f"device is not `cuda` but {device_val}"

    weight = torch.ones(size_val, dtype=dtype_val)
    return get_trt_tensor(network, weight, f"{name}_weight")


@@ -3603,13 +3258,11 @@
    if dtype_val is None:
        dtype_val = input_val.dtype
        dtype_val = torch_dtype_from_trt(dtype_val)

    device_val = kwargs.get("device")
-    assert (
-        device_val == "cuda" or device_val == None
-    ), f"device is not `cuda` but {device_val}"
+    assert device_val == "cuda" or device_val == None, f"device is not `cuda` but {device_val}"

    weight = torch.zeros(size_val, dtype=dtype_val)
    return get_trt_tensor(network, weight, f"{name}_weight")


@@ -3634,13 +3287,11 @@
        input_val[i] = get_trt_tensor(network, input_source, name + f"_input_source{i}")

    if const_flag:
        for i, input_source in enumerate(input_val):
            if input_source.dtype != trt.float32:
-                input_val[i] = type_cast(
-                    network, target, f"{name}_input_cast{i}", input_source, trt.float32
-                )
+                input_val[i] = type_cast(network, target, f"{name}_input_cast{i}", input_source, trt.float32)
    einsum_layer = network.add_einsum(inputs=input_val, equation=equation)
    return einsum_layer.get_output(0)


@tensorrt_converter(acc_ops.as_strided)
--- py/torch_tensorrt/fx/test/converters/acc_op/test_prod.py	2022-08-12 18:48:48.790461 +0000
+++ py/torch_tensorrt/fx/test/converters/acc_op/test_prod.py	2022-08-12 18:52:26.557706 +0000
@@ -70,13 +70,11 @@
            inputs,
            expected_ops={expected_acc_op},
            test_implicit_batch_dim=(dim != 0),
        )

-    @parameterized.expand(
-        [(f"{acc_ops.prod.__name__}_no_dim_no_keepdim", torch.prod, acc_ops.prod)]
-    )
+    @parameterized.expand([(f"{acc_ops.prod.__name__}_no_dim_no_keepdim", torch.prod, acc_ops.prod)])
    def test_prod_all_dims(
        self,
        test_name,
        op,
        expected_acc_op,
@@ -107,12 +105,10 @@
                dtype=torch.float32,
                shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))],
            ),
        ]

-        self.run_test_with_dynamic_shape(
-            Prod(), input_specs, expected_ops={acc_ops.prod}
-        )
+        self.run_test_with_dynamic_shape(Prod(), input_specs, expected_ops={acc_ops.prod})


if __name__ == "__main__":
    run_tests()
--- py/torch_tensorrt/fx/test/converters/acc_op/test_reduce_ops.py	2022-08-12 18:48:48.790461 +0000
+++ py/torch_tensorrt/fx/test/converters/acc_op/test_reduce_ops.py	2022-08-12 18:52:26.565275 +0000
@@ -50,16 +50,11 @@
            inputs,
            expected_ops={expected_acc_op},
            test_implicit_batch_dim=(dim != 0),
        )

-    @parameterized.expand(
-        [
-            (f"{acc_op.__name__}_no_dim_no_keepdim", op, acc_op)
-            for op, acc_op in reduce_ops
-        ]
-    )
+    @parameterized.expand([(f"{acc_op.__name__}_no_dim_no_keepdim", op, acc_op) for op, acc_op in reduce_ops])
    def test_reduce_all_dims(
        self,
        test_name,
        op,
        expected_acc_op,
@@ -74,16 +69,11 @@
            inputs,
            expected_ops={expected_acc_op},
            test_implicit_batch_dim=False,
        )

-    @parameterized.expand(
-        [
-            (f"{acc_op.__name__}_no_dim_no_keepdim", op, acc_op)
-            for op, acc_op in reduce_ops
-        ]
-    )
+    @parameterized.expand([(f"{acc_op.__name__}_no_dim_no_keepdim", op, acc_op) for op, acc_op in reduce_ops])
    def test_reduce_all_dims_with_dynamic_shape_four_dimensions(
        self,
        test_name,
        op,
        expected_acc_op,
@@ -97,12 +87,10 @@
                shape=(-1, -1, -1, -1),
                dtype=torch.float32,
                shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (3, 3, 3, 3))],
            ),
        ]
-        self.run_test_with_dynamic_shape(
-            Reduce(), input_specs, expected_ops={expected_acc_op}
-        )
+        self.run_test_with_dynamic_shape(Reduce(), input_specs, expected_ops={expected_acc_op})


if __name__ == "__main__":
    run_tests()
--- py/torch_tensorrt/fx/test/converters/acc_op/test_tile.py	2022-08-12 18:48:48.794461 +0000
+++ py/torch_tensorrt/fx/test/converters/acc_op/test_tile.py	2022-08-12 18:52:26.764386 +0000
@@ -26,14 +26,11 @@
        inputs = [torch.randn(*input_shape)]
        self.run_test(
            Tile(dims),
            inputs,
            expected_ops={acc_ops.tile},
-            test_implicit_batch_dim=(
-                len(input_shape) > len(dims)
-                or (len(input_shape) == len(dims) and dims[0] == 1)
-            ),
+            test_implicit_batch_dim=(len(input_shape) > len(dims) or (len(input_shape) == len(dims) and dims[0] == 1)),
        )

    @parameterized.expand(
        [
            ("same_num_dims", (-1, 2, 3), (1, 2, 2)),
@@ -62,13 +59,11 @@
                        tuple(i if i != -1 else 3 for i in shape),
                    )
                ],
            ),
        ]
-        self.run_test_with_dynamic_shape(
-            Tile(dims), input_specs, expected_ops={acc_ops.tile}
-        )
+        self.run_test_with_dynamic_shape(Tile(dims), input_specs, expected_ops={acc_ops.tile})

    @parameterized.expand(
        [
            ("all_dynamic_dim", (-1, -1), (1, 2, 2, 1)),
        ]
@@ -88,13 +83,11 @@
                dtype=torch.float32,
                shape_ranges=[((1, 1, 1, 3), (3, 3, 3, 3), (3, 3, 3, 3))],
            ),
        ]

-        self.run_test_with_dynamic_shape(
-            Tile(dims), input_specs, expected_ops={acc_ops.tile}
-        )
+        self.run_test_with_dynamic_shape(Tile(dims), input_specs, expected_ops={acc_ops.tile})

    def test_tile_non_int_dims(self):
        class Tile(nn.Module):
            def __init__(self):
                super().__init__()
@@ -103,13 +96,11 @@
                y = y * 2
                return torch.tile(x, (1, y.shape[1], y.shape[1]))

        inputs = [torch.randn(2, 2, 3), torch.randn(2, 2, 3)]
        batch_size_range = (1, 2, 3)
-        input_specs = InputTensorSpec.from_tensors_with_dynamic_batch_size(
-            inputs, batch_size_range
-        )
+        input_specs = InputTensorSpec.from_tensors_with_dynamic_batch_size(inputs, batch_size_range)
        self.run_test_with_dynamic_shape(
            Tile(),
            input_specs,
            expected_ops={acc_ops.tile},
        )
@@ -134,12 +125,10 @@
                dtype=torch.float32,
                shape_ranges=[((1, 1, 1, 3), (3, 3, 3, 3), (3, 3, 3, 3))],
            ),
        ]

-        self.run_test_with_dynamic_shape(
-            Tile(), input_specs, expected_ops={acc_ops.tile}
-        )
+        self.run_test_with_dynamic_shape(Tile(), input_specs, expected_ops={acc_ops.tile})


if __name__ == "__main__":
    run_tests()
--- py/torch_tensorrt/fx/test/converters/acc_op/test_to_dtype.py	2022-08-12 18:48:48.794461 +0000
+++ py/torch_tensorrt/fx/test/converters/acc_op/test_to_dtype.py	2022-08-12 18:52:26.814449 +0000
@@ -51,13 +51,11 @@

        input = torch.randn(2, 2).to(torch.float16)
        inputs = [
            input,
        ]
-        self.run_test(
-            To(), inputs, expected_ops={acc_ops.to_dtype}, test_implicit_batch_dim=False
-        )
+        self.run_test(To(), inputs, expected_ops={acc_ops.to_dtype}, test_implicit_batch_dim=False)

    def test_cuda_fp16(self):
        class To(torch.nn.Module):
            def forward(self, x):
                return x.to(torch.device("cuda:0"), torch.float16)
@@ -106,13 +104,11 @@
                dtype=torch.float16,
                shape_ranges=[((1, 1, 3, 3), (3, 3, 3, 3), (3, 3, 3, 3))],
            ),
        ]

-        self.run_test_with_dynamic_shape(
-            To(), input_specs, expected_ops={acc_ops.to_dtype, acc_ops.add}
-        )
+        self.run_test_with_dynamic_shape(To(), input_specs, expected_ops={acc_ops.to_dtype, acc_ops.add})

    def test_device(self):
        class To(torch.nn.Module):
            def __init__(self):
                super().__init__()
@@ -152,13 +148,11 @@
                dtype=torch.float16,
                shape_ranges=[((1, 1, 3, 3), (3, 3, 3, 3), (3, 3, 3, 3))],
            ),
        ]

-        self.run_test_with_dynamic_shape(
-            To(), input_specs, expected_ops={acc_ops.to_dtype, acc_ops.add}
-        )
+        self.run_test_with_dynamic_shape(To(), input_specs, expected_ops={acc_ops.to_dtype, acc_ops.add})

    def test_device_fp16(self):
        class To(torch.nn.Module):
            def __init__(self):
                super().__init__()
@@ -244,13 +238,11 @@
                dtype=torch.float32,
                shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (3, 3, 3, 3))],
            ),
        ]

-        self.run_test_with_dynamic_shape(
-            To(), input_specs, expected_ops={acc_ops.to_dtype}
-        )
+        self.run_test_with_dynamic_shape(To(), input_specs, expected_ops={acc_ops.to_dtype})

    # Half is not suitable for dynamic shape
    # Error: assert engine

    # tensor.half()
@@ -307,12 +299,10 @@
                dtype=torch.int,
                shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (3, 3, 3, 3))],
            ),
        ]

-        self.run_test_with_dynamic_shape(
-            To(), input_specs, expected_ops={acc_ops.to_dtype}
-        )
+        self.run_test_with_dynamic_shape(To(), input_specs, expected_ops={acc_ops.to_dtype})


if __name__ == "__main__":
    run_tests()
--- py/torch_tensorrt/fx/test/converters/acc_op/test_topk.py	2022-08-12 18:48:48.794461 +0000
+++ py/torch_tensorrt/fx/test/converters/acc_op/test_topk.py	2022-08-12 18:52:26.820155 +0000
@@ -24,13 +24,11 @@
                self.dim = dim
                self.largest = largest

            def forward(self, x):
                if self.dim is not None:
-                    out = torch.topk(
-                        x, k=self.k, dim=self.dim, largest=self.largest, sorted=False
-                    )
+                    out = torch.topk(x, k=self.k, dim=self.dim, largest=self.largest, sorted=False)
                else:
                    out = torch.topk(x, k=self.k, largest=self.largest, sorted=False)
                return out[0], out[1]

        inputs = [torch.randn(1, 2, 3, 4)]
@@ -58,13 +56,11 @@
                self.dim = dim
                self.largest = largest

            def forward(self, x):
                if self.dim is not None:
-                    out = torch.topk(
-                        x, k=self.k, dim=self.dim, largest=self.largest, sorted=False
-                    )
+                    out = torch.topk(x, k=self.k, dim=self.dim, largest=self.largest, sorted=False)
                else:
                    out = torch.topk(x, k=self.k, largest=self.largest, sorted=False)
                return out[0], out[1]

        input_specs = [
@@ -73,12 +69,10 @@
                dtype=torch.float32,
                shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (3, 3, 3, 3))],
            ),
        ]

-        self.run_test_with_dynamic_shape(
-            TopK(k, dim), input_specs, expected_ops={acc_ops.topk}
-        )
+        self.run_test_with_dynamic_shape(TopK(k, dim), input_specs, expected_ops={acc_ops.topk})


if __name__ == "__main__":
    run_tests()
--- py/torch_tensorrt/fx/test/converters/acc_op/test_unary_ops.py	2022-08-12 18:48:48.794461 +0000
+++ py/torch_tensorrt/fx/test/converters/acc_op/test_unary_ops.py	2022-08-12 18:52:26.956415 +0000
@@ -62,13 +62,11 @@
                dtype=torch.float32,
                shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (3, 3, 3, 3))],
            ),
        ]

-        self.run_test_with_dynamic_shape(
-            TestModule(orig_op), input_specs, expected_ops={expected_op}
-        )
+        self.run_test_with_dynamic_shape(TestModule(orig_op), input_specs, expected_ops={expected_op})


class TestUnaryOpNotConverters(AccTestCase):
    @parameterized.expand(
        [
@@ -87,13 +85,11 @@
                x = self.orig_op(x)
                return self.orig_op(x)

        m = TestModule(orig_op)
        inputs = [torch.randn(2, 2, 3).to(input_dtype)]
-        self.run_test(
-            m, inputs, expected_ops={expected_op}, test_implicit_batch_dim=False
-        )
+        self.run_test(m, inputs, expected_ops={expected_op}, test_implicit_batch_dim=False)


class TestUnaryOpNotConvertersWithDynamicShapeFourDimensions(AccTestCase):
    @parameterized.expand(
        [
@@ -118,13 +114,11 @@
                dtype=torch.float32,
                shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (3, 3, 3, 3))],
            ),
        ]

-        self.run_test_with_dynamic_shape(
-            TestModule(orig_op), input_specs, expected_ops={expected_op}
-        )
+        self.run_test_with_dynamic_shape(TestModule(orig_op), input_specs, expected_ops={expected_op})


class TestUnaryRSQRTConverters(AccTestCase):
    def test_unary_ops(self):
        class TestModule(nn.Module):
@@ -148,12 +142,10 @@
                dtype=torch.float32,
                shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (3, 3, 3, 3))],
            ),
        ]

-        self.run_test_with_dynamic_shape(
-            TestModule(), input_specs, expected_ops={acc_ops.sqrt, acc_ops.reciprocal}
-        )
+        self.run_test_with_dynamic_shape(TestModule(), input_specs, expected_ops={acc_ops.sqrt, acc_ops.reciprocal})


if __name__ == "__main__":
    run_tests()
--- py/torch_tensorrt/fx/test/core/test_input_tensor_spec.py	2022-08-12 18:48:48.794461 +0000
+++ py/torch_tensorrt/fx/test/core/test_input_tensor_spec.py	2022-08-12 18:52:26.977250 +0000
@@ -35,26 +35,22 @@
            self._validate_spec(spec, tensor)

    def test_from_tensors_with_dynamic_batch_size(self):
        tensors = [torch.randn(1, 2, 3), torch.randn(1, 4)]
        batch_size_range = [2, 3, 4]
-        specs = InputTensorSpec.from_tensors_with_dynamic_batch_size(
-            tensors, batch_size_range
-        )
+        specs = InputTensorSpec.from_tensors_with_dynamic_batch_size(tensors, batch_size_range)
        for spec, tensor in zip(specs, tensors):
            self._validate_spec(spec, tensor, dynamic_dims=[0])

            for batch_size, shape in zip(batch_size_range, spec.shape_ranges[0]):
                self.assertEqual(batch_size, shape[0])
                self.assertSequenceEqual(tensor.shape[1:], shape[1:])

    def test_from_tensors_with_dynamic_batch_size_different_batch_dims(self):
        tensors = [torch.randn(1, 2, 3), torch.randn(2, 1, 4)]
        batch_size_range = [2, 3, 4]
-        specs = InputTensorSpec.from_tensors_with_dynamic_batch_size(
-            tensors, batch_size_range, batch_dims=[0, 1]
-        )
+        specs = InputTensorSpec.from_tensors_with_dynamic_batch_size(tensors, batch_size_range, batch_dims=[0, 1])
        for i, spec_and_tensor in enumerate(zip(specs, tensors)):
            spec, tensor = spec_and_tensor
            self._validate_spec(spec, tensor, dynamic_dims=[i])

            for batch_size, shape in zip(batch_size_range, spec.shape_ranges[0]):
@@ -62,13 +58,11 @@
                tensor_shape = list(tensor.shape)
                tensor_shape[i] = batch_size
                self.assertSequenceEqual(tensor_shape, shape)

    def test_generate_input_specs(self):
-        lower_setting = LowerSetting(
-            explicit_batch_dimension=False, max_batch_size=256, opt_profile_replica=2
-        )
+        lower_setting = LowerSetting(explicit_batch_dimension=False, max_batch_size=256, opt_profile_replica=2)

        # Implicit batch dim.
        inputs = [torch.randn(1, 2, 3)]
        specs = generate_input_specs(inputs, lower_setting)
        for spec, tensor in zip(specs, inputs):
--- py/torch_tensorrt/fx/test/quant/test_quant_trt.py	2022-08-12 18:48:48.794461 +0000
+++ py/torch_tensorrt/fx/test/quant/test_quant_trt.py	2022-08-12 18:52:27.493933 +0000
@@ -46,13 +46,11 @@
            shape_ranges=shape_ranges,
            has_batch_dim=True,
        )
    ]

-    interp = TRTInterpreter(
-        model, input_specs, explicit_batch_dimension=True, explicit_precision=True
-    )
+    interp = TRTInterpreter(model, input_specs, explicit_batch_dimension=True, explicit_precision=True)
    result = interp.run(lower_precision=LowerPrecision.INT8)
    trt_mod = TRTModule(result.engine, result.input_names, result.output_names)
    return trt_mod


@@ -65,13 +63,11 @@
            ),
            weight=torch.ao.quantization.default_weight_observer,
        )
        self.trt_backend_config_dict = get_tensorrt_backend_config_dict()

-    def _test_quantized_inputs_outputs(
-        self, prepare_custom_config_dict, prepare_count_check, convert_count_check
-    ):
+    def _test_quantized_inputs_outputs(self, prepare_custom_config_dict, prepare_count_check, convert_count_check):
        """
        Test the option to have inputs and outputs of the graph quantized
        """

        class M(torch.nn.Module):
@@ -113,13 +109,11 @@
            # output of ref conv1 and output of ref conv2
            ns.call_function(torch.quantize_per_tensor): 2,
            # input of ref conv1 and input of ref conv2
            ns.call_method("dequantize"): 2,
        }
-        self._test_quantized_inputs_outputs(
-            prepare_custom_config_dict, prepare_count_check, convert_count_check
-        )
+        self._test_quantized_inputs_outputs(prepare_custom_config_dict, prepare_count_check, convert_count_check)

    def test_fp32_input_quantized_output(self):
        prepare_custom_config_dict = {"output_quantized_idxs": [0]}
        prepare_count_check = {
            ns.call_module(torch.ao.quantization.MinMaxObserver): 3,
@@ -128,13 +122,11 @@
            # input, output of conv1 and output of conv2
            ns.call_function(torch.quantize_per_tensor): 3,
            # input of conv1, conv2
            ns.call_method("dequantize"): 2,
        }
-        self._test_quantized_inputs_outputs(
-            prepare_custom_config_dict, prepare_count_check, convert_count_check
-        )
+        self._test_quantized_inputs_outputs(prepare_custom_config_dict, prepare_count_check, convert_count_check)

    def test_quantized_input_fp32_output(self):
        prepare_custom_config_dict = {"input_quantized_idxs": [0]}
        prepare_count_check = {
            ns.call_module(torch.ao.quantization.MinMaxObserver): 2,
@@ -143,26 +135,22 @@
            # output of conv1, conv2
            ns.call_function(torch.quantize_per_tensor): 2,
            # input of ref conv1, input of ref conv2, final output
            ns.call_method("dequantize"): 3,
        }
-        self._test_quantized_inputs_outputs(
-            prepare_custom_config_dict, prepare_count_check, convert_count_check
-        )
+        self._test_quantized_inputs_outputs(prepare_custom_config_dict, prepare_count_check, convert_count_check)

    def test_fp32_input_fp32_output(self):
        prepare_custom_config_dict = {}
        prepare_count_check = {
            ns.call_module(torch.ao.quantization.MinMaxObserver): 3,
        }
        convert_count_check = {
            ns.call_function(torch.quantize_per_tensor): 3,
            ns.call_method("dequantize"): 3,
        }
-        self._test_quantized_inputs_outputs(
-            prepare_custom_config_dict, prepare_count_check, convert_count_check
-        )
+        self._test_quantized_inputs_outputs(prepare_custom_config_dict, prepare_count_check, convert_count_check)

    def _test_standalone_module(
        self,
        interface_config,
        prepare_count_check,
@@ -213,20 +201,14 @@

        data = torch.randn(1, 1, 1, 1)
        # instantiate M and RefM and align the parameters
        original_m = M().eval()
        original_ref_m = RefM().eval()
-        original_ref_m.conv1.weight = torch.nn.Parameter(
-            original_m.conv.weight.detach()
-        )
+        original_ref_m.conv1.weight = torch.nn.Parameter(original_m.conv.weight.detach())
        original_ref_m.conv1.bias = torch.nn.Parameter(original_m.conv.bias.detach())
-        original_ref_m.conv2.weight = torch.nn.Parameter(
-            original_m.standalone.conv.weight.detach()
-        )
-        original_ref_m.conv2.bias = torch.nn.Parameter(
-            original_m.standalone.conv.bias.detach()
-        )
+        original_ref_m.conv2.weight = torch.nn.Parameter(original_m.standalone.conv.weight.detach())
+        original_ref_m.conv2.bias = torch.nn.Parameter(original_m.standalone.conv.bias.detach())

        sm_example_inputs = (data,)
        prepare_config = {
            "standalone_module_name": [
                (
@@ -253,20 +235,16 @@
            backend_config=backend_config_dict,
        )
        # calibration
        m(data)
        self.checkGraphModuleNodes(m, expected_node_occurrence=prepare_count_check)
-        self.checkGraphModuleNodes(
-            m.standalone, expected_node_occurrence=standalone_prepare_count_check
-        )
+        self.checkGraphModuleNodes(m.standalone, expected_node_occurrence=standalone_prepare_count_check)

        # check converted/quantized model
        m = convert_to_reference_fx(m, backend_config=backend_config_dict)
        self.checkGraphModuleNodes(m, expected_node_occurrence=convert_count_check)
-        self.checkGraphModuleNodes(
-            m.standalone, expected_node_occurrence=standalone_convert_count_check
-        )
+        self.checkGraphModuleNodes(m.standalone, expected_node_occurrence=standalone_convert_count_check)
        res = m(data)

        # quantize the reference model
        ref_m = prepare_fx(
            original_ref_m_copy,
@@ -285,17 +263,13 @@
            "output_quantized_idxs": [],  # float output
        }
        interface_config = float_interface_config
        # input and output of first conv, observer for standalone module
        # will be inserted in the standalone module itself
-        prepare_count_check = {
-            ns.call_module(torch.ao.quantization.HistogramObserver): 2
-        }
+        prepare_count_check = {ns.call_module(torch.ao.quantization.HistogramObserver): 2}
        # for input and output of conv in the standalone module
-        standalone_prepare_count_check = {
-            ns.call_module(torch.ao.quantization.HistogramObserver): 2
-        }
+        standalone_prepare_count_check = {ns.call_module(torch.ao.quantization.HistogramObserver): 2}
        convert_count_check = {
            # input and output of reference conv
            ns.call_function(torch.quantize_per_tensor): 2,
            ns.call_module(nnqr.Conv2d): 1,
            ns.call_method("dequantize"): 2,
@@ -351,17 +325,13 @@
            "root_module": torch.nn.Conv2d,
            "reference_quantized_module_for_root": torch.nn.quantized._reference.Conv2d,
        }
        custom_backend_config_dict = {"configs": [conv_module_config]}
        # observer for input and output of first conv
-        prepare_count_check = {
-            ns.call_module(torch.ao.quantization.HistogramObserver): 2
-        }
+        prepare_count_check = {ns.call_module(torch.ao.quantization.HistogramObserver): 2}
        # for output of conv in the standalone module
-        standalone_prepare_count_check = {
-            ns.call_module(torch.ao.quantization.HistogramObserver): 1
-        }
+        standalone_prepare_count_check = {ns.call_module(torch.ao.quantization.HistogramObserver): 1}
        convert_count_check = {
            # quantizing input/output for reference conv
            ns.call_function(torch.quantize_per_tensor): 2,
            ns.call_module(nnqr.Conv2d): 1,
            # dequantize the input of reference conv and
@@ -400,13 +370,11 @@
            ),
            weight=torch.ao.quantization.default_weight_observer,
        )
        self.trt_backend_config_dict = get_tensorrt_backend_config_dict()

-    def _test_module(
-        self, m, inputs, shape_ranges, no_prepare=None, no_convert=None, is_qat=False
-    ):
+    def _test_module(self, m, inputs, shape_ranges, no_prepare=None, no_convert=None, is_qat=False):
        """
        Args:
          m: the float module we want to test
          inputs: list of inputs for the module
          shape_ranges: a list of shape_range, where every shape_range is a tuple of
@@ -468,13 +436,11 @@

            def forward(self, x):
                return self.relu(self.conv(x))

        # just testing conv2d since conv1d and conv3d are not supported in fx2trt
-        for dim, has_relu, f_relu, is_qat in itertools.product(
-            [1, 2], [True, False], [True, False], [True, False]
-        ):
+        for dim, has_relu, f_relu, is_qat in itertools.product([1, 2], [True, False], [True, False], [True, False]):
            # when has_relu=False, we have torch.nn.Identity, which would introduce
            # extra quant-dequat pair
            no_convert = {
                ns.call_function(torch.quantize_per_tensor): 2 + int(not has_relu),
                ns.call_method("dequantize"): 2 + int(not has_relu),
@@ -510,13 +476,11 @@
                return self.relu(self.linear(x))

        linear_input = torch.rand(8, 5)

        shape_ranges = [((1, 5), (5, 5), (10, 5))]
-        for has_relu, f_relu, is_qat in itertools.product(
-            [True, False], [True, False], [True, False]
-        ):
+        for has_relu, f_relu, is_qat in itertools.product([True, False], [True, False], [True, False]):
            # when has_relu=False, we have torch.nn.Identity, which would introduce
            # extra quant-dequat pair
            no_convert = {
                ns.call_function(torch.quantize_per_tensor): 2 + int(not has_relu),
                ns.call_method("dequantize"): 2 + int(not has_relu),
@@ -662,13 +626,11 @@
            ns.call_function(torch.addmm): 1,
            ns.call_method("dequantize"): 3,
        }
        self.checkGraphModuleNodes(quantized, expected_node_occurrence=node_occurrence)

-    @unittest.skip(
-        "This is not supported yet, we can enable the test after it's supported"
-    )
+    @unittest.skip("This is not supported yet, we can enable the test after it's supported")
    def test_conv_add(self):
        class M(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.conv = torch.nn.Conv2d(3, 3, 3)
@@ -828,13 +790,11 @@
        self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
        standalone_node_occurrence = {
            # output of the standalone module
            ns.call_module(torch.ao.quantization.HistogramObserver): 1,
        }
-        self.checkGraphModuleNodes(
-            m.standalone, expected_node_occurrence=standalone_node_occurrence
-        )
+        self.checkGraphModuleNodes(m.standalone, expected_node_occurrence=standalone_node_occurrence)
        m = convert_to_reference_fx(m, backend_config=backend_config_dict)
        node_occurrence = {
            # two inputs for standalone module
            ns.call_function(torch.quantize_per_tensor): 2,
            ns.call_module(nn.Conv2d): 1,
@@ -847,13 +807,11 @@
            ns.call_module(nn.Conv2d): 1,
            ns.call_module(torch.nn.ReLU): 1,
            # two input and one output for the pattern in standalone module
            ns.call_method("dequantize"): 3,
        }
-        self.checkGraphModuleNodes(
-            m.standalone, expected_node_occurrence=standalone_node_occurrence
-        )
+        self.checkGraphModuleNodes(m.standalone, expected_node_occurrence=standalone_node_occurrence)

    def test_quant_dequant_not_fold(self):
        class LinearModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
--- py/torch_tensorrt/fx/tools/common_fx2trt.py	2022-08-12 18:48:48.794461 +0000
+++ py/torch_tensorrt/fx/tools/common_fx2trt.py	2022-08-12 18:52:27.791019 +0000
@@ -29,13 +29,11 @@
    """
    target_atoms = target.split(".")
    attr_itr = mod
    for i, atom in enumerate(target_atoms):
        if not hasattr(attr_itr, atom):
-            raise RuntimeError(
-                f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}"
-            )
+            raise RuntimeError(f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}")
        attr_itr = getattr(attr_itr, atom)
    return attr_itr


@unittest.skipIf(not torch.cuda.is_available(), "Skip because CUDA is not available")
@@ -82,13 +80,11 @@
            end_event = torch.cuda.Event(enable_timing=True)
            start_event.record()
            outputs = trt_mod(*cuda_inputs)
            end_event.record()
            torch.cuda.synchronize()
-            _LOGGER.info(
-                f"TRT run time(s)= {(start_event.elapsed_time(end_event) * 1.0e-3)}"
-            )
+            _LOGGER.info(f"TRT run time(s)= {(start_event.elapsed_time(end_event) * 1.0e-3)}")

            if isinstance(outputs, torch.Tensor):
                ref_outputs = [ref_outputs]
                outputs = [outputs]
            for out, ref in zip(outputs, ref_outputs):
@@ -126,26 +122,22 @@
            mod.eval()
            if len(expected_ops):
                self.assert_has_op(mod, expected_ops)

            interpreter_result = interpreter.run(
-                lower_precision=LowerPrecision.FP16
-                if fp16_mode
-                else LowerPrecision.FP32
+                lower_precision=LowerPrecision.FP16 if fp16_mode else LowerPrecision.FP32
            )
            trt_mod = TRTModule(
                interpreter_result.engine,
                interpreter_result.input_names,
                interpreter_result.output_names,
            )
            res_trt = trt_mod(*cuda_inputs).cpu()
            res_cpu = mod(*inputs)
            assert len(res_trt) == len(res_cpu)
            assert len(res_cpu) == len(comparators)
-            for output_trt, output_cpu, comparator in zip(
-                res_trt, res_cpu, comparators
-            ):
+            for output_trt, output_cpu, comparator in zip(res_trt, res_cpu, comparators):
                comp_func = comparator[0]
                args = comparator[1]
                self.assertTrue(comp_func(output_trt, output_cpu, *args))

    def run_test_with_error(self, mod, inputs, interpreter, expect_error):
@@ -165,13 +157,11 @@
            if node.op == "call_module":
                ops_in_mod.add(type(fetch_attr(mod, node.target)))
            elif node.op in {"call_function", "call_method"}:
                ops_in_mod.add(node.target)

-        self.assertTrue(
-            ops_in_mod >= ops, f"expected ops {ops}, actuall ops {ops_in_mod}"
-        )
+        self.assertTrue(ops_in_mod >= ops, f"expected ops {ops}, actuall ops {ops_in_mod}")

    def assert_unexpected_op(self, mod, ops):
        for node in mod.graph.nodes:
            if node.op == "call_module":
                if type(fetch_attr(mod, node.target)) in ops:
@@ -204,13 +194,11 @@
        # after we refactor the internal callsites to use this file
        mod = torch.fx.symbolic_trace(mod)
        shape_prop.ShapeProp(mod).propagate(*inputs)
        mod = NormalizeArgs(mod).transform()
        interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs))
-        super().run_test_custom_compare_results(
-            mod, inputs, expected_ops, interp, comparators, fp16_mode=fp16_mode
-        )
+        super().run_test_custom_compare_results(mod, inputs, expected_ops, interp, comparators, fp16_mode=fp16_mode)


class AccTestCase(TRTTestCase):
    def run_test(
        self,
@@ -233,41 +221,31 @@
            pass_tracer = chain_passes(*apply_passes)
            mod = pass_tracer(mod, inputs)

        if test_implicit_batch_dim:
            interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs))
-            super().run_test(
-                mod, inputs, expected_ops, unexpected_ops, interp, rtol, atol, precision
-            )
+            super().run_test(mod, inputs, expected_ops, unexpected_ops, interp, rtol, atol, precision)

        if test_explicit_batch_dim:
-            interp = TRTInterpreter(
-                mod, InputTensorSpec.from_tensors(inputs), explicit_batch_dimension=True
-            )
-            super().run_test(
-                mod, inputs, expected_ops, unexpected_ops, interp, rtol, atol, precision
-            )
+            interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs), explicit_batch_dimension=True)
+            super().run_test(mod, inputs, expected_ops, unexpected_ops, interp, rtol, atol, precision)

        if test_explicit_precision:
            interp = TRTInterpreter(
                mod,
                InputTensorSpec.from_tensors(inputs),
                explicit_precision=test_explicit_precision,
            )
-            super().run_test(
-                mod, inputs, expected_ops, unexpected_ops, interp, rtol, atol
-            )
+            super().run_test(mod, inputs, expected_ops, unexpected_ops, interp, rtol, atol)

            interp = TRTInterpreter(
                mod,
                InputTensorSpec.from_tensors(inputs),
                explicit_batch_dimension=True,
                explicit_precision=test_explicit_precision,
            )
-            super().run_test(
-                mod, inputs, expected_ops, unexpected_ops, interp, rtol, atol, precision
-            )
+            super().run_test(mod, inputs, expected_ops, unexpected_ops, interp, rtol, atol, precision)

    def run_test_with_assert_error(
        self,
        mod,
        inputs,
@@ -281,13 +259,11 @@
        if test_implicit_batch_dim:
            interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs))
            super().run_test_with_error(mod, inputs, interp, expect_error)

        if test_explicit_batch_dim:
-            interp = TRTInterpreter(
-                mod, InputTensorSpec.from_tensors(inputs), explicit_batch_dimension=True
-            )
+            interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs), explicit_batch_dimension=True)
            super().run_test_with_error(mod, inputs, interp, expect_error)

    def run_test_with_dynamic_shape(
        self,
        mod,
--- py/torch_tensorrt/fx/tools/trt_splitter.py	2022-08-12 18:48:48.794461 +0000
+++ py/torch_tensorrt/fx/tools/trt_splitter.py	2022-08-12 18:52:27.864356 +0000
@@ -72,13 +72,11 @@
            operator_support,
            settings,
            non_acc_submodule_name="_run_on_gpu_",
        )

-    def _lower_model_to_backend(
-        self, mod: torch.fx.GraphModule, inputs: Iterable[torch.Tensor]
-    ):
+    def _lower_model_to_backend(self, mod: torch.fx.GraphModule, inputs: Iterable[torch.Tensor]):
        """
        Lower a GraphModule `mod` to TensorRT with `inputs`.
        """
        # Current code for lowering is place-holder, subject to future change
        # based on feeds model's actual status
--- py/torch_tensorrt/fx/tools/trt_minimizer.py	2022-08-12 18:48:48.794461 +0000
+++ py/torch_tensorrt/fx/tools/trt_minimizer.py	2022-08-12 18:52:27.879445 +0000
@@ -8,16 +8,12 @@
from .. import InputTensorSpec, TRTInterpreter, TRTModule

_LOGGER: logging.Logger = logging.getLogger(__name__)


-def lower_mod_default(
-    mod: torch.fx.GraphModule, inputs: Tensors, batch_size: Any = 2048
-) -> TRTModule:
-    interp = TRTInterpreter(
-        mod, InputTensorSpec.from_tensors(inputs), explicit_batch_dimension=True
-    )
+def lower_mod_default(mod: torch.fx.GraphModule, inputs: Tensors, batch_size: Any = 2048) -> TRTModule:
+    interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs), explicit_batch_dimension=True)
    interpreter_result = interp.run(max_batch_size=batch_size)
    res_mod = TRTModule(
        interpreter_result.engine,
        interpreter_result.input_names,
        interpreter_result.output_names,
@@ -37,13 +33,11 @@
        module: torch.fx.GraphModule,
        sample_input: Tensors,
        compare_fn: Callable[[Any, Any, Any], Tuple[float, bool]],
        settings: TensorRTMinizerSetting = TensorRTMinizerSetting(),
        max_batch_size: Any = 2048,
-        lower_fn: Callable[
-            [torch.fx.GraphModule, Tensors, Any], TRTModule
-        ] = lower_mod_default,
+        lower_fn: Callable[[torch.fx.GraphModule, Tensors, Any], TRTModule] = lower_mod_default,
    ):
        self.lower_fn = lower_fn
        self.max_batch_size = max_batch_size
        super().__init__(module, sample_input, compare_fn, settings)

@@ -56,13 +50,11 @@
        mod.eval()
        try:
            mod = self.lower_fn(mod, inputs, self.max_batch_size)
            output = mod(*inputs)
        except RuntimeError as e:
-            raise net_min_base.FxNetMinimizerRunFuncError(
-                f"Encounter an error when processing \n{mod.graph}\n {e}"
-            )
+            raise net_min_base.FxNetMinimizerRunFuncError(f"Encounter an error when processing \n{mod.graph}\n {e}")
        else:
            return output

    def get_nodes(self, start=None, end=None, enable_print=False):
        nodes = self._collect_nodes(start, end)
--- py/torch_tensorrt/fx/tracer/acc_tracer/acc_tracer.py	2022-08-12 18:48:48.794461 +0000
+++ py/torch_tensorrt/fx/tracer/acc_tracer/acc_tracer.py	2022-08-12 18:52:28.165165 +0000
@@ -41,13 +41,11 @@
    def __init__(self):
        super().__init__()
        self.exceptions_rewritten: Set[Type[Exception]] = set()
        self.exceptions_bool_rewritten: Set[Type[Exception]] = set()

-    def rewrite(
-        self, fn: FunctionType
-    ) -> Tuple[FunctionType, Set[Type[Exception]], Set[Type[Exception]]]:
+    def rewrite(self, fn: FunctionType) -> Tuple[FunctionType, Set[Type[Exception]], Set[Type[Exception]]]:

        # Normalize the source lines
        sourcelines, _ = inspect.getsourcelines(fn)
        sourcelines = normalize_source_lines(sourcelines)
        source = "".join(sourcelines)
@@ -139,12 +137,11 @@
            return if_node

        # Check that we actually have a builtin exception.
        if (
            not issubclass(exc_type, Exception)
-            or getattr(getattr(exc_type, "__class__", None), "__module__", None)
-            != "builtins"
+            or getattr(getattr(exc_type, "__class__", None), "__module__", None) != "builtins"
        ):
            return if_node

        # We need a ConditionalExceptionWrapper specialized for every kind of
        # exception, so add it to exceptions_rewritten to remember for later to
@@ -156,23 +153,17 @@
        # the If with, with args set as the If's condition and the string of the
        # exception. The call to the self._conditional_exception_wrapper_*Error
        # module is safe because the RewrittenModule will add it as an attr
        # based on the returned exceptions_rewritten, and we assume we are
        # currently modifying the AST of a method from a RewrittenModule.
-        exc_wrapper_node = ast.parse(
-            f"self.{_get_exception_wrapper_attr_name(exc_type)}()", mode="eval"
-        )
+        exc_wrapper_node = ast.parse(f"self.{_get_exception_wrapper_attr_name(exc_type)}()", mode="eval")
        assert isinstance(exc_wrapper_node, ast.Expression)
        exc_wrapper_call_node = exc_wrapper_node.body
        assert isinstance(exc_wrapper_call_node, ast.Call)
-        if isinstance(if_node.test, ast.BoolOp) and isinstance(
-            if_node.test.op, ast.And
-        ):
+        if isinstance(if_node.test, ast.BoolOp) and isinstance(if_node.test.op, ast.And):
            self.exceptions_bool_rewritten.add(exc_type)
-            bool_wrapper_node = ast.parse(
-                f"self.{_get_exception_wrapper_attr_name(exc_type)}_bool()", mode="eval"
-            )
+            bool_wrapper_node = ast.parse(f"self.{_get_exception_wrapper_attr_name(exc_type)}_bool()", mode="eval")
            assert isinstance(exc_wrapper_node, ast.Expression)
            bool_wrapper_call_node = bool_wrapper_node.body
            assert isinstance(exc_wrapper_call_node, ast.Call)
            bool_wrapper_call_node.args = if_node.test.values
            exc_wrapper_call_node.args = [
@@ -323,13 +314,11 @@
            name_target[-1] == "_"
            and name_target[0] != "_"
            and not (name_target in allow_list)
            and kind != "placeholder"
        ):
-            raise RuntimeError(
-                f"Tried to trace mutable operation {name_target}. FX only supports functional code"
-            )
+            raise RuntimeError(f"Tried to trace mutable operation {name_target}. FX only supports functional code")

        return self.graph.create_node(kind, target, args, kwargs, name, type_expr)


# List of modules that need rewriting to be supported for tracing.
@@ -384,13 +373,11 @@
            # Write all of the non-dunder or special methods from base_class
            # into RewrittenModule.
            for method_name in dir(base_class):
                method = getattr(base_class, method_name, None)
                if method is None and method_name not in {"__doc__"}:
-                    _LOGGER.warning(
-                        f"{__qualname__} does not have attribute {method_name}"
-                    )
+                    _LOGGER.warning(f"{__qualname__} does not have attribute {method_name}")

                if builtins.type(method) is not FunctionType:
                    continue

                # Always skip rewriting dunder methods, as they haven't (yet) been
@@ -437,13 +424,11 @@
                # Recursively rewrite and copy all module attrs of this module.
                for k, v in orig.__dict__.items():
                    if k == "_modules":
                        for mod_k, mod_v in v.items():
                            if getattr(mod_v, "_base_class_origin", type(mod_v)) in leaf_module_list:  # type: ignore[operator]
-                                _LOGGER.info(
-                                    f"Skip rewriting leaf module {type(mod_v)}"
-                                )
+                                _LOGGER.info(f"Skip rewriting leaf module {type(mod_v)}")
                                self._modules[mod_k] = mod_v
                            else:
                                self._modules[mod_k] = rewrite_module(mod_v)
                    else:
                        self.__dict__[k] = v
@@ -475,25 +460,21 @@
    """
    changed = False
    for node in reversed(gm.graph.nodes):
        if node.op == "call_module" and (
            isinstance(gm.get_submodule(node.target), ConditionalExceptionWrapper)
-            or isinstance(
-                gm.get_submodule(node.target), ConditionalExceptionBoolCondWrapper
-            )
+            or isinstance(gm.get_submodule(node.target), ConditionalExceptionBoolCondWrapper)
        ):
            gm.graph.erase_node(node)
            changed = True
    return changed


def _replace_tensor_meta_with_rank(gm: torch.fx.GraphModule):
    for node in gm.graph.nodes:
        if node.op != "output" and "tensor_meta" in node.meta:
-            node.meta["tensor_rank"] = acc_utils.map_tensor_metadata(
-                node.meta["tensor_meta"], lambda x: len(x.shape)
-            )
+            node.meta["tensor_rank"] = acc_utils.map_tensor_metadata(node.meta["tensor_meta"], lambda x: len(x.shape))
            del node.meta["tensor_meta"]


def rewriter_base_trace(mod, ast_rewriter_allow_list, leaf_module_list):
    rewritten_graph, rewritten_mod = AccRewritingTracer().trace(

@narendasan
Copy link
Collaborator

@frank-wei can you share the fb lint config or something so that we can use a consistent code style?

@frank-wei frank-wei marked this pull request as ready for review August 12, 2022 19:15
@frank-wei
Copy link
Contributor Author

@frank-wei can you share the fb lint config or something so that we can use a consistent code style?

They are using black but looks like it is more than that.

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- py/torch_tensorrt/fx/input_tensor_spec.py	2022-08-12 19:16:11.712868 +0000
+++ py/torch_tensorrt/fx/input_tensor_spec.py	2022-08-12 19:19:58.709915 +0000
@@ -6,14 +6,11 @@
from .utils import get_dynamic_dims


def generate_input_specs(inputs, lower_setting, additional_inputs=None):
    # dynamic_batch is TRT only flag.
-    if (
-        not lower_setting.explicit_batch_dimension
-        or lower_setting.dynamic_batch is False
-    ):
+    if not lower_setting.explicit_batch_dimension or lower_setting.dynamic_batch is False:
        return InputTensorSpec.from_tensors(inputs)

    # If we don't have additional inputs, we assume the first dimension
    # is the dynamic batch dimension. Otherwise, we use the additional
    # inputs to determine the batch dimension.
@@ -33,20 +30,16 @@
        for i, j in zip(inputs, additional_inputs):
            found_batch_dim = False

            for idx, values in enumerate(zip(i.shape, j.shape)):
                if values[0] != values[1]:
-                    assert (
-                        found_batch_dim is False
-                    ), f"We've already found a batch dim, {i.shape}, {j.shape}."
+                    assert found_batch_dim is False, f"We've already found a batch dim, {i.shape}, {j.shape}."
                    batch_dims.append(idx)
                    found_batch_dim = True

            if not found_batch_dim:
-                raise RuntimeError(
-                    f"Failed to find batch dimension because shapes are the same, {i.shape}"
-                )
+                raise RuntimeError(f"Failed to find batch dimension because shapes are the same, {i.shape}")

        return InputTensorSpec.from_tensors_with_dynamic_batch_size(
            inputs,
            (
                0,
@@ -158,13 +151,11 @@
                batch_dim
            ), f"The {i}th tensor (shape: {tensor.shape}) doesn't have the correct batch size: {batch_size}."
            shape = list(tensor.shape)
            shape[batch_dim] = -1
            shape_ranges: List[ShapeRange] = [tuple(tuple(shape[0:batch_dim] + [bs] + shape[batch_dim + 1 :]) for bs in batch_size_range)] * opt_profile_replica  # type: ignore[list-item]
-            input_specs.append(
-                cls(tuple(shape), tensor.dtype, tensor.device, shape_ranges)
-            )
+            input_specs.append(cls(tuple(shape), tensor.dtype, tensor.device, shape_ranges))

        return input_specs

    def to_random_tensor(self):
        shape = tuple(self.shape)
--- py/torch_tensorrt/fx/lower.py	2022-08-12 19:16:11.712868 +0000
+++ py/torch_tensorrt/fx/lower.py	2022-08-12 19:19:58.801249 +0000
@@ -77,13 +77,11 @@
    lower_setting: LowerSetting
    timing_cache_manager: TimingCacheManager

    @classmethod
    def create(cls, lower_setting):
-        timing_cache_manager = TimingCacheManager(
-            lower_setting.timing_cache_prefix, lower_setting.save_timing_cache
-        )
+        timing_cache_manager = TimingCacheManager(lower_setting.timing_cache_prefix, lower_setting.save_timing_cache)
        return LowerTrtInterpreter(lower_setting, timing_cache_manager)

    def __call__(self, mod, input, split_name) -> TRTInterpreterResult:
        assert self.lower_setting.input_specs, "Can't find input specs for lowering!"
        logger.info(f"{split_name=} {self.lower_setting.input_specs=}")
@@ -103,13 +101,11 @@
        interpreter = TRTInterpreter(
            mod,
            input_specs=self.lower_setting.input_specs,
            explicit_batch_dimension=self.lower_setting.explicit_batch_dimension,
            explicit_precision=self.lower_setting.explicit_precision,
-            logger_level=trt.Logger.VERBOSE
-            if self.lower_setting.verbose_log
-            else trt.Logger.WARNING,
+            logger_level=trt.Logger.VERBOSE if self.lower_setting.verbose_log else trt.Logger.WARNING,
        )

        interp_result: TRTInterpreterResult = interpreter.run(
            max_batch_size=self.lower_setting.max_batch_size,
            max_workspace_size=self.lower_setting.max_workspace_size,
@@ -129,13 +125,11 @@
            self.timing_cache_manager.update_timing_cache(split_name, timing_cache)

        return interp_result


-def default_split_function(
-    model: fx.GraphModule, inputs: Input, lower_setting: LowerSetting
-) -> SplitResult:
+def default_split_function(model: fx.GraphModule, inputs: Input, lower_setting: LowerSetting) -> SplitResult:
    splitter_setting = TRTSplitterSetting()
    splitter_setting.use_implicit_batch_dim = not lower_setting.explicit_batch_dimension
    splitter_setting.min_acc_module_size = lower_setting.min_acc_module_size
    splitter = TRTSplitter(model, inputs, settings=splitter_setting)
    splitter.node_support_preview()
@@ -147,13 +141,11 @@


def default_lower_pass(
    create_trt_interpreter: Callable[[LowerSetting], LowerTrtInterpreter],
) -> PassFunc:
-    def lower_pass(
-        mod: nn.Module, input: Input, lower_setting: LowerSetting, module_name: str
-    ) -> nn.Module:
+    def lower_pass(mod: nn.Module, input: Input, lower_setting: LowerSetting, module_name: str) -> nn.Module:
        """
        Create a module transformation pass which lowers an `fx.GraphModule` into a
        `TRTModule`
        """
        interpreter = create_trt_interpreter(lower_setting)
@@ -223,21 +215,13 @@
        inputs: Input,
        additional_inputs: Optional[Input] = None,
    ) -> nn.Module:
        module.eval()

-        if (
-            self.lower_pass_manager_builder.lower_setting.lower_precision
-            == LowerPrecision.FP16
-        ):
+        if self.lower_pass_manager_builder.lower_setting.lower_precision == LowerPrecision.FP16:
            module.half()
-            inputs = tuple(
-                x.half() if x is not None and x.dtype == torch.float32 else x
-                for x in inputs
-            )
-        pm = self.lower_pass_manager_builder.build_trt_lower_pipeline(
-            inputs, additional_inputs
-        )
+            inputs = tuple(x.half() if x is not None and x.dtype == torch.float32 else x for x in inputs)
+        pm = self.lower_pass_manager_builder.build_trt_lower_pipeline(inputs, additional_inputs)

        lower_result = pm(module)

        return lower_result
--- py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py	2022-08-12 19:16:11.712868 +0000
+++ py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py	2022-08-12 19:19:59.035315 +0000
@@ -35,23 +35,17 @@
# >>> with FUSE_PASSES_POST_OBSERVER.add(print_module_and_input):
# >>>     # print_module_and_input will be called right after the fuse passes
# >>>     lower(module, sample_input)

# Observer for the model after the fuse passes.
-FUSE_PASSES_POST_OBSERVER: Observer[Callable[[nn.Module, Input], None]] = Observer(
-    "FUSE_PASSES_POST_OBSERVER"
-)
+FUSE_PASSES_POST_OBSERVER: Observer[Callable[[nn.Module, Input], None]] = Observer("FUSE_PASSES_POST_OBSERVER")

# Observer for the TRT split submodules before lowering
-LOWER_SPLIT_PRE_OBSERVER: Observer[Callable[[str, nn.Module, Input], None]] = Observer(
-    "LOWER_SPLIT_PRE_OBSERVER"
-)
+LOWER_SPLIT_PRE_OBSERVER: Observer[Callable[[str, nn.Module, Input], None]] = Observer("LOWER_SPLIT_PRE_OBSERVER")

# Observer for the TRT split submodules after lowering
-LOWER_SPLIT_POST_OBSERVER: Observer[Callable[[str, nn.Module, Input], None]] = Observer(
-    "LOWER_SPLIT_POST_OBSERVER"
-)
+LOWER_SPLIT_POST_OBSERVER: Observer[Callable[[str, nn.Module, Input], None]] = Observer("LOWER_SPLIT_POST_OBSERVER")
# ----------------------------------------------------------------------


def wrapper(fn: Callable, input) -> Callable:
    @wraps(fn)
@@ -103,22 +97,16 @@
            passes.append(wrapper(p, self._input))
        for p in self.lower_setting.lower_basic_fuse_pass.passes:
            passes.append(wrapper(p, self._input))

        passes.append(inplace_wrapper(common_subexpression_elimination))
-        passes.append(
-            inplace_wrapper(lambda m: FUSE_PASSES_POST_OBSERVER.observe(m, self._input))
-        )
+        passes.append(inplace_wrapper(lambda m: FUSE_PASSES_POST_OBSERVER.observe(m, self._input)))

        return PassManager.build_from_passlist(passes)

    def _split_pass(self) -> PassManager:
-        passes = [
-            partial(
-                self._split_func, inputs=self._input, lower_setting=self.lower_setting
-            )
-        ]
+        passes = [partial(self._split_func, inputs=self._input, lower_setting=self.lower_setting)]
        passes.append(
            inplace_wrapper(
                lambda split_result: remove_duplicate_output_args(
                    split_result.split_module, split_result.submodule_inputs.keys()
                )
@@ -152,21 +140,15 @@
                    lowering_start_time = datetime.datetime.now()

                    self.lower_setting.input_specs = generate_input_specs(
                        submod_inputs,
                        self.lower_setting,
-                        additional_submodule_inputs[submod_name]
-                        if additional_submodule_inputs
-                        else None,
+                        additional_submodule_inputs[submod_name] if additional_submodule_inputs else None,
                    )
-                    lowered_module = self._lower_func(
-                        submod, submod_inputs, self.lower_setting, submod_name
-                    )
+                    lowered_module = self._lower_func(submod, submod_inputs, self.lower_setting, submod_name)
                    setattr(split_result.split_module, submod_name, lowered_module)
-                    LOWER_SPLIT_POST_OBSERVER.observe(
-                        submod_name, lowered_module, submod_inputs
-                    )
+                    LOWER_SPLIT_POST_OBSERVER.observe(submod_name, lowered_module, submod_inputs)
                    _LOGGER.info(
                        f"Lowering submodule {submod_name} elapsed time {datetime.datetime.now() - lowering_start_time}"
                    )

            return split_result.split_module
@@ -184,28 +166,22 @@
                # Only acc submodules will be lowered.
                if not submod_name.startswith(split_result.non_acc_submodule_prefix):
                    _LOGGER.info(f"Now lowering submodule {submod_name}")
                    lowering_start_time = datetime.datetime.now()

-                    lowered_module = self._lower_func(
-                        submod, submod_inputs, self.lower_setting, submod_name
-                    )
+                    lowered_module = self._lower_func(submod, submod_inputs, self.lower_setting, submod_name)
                    setattr(split_result.split_module, submod_name, lowered_module)
-                    LOWER_SPLIT_POST_OBSERVER.observe(
-                        submod_name, lowered_module, submod_inputs
-                    )
+                    LOWER_SPLIT_POST_OBSERVER.observe(submod_name, lowered_module, submod_inputs)
                    _LOGGER.info(
                        f"Lowering submodule {submod_name} elapsed time {datetime.datetime.now() - lowering_start_time}"
                    )

            return split_result.split_module

        return PassManager.build_from_passlist([lower_func])

-    def build_trt_lower_pipeline(
-        self, input: Input, additional_input: Optional[Input] = None
-    ) -> PassManager:
+    def build_trt_lower_pipeline(self, input: Input, additional_input: Optional[Input] = None) -> PassManager:
        self._input = input
        self._additional_input = additional_input
        passes = []

        passes.append(self._const_fold_pass())
@@ -214,13 +190,11 @@
        passes.append(self._trt_lower_pass())

        pm = PassManager.build_from_passlist(passes)
        return pm

-    def build_default_lower_pipeline(
-        self, input: Input, additional_input: Optional[Input] = None
-    ) -> PassManager:
+    def build_default_lower_pipeline(self, input: Input, additional_input: Optional[Input] = None) -> PassManager:
        self._input = input
        self._additional_input = additional_input
        passes = []

        passes.append(self._const_fold_pass())
--- py/torch_tensorrt/fx/test/converters/acc_op/test_avgpool.py	2022-08-12 19:16:11.712868 +0000
+++ py/torch_tensorrt/fx/test/converters/acc_op/test_avgpool.py	2022-08-12 19:19:59.232592 +0000
@@ -27,13 +27,11 @@
        count_include_pad=True,
    ):
        class TestModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
-                self.avg_pool = torch.nn.AvgPool1d(
-                    kernel_size, stride, padding, ceil_mode, count_include_pad
-                )
+                self.avg_pool = torch.nn.AvgPool1d(kernel_size, stride, padding, ceil_mode, count_include_pad)

            def forward(self, x):
                return self.avg_pool(x)

        inputs = [torch.randn(1, 3, 224)]
@@ -60,13 +58,11 @@
        count_include_pad=True,
    ):
        class TestModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
-                self.avg_pool = torch.nn.AvgPool1d(
-                    kernel_size, stride, padding, ceil_mode, count_include_pad
-                )
+                self.avg_pool = torch.nn.AvgPool1d(kernel_size, stride, padding, ceil_mode, count_include_pad)

            def forward(self, x):
                return self.avg_pool(x)

        input_specs = [
@@ -75,13 +71,11 @@
                dtype=torch.float32,
                shape_ranges=[((1, 3, 3), (3, 3, 3), (3, 3, 3))],
            ),
        ]

-        self.run_test_with_dynamic_shape(
-            TestModule(), input_specs, expected_ops={acc_ops.avg_pool1d}
-        )
+        self.run_test_with_dynamic_shape(TestModule(), input_specs, expected_ops={acc_ops.avg_pool1d})

    def test_avg_pool2d_with_dynamic_shape_four_dimensions(
        self,
        test_name="default",
        kernel_size=1,
@@ -112,13 +106,11 @@
                dtype=torch.float32,
                shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (5, 5, 5, 5))],
            ),
        ]

-        self.run_test_with_dynamic_shape(
-            TestModule(), input_specs, expected_ops={acc_ops.avg_pool2d}
-        )
+        self.run_test_with_dynamic_shape(TestModule(), input_specs, expected_ops={acc_ops.avg_pool2d})

    @parameterized.expand(
        [
            ("default", 1),
            ("kernal_size", 3),
@@ -254,12 +246,10 @@
                dtype=torch.float32,
                shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (5, 5, 5, 5))],
            ),
        ]

-        self.run_test_with_dynamic_shape(
-            TestModule(), input_specs, expected_ops={acc_ops.avg_pool2d}
-        )
+        self.run_test_with_dynamic_shape(TestModule(), input_specs, expected_ops={acc_ops.avg_pool2d})


if __name__ == "__main__":
    run_tests()
--- py/torch_tensorrt/fx/test/converters/acc_op/test_batchnorm.py	2022-08-12 19:16:11.712868 +0000
+++ py/torch_tensorrt/fx/test/converters/acc_op/test_batchnorm.py	2022-08-12 19:19:59.271217 +0000
@@ -32,13 +32,11 @@
                dtype=torch.float32,
                shape_ranges=[((2, 3, 5), (6, 3, 5), (10, 3, 5))],
            ),
        ]

-        self.run_test_with_dynamic_shape(
-            TestModule(), input_specs, expected_ops={acc_ops.batch_norm}
-        )
+        self.run_test_with_dynamic_shape(TestModule(), input_specs, expected_ops={acc_ops.batch_norm})

    def test_batchnorm_with_dynamic_shape(self):
        class TestModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
@@ -53,13 +51,11 @@
                dtype=torch.float32,
                shape_ranges=[((1, 3, 1, 1), (1, 3, 5, 5), (2, 3, 10, 10))],
            ),
        ]

-        self.run_test_with_dynamic_shape(
-            TestModule(), input_specs, expected_ops={acc_ops.batch_norm}
-        )
+        self.run_test_with_dynamic_shape(TestModule(), input_specs, expected_ops={acc_ops.batch_norm})

    # Testing with shape=(-1, -1, -1, -1) results in AssertionError: Channel dim can't be dynamic for batch norm.


if __name__ == "__main__":
--- py/torch_tensorrt/fx/test/converters/acc_op/test_clamp.py	2022-08-12 19:16:11.712868 +0000
+++ py/torch_tensorrt/fx/test/converters/acc_op/test_clamp.py	2022-08-12 19:19:59.400355 +0000
@@ -51,12 +51,10 @@
                dtype=torch.float32,
                shape_ranges=[((1, 1, 3, 3), (3, 3, 3, 3), (5, 5, 3, 3))],
            ),
        ]

-        self.run_test_with_dynamic_shape(
-            TestModule(), input_specs, expected_ops={acc_ops.clamp}
-        )
+        self.run_test_with_dynamic_shape(TestModule(), input_specs, expected_ops={acc_ops.clamp})


if __name__ == "__main__":
    run_tests()
--- py/torch_tensorrt/fx/test/converters/acc_op/test_convolution.py	2022-08-12 19:16:11.712868 +0000
+++ py/torch_tensorrt/fx/test/converters/acc_op/test_convolution.py	2022-08-12 19:19:59.500742 +0000
@@ -27,13 +27,11 @@
        bias=True,
    ):
        class TestModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
-                self.conv = torch.nn.Conv1d(
-                    3, 6, kernel_size, stride, padding, dilation, groups, bias
-                )
+                self.conv = torch.nn.Conv1d(3, 6, kernel_size, stride, padding, dilation, groups, bias)

            def forward(self, x):
                return self.conv(x)

        inputs = [torch.randn(1, 3, 32)]
@@ -60,13 +58,11 @@
        bias=True,
    ):
        class TestModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
-                self.conv = torch.nn.Conv1d(
-                    3, 6, kernel_size, stride, padding, dilation, groups, bias
-                )
+                self.conv = torch.nn.Conv1d(3, 6, kernel_size, stride, padding, dilation, groups, bias)

            def forward(self, x):
                return self.conv(x)

        input_specs = [
@@ -75,13 +71,11 @@
                dtype=torch.float32,
                shape_ranges=[((1, 3, 3), (3, 3, 3), (5, 3, 3))],
            ),
        ]

-        self.run_test_with_dynamic_shape(
-            TestModule(), input_specs, expected_ops={acc_ops.conv1d}
-        )
+        self.run_test_with_dynamic_shape(TestModule(), input_specs, expected_ops={acc_ops.conv1d})

    @parameterized.expand(
        [
            ("default", 1),
            param("no_bias", 1, bias=False),
@@ -102,13 +96,11 @@
        bias=True,
    ):
        class TestModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
-                self.conv = torch.nn.Conv2d(
-                    3, 6, kernel_size, stride, padding, dilation, groups, bias
-                )
+                self.conv = torch.nn.Conv2d(3, 6, kernel_size, stride, padding, dilation, groups, bias)

            def forward(self, x):
                return self.conv(x)

        inputs = [torch.randn(1, 3, 32, 32)]
@@ -131,13 +123,11 @@
                shape=(-1, 3, -1, -1),
                dtype=torch.float32,
                shape_ranges=[((1, 3, 1, 1), (1, 3, 4, 4), (32, 3, 128, 128))],
            ),
        ]
-        self.run_test_with_dynamic_shape(
-            TestModule(), input_specs, expected_ops={acc_ops.conv2d}
-        )
+        self.run_test_with_dynamic_shape(TestModule(), input_specs, expected_ops={acc_ops.conv2d})

    @parameterized.expand(
        [
            ("default", 1),
            param("no_bias", 1, bias=False),
@@ -158,13 +148,11 @@
        bias=True,
    ):
        class TestModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
-                self.conv = torch.nn.Conv3d(
-                    3, 6, kernel_size, stride, padding, dilation, groups, bias
-                )
+                self.conv = torch.nn.Conv3d(3, 6, kernel_size, stride, padding, dilation, groups, bias)

            def forward(self, x):
                return self.conv(x)

        inputs = [torch.randn(1, 3, 32, 32, 32)]
@@ -187,12 +175,10 @@
                shape=(-1, 3, -1, -1, -1),
                dtype=torch.float32,
                shape_ranges=[((1, 3, 1, 1, 1), (1, 3, 4, 4, 4), (8, 3, 32, 32, 32))],
            ),
        ]
-        self.run_test_with_dynamic_shape(
-            TestModule(), input_specs, expected_ops={acc_ops.conv3d}
-        )
+        self.run_test_with_dynamic_shape(TestModule(), input_specs, expected_ops={acc_ops.conv3d})


if __name__ == "__main__":
    run_tests()
--- py/torch_tensorrt/fx/test/converters/acc_op/test_gelu.py	2022-08-12 19:16:11.712868 +0000
+++ py/torch_tensorrt/fx/test/converters/acc_op/test_gelu.py	2022-08-12 19:19:59.662120 +0000
@@ -5,13 +5,11 @@
import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase, InputTensorSpec


-@unittest.skip(
-    reason="Could not find CustomGeluPluginDynamic. Enable it once we upgrade TRT to 8.4"
-)
+@unittest.skip(reason="Could not find CustomGeluPluginDynamic. Enable it once we upgrade TRT to 8.4")
class TestGELU(AccTestCase):
    def test_gelu(self):
        class TestModule(nn.Module):
            def forward(self, x):
                return nn.functional.gelu(x)
@@ -34,13 +32,11 @@
                shape=(-1, -1, -1),
                dtype=torch.float32,
                shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))],
            ),
        ]
-        self.run_test_with_dynamic_shape(
-            TestModule(), input_specs, expected_ops={acc_ops.gelu}
-        )
+        self.run_test_with_dynamic_shape(TestModule(), input_specs, expected_ops={acc_ops.gelu})

    def test_gelu_with_dynamic_shape_four_dimensions(self):
        class TestModule(nn.Module):
            def forward(self, x):
                return nn.functional.gelu(x)
@@ -51,12 +47,10 @@
                dtype=torch.float32,
                shape_ranges=[((1, 1, 1, 1), (1, 2, 3, 3), (3, 3, 3, 3))],
            ),
        ]

-        self.run_test_with_dynamic_shape(
-            TestModule(), input_specs, expected_ops={acc_ops.gelu}
-        )
+        self.run_test_with_dynamic_shape(TestModule(), input_specs, expected_ops={acc_ops.gelu})


if __name__ == "__main__":
    run_tests()
--- py/torch_tensorrt/fx/test/converters/acc_op/test_interpolate.py	2022-08-12 19:16:11.712868 +0000
+++ py/torch_tensorrt/fx/test/converters/acc_op/test_interpolate.py	2022-08-12 19:19:59.880992 +0000
@@ -131,12 +131,10 @@
                dtype=torch.float32,
                shape_ranges=[((1, 1, 1, 1), (1, 2, 3, 3), (3, 3, 3, 3))],
            ),
        ]

-        self.run_test_with_dynamic_shape(
-            Interpolate(), input_specs, expected_ops={acc_ops.interpolate}
-        )
+        self.run_test_with_dynamic_shape(Interpolate(), input_specs, expected_ops={acc_ops.interpolate})


if __name__ == "__main__":
    run_tests()
--- py/torch_tensorrt/fx/test/converters/acc_op/test_matmul.py	2022-08-12 19:16:11.712868 +0000
+++ py/torch_tensorrt/fx/test/converters/acc_op/test_matmul.py	2022-08-12 19:20:00.162992 +0000
@@ -71,15 +71,11 @@
        class MatMul(nn.Module):
            def forward(self, input, other):
                return torch.matmul(input, other)

        inputs = [torch.randn(*input_shape), torch.randn(*other_shape)]
-        test_implicit_batch_dim = (
-            input_shape[0] == other_shape[0]
-            and len(input_shape) > 2
-            and len(other_shape) > 2
-        )
+        test_implicit_batch_dim = input_shape[0] == other_shape[0] and len(input_shape) > 2 and len(other_shape) > 2
        self.run_test(
            MatMul(),
            inputs,
            expected_ops={acc_ops.matmul},
            test_implicit_batch_dim=test_implicit_batch_dim,
@@ -106,12 +102,10 @@
                dtype=torch.float32,
                shape_ranges=[((1, 1, 3, 3), (9, 4, 3, 3), (9, 4, 3, 3))],
            ),
        ]

-        self.run_test_with_dynamic_shape(
-            Matmul(), input_specs, expected_ops={acc_ops.matmul}
-        )
+        self.run_test_with_dynamic_shape(Matmul(), input_specs, expected_ops={acc_ops.matmul})


if __name__ == "__main__":
    run_tests()
--- py/torch_tensorrt/fx/test/converters/acc_op/test_max.py	2022-08-12 19:16:11.712868 +0000
+++ py/torch_tensorrt/fx/test/converters/acc_op/test_max.py	2022-08-12 19:20:00.274826 +0000
@@ -102,13 +102,11 @@
                dtype=torch.float32,
                shape_ranges=[((1, 1, 5, 5), (2, 3, 5, 5), (2, 3, 5, 5))],
            ),
        ]

-        self.run_test_with_dynamic_shape(
-            MaxDimReduce(), input_specs, expected_ops={acc_ops.max_dim_reduce}
-        )
+        self.run_test_with_dynamic_shape(MaxDimReduce(), input_specs, expected_ops={acc_ops.max_dim_reduce})

    def test_max_full_reduce(
        self,
    ):
        class MaxFullReduce(torch.nn.Module):
@@ -124,13 +122,11 @@
                dtype=torch.float32,
                shape_ranges=[((1, 1, 5, 5), (2, 3, 5, 5), (2, 3, 5, 5))],
            ),
        ]

-        self.run_test_with_dynamic_shape(
-            MaxFullReduce(), input_specs, expected_ops={acc_ops.max_full_reduce}
-        )
+        self.run_test_with_dynamic_shape(MaxFullReduce(), input_specs, expected_ops={acc_ops.max_full_reduce})

    def test_max_method(self):
        class MaxMethod(torch.nn.Module):
            def __init__(self):
                super().__init__()
@@ -149,12 +145,10 @@
                dtype=torch.float32,
                shape_ranges=[((1, 1, 5, 5), (2, 3, 5, 5), (2, 3, 5, 5))],
            ),
        ]

-        self.run_test_with_dynamic_shape(
-            MaxMethod(), input_specs, expected_ops={acc_ops.maximum}
-        )
+        self.run_test_with_dynamic_shape(MaxMethod(), input_specs, expected_ops={acc_ops.maximum})


if __name__ == "__main__":
    run_tests()
--- py/torch_tensorrt/fx/test/converters/acc_op/test_min.py	2022-08-12 19:16:11.712868 +0000
+++ py/torch_tensorrt/fx/test/converters/acc_op/test_min.py	2022-08-12 19:20:00.440385 +0000
@@ -101,13 +101,11 @@
                dtype=torch.float32,
                shape_ranges=[((1, 1, 5, 5), (2, 3, 5, 5), (2, 3, 5, 5))],
            ),
        ]

-        self.run_test_with_dynamic_shape(
-            MinDimReduce(), input_specs, expected_ops={acc_ops.min_dim_reduce}
-        )
+        self.run_test_with_dynamic_shape(MinDimReduce(), input_specs, expected_ops={acc_ops.min_dim_reduce})

    def test_min_full_reduce(
        self,
    ):
        class MinFullReduce(torch.nn.Module):
@@ -123,13 +121,11 @@
                dtype=torch.float32,
                shape_ranges=[((1, 1, 5, 5), (2, 3, 5, 5), (2, 3, 5, 5))],
            ),
        ]

-        self.run_test_with_dynamic_shape(
-            MinFullReduce(), input_specs, expected_ops={acc_ops.min_full_reduce}
-        )
+        self.run_test_with_dynamic_shape(MinFullReduce(), input_specs, expected_ops={acc_ops.min_full_reduce})

    def test_min_method(self):
        class MinMethod(torch.nn.Module):
            def __init__(self):
                super().__init__()
@@ -148,12 +144,10 @@
                dtype=torch.float32,
                shape_ranges=[((1, 1, 5, 5), (2, 3, 5, 5), (2, 3, 5, 5))],
            ),
        ]

-        self.run_test_with_dynamic_shape(
-            MinMethod(), input_specs, expected_ops={acc_ops.minimum}
-        )
+        self.run_test_with_dynamic_shape(MinMethod(), input_specs, expected_ops={acc_ops.minimum})


if __name__ == "__main__":
    run_tests()
--- py/torch_tensorrt/fx/test/converters/acc_op/test_narrow.py	2022-08-12 19:16:11.712868 +0000
+++ py/torch_tensorrt/fx/test/converters/acc_op/test_narrow.py	2022-08-12 19:20:00.490650 +0000
@@ -23,13 +23,11 @@
                dtype=torch.float32,
                shape_ranges=[((1, 3, 5, 5), (2, 3, 5, 5), (2, 3, 5, 5))],
            ),
        ]

-        self.run_test_with_dynamic_shape(
-            Narrow(), input_specs, expected_ops={acc_ops.slice_tensor}
-        )
+        self.run_test_with_dynamic_shape(Narrow(), input_specs, expected_ops={acc_ops.slice_tensor})


class TestNarrowConverter(AccTestCase):
    @parameterized.expand(
        [
--- py/torch_tensorrt/fx/converters/acc_ops_converters.py	2022-08-12 19:16:11.708868 +0000
+++ py/torch_tensorrt/fx/converters/acc_ops_converters.py	2022-08-12 19:20:00.578833 +0000
@@ -34,14 +34,11 @@
    kwargs: Dict[str, Argument],
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    input_val = kwargs["input"]
    if not isinstance(input_val, TRTTensor):
-        raise RuntimeError(
-            f"Conv received input {input_val} that is not part "
-            "of the TensorRT region!"
-        )
+        raise RuntimeError(f"Conv received input {input_val} that is not part " "of the TensorRT region!")

    # Process 1d input with unsqueeze -> conv2d -> squeeze to calculated conv1d
    unsqueeze_layer = network.add_shuffle(input=input_val)
    unsqueeze_layer.reshape_dims = tuple([*input_val.shape, 1])
    set_layer_name(unsqueeze_layer, target, name + "_unsqueeze")
@@ -52,13 +49,11 @@

    # for now we'll assume bias is constant Tensor or None,
    # and bias being ITensor is not supported in TensorRT api
    # right now
    if kwargs["bias"] is not None and not isinstance(kwargs["bias"], torch.Tensor):
-        raise RuntimeError(
-            f"linear {name} has bias of type {type(kwargs['bias'])}, Expect Optional[Tenosr]"
-        )
+        raise RuntimeError(f"linear {name} has bias of type {type(kwargs['bias'])}, Expect Optional[Tenosr]")
    bias = to_numpy(kwargs["bias"])  # type: ignore[arg-type]
    if bias is not None:
        bias = bias[None]
    weight = kwargs["weight"]

@@ -82,13 +77,11 @@
        )

        layer.set_input(1, weight)
    else:
        if not isinstance(kwargs["weight"], torch.Tensor):
-            raise RuntimeError(
-                f"linear {name} has weight of type {type(kwargs['weight'])}, Expect Optional[Tenosr]"
-            )
+            raise RuntimeError(f"linear {name} has weight of type {type(kwargs['weight'])}, Expect Optional[Tenosr]")
        weight = to_numpy(weight)
        weight = np.expand_dims(weight, -1)
        layer = network.add_convolution_nd(
            input=input_val,
            num_output_maps=weight.shape[0],
@@ -126,25 +119,20 @@
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    input_val = kwargs["input"]

    if not isinstance(input_val, TRTTensor):
-        raise RuntimeError(
-            f"Conv received input {input_val} that is not part "
-            "of the TensorRT region!"
-        )
+        raise RuntimeError(f"Conv received input {input_val} that is not part " "of the TensorRT region!")

    if has_dynamic_shape(input_val.shape):
        assert input_val.shape[1] != -1, "Channel dim can't be dynamic for convolution."

    # for now we'll assume bias is constant Tensor or None,
    # and bias being ITensor is not supported in TensorRT api
    # right now
    if kwargs["bias"] is not None and not isinstance(kwargs["bias"], torch.Tensor):
-        raise RuntimeError(
-            f"linear {name} has bias of type {type(kwargs['bias'])}, Expect Optional[Tenosr]"
-        )
+        raise RuntimeError(f"linear {name} has bias of type {type(kwargs['bias'])}, Expect Optional[Tenosr]")
    bias = to_numpy(kwargs["bias"])  # type: ignore[arg-type]

    if network.has_explicit_precision:
        weight = get_trt_tensor(network, kwargs["weight"], f"{name}_weight")
        weight_shape = tuple(kwargs["weight"].shape)  # type: ignore[union-attr]
@@ -160,13 +148,11 @@
        )

        layer.set_input(1, weight)
    else:
        if not isinstance(kwargs["weight"], torch.Tensor):
-            raise RuntimeError(
-                f"linear {name} has weight of type {type(kwargs['weight'])}, Expect Optional[Tenosr]"
-            )
+            raise RuntimeError(f"linear {name} has weight of type {type(kwargs['weight'])}, Expect Optional[Tenosr]")
        weight = to_numpy(kwargs["weight"])
        layer = network.add_convolution_nd(
            input=input_val,
            num_output_maps=weight.shape[0],
            kernel_shape=weight.shape[2:],
@@ -194,27 +180,20 @@
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    input_val = kwargs["input"]

    if not isinstance(input_val, TRTTensor):
-        raise RuntimeError(
-            f"Transpose conv received input {input_val} that is not part "
-            "of the TensorRT region!"
-        )
+        raise RuntimeError(f"Transpose conv received input {input_val} that is not part " "of the TensorRT region!")

    if has_dynamic_shape(input_val.shape):
-        assert (
-            input_val.shape[1] != -1
-        ), "Channel dim can't be dynamic for transpose convolution."
+        assert input_val.shape[1] != -1, "Channel dim can't be dynamic for transpose convolution."

    # for now we'll assume bias is constant Tensor or None,
    # and bias being ITensor is not supported in TensorRT api
    # right now
    if kwargs["bias"] is not None and not isinstance(kwargs["bias"], torch.Tensor):
-        raise RuntimeError(
-            f"ConvTranspose {name} has bias of type {type(kwargs['bias'])}, Expect Optional[Tensor]"
-        )
+        raise RuntimeError(f"ConvTranspose {name} has bias of type {type(kwargs['bias'])}, Expect Optional[Tensor]")
    bias = to_numpy(kwargs["bias"])  # type: ignore[arg-type]

    if network.has_explicit_precision:
        weight = get_trt_tensor(network, kwargs["weight"], f"{name}_weight")
        weight_shape = tuple(kwargs["weight"].shape)  # type: ignore[union-attr]
@@ -232,13 +211,11 @@
        )

        layer.set_input(1, weight)
    else:
        if not isinstance(kwargs["weight"], torch.Tensor):
-            raise RuntimeError(
-                f"conv {name} has weight of type {type(kwargs['weight'])}, Expect Optional[Tensor]"
-            )
+            raise RuntimeError(f"conv {name} has weight of type {type(kwargs['weight'])}, Expect Optional[Tensor]")
        weight = to_numpy(kwargs["weight"])
        # nn.ConvTranspose2d/3d weight size is (in_channels, out_channels/groups, kernel_0, kernel_1, [kernel_2])
        layer = network.add_deconvolution_nd(
            input=input_val,
            num_output_maps=weight.shape[1] * kwargs["groups"],
@@ -270,29 +247,20 @@
    mode = kwargs["mode"]
    value = kwargs["value"] if kwargs["value"] is not None else 0
    rank = len(input_val.shape)  # type: ignore[union-attr]

    if not isinstance(input_val, TRTTensor):
-        raise RuntimeError(
-            f"pad received input {input_val} that is not part "
-            "of the TensorRT region!"
-        )
+        raise RuntimeError(f"pad received input {input_val} that is not part " "of the TensorRT region!")

    if mode != "constant":
-        raise RuntimeError(
-            f"Currently we only support constant mode for pad, got {mode}."
-        )
+        raise RuntimeError(f"Currently we only support constant mode for pad, got {mode}.")

    if len(pad) / 2 > rank:
-        raise RuntimeError(
-            f"Trying to pad last {len(pad) / 2} dimension but the input only has {rank} dimension."
-        )
+        raise RuntimeError(f"Trying to pad last {len(pad) / 2} dimension but the input only has {rank} dimension.")

    if value != 0:
-        raise RuntimeError(
-            f"Currently we only support padding value of 0, got {value}."
-        )
+        raise RuntimeError(f"Currently we only support padding value of 0, got {value}.")

    if len(pad) > 4:
        raise RuntimeError("Currently we only support padding last two dimensions.")

    pre_padding = tuple(pad[len(pad) - i - 2] for i in range(0, len(pad), 2))
@@ -320,38 +288,28 @@
    mode = kwargs["mode"]
    value = kwargs["value"] if kwargs["value"] is not None else 0
    rank = len(input_val.shape)  # type: ignore[union-attr]

    if not isinstance(input_val, TRTTensor):
-        raise RuntimeError(
-            f"pad received input {input_val} that is not part "
-            "of the TensorRT region!"
-        )
+        raise RuntimeError(f"pad received input {input_val} that is not part " "of the TensorRT region!")

    if mode != "constant":
-        raise RuntimeError(
-            f"Currently we only support constant mode for pad, got {mode}."
-        )
+        raise RuntimeError(f"Currently we only support constant mode for pad, got {mode}.")

    if len(pad) / 2 > rank:
-        raise RuntimeError(
-            f"Trying to pad last {len(pad) / 2} dimension but the input only has {rank} dimension."
-        )
+        raise RuntimeError(f"Trying to pad last {len(pad) / 2} dimension but the input only has {rank} dimension.")

    # cast value to TRTensor
    dt = torch_dtype_from_trt(input_val.dtype)
    value = 0 if value == None else value
-    value_const = get_trt_tensor(
-        network, torch.tensor([value], dtype=dt), f"{name}_value"
-    )
+    value_const = get_trt_tensor(network, torch.tensor([value], dtype=dt), f"{name}_value")

    input_shape = input_val.shape
    pre_start = tuple(i - 1 for i in input_shape)
    prefix_len = len(input_shape) - len(pad) // 2
    pre_shape = tuple(
-        input_shape[i] + (pad[-(i - prefix_len) * 2 - 2] if i >= prefix_len else 0)
-        for i in range(0, len(input_shape))
+        input_shape[i] + (pad[-(i - prefix_len) * 2 - 2] if i >= prefix_len else 0) for i in range(0, len(input_shape))
    )
    pre_stride = [-1] * len(input_shape)

    layer = network.add_slice(
        input_val,
@@ -374,12 +332,11 @@
    transpose_output = layer.get_output(0)

    shape = transpose_output.shape
    post_start = tuple([0] * len(shape))
    post_shape = tuple(
-        shape[i] + (pad[-(i - prefix_len) * 2 - 1] if i >= prefix_len else 0)
-        for i in range(0, len(shape))
+        shape[i] + (pad[-(i - prefix_len) * 2 - 1] if i >= prefix_len else 0) for i in range(0, len(shape))
    )
    post_stride = tuple([1] * len(shape))

    layer = network.add_slice(transpose_output, post_start, post_shape, post_stride)
    layer.set_input(4, value_const)
@@ -397,22 +354,15 @@
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    input_val = kwargs["input"]

    if not isinstance(input_val, TRTTensor):
-        raise RuntimeError(
-            f"flatten received input {input_val} that is not part "
-            "of the TensorRT region!"
-        )
+        raise RuntimeError(f"flatten received input {input_val} that is not part " "of the TensorRT region!")

    num_dims = len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0)
-    start_dim = get_positive_dim(
-        cast(int, kwargs["start_dim"] if "start_dim" in kwargs else 0), num_dims
-    )
-    end_dim = get_positive_dim(
-        cast(int, kwargs["end_dim"] if "end_dim" in kwargs else -1), num_dims
-    )
+    start_dim = get_positive_dim(cast(int, kwargs["start_dim"] if "start_dim" in kwargs else 0), num_dims)
+    end_dim = get_positive_dim(cast(int, kwargs["end_dim"] if "end_dim" in kwargs else -1), num_dims)

    if network.has_implicit_batch_dimension:
        assert start_dim != 0, "Can't flatten batch dimension when it's implicit."
        start_dim -= 1
        end_dim -= 1
@@ -511,24 +461,18 @@
    kwargs: Dict[str, Argument],
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    input_t = kwargs["input"]
    if type(input_t) == torch.nn.Parameter or type(input_t) == torch.Tensor:
-        if (
-            not has_dynamic_shape(input_t.shape)
-            and network.has_implicit_batch_dimension
-        ):
+        if not has_dynamic_shape(input_t.shape) and network.has_implicit_batch_dimension:
            return torch.Size((IMPLICIT_BATCH_DIM,) + tuple(input_t.shape))
        return input_t.shape

    # input_val = get_trt_tensor(network, input_t, f"{name}_input_t")
    input_val = input_t
    if not isinstance(input_val, TRTTensor):
-        raise RuntimeError(
-            f"size received input {input_val} that is not part "
-            "of the TensorRT region!"
-        )
+        raise RuntimeError(f"size received input {input_val} that is not part " "of the TensorRT region!")

    if not has_dynamic_shape(input_val.shape):
        if network.has_implicit_batch_dimension:
            return torch.Size((IMPLICIT_BATCH_DIM,) + tuple(input_val.shape))
        return torch.Size(input_val.shape)
@@ -547,14 +491,11 @@
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    input_val = kwargs["input"]

    if not isinstance(input_val, TRTTensor):
-        raise RuntimeError(
-            f"size received input {input_val} that is not part "
-            "of the TensorRT region!"
-        )
+        raise RuntimeError(f"size received input {input_val} that is not part " "of the TensorRT region!")

    if has_dynamic_shape(input_val.shape):
        raise RuntimeError(f"numel does not support dynamic shapes.")

    numel = np.prod(input_val.shape)
@@ -572,29 +513,20 @@
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    input_val = kwargs["input"]

    if not isinstance(input_val, TRTTensor):
-        raise RuntimeError(
-            f"BatchNorm2d received input {input_val} that is not part "
-            "of the TensorRT region!"
-        )
+        raise RuntimeError(f"BatchNorm2d received input {input_val} that is not part " "of the TensorRT region!")

    if has_dynamic_shape(input_val.shape):
        assert input_val.shape[1] != -1, "Channel dim can't be dynamic for batch norm."

-    scale = cast(
-        torch.Tensor, to_numpy(cast(torch.Tensor, kwargs["weight"]))
-    ) / np.sqrt(
-        cast(torch.Tensor, to_numpy(cast(torch.Tensor, kwargs["running_var"])))
-        + cast(float, kwargs["eps"])
+    scale = cast(torch.Tensor, to_numpy(cast(torch.Tensor, kwargs["weight"]))) / np.sqrt(
+        cast(torch.Tensor, to_numpy(cast(torch.Tensor, kwargs["running_var"]))) + cast(float, kwargs["eps"])
    )

-    bias = (
-        to_numpy(cast(torch.Tensor, kwargs["bias"]))
-        - to_numpy(cast(torch.Tensor, kwargs["running_mean"])) * scale
-    )
+    bias = to_numpy(cast(torch.Tensor, kwargs["bias"])) - to_numpy(cast(torch.Tensor, kwargs["running_mean"])) * scale
    power = np.ones_like(scale)

    # For BatchNorm1d, reshape 1d to 2d
    output_shape = input_val.shape
    if not network.has_implicit_batch_dimension and len(input_val.shape) < 4:
@@ -628,44 +560,33 @@
@tensorrt_converter(acc_ops.layer_norm)
def acc_ops_layer_norm(network, target, args, kwargs, name):
    input_val = kwargs["input"]

    if not isinstance(input_val, trt.tensorrt.ITensor):
-        raise RuntimeError(
-            f"LayerNorm received input {input_val} that is not part "
-            "of the TensorRT region!"
-        )
+        raise RuntimeError(f"LayerNorm received input {input_val} that is not part " "of the TensorRT region!")

    gamma = kwargs["weight"].detach().cpu().float().numpy()
    gamma_field = trt.PluginField("gamma", gamma, trt.PluginFieldType.FLOAT32)
    beta = kwargs["bias"].detach().cpu().float().numpy()
    beta_field = trt.PluginField("beta", beta, trt.PluginFieldType.FLOAT32)
-    eps_field = trt.PluginField(
-        "eps", np.array([kwargs["eps"]], dtype=np.float32), trt.PluginFieldType.FLOAT32
-    )
+    eps_field = trt.PluginField("eps", np.array([kwargs["eps"]], dtype=np.float32), trt.PluginFieldType.FLOAT32)
    try:
        normalized_shape = np.array(kwargs["normalized_shape"], dtype=np.int32)
    except TypeError:
        _LOGGER.error("Unable to convert normalized_shape to a field, fall back to []")
        normalized_shape = np.array([], dtype=np.int32)

-    normalized_shape_filed = trt.PluginField(
-        "normalized_shape", normalized_shape, trt.PluginFieldType.INT32
-    )
-    field_collection = trt.PluginFieldCollection(
-        [gamma_field, beta_field, eps_field, normalized_shape_filed]
-    )
+    normalized_shape_filed = trt.PluginField("normalized_shape", normalized_shape, trt.PluginFieldType.INT32)
+    field_collection = trt.PluginFieldCollection([gamma_field, beta_field, eps_field, normalized_shape_filed])

    try:
        if network.has_implicit_batch_dimension:
            plugin = get_trt_plugin("layer_norm", field_collection, "1", "fx2trt")
        else:
            plugin = get_trt_plugin("LayerNormDynamic", field_collection, "1", "fx2trt")
    except AssertionError:
-        _LOGGER.error(
-            "Unable to find layer norm plugin, fall back to TensorRT implementation."
-        )
+        _LOGGER.error("Unable to find layer norm plugin, fall back to TensorRT implementation.")
        return layer_norm(network, target, args, kwargs, name)
    layer = network.add_plugin_v2([input_val], plugin)
    layer.name = name
    return layer.get_output(0)

@@ -678,14 +599,11 @@
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    input_val = kwargs["input"]

    if not isinstance(input_val, TRTTensor):
-        raise RuntimeError(
-            f"LayerNorm received input {input_val} that is not part "
-            "of the TensorRT region!"
-        )
+        raise RuntimeError(f"LayerNorm received input {input_val} that is not part " "of the TensorRT region!")

    shape = kwargs["weight"].shape  # type: ignore[union-attr]
    broadcasted_shape = (1,) * (len(input_val.shape) - len(shape)) + shape
    gamma = to_numpy(kwargs["weight"].reshape(*shape))  # type: ignore[union-attr]
    beta = to_numpy(kwargs["bias"].reshape(*shape))  # type: ignore[union-attr]
@@ -694,13 +612,11 @@
    axes = 0
    for d in range(len(shape)):
        axes |= 1 << (len(input_val.shape) - d - 1)

    # E[x]
-    mean_expected_layer = network.add_reduce(
-        input_val, trt.ReduceOperation.AVG, axes, keep_dims=True
-    )
+    mean_expected_layer = network.add_reduce(input_val, trt.ReduceOperation.AVG, axes, keep_dims=True)
    set_layer_name(mean_expected_layer, target, f"{name}_mean_expected")

    # X-E[x]
    sub_trt = add_binary_elementwise_layer(
        network,
@@ -722,13 +638,11 @@
        pow_tensor.get_output(0),
        trt.ElementWiseOperation.POW,
        target,
        f"{name}_pow_var",
    )
-    mean_trt_layer = network.add_reduce(
-        pow_var, trt.ReduceOperation.AVG, axes, keep_dims=True
-    )
+    mean_trt_layer = network.add_reduce(pow_var, trt.ReduceOperation.AVG, axes, keep_dims=True)
    set_layer_name(mean_trt_layer, target, f"{name}_mean")
    # Variance + eps
    eps_tensor = network.add_constant(
        (1,) * len(input_val.shape),
        trt.Weights(np.ascontiguousarray([eps], dtype=np.float32)),
@@ -741,13 +655,11 @@
        trt.ElementWiseOperation.SUM,
        target,
        f"{name}_add",
    )
    # SQRT((Var + eps))
-    sqrt_trt = add_unary_layer(
-        network, add_trt, trt.UnaryOperation.SQRT, target, f"{name}_sqrt"
-    )
+    sqrt_trt = add_unary_layer(network, add_trt, trt.UnaryOperation.SQRT, target, f"{name}_sqrt")
    # (x - E[x]) / sqrt((var + eps))
    div_trt = add_binary_elementwise_layer(
        network,
        sub_trt,
        sqrt_trt,
@@ -791,14 +703,11 @@
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    input_val = kwargs["input"]
    input_ranks = len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0)  # type: ignore[union-attr]

    if not isinstance(input_val, TRTTensor):
-        raise RuntimeError(
-            f"softmax received input {input_val} that is not part "
-            "of the TensorRT region!"
-        )
+        raise RuntimeError(f"softmax received input {input_val} that is not part " "of the TensorRT region!")

    # Used to get dim when dim is None. Copied from PyTorch softmax implementation.
    def get_softmax_dim(ndim: int) -> int:
        if ndim == 0 or ndim == 1 or ndim == 3:
            ret = 0
@@ -832,13 +741,11 @@
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    input_t = kwargs["input"]
    input_val = get_trt_tensor(network, input_t, f"{name}_input")

    dims = tuple(cast(Sequence[int], kwargs["dims"]))
-    n_input_dims = len(input_val.shape) + (
-        1 if network.has_implicit_batch_dimension else 0
-    )
+    n_input_dims = len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0)

    if len(dims) > n_input_dims:
        assert not network.has_implicit_batch_dimension
        layer = network.add_shuffle(input_val)
        layer.name = f"{name}_reshape"
@@ -849,20 +756,16 @@
            input_shape_layer.name = f"{name}_input_shape"
            preceding_ones = network.add_constant(
                (num_preceding_ones,),
                np.ascontiguousarray([1] * num_preceding_ones, np.int32),
            ).get_output(0)
-            reshape_layer = network.add_concatenation(
-                [preceding_ones, input_shape_layer.get_output(0)]
-            )
+            reshape_layer = network.add_concatenation([preceding_ones, input_shape_layer.get_output(0)])
            reshape_layer.axis = 0
            reshape_layer.name = f"{name}_reshape_dims"
            layer.set_input(1, reshape_layer.get_output(0))
        else:
-            layer.reshape_dims = (1,) * (len(dims) - n_input_dims) + tuple(
-                input_val.shape
-            )
+            layer.reshape_dims = (1,) * (len(dims) - n_input_dims) + tuple(input_val.shape)
        input_val = layer.get_output(0)
    else:
        dims = (1,) * (n_input_dims - len(dims)) + dims

    if network.has_implicit_batch_dimension:
@@ -898,17 +801,15 @@
    layer = network.add_slice(input_val, starts, shapes, strides)
    layer.mode = trt.SliceMode.WRAP
    set_layer_name(layer, target, name)

    if has_dynamic_shape(input_val.shape):  # type: ignore[union-attr]
-        starts_tensor = network.add_constant(
-            (len(dims),), np.ascontiguousarray([0] * len(dims), np.int32)
-        ).get_output(0)
+        starts_tensor = network.add_constant((len(dims),), np.ascontiguousarray([0] * len(dims), np.int32)).get_output(
+            0
+        )
        if all(isinstance(d, int) for d in dims):
-            dims_tensor = network.add_constant(
-                (len(dims),), np.ascontiguousarray(dims, np.int32)
-            ).get_output(0)
+            dims_tensor = network.add_constant((len(dims),), np.ascontiguousarray(dims, np.int32)).get_output(0)
        else:
            assert all(isinstance(d, TRTTensor) for d in dims)
            concat_dims_layer = network.add_concatenation(inputs=dims)
            concat_dims_layer.axis = 0
            concat_dims_layer.name = f"{name}_tile_dim"
@@ -969,13 +870,11 @@
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    input_val = kwargs["input"]
    negative_slope = kwargs["negative_slope"]
    operation_type = trt.ActivationType.LEAKY_RELU
-    return add_activation_layer(
-        network, input_val, operation_type, target, name, negative_slope
-    )
+    return add_activation_layer(network, input_val, operation_type, target, name, negative_slope)


@tensorrt_converter(acc_ops.elu)
def acc_ops_elu(
    network: TRTNetwork,
@@ -1243,51 +1142,40 @@
    target: Target,
    args: Tuple[Argument, ...],
    kwargs: Dict[str, Argument],
    name: str,
) -> TRTTensor:
-    return add_reduce_layer(
-        network, target, args, kwargs, trt.ReduceOperation.SUM, name
-    )
+    return add_reduce_layer(network, target, args, kwargs, trt.ReduceOperation.SUM, name)


@tensorrt_converter(acc_ops.prod)
def acc_ops_prod(
    network: TRTNetwork,
    target: Target,
    args: Tuple[Argument, ...],
    kwargs: Dict[str, Argument],
    name: str,
) -> TRTTensor:
-    return add_reduce_layer(
-        network, target, args, kwargs, trt.ReduceOperation.PROD, name
-    )
+    return add_reduce_layer(network, target, args, kwargs, trt.ReduceOperation.PROD, name)


@tensorrt_converter(acc_ops.mean)
def acc_ops_mean(
    network: TRTNetwork,
    target: Target,
    args: Tuple[Argument, ...],
    kwargs: Dict[str, Argument],
    name: str,
) -> TRTTensor:
-    return add_reduce_layer(
-        network, target, args, kwargs, trt.ReduceOperation.AVG, name
-    )
+    return add_reduce_layer(network, target, args, kwargs, trt.ReduceOperation.AVG, name)


def add_acc_ops_full_reduce(network, target, args, kwargs, name, reduce_op):
    input_val = kwargs["input"]
    if not isinstance(input_val, TRTTensor):
-        raise RuntimeError(
-            f"max received input {input_val} that is not part "
-            "of the TensorRT region!"
-        )
-    assert (
-        not network.has_implicit_batch_dimension
-    ), "Do not support max over all the elements for implicit batch."
+        raise RuntimeError(f"max received input {input_val} that is not part " "of the TensorRT region!")
+    assert not network.has_implicit_batch_dimension, "Do not support max over all the elements for implicit batch."

    dim = range(len(input_val.shape))

    layer = network.add_reduce(
        input_val,
@@ -1307,25 +1195,21 @@
        new_kwargs["largest"] = True
    elif reduce_op == trt.ReduceOperation.MIN:
        new_kwargs["largest"] = False
    new_kwargs["sorted"] = False

-    topk_out0, topk_out1 = acc_ops_topk(
-        network, target, args, new_kwargs, name + "_topk"
-    )
+    topk_out0, topk_out1 = acc_ops_topk(network, target, args, new_kwargs, name + "_topk")

    topk_out0.name = f"{name}_topk0"
    topk_out1.name = f"{name}_topk1"

    if "keepdim" in new_kwargs and new_kwargs["keepdim"]:
        return topk_out0, topk_out1

    dim = new_kwargs["dim"]
    if network.has_implicit_batch_dimension:
-        assert (
-            dim != 0
-        ), "can't reduce on dim == 0 when network has implicit batch dimension"
+        assert dim != 0, "can't reduce on dim == 0 when network has implicit batch dimension"
        # we remove the first dim in the shape tuple when it is implicit
        dim -= 1
    input_val = topk_out0
    shape = input_val.shape

@@ -1355,52 +1239,44 @@
    target: Target,
    args: Tuple[Argument, ...],
    kwargs: Dict[str, Argument],
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
-    return add_acc_ops_full_reduce(
-        network, target, args, kwargs, name, trt.ReduceOperation.MAX
-    )
+    return add_acc_ops_full_reduce(network, target, args, kwargs, name, trt.ReduceOperation.MAX)


@tensorrt_converter(acc_ops.min_full_reduce, no_implicit_batch_dim=True)
def acc_ops_min_full_reduce(
    network: TRTNetwork,
    target: Target,
    args: Tuple[Argument, ...],
    kwargs: Dict[str, Argument],
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
-    return add_acc_ops_full_reduce(
-        network, target, args, kwargs, name, trt.ReduceOperation.MIN
-    )
+    return add_acc_ops_full_reduce(network, target, args, kwargs, name, trt.ReduceOperation.MIN)


@tensorrt_converter(acc_ops.max_dim_reduce)
def acc_ops_max_dim_reduce(
    network: TRTNetwork,
    target: Target,
    args: Tuple[Argument, ...],
    kwargs: Dict[str, Argument],
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
-    return add_acc_ops_dim_reduce(
-        network, target, args, kwargs, name, trt.ReduceOperation.MAX
-    )
+    return add_acc_ops_dim_reduce(network, target, args, kwargs, name, trt.ReduceOperation.MAX)


@tensorrt_converter(acc_ops.min_dim_reduce)
def acc_ops_min_dim_reduce(
    network: TRTNetwork,
    target: Target,
    args: Tuple[Argument, ...],
    kwargs: Dict[str, Argument],
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
-    return add_acc_ops_dim_reduce(
-        network, target, args, kwargs, name, trt.ReduceOperation.MIN
-    )
+    return add_acc_ops_dim_reduce(network, target, args, kwargs, name, trt.ReduceOperation.MIN)


@tensorrt_converter(acc_ops.maximum)
def acc_ops_maximum(
    network: TRTNetwork,
@@ -1503,32 +1379,24 @@
    args: Tuple[Argument, ...],
    kwargs: Dict[str, Argument],
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    if network.has_implicit_batch_dimension:
-        raise RuntimeError(
-            "The `logical_and` function should be called with explicit batch dimension."
-        )
+        raise RuntimeError("The `logical_and` function should be called with explicit batch dimension.")

    input_t = kwargs["input"]
    other_t = kwargs["other"]
    # we only support both inputs are bool type
    if target == acc_ops.bitwise_and:

        def check_is_bool(input_t):
            if isinstance(input_t, TRTTensor):
-                assert (
-                    input_t.dtype == trt.bool
-                ), "We currently do not support input is non-bool"
+                assert input_t.dtype == trt.bool, "We currently do not support input is non-bool"
            elif isinstance(input_t, torch.Tensor):
-                assert (
-                    input_t.dtype == torch.bool
-                ), "We currently do not support input is non-bool"
+                assert input_t.dtype == torch.bool, "We currently do not support input is non-bool"
            else:
-                assert isinstance(
-                    input_t.bool
-                ), "We currently do not support input is non-bool"
+                assert isinstance(input_t.bool), "We currently do not support input is non-bool"

        check_is_bool(input_t)
        check_is_bool(other_t)

    input_t = get_trt_tensor(network, input_t, f"{name}_input_t")
@@ -1536,13 +1404,11 @@

    if input_t.dtype != trt.bool:
        input_t = type_cast(network, target, f"{name}_input", input_t, trt.bool)
    if other_t.dtype != trt.bool:
        other_t = type_cast(network, target, f"{name}_other", other_t, trt.bool)
-    return add_binary_elementwise_layer(
-        network, input_t, other_t, trt.ElementWiseOperation.AND, target, name
-    )
+    return add_binary_elementwise_layer(network, input_t, other_t, trt.ElementWiseOperation.AND, target, name)


@tensorrt_converter(acc_ops.ne, no_implicit_batch_dim=True)
def acc_ops_ne(
    network: TRTNetwork,
@@ -1550,24 +1416,20 @@
    args: Tuple[Argument, ...],
    kwargs: Dict[str, Argument],
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    if network.has_implicit_batch_dimension:
-        raise RuntimeError(
-            "The `ne` function should be called with explicit batch dimension."
-        )
+        raise RuntimeError("The `ne` function should be called with explicit batch dimension.")

    input_t = kwargs["input"]
    other_t = kwargs["other"]

    input_t = get_trt_tensor(network, input_t, f"{name}_input_t")
    other_t = get_trt_tensor(network, other_t, f"{name}_other_t")

    input_t, other_t = dtype_uniform(network, target, name, input_t, other_t)
-    eq_t = add_binary_elementwise_layer(
-        network, input_t, other_t, trt.ElementWiseOperation.EQUAL, target, name
-    )
+    eq_t = add_binary_elementwise_layer(network, input_t, other_t, trt.ElementWiseOperation.EQUAL, target, name)

    return add_unary_layer(network, eq_t, trt.UnaryOperation.NOT, target, name)


@tensorrt_converter(acc_ops.eq, no_implicit_batch_dim=True)
@@ -1577,24 +1439,20 @@
    args: Tuple[Argument, ...],
    kwargs: Dict[str, Argument],
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    if network.has_implicit_batch_dimension:
-        raise RuntimeError(
-            "The `eq` function should be called with explicit batch dimension."
-        )
+        raise RuntimeError("The `eq` function should be called with explicit batch dimension.")

    input_t = kwargs["input"]
    other_t = kwargs["other"]

    input_t = get_trt_tensor(network, input_t, f"{name}_input_t")
    other_t = get_trt_tensor(network, other_t, f"{name}_other_t")

    input_t, other_t = dtype_uniform(network, target, name, input_t, other_t)
-    return add_binary_elementwise_layer(
-        network, input_t, other_t, trt.ElementWiseOperation.EQUAL, target, name
-    )
+    return add_binary_elementwise_layer(network, input_t, other_t, trt.ElementWiseOperation.EQUAL, target, name)


@tensorrt_converter(acc_ops.gt, no_implicit_batch_dim=True)
def acc_ops_gt(
    network: TRTNetwork,
@@ -1602,24 +1460,20 @@
    args: Tuple[Argument, ...],
    kwargs: Dict[str, Argument],
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    if network.has_implicit_batch_dimension:
-        raise RuntimeError(
-            "The `gt` function should be called with explicit batch dimension."
-        )
+        raise RuntimeError("The `gt` function should be called with explicit batch dimension.")

    input_t = kwargs["input"]
    other_t = kwargs["other"]

    input_t = get_trt_tensor(network, input_t, f"{name}_input_t")
    other_t = get_trt_tensor(network, other_t, f"{name}_other_t")

    input_t, other_t = dtype_uniform(network, target, name, input_t, other_t)
-    return add_binary_elementwise_layer(
-        network, input_t, other_t, trt.ElementWiseOperation.GREATER, target, name
-    )
+    return add_binary_elementwise_layer(network, input_t, other_t, trt.ElementWiseOperation.GREATER, target, name)


@tensorrt_converter(acc_ops.lt, no_implicit_batch_dim=True)
def acc_ops_lt(
    network: TRTNetwork,
@@ -1627,24 +1481,20 @@
    args: Tuple[Argument, ...],
    kwargs: Dict[str, Argument],
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    if network.has_implicit_batch_dimension:
-        raise RuntimeError(
-            "The `le` function should be called with explicit batch dimension."
-        )
+        raise RuntimeError("The `le` function should be called with explicit batch dimension.")

    input_t = kwargs["input"]
    other_t = kwargs["other"]

    input_t = get_trt_tensor(network, input_t, f"{name}_input_t")
    other_t = get_trt_tensor(network, other_t, f"{name}_other_t")

    input_t, other_t = dtype_uniform(network, target, name, input_t, other_t)
-    return add_binary_elementwise_layer(
-        network, input_t, other_t, trt.ElementWiseOperation.LESS, target, name
-    )
+    return add_binary_elementwise_layer(network, input_t, other_t, trt.ElementWiseOperation.LESS, target, name)


@tensorrt_converter(acc_ops.logical_or, no_implicit_batch_dim=True)
def acc_ops_logical_or(
    network: TRTNetwork,
@@ -1652,13 +1502,11 @@
    args: Tuple[Argument, ...],
    kwargs: Dict[str, Argument],
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    if network.has_implicit_batch_dimension:
-        raise RuntimeError(
-            "The `logical_or` function should be called with explicit batch dimension."
-        )
+        raise RuntimeError("The `logical_or` function should be called with explicit batch dimension.")

    input_t = kwargs["input"]
    other_t = kwargs["other"]
    if isinstance(other_t, (torch.Tensor, bool)):
        if isinstance(other_t, bool):
@@ -1675,13 +1523,11 @@
        layer_o = network.add_identity(other_t)
        layer_o.set_output_type(0, trt.bool)
        set_layer_name(layer_o, target, f"{name}_other_dtype_change")
        other_t = layer_o.get_output(0)

-    return add_binary_elementwise_layer(
-        network, input_t, other_t, trt.ElementWiseOperation.OR, target, name
-    )
+    return add_binary_elementwise_layer(network, input_t, other_t, trt.ElementWiseOperation.OR, target, name)


@tensorrt_converter(acc_ops.logical_xor, no_implicit_batch_dim=True)
def acc_ops_logical_xor(
    network: TRTNetwork,
@@ -1689,13 +1535,11 @@
    args: Tuple[Argument, ...],
    kwargs: Dict[str, Argument],
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    if network.has_implicit_batch_dimension:
-        raise RuntimeError(
-            "The `logical_xor` function should be called with explicit batch dimension."
-        )
+        raise RuntimeError("The `logical_xor` function should be called with explicit batch dimension.")

    input_t = kwargs["input"]
    other_t = kwargs["other"]
    if isinstance(other_t, (torch.Tensor, bool)):
        if isinstance(other_t, bool):
@@ -1712,13 +1556,11 @@
        layer_o = network.add_identity(other_t)
        layer_o.set_output_type(0, trt.bool)
        set_layer_name(layer_o, target, f"{name}_other_dtype_change")
        other_t = layer_o.get_output(0)

-    return add_binary_elementwise_layer(
-        network, input_t, other_t, trt.ElementWiseOperation.XOR, target, name
-    )
+    return add_binary_elementwise_layer(network, input_t, other_t, trt.ElementWiseOperation.XOR, target, name)


# T113156424 Have some accuracy problems in hf_T5.
# [TRT] [W] Weights [name=isinf_1_inf_t]: Converted FP32 value in weights (either FP32 infinity or FP32 value outside FP16 range) to corresponding FP16 infinity. If this is not the desired behavior, please modify the weights or retrain with regularization to reduce the magnitude of the weights.
# @tensorrt_converter(acc_ops.isinf)
@@ -1764,26 +1606,19 @@
    kwargs: Dict[str, Argument],
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    input_t = kwargs["input"]
    if not isinstance(input_t, TRTTensor):
-        raise RuntimeError(
-            f"isinf received input {input_t} that is not part "
-            "of the TensorRT region!"
-        )
+        raise RuntimeError(f"isinf received input {input_t} that is not part " "of the TensorRT region!")

    if input_t.dtype in (trt.float32, trt.float16, trt.int32):
-        comp_t = torch.zeros(tuple([*input_t.shape])).to(
-            torch_dtype_from_trt(input_t.dtype)
-        )
+        comp_t = torch.zeros(tuple([*input_t.shape])).to(torch_dtype_from_trt(input_t.dtype))
        comp_t = get_trt_tensor(network, comp_t, f"{name}_comp_t")
        kwargs_new = {"input": input_t, "other": comp_t}
        eq_output = acc_ops_eq(network, target, None, kwargs_new, name + "_eq")
        kwargs_new = {"input": eq_output}
-        not_output = acc_ops_logical_not(
-            network, target, None, kwargs_new, name + "_not"
-        )
+        not_output = acc_ops_logical_not(network, target, None, kwargs_new, name + "_not")
    else:
        not_output = input_t
    # cast bool result to int
    int_output = type_cast(network, target, f"{name}_cast_int", not_output, trt.int32)
    # sum
@@ -1809,13 +1644,11 @@
    args: Tuple[Argument, ...],
    kwargs: Dict[str, Argument],
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    # NOTE: TRT doesnt currently implement fmod so we need multiple operations to perform it
-    trunc_div_value = trunc_div(
-        kwargs["input"], kwargs["other"], network, target, name + "_trunc_div"
-    )
+    trunc_div_value = trunc_div(kwargs["input"], kwargs["other"], network, target, name + "_trunc_div")
    prod_value = add_binary_elementwise_layer(
        network,
        trunc_div_value,
        kwargs["other"],
        trt.ElementWiseOperation.PROD,
@@ -1907,14 +1740,11 @@
    kwargs: Dict[str, Argument],
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    input_trt = kwargs["input"]
    if not isinstance(input_trt, TRTTensor):
-        raise RuntimeError(
-            f"Max_pool1d received input {input_trt} that is not part "
-            "of the TensorRT region!"
-        )
+        raise RuntimeError(f"Max_pool1d received input {input_trt} that is not part " "of the TensorRT region!")

    # adds unsqueeze layer -> max pool 2d -> squeeze layer to emulate max pool 1d.
    unsqueeze_layer = network.add_shuffle(input=input_trt)
    unsqueeze_layer.reshape_dims = tuple([*input_trt.shape, 1])
    set_layer_name(unsqueeze_layer, target, name + "_unsqueeze")
@@ -1929,25 +1759,16 @@
    ceil_mode = kwargs["ceil_mode"]

    if len(stride) == 0 or stride[0] == None:
        stride = kernel_size

-    if any(
-        [
-            not isinstance(param, int)
-            for param in [kernel_size[0], stride[0], padding[0], dilation[0]]
-        ]
-    ):
-        raise RuntimeError(
-            f"Parameters kernel_size, stride, padding, and dilation should be of type int."
-        )
+    if any([not isinstance(param, int) for param in [kernel_size[0], stride[0], padding[0], dilation[0]]]):
+        raise RuntimeError(f"Parameters kernel_size, stride, padding, and dilation should be of type int.")
    if dilation[0] != 1:
        raise RuntimeError(f"Only support dilation=1 for maxpool, but got {dilation}")

-    max_pooling_layer = network.add_pooling(
-        input=input_trt, type=trt.PoolingType.MAX, window_size=(kernel_size[0], 1)
-    )
+    max_pooling_layer = network.add_pooling(input=input_trt, type=trt.PoolingType.MAX, window_size=(kernel_size[0], 1))
    max_pooling_layer.stride_nd = stride + (1,)
    max_pooling_layer.padding_nd = padding + (0,)
    set_layer_name(max_pooling_layer, target, name)

    if ceil_mode:
@@ -1969,14 +1790,11 @@
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    input_val = kwargs["input"]

    if not isinstance(input_val, TRTTensor):
-        raise RuntimeError(
-            f"MaxPool2d received input {input_val} that is not part "
-            "of the TensorRT region!"
-        )
+        raise RuntimeError(f"MaxPool2d received input {input_val} that is not part " "of the TensorRT region!")
    extend_len = 2 if target == acc_ops.max_pool2d else 3
    kernel_size = extend_attr_to_tuple(kwargs["kernel_size"], extend_len)
    stride = extend_attr_to_tuple(kwargs["stride"], extend_len)
    padding = extend_attr_to_tuple(kwargs["padding"], extend_len)
    dilation = extend_attr_to_tuple(kwargs["dilation"], extend_len)
@@ -1985,17 +1803,13 @@
    if len(stride) == 0 or stride[0] == None:
        stride = kernel_size

    ones = (1,) * extend_len
    if dilation != ones:
-        raise RuntimeError(
-            f"Only support dilation=(1, 1) for maxpool, but got {dilation}"
-        )
-
-    layer = network.add_pooling_nd(
-        input=input_val, type=trt.PoolingType.MAX, window_size=kernel_size
-    )
+        raise RuntimeError(f"Only support dilation=(1, 1) for maxpool, but got {dilation}")
+
+    layer = network.add_pooling_nd(input=input_val, type=trt.PoolingType.MAX, window_size=kernel_size)
    layer.stride_nd = stride
    layer.padding_nd = padding
    set_layer_name(layer, target, name)

    if ceil_mode:
@@ -2013,23 +1827,18 @@
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    input_val = kwargs["input"]

    if not isinstance(input_val, TRTTensor):
-        raise RuntimeError(
-            f"squeeze received input {input_val} that is not part "
-            "of the TensorRT region!"
-        )
+        raise RuntimeError(f"squeeze received input {input_val} that is not part " "of the TensorRT region!")

    dim = cast(Optional[int], kwargs["dim"] if "dim" in kwargs else None)
    # Squeeze with dim=None would only work in explicit batch dim mode without any dynamic
    # dim, which is a very rare case. For now we just claim not supporting dim=None.
    assert dim is not None, "We don't support dim=None right now for squeeze."

-    dim = get_positive_dim(
-        dim, len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0)
-    )
+    dim = get_positive_dim(dim, len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0))
    if network.has_implicit_batch_dimension:
        assert dim != 0, "We don't support squeeze batch dim when it's implicit."
        dim -= 1

    assert input_val.shape[dim] != -1, "We don't support squeeze dynamic dim."
@@ -2176,35 +1985,26 @@
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    input_t = kwargs["input"]
    input_val = get_trt_tensor(network, input_t, f"{name}_input_t")
    if not isinstance(input_val, TRTTensor):
-        raise RuntimeError(
-            f"unsqueeze received input {input_val} that is not part "
-            "of the TensorRT region!"
-        )
+        raise RuntimeError(f"unsqueeze received input {input_val} that is not part " "of the TensorRT region!")

    dim = cast(int, kwargs["dim"])
    input_shape = input_val.shape
-    input_shape_size = (
-        len(input_val.shape) + 1
-        if network.has_implicit_batch_dimension
-        else len(input_val.shape)
-    )
+    input_shape_size = len(input_val.shape) + 1 if network.has_implicit_batch_dimension else len(input_val.shape)
    dim = get_positive_dim(dim, input_shape_size + 1)

    if network.has_implicit_batch_dimension:
        assert dim != 0
        dim -= 1

    assert (
        len(get_dynamic_dims(input_val.shape)) <= 1
    ), "Currently we don't support unsqueeze with more than one dynamic dims."
    layer = network.add_shuffle(input_val)
-    layer.reshape_dims = (
-        tuple(input_val.shape)[:dim] + (1,) + tuple(input_val.shape)[dim:]
-    )
+    layer.reshape_dims = tuple(input_val.shape)[:dim] + (1,) + tuple(input_val.shape)[dim:]
    set_layer_name(layer, target, name)
    return layer.get_output(0)


@tensorrt_converter(acc_ops.topk)
@@ -2216,14 +2016,11 @@
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    input_val = kwargs["input"]

    if not isinstance(input_val, TRTTensor):
-        raise RuntimeError(
-            f"topk received input {input_val} that is not part "
-            "of the TensorRT region!"
-        )
+        raise RuntimeError(f"topk received input {input_val} that is not part " "of the TensorRT region!")

    if kwargs["sorted"] and kwargs["k"] != 1:
        raise RuntimeError("Currently we don't support sorted=True in topk.")

    if not network.has_implicit_batch_dimension and len(input_val.shape) <= 1:
@@ -2253,40 +2050,28 @@
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    input_val = kwargs["input"]

    if not isinstance(input_val, TRTTensor):
-        raise RuntimeError(
-            f"AdaptiveAvgPool2d received input {input_val} that is not part "
-            "of the TensorRT region!"
-        )
+        raise RuntimeError(f"AdaptiveAvgPool2d received input {input_val} that is not part " "of the TensorRT region!")

    extend_len = 2 if target == acc_ops.adaptive_avg_pool2d else 3
    assert all(
        input_val.shape[-(i + 1)] != -1 for i in range(extend_len)
    ), "AdaptiveAvgPool2d and AdaptiveAvgPool3d currently doesn't support dynamic shapes for last two dims."

-    output_size = cast(
-        Sequence[int], extend_attr_to_tuple(kwargs["output_size"], extend_len)
-    )
+    output_size = cast(Sequence[int], extend_attr_to_tuple(kwargs["output_size"], extend_len))
    for input_dim, output_dim in zip(input_val.shape[-extend_len:], output_size):
        if input_dim % output_dim != 0:
            raise RuntimeError(
                "For AdaptiveAvgPool, input dim has to be integer multiple of output dim."
                f"Got input dim {input_dim}, output dim {output_dim}"
            )

-    stride = tuple(
-        input_val.shape[-extend_len + i] // output_size[i] for i in range(extend_len)
-    )
-    kernel_size = tuple(
-        input_val.shape[-extend_len + i] - (output_size[i] - 1) * stride[i]
-        for i in range(extend_len)
-    )
-    layer = network.add_pooling_nd(
-        input=input_val, type=trt.PoolingType.AVERAGE, window_size=kernel_size
-    )
+    stride = tuple(input_val.shape[-extend_len + i] // output_size[i] for i in range(extend_len))
+    kernel_size = tuple(input_val.shape[-extend_len + i] - (output_size[i] - 1) * stride[i] for i in range(extend_len))
+    layer = network.add_pooling_nd(input=input_val, type=trt.PoolingType.AVERAGE, window_size=kernel_size)
    layer.stride_nd = stride
    set_layer_name(layer, target, name)

    return layer.get_output(0)

@@ -2300,14 +2085,11 @@
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    input_val = kwargs["input"]

    if not isinstance(input_val, TRTTensor):
-        raise RuntimeError(
-            f"AvgPool1d received input {input_val} that is not part "
-            "of the TensorRT region!"
-        )
+        raise RuntimeError(f"AvgPool1d received input {input_val} that is not part " "of the TensorRT region!")

    kernel_size = extend_attr_to_tuple(kwargs["kernel_size"], 1)
    stride = extend_attr_to_tuple(kwargs["stride"], 1)
    padding = extend_attr_to_tuple(kwargs["padding"], 1)
    ceil_mode = kwargs["ceil_mode"]
@@ -2319,13 +2101,11 @@
    shuffle_layer = network.add_shuffle(input_val)
    shuffle_layer.reshape_dims = tuple(input_val.shape) + (1,)
    set_layer_name(shuffle_layer, target, name + "_shuffle1")
    shuffle_out = shuffle_layer.get_output(0)

-    layer = network.add_pooling_nd(
-        input=shuffle_out, type=trt.PoolingType.AVERAGE, window_size=(kernel_size[0], 1)
-    )
+    layer = network.add_pooling_nd(input=shuffle_out, type=trt.PoolingType.AVERAGE, window_size=(kernel_size[0], 1))

    layer.stride_nd = stride + (1,)
    layer.padding_nd = padding + (0,)
    layer.average_count_excludes_padding = False if count_include_pad else True
    set_layer_name(layer, target, name)
@@ -2349,14 +2129,11 @@
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    input_val = kwargs["input"]

    if not isinstance(input_val, TRTTensor):
-        raise RuntimeError(
-            f"AvgPool2d received input {input_val} that is not part "
-            "of the TensorRT region!"
-        )
+        raise RuntimeError(f"AvgPool2d received input {input_val} that is not part " "of the TensorRT region!")

    kernel_size = extend_attr_to_tuple(kwargs["kernel_size"], 2)
    stride = extend_attr_to_tuple(kwargs["stride"], 2)
    padding = extend_attr_to_tuple(kwargs["padding"], 2)
    ceil_mode = kwargs["ceil_mode"]
@@ -2367,13 +2144,11 @@
        stride = kernel_size

    if divisor_override:
        raise RuntimeError("TensorRT does not support divisor_override.")

-    layer = network.add_pooling(
-        input=input_val, type=trt.PoolingType.AVERAGE, window_size=kernel_size
-    )
+    layer = network.add_pooling(input=input_val, type=trt.PoolingType.AVERAGE, window_size=kernel_size)
    layer.stride = stride
    layer.padding = padding
    layer.average_count_excludes_padding = False if count_include_pad else True
    set_layer_name(layer, target, name)

@@ -2433,23 +2208,18 @@
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    input_val = kwargs["input"]

    if not isinstance(input_val, TRTTensor):
-        raise RuntimeError(
-            f"slice_tensor received input {input_val} that is not part "
-            "of the TensorRT region!"
-        )
+        raise RuntimeError(f"slice_tensor received input {input_val} that is not part " "of the TensorRT region!")

    ranks = len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0)
    dim = get_positive_dim(cast(int, kwargs["dim"]), ranks)
    dynamic_shape = has_dynamic_shape(input_val.shape)
    if network.has_implicit_batch_dimension:
        if dim == 0:
-            raise RuntimeError(
-                f"We do not support slice_tensor at batch dim when it's implicit, got {dim}!"
-            )
+            raise RuntimeError(f"We do not support slice_tensor at batch dim when it's implicit, got {dim}!")
        dim = dim - 1
    else:
        if dynamic_shape:
            # Check whether slice target dim is dynamic shape dim
            assert input_val.shape[dim] != -1, "Can't chunk on dynamic shape dimension!"
@@ -2463,13 +2233,11 @@
    stride[dim] = step_int
    output_shape = list(input_val.shape)
    output_shape[dim] = (stop_int - start_int) // step_int

    if dynamic_shape > 0:
-        output_shape = get_shape_with_dynamic_shape(
-            network, output_shape, input_val, target, name
-        )
+        output_shape = get_shape_with_dynamic_shape(network, output_shape, input_val, target, name)
    layer = network.add_slice(
        input_val,
        start=start,
        shape=[] if dynamic_shape else output_shape,
        stride=stride,
@@ -2502,13 +2270,11 @@
    shape = [input_val.shape[i] if shape[i] == -1 else shape[i] for i in range(ranks)]

    inshape = tuple(input_val.shape)
    shape = tuple(shape)
    start = tuple([0] * ranks)
-    stride = tuple(
-        [int(i == o) for i, o in zip(inshape, shape)]
-    )  # stride == 1 if dimensions match, 0 otherwise
+    stride = tuple([int(i == o) for i, o in zip(inshape, shape)])  # stride == 1 if dimensions match, 0 otherwise
    layer = network.add_slice(input_val, start=start, shape=shape, stride=stride)
    set_layer_name(layer, target, name)
    return layer.get_output(0)


@@ -2615,13 +2381,11 @@
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    input_t = kwargs["input"]
    mask_t = kwargs["mask"]
    value_t = kwargs["value"]
    if network.has_implicit_batch_dimension:
-        raise RuntimeError(
-            "We don't support masked_fill with implicit batch dimension due to select layer!"
-        )
+        raise RuntimeError("We don't support masked_fill with implicit batch dimension due to select layer!")

    shape = list(input_t.shape)
    mask_shape = list(mask_t.shape)

    assert type(value_t) in (
@@ -2674,14 +2438,11 @@
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    input_val = kwargs["input"]

    if not isinstance(input_val, TRTTensor):
-        raise RuntimeError(
-            f"split received input {input_val} that is not part "
-            "of the TensorRT region!"
-        )
+        raise RuntimeError(f"split received input {input_val} that is not part " "of the TensorRT region!")

    dim = cast(int, kwargs["dim"])
    dynamic_shape = has_dynamic_shape(input_val.shape)
    if network.has_implicit_batch_dimension:
        assert dim != 0, "Can't split on batch dim when it's implicit!"
@@ -2695,28 +2456,22 @@
    start = [0] * len(input_val.shape)
    stride = [1] * len(start)
    offset = 0
    num_splits = (input_val.shape[dim] + split_size - 1) // split_size
    if num_splits < 1:
-        raise RuntimeError(
-            f"Invalid split: {input_val.shape[dim]} with split_size={split_size}"
-        )
+        raise RuntimeError(f"Invalid split: {input_val.shape[dim]} with split_size={split_size}")

    max_offset = input_val.shape[dim]
    # add slice layers
    output = []
    for i in range(num_splits):
        shape = list(input_val.shape)
        shape[dim] = min(split_size, cast(int, max_offset - offset))
        start[dim] = offset
        if dynamic_shape:
-            shape = get_shape_with_dynamic_shape(
-                network, shape, input_val, target, f"{name}_shape_{i}"
-            )
-        layer = network.add_slice(
-            input_val, start=start, shape=[] if dynamic_shape else shape, stride=stride
-        )
+            shape = get_shape_with_dynamic_shape(network, shape, input_val, target, f"{name}_shape_{i}")
+        layer = network.add_slice(input_val, start=start, shape=[] if dynamic_shape else shape, stride=stride)
        if dynamic_shape:
            layer.set_input(2, shape)
        offset += split_size
        set_layer_name(layer, target, f"{name}_{i}")
        output.append(layer.get_output(0))
@@ -2732,19 +2487,15 @@
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    input_val = kwargs["input"]

    if not isinstance(input_val, TRTTensor):
-        raise RuntimeError(
-            f"Linear received input {input_val} that is not part "
-            "of the TensorRT region!"
-        )
+        raise RuntimeError(f"Linear received input {input_val} that is not part " "of the TensorRT region!")

    dynamic_dims = get_dynamic_dims(input_val.shape)
    assert len(dynamic_dims) < 2 and input_val.shape[-1] != -1, (
-        "Currently we only support one dynmaic "
-        "dim for linear and it can't be the last dim."
+        "Currently we only support one dynmaic " "dim for linear and it can't be the last dim."
    )

    if isinstance(kwargs["weight"], torch.Tensor):
        weight = get_trt_tensor(network, kwargs["weight"].t(), f"{name}_weight")
        weight_op = trt.MatrixOperation.NONE
@@ -2760,13 +2511,11 @@
        preset_diff -= 1
        input_op = trt.MatrixOperation.VECTOR
    else:
        input_op = trt.MatrixOperation.NONE

-    input_val, weight = broadcast(
-        network, input_val, weight, f"{name}_input", f"{name}_weight", preset_diff
-    )
+    input_val, weight = broadcast(network, input_val, weight, f"{name}_input", f"{name}_weight", preset_diff)
    matmul_layer = network.add_matrix_multiply(input_val, input_op, weight, weight_op)
    set_layer_name(matmul_layer, target, f"{name}_matmul")
    res = matmul_layer.get_output(0)

    if kwargs["bias"] is not None:
@@ -2782,16 +2531,11 @@
    return res


def add_clamp(network, input, val, op):
    acc_ops_clamp_shape = (1,) * len(input.shape)  # broadcast all dimensions
-    acc_ops_clamp_tensor = (
-        val
-        * torch.ones(acc_ops_clamp_shape, dtype=torch_dtype_from_trt(input.dtype))
-        .cpu()
-        .numpy()
-    )
+    acc_ops_clamp_tensor = val * torch.ones(acc_ops_clamp_shape, dtype=torch_dtype_from_trt(input.dtype)).cpu().numpy()
    acc_ops_clamp_trt = network.add_constant(acc_ops_clamp_shape, acc_ops_clamp_tensor)
    layer = network.add_elementwise(input, acc_ops_clamp_trt.get_output(0), op)

    return layer

@@ -2807,25 +2551,18 @@
    input_val = kwargs["input"]
    min_val = kwargs["min"]
    max_val = kwargs["max"]

    if not isinstance(input_val, TRTTensor):
-        raise RuntimeError(
-            f"Clamp received input {input_val} that is not part "
-            "of the TensorRT region!"
-        )
+        raise RuntimeError(f"Clamp received input {input_val} that is not part " "of the TensorRT region!")

    if min_val is not None:
-        clamp_min_layer = add_clamp(
-            network, input_val, min_val, trt.ElementWiseOperation.MAX
-        )
+        clamp_min_layer = add_clamp(network, input_val, min_val, trt.ElementWiseOperation.MAX)
        set_layer_name(clamp_min_layer, target, f"{name}_clamp_min")
        input_val = clamp_min_layer.get_output(0)
    if max_val is not None:
-        clamp_max_layer = add_clamp(
-            network, input_val, max_val, trt.ElementWiseOperation.MIN
-        )
+        clamp_max_layer = add_clamp(network, input_val, max_val, trt.ElementWiseOperation.MIN)
        set_layer_name(clamp_max_layer, target, f"{name}_clamp_max")
        input_val = clamp_max_layer.get_output(0)

    return input_val

@@ -2883,30 +2620,22 @@

    def slice_to_trt_params(py_slice, dim_size):
        """
        Convert python slice to TensorRT slice layer parameters.
        """
-        start = (
-            get_positive_dim(py_slice.start, dim_size) if py_slice.start != None else 0
-        )
+        start = get_positive_dim(py_slice.start, dim_size) if py_slice.start != None else 0
        stride = py_slice.step if py_slice.step != None else 1
-        stop = (
-            get_positive_dim(py_slice.stop, dim_size)
-            if py_slice.stop != None
-            else dim_size
-        )
+        stop = get_positive_dim(py_slice.stop, dim_size) if py_slice.stop != None else dim_size
        size = math.ceil((stop - start) * 1.0 / stride)
        return start, size, stride

    if network.has_implicit_batch_dimension:
        # Raise an error if it's trying to subscript batch dimension unless it's
        # slice(None, None, None).
        batch_subscript = slices[0]
        if batch_subscript not in [slice(None, None, None), slice(0, None, None)]:
-            raise RuntimeError(
-                f"{name}: Can't subscript batch dimension when it's implicit. Got {slices}"
-            )
+            raise RuntimeError(f"{name}: Can't subscript batch dimension when it's implicit. Got {slices}")

        # Remove batch_dim subscript
        slices = slices[1:]

    # Replace ellipsis with expanded slices.
@@ -2995,13 +2724,11 @@
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    tensors = kwargs["tensors"]
    dim = kwargs["dim"]

    if any(not isinstance(t, TRTTensor) for t in tensors):  # type: ignore[union-attr]
-        raise RuntimeError(
-            f"cat received inputs {tensors} that is not part " "of the TensorRT region!"
-        )
+        raise RuntimeError(f"cat received inputs {tensors} that is not part " "of the TensorRT region!")
    layer = network.add_concatenation(inputs=tensors)
    if dim < 0:
        if network.has_implicit_batch_dimension:
            dim = len(tensors[0].shape) + 1 + dim
        else:
@@ -3023,13 +2750,11 @@
    input_val = get_trt_tensor(network, kwargs["input"], f"{name}_input")
    other_val = get_trt_tensor(network, kwargs["other"], f"{name}_other")

    for i in [input_val, other_val]:
        if not isinstance(i, TRTTensor):
-            raise RuntimeError(
-                f"matmul received input {i} that is not part of the TensorRT region!"
-            )
+            raise RuntimeError(f"matmul received input {i} that is not part of the TensorRT region!")

    input_matrix_op = other_matrix_op = trt.MatrixOperation.NONE
    preset_diff = 0

    if len(input_val.shape) == 1:
@@ -3038,16 +2763,12 @@

    if len(other_val.shape) == 1:
        preset_diff += 1
        other_matrix_op = trt.MatrixOperation.VECTOR

-    input_val, other_val = broadcast(
-        network, input_val, other_val, f"{name}_input", f"{name}_other", preset_diff
-    )
-    layer = network.add_matrix_multiply(
-        input_val, input_matrix_op, other_val, other_matrix_op
-    )
+    input_val, other_val = broadcast(network, input_val, other_val, f"{name}_input", f"{name}_other", preset_diff)
+    layer = network.add_matrix_multiply(input_val, input_matrix_op, other_val, other_matrix_op)
    set_layer_name(layer, target, name)
    return layer.get_output(0)


@tensorrt_converter(acc_ops.hardsigmoid)
@@ -3059,14 +2780,11 @@
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    input_val = kwargs["input"]

    if not isinstance(input_val, TRTTensor):
-        raise RuntimeError(
-            f"Hard sigmoid received input {input_val} that is not part "
-            "of the TensorRT region!"
-        )
+        raise RuntimeError(f"Hard sigmoid received input {input_val} that is not part " "of the TensorRT region!")

    return add_activation_layer(
        network,
        input_val,
        trt.ActivationType.HARD_SIGMOID,
@@ -3086,18 +2804,13 @@
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    input_val = kwargs["input"]

    if not isinstance(input_val, TRTTensor):
-        raise RuntimeError(
-            f"Sigmoid received input {input_val} that is not part "
-            "of the TensorRT region!"
-        )
-
-    return add_activation_layer(
-        network, input_val, trt.ActivationType.SIGMOID, target, name
-    )
+        raise RuntimeError(f"Sigmoid received input {input_val} that is not part " "of the TensorRT region!")
+
+    return add_activation_layer(network, input_val, trt.ActivationType.SIGMOID, target, name)


@tensorrt_converter(acc_ops.permute)
def acc_ops_permute(
    network: TRTNetwork,
@@ -3113,14 +2826,11 @@
    else:
        index = kwargs["permutation"]
    permutation = [get_positive_dim(i, ranks) for i in cast(Sequence[int], index)]

    if not isinstance(input_val, TRTTensor):
-        raise RuntimeError(
-            f"permute received input {input_val} that is not part "
-            "of the TensorRT region!"
-        )
+        raise RuntimeError(f"permute received input {input_val} that is not part " "of the TensorRT region!")

    if network.has_implicit_batch_dimension:
        assert permutation[0] == 0, "Can't permute batch dimension when it's implicit."
        permutation = [i - 1 for i in permutation[1:]]

@@ -3139,14 +2849,11 @@
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    input_val = get_trt_tensor(network, kwargs["input"], f"{name}_input")

    if not isinstance(input_val, TRTTensor):
-        raise RuntimeError(
-            f"{name} received input {input_val} that is not part "
-            "of the TensorRT region!"
-        )
+        raise RuntimeError(f"{name} received input {input_val} that is not part " "of the TensorRT region!")

    qparams = kwargs["acc_out_ty"].qparams  # type: ignore[misc]
    q_scale = qparams["scale"]
    q_zero_point = qparams["zero_point"]
    dtype = kwargs["acc_out_ty"].dtype  # type: ignore[misc]
@@ -3157,13 +2864,11 @@
        )

    if q_zero_point != 0:
        raise RuntimeError(f"Only support zero_point == 0, get {q_zero_point}")

-    scale_layer = network.add_constant(
-        (1,), trt.Weights(np.ascontiguousarray([float(q_scale)], dtype=np.float32))
-    )
+    scale_layer = network.add_constant((1,), trt.Weights(np.ascontiguousarray([float(q_scale)], dtype=np.float32)))
    scale_layer.name = input_val.name + ".per_tensor_quant.scale"
    scale = scale_layer.get_output(0)
    # assert trt.__version__ > "8.0", "Explicit quantize op is only supported in "
    # "TensorRT 8.0 or above, current TensorRT version:" + trt.__version__
    layer = network.add_quantize(input=input_val, scale=scale)
@@ -3181,14 +2886,11 @@
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    input_val = get_trt_tensor(network, kwargs["input"], f"{name}_input")

    if not isinstance(input_val, TRTTensor):
-        raise RuntimeError(
-            f"{name} received input {input_val} that is not part "
-            "of the TensorRT region!"
-        )
+        raise RuntimeError(f"{name} received input {input_val} that is not part " "of the TensorRT region!")

    qparams = kwargs["acc_out_ty"].qparams  # type: ignore[misc]
    q_per_channel_scales = qparams["scale"]
    q_per_channel_zero_points = qparams["zero_point"]
    q_per_channel_axis = qparams["axis"]
@@ -3201,17 +2903,13 @@

    # Make sure zero_points are all 0 because only symmetric quantization
    # is supported in TensorRT
    if not torch.equal(
        q_per_channel_zero_points,
-        torch.zeros(
-            q_per_channel_zero_points.shape, dtype=q_per_channel_zero_points.dtype
-        ),
+        torch.zeros(q_per_channel_zero_points.shape, dtype=q_per_channel_zero_points.dtype),
    ):
-        raise RuntimeError(
-            f"Only support zero_point == 0, get {q_per_channel_zero_points}"
-        )
+        raise RuntimeError(f"Only support zero_point == 0, get {q_per_channel_zero_points}")

    if not torch.all(torch.ge(q_per_channel_scales, 0)):
        raise RuntimeError(f"All scale values must be >= 0, get {q_per_channel_scales}")

    scale_layer = network.add_constant(
@@ -3238,14 +2936,11 @@
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    input_val = kwargs["input"]
    input_val_tensor_meta = kwargs["_itensor_to_tensor_meta"][input_val]  # type: ignore[index]

    if not isinstance(input_val, TRTTensor):
-        raise RuntimeError(
-            f"{name} received input {input_val} that is not part "
-            "of the TensorRT region!"
-        )
+        raise RuntimeError(f"{name} received input {input_val} that is not part " "of the TensorRT region!")

    qparams = input_val_tensor_meta.qparams  # type: ignore[misc]
    qscheme = qparams["qscheme"]
    if qscheme == torch.per_tensor_affine:
        q_scale = qparams["scale"]
@@ -3256,30 +2951,25 @@
            raise RuntimeError(f"Only support zero_point == 0, get {q_zero_point}")
    elif qscheme == torch.per_channel_affine:
        q_scale = qparams["scale"]
        q_zero_point = qparams["zero_point"]
        q_axis = qparams["axis"]
-        assert isinstance(
-            q_scale, immutable_list
-        ), "expected q_scale to be immutable_list got {}".format(type(q_scale))
+        assert isinstance(q_scale, immutable_list), "expected q_scale to be immutable_list got {}".format(type(q_scale))
        scale_shape = (len(q_scale),)
        if any(x != 0 for x in q_zero_point):
            raise RuntimeError(f"Only support zero_point == 0, get {q_zero_point}")
    else:
        raise RuntimeError("Unsupported qscheme in dequantize: {qscheme}")

    dtype = input_val_tensor_meta.dtype  # type: ignore[misc]

    if dtype not in (torch.quint8, torch.qint8, torch.qint32):
        raise RuntimeError(
-            "Only support (torch.quint8, torch.qint8, torch.qint32) "
-            f"quantized type in dequantize, get {dtype}."
+            "Only support (torch.quint8, torch.qint8, torch.qint32) " f"quantized type in dequantize, get {dtype}."
        )

-    scale_layer = network.add_constant(
-        scale_shape, trt.Weights(np.ascontiguousarray(q_scale, dtype=np.float32))
-    )
+    scale_layer = network.add_constant(scale_shape, trt.Weights(np.ascontiguousarray(q_scale, dtype=np.float32)))
    scale_layer.name = input_val.name + ".dequant.scale"
    scale = scale_layer.get_output(0)
    # assert trt.__version__ > "8.0", "Explicit dequantize op is only supported in "
    # "TensorRT 8.0 or above, current TensorRT version:" + trt.__version__
    layer = network.add_dequantize(input=input_val, scale=scale)
@@ -3296,24 +2986,17 @@
    kwargs: Dict[str, Argument],
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    input_val = kwargs["input"]
    if not isinstance(input_val, TRTTensor):
-        raise RuntimeError(
-            f"GELU received input {input_val} that is not part "
-            "of the TensorRT region!"
-        )
+        raise RuntimeError(f"GELU received input {input_val} that is not part " "of the TensorRT region!")
    if network.has_implicit_batch_dimension:
-        raise RuntimeError(
-            "GeLU converter currently doesn't support implicit batch dimension"
-        )
+        raise RuntimeError("GeLU converter currently doesn't support implicit batch dimension")

    plugin_name = "CustomGeluPluginDynamic"
    # type_id 0 for float32, 1 for  float16
-    type_id = trt.PluginField(
-        "type_id", np.array(0, dtype=np.int32), trt.PluginFieldType.INT32
-    )
+    type_id = trt.PluginField("type_id", np.array(0, dtype=np.int32), trt.PluginFieldType.INT32)
    field_collection = TRTPluginFieldCollection([type_id])
    plugin_version = "1"

    plugin = get_trt_plugin(plugin_name, field_collection, plugin_version)

@@ -3334,14 +3017,11 @@
    chunks = cast(int, kwargs["chunks"])
    dim = cast(int, kwargs["dim"])
    input_dim_size = len(input_val.shape)  # type: ignore[union-attr]

    if not isinstance(input_val, TRTTensor):
-        raise RuntimeError(
-            f"chunk received input {input_val} that is not part "
-            "of the TensorRT region!"
-        )
+        raise RuntimeError(f"chunk received input {input_val} that is not part " "of the TensorRT region!")

    dynamic_shape = has_dynamic_shape(input_val.shape)
    if network.has_implicit_batch_dimension:
        input_dim_size += 1
        dim = get_positive_dim(dim, input_dim_size)
@@ -3371,17 +3051,13 @@
    output = []
    for i in range(chunks):
        shape = list(input_val.shape)
        shape[dim] = min(split_size, max_offset - offset)
        if dynamic_shape:
-            shape = get_shape_with_dynamic_shape(
-                network, shape, input_val, target, f"{name}_{i}"
-            )
+            shape = get_shape_with_dynamic_shape(network, shape, input_val, target, f"{name}_{i}")
        start[dim] = offset
-        layer = network.add_slice(
-            input_val, start=start, shape=[] if dynamic_shape else shape, stride=stride
-        )
+        layer = network.add_slice(input_val, start=start, shape=[] if dynamic_shape else shape, stride=stride)
        if dynamic_shape:
            layer.set_input(2, shape)
        offset += split_size
        set_layer_name(layer, target, f"{name}_{i}")
        output.append(layer.get_output(0))
@@ -3400,18 +3076,13 @@
    dim = cast(int, kwargs["dim"])
    input_shape = input_val.shape  # type: ignore[union-attr]
    input_dim_size = len(input_val.shape)  # type: ignore[union-attr]

    if not isinstance(input_val, TRTTensor):
-        raise RuntimeError(
-            f"cumsum received input {input_val} that is not part "
-            "of the TensorRT region!"
-        )
+        raise RuntimeError(f"cumsum received input {input_val} that is not part " "of the TensorRT region!")
    if network.has_implicit_batch_dimension:
-        raise RuntimeError(
-            "cumsum converter currently doesn't support implicit batch dimension"
-        )
+        raise RuntimeError("cumsum converter currently doesn't support implicit batch dimension")
    dim = get_positive_dim(dim, input_dim_size)
    loop = network.add_loop()
    trip_limit = None
    if input_shape[dim] > 0:
        axis = torch.tensor(input_shape[dim], dtype=torch.int32)
@@ -3427,13 +3098,11 @@
    loop.add_trip_limit(trip_limit, trt.TripLimit(0))
    iterator = loop.add_iterator(input_val, dim, False)
    data = iterator.get_output(0)
    new_dims = tuple(data.shape)
    zero_tensor = torch.zeros(new_dims, dtype=trt_dtype_to_torch_dtype(input_val.dtype))
-    zero_tensor = network.add_constant(
-        zero_tensor.shape, to_numpy(zero_tensor)
-    ).get_output(0)
+    zero_tensor = network.add_constant(zero_tensor.shape, to_numpy(zero_tensor)).get_output(0)

    running_sum = loop.add_recurrence(zero_tensor)
    set_layer_name(running_sum, target, f"{name}_running_sum_1")
    running_sum_tensor = running_sum.get_output(0)

@@ -3476,14 +3145,11 @@
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    input_val = kwargs["input"]

    if not isinstance(input_val, TRTTensor):
-        raise RuntimeError(
-            f"hardtanh received input {input_val} that is not part "
-            "of the TensorRT region!"
-        )
+        raise RuntimeError(f"hardtanh received input {input_val} that is not part " "of the TensorRT region!")

    return add_activation_layer(
        network,
        input_val,
        trt.ActivationType.CLIP,
@@ -3507,26 +3173,19 @@
    scale_factor = kwargs["scale_factor"]
    mode = kwargs["mode"]
    align_corners = kwargs["align_corners"]

    if not isinstance(input_val, TRTTensor):
-        raise RuntimeError(
-            f"interpolate received input {input_val} that is not part "
-            "of the TensorRT region!"
-        )
+        raise RuntimeError(f"interpolate received input {input_val} that is not part " "of the TensorRT region!")

    dim = input_val.shape
    ranks = len(input_val.shape)
    if network.has_implicit_batch_dimension:
-        assert (
-            ranks >= 2 and ranks <= 4
-        ), "Interpolate expects inputs are 3D,4D,5D in shape"
+        assert ranks >= 2 and ranks <= 4, "Interpolate expects inputs are 3D,4D,5D in shape"
        ranks = ranks - 1
    else:
-        assert (
-            ranks >= 3 and ranks <= 5
-        ), "Interpolate expects inputs are 3D,4D,5D in shape"
+        assert ranks >= 3 and ranks <= 5, "Interpolate expects inputs are 3D,4D,5D in shape"
        ranks = ranks - 2

    layer = network.add_resize(input_val)
    if network.has_implicit_batch_dimension:
        if size != None:
@@ -3555,13 +3214,11 @@
        layer.resize_mode = trt.ResizeMode.LINEAR
    else:
        layer.resize_mode = trt.ResizeMode.NEAREST

    if align_corners != None:
-        layer.coordinate_transformation = (
-            trt.ResizeCoordinateTransformation.ALIGN_CORNERS
-        )
+        layer.coordinate_transformation = trt.ResizeCoordinateTransformation.ALIGN_CORNERS

    set_layer_name(layer, target, name)
    return layer.get_output(0)


@@ -3579,13 +3236,11 @@
    if dtype_val is None:
        dtype_val = input_val.dtype
        dtype_val = torch_dtype_from_trt(dtype_val)

    device_val = kwargs.get("device")
-    assert (
-        device_val == "cuda" or device_val == None
-    ), f"device is not `cuda` but {device_val}"
+    assert device_val == "cuda" or device_val == None, f"device is not `cuda` but {device_val}"

    weight = torch.ones(size_val, dtype=dtype_val)
    return get_trt_tensor(network, weight, f"{name}_weight")


@@ -3603,13 +3258,11 @@
    if dtype_val is None:
        dtype_val = input_val.dtype
        dtype_val = torch_dtype_from_trt(dtype_val)

    device_val = kwargs.get("device")
-    assert (
-        device_val == "cuda" or device_val == None
-    ), f"device is not `cuda` but {device_val}"
+    assert device_val == "cuda" or device_val == None, f"device is not `cuda` but {device_val}"

    weight = torch.zeros(size_val, dtype=dtype_val)
    return get_trt_tensor(network, weight, f"{name}_weight")


@@ -3634,13 +3287,11 @@
        input_val[i] = get_trt_tensor(network, input_source, name + f"_input_source{i}")

    if const_flag:
        for i, input_source in enumerate(input_val):
            if input_source.dtype != trt.float32:
-                input_val[i] = type_cast(
-                    network, target, f"{name}_input_cast{i}", input_source, trt.float32
-                )
+                input_val[i] = type_cast(network, target, f"{name}_input_cast{i}", input_source, trt.float32)
    einsum_layer = network.add_einsum(inputs=input_val, equation=equation)
    return einsum_layer.get_output(0)


@tensorrt_converter(acc_ops.as_strided)
--- py/torch_tensorrt/fx/test/converters/acc_op/test_prod.py	2022-08-12 19:16:11.712868 +0000
+++ py/torch_tensorrt/fx/test/converters/acc_op/test_prod.py	2022-08-12 19:20:00.656595 +0000
@@ -70,13 +70,11 @@
            inputs,
            expected_ops={expected_acc_op},
            test_implicit_batch_dim=(dim != 0),
        )

-    @parameterized.expand(
-        [(f"{acc_ops.prod.__name__}_no_dim_no_keepdim", torch.prod, acc_ops.prod)]
-    )
+    @parameterized.expand([(f"{acc_ops.prod.__name__}_no_dim_no_keepdim", torch.prod, acc_ops.prod)])
    def test_prod_all_dims(
        self,
        test_name,
        op,
        expected_acc_op,
@@ -107,12 +105,10 @@
                dtype=torch.float32,
                shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))],
            ),
        ]

-        self.run_test_with_dynamic_shape(
-            Prod(), input_specs, expected_ops={acc_ops.prod}
-        )
+        self.run_test_with_dynamic_shape(Prod(), input_specs, expected_ops={acc_ops.prod})


if __name__ == "__main__":
    run_tests()
--- py/torch_tensorrt/fx/test/converters/acc_op/test_reduce_ops.py	2022-08-12 19:16:11.712868 +0000
+++ py/torch_tensorrt/fx/test/converters/acc_op/test_reduce_ops.py	2022-08-12 19:20:00.693911 +0000
@@ -50,16 +50,11 @@
            inputs,
            expected_ops={expected_acc_op},
            test_implicit_batch_dim=(dim != 0),
        )

-    @parameterized.expand(
-        [
-            (f"{acc_op.__name__}_no_dim_no_keepdim", op, acc_op)
-            for op, acc_op in reduce_ops
-        ]
-    )
+    @parameterized.expand([(f"{acc_op.__name__}_no_dim_no_keepdim", op, acc_op) for op, acc_op in reduce_ops])
    def test_reduce_all_dims(
        self,
        test_name,
        op,
        expected_acc_op,
@@ -74,16 +69,11 @@
            inputs,
            expected_ops={expected_acc_op},
            test_implicit_batch_dim=False,
        )

-    @parameterized.expand(
-        [
-            (f"{acc_op.__name__}_no_dim_no_keepdim", op, acc_op)
-            for op, acc_op in reduce_ops
-        ]
-    )
+    @parameterized.expand([(f"{acc_op.__name__}_no_dim_no_keepdim", op, acc_op) for op, acc_op in reduce_ops])
    def test_reduce_all_dims_with_dynamic_shape_four_dimensions(
        self,
        test_name,
        op,
        expected_acc_op,
@@ -97,12 +87,10 @@
                shape=(-1, -1, -1, -1),
                dtype=torch.float32,
                shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (3, 3, 3, 3))],
            ),
        ]
-        self.run_test_with_dynamic_shape(
-            Reduce(), input_specs, expected_ops={expected_acc_op}
-        )
+        self.run_test_with_dynamic_shape(Reduce(), input_specs, expected_ops={expected_acc_op})


if __name__ == "__main__":
    run_tests()
--- py/torch_tensorrt/fx/test/converters/acc_op/test_tile.py	2022-08-12 19:16:11.712868 +0000
+++ py/torch_tensorrt/fx/test/converters/acc_op/test_tile.py	2022-08-12 19:20:00.870688 +0000
@@ -26,14 +26,11 @@
        inputs = [torch.randn(*input_shape)]
        self.run_test(
            Tile(dims),
            inputs,
            expected_ops={acc_ops.tile},
-            test_implicit_batch_dim=(
-                len(input_shape) > len(dims)
-                or (len(input_shape) == len(dims) and dims[0] == 1)
-            ),
+            test_implicit_batch_dim=(len(input_shape) > len(dims) or (len(input_shape) == len(dims) and dims[0] == 1)),
        )

    @parameterized.expand(
        [
            ("same_num_dims", (-1, 2, 3), (1, 2, 2)),
@@ -62,13 +59,11 @@
                        tuple(i if i != -1 else 3 for i in shape),
                    )
                ],
            ),
        ]
-        self.run_test_with_dynamic_shape(
-            Tile(dims), input_specs, expected_ops={acc_ops.tile}
-        )
+        self.run_test_with_dynamic_shape(Tile(dims), input_specs, expected_ops={acc_ops.tile})

    @parameterized.expand(
        [
            ("all_dynamic_dim", (-1, -1), (1, 2, 2, 1)),
        ]
@@ -88,13 +83,11 @@
                dtype=torch.float32,
                shape_ranges=[((1, 1, 1, 3), (3, 3, 3, 3), (3, 3, 3, 3))],
            ),
        ]

-        self.run_test_with_dynamic_shape(
-            Tile(dims), input_specs, expected_ops={acc_ops.tile}
-        )
+        self.run_test_with_dynamic_shape(Tile(dims), input_specs, expected_ops={acc_ops.tile})

    def test_tile_non_int_dims(self):
        class Tile(nn.Module):
            def __init__(self):
                super().__init__()
@@ -103,13 +96,11 @@
                y = y * 2
                return torch.tile(x, (1, y.shape[1], y.shape[1]))

        inputs = [torch.randn(2, 2, 3), torch.randn(2, 2, 3)]
        batch_size_range = (1, 2, 3)
-        input_specs = InputTensorSpec.from_tensors_with_dynamic_batch_size(
-            inputs, batch_size_range
-        )
+        input_specs = InputTensorSpec.from_tensors_with_dynamic_batch_size(inputs, batch_size_range)
        self.run_test_with_dynamic_shape(
            Tile(),
            input_specs,
            expected_ops={acc_ops.tile},
        )
@@ -134,12 +125,10 @@
                dtype=torch.float32,
                shape_ranges=[((1, 1, 1, 3), (3, 3, 3, 3), (3, 3, 3, 3))],
            ),
        ]

-        self.run_test_with_dynamic_shape(
-            Tile(), input_specs, expected_ops={acc_ops.tile}
-        )
+        self.run_test_with_dynamic_shape(Tile(), input_specs, expected_ops={acc_ops.tile})


if __name__ == "__main__":
    run_tests()
--- py/torch_tensorrt/fx/test/converters/acc_op/test_topk.py	2022-08-12 19:16:11.712868 +0000
+++ py/torch_tensorrt/fx/test/converters/acc_op/test_topk.py	2022-08-12 19:20:00.924215 +0000
@@ -24,13 +24,11 @@
                self.dim = dim
                self.largest = largest

            def forward(self, x):
                if self.dim is not None:
-                    out = torch.topk(
-                        x, k=self.k, dim=self.dim, largest=self.largest, sorted=False
-                    )
+                    out = torch.topk(x, k=self.k, dim=self.dim, largest=self.largest, sorted=False)
                else:
                    out = torch.topk(x, k=self.k, largest=self.largest, sorted=False)
                return out[0], out[1]

        inputs = [torch.randn(1, 2, 3, 4)]
@@ -58,13 +56,11 @@
                self.dim = dim
                self.largest = largest

            def forward(self, x):
                if self.dim is not None:
-                    out = torch.topk(
-                        x, k=self.k, dim=self.dim, largest=self.largest, sorted=False
-                    )
+                    out = torch.topk(x, k=self.k, dim=self.dim, largest=self.largest, sorted=False)
                else:
                    out = torch.topk(x, k=self.k, largest=self.largest, sorted=False)
                return out[0], out[1]

        input_specs = [
@@ -73,12 +69,10 @@
                dtype=torch.float32,
                shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (3, 3, 3, 3))],
            ),
        ]

-        self.run_test_with_dynamic_shape(
-            TopK(k, dim), input_specs, expected_ops={acc_ops.topk}
-        )
+        self.run_test_with_dynamic_shape(TopK(k, dim), input_specs, expected_ops={acc_ops.topk})


if __name__ == "__main__":
    run_tests()
--- py/torch_tensorrt/fx/test/converters/acc_op/test_to_dtype.py	2022-08-12 19:16:11.712868 +0000
+++ py/torch_tensorrt/fx/test/converters/acc_op/test_to_dtype.py	2022-08-12 19:20:00.924826 +0000
@@ -51,13 +51,11 @@

        input = torch.randn(2, 2).to(torch.float16)
        inputs = [
            input,
        ]
-        self.run_test(
-            To(), inputs, expected_ops={acc_ops.to_dtype}, test_implicit_batch_dim=False
-        )
+        self.run_test(To(), inputs, expected_ops={acc_ops.to_dtype}, test_implicit_batch_dim=False)

    def test_cuda_fp16(self):
        class To(torch.nn.Module):
            def forward(self, x):
                return x.to(torch.device("cuda:0"), torch.float16)
@@ -106,13 +104,11 @@
                dtype=torch.float16,
                shape_ranges=[((1, 1, 3, 3), (3, 3, 3, 3), (3, 3, 3, 3))],
            ),
        ]

-        self.run_test_with_dynamic_shape(
-            To(), input_specs, expected_ops={acc_ops.to_dtype, acc_ops.add}
-        )
+        self.run_test_with_dynamic_shape(To(), input_specs, expected_ops={acc_ops.to_dtype, acc_ops.add})

    def test_device(self):
        class To(torch.nn.Module):
            def __init__(self):
                super().__init__()
@@ -152,13 +148,11 @@
                dtype=torch.float16,
                shape_ranges=[((1, 1, 3, 3), (3, 3, 3, 3), (3, 3, 3, 3))],
            ),
        ]

-        self.run_test_with_dynamic_shape(
-            To(), input_specs, expected_ops={acc_ops.to_dtype, acc_ops.add}
-        )
+        self.run_test_with_dynamic_shape(To(), input_specs, expected_ops={acc_ops.to_dtype, acc_ops.add})

    def test_device_fp16(self):
        class To(torch.nn.Module):
            def __init__(self):
                super().__init__()
@@ -244,13 +238,11 @@
                dtype=torch.float32,
                shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (3, 3, 3, 3))],
            ),
        ]

-        self.run_test_with_dynamic_shape(
-            To(), input_specs, expected_ops={acc_ops.to_dtype}
-        )
+        self.run_test_with_dynamic_shape(To(), input_specs, expected_ops={acc_ops.to_dtype})

    # Half is not suitable for dynamic shape
    # Error: assert engine

    # tensor.half()
@@ -307,12 +299,10 @@
                dtype=torch.int,
                shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (3, 3, 3, 3))],
            ),
        ]

-        self.run_test_with_dynamic_shape(
-            To(), input_specs, expected_ops={acc_ops.to_dtype}
-        )
+        self.run_test_with_dynamic_shape(To(), input_specs, expected_ops={acc_ops.to_dtype})


if __name__ == "__main__":
    run_tests()
--- py/torch_tensorrt/fx/test/converters/acc_op/test_unary_ops.py	2022-08-12 19:16:11.712868 +0000
+++ py/torch_tensorrt/fx/test/converters/acc_op/test_unary_ops.py	2022-08-12 19:20:01.054432 +0000
@@ -62,13 +62,11 @@
                dtype=torch.float32,
                shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (3, 3, 3, 3))],
            ),
        ]

-        self.run_test_with_dynamic_shape(
-            TestModule(orig_op), input_specs, expected_ops={expected_op}
-        )
+        self.run_test_with_dynamic_shape(TestModule(orig_op), input_specs, expected_ops={expected_op})


class TestUnaryOpNotConverters(AccTestCase):
    @parameterized.expand(
        [
@@ -87,13 +85,11 @@
                x = self.orig_op(x)
                return self.orig_op(x)

        m = TestModule(orig_op)
        inputs = [torch.randn(2, 2, 3).to(input_dtype)]
-        self.run_test(
-            m, inputs, expected_ops={expected_op}, test_implicit_batch_dim=False
-        )
+        self.run_test(m, inputs, expected_ops={expected_op}, test_implicit_batch_dim=False)


class TestUnaryOpNotConvertersWithDynamicShapeFourDimensions(AccTestCase):
    @parameterized.expand(
        [
@@ -118,13 +114,11 @@
                dtype=torch.float32,
                shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (3, 3, 3, 3))],
            ),
        ]

-        self.run_test_with_dynamic_shape(
-            TestModule(orig_op), input_specs, expected_ops={expected_op}
-        )
+        self.run_test_with_dynamic_shape(TestModule(orig_op), input_specs, expected_ops={expected_op})


class TestUnaryRSQRTConverters(AccTestCase):
    def test_unary_ops(self):
        class TestModule(nn.Module):
@@ -148,12 +142,10 @@
                dtype=torch.float32,
                shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (3, 3, 3, 3))],
            ),
        ]

-        self.run_test_with_dynamic_shape(
-            TestModule(), input_specs, expected_ops={acc_ops.sqrt, acc_ops.reciprocal}
-        )
+        self.run_test_with_dynamic_shape(TestModule(), input_specs, expected_ops={acc_ops.sqrt, acc_ops.reciprocal})


if __name__ == "__main__":
    run_tests()
--- py/torch_tensorrt/fx/test/core/test_input_tensor_spec.py	2022-08-12 19:16:11.712868 +0000
+++ py/torch_tensorrt/fx/test/core/test_input_tensor_spec.py	2022-08-12 19:20:01.105902 +0000
@@ -35,26 +35,22 @@
            self._validate_spec(spec, tensor)

    def test_from_tensors_with_dynamic_batch_size(self):
        tensors = [torch.randn(1, 2, 3), torch.randn(1, 4)]
        batch_size_range = [2, 3, 4]
-        specs = InputTensorSpec.from_tensors_with_dynamic_batch_size(
-            tensors, batch_size_range
-        )
+        specs = InputTensorSpec.from_tensors_with_dynamic_batch_size(tensors, batch_size_range)
        for spec, tensor in zip(specs, tensors):
            self._validate_spec(spec, tensor, dynamic_dims=[0])

            for batch_size, shape in zip(batch_size_range, spec.shape_ranges[0]):
                self.assertEqual(batch_size, shape[0])
                self.assertSequenceEqual(tensor.shape[1:], shape[1:])

    def test_from_tensors_with_dynamic_batch_size_different_batch_dims(self):
        tensors = [torch.randn(1, 2, 3), torch.randn(2, 1, 4)]
        batch_size_range = [2, 3, 4]
-        specs = InputTensorSpec.from_tensors_with_dynamic_batch_size(
-            tensors, batch_size_range, batch_dims=[0, 1]
-        )
+        specs = InputTensorSpec.from_tensors_with_dynamic_batch_size(tensors, batch_size_range, batch_dims=[0, 1])
        for i, spec_and_tensor in enumerate(zip(specs, tensors)):
            spec, tensor = spec_and_tensor
            self._validate_spec(spec, tensor, dynamic_dims=[i])

            for batch_size, shape in zip(batch_size_range, spec.shape_ranges[0]):
@@ -62,13 +58,11 @@
                tensor_shape = list(tensor.shape)
                tensor_shape[i] = batch_size
                self.assertSequenceEqual(tensor_shape, shape)

    def test_generate_input_specs(self):
-        lower_setting = LowerSetting(
-            explicit_batch_dimension=False, max_batch_size=256, opt_profile_replica=2
-        )
+        lower_setting = LowerSetting(explicit_batch_dimension=False, max_batch_size=256, opt_profile_replica=2)

        # Implicit batch dim.
        inputs = [torch.randn(1, 2, 3)]
        specs = generate_input_specs(inputs, lower_setting)
        for spec, tensor in zip(specs, inputs):
--- py/torch_tensorrt/fx/test/quant/test_quant_trt.py	2022-08-12 19:16:11.716868 +0000
+++ py/torch_tensorrt/fx/test/quant/test_quant_trt.py	2022-08-12 19:20:01.635277 +0000
@@ -46,13 +46,11 @@
            shape_ranges=shape_ranges,
            has_batch_dim=True,
        )
    ]

-    interp = TRTInterpreter(
-        model, input_specs, explicit_batch_dimension=True, explicit_precision=True
-    )
+    interp = TRTInterpreter(model, input_specs, explicit_batch_dimension=True, explicit_precision=True)
    result = interp.run(lower_precision=LowerPrecision.INT8)
    trt_mod = TRTModule(result.engine, result.input_names, result.output_names)
    return trt_mod


@@ -65,13 +63,11 @@
            ),
            weight=torch.ao.quantization.default_weight_observer,
        )
        self.trt_backend_config_dict = get_tensorrt_backend_config_dict()

-    def _test_quantized_inputs_outputs(
-        self, prepare_custom_config_dict, prepare_count_check, convert_count_check
-    ):
+    def _test_quantized_inputs_outputs(self, prepare_custom_config_dict, prepare_count_check, convert_count_check):
        """
        Test the option to have inputs and outputs of the graph quantized
        """

        class M(torch.nn.Module):
@@ -113,13 +109,11 @@
            # output of ref conv1 and output of ref conv2
            ns.call_function(torch.quantize_per_tensor): 2,
            # input of ref conv1 and input of ref conv2
            ns.call_method("dequantize"): 2,
        }
-        self._test_quantized_inputs_outputs(
-            prepare_custom_config_dict, prepare_count_check, convert_count_check
-        )
+        self._test_quantized_inputs_outputs(prepare_custom_config_dict, prepare_count_check, convert_count_check)

    def test_fp32_input_quantized_output(self):
        prepare_custom_config_dict = {"output_quantized_idxs": [0]}
        prepare_count_check = {
            ns.call_module(torch.ao.quantization.MinMaxObserver): 3,
@@ -128,13 +122,11 @@
            # input, output of conv1 and output of conv2
            ns.call_function(torch.quantize_per_tensor): 3,
            # input of conv1, conv2
            ns.call_method("dequantize"): 2,
        }
-        self._test_quantized_inputs_outputs(
-            prepare_custom_config_dict, prepare_count_check, convert_count_check
-        )
+        self._test_quantized_inputs_outputs(prepare_custom_config_dict, prepare_count_check, convert_count_check)

    def test_quantized_input_fp32_output(self):
        prepare_custom_config_dict = {"input_quantized_idxs": [0]}
        prepare_count_check = {
            ns.call_module(torch.ao.quantization.MinMaxObserver): 2,
@@ -143,26 +135,22 @@
            # output of conv1, conv2
            ns.call_function(torch.quantize_per_tensor): 2,
            # input of ref conv1, input of ref conv2, final output
            ns.call_method("dequantize"): 3,
        }
-        self._test_quantized_inputs_outputs(
-            prepare_custom_config_dict, prepare_count_check, convert_count_check
-        )
+        self._test_quantized_inputs_outputs(prepare_custom_config_dict, prepare_count_check, convert_count_check)

    def test_fp32_input_fp32_output(self):
        prepare_custom_config_dict = {}
        prepare_count_check = {
            ns.call_module(torch.ao.quantization.MinMaxObserver): 3,
        }
        convert_count_check = {
            ns.call_function(torch.quantize_per_tensor): 3,
            ns.call_method("dequantize"): 3,
        }
-        self._test_quantized_inputs_outputs(
-            prepare_custom_config_dict, prepare_count_check, convert_count_check
-        )
+        self._test_quantized_inputs_outputs(prepare_custom_config_dict, prepare_count_check, convert_count_check)

    def _test_standalone_module(
        self,
        interface_config,
        prepare_count_check,
@@ -213,20 +201,14 @@

        data = torch.randn(1, 1, 1, 1)
        # instantiate M and RefM and align the parameters
        original_m = M().eval()
        original_ref_m = RefM().eval()
-        original_ref_m.conv1.weight = torch.nn.Parameter(
-            original_m.conv.weight.detach()
-        )
+        original_ref_m.conv1.weight = torch.nn.Parameter(original_m.conv.weight.detach())
        original_ref_m.conv1.bias = torch.nn.Parameter(original_m.conv.bias.detach())
-        original_ref_m.conv2.weight = torch.nn.Parameter(
-            original_m.standalone.conv.weight.detach()
-        )
-        original_ref_m.conv2.bias = torch.nn.Parameter(
-            original_m.standalone.conv.bias.detach()
-        )
+        original_ref_m.conv2.weight = torch.nn.Parameter(original_m.standalone.conv.weight.detach())
+        original_ref_m.conv2.bias = torch.nn.Parameter(original_m.standalone.conv.bias.detach())

        sm_example_inputs = (data,)
        prepare_config = {
            "standalone_module_name": [
                (
@@ -253,20 +235,16 @@
            backend_config=backend_config_dict,
        )
        # calibration
        m(data)
        self.checkGraphModuleNodes(m, expected_node_occurrence=prepare_count_check)
-        self.checkGraphModuleNodes(
-            m.standalone, expected_node_occurrence=standalone_prepare_count_check
-        )
+        self.checkGraphModuleNodes(m.standalone, expected_node_occurrence=standalone_prepare_count_check)

        # check converted/quantized model
        m = convert_to_reference_fx(m, backend_config=backend_config_dict)
        self.checkGraphModuleNodes(m, expected_node_occurrence=convert_count_check)
-        self.checkGraphModuleNodes(
-            m.standalone, expected_node_occurrence=standalone_convert_count_check
-        )
+        self.checkGraphModuleNodes(m.standalone, expected_node_occurrence=standalone_convert_count_check)
        res = m(data)

        # quantize the reference model
        ref_m = prepare_fx(
            original_ref_m_copy,
@@ -285,17 +263,13 @@
            "output_quantized_idxs": [],  # float output
        }
        interface_config = float_interface_config
        # input and output of first conv, observer for standalone module
        # will be inserted in the standalone module itself
-        prepare_count_check = {
-            ns.call_module(torch.ao.quantization.HistogramObserver): 2
-        }
+        prepare_count_check = {ns.call_module(torch.ao.quantization.HistogramObserver): 2}
        # for input and output of conv in the standalone module
-        standalone_prepare_count_check = {
-            ns.call_module(torch.ao.quantization.HistogramObserver): 2
-        }
+        standalone_prepare_count_check = {ns.call_module(torch.ao.quantization.HistogramObserver): 2}
        convert_count_check = {
            # input and output of reference conv
            ns.call_function(torch.quantize_per_tensor): 2,
            ns.call_module(nnqr.Conv2d): 1,
            ns.call_method("dequantize"): 2,
@@ -351,17 +325,13 @@
            "root_module": torch.nn.Conv2d,
            "reference_quantized_module_for_root": torch.nn.quantized._reference.Conv2d,
        }
        custom_backend_config_dict = {"configs": [conv_module_config]}
        # observer for input and output of first conv
-        prepare_count_check = {
-            ns.call_module(torch.ao.quantization.HistogramObserver): 2
-        }
+        prepare_count_check = {ns.call_module(torch.ao.quantization.HistogramObserver): 2}
        # for output of conv in the standalone module
-        standalone_prepare_count_check = {
-            ns.call_module(torch.ao.quantization.HistogramObserver): 1
-        }
+        standalone_prepare_count_check = {ns.call_module(torch.ao.quantization.HistogramObserver): 1}
        convert_count_check = {
            # quantizing input/output for reference conv
            ns.call_function(torch.quantize_per_tensor): 2,
            ns.call_module(nnqr.Conv2d): 1,
            # dequantize the input of reference conv and
@@ -400,13 +370,11 @@
            ),
            weight=torch.ao.quantization.default_weight_observer,
        )
        self.trt_backend_config_dict = get_tensorrt_backend_config_dict()

-    def _test_module(
-        self, m, inputs, shape_ranges, no_prepare=None, no_convert=None, is_qat=False
-    ):
+    def _test_module(self, m, inputs, shape_ranges, no_prepare=None, no_convert=None, is_qat=False):
        """
        Args:
          m: the float module we want to test
          inputs: list of inputs for the module
          shape_ranges: a list of shape_range, where every shape_range is a tuple of
@@ -468,13 +436,11 @@

            def forward(self, x):
                return self.relu(self.conv(x))

        # just testing conv2d since conv1d and conv3d are not supported in fx2trt
-        for dim, has_relu, f_relu, is_qat in itertools.product(
-            [1, 2], [True, False], [True, False], [True, False]
-        ):
+        for dim, has_relu, f_relu, is_qat in itertools.product([1, 2], [True, False], [True, False], [True, False]):
            # when has_relu=False, we have torch.nn.Identity, which would introduce
            # extra quant-dequat pair
            no_convert = {
                ns.call_function(torch.quantize_per_tensor): 2 + int(not has_relu),
                ns.call_method("dequantize"): 2 + int(not has_relu),
@@ -510,13 +476,11 @@
                return self.relu(self.linear(x))

        linear_input = torch.rand(8, 5)

        shape_ranges = [((1, 5), (5, 5), (10, 5))]
-        for has_relu, f_relu, is_qat in itertools.product(
-            [True, False], [True, False], [True, False]
-        ):
+        for has_relu, f_relu, is_qat in itertools.product([True, False], [True, False], [True, False]):
            # when has_relu=False, we have torch.nn.Identity, which would introduce
            # extra quant-dequat pair
            no_convert = {
                ns.call_function(torch.quantize_per_tensor): 2 + int(not has_relu),
                ns.call_method("dequantize"): 2 + int(not has_relu),
@@ -662,13 +626,11 @@
            ns.call_function(torch.addmm): 1,
            ns.call_method("dequantize"): 3,
        }
        self.checkGraphModuleNodes(quantized, expected_node_occurrence=node_occurrence)

-    @unittest.skip(
-        "This is not supported yet, we can enable the test after it's supported"
-    )
+    @unittest.skip("This is not supported yet, we can enable the test after it's supported")
    def test_conv_add(self):
        class M(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.conv = torch.nn.Conv2d(3, 3, 3)
@@ -828,13 +790,11 @@
        self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
        standalone_node_occurrence = {
            # output of the standalone module
            ns.call_module(torch.ao.quantization.HistogramObserver): 1,
        }
-        self.checkGraphModuleNodes(
-            m.standalone, expected_node_occurrence=standalone_node_occurrence
-        )
+        self.checkGraphModuleNodes(m.standalone, expected_node_occurrence=standalone_node_occurrence)
        m = convert_to_reference_fx(m, backend_config=backend_config_dict)
        node_occurrence = {
            # two inputs for standalone module
            ns.call_function(torch.quantize_per_tensor): 2,
            ns.call_module(nn.Conv2d): 1,
@@ -847,13 +807,11 @@
            ns.call_module(nn.Conv2d): 1,
            ns.call_module(torch.nn.ReLU): 1,
            # two input and one output for the pattern in standalone module
            ns.call_method("dequantize"): 3,
        }
-        self.checkGraphModuleNodes(
-            m.standalone, expected_node_occurrence=standalone_node_occurrence
-        )
+        self.checkGraphModuleNodes(m.standalone, expected_node_occurrence=standalone_node_occurrence)

    def test_quant_dequant_not_fold(self):
        class LinearModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
--- py/torch_tensorrt/fx/tools/common_fx2trt.py	2022-08-12 19:16:11.716868 +0000
+++ py/torch_tensorrt/fx/tools/common_fx2trt.py	2022-08-12 19:20:01.932333 +0000
@@ -29,13 +29,11 @@
    """
    target_atoms = target.split(".")
    attr_itr = mod
    for i, atom in enumerate(target_atoms):
        if not hasattr(attr_itr, atom):
-            raise RuntimeError(
-                f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}"
-            )
+            raise RuntimeError(f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}")
        attr_itr = getattr(attr_itr, atom)
    return attr_itr


@unittest.skipIf(not torch.cuda.is_available(), "Skip because CUDA is not available")
@@ -82,13 +80,11 @@
            end_event = torch.cuda.Event(enable_timing=True)
            start_event.record()
            outputs = trt_mod(*cuda_inputs)
            end_event.record()
            torch.cuda.synchronize()
-            _LOGGER.info(
-                f"TRT run time(s)= {(start_event.elapsed_time(end_event) * 1.0e-3)}"
-            )
+            _LOGGER.info(f"TRT run time(s)= {(start_event.elapsed_time(end_event) * 1.0e-3)}")

            if isinstance(outputs, torch.Tensor):
                ref_outputs = [ref_outputs]
                outputs = [outputs]
            for out, ref in zip(outputs, ref_outputs):
@@ -126,26 +122,22 @@
            mod.eval()
            if len(expected_ops):
                self.assert_has_op(mod, expected_ops)

            interpreter_result = interpreter.run(
-                lower_precision=LowerPrecision.FP16
-                if fp16_mode
-                else LowerPrecision.FP32
+                lower_precision=LowerPrecision.FP16 if fp16_mode else LowerPrecision.FP32
            )
            trt_mod = TRTModule(
                interpreter_result.engine,
                interpreter_result.input_names,
                interpreter_result.output_names,
            )
            res_trt = trt_mod(*cuda_inputs).cpu()
            res_cpu = mod(*inputs)
            assert len(res_trt) == len(res_cpu)
            assert len(res_cpu) == len(comparators)
-            for output_trt, output_cpu, comparator in zip(
-                res_trt, res_cpu, comparators
-            ):
+            for output_trt, output_cpu, comparator in zip(res_trt, res_cpu, comparators):
                comp_func = comparator[0]
                args = comparator[1]
                self.assertTrue(comp_func(output_trt, output_cpu, *args))

    def run_test_with_error(self, mod, inputs, interpreter, expect_error):
@@ -165,13 +157,11 @@
            if node.op == "call_module":
                ops_in_mod.add(type(fetch_attr(mod, node.target)))
            elif node.op in {"call_function", "call_method"}:
                ops_in_mod.add(node.target)

-        self.assertTrue(
-            ops_in_mod >= ops, f"expected ops {ops}, actuall ops {ops_in_mod}"
-        )
+        self.assertTrue(ops_in_mod >= ops, f"expected ops {ops}, actuall ops {ops_in_mod}")

    def assert_unexpected_op(self, mod, ops):
        for node in mod.graph.nodes:
            if node.op == "call_module":
                if type(fetch_attr(mod, node.target)) in ops:
@@ -204,13 +194,11 @@
        # after we refactor the internal callsites to use this file
        mod = torch.fx.symbolic_trace(mod)
        shape_prop.ShapeProp(mod).propagate(*inputs)
        mod = NormalizeArgs(mod).transform()
        interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs))
-        super().run_test_custom_compare_results(
-            mod, inputs, expected_ops, interp, comparators, fp16_mode=fp16_mode
-        )
+        super().run_test_custom_compare_results(mod, inputs, expected_ops, interp, comparators, fp16_mode=fp16_mode)


class AccTestCase(TRTTestCase):
    def run_test(
        self,
@@ -233,41 +221,31 @@
            pass_tracer = chain_passes(*apply_passes)
            mod = pass_tracer(mod, inputs)

        if test_implicit_batch_dim:
            interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs))
-            super().run_test(
-                mod, inputs, expected_ops, unexpected_ops, interp, rtol, atol, precision
-            )
+            super().run_test(mod, inputs, expected_ops, unexpected_ops, interp, rtol, atol, precision)

        if test_explicit_batch_dim:
-            interp = TRTInterpreter(
-                mod, InputTensorSpec.from_tensors(inputs), explicit_batch_dimension=True
-            )
-            super().run_test(
-                mod, inputs, expected_ops, unexpected_ops, interp, rtol, atol, precision
-            )
+            interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs), explicit_batch_dimension=True)
+            super().run_test(mod, inputs, expected_ops, unexpected_ops, interp, rtol, atol, precision)

        if test_explicit_precision:
            interp = TRTInterpreter(
                mod,
                InputTensorSpec.from_tensors(inputs),
                explicit_precision=test_explicit_precision,
            )
-            super().run_test(
-                mod, inputs, expected_ops, unexpected_ops, interp, rtol, atol
-            )
+            super().run_test(mod, inputs, expected_ops, unexpected_ops, interp, rtol, atol)

            interp = TRTInterpreter(
                mod,
                InputTensorSpec.from_tensors(inputs),
                explicit_batch_dimension=True,
                explicit_precision=test_explicit_precision,
            )
-            super().run_test(
-                mod, inputs, expected_ops, unexpected_ops, interp, rtol, atol, precision
-            )
+            super().run_test(mod, inputs, expected_ops, unexpected_ops, interp, rtol, atol, precision)

    def run_test_with_assert_error(
        self,
        mod,
        inputs,
@@ -281,13 +259,11 @@
        if test_implicit_batch_dim:
            interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs))
            super().run_test_with_error(mod, inputs, interp, expect_error)

        if test_explicit_batch_dim:
-            interp = TRTInterpreter(
-                mod, InputTensorSpec.from_tensors(inputs), explicit_batch_dimension=True
-            )
+            interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs), explicit_batch_dimension=True)
            super().run_test_with_error(mod, inputs, interp, expect_error)

    def run_test_with_dynamic_shape(
        self,
        mod,
--- py/torch_tensorrt/fx/tools/trt_minimizer.py	2022-08-12 19:16:11.716868 +0000
+++ py/torch_tensorrt/fx/tools/trt_minimizer.py	2022-08-12 19:20:01.988637 +0000
@@ -8,16 +8,12 @@
from .. import InputTensorSpec, TRTInterpreter, TRTModule

_LOGGER: logging.Logger = logging.getLogger(__name__)


-def lower_mod_default(
-    mod: torch.fx.GraphModule, inputs: Tensors, batch_size: Any = 2048
-) -> TRTModule:
-    interp = TRTInterpreter(
-        mod, InputTensorSpec.from_tensors(inputs), explicit_batch_dimension=True
-    )
+def lower_mod_default(mod: torch.fx.GraphModule, inputs: Tensors, batch_size: Any = 2048) -> TRTModule:
+    interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs), explicit_batch_dimension=True)
    interpreter_result = interp.run(max_batch_size=batch_size)
    res_mod = TRTModule(
        interpreter_result.engine,
        interpreter_result.input_names,
        interpreter_result.output_names,
@@ -37,13 +33,11 @@
        module: torch.fx.GraphModule,
        sample_input: Tensors,
        compare_fn: Callable[[Any, Any, Any], Tuple[float, bool]],
        settings: TensorRTMinizerSetting = TensorRTMinizerSetting(),
        max_batch_size: Any = 2048,
-        lower_fn: Callable[
-            [torch.fx.GraphModule, Tensors, Any], TRTModule
-        ] = lower_mod_default,
+        lower_fn: Callable[[torch.fx.GraphModule, Tensors, Any], TRTModule] = lower_mod_default,
    ):
        self.lower_fn = lower_fn
        self.max_batch_size = max_batch_size
        super().__init__(module, sample_input, compare_fn, settings)

@@ -56,13 +50,11 @@
        mod.eval()
        try:
            mod = self.lower_fn(mod, inputs, self.max_batch_size)
            output = mod(*inputs)
        except RuntimeError as e:
-            raise net_min_base.FxNetMinimizerRunFuncError(
-                f"Encounter an error when processing \n{mod.graph}\n {e}"
-            )
+            raise net_min_base.FxNetMinimizerRunFuncError(f"Encounter an error when processing \n{mod.graph}\n {e}")
        else:
            return output

    def get_nodes(self, start=None, end=None, enable_print=False):
        nodes = self._collect_nodes(start, end)
--- py/torch_tensorrt/fx/tools/trt_splitter.py	2022-08-12 19:16:11.716868 +0000
+++ py/torch_tensorrt/fx/tools/trt_splitter.py	2022-08-12 19:20:02.057670 +0000
@@ -72,13 +72,11 @@
            operator_support,
            settings,
            non_acc_submodule_name="_run_on_gpu_",
        )

-    def _lower_model_to_backend(
-        self, mod: torch.fx.GraphModule, inputs: Iterable[torch.Tensor]
-    ):
+    def _lower_model_to_backend(self, mod: torch.fx.GraphModule, inputs: Iterable[torch.Tensor]):
        """
        Lower a GraphModule `mod` to TensorRT with `inputs`.
        """
        # Current code for lowering is place-holder, subject to future change
        # based on feeds model's actual status
--- py/torch_tensorrt/fx/tracer/acc_tracer/acc_tracer.py	2022-08-12 19:16:11.716868 +0000
+++ py/torch_tensorrt/fx/tracer/acc_tracer/acc_tracer.py	2022-08-12 19:20:02.310172 +0000
@@ -41,13 +41,11 @@
    def __init__(self):
        super().__init__()
        self.exceptions_rewritten: Set[Type[Exception]] = set()
        self.exceptions_bool_rewritten: Set[Type[Exception]] = set()

-    def rewrite(
-        self, fn: FunctionType
-    ) -> Tuple[FunctionType, Set[Type[Exception]], Set[Type[Exception]]]:
+    def rewrite(self, fn: FunctionType) -> Tuple[FunctionType, Set[Type[Exception]], Set[Type[Exception]]]:

        # Normalize the source lines
        sourcelines, _ = inspect.getsourcelines(fn)
        sourcelines = normalize_source_lines(sourcelines)
        source = "".join(sourcelines)
@@ -139,12 +137,11 @@
            return if_node

        # Check that we actually have a builtin exception.
        if (
            not issubclass(exc_type, Exception)
-            or getattr(getattr(exc_type, "__class__", None), "__module__", None)
-            != "builtins"
+            or getattr(getattr(exc_type, "__class__", None), "__module__", None) != "builtins"
        ):
            return if_node

        # We need a ConditionalExceptionWrapper specialized for every kind of
        # exception, so add it to exceptions_rewritten to remember for later to
@@ -156,23 +153,17 @@
        # the If with, with args set as the If's condition and the string of the
        # exception. The call to the self._conditional_exception_wrapper_*Error
        # module is safe because the RewrittenModule will add it as an attr
        # based on the returned exceptions_rewritten, and we assume we are
        # currently modifying the AST of a method from a RewrittenModule.
-        exc_wrapper_node = ast.parse(
-            f"self.{_get_exception_wrapper_attr_name(exc_type)}()", mode="eval"
-        )
+        exc_wrapper_node = ast.parse(f"self.{_get_exception_wrapper_attr_name(exc_type)}()", mode="eval")
        assert isinstance(exc_wrapper_node, ast.Expression)
        exc_wrapper_call_node = exc_wrapper_node.body
        assert isinstance(exc_wrapper_call_node, ast.Call)
-        if isinstance(if_node.test, ast.BoolOp) and isinstance(
-            if_node.test.op, ast.And
-        ):
+        if isinstance(if_node.test, ast.BoolOp) and isinstance(if_node.test.op, ast.And):
            self.exceptions_bool_rewritten.add(exc_type)
-            bool_wrapper_node = ast.parse(
-                f"self.{_get_exception_wrapper_attr_name(exc_type)}_bool()", mode="eval"
-            )
+            bool_wrapper_node = ast.parse(f"self.{_get_exception_wrapper_attr_name(exc_type)}_bool()", mode="eval")
            assert isinstance(exc_wrapper_node, ast.Expression)
            bool_wrapper_call_node = bool_wrapper_node.body
            assert isinstance(exc_wrapper_call_node, ast.Call)
            bool_wrapper_call_node.args = if_node.test.values
            exc_wrapper_call_node.args = [
@@ -323,13 +314,11 @@
            name_target[-1] == "_"
            and name_target[0] != "_"
            and not (name_target in allow_list)
            and kind != "placeholder"
        ):
-            raise RuntimeError(
-                f"Tried to trace mutable operation {name_target}. FX only supports functional code"
-            )
+            raise RuntimeError(f"Tried to trace mutable operation {name_target}. FX only supports functional code")

        return self.graph.create_node(kind, target, args, kwargs, name, type_expr)


# List of modules that need rewriting to be supported for tracing.
@@ -384,13 +373,11 @@
            # Write all of the non-dunder or special methods from base_class
            # into RewrittenModule.
            for method_name in dir(base_class):
                method = getattr(base_class, method_name, None)
                if method is None and method_name not in {"__doc__"}:
-                    _LOGGER.warning(
-                        f"{__qualname__} does not have attribute {method_name}"
-                    )
+                    _LOGGER.warning(f"{__qualname__} does not have attribute {method_name}")

                if builtins.type(method) is not FunctionType:
                    continue

                # Always skip rewriting dunder methods, as they haven't (yet) been
@@ -437,13 +424,11 @@
                # Recursively rewrite and copy all module attrs of this module.
                for k, v in orig.__dict__.items():
                    if k == "_modules":
                        for mod_k, mod_v in v.items():
                            if getattr(mod_v, "_base_class_origin", type(mod_v)) in leaf_module_list:  # type: ignore[operator]
-                                _LOGGER.info(
-                                    f"Skip rewriting leaf module {type(mod_v)}"
-                                )
+                                _LOGGER.info(f"Skip rewriting leaf module {type(mod_v)}")
                                self._modules[mod_k] = mod_v
                            else:
                                self._modules[mod_k] = rewrite_module(mod_v)
                    else:
                        self.__dict__[k] = v
@@ -475,25 +460,21 @@
    """
    changed = False
    for node in reversed(gm.graph.nodes):
        if node.op == "call_module" and (
            isinstance(gm.get_submodule(node.target), ConditionalExceptionWrapper)
-            or isinstance(
-                gm.get_submodule(node.target), ConditionalExceptionBoolCondWrapper
-            )
+            or isinstance(gm.get_submodule(node.target), ConditionalExceptionBoolCondWrapper)
        ):
            gm.graph.erase_node(node)
            changed = True
    return changed


def _replace_tensor_meta_with_rank(gm: torch.fx.GraphModule):
    for node in gm.graph.nodes:
        if node.op != "output" and "tensor_meta" in node.meta:
-            node.meta["tensor_rank"] = acc_utils.map_tensor_metadata(
-                node.meta["tensor_meta"], lambda x: len(x.shape)
-            )
+            node.meta["tensor_rank"] = acc_utils.map_tensor_metadata(node.meta["tensor_meta"], lambda x: len(x.shape))
            del node.meta["tensor_meta"]


def rewriter_base_trace(mod, ast_rewriter_allow_list, leaf_module_list):
    rewritten_graph, rewritten_mod = AccRewritingTracer().trace(

@narendasan
Copy link
Collaborator

@frank-wei can you share the fb lint config or something so that we can use a consistent code style?

They are using black but looks like it is more than that.

This is our current configuration https://github.com/pytorch/TensorRT/blob/master/pyproject.toml

Maybe it's the line width?

@narendasan
Copy link
Collaborator

Seems like in pytorch/pytorch they dont use the line-length argument. https://github.com/pytorch/pytorch/blob/2c089290b676a221817e48c7de42d1b2bd13609a/pyproject.toml#L22
Can you try commenting it out and seeing if it passes your internal and ours?

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

@frank-wei
Copy link
Contributor Author

frank-wei commented Aug 12, 2022

Seems like in pytorch/pytorch they dont use the line-length argument. https://github.com/pytorch/pytorch/blob/2c089290b676a221817e48c7de42d1b2bd13609a/pyproject.toml#L22 Can you try commenting it out and seeing if it passes your internal and ours?

that is also what I found internally.

# NOTICE: Python Foundation strongly recommends NOT having
# project-specific configuration. There are many benefits
# to having uniform formatting across every project.  

So the black default uses 88. Let's try default.

@narendasan
Copy link
Collaborator

Yeah that is fine with me, if the fx changes pass both ours and your internal, I can handle reformatting the rest of the python code

@narendasan
Copy link
Collaborator

narendasan commented Aug 12, 2022

Alternatively if you run pre-commit run --all-files you can update the rest of the python files, might be easier to tell if it works

@github-actions github-actions bot added component: tests Issues re: Tests documentation Improvements or additions to documentation labels Aug 12, 2022
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

@frank-wei
Copy link
Contributor Author

@narendasan could you take a look at the pybind issue in the build-x86_64-pyt-nightly?

@narendasan
Copy link
Collaborator

narendasan commented Aug 12, 2022

@narendasan could you take a look at the pybind issue in the build-x86_64-pyt-nightly?

@peri044 I think you looked at these classes of errors previously on the nightly channel (might have been in another FX pr)

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

@frank-wei frank-wei merged commit 22451ce into master Aug 12, 2022
@frank-wei frank-wei deleted the fb-sync-wwei6 branch August 12, 2022 23:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed component: api [Python] Issues re: Python API component: fx component: tests Issues re: Tests documentation Improvements or additions to documentation fx
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants