diff --git a/README.md b/README.md index 1f7a0f2e..47b8f151 100644 --- a/README.md +++ b/README.md @@ -64,6 +64,11 @@ When running in a docker container without nvidia driver, PyTorch needs to evalu export TORCH_CUDA_ARCH_LIST = "6.0 6.1 7.2+PTX 7.5+PTX" ``` +### Windows + +If you are installing this on Windows specifically, **you will need to point the setup to your Visual Studio installation** for some neccessary libraries and header files. +To do this, add the include and library paths of your installation to the path lists in setup.py as described in the respective comments in the code. + If you are running into any installation problems, please create an [issue](https://github.com/rusty1s/pytorch_scatter/issues). Be sure to import `torch` first before using this package to resolve symbols the dynamic linker must see. diff --git a/setup.py b/setup.py index ac4c4300..378993c9 100644 --- a/setup.py +++ b/setup.py @@ -7,8 +7,19 @@ import torch from torch.utils.cpp_extension import CppExtension, CUDAExtension, CUDA_HOME + +# Windows users: Edit both of these to contain your VS include path, i.e. +# cxx_extra_compile_args = ['-I{VISUAL_STUDIO_DIR}\\include'] +# nvcc_extra_compile_args = [..., '-I{VISUAL_STUDIO_DIR}\\include'] cxx_extra_compile_args = [] nvcc_extra_compile_args = ['-arch=sm_35', '--expt-relaxed-constexpr'] + +# Windows users: Edit both of these to contain your VS library path, i.e. +# cxx_extra_link_args = ['/LIBPATH:{VISUAL_STUDIO_DIR}\\lib\\{x86|x64}'] +# nvcc_extra_link_args = ['/LIBPATH:{VISUAL_STUDIO_DIR}\\lib\\{x86|x64}'] +cxx_extra_link_args = [] +nvcc_extra_link_args = [] + if platform.system() != 'Windows': cxx_extra_compile_args += ['-Wno-unused-variable'] TORCH_MAJOR = int(torch.__version__.split('.')[0]) @@ -23,7 +34,8 @@ ext_modules += [ CppExtension( f'torch_scatter.{ext}_cpu', [f'cpu/{ext}.cpp'], - extra_compile_args=cxx_extra_compile_args) for ext in exts + extra_compile_args=cxx_extra_compile_args, + extra_link_args=cxx_extra_link_args) for ext in exts ] if CUDA_HOME is not None and '--cpu' not in argv: @@ -35,7 +47,8 @@ extra_compile_args={ 'cxx': cxx_extra_compile_args, 'nvcc': nvcc_extra_compile_args, - }) for ext in exts + }, + extra_link_args=nvcc_extra_link_args) for ext in exts ] __version__ = '1.5.0'