Closed
Description
Bug Description
When use deform_conv from torchvision.ops, torch_tensorrt.compile returns error:
RuntimeError: Unhandled FakeTensor Device Propagation for torchvision.deform_conv2d.default, found two different devices cuda:0, cpu
To Reproduce
Steps to reproduce the behavior:
- Source code mysdn.py
import torch
import torch.nn as nn
from torchvision.ops import DeformConv2d, deform_conv2d
import torch_tensorrt
class SimpleDeformNet(nn.Module):
def __init__(self):
super(SimpleDeformNet, self).__init__()
self.offset_conv = nn.Conv2d(3, 18, kernel_size=3, padding=1)
self.deform_conv = DeformConv2d(3, 16, kernel_size=3, padding=1)
def forward(self, x):
offset = self.offset_conv(x)
out = deform_conv2d(x, offset, self.deform_conv.weight, self.deform_conv.bias,
stride=1, padding=1)
return out
model = SimpleDeformNet().eval().cuda()
input_data = torch.randn(1, 3, 224, 224).cuda()
trt_model = torch_tensorrt.compile(
model,
inputs=[
torch_tensorrt.Input(
min_shape=[1, 3, 224, 224],
opt_shape=[1, 3, 224, 224],
max_shape=[1, 3, 224, 224],
dtype=torch.float32
)
],
enabled_precisions={torch.float32},
workspace_size=1 << 22
)
output = trt_model(input_data)
print(output.shape)
- Run the above code
python mysdn.py
- Get error
WARNING:torch_tensorrt.dynamo.conversion.aten_ops_converters:Unable to import quantization op. Please install modelopt library (https://github.com/NVIDIA/TensorRT-Model-Optimizer?tab=readme-ov-file#installation) to add support for compiling quantized models
INFO:torch_tensorrt._compile:ir was set to default, using dynamo frontend
WARNING:py.warnings:.\site-packages\torch\_subclasses\functional_tensor.py:362: UserWarning: At pre-dispatch tracing, we will assume that any custom op that is marked with CompositeImplicitAutograd and functional are safe to not decompose. We found torchvision.deform_conv2d.default to be one such op.
warnings.warn(
Traceback (most recent call last):
File ".\mysdn.py", line 25, in <module>
trt_model = torch_tensorrt.compile(
File "pyl\lib\site-packages\torch_tensorrt\_compile.py", line 249, in compile
trt_graph_module = dynamo_compile(
File "pyl\lib\site-packages\torch_tensorrt\dynamo\_compiler.py", line 185, in compile
exported_program = exported_program.run_decompositions(
File "pyl\lib\site-packages\torch\export\exported_program.py", line 89, in wrapper
return fn(*args, **kwargs)
File "pyl\lib\site-packages\torch\export\exported_program.py", line 567, in run_decompositions
gm, graph_signature = aot_export_module(
File "pyl\lib\site-packages\torch\_functorch\aot_autograd.py", line 1131, in aot_export_module
fx_g, metadata, in_spec, out_spec = _aot_export_function(
File "pyl\lib\site-packages\torch\_functorch\aot_autograd.py", line 1350, in _aot_export_function
fx_g, meta = create_aot_dispatcher_function(
File "pyl\lib\site-packages\torch\_dynamo\utils.py", line 231, in time_wrapper
r = func(*args, **kwargs)
File "pyl\lib\site-packages\torch\_functorch\aot_autograd.py", line 687, in create_aot_dispatcher_function
compiled_fn, fw_metadata = compiler_fn(
File "pyl\lib\site-packages\torch\_functorch\_aot_autograd\jit_compile_runtime_wrappers.py", line 95, in aot_dispatch_export
graph, _, _ = aot_dispatch_base_graph(
File "pyl\lib\site-packages\torch\_functorch\_aot_autograd\dispatch_and_compile_graph.py", line 138, in aot_dispatch_base_graph
fw_module = _create_graph(
File "pyl\lib\site-packages\torch\_functorch\_aot_autograd\dispatch_and_compile_graph.py", line 46, in _create_graph
fx_g = make_fx(
File "pyl\lib\site-packages\torch\fx\experimental\proxy_tensor.py", line 1421, in wrapped
return make_fx_tracer.trace(f, *args)
File "pyl\lib\site-packages\torch\fx\experimental\proxy_tensor.py", line 1367, in trace
return self._trace_inner(f, *args)
File "pyl\lib\site-packages\torch\fx\experimental\proxy_tensor.py", line 1354, in _trace_inner
t = dispatch_trace(
File "pyl\lib\site-packages\torch\_compile.py", line 31, in inner
return disable_fn(*args, **kwargs)
File "pyl\lib\site-packages\torch\_dynamo\eval_frame.py", line 600, in _fn
return fn(*args, **kwargs)
File "pyl\lib\site-packages\torch\fx\experimental\proxy_tensor.py", line 642, in dispatch_trace
graph = tracer.trace(root, concrete_args)
File "pyl\lib\site-packages\torch\fx\experimental\proxy_tensor.py", line 1019, in trace
res = super().trace(root, concrete_args)
File "pyl\lib\site-packages\torch\_dynamo\eval_frame.py", line 600, in _fn
return fn(*args, **kwargs)
File "pyl\lib\site-packages\torch\fx\_symbolic_trace.py", line 822, in trace
(self.create_arg(fn(*args)),),
File "pyl\lib\site-packages\torch\fx\experimental\proxy_tensor.py", line 660, in wrapped
out = f(*tensors)
File "<string>", line 1, in <lambda>
File "pyl\lib\site-packages\torch\_functorch\_aot_autograd\traced_function_transforms.py", line 388, in _functionalized_f_helper
f_outs = fn(*f_args)
File "pyl\lib\site-packages\torch\_functorch\_aot_autograd\traced_function_transforms.py", line 72, in inner_fn
outs = fn(*args)
File "pyl\lib\site-packages\torch\_functorch\_aot_autograd\utils.py", line 178, in flat_fn
tree_out = fn(*args, **kwargs)
File "pyl\lib\site-packages\torch\_functorch\_aot_autograd\traced_function_transforms.py", line 744, in functional_call
out = PropagateUnbackedSymInts(mod).run(
File "pyl\lib\site-packages\torch\fx\interpreter.py", line 146, in run
self.env[node] = self.run_node(node)
File "pyl\lib\site-packages\torch\fx\experimental\symbolic_shapes.py", line 5461, in run_node
result = super().run_node(n)
File "pyl\lib\site-packages\torch\fx\interpreter.py", line 203, in run_node
return getattr(self, n.op)(n.target, args, kwargs)
File "pyl\lib\site-packages\torch\fx\interpreter.py", line 275, in call_function
return target(*args, **kwargs)
File "pyl\lib\site-packages\torch\_ops.py", line 667, in __call__
return self_._op(*args, **kwargs)
File "pyl\lib\site-packages\torch\fx\experimental\proxy_tensor.py", line 705, in __torch_function__
return func(*args, **kwargs)
File "pyl\lib\site-packages\torch\_ops.py", line 667, in __call__
return self_._op(*args, **kwargs)
File "pyl\lib\site-packages\torch\_subclasses\functional_tensor.py", line 468, in __torch_dispatch__
outs_unwrapped = func._op_dk(
File "pyl\lib\site-packages\torch\utils\_stats.py", line 21, in wrapper
return fn(*args, **kwargs)
File "pyl\lib\site-packages\torch\fx\experimental\proxy_tensor.py", line 755, in __torch_dispatch__
return self.inner_torch_dispatch(func, types, args, kwargs)
File "pyl\lib\site-packages\torch\fx\experimental\proxy_tensor.py", line 790, in inner_torch_dispatch
return proxy_call(self, func, self.pre_dispatch, args, kwargs)
File "pyl\lib\site-packages\torch\fx\experimental\proxy_tensor.py", line 467, in proxy_call
out = func(*args, **kwargs)
File "pyl\lib\site-packages\torch\_ops.py", line 667, in __call__
return self_._op(*args, **kwargs)
File "pyl\lib\site-packages\torch\utils\_stats.py", line 21, in wrapper
return fn(*args, **kwargs)
File "pyl\lib\site-packages\torch\_subclasses\fake_tensor.py", line 1061, in __torch_dispatch__
return self.dispatch(func, types, args, kwargs)
File "pyl\lib\site-packages\torch\_subclasses\fake_tensor.py", line 1450, in dispatch
return self._cached_dispatch_impl(func, types, args, kwargs)
File "pyl\lib\site-packages\torch\_subclasses\fake_tensor.py", line 1153, in _cached_dispatch_impl
output = self._dispatch_impl(func, types, args, kwargs)
File "pyl\lib\site-packages\torch\_subclasses\fake_tensor.py", line 1765, in _dispatch_impl
self.wrap_meta_outputs_with_default_device_logic(
File "pyl\lib\site-packages\torch\_subclasses\fake_tensor.py", line 1875, in wrap_meta_outputs_with_default_device_logic
return tree_map(wrap, r)
File "pyl\lib\site-packages\torch\utils\_pytree.py", line 948, in tree_map
return treespec.unflatten(map(func, *flat_args))
File "pyl\lib\site-packages\torch\utils\_pytree.py", line 787, in unflatten
leaves = list(leaves)
File "pyl\lib\site-packages\torch\_subclasses\fake_tensor.py", line 1853, in wrap
) = FakeTensor._find_common_device(func, flat_args)
File "pyl\lib\site-packages\torch\_subclasses\fake_tensor.py", line 775, in _find_common_device
merge_devices(arg)
File "pyl\lib\site-packages\torch\_subclasses\fake_tensor.py", line 770, in merge_devices
raise RuntimeError(
RuntimeError: Unhandled FakeTensor Device Propagation for torchvision.deform_conv2d.default, found two different devices cuda:0, cpu
While executing %deform_conv2d : [num_users=1] = call_function[target=torch.ops.torchvision.deform_conv2d.default](args = (%x, %p_deform_conv_weight, %conv2d, %zeros, %p_deform_conv_bias, 1, 1, 1, 1, 1, 1, 1, 1, False), kwargs = {})
Original traceback:
File ".\mysdn.py", line 16, in forward
out = deform_conv2d(x, offset, self.deform_conv.weight, self.deform_conv.bias,
File "pyl\lib\site-packages\torchvision\ops\deform_conv.py", line 92, in deform_conv2d
return torch.ops.torchvision.deform_conv2d(
Expected behavior
Show right print message.
Environment
Build information about Torch-TensorRT can be found by turning on debug messages
- Torch-TensorRT Version (e.g. 1.0.0): 2.4.0+cu118
- PyTorch Version (e.g. 1.0): 2.4.1+cu118
- CPU Architecture: x86 64bit
- OS (e.g., Linux): Window 11
- How you installed PyTorch (
conda
,pip
,libtorch
, source): pip - Build command you used (if compiling from source): None
- Are you using local sources or building from archives: No
- Python version: 3.10.14
- CUDA version: cu118
- GPU models and configuration: RTX 3080
- Any other relevant information: torchvision 0.19.1+cu118
Additional context
None