Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: enable DS support for converters #2775

Merged
merged 107 commits into from
May 22, 2024
Merged
Show file tree
Hide file tree
Changes from 103 commits
Commits
Show all changes
107 commits
Select commit Hold shift + click to select a range
d06c74a
chore: Switch to new export apis
peri044 Oct 3, 2023
47e0997
chore: rebase with main
peri044 Oct 9, 2023
fd29fe0
Merge branch 'main' into export_2.2
peri044 Oct 11, 2023
ad3b031
feat: Add support for dynamic shapes and remove constraints API
peri044 Oct 19, 2023
1582b72
chore: add dynamic shape support for certain converters
peri044 Oct 23, 2023
4d01545
chore: minor updates
peri044 Oct 25, 2023
6731a57
chore: updates
peri044 Oct 26, 2023
a8a194b
chore: rebase with main
peri044 Nov 15, 2023
0b60aae
chore: add sym int converter
peri044 Nov 15, 2023
634612f
feat: Replace the existing shape propagation with symbolic shape prop…
peri044 Nov 16, 2023
93edba4
chore: fix imports
peri044 Nov 16, 2023
7ad9272
chore: fix imports
peri044 Nov 16, 2023
f444d54
chore: updates
peri044 Nov 21, 2023
6e5c582
chore: change device calls
peri044 Nov 28, 2023
83791f8
chore: fix metadata check
peri044 Dec 5, 2023
8375996
chore: rebase with main
peri044 Dec 15, 2023
aba91fa
Merge branch 'main' into dyn_2.2
peri044 Dec 22, 2023
16394d9
chore: minor fixes
peri044 Jan 7, 2024
b9a7ccd
chore: Add sym_size converter tests
peri044 Jan 8, 2024
15cc643
chore: Update test utilities
peri044 Jan 8, 2024
5234d74
chore: add testcase for sym_size.int
peri044 Jan 8, 2024
fd2dae1
Merge branch 'main' into dyn_2.2
peri044 Jan 26, 2024
51e8bb7
chore: revert output type change
peri044 Jan 26, 2024
19c3fad
chore: add update_metadata utility
peri044 Jan 27, 2024
ed48551
chore: change debug to warning if the graph does not have metadata
peri044 Jan 27, 2024
18b7e11
feat: add lowering passes to support dynamic shapes for torch.compile
peri044 Jan 30, 2024
3a39d27
chore: add test case
peri044 Jan 30, 2024
abb2677
chore: add view test case
peri044 Feb 2, 2024
9aff04b
chore: gpt2 changes + linting
peri044 Feb 7, 2024
440fcd5
chore: gpt2 changes + linting
peri044 Feb 7, 2024
a2d38f3
chore: rebase with main
peri044 Feb 7, 2024
002db3c
chore: add fallback option if val is missing in metadata
peri044 Feb 7, 2024
00cd17b
chore: tmp changes
peri044 Feb 13, 2024
6ac70cd
chore: tmp changes
peri044 Feb 13, 2024
b827070
Merge branch 'main' into dyn_2.2
peri044 Feb 16, 2024
8f9bca0
Merge branch 'main' into dyn_2.2
peri044 Feb 21, 2024
4399d57
Merge branch 'dyn_2.2' into dyn_2.2_tc
peri044 Feb 21, 2024
39615a2
chore: fixes
peri044 Feb 26, 2024
cd86660
feat: Add save API for torch-trt compiled models
peri044 Mar 14, 2024
3ece71b
chore: resolve merge conflicts
peri044 Mar 15, 2024
1fa1771
Merge branch 'main' into dyn_2.2
peri044 Mar 15, 2024
febf05b
Merge branch 'save' into dyn_2.2
peri044 Mar 15, 2024
eab0dba
chore: Fix save failures
peri044 Mar 18, 2024
b191d62
chore: update to 2.3 rc build
peri044 Mar 18, 2024
380477b
Merge branch 'dyn_2.2' into dyn_2.2_tc
peri044 Mar 19, 2024
5f34d4f
chore: minor fixes
peri044 Mar 19, 2024
ce606fe
chore: rebase with release/2.3 branch
peri044 Mar 19, 2024
8674a3c
chore: minor fixes
peri044 Mar 19, 2024
f4e8fe9
chore: remove duplicate bert test case
peri044 Mar 20, 2024
4ae6ab9
chore: remove comments
peri044 Mar 20, 2024
c14f28d
Merge branch 'save' into dyn_2.2
peri044 Mar 20, 2024
3295c02
Merge branch 'dyn_2.2' into dyn_2.2_tc
peri044 Mar 20, 2024
4188173
chore: rebase with release/2.3
peri044 Apr 2, 2024
f6b758e
Merge branch 'dyn_2.2' into dyn_2.2_tc
peri044 Apr 2, 2024
78f7eb5
chore: updates
peri044 Apr 2, 2024
fe13c2a
chore: Update mypy type for sample_inputs
peri044 Apr 2, 2024
e9b649d
chore: revert changes
peri044 Apr 5, 2024
03ecc61
Merge branch 'dyn_2.2' into dyn_2.2_tc
peri044 Apr 5, 2024
978c039
Merge branch 'release/2.3' into dyn_2.2
peri044 Apr 5, 2024
ccb88c8
Merge branch 'dyn_2.2' into dyn_2.2_tc
peri044 Apr 5, 2024
5a62761
chore: patches for llamav2
peri044 Apr 5, 2024
22150ec
Merge branch 'dyn_llama' of https://github.com/pytorch/TensorRT into …
peri044 Apr 5, 2024
91d1a59
chore: rebase
peri044 Apr 23, 2024
c1574be
chore: updates
peri044 Apr 24, 2024
8c68359
chore: add consistent graph log
peri044 Apr 25, 2024
c34582e
Merge branch 'release/2.3' into dyn_llama
peri044 Apr 26, 2024
1481ad3
chore: updates
peri044 Apr 27, 2024
7a59e63
feat: Add validators for dynamic shapes in converter registration
peri044 Apr 29, 2024
f55d41a
chore: updates
peri044 Apr 30, 2024
87da1c1
chore: updates
peri044 Apr 30, 2024
6baad7f
Merge branch 'release/2.3' into dyn_llama
peri044 Apr 30, 2024
b45b8d7
Merge branch 'release/2.3' into dyn_validator
peri044 Apr 30, 2024
5aff3c1
Merge branch 'dyn_validator' into dyn_llama
peri044 Apr 30, 2024
a1604a1
Merge branch 'release/2.3' into dyn_validator
peri044 May 1, 2024
e3e7927
chore: updates
peri044 May 1, 2024
8ec68da
chore: address failures and implement flag to enable all converters
peri044 May 2, 2024
151fc40
chore: update docstring
peri044 May 2, 2024
a2ed092
chore: add testcase
peri044 May 2, 2024
c1f5d15
chore: updates
peri044 May 2, 2024
19aa788
Merge branch 'dyn_validator' into dyn_llama
peri044 May 2, 2024
649b79d
chore: rename disable_dynamic_converter_checks to assume_dynamic_shap…
peri044 May 2, 2024
6f945fa
chore: updates
peri044 May 3, 2024
b531573
chore: updates
peri044 May 7, 2024
e0415a5
chore: updates
peri044 May 13, 2024
6dd2c90
chore: updates
peri044 May 14, 2024
798aa30
chore: updates
peri044 May 14, 2024
f947509
chore: remove dyn shape support for split converter
peri044 May 14, 2024
477a49b
chore: updates
peri044 May 14, 2024
1828d5b
chore: updates
peri044 May 14, 2024
66c7b19
chore: updates
peri044 May 15, 2024
afa85fc
chore: updates
peri044 May 15, 2024
ac4feba
chore: roll back GHA changes
peri044 May 15, 2024
aa8e0b5
chore: updates
peri044 May 16, 2024
5b98915
chore: updates
peri044 May 16, 2024
40dbbff
chore: updates
peri044 May 16, 2024
80a2e9e
chore: updates
peri044 May 16, 2024
3b2245e
chore: updates
peri044 May 16, 2024
382ea09
chore: updates
peri044 May 16, 2024
89d3e8d
chore: rebase
peri044 May 16, 2024
6dc40e2
chore: rebase with 2.3
peri044 May 16, 2024
31bf8ed
chore: updates
peri044 May 17, 2024
18c0b4b
chore: fix tests
peri044 May 17, 2024
c6f7b4a
chore: updates
peri044 May 17, 2024
bb5d30d
chore: updates
peri044 May 17, 2024
da72508
chore: address review comments
peri044 May 21, 2024
e33976a
chore: remove gpt2 example
peri044 May 21, 2024
f28684c
chore: updates
peri044 May 22, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
23 changes: 8 additions & 15 deletions .github/workflows/build-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -78,16 +78,15 @@ jobs:
script: |
export USE_HOST_DEPS=1
export LD_LIBRARY_PATH=/usr/lib64:$LD_LIBRARY_PATH
export LD_LIBRARY_PATH=/opt/torch-tensorrt-builds/TensorRT-10.0.1.6/lib:$LD_LIBRARY_PATH
pushd .
cd tests/modules
# Don't use requirements.txt here as it contains tensorrt and torch which should have been installed by now.
${CONDA_RUN} python -m pip install numpy packaging pyyaml transformers timm pybind11==2.6.2
${CONDA_RUN} python -m pip install numpy packaging pyyaml transformers==4.40.2 timm==0.9.16 pybind11==2.6.2
${CONDA_RUN} python hub.py
popd
pushd .
cd tests/py/ts
${CONDA_RUN} python -m pip install --pre pytest timm transformers parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver
${CONDA_RUN} python -m pip install --pre pytest timm==0.9.16 transformers==4.40.2 parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/ts_api_test_results.xml api/
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/ts_models_test_results.xml models/
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/ts_integrations_test_results.xml integrations/
Expand Down Expand Up @@ -115,10 +114,9 @@ jobs:
pre-script: ${{ matrix.pre-script }}
script: |
export USE_HOST_DEPS=1
export LD_LIBRARY_PATH=/opt/torch-tensorrt-builds/TensorRT-10.0.1.6/lib:$LD_LIBRARY_PATH
pushd .
cd tests/py/dynamo
${CONDA_RUN} python -m pip install --pre pytest-xdist timm transformers parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver
${CONDA_RUN} python -m pip install --pre pytest-xdist timm==0.9.16 transformers==4.40.2 parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_converters_test_results.xml -n 10 conversion/
popd

Expand All @@ -144,10 +142,9 @@ jobs:
pre-script: ${{ matrix.pre-script }}
script: |
export USE_HOST_DEPS=1
export LD_LIBRARY_PATH=/opt/torch-tensorrt-builds/TensorRT-10.0.1.6/lib:$LD_LIBRARY_PATH
pushd .
cd tests/py/dynamo
${CONDA_RUN} python -m pip install --pre pytest timm transformers parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver
${CONDA_RUN} python -m pip install --pre pytest timm==0.9.16 transformers==4.40.2 parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_fe_test_results.xml --ir dynamo models/test_models_export.py
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/dyn_models_export.xml --ir dynamo models/test_dyn_models.py
popd
Expand All @@ -174,10 +171,9 @@ jobs:
pre-script: ${{ matrix.pre-script }}
script: |
export USE_HOST_DEPS=1
export LD_LIBRARY_PATH=/opt/torch-tensorrt-builds/TensorRT-10.0.1.6/lib:$LD_LIBRARY_PATH
pushd .
cd tests/py/dynamo
${CONDA_RUN} python -m pip install --pre pytest timm transformers parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver
${CONDA_RUN} python -m pip install --pre pytest timm==0.9.16 transformers==4.40.2 parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/export_serde_test_results.xml --ir dynamo models/test_export_serde.py
popd

Expand All @@ -203,10 +199,9 @@ jobs:
pre-script: ${{ matrix.pre-script }}
script: |
export USE_HOST_DEPS=1
export LD_LIBRARY_PATH=/opt/torch-tensorrt-builds/TensorRT-10.0.1.6/lib:$LD_LIBRARY_PATH
pushd .
cd tests/py/dynamo
${CONDA_RUN} python -m pip install --pre pytest-xdist timm transformers parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver
${CONDA_RUN} python -m pip install --pre pytest-xdist timm==0.9.16 transformers==4.40.2 parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver
${CONDA_RUN} python -m pytest -n 10 --junitxml=${RUNNER_TEST_RESULTS_DIR}/torch_compile_be_test_results.xml backend/
${CONDA_RUN} python -m pytest -n 4 --junitxml=${RUNNER_TEST_RESULTS_DIR}/torch_comple_be_e2e_test_results.xml --ir torch_compile models/test_models.py
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/torch_compile_dyn_models_export.xml --ir torch_compile models/test_dyn_models.py
Expand Down Expand Up @@ -234,10 +229,9 @@ jobs:
pre-script: ${{ matrix.pre-script }}
script: |
export USE_HOST_DEPS=1
export LD_LIBRARY_PATH=/opt/torch-tensorrt-builds/TensorRT-10.0.1.6/lib:$LD_LIBRARY_PATH
pushd .
cd tests/py/dynamo
${CONDA_RUN} python -m pip install --pre pytest-xdist timm transformers parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver
${CONDA_RUN} python -m pip install --pre pytest-xdist timm==0.9.16 transformers==4.40.2 parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver
${CONDA_RUN} python -m pytest -n 4 --junitxml=${RUNNER_TEST_RESULTS_DIR}/tests_py_dynamo_core_runtime_test_results.xml runtime/
${CONDA_RUN} python -m pytest -n 4 --junitxml=${RUNNER_TEST_RESULTS_DIR}/tests_py_dynamo_core_partitioning_test_results.xml partitioning/
${CONDA_RUN} python -m pytest -n 4 --junitxml=${RUNNER_TEST_RESULTS_DIR}/tests_py_dynamo_core_lowering_test_results.xml lowering/
Expand All @@ -264,9 +258,8 @@ jobs:
pre-script: ${{ matrix.pre-script }}
script: |
export USE_HOST_DEPS=1
export LD_LIBRARY_PATH=/opt/torch-tensorrt-builds/TensorRT-10.0.1.6/lib:$LD_LIBRARY_PATH
pushd .
cd tests/py/core
${CONDA_RUN} python -m pip install --pre pytest-xdist timm transformers parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver
${CONDA_RUN} python -m pip install --pre pytest-xdist timm==0.9.16 transformers==4.40.2 parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver
${CONDA_RUN} python -m pytest -n 4 --junitxml=${RUNNER_TEST_RESULTS_DIR}/tests_py_core_test_results.xml .
popd
26 changes: 23 additions & 3 deletions core/runtime/execute_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
}
}

// this is a buffer to store shape tensor input addresses throughout the runtime scope
std::list<std::vector<int32_t>> inputShapeTensorValues;
{
std::unique_ptr<torch::autograd::profiler::RecordProfile> input_profiler_guard;
if (compiled_engine->profile_execution) {
Expand All @@ -142,12 +144,30 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
auto dims = core::util::toDims(inputs[i].sizes());
auto shape = core::util::toVec(dims);
LOG_DEBUG("Input Name: " << name << " Shape: " << dims);
compiled_engine->exec_ctx->setInputShape(name.c_str(), dims);
compiled_engine->exec_ctx->setTensorAddress(name.c_str(), inputs[i].view(shape).contiguous().data_ptr());
if (compiled_engine->cuda_engine->isShapeInferenceIO(name.c_str())) {
peri044 marked this conversation as resolved.
Show resolved Hide resolved
// Shape tensor inputs are casted to int32 explicitly.
// Refer to
// https://github.com/NVIDIA/TensorRT/blob/d2f4ef789a9a6ffdf37b55c3f81b486225f6b380/samples/common/sampleInference.cpp#L435
auto input_cpu = inputs[i].clone().contiguous().cpu().to(torch::kInt32);
std::vector<int32_t> inputs_cpu_vec(
input_cpu.data_ptr<int32_t>(), input_cpu.data_ptr<int32_t>() + input_cpu.numel());
inputShapeTensorValues.emplace_back(inputs_cpu_vec);
compiled_engine->exec_ctx->setTensorAddress(name.c_str(), inputShapeTensorValues.back().data());
} else {
compiled_engine->exec_ctx->setInputShape(name.c_str(), dims);
compiled_engine->exec_ctx->setTensorAddress(name.c_str(), inputs[i].view(shape).contiguous().data_ptr());
}
}

// Check if input shapes can be inferred.
int32_t const io_size{compiled_engine->cuda_engine->getNbIOTensors()};
std::vector<char const*> names(io_size);
int32_t const nbNames = compiled_engine->exec_ctx->inferShapes(names.size(), names.data());
TORCHTRT_CHECK(
compiled_engine->exec_ctx->allInputShapesSpecified(), "Not enough inputs provided (runtime.RunCudaEngine)");
nbNames == 0,
"The shapes of the inputs: "
<< names
<< " cannot be inferred. This could happen if the input tensor addresses/shapes haven't been configured correctly");
}

std::vector<at::Tensor> outputs(compiled_engine->num_io.second);
Expand Down
68 changes: 68 additions & 0 deletions examples/gpt2_tc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import torch
import torch_tensorrt
from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteriaList
from transformers.generation.stopping_criteria import (
EosTokenCriteria,
MaxLengthCriteria,
)

# Define tokenizer and model
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = (
AutoModelForCausalLM.from_pretrained(
"gpt2", pad_token_id=tokenizer.eos_token_id, use_cache=False
)
.eval()
.to(torch_device)
)

# Input prompt
model_inputs = tokenizer("I enjoy walking with my cute dog", return_tensors="pt").to(
torch_device
)
input_ids = model_inputs["input_ids"]
max_tokens = 40

# Pyt model outputs
greedy_output = model.generate(**model_inputs, max_new_tokens=max_tokens)
print(
"Pytorch model generated text: ",
tokenizer.decode(greedy_output[0], skip_special_tokens=True),
)

# Compile Torch-TRT model
torch._dynamo.mark_dynamic(input_ids, 1, min=2, max=1023)
model.forward = torch.compile(
model.forward,
backend="tensorrt",
dynamic=None,
options={
"debug": False,
"enabled_precisions": {torch.float},
"torch_executed_ops": {"torch.ops.aten.slice.Tensor"},
"use_python_runtime": True,
},
)

# Auto-regressive generation loop for greedy search
stopping_criteria = StoppingCriteriaList(
[
MaxLengthCriteria(max_length=max_tokens),
EosTokenCriteria(eos_token_id=tokenizer.eos_token_id),
]
)
while True:
trt_outputs = model(input_ids)
logits = trt_outputs.logits
next_token_logits = logits[:, -1, :]
next_tokens = torch.argmax(next_token_logits, dim=-1)
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
if stopping_criteria(input_ids, logits).item():
break

# Decode the sentence
print(
"TensorRT model generated text: ",
tokenizer.decode(input_ids[0], skip_special_tokens=True),
)
10 changes: 9 additions & 1 deletion py/torch_tensorrt/_Input.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class _ShapeMode(Enum):
high_tensor_domain_excl: float = low_tensor_domain_incl + DOMAIN_OFFSET
torch_tensor: torch.Tensor = None
name: str = ""
is_shape_tensor: bool = False

def __init__(self, *args: Any, **kwargs: Any) -> None:
"""__init__ Method for torch_tensorrt.Input
Expand Down Expand Up @@ -161,6 +162,9 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
else:
self._explicit_set_dtype = False

if "is_shape_tensor" in kwargs:
self.is_shape_tensor = kwargs["is_shape_tensor"]

if "format" in kwargs:
self.format = memory_format._from(kwargs["format"])

Expand All @@ -174,7 +178,11 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
if "torch_tensor" in kwargs:
self.torch_tensor = kwargs["torch_tensor"]
else:
if self.shape_mode == Input._ShapeMode.DYNAMIC:
if self.is_shape_tensor:
self.torch_tensor = torch.tensor(
kwargs["opt_shape"], dtype=kwargs["dtype"]
)
elif self.shape_mode == Input._ShapeMode.DYNAMIC:
self.torch_tensor = self.example_tensor("opt_shape")
else:
self.torch_tensor = self.example_tensor()
Expand Down
10 changes: 3 additions & 7 deletions py/torch_tensorrt/dynamo/_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,9 @@ def trace(

device = to_torch_device(kwargs.get("device", default_device()))
torch_inputs = get_torch_inputs(inputs, device)
dynamic_shapes = {}
dynamic_shapes = []
for input in inputs:
if isinstance(input, Input) and input.shape_mode == Input._ShapeMode.DYNAMIC:
if not input.name:
raise AssertionError(
f"Expected a name for a dynamic input with shape {input.shape} but found none"
)
min_shape = input.shape["min_shape"]
opt_shape = input.shape["opt_shape"]
max_shape = input.shape["max_shape"]
Expand All @@ -80,8 +76,8 @@ def trace(
max=max_shape[dim],
)

dynamic_shapes[input.name] = dynamic_dims
dynamic_shapes.append(dynamic_dims)
peri044 marked this conversation as resolved.
Show resolved Hide resolved

exp_program = export(mod, tuple(torch_inputs), dynamic_shapes=dynamic_shapes)
exp_program = export(mod, tuple(torch_inputs), dynamic_shapes=tuple(dynamic_shapes))

return exp_program
2 changes: 2 additions & 0 deletions py/torch_tensorrt/dynamo/backend/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ def _pretraced_backend(

gm = apply_lowering_passes(gm, torch_inputs)

logger.debug("Lowered Input graph:\n " + str(gm.graph))

torchtrt_inputs = prepare_inputs(
torch_inputs, disable_memory_format_check=True
)
Expand Down
36 changes: 24 additions & 12 deletions py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Set

import numpy as np
import tensorrt as trt
import torch
import torch.fx
from torch.fx.node import _get_qualified_name
Expand All @@ -25,7 +26,6 @@
from torch_tensorrt.fx.observer import Observer
from torch_tensorrt.logging import TRT_LOGGER

import tensorrt as trt
from packaging import version

_LOGGER: logging.Logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -365,18 +365,29 @@ def placeholder(self, target: str, args: Any, kwargs: Any) -> trt.ITensor:
max_shape = current_input.shape["max_shape"]
# TODO: Does not support disjoint optimization profiles?
assert self.optimization_profiles is not None
self.optimization_profiles[0].set_shape(
target, min_shape, opt_shape, max_shape
)
if current_input.is_shape_tensor:
# For shape_tensors, min/opt/max_shapes correspond to actual values
# of the shapes provided during runtime
self.optimization_profiles[0].set_shape_input(
target, min_shape, opt_shape, max_shape
)
shape.append(1)
peri044 marked this conversation as resolved.
Show resolved Hide resolved
else:
self.optimization_profiles[0].set_shape(
target, min_shape, opt_shape, max_shape
)

assert len(min_shape) == len(opt_shape) == len(max_shape)
for i in range(len(min_shape)):
if min_shape[i] == opt_shape[i] == max_shape[i]:
shape.append(min_shape[i])
else:
# -1 to represent the dynamic dimension
shape.append(-1)
elif current_input.shape_mode == Input._ShapeMode.STATIC:
assert len(min_shape) == len(opt_shape) == len(max_shape)
peri044 marked this conversation as resolved.
Show resolved Hide resolved
for i in range(len(min_shape)):
if min_shape[i] == opt_shape[i] == max_shape[i]:
shape.append(min_shape[i])
else:
# -1 to represent the dynamic dimension
shape.append(-1)
peri044 marked this conversation as resolved.
Show resolved Hide resolved
elif (
not current_input.is_shape_tensor
and current_input.shape_mode == Input._ShapeMode.STATIC
):
assert isinstance(current_input.shape, tuple)
shape = list(current_input.shape)
else:
Expand All @@ -388,6 +399,7 @@ def placeholder(self, target: str, args: Any, kwargs: Any) -> trt.ITensor:
_LOGGER.debug(
f"Adding input to in-progress INetwork: {target} [shape={shape}, dtype={trt_input_dtype}]"
)

return self.ctx.net.add_input(
name=target,
shape=tuple(shape),
Expand Down
16 changes: 8 additions & 8 deletions py/torch_tensorrt/dynamo/conversion/_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
import logging
from typing import List, Sequence

import tensorrt as trt
import torch
from torch.fx.experimental.proxy_tensor import maybe_disable_fake_tensor_mode
from torch_tensorrt._Device import Device
from torch_tensorrt._enums import dtype
from torch_tensorrt._features import ENABLED_FEATURES
Expand All @@ -17,8 +19,6 @@
from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule
from torch_tensorrt.dynamo.utils import get_torch_inputs

import tensorrt as trt

logger = logging.getLogger(__name__)


Expand All @@ -28,12 +28,12 @@ def infer_module_output_dtypes(
device: Device,
truncate_double: bool = False,
) -> List[dtype]:
torch_inputs = get_torch_inputs(inputs, device)
module = module.to(device.to(torch.device))
module_outputs = module(*torch_inputs)

if not isinstance(module_outputs, (list, tuple)):
module_outputs = [module_outputs]
with maybe_disable_fake_tensor_mode():
torch_inputs = get_torch_inputs(inputs, device)
module = module.to(device.to(torch.device))
module_outputs = module(*torch_inputs)
if not isinstance(module_outputs, (list, tuple)):
module_outputs = [module_outputs]

# Int64 outputs can sometimes be generated from within other operators
# such as aten.sum - such outputs can be truncated
Expand Down