Skip to content

🐛 [Bug] RuntimeError: Unhandled FakeTensor Device Propagation for torchvision.deform_conv2d.default, found two different devices cuda:0, cpu #3556

Closed
@xfeep

Description

@xfeep

Bug Description

When use deform_conv from torchvision.ops, torch_tensorrt.compile returns error:

RuntimeError: Unhandled FakeTensor Device Propagation for torchvision.deform_conv2d.default, found two different devices cuda:0, cpu

To Reproduce

Steps to reproduce the behavior:

  1. Source code mysdn.py
import torch
import torch.nn as nn
from torchvision.ops import DeformConv2d, deform_conv2d
import torch_tensorrt


class SimpleDeformNet(nn.Module):
    def __init__(self):
        super(SimpleDeformNet, self).__init__()
        
        self.offset_conv = nn.Conv2d(3, 18, kernel_size=3, padding=1)
        self.deform_conv = DeformConv2d(3, 16, kernel_size=3, padding=1)

    def forward(self, x):
        offset = self.offset_conv(x)
        out = deform_conv2d(x, offset, self.deform_conv.weight, self.deform_conv.bias,
                            stride=1, padding=1)
        return out


model = SimpleDeformNet().eval().cuda()

input_data = torch.randn(1, 3, 224, 224).cuda()

trt_model = torch_tensorrt.compile(
    model,
    inputs=[
        torch_tensorrt.Input(
            min_shape=[1, 3, 224, 224],
            opt_shape=[1, 3, 224, 224],
            max_shape=[1, 3, 224, 224],
            dtype=torch.float32
        )
    ],
    enabled_precisions={torch.float32},
    workspace_size=1 << 22
)

output = trt_model(input_data)
print(output.shape)
  1. Run the above code

python mysdn.py

  1. Get error
WARNING:torch_tensorrt.dynamo.conversion.aten_ops_converters:Unable to import quantization op. Please install modelopt library (https://github.com/NVIDIA/TensorRT-Model-Optimizer?tab=readme-ov-file#installation) to add support for compiling quantized models
INFO:torch_tensorrt._compile:ir was set to default, using dynamo frontend
WARNING:py.warnings:.\site-packages\torch\_subclasses\functional_tensor.py:362: UserWarning: At pre-dispatch tracing, we will assume that any custom op that is marked with CompositeImplicitAutograd and functional are safe to not decompose. We found torchvision.deform_conv2d.default to be one such op.
  warnings.warn(

Traceback (most recent call last):
  File ".\mysdn.py", line 25, in <module>
    trt_model = torch_tensorrt.compile(
  File "pyl\lib\site-packages\torch_tensorrt\_compile.py", line 249, in compile
    trt_graph_module = dynamo_compile(
  File "pyl\lib\site-packages\torch_tensorrt\dynamo\_compiler.py", line 185, in compile
    exported_program = exported_program.run_decompositions(
  File "pyl\lib\site-packages\torch\export\exported_program.py", line 89, in wrapper
    return fn(*args, **kwargs)
  File "pyl\lib\site-packages\torch\export\exported_program.py", line 567, in run_decompositions
    gm, graph_signature = aot_export_module(
  File "pyl\lib\site-packages\torch\_functorch\aot_autograd.py", line 1131, in aot_export_module
    fx_g, metadata, in_spec, out_spec = _aot_export_function(
  File "pyl\lib\site-packages\torch\_functorch\aot_autograd.py", line 1350, in _aot_export_function
    fx_g, meta = create_aot_dispatcher_function(
  File "pyl\lib\site-packages\torch\_dynamo\utils.py", line 231, in time_wrapper
    r = func(*args, **kwargs)
  File "pyl\lib\site-packages\torch\_functorch\aot_autograd.py", line 687, in create_aot_dispatcher_function
    compiled_fn, fw_metadata = compiler_fn(
  File "pyl\lib\site-packages\torch\_functorch\_aot_autograd\jit_compile_runtime_wrappers.py", line 95, in aot_dispatch_export
    graph, _, _ = aot_dispatch_base_graph(
  File "pyl\lib\site-packages\torch\_functorch\_aot_autograd\dispatch_and_compile_graph.py", line 138, in aot_dispatch_base_graph
    fw_module = _create_graph(
  File "pyl\lib\site-packages\torch\_functorch\_aot_autograd\dispatch_and_compile_graph.py", line 46, in _create_graph
    fx_g = make_fx(
  File "pyl\lib\site-packages\torch\fx\experimental\proxy_tensor.py", line 1421, in wrapped
    return make_fx_tracer.trace(f, *args)
  File "pyl\lib\site-packages\torch\fx\experimental\proxy_tensor.py", line 1367, in trace
    return self._trace_inner(f, *args)
  File "pyl\lib\site-packages\torch\fx\experimental\proxy_tensor.py", line 1354, in _trace_inner
    t = dispatch_trace(
  File "pyl\lib\site-packages\torch\_compile.py", line 31, in inner
    return disable_fn(*args, **kwargs)
  File "pyl\lib\site-packages\torch\_dynamo\eval_frame.py", line 600, in _fn
    return fn(*args, **kwargs)
  File "pyl\lib\site-packages\torch\fx\experimental\proxy_tensor.py", line 642, in dispatch_trace
    graph = tracer.trace(root, concrete_args)
  File "pyl\lib\site-packages\torch\fx\experimental\proxy_tensor.py", line 1019, in trace
    res = super().trace(root, concrete_args)
  File "pyl\lib\site-packages\torch\_dynamo\eval_frame.py", line 600, in _fn
    return fn(*args, **kwargs)
  File "pyl\lib\site-packages\torch\fx\_symbolic_trace.py", line 822, in trace
    (self.create_arg(fn(*args)),),
  File "pyl\lib\site-packages\torch\fx\experimental\proxy_tensor.py", line 660, in wrapped
    out = f(*tensors)
  File "<string>", line 1, in <lambda>
  File "pyl\lib\site-packages\torch\_functorch\_aot_autograd\traced_function_transforms.py", line 388, in _functionalized_f_helper
    f_outs = fn(*f_args)
  File "pyl\lib\site-packages\torch\_functorch\_aot_autograd\traced_function_transforms.py", line 72, in inner_fn
    outs = fn(*args)
  File "pyl\lib\site-packages\torch\_functorch\_aot_autograd\utils.py", line 178, in flat_fn
    tree_out = fn(*args, **kwargs)
  File "pyl\lib\site-packages\torch\_functorch\_aot_autograd\traced_function_transforms.py", line 744, in functional_call
    out = PropagateUnbackedSymInts(mod).run(
  File "pyl\lib\site-packages\torch\fx\interpreter.py", line 146, in run
    self.env[node] = self.run_node(node)
  File "pyl\lib\site-packages\torch\fx\experimental\symbolic_shapes.py", line 5461, in run_node
    result = super().run_node(n)
  File "pyl\lib\site-packages\torch\fx\interpreter.py", line 203, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
  File "pyl\lib\site-packages\torch\fx\interpreter.py", line 275, in call_function
    return target(*args, **kwargs)
  File "pyl\lib\site-packages\torch\_ops.py", line 667, in __call__
    return self_._op(*args, **kwargs)
  File "pyl\lib\site-packages\torch\fx\experimental\proxy_tensor.py", line 705, in __torch_function__
    return func(*args, **kwargs)
  File "pyl\lib\site-packages\torch\_ops.py", line 667, in __call__
    return self_._op(*args, **kwargs)
  File "pyl\lib\site-packages\torch\_subclasses\functional_tensor.py", line 468, in __torch_dispatch__
    outs_unwrapped = func._op_dk(
  File "pyl\lib\site-packages\torch\utils\_stats.py", line 21, in wrapper
    return fn(*args, **kwargs)
  File "pyl\lib\site-packages\torch\fx\experimental\proxy_tensor.py", line 755, in __torch_dispatch__
    return self.inner_torch_dispatch(func, types, args, kwargs)
  File "pyl\lib\site-packages\torch\fx\experimental\proxy_tensor.py", line 790, in inner_torch_dispatch
    return proxy_call(self, func, self.pre_dispatch, args, kwargs)
  File "pyl\lib\site-packages\torch\fx\experimental\proxy_tensor.py", line 467, in proxy_call
    out = func(*args, **kwargs)
  File "pyl\lib\site-packages\torch\_ops.py", line 667, in __call__
    return self_._op(*args, **kwargs)
  File "pyl\lib\site-packages\torch\utils\_stats.py", line 21, in wrapper
    return fn(*args, **kwargs)
  File "pyl\lib\site-packages\torch\_subclasses\fake_tensor.py", line 1061, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
  File "pyl\lib\site-packages\torch\_subclasses\fake_tensor.py", line 1450, in dispatch
    return self._cached_dispatch_impl(func, types, args, kwargs)
  File "pyl\lib\site-packages\torch\_subclasses\fake_tensor.py", line 1153, in _cached_dispatch_impl
    output = self._dispatch_impl(func, types, args, kwargs)
  File "pyl\lib\site-packages\torch\_subclasses\fake_tensor.py", line 1765, in _dispatch_impl
    self.wrap_meta_outputs_with_default_device_logic(
  File "pyl\lib\site-packages\torch\_subclasses\fake_tensor.py", line 1875, in wrap_meta_outputs_with_default_device_logic
    return tree_map(wrap, r)
  File "pyl\lib\site-packages\torch\utils\_pytree.py", line 948, in tree_map
    return treespec.unflatten(map(func, *flat_args))
  File "pyl\lib\site-packages\torch\utils\_pytree.py", line 787, in unflatten
    leaves = list(leaves)
  File "pyl\lib\site-packages\torch\_subclasses\fake_tensor.py", line 1853, in wrap
    ) = FakeTensor._find_common_device(func, flat_args)
  File "pyl\lib\site-packages\torch\_subclasses\fake_tensor.py", line 775, in _find_common_device
    merge_devices(arg)
  File "pyl\lib\site-packages\torch\_subclasses\fake_tensor.py", line 770, in merge_devices
    raise RuntimeError(
RuntimeError: Unhandled FakeTensor Device Propagation for torchvision.deform_conv2d.default, found two different devices cuda:0, cpu

While executing %deform_conv2d : [num_users=1] = call_function[target=torch.ops.torchvision.deform_conv2d.default](args = (%x, %p_deform_conv_weight, %conv2d, %zeros, %p_deform_conv_bias, 1, 1, 1, 1, 1, 1, 1, 1, False), kwargs = {})
Original traceback:
  File ".\mysdn.py", line 16, in forward
    out = deform_conv2d(x, offset, self.deform_conv.weight, self.deform_conv.bias,
  File "pyl\lib\site-packages\torchvision\ops\deform_conv.py", line 92, in deform_conv2d
    return torch.ops.torchvision.deform_conv2d(

Expected behavior

Show right print message.

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

  • Torch-TensorRT Version (e.g. 1.0.0): 2.4.0+cu118
  • PyTorch Version (e.g. 1.0): 2.4.1+cu118
  • CPU Architecture: x86 64bit
  • OS (e.g., Linux): Window 11
  • How you installed PyTorch (conda, pip, libtorch, source): pip
  • Build command you used (if compiling from source): None
  • Are you using local sources or building from archives: No
  • Python version: 3.10.14
  • CUDA version: cu118
  • GPU models and configuration: RTX 3080
  • Any other relevant information: torchvision 0.19.1+cu118

Additional context

None

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions