Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

❓ [Question] operator being decomposed rather than being converted when a corresponding converter exists? #2665

Open
HolyWu opened this issue Feb 28, 2024 · 2 comments
Assignees
Labels
question Further information is requested

Comments

@HolyWu
Copy link
Contributor

HolyWu commented Feb 28, 2024

❓ Question

From the debug log below, it seems that the aten.grid_sampler_2d operator gets decomposed into several lower-level operators. But isn't there a corresponding converter which should be used?

What you have already tried

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_tensorrt


class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        
    def forward(self, input, grid):
        return F.grid_sample(input, grid, mode="bilinear", padding_mode="border", align_corners=True)
    
model = MyModule().eval().cuda()

inputs = [
    torch.randn((1, 3, 8, 8), dtype=torch.float, device="cuda"),
    torch.randn((1, 16, 16, 2), dtype=torch.float, device="cuda")
]

optimized_model = torch_tensorrt.compile(
    model,
    ir="dynamo",
    inputs=inputs,
    enabled_precisions={torch.float},
    debug=True,
    min_block_size=1,
    truncate_long_and_double=True,
    output_format="fx",
)
DEBUG:torch_tensorrt.dynamo.conversion.aten_ops_converters:_to_copy converter rejected node _to_copy with dtype torch.int64
DEBUG:torch_tensorrt.dynamo.conversion.aten_ops_converters:_to_copy converter rejected node _to_copy with dtype torch.int64
DEBUG:torch_tensorrt.dynamo.conversion.aten_ops_converters:_to_copy converter rejected node _to_copy_1 with dtype torch.int64
DEBUG:torch_tensorrt.dynamo.conversion.aten_ops_converters:_to_copy converter rejected node _to_copy_1 with dtype torch.int64
DEBUG:torch_tensorrt.dynamo.conversion.aten_ops_converters:_to_copy converter rejected node _to_copy_2 with dtype torch.int64
DEBUG:torch_tensorrt.dynamo.conversion.aten_ops_converters:_to_copy converter rejected node _to_copy_2 with dtype torch.int64
DEBUG:torch_tensorrt.dynamo.conversion.aten_ops_converters:_to_copy converter rejected node _to_copy_3 with dtype torch.int64
DEBUG:torch_tensorrt.dynamo.conversion.aten_ops_converters:_to_copy converter rejected node _to_copy_3 with dtype torch.int64
DEBUG:torch_tensorrt.dynamo.conversion.aten_ops_converters:_to_copy converter rejected node _to_copy_4 with dtype torch.int64
DEBUG:torch_tensorrt.dynamo.conversion.aten_ops_converters:_to_copy converter rejected node _to_copy_4 with dtype torch.int64
DEBUG:torch_tensorrt.dynamo.conversion.aten_ops_converters:_to_copy converter rejected node _to_copy_5 with dtype torch.int64
DEBUG:torch_tensorrt.dynamo.conversion.aten_ops_converters:_to_copy converter rejected node _to_copy_5 with dtype torch.int64
DEBUG:torch_tensorrt.dynamo.conversion.aten_ops_converters:_to_copy converter rejected node _to_copy_6 with dtype torch.int64
DEBUG:torch_tensorrt.dynamo.conversion.aten_ops_converters:_to_copy converter rejected node _to_copy_6 with dtype torch.int64
DEBUG:torch_tensorrt.dynamo.conversion.aten_ops_converters:_to_copy converter rejected node _to_copy_7 with dtype torch.int64
DEBUG:torch_tensorrt.dynamo.conversion.aten_ops_converters:_to_copy converter rejected node _to_copy_7 with dtype torch.int64
DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
Supported Nodes:
- torch.ops.aten.reshape.default + Operator Count: 13
- torch.ops.aten.expand.default + Operator Count: 1
- torch.ops.aten.select.int + Operator Count: 2
- torch.ops.aten.mul.Tensor + Operator Count: 10
- torch.ops.aten.add.Tensor + Operator Count: 7
- torch.ops.aten.clamp.default + Operator Count: 2
- torch.ops.aten.floor.default + Operator Count: 2
- torch.ops.aten.sub.Tensor + Operator Count: 8
- torch.ops.aten.ge.Scalar + Operator Count: 8
- torch.ops.aten.lt.Scalar + Operator Count: 8
- torch.ops.aten.logical_and.default + Operator Count: 12
- torch.ops.aten.where.self + Operator Count: 12
- torch.ops.aten.index.Tensor + Operator Count: 4

DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
Unsupported or Excluded Nodes:
- torch.ops.aten._to_copy.default + Operator Count: 8

DEBUG:torch_tensorrt.dynamo._compiler:Detected support for 89 operators out of 97 in subgraph.
DEBUG:torch_tensorrt.dynamo.conversion.aten_ops_converters:_to_copy converter rejected node _to_copy with dtype torch.int64
DEBUG:torch_tensorrt.dynamo.conversion.aten_ops_converters:_to_copy converter rejected node _to_copy with dtype torch.int64
DEBUG:torch_tensorrt.dynamo.conversion.aten_ops_converters:_to_copy converter rejected node _to_copy_1 with dtype torch.int64
DEBUG:torch_tensorrt.dynamo.conversion.aten_ops_converters:_to_copy converter rejected node _to_copy_1 with dtype torch.int64
DEBUG:torch_tensorrt.dynamo.conversion.aten_ops_converters:_to_copy converter rejected node _to_copy_2 with dtype torch.int64
DEBUG:torch_tensorrt.dynamo.conversion.aten_ops_converters:_to_copy converter rejected node _to_copy_2 with dtype torch.int64
DEBUG:torch_tensorrt.dynamo.conversion.aten_ops_converters:_to_copy converter rejected node _to_copy_3 with dtype torch.int64
DEBUG:torch_tensorrt.dynamo.conversion.aten_ops_converters:_to_copy converter rejected node _to_copy_3 with dtype torch.int64
DEBUG:torch_tensorrt.dynamo.conversion.aten_ops_converters:_to_copy converter rejected node _to_copy_4 with dtype torch.int64
DEBUG:torch_tensorrt.dynamo.conversion.aten_ops_converters:_to_copy converter rejected node _to_copy_4 with dtype torch.int64
DEBUG:torch_tensorrt.dynamo.conversion.aten_ops_converters:_to_copy converter rejected node _to_copy_5 with dtype torch.int64
DEBUG:torch_tensorrt.dynamo.conversion.aten_ops_converters:_to_copy converter rejected node _to_copy_5 with dtype torch.int64
DEBUG:torch_tensorrt.dynamo.conversion.aten_ops_converters:_to_copy converter rejected node _to_copy_6 with dtype torch.int64
DEBUG:torch_tensorrt.dynamo.conversion.aten_ops_converters:_to_copy converter rejected node _to_copy_6 with dtype torch.int64
DEBUG:torch_tensorrt.dynamo.conversion.aten_ops_converters:_to_copy converter rejected node _to_copy_7 with dtype torch.int64
DEBUG:torch_tensorrt.dynamo.conversion.aten_ops_converters:_to_copy converter rejected node _to_copy_7 with dtype torch.int64
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
Number of TensorRT-Accelerated Engines Generated: 2
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
Supported Nodes:
- torch.ops.aten.reshape.default + Operator Count: 13
- torch.ops.aten.expand.default + Operator Count: 1
- torch.ops.aten.select.int + Operator Count: 2
- torch.ops.aten.mul.Tensor + Operator Count: 10
- torch.ops.aten.add.Tensor + Operator Count: 7
- torch.ops.aten.clamp.default + Operator Count: 2
- torch.ops.aten.floor.default + Operator Count: 2
- torch.ops.aten.sub.Tensor + Operator Count: 8
- torch.ops.aten.ge.Scalar + Operator Count: 8
- torch.ops.aten.lt.Scalar + Operator Count: 8
- torch.ops.aten.logical_and.default + Operator Count: 12
- torch.ops.aten.where.self + Operator Count: 12
- torch.ops.aten.index.Tensor + Operator Count: 4

DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
Unsupported or Excluded Nodes:
- torch.ops.aten._to_copy.default + Operator Count: 8

++++++++++++++++++++++++++++++++++++++++++++++++++ Dry-Run Results for Graph ++++++++++++++++++++++++++++++++++++++++++++++++++

The graph consists of 97 Total Operators, of which 89 operators are supported, 91.75% coverage

The following ops are currently unsupported or excluded from conversion, and are listed with their op-count in the graph:
 torch.ops.aten._to_copy.default: 8

The following nodes are currently set to run in Torch:
Node: torch.ops.aten._to_copy.default, with layer location: __/_to_copy
Node: torch.ops.aten._to_copy.default, with layer location: __/_to_copy_1
Node: torch.ops.aten._to_copy.default, with layer location: __/_to_copy_2
Node: torch.ops.aten._to_copy.default, with layer location: __/_to_copy_3
Node: torch.ops.aten._to_copy.default, with layer location: __/_to_copy_4
Node: torch.ops.aten._to_copy.default, with layer location: __/_to_copy_5
Node: torch.ops.aten._to_copy.default, with layer location: __/_to_copy_6
Node: torch.ops.aten._to_copy.default, with layer location: __/_to_copy_7
Note: Some of the above nodes may be supported, but were not included in a TRT graph by the partitioner

Compiled with: CompilationSettings(precision=torch.float32, debug=True, workspace_size=0, min_block_size=1, torch_executed_ops=set(), pass_through_build_failures=False, max_aux_streams=None, version_compatible=False, optimization_level=None, use_python_runtime=False, truncate_long_and_double=True, use_fast_partitioner=True, enable_experimental_decompositions=False, device=Device(type=DeviceType.GPU, gpu_id=0), require_full_compilation=False, disable_tf32=False, sparse_weights=False, refit=False, engine_capability=<EngineCapability.DEFAULT: 0>, num_avg_timing_iters=1, dla_sram_size=1048576, dla_local_dram_size=1073741824, dla_global_dram_size=536870912, dryrun=False, hardware_compatible=False, output_format='fx')

  Graph Structure:

   Inputs: List[Tensor: (1, 3, 8, 8)@float32, Tensor: (1, 16, 16, 2)@float32]
    ...
    TRT Engine #1 - Submodule name: _run_on_acc_0
     Engine Inputs: List[Tensor: (1, 16, 16, 2)@float32]
     Number of Operators in Engine: 58
     Engine Outputs: Tuple(Tensor: (1, 3, 16, 16)@float32, Tensor: (1, 3, 16, 16)@float32, Tensor: (1, 3, 16, 16)@bool, Tensor: (1, 3, 16, 16)@float32, Tensor: (1, 3, 16, 16)@float32, Tensor: (1, 3, 16, 16)@bool, Tensor: (1, 3, 16, 16)@float32, Tensor: (1, 3, 16, 16)@float32, Tensor: (1, 3, 16, 16)@bool, Tensor: (1, 3, 16, 16)@float32, Tensor: (1, 3, 16, 16)@bool, Tensor: (1, 3, 16, 16)@float32)
    ...
    TRT Engine #2 - Submodule name: _run_on_acc_2
     Engine Inputs: List[Tensor: (1, 3, 16, 16)@bool, Tensor: (1, 3, 16, 16)@int32, Tensor: (1, 3, 16, 16)@int32, Tensor: (1, 3, 8, 8)@float32, Tensor: (1, 3, 16, 16)@float32, Tensor: (1, 3, 16, 16)@bool, Tensor: (1, 3, 16, 16)@int32, Tensor: (1, 3, 16, 16)@int32, Tensor: (1, 3, 16, 16)@float32, Tensor: (1, 3, 16, 16)@bool, Tensor: (1, 3, 16, 16)@int32, Tensor: (1, 3, 16, 16)@int32, Tensor: (1, 3, 16, 16)@float32, Tensor: (1, 3, 16, 16)@bool, Tensor: (1, 3, 16, 16)@int32, Tensor: (1, 3, 16, 16)@int32, Tensor: (1, 3, 16, 16)@float32]
     Number of Operators in Engine: 31
     Engine Outputs: Tensor: (1, 3, 16, 16)@float32
    ...
   Outputs: List[Tensor: (1, 3, 16, 16)@float32]

  ------------------------- Aggregate Stats -------------------------

   Average Number of Operators per TRT Engine: 44.5
   Most Operators in a TRT Engine: 58

  ********** Recommendations **********

   - For minimal graph segmentation, select min_block_size=58 which would generate 1 TRT engine(s)
   - For moderate graph segmentation, select min_block_size=45 which would generate 1 TRT engine(s)
   - The current level of graph segmentation is equivalent to selecting min_block_size=31 which generates 2 TRT engine(s)

Environment

  • PyTorch Version (e.g., 1.0): 2.3.0.dev20240221+cu121
  • CPU Architecture: x64
  • OS (e.g., Linux): Ubuntu 22.04 LTS
  • How you installed PyTorch (conda, pip, libtorch, source): pip
  • Build command you used (if compiling from source):
  • Are you using local sources or building from archives:
  • Python version: 3.10.12
  • CUDA version: 12.1
  • GPU models and configuration: RTX 3050
  • Any other relevant information:
@HolyWu HolyWu added the question Further information is requested label Feb 28, 2024
@HolyWu
Copy link
Contributor Author

HolyWu commented Mar 2, 2024

The same for aten.leaky_relu.

import torch
import torch.nn as nn
import torch_tensorrt

class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.m = nn.LeakyReLU()

    def forward(self, x):
        return self.m(x)

model = MyModule().eval().cuda().half()
inputs = [torch.randn((1, 3, 4, 4), dtype=torch.half, device="cuda")]

optimized_model = torch_tensorrt.compile(
    model,
    ir="dynamo",
    inputs=inputs,
    enabled_precisions={torch.half},
    debug=True,
    min_block_size=1,
)
DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
Supported Nodes:
- torch.ops.aten._to_copy.default + Operator Count: 2
- torch.ops.aten.gt.Scalar + Operator Count: 1
- torch.ops.aten.mul.Tensor + Operator Count: 1
- torch.ops.aten.where.self + Operator Count: 1

DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
All Nodes Supported

DEBUG:torch_tensorrt.dynamo._compiler:Detected support for 5 operators out of 5 in subgraph.
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
Number of TensorRT-Accelerated Engines Generated: 1
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
Supported Nodes:
- torch.ops.aten._to_copy.default + Operator Count: 2
- torch.ops.aten.gt.Scalar + Operator Count: 1
- torch.ops.aten.mul.Tensor + Operator Count: 1
- torch.ops.aten.where.self + Operator Count: 1

DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
All Nodes Supported

++++++++++++++++++++++++++++++++++++++++++++++++++ Dry-Run Results for Graph ++++++++++++++++++++++++++++++++++++++++++++++++++

The graph consists of 5 Total Operators, of which 5 operators are supported, 100.0% coverage

Compiled with: CompilationSettings(precision=torch.float16, debug=True, workspace_size=0, min_block_size=1, torch_executed_ops=set(), pass_through_build_failures=False, max_aux_streams=None, version_compatible=False, optimization_level=None, use_python_runtime=False, truncate_long_and_double=False, use_fast_partitioner=True, enable_experimental_decompositions=False, device=Device(type=DeviceType.GPU, gpu_id=0), require_full_compilation=False, disable_tf32=False, sparse_weights=False, refit=False, engine_capability=<EngineCapability.DEFAULT: 0>, num_avg_timing_iters=1, dla_sram_size=1048576, dla_local_dram_size=1073741824, dla_global_dram_size=536870912, dryrun=False, hardware_compatible=False, output_format='exported_program')

  Graph Structure:

   Inputs: List[Tensor: (1, 3, 4, 4)@float16]
    ...
    TRT Engine #1 - Submodule name: _run_on_acc_0
     Engine Inputs: List[Tensor: (1, 3, 4, 4)@float16]
     Number of Operators in Engine: 5
     Engine Outputs: Tensor: (1, 3, 4, 4)@float16
    ...
   Outputs: List[Tensor: (1, 3, 4, 4)@float16]

  ------------------------- Aggregate Stats -------------------------

   Average Number of Operators per TRT Engine: 5.0
   Most Operators in a TRT Engine: 5

  ********** Recommendations **********

   - For minimal graph segmentation, select min_block_size=5 which would generate 1 TRT engine(s)
   - The current level of graph segmentation is equivalent to selecting min_block_size=5 which generates 1 TRT engine(s)
WARNING: [Torch-TensorRT] - Using default stream in enqueue()/enqueueV2()/enqueueV3() may lead to performance issues due to additional cudaDeviceSynchronize() calls by TensorRT to ensure correct synchronizations. Please use non-default stream instead.
C:\Python311\Lib\site-packages\torch\export\exported_program.py:740: UserWarning: Unable to execute the generated python source code from the graph. The graph module will no longer be directly callable, but you can still run the ExportedProgram, and if needed, you can run the graph module eagerly using torch.fx.Interpreter.
  warnings.warn(

@HolyWu
Copy link
Contributor Author

HolyWu commented Mar 3, 2024

aten.upsample_bilinear2d

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_tensorrt


class MyModule(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=True)


model = MyModule().eval().cuda().half()
inputs = [
    torch.randn((1, 3, 128, 128), dtype=torch.half, device="cuda"),
]

optimized_model = torch_tensorrt.compile(
    model,
    ir="dynamo",
    inputs=inputs,
    enabled_precisions={torch.half},
    debug=True,
    min_block_size=1,
)
DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
Supported Nodes:
- torch.ops.aten._to_copy.default + Operator Count: 2
- torch.ops.aten.index.Tensor + Operator Count: 4
- torch.ops.aten.sub.Tensor + Operator Count: 3
- torch.ops.aten.mul.Tensor + Operator Count: 3
- torch.ops.aten.add.Tensor + Operator Count: 3

DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
All Nodes Supported

DEBUG:torch_tensorrt.dynamo._compiler:Detected support for 15 operators out of 15 in subgraph.
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
Number of TensorRT-Accelerated Engines Generated: 1
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
Supported Nodes:
- torch.ops.aten._to_copy.default + Operator Count: 2
- torch.ops.aten.index.Tensor + Operator Count: 4
- torch.ops.aten.sub.Tensor + Operator Count: 3
- torch.ops.aten.mul.Tensor + Operator Count: 3
- torch.ops.aten.add.Tensor + Operator Count: 3

DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
All Nodes Supported

++++++++++++++++++++++++++++++++++++++++++++++++++ Dry-Run Results for Graph ++++++++++++++++++++++++++++++++++++++++++++++++++

The graph consists of 15 Total Operators, of which 15 operators are supported, 100.0% coverage

Compiled with: CompilationSettings(precision=torch.float16, debug=True, workspace_size=0, min_block_size=1, torch_executed_ops=set(), pass_through_build_failures=False, max_aux_streams=None, version_compatible=False, optimization_level=None, use_python_runtime=False, truncate_long_and_double=True, use_fast_partitioner=True, enable_experimental_decompositions=False, device=Device(type=DeviceType.GPU, gpu_id=0), require_full_compilation=False, disable_tf32=False, sparse_weights=False, refit=False, engine_capability=<EngineCapability.DEFAULT: 0>, num_avg_timing_iters=1, dla_sram_size=1048576, dla_local_dram_size=1073741824, dla_global_dram_size=536870912, dryrun=False, hardware_compatible=False, output_format='exported_program')

  Graph Structure:

   Inputs: List[Tensor: (1, 3, 128, 128)@float16]
    ...
    TRT Engine #1 - Submodule name: _run_on_acc_0
     Engine Inputs: List[Tensor: (1, 3, 128, 128)@float16]
     Number of Operators in Engine: 15
     Engine Outputs: Tensor: (1, 3, 256, 256)@float16
    ...
   Outputs: List[Tensor: (1, 3, 256, 256)@float16]

  ------------------------- Aggregate Stats -------------------------

   Average Number of Operators per TRT Engine: 15.0
   Most Operators in a TRT Engine: 15

  ********** Recommendations **********

   - For minimal graph segmentation, select min_block_size=15 which would generate 1 TRT engine(s)
   - The current level of graph segmentation is equivalent to selecting min_block_size=15 which generates 1 TRT engine(s)
WARNING: [Torch-TensorRT] - Using default stream in enqueue()/enqueueV2()/enqueueV3() may lead to performance issues due to additional cudaDeviceSynchronize() calls by TensorRT to ensure correct synchronizations. Please use non-default stream instead.
C:\Python311\Lib\site-packages\torch\export\exported_program.py:740: UserWarning: Unable to execute the generated python source code from the graph. The graph module will no longer be directly callable, but you can still run the ExportedProgram, and if needed, you can run the graph module eagerly using torch.fx.Interpreter.
  warnings.warn(

@HolyWu HolyWu changed the title ❓ [Question] grid_sampler_2d converter doesn't get used? ❓ [Question] operator being decomposed rather than being converted when a corresponding converter exists? Mar 3, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants