Skip to content

Commit 29972a9

Browse files
committed
[dynamo] Reserve the tensorrt backend name for torch-tensorrt
1 parent 3fb0819 commit 29972a9

File tree

2 files changed

+17
-6
lines changed

2 files changed

+17
-6
lines changed

torch/_dynamo/backends/onnxrt.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -115,9 +115,4 @@ def _call(*initial_args):
115115
binding.copy_outputs_to_cpu()
116116
return outputs
117117

118-
return _call
119-
120-
121-
@register_backend
122-
def tensorrt(gm, example_inputs):
123-
return onnxrt(gm, example_inputs, provider="TensorrtExecutionProvider")
118+
return _call

torch/_dynamo/backends/tensorrt.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import importlib
2+
import os
3+
import tempfile
4+
5+
import torch
6+
from .common import device_from_inputs, fake_tensor_unsupported
7+
from .registry import register_backend
8+
9+
'''
10+
Placeholder for TensorRT backend for dynamo via torch-tensorrt
11+
'''
12+
13+
# @register_backend
14+
# def tensorrt(gm, example_inputs):
15+
# import torch_tensorrt # type: ignore[import]
16+
# pass

0 commit comments

Comments
 (0)