Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bizarre "no kernel image" error for pytorch built from source #32759

Closed
aluo-x opened this issue Jan 29, 2020 · 7 comments
Closed

Bizarre "no kernel image" error for pytorch built from source #32759

aluo-x opened this issue Jan 29, 2020 · 7 comments
Labels
module: binaries Anything related to official binaries that we release to users module: cuda Related to torch.cuda, and CUDA support in general triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@aluo-x
Copy link

aluo-x commented Jan 29, 2020

🐛 Bug

A workstation with Nvidia k40c (compute capability 3.5), Anaconda 2019.10, python 3.7.6, Ubuntu 19.10, GCC 8.3.0, driver 440.48.02.

CUDA SDK version 10.1.243, and cudnn 7.6.5.

Pytorch 1.4.0 is built from source using instructions according to those available on github, with Magma 101 installed.
export TORCH_CUDA_ARCH_LIST="3.5" is set.

Compilation succeeds without fatal errors.

The following works:

import torch
tmp = torch.randn(3,3).cuda()
d = torch.det(tmp)
print(d)
# prints tensor(-0.4732, device='cuda:0')

The following fails:

import torch
tmp = torch.randn(1,3,3).cuda()
# Note that the working example uses shape (3,3), while using shape (1,3,3) fails
d = torch.det(tmp)
# RuntimeError: CUDA error: no kernel image is available for execution on the device

It is super odd that adding a single dimension causes the determinant to fail. I can reproduce the issue using Pytorch 1.3.1 and 1.4.0. The same operation works on the CPU. More information can be provided if needed.

import torch
print(torch.__version__)
# gives '1.4.0a0+7f73f1d'
print(torch.cuda.get_device_capability(0))
# gives (3, 5)

Edit:
The vast majority of operations work. I can run complex 100M plus parameter networks with instance/layer/spectral/batch norm, as well as applying gradient penalties to a discriminator.

Edit:

import torch
tmp = torch.randn(3,3).cuda()
print(torch.inverse(a))
# CUDA error: no kernel image is available for execution on the device

cc @ezyang @ngimel

@mrshenli mrshenli added module: cuda Related to torch.cuda, and CUDA support in general module: operators triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Jan 29, 2020
@ngimel ngimel added the module: binaries Anything related to official binaries that we release to users label Jan 29, 2020
@ezyang
Copy link
Contributor

ezyang commented Jan 29, 2020

det is implemented in MAGMA, so this might be related to how you compiled against magma. Could you please run the environment collection script

@aluo-x
Copy link
Author

aluo-x commented Jan 29, 2020

Very insightful. I remember seeing some documentation on the MAGMA gitlab page allowing specification of CUDA compute capabilities, I'm using the pre-built ones from the pytorch conda channel.

I will try and build without MAGMA.

Collecting environment information...
PyTorch version: 1.4.0a0+7f73f1d
Is debug build: No
CUDA used to build PyTorch: 10.1

OS: Ubuntu 19.10
GCC version: (Ubuntu 8.3.0-23ubuntu2) 8.3.0
CMake version: version 3.14.0

Python version: 3.7
Is CUDA available: Yes
CUDA runtime version: 10.1.243
GPU models and configuration:
GPU 0: GeForce GTX TITAN Black
GPU 1: Tesla K40c

Nvidia driver version: 440.48.02
cuDNN version: /usr/local/cuda-10.1/targets/x86_64-linux/lib/libcudnn.so.7

Versions of relevant libraries:
[pip] neural-renderer-pytorch==1.1.3
[pip] numpy==1.18.1
[pip] numpydoc==0.9.2
[pip] pytorch3d==0.1
[pip] torch==1.4.0a0+7f73f1d
[pip] torchvision==0.5.0
[conda] blas                      1.0                         mkl
[conda] magma-cuda101             2.5.1                         1    pytorch
[conda] mkl                       2019.4                      243
[conda] mkl-include               2019.4                      243
[conda] mkl-service               2.3.0            py37he904b0f_0
[conda] mkl_fft                   1.0.15           py37ha843d7b_0
[conda] mkl_random                1.1.0            py37hd6b4f25_0
[conda] neural-renderer-pytorch   1.1.3                    pypi_0    pypi
[conda] pytorch3d                 0.1                       dev_0    <develop>
[conda] torchvision               0.5.0                    pypi_0    pypi

@aluo-x
Copy link
Author

aluo-x commented Jan 29, 2020

It does seem to be a MAGMA problem.
The build scripts for the conda MAGMA in the pytorch channel specifies the minimum compute capability to be K80 generation (so newer than those for K40c).

Closing now.

@aluo-x aluo-x closed this as completed Jan 29, 2020
@aluo-x
Copy link
Author

aluo-x commented Jan 29, 2020

Modified the patch code to add support for compute capability 3.5
Rebuilt using the provided script

conda build .
conda install OUTPUT_TAR.GZ_LOCATION

And rebuilt pytorch again, worked like a charm

@rajeshroy402
Copy link

I followed this document - https://github.com/pytorch/pytorch/#from-source and build it.
I will share you the specs.
I have Nvidia driver 470.x.x with Cuda 11.4.x installed for my UBUNTU 20.04 LTS system.
I didn’t find conda install -c pytorch magma-cuda114 so went with `conda install -c pytorch magma-cuda112'
Rest all the installation was smooth.
Still, if I try to run something, I get the below-quoted errors:

RuntimeError: CUDA error: no kernel image is available for execution on the device
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
----------------------------


--------------------------------------------------------------------------------------------

home/assertst/.local/lib/python3.8/site-packages/torch/cuda/__init__.py:83: UserWarning: 
    Found GPU%d %s which is of cuda capability %d.%d.
    PyTorch no longer supports this GPU because it is too old.
    The minimum cuda capability supported by this library is %d.%d.
    
  warnings.warn(old_gpu_warn.format(d, name, major, minor, min_arch // 10, min_arch % 10))
Device Used:  NVIDIA GeForce GT 730
Capability:  (3, 5)
<class 'darknet.Darknet'>
Exception in thread Thread-1:
Traceback (most recent call last):
  File "/home/assertst/anaconda3/lib/python3.8/threading.py", line 932, in _bootstrap_inner
    self.run()
  File "/home/assertst/anaconda3/lib/python3.8/threading.py", line 870, in run
    self._target(*self._args, **self._kwargs)
  File "nephro.py", line 88, in processResult
    output = self.model(Variable(img), self.CUDA)
  File "/home/assertst/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/assertst/nephroplus/darknet.py", line 319, in forward
    x = self.module_list[i](x)
  File "/home/assertst/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/assertst/.local/lib/python3.8/site-packages/torch/nn/modules/container.py", line 139, in forward
    input = module(input)
File "/home/assertst/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/assertst/.local/lib/python3.8/site-packages/torch/nn/modules/activation.py", line 756, in forward
    return F.leaky_relu(input, self.negative_slope, self.inplace)
  File "/home/assertst/.local/lib/python3.8/site-packages/torch/nn/functional.py", line 1472, in leaky_relu
    result = torch._C._nn.leaky_relu_(input, negative_slope)
RuntimeError: CUDA error: no kernel image is available for execution on the device
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
```.

@rajeshroy402
Copy link

Hi @aluo-x - #32759 (comment)
Can you help me with my GT 730?

@aluo-x
Copy link
Author

aluo-x commented Jul 27, 2021

Our cluster has retired GPUs from the K80 generation, I believe we target Nvidia 1080Ti~3090 now - so this is no longer a problem for me, and I no longer build Pytorch/MAGMA from source.

But installing magma from conda will not work for you, you will need to clone it from here then specify the correct arch here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: binaries Anything related to official binaries that we release to users module: cuda Related to torch.cuda, and CUDA support in general triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

5 participants