Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 51 additions & 48 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,22 +31,26 @@
__tensorrt_rtx_version__: str = "0.0"

LEGACY_BASE_VERSION_SUFFIX_PATTERN = re.compile("a0$")
# CI_PIPELINE_ID is the environment variable set by DLFW ci build
IS_DLFW_CI = os.environ.get("CI_PIPELINE_ID") is not None


def get_root_dir() -> Path:
return Path(
subprocess.check_output(["git", "rev-parse", "--show-toplevel"])
.decode("ascii")
.strip()
)
return Path(__file__).parent.absolute()


def get_git_revision_short_hash() -> str:
return (
subprocess.check_output(["git", "rev-parse", "--short", "HEAD"])
.decode("ascii")
.strip()
)
# DLFW ci build does not have git
try:
return (
subprocess.check_output(["git", "rev-parse", "--short", "HEAD"])
.decode("ascii")
.strip()
)
except:
print("WARNING: Could not get git revision short hash, using default one")
# in release/ngc/25.10 branch this is the commit hash of the pytorch commit that is used for dlfw package
return "0000000"


def get_base_version() -> str:
Expand Down Expand Up @@ -718,58 +722,57 @@ def run(self):
with open(os.path.join(get_root_dir(), "README.md"), "r", encoding="utf-8") as fh:
long_description = fh.read()

base_requirements = [
"packaging>=23",
"typing-extensions>=4.7.0",
"dllist",
]

def get_requirements():
requirements = [
"packaging>=23",
"typing-extensions>=4.7.0",
"dllist",
]

def get_requirements():
if IS_JETPACK:
requirements.extend(
[
"torch>=2.8.0,<2.9.0",
"tensorrt>=10.3.0,<10.4.0",
"numpy<2.0.0",
]
)
requirements = get_jetpack_requirements()
elif IS_SBSA:
requirements.extend(
[
"torch>=2.9.0.dev,<2.10.0",
"tensorrt>=10.12.0,<10.13.0",
"tensorrt-cu12>=10.12.0,<10.13.0",
"tensorrt-cu12-bindings>=10.12.0,<10.13.0",
"tensorrt-cu12-libs>=10.12.0,<10.13.0",
"numpy",
]
)
requirements = get_sbsa_requirements()
else:
requirements.extend(
[
"torch>=2.9.0.dev,<2.10.0",
"numpy",
]
)
if USE_TRT_RTX:
requirements.extend(
[
"tensorrt-rtx>=1.0.0.21",
# standard linux and windows requirements
requirements = base_requirements + ["numpy"]
if not IS_DLFW_CI:
requirements = requirements + ["torch>=2.9.0.dev,<2.10.0"]
if USE_TRT_RTX:
requirements = requirements + [
"tensorrt_rtx>=1.0.0.21",
]
)
else:
requirements.extend(
[
else:
requirements = requirements + [
"tensorrt>=10.12.0,<10.13.0",
"tensorrt-cu12>=10.12.0,<10.13.0",
"tensorrt-cu12-bindings>=10.12.0,<10.13.0",
"tensorrt-cu12-libs>=10.12.0,<10.13.0",
]
)
return requirements


def get_jetpack_requirements():
jetpack_requirements = base_requirements + ["numpy<2.0.0"]
if IS_DLFW_CI:
return jetpack_requirements
return jetpack_requirements + ["torch>=2.8.0,<2.9.0", "tensorrt>=10.3.0,<10.4.0"]


def get_sbsa_requirements():
sbsa_requirements = base_requirements + ["numpy"]
if IS_DLFW_CI:
return sbsa_requirements
return sbsa_requirements + [
"torch>=2.9.0.dev,<2.10.0",
"tensorrt>=10.12.0,<10.13.0",
"tensorrt-cu12>=10.12.0,<10.13.0",
"tensorrt-cu12-bindings>=10.12.0,<10.13.0",
"tensorrt-cu12-libs>=10.12.0,<10.13.0",
]


setup(
name="torch_tensorrt",
ext_modules=ext_modules,
Expand Down
47 changes: 0 additions & 47 deletions tests/modules/hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,53 +39,6 @@
# "bert_base_uncased": {"model": cm.BertModule(), "path": "trace"},
}

if importlib.util.find_spec("torchvision"):
import timm
import torchvision.models as models

torchvision_models = {
"alexnet": {"model": models.alexnet(pretrained=True), "path": "both"},
"vgg16": {"model": models.vgg16(pretrained=True), "path": "both"},
"squeezenet": {"model": models.squeezenet1_0(pretrained=True), "path": "both"},
"densenet": {"model": models.densenet161(pretrained=True), "path": "both"},
"inception_v3": {"model": models.inception_v3(pretrained=True), "path": "both"},
"shufflenet": {
"model": models.shufflenet_v2_x1_0(pretrained=True),
"path": "both",
},
"mobilenet_v2": {"model": models.mobilenet_v2(pretrained=True), "path": "both"},
"resnext50_32x4d": {
"model": models.resnext50_32x4d(pretrained=True),
"path": "both",
},
"wideresnet50_2": {
"model": models.wide_resnet50_2(pretrained=True),
"path": "both",
},
"mnasnet": {"model": models.mnasnet1_0(pretrained=True), "path": "both"},
"resnet18": {
"model": torch.hub.load(
"pytorch/vision:v0.9.0", "resnet18", pretrained=True
),
"path": "both",
},
"resnet50": {
"model": torch.hub.load(
"pytorch/vision:v0.9.0", "resnet50", pretrained=True
),
"path": "both",
},
"efficientnet_b0": {
"model": timm.create_model("efficientnet_b0", pretrained=True),
"path": "script",
},
"vit": {
"model": timm.create_model("vit_base_patch16_224", pretrained=True),
"path": "script",
},
}
to_test_models.update(torchvision_models)


def get(n, m, manifest):
print("Downloading {}".format(n))
Expand Down
9 changes: 5 additions & 4 deletions tests/py/dynamo/models/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@
assertions = unittest.TestCase()

if importlib.util.find_spec("torchvision"):
import timm
import torchvision.models as models
if importlib.util.find_spec("timm"):
import timm


@pytest.mark.unit
Expand Down Expand Up @@ -132,11 +133,11 @@ def test_resnet18_torch_exec_ops(ir):


@pytest.mark.unit
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
@unittest.skipIf(
not importlib.util.find_spec("torchvision"),
"torchvision is not installed",
)
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
def test_mobilenet_v2(ir, dtype):
if torchtrt.ENABLED_FEATURES.tensorrt_rtx and dtype == torch.bfloat16:
pytest.skip("TensorRT-RTX does not support bfloat16")
Expand Down Expand Up @@ -174,11 +175,11 @@ def test_mobilenet_v2(ir, dtype):


@pytest.mark.unit
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
@unittest.skipIf(
not importlib.util.find_spec("timm") or not importlib.util.find_spec("torchvision"),
"timm or torchvision not installed",
)
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
def test_efficientnet_b0(ir, dtype):
if torchtrt.ENABLED_FEATURES.tensorrt_rtx and dtype == torch.bfloat16:
pytest.skip("TensorRT-RTX does not support bfloat16")
Expand Down Expand Up @@ -221,11 +222,11 @@ def test_efficientnet_b0(ir, dtype):


@pytest.mark.unit
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
@unittest.skipIf(
not importlib.util.find_spec("transformers"),
"transformers is required to run this test",
)
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
def test_bert_base_uncased(ir, dtype):
if torchtrt.ENABLED_FEATURES.tensorrt_rtx and dtype == torch.bfloat16:
pytest.skip("TensorRT-RTX does not support bfloat16")
Expand Down
4 changes: 3 additions & 1 deletion tests/py/dynamo/models/test_models_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
from packaging.version import Version

if importlib.util.find_spec("torchvision"):
import timm
import torchvision.models as models

if importlib.util.find_spec("timm"):
import timm

assertions = unittest.TestCase()


Expand Down
Loading