diff --git a/setup.py b/setup.py index 2779a29a152..cc999538c8c 100644 --- a/setup.py +++ b/setup.py @@ -88,14 +88,17 @@ def get_extensions(): sources = main_file + source_cpu extension = CppExtension - test_dir = os.path.join(this_dir, 'test') - models_dir = os.path.join(this_dir, 'torchvision', 'csrc', 'models') - test_file = glob.glob(os.path.join(test_dir, '*.cpp')) - source_models = glob.glob(os.path.join(models_dir, '*.cpp')) - - test_file = [os.path.join(test_dir, s) for s in test_file] - source_models = [os.path.join(models_dir, s) for s in source_models] - tests = test_file + source_models + compile_cpp_tests = os.getenv('WITH_CPP_MODELS_TEST', '0') == '1' + if compile_cpp_tests: + test_dir = os.path.join(this_dir, 'test') + models_dir = os.path.join(this_dir, 'torchvision', 'csrc', 'models') + test_file = glob.glob(os.path.join(test_dir, '*.cpp')) + source_models = glob.glob(os.path.join(models_dir, '*.cpp')) + + test_file = [os.path.join(test_dir, s) for s in test_file] + source_models = [os.path.join(models_dir, s) for s in source_models] + tests = test_file + source_models + tests_include_dirs = [test_dir, models_dir] define_macros = [] @@ -123,7 +126,6 @@ def get_extensions(): sources = [os.path.join(extensions_dir, s) for s in sources] include_dirs = [extensions_dir] - tests_include_dirs = [test_dir, models_dir] ffmpeg_exe = distutils.spawn.find_executable('ffmpeg') has_ffmpeg = ffmpeg_exe is not None @@ -143,15 +145,18 @@ def get_extensions(): include_dirs=include_dirs, define_macros=define_macros, extra_compile_args=extra_compile_args, - ), - extension( - 'torchvision._C_tests', - tests, - include_dirs=tests_include_dirs, - define_macros=define_macros, - extra_compile_args=extra_compile_args, - ), + ) ] + if compile_cpp_tests: + ext_modules.append( + extension( + 'torchvision._C_tests', + tests, + include_dirs=tests_include_dirs, + define_macros=define_macros, + extra_compile_args=extra_compile_args, + ) + ) if has_ffmpeg: ext_modules.append( CppExtension( diff --git a/test/test_cpp_models.py b/test/test_cpp_models.py index 0b1a0756a00..b6654a0278d 100644 --- a/test/test_cpp_models.py +++ b/test/test_cpp_models.py @@ -1,12 +1,17 @@ import torch import os import unittest -from torchvision import models, transforms, _C_tests +from torchvision import models, transforms import sys from PIL import Image import torchvision.transforms.functional as F +try: + from torchvision import _C_tests +except ImportError: + _C_tests = None + def process_model(model, tensor, func, name): model.eval()