Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

No GPU/TPU found, falling back to CPU. #30

Open
ontheway-arch opened this issue Apr 23, 2023 · 8 comments
Open

No GPU/TPU found, falling back to CPU. #30

ontheway-arch opened this issue Apr 23, 2023 · 8 comments

Comments

@ontheway-arch
Copy link

Configuration:
GPU:3090
Driver Version: 525.105.17
CUDA Version: 12.0
image

but I find the problem that never find the GPU, How to solve this problem?
image

@ontheway-arch
Copy link
Author

I have added the following commands to the code:
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]=".XX"
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"]="platform"
if I delete it,then it couldn't run
image

@RaulKite
Copy link

RaulKite commented Apr 24, 2023

Same error here

import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print("Using device:", device)

Using device: cuda

from whisper_jax import FlaxWhisperPipline

# instantiate pipeline
pipeline = FlaxWhisperPipline("openai/whisper-large-v2")

# JIT compile the forward call - slow, but we only do once
text = pipeline("chris_3_first_minutes.mp3")

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

@RaulKite
Copy link

RaulKite commented Apr 24, 2023

stupid me. It is tensorflow!

pip install tensorflow

and fixed

@vakkov
Copy link

vakkov commented Apr 24, 2023

For people who have no problems with their tensorflow installations - get a jaxlib compiled against the appropriate CUDA and CUDNN libraries, e.g.:

pip install -U jaxlib==0.4.7+cuda11.cudnn86 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

@sanchit-gandhi
Copy link
Owner

sanchit-gandhi commented Apr 24, 2023

Hey all! Super sorry to hear it's been so tough to install JAX on GPU - I would definitely recommend reading through the official JAX installation guide to see if there are any hardware specific instructions: https://github.com/google/jax#installation

If your problems persist, by all means open an issue on the JAX repo and request help for your GPU set-up.

You can quickly test that you've got a proper installation of JAX on GPU/TPU with the following Python code:

import jax

num_devices = jax.device_count()
device_type = jax.devices()[0].device_kind

print(f"Found {num_devices} JAX devices of type {device_type}.")

I would only proceed to install Whisper JAX once this check gives you the correct output

@luisroque
Copy link

Simply using CUDA and CUDNN installed from pip wheels worked for me (NVIDIA GPU):

pip install --upgrade pip

# CUDA 12 installation
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

# CUDA 11 installation
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

@sanchit-gandhi
Copy link
Owner

This is how I installed JAX previously on GPU as well @luisroque - you just need to match your JAX installation to your CUDA version!

@themanyone
Copy link

Assuming you have matching video drivers, cuda, and cudnn from Nvidia website.

Don't install torch!

It will currently downgrade cudnn to an incompatible version.
But you can get it back.

pip3 install --upgrade nvidia-cudnn-cu11

This is why we use venv or conda. Put torch in a separate container.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants