diff --git a/setup.py b/setup.py index 48e9c92e127..c75c46424ea 100644 --- a/setup.py +++ b/setup.py @@ -143,7 +143,7 @@ def get_extensions(): if torch.__version__ >= "1.5": from torch.utils.cpp_extension import ROCM_HOME - is_rocm_pytorch = True if ((torch.version.hip is not None) and (ROCM_HOME is not None)) else False + is_rocm_pytorch = (torch.version.hip is not None) and (ROCM_HOME is not None) if is_rocm_pytorch: from torch.utils.hipify import hipify_python @@ -159,7 +159,6 @@ def get_extensions(): # Copy over additional files for file in glob.glob(r"torchvision/csrc/ops/cuda/*.h"): shutil.copy(file, "torchvision/csrc/ops/hip") - else: source_cuda = glob.glob(os.path.join(extensions_dir, "ops", "cuda", "*.cu"))