Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,11 @@ When running in a docker container without nvidia driver, PyTorch needs to evalu
export TORCH_CUDA_ARCH_LIST = "6.0 6.1 7.2+PTX 7.5+PTX"
```

### Windows

If you are installing this on Windows specifically, **you will need to point the setup to your Visual Studio installation** for some neccessary libraries and header files.
To do this, add the include and library paths of your installation to the path lists in setup.py as described in the respective comments in the code.

If you are running into any installation problems, please create an [issue](https://github.com/rusty1s/pytorch_scatter/issues).
Be sure to import `torch` first before using this package to resolve symbols the dynamic linker must see.

Expand Down
17 changes: 15 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,19 @@
import torch
from torch.utils.cpp_extension import CppExtension, CUDAExtension, CUDA_HOME


# Windows users: Edit both of these to contain your VS include path, i.e.
# cxx_extra_compile_args = ['-I{VISUAL_STUDIO_DIR}\\include']
# nvcc_extra_compile_args = [..., '-I{VISUAL_STUDIO_DIR}\\include']
cxx_extra_compile_args = []
nvcc_extra_compile_args = ['-arch=sm_35', '--expt-relaxed-constexpr']

# Windows users: Edit both of these to contain your VS library path, i.e.
# cxx_extra_link_args = ['/LIBPATH:{VISUAL_STUDIO_DIR}\\lib\\{x86|x64}']
# nvcc_extra_link_args = ['/LIBPATH:{VISUAL_STUDIO_DIR}\\lib\\{x86|x64}']
cxx_extra_link_args = []
nvcc_extra_link_args = []

if platform.system() != 'Windows':
cxx_extra_compile_args += ['-Wno-unused-variable']
TORCH_MAJOR = int(torch.__version__.split('.')[0])
Expand All @@ -23,7 +34,8 @@
ext_modules += [
CppExtension(
f'torch_scatter.{ext}_cpu', [f'cpu/{ext}.cpp'],
extra_compile_args=cxx_extra_compile_args) for ext in exts
extra_compile_args=cxx_extra_compile_args,
extra_link_args=cxx_extra_link_args) for ext in exts
]

if CUDA_HOME is not None and '--cpu' not in argv:
Expand All @@ -35,7 +47,8 @@
extra_compile_args={
'cxx': cxx_extra_compile_args,
'nvcc': nvcc_extra_compile_args,
}) for ext in exts
},
extra_link_args=nvcc_extra_link_args) for ext in exts
]

__version__ = '1.5.0'
Expand Down