diff --git a/setup.py b/setup.py index 3ba8352f0b..8aec510eb2 100644 --- a/setup.py +++ b/setup.py @@ -31,22 +31,26 @@ __tensorrt_rtx_version__: str = "0.0" LEGACY_BASE_VERSION_SUFFIX_PATTERN = re.compile("a0$") +# CI_PIPELINE_ID is the environment variable set by DLFW ci build +IS_DLFW_CI = os.environ.get("CI_PIPELINE_ID") is not None def get_root_dir() -> Path: - return Path( - subprocess.check_output(["git", "rev-parse", "--show-toplevel"]) - .decode("ascii") - .strip() - ) + return Path(__file__).parent.absolute() def get_git_revision_short_hash() -> str: - return ( - subprocess.check_output(["git", "rev-parse", "--short", "HEAD"]) - .decode("ascii") - .strip() - ) + # DLFW ci build does not have git + try: + return ( + subprocess.check_output(["git", "rev-parse", "--short", "HEAD"]) + .decode("ascii") + .strip() + ) + except: + print("WARNING: Could not get git revision short hash, using default one") + # in release/ngc/25.10 branch this is the commit hash of the pytorch commit that is used for dlfw package + return "0000000" def get_base_version() -> str: @@ -718,58 +722,57 @@ def run(self): with open(os.path.join(get_root_dir(), "README.md"), "r", encoding="utf-8") as fh: long_description = fh.read() +base_requirements = [ + "packaging>=23", + "typing-extensions>=4.7.0", + "dllist", +] -def get_requirements(): - requirements = [ - "packaging>=23", - "typing-extensions>=4.7.0", - "dllist", - ] +def get_requirements(): if IS_JETPACK: - requirements.extend( - [ - "torch>=2.8.0,<2.9.0", - "tensorrt>=10.3.0,<10.4.0", - "numpy<2.0.0", - ] - ) + requirements = get_jetpack_requirements() elif IS_SBSA: - requirements.extend( - [ - "torch>=2.9.0.dev,<2.10.0", - "tensorrt>=10.12.0,<10.13.0", - "tensorrt-cu12>=10.12.0,<10.13.0", - "tensorrt-cu12-bindings>=10.12.0,<10.13.0", - "tensorrt-cu12-libs>=10.12.0,<10.13.0", - "numpy", - ] - ) + requirements = get_sbsa_requirements() else: - requirements.extend( - [ - "torch>=2.9.0.dev,<2.10.0", - "numpy", - ] - ) - if USE_TRT_RTX: - requirements.extend( - [ - "tensorrt-rtx>=1.0.0.21", + # standard linux and windows requirements + requirements = base_requirements + ["numpy"] + if not IS_DLFW_CI: + requirements = requirements + ["torch>=2.9.0.dev,<2.10.0"] + if USE_TRT_RTX: + requirements = requirements + [ + "tensorrt_rtx>=1.0.0.21", ] - ) - else: - requirements.extend( - [ + else: + requirements = requirements + [ "tensorrt>=10.12.0,<10.13.0", "tensorrt-cu12>=10.12.0,<10.13.0", "tensorrt-cu12-bindings>=10.12.0,<10.13.0", "tensorrt-cu12-libs>=10.12.0,<10.13.0", ] - ) return requirements +def get_jetpack_requirements(): + jetpack_requirements = base_requirements + ["numpy<2.0.0"] + if IS_DLFW_CI: + return jetpack_requirements + return jetpack_requirements + ["torch>=2.8.0,<2.9.0", "tensorrt>=10.3.0,<10.4.0"] + + +def get_sbsa_requirements(): + sbsa_requirements = base_requirements + ["numpy"] + if IS_DLFW_CI: + return sbsa_requirements + return sbsa_requirements + [ + "torch>=2.9.0.dev,<2.10.0", + "tensorrt>=10.12.0,<10.13.0", + "tensorrt-cu12>=10.12.0,<10.13.0", + "tensorrt-cu12-bindings>=10.12.0,<10.13.0", + "tensorrt-cu12-libs>=10.12.0,<10.13.0", + ] + + setup( name="torch_tensorrt", ext_modules=ext_modules, diff --git a/tests/modules/hub.py b/tests/modules/hub.py index 80dd37cd4a..aaa7a86293 100644 --- a/tests/modules/hub.py +++ b/tests/modules/hub.py @@ -39,53 +39,6 @@ # "bert_base_uncased": {"model": cm.BertModule(), "path": "trace"}, } -if importlib.util.find_spec("torchvision"): - import timm - import torchvision.models as models - - torchvision_models = { - "alexnet": {"model": models.alexnet(pretrained=True), "path": "both"}, - "vgg16": {"model": models.vgg16(pretrained=True), "path": "both"}, - "squeezenet": {"model": models.squeezenet1_0(pretrained=True), "path": "both"}, - "densenet": {"model": models.densenet161(pretrained=True), "path": "both"}, - "inception_v3": {"model": models.inception_v3(pretrained=True), "path": "both"}, - "shufflenet": { - "model": models.shufflenet_v2_x1_0(pretrained=True), - "path": "both", - }, - "mobilenet_v2": {"model": models.mobilenet_v2(pretrained=True), "path": "both"}, - "resnext50_32x4d": { - "model": models.resnext50_32x4d(pretrained=True), - "path": "both", - }, - "wideresnet50_2": { - "model": models.wide_resnet50_2(pretrained=True), - "path": "both", - }, - "mnasnet": {"model": models.mnasnet1_0(pretrained=True), "path": "both"}, - "resnet18": { - "model": torch.hub.load( - "pytorch/vision:v0.9.0", "resnet18", pretrained=True - ), - "path": "both", - }, - "resnet50": { - "model": torch.hub.load( - "pytorch/vision:v0.9.0", "resnet50", pretrained=True - ), - "path": "both", - }, - "efficientnet_b0": { - "model": timm.create_model("efficientnet_b0", pretrained=True), - "path": "script", - }, - "vit": { - "model": timm.create_model("vit_base_patch16_224", pretrained=True), - "path": "script", - }, - } - to_test_models.update(torchvision_models) - def get(n, m, manifest): print("Downloading {}".format(n)) diff --git a/tests/py/dynamo/models/test_models.py b/tests/py/dynamo/models/test_models.py index 6f76d9510d..c52b732c42 100644 --- a/tests/py/dynamo/models/test_models.py +++ b/tests/py/dynamo/models/test_models.py @@ -14,8 +14,9 @@ assertions = unittest.TestCase() if importlib.util.find_spec("torchvision"): - import timm import torchvision.models as models +if importlib.util.find_spec("timm"): + import timm @pytest.mark.unit @@ -132,11 +133,11 @@ def test_resnet18_torch_exec_ops(ir): @pytest.mark.unit +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) @unittest.skipIf( not importlib.util.find_spec("torchvision"), "torchvision is not installed", ) -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) def test_mobilenet_v2(ir, dtype): if torchtrt.ENABLED_FEATURES.tensorrt_rtx and dtype == torch.bfloat16: pytest.skip("TensorRT-RTX does not support bfloat16") @@ -174,11 +175,11 @@ def test_mobilenet_v2(ir, dtype): @pytest.mark.unit +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) @unittest.skipIf( not importlib.util.find_spec("timm") or not importlib.util.find_spec("torchvision"), "timm or torchvision not installed", ) -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) def test_efficientnet_b0(ir, dtype): if torchtrt.ENABLED_FEATURES.tensorrt_rtx and dtype == torch.bfloat16: pytest.skip("TensorRT-RTX does not support bfloat16") @@ -221,11 +222,11 @@ def test_efficientnet_b0(ir, dtype): @pytest.mark.unit +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) @unittest.skipIf( not importlib.util.find_spec("transformers"), "transformers is required to run this test", ) -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) def test_bert_base_uncased(ir, dtype): if torchtrt.ENABLED_FEATURES.tensorrt_rtx and dtype == torch.bfloat16: pytest.skip("TensorRT-RTX does not support bfloat16") diff --git a/tests/py/dynamo/models/test_models_export.py b/tests/py/dynamo/models/test_models_export.py index e8c3933d00..4620176523 100644 --- a/tests/py/dynamo/models/test_models_export.py +++ b/tests/py/dynamo/models/test_models_export.py @@ -12,9 +12,11 @@ from packaging.version import Version if importlib.util.find_spec("torchvision"): - import timm import torchvision.models as models +if importlib.util.find_spec("timm"): + import timm + assertions = unittest.TestCase()