Skip to content

Commit

Permalink
chore: Make from and to methods use the same TRT API (#2858)
Browse files Browse the repository at this point in the history
Signed-off-by: Naren Dasan <naren@narendasan.com>
Signed-off-by: Naren Dasan <narens@nvidia.com>
  • Loading branch information
narendasan authored and laikhtewari committed May 24, 2024
1 parent 1fb8c1d commit 55f1149
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 20 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: Build and test linux wheels
name: Build and test Linux wheels

on:
pull_request:
Expand Down Expand Up @@ -86,7 +86,7 @@ jobs:
popd
pushd .
cd tests/py/ts
${CONDA_RUN} python -m pip install --pre pytest timm transformers parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver
${CONDA_RUN} python -m pip install --pre -r ../requirements.txt
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/ts_api_test_results.xml api/
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/ts_models_test_results.xml models/
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/ts_integrations_test_results.xml integrations/
Expand Down Expand Up @@ -117,7 +117,7 @@ jobs:
export LD_LIBRARY_PATH=/opt/torch-tensorrt-builds/TensorRT-10.0.1.6/lib:$LD_LIBRARY_PATH
pushd .
cd tests/py/dynamo
${CONDA_RUN} python -m pip install --pre pytest-xdist timm transformers parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver
${CONDA_RUN} python -m pip install --pre -r ../requirements.txt
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_converters_test_results.xml -n 10 conversion/
popd
Expand Down Expand Up @@ -146,7 +146,7 @@ jobs:
export LD_LIBRARY_PATH=/opt/torch-tensorrt-builds/TensorRT-10.0.1.6/lib:$LD_LIBRARY_PATH
pushd .
cd tests/py/dynamo
${CONDA_RUN} python -m pip install --pre pytest timm transformers parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver
${CONDA_RUN} python -m pip install --pre -r ../requirements.txt
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_fe_test_results.xml --ir dynamo models/test_models_export.py
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/dyn_models_export.xml --ir dynamo models/test_dyn_models.py
popd
Expand Down Expand Up @@ -176,7 +176,7 @@ jobs:
export LD_LIBRARY_PATH=/opt/torch-tensorrt-builds/TensorRT-10.0.1.6/lib:$LD_LIBRARY_PATH
pushd .
cd tests/py/dynamo
${CONDA_RUN} python -m pip install --pre pytest timm transformers parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver
${CONDA_RUN} python -m pip install --pre -r ../requirements.txt
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/export_serde_test_results.xml --ir dynamo models/test_export_serde.py
popd
Expand Down Expand Up @@ -205,7 +205,7 @@ jobs:
export LD_LIBRARY_PATH=/opt/torch-tensorrt-builds/TensorRT-10.0.1.6/lib:$LD_LIBRARY_PATH
pushd .
cd tests/py/dynamo
${CONDA_RUN} python -m pip install --pre pytest-xdist timm transformers parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver
${CONDA_RUN} python -m pip install --pre -r ../requirements.txt
${CONDA_RUN} python -m pytest -n 10 --junitxml=${RUNNER_TEST_RESULTS_DIR}/torch_compile_be_test_results.xml backend/
${CONDA_RUN} python -m pytest -n 4 --junitxml=${RUNNER_TEST_RESULTS_DIR}/torch_comple_be_e2e_test_results.xml --ir torch_compile models/test_models.py
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/torch_compile_dyn_models_export.xml --ir torch_compile models/test_dyn_models.py
Expand Down Expand Up @@ -236,7 +236,7 @@ jobs:
export LD_LIBRARY_PATH=/opt/torch-tensorrt-builds/TensorRT-10.0.1.6/lib:$LD_LIBRARY_PATH
pushd .
cd tests/py/dynamo
${CONDA_RUN} python -m pip install --pre pytest-xdist timm transformers parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver
${CONDA_RUN} python -m pip install --pre -r ../requirements.txt
${CONDA_RUN} python -m pytest -n 4 --junitxml=${RUNNER_TEST_RESULTS_DIR}/tests_py_dynamo_core_runtime_test_results.xml runtime/
${CONDA_RUN} python -m pytest -n 4 --junitxml=${RUNNER_TEST_RESULTS_DIR}/tests_py_dynamo_core_partitioning_test_results.xml partitioning/
${CONDA_RUN} python -m pytest -n 4 --junitxml=${RUNNER_TEST_RESULTS_DIR}/tests_py_dynamo_core_lowering_test_results.xml lowering/
Expand Down Expand Up @@ -266,6 +266,6 @@ jobs:
export LD_LIBRARY_PATH=/opt/torch-tensorrt-builds/TensorRT-10.0.1.6/lib:$LD_LIBRARY_PATH
pushd .
cd tests/py/core
${CONDA_RUN} python -m pip install --pre pytest-xdist timm transformers parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver
${CONDA_RUN} python -m pip install --pre -r ../requirements.txt
${CONDA_RUN} python -m pytest -n 4 --junitxml=${RUNNER_TEST_RESULTS_DIR}/tests_py_core_test_results.xml .
popd
8 changes: 4 additions & 4 deletions .github/workflows/build-test-windows.yml
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ jobs:
export USE_HOST_DEPS=1
pushd .
cd tests/py/dynamo
${CONDA_RUN} python -m pip install --pre pytest-xdist timm transformers parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver
${CONDA_RUN} python -m pip install --pre -r ../requirements.txt
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_converters_test_results.xml -n 10 conversion/
popd
Expand All @@ -98,7 +98,7 @@ jobs:
export USE_HOST_DEPS=1
pushd .
cd tests/py/dynamo
${CONDA_RUN} python -m pip install --pre pytest-xdist timm transformers parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver
${CONDA_RUN} python -m pip install --pre -r ../requirements.txt
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_fe_test_results.xml --ir dynamo models/test_models_export.py
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/dyn_models_export.xml --ir dynamo models/test_dyn_models.py
popd
Expand All @@ -125,7 +125,7 @@ jobs:
export USE_HOST_DEPS=1
pushd .
cd tests/py/dynamo
${CONDA_RUN} python -m pip install --pre pytest-xdist timm transformers parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver
${CONDA_RUN} python -m pip install --pre -r ../requirements.txt
${CONDA_RUN} python -m pytest -n 10 --junitxml=${RUNNER_TEST_RESULTS_DIR}/torch_compile_be_test_results.xml backend/
${CONDA_RUN} python -m pytest -n 4 --junitxml=${RUNNER_TEST_RESULTS_DIR}/torch_comple_be_e2e_test_results.xml --ir torch_compile models/test_models.py
popd
Expand All @@ -152,7 +152,7 @@ jobs:
export USE_HOST_DEPS=1
pushd .
cd tests/py/dynamo
${CONDA_RUN} python -m pip install --pre pytest-xdist timm transformers parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver
${CONDA_RUN} python -m pip install --pre -r ../requirements.txt
${CONDA_RUN} python -m pytest -n 4 --junitxml=${RUNNER_TEST_RESULTS_DIR}/tests_py_dynamo_core_runtime_test_results.xml runtime/
${CONDA_RUN} python -m pytest -n 4 --junitxml=${RUNNER_TEST_RESULTS_DIR}/tests_py_dynamo_core_partitioning_test_results.xml partitioning/
${CONDA_RUN} python -m pytest -n 4 --junitxml=${RUNNER_TEST_RESULTS_DIR}/tests_py_dynamo_core_lowering_test_results.xml lowering/
Expand Down
16 changes: 8 additions & 8 deletions py/torch_tensorrt/_enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,21 +99,21 @@ def _from(
f"Provided an unsupported data type as a data type for translation (support: bool, int, long, half, float, bfloat16), got: {t}"
)
elif isinstance(t, trt.DataType):
if t == trt.uint8:
if t == trt.DataType.UINT8:
return dtype.u8
elif t == trt.int8:
elif t == trt.DataType.INT8:
return dtype.i8
elif t == trt.int32:
elif t == trt.DataType.INT32:
return dtype.i32
elif t == trt.int64:
elif t == trt.DataType.INT64:
return dtype.i64
elif t == trt.float16:
elif t == trt.DataType.HALF:
return dtype.f16
elif t == trt.float32:
elif t == trt.DataType.FLOAT:
return dtype.f32
elif t == trt.bool:
elif t == trt.DataType.BOOL:
return dtype.b
elif t == trt.bf16:
elif t == trt.DataType.BF16:
return dtype.bf16
else:
raise TypeError(
Expand Down
6 changes: 6 additions & 0 deletions tests/py/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pytest>=8.2.1
pytest-xdist>=3.6.1
timm>=1.0.3
transformers==4.39.3
parameterized>=0.2.0
expecttest==0.1.6

0 comments on commit 55f1149

Please sign in to comment.