Skip to content

Commit

Permalink
Link conda packages with cusparselt
Browse files Browse the repository at this point in the history
Fixes pytorch/pytorch#115085

(cherry picked from commit c55c58b)
  • Loading branch information
malfet committed Dec 28, 2023
1 parent 31d77df commit b5527e4
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 14 deletions.
1 change: 0 additions & 1 deletion conda/build_pytorch.sh
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,6 @@ for py_ver in "${DESIRED_PYTHON[@]}"; do
PYTORCH_GITHUB_ROOT_DIR="$pytorch_rootdir" \
PYTORCH_BUILD_STRING="$build_string" \
PYTORCH_MAGMA_CUDA_VERSION="$cuda_nodot" \
USE_CUSPARSELT=0 \
conda build -c "$ANACONDA_USER" \
${NO_TEST:-} \
--no-anaconda-upload \
Expand Down
21 changes: 9 additions & 12 deletions conda/pytorch-nightly/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -53,21 +53,18 @@ fi
if [[ -n "$build_with_cuda" ]]; then
export TORCH_NVCC_FLAGS="-Xfatbin -compress-all"
TORCH_CUDA_ARCH_LIST="5.0;6.0;6.1;7.0;7.5;8.0;8.6"
export USE_STATIC_CUDNN=1 # links cudnn statically (driven by tools/setup_helpers/cudnn.py)
export USE_STATIC_CUDNN=0 # link with cudnn dynamically
export USE_CUSPARSELT=1 # link with cusparselt

if [[ $CUDA_VERSION == 11.8* ]]; then
TORCH_CUDA_ARCH_LIST="$TORCH_CUDA_ARCH_LIST;3.7+PTX;9.0"
#for cuda 11.8 we use cudnn 8.7
#which does not have single static libcudnn_static.a deliverable to link with
export USE_STATIC_CUDNN=0
#for cuda 11.8 include all dynamic loading libraries
DEPS_LIST=(/usr/local/cuda/lib64/libcudnn*.so.8 /usr/local/cuda-11.8/extras/CUPTI/lib64/libcupti.so.11.8)
TORCH_CUDA_ARCH_LIST="$TORCH_CUDA_ARCH_LIST;3.7+PTX;9.0"
#for cuda 11.8 include all dynamic loading libraries
DEPS_LIST=(/usr/local/cuda/lib64/libcudnn*.so.8 /usr/local/cuda-11.8/extras/CUPTI/lib64/libcupti.so.11.8 /usr/local/cuda/lib64/libcusparseLt.so.0)
elif [[ $CUDA_VERSION == 12.1* ]]; then
# cuda 12 does not support sm_3x
TORCH_CUDA_ARCH_LIST="$TORCH_CUDA_ARCH_LIST;9.0"
# for cuda 12.1 we use cudnn 8.8 and include all dynamic loading libraries
export USE_STATIC_CUDNN=0
DEPS_LIST=(/usr/local/cuda/lib64/libcudnn*.so.8 /usr/local/cuda-12.1/extras/CUPTI/lib64/libcupti.so.12)
# cuda 12 does not support sm_3x
TORCH_CUDA_ARCH_LIST="$TORCH_CUDA_ARCH_LIST;9.0"
# for cuda 12.1 we use cudnn 8.8 and include all dynamic loading libraries
DEPS_LIST=(/usr/local/cuda/lib64/libcudnn*.so.8 /usr/local/cuda-12.1/extras/CUPTI/lib64/libcupti.so.12 /usr/local/cuda/lib64/libcusparseLt.so.0)
fi
if [[ -n "$OVERRIDE_TORCH_CUDA_ARCH_LIST" ]]; then
TORCH_CUDA_ARCH_LIST="$OVERRIDE_TORCH_CUDA_ARCH_LIST"
Expand Down
1 change: 0 additions & 1 deletion conda/pytorch-nightly/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@ build:
- _GLIBCXX_USE_CXX11_ABI # [unix]
- MAX_JOBS # [unix]
- OVERRIDE_TORCH_CUDA_ARCH_LIST
- USE_CUSPARSELT

test:
imports:
Expand Down

0 comments on commit b5527e4

Please sign in to comment.