-
Notifications
You must be signed in to change notification settings - Fork 24.9k
Closed
Closed
Copy link
Labels
module: cpp-extensionsRelated to torch.utils.cpp_extensionRelated to torch.utils.cpp_extensionmodule: cudaRelated to torch.cuda, and CUDA support in generalRelated to torch.cuda, and CUDA support in generaltriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🐛 Bug
CUDA extension with TORCH_CUDABLAS_CHECK
throws undefined symbol error
To Reproduce
Prepare two files to build a CUDA extension
- cuda_ext.cpp
#include <iostream>
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include <cublas_v2.h>
void a_cublas_function() {
printf("hello world\n");
cublasHandle_t handle;
TORCH_CUDABLAS_CHECK(cublasCreate(&handle));
TORCH_CUDABLAS_CHECK(cublasDestroy(handle));
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("a_cublas_function", &a_cublas_function, "a cublas function");
}
- setup.py
from setuptools import setup, Extension
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CppExtension
cublas_module = CUDAExtension(
name='cublas_ext',
sources=['cublas_ext.cpp']
)
setup(
name='cublas_ext_root',
version='0.1',
ext_modules=[cublas_module],
cmdclass={
'build_ext': BuildExtension.with_options(use_ninja=False)
}
)
build cuda extension with
pip install -v --no-cache-dir .
run test with
python -c 'import torch; import cublas_ext; cublas_ext.a_cublas_function();'
got error message
Traceback (most recent call last):
File "<string>", line 1, in <module>
ImportError: /home/xwang/.local/lib/python3.9/site-packages/cublas_ext.cpython-39-x86_64-linux-gnu.so: undefined symbol: _ZN2at4cuda4blas19_cublasGetErrorEnumE14cublasStatus_t
Expected behavior
No error for TORCH_CUDABLAS_CHECK
in cuda extensions.
Environment
pytorch is source build using gcc 10.3 from latest master commit
Collecting environment information...
PyTorch version: 1.11.0a0+gitf56a1a5
Is debug build: False
CUDA used to build PyTorch: 11.4
ROCM used to build PyTorch: N/A
OS: Manjaro Linux (x86_64)
GCC version: (GCC) 11.1.0
Clang version: Could not collect
CMake version: version 3.21.1
Libc version: glibc-2.33
Python version: 3.9.6 (default, Jun 30 2021, 10:22:16) [GCC 11.1.0] (64-bit runtime)
Python platform: Linux-5.10.60-1-MANJARO-x86_64-with-glibc2.33
Is CUDA available: True
CUDA runtime version: 11.4.100
GPU models and configuration:
GPU 0: NVIDIA GeForce RTX 2070 SUPER
GPU 1: NVIDIA GeForce GTX 1070 Ti
Nvidia driver version: 470.63.01
cuDNN version: Probably one of the following:
/usr/lib/libcudnn.so.8.2.2
/usr/lib/libcudnn_adv_infer.so.8.2.2
/usr/lib/libcudnn_adv_train.so.8.2.2
/usr/lib/libcudnn_cnn_infer.so.8.2.2
/usr/lib/libcudnn_cnn_train.so.8.2.2
/usr/lib/libcudnn_ops_infer.so.8.2.2
/usr/lib/libcudnn_ops_train.so.8.2.2
HIP runtime version: N/A
MIOpen runtime version: N/A
Versions of relevant libraries:
[pip3] mypy==0.812
[pip3] mypy-extensions==0.4.3
[pip3] numpy==1.19.5
[pip3] torch==1.11.0a0+gitf56a1a5
[pip3] torch-tb-profiler==0.2.0
[conda] Could not collect
Additional context
N/A
Metadata
Metadata
Assignees
Labels
module: cpp-extensionsRelated to torch.utils.cpp_extensionRelated to torch.utils.cpp_extensionmodule: cudaRelated to torch.cuda, and CUDA support in generalRelated to torch.cuda, and CUDA support in generaltriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module