diff --git a/setup.py b/setup.py index 00ea774c..4b70666d 100644 --- a/setup.py +++ b/setup.py @@ -4,6 +4,7 @@ import os.path as osp from itertools import product from setuptools import setup, find_packages +import platform import torch from torch.__config__ import parallel_info @@ -66,6 +67,11 @@ def get_extensions(): else: print('Compiling without OpenMP...') + # Compile for mac arm64 + if (sys.platform == 'darwin' and platform.machine() == 'arm64'): + extra_compile_args['cxx'] += ['-arch', 'arm64'] + extra_link_args += ['-arch', 'arm64'] + if suffix == 'cuda': define_macros += [('WITH_CUDA', None)] nvcc_flags = os.getenv('NVCC_FLAGS', '')