Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error when installing requirements #6

Open
kesevone opened this issue Mar 17, 2024 · 24 comments
Open

Error when installing requirements #6

kesevone opened this issue Mar 17, 2024 · 24 comments

Comments

@kesevone
Copy link

i have installed python 3.10 and venv. Trying to "pip install -r requirements.txt"

ERROR: Ignored the following versions that require a different python version: 1.6.2 Requires-Python >=3.7,<3.10; 1.6.3 Requires-Python >=3.7,<3.10; 1.7.0 Requires-Python >=3.7,<3.10; 1.7.1 Requires-Python >=3.7,<3.10
ERROR: Could not find a version that satisfies the requirement jaxlib==0.4.25+cuda12.cudnn89; extra == "cuda12_pip" (from jax[cuda12-pip]) (from versions: 0.4.13, 0.4.14, 0.4.16, 0.4.17, 0.4.18, 0.4.19, 0.4.20, 0.4.21, 0.4.22, 0.4.23, 0.4.24, 0.4.25)
ERROR: No matching distribution found for jaxlib==0.4.25+cuda12.cudnn89; extra == "cuda12_pip"

@ibab
Copy link
Contributor

ibab commented Mar 17, 2024

I believe this should be fixed now. Can you try again?

@kesevone
Copy link
Author

kesevone commented Mar 17, 2024

I believe this should be fixed now. Can you try again?

No, it doesn't work

WARNING: jax 0.4.25 does not provide the extra 'cuda12-pip'
INFO: pip is looking at multiple versions of jax[cuda12-pip] to determine which version is compatible with other requirements. This could take a while.
ERROR: Ignored the following versions that require a different python version: 1.6.2 Requires-Python >=3.7,<3.10; 1.6.3 Requires-Python >=3.7,<3.10; 1.7.0 Requires-Python >=3.7,<3.10; 1.7.1 Requires-Python >=3.7,<3.10
ERROR: Could not find a version that satisfies the requirement jaxlib==0.4.25+cuda12.cudnn89; extra == "cuda12_pip" (from jax[cuda12-pip]) (from versions: 0.4.13, 0.4.14, 0.4.16, 0.4.17, 0.4.18, 0.4.19, 0.4.20, 0.4.21, 0.4.22, 0.4.23, 0.4.24, 0.4.25)
ERROR: No matching distribution found for jaxlib==0.4.25+cuda12.cudnn89; extra == "cuda12_pip"

@kesevone
Copy link
Author

I think this is a problem in my system, I have Windows, there is no CUDA in jax for Windows as I understand it

@alpernae
Copy link

Hello @ibab,

I'm getting same error while installing requirements in WSL-2 Kali. Looks like fix doesn't work or I'm doing some kind of mistake while installing requirements. Error message under bellow;

Used command: pip install -r requirements.txt

ERROR: Could not find a version that satisfies the requirement jaxlib==0.4.25+cuda12.cudnn89; extra == "cuda12_pip" (from jax[cuda12-pip]) (from versions: 0.4.3, 0.4.4, 0.4.6, 0.4.7, 0.4.9, 0.4.10, 0.4.11, 0.4.12, 0.4.13, 0.4.14, 0.4.16, 0.4.17, 0.4.18, 0.4.19, 0.4.20, 0.4.21, 0.4.22, 0.4.23, 0.4.24, 0.4.25) ERROR: No matching distribution found for jaxlib==0.4.25+cuda12.cudnn89; extra == "cuda12_pip"

@alpernae
Copy link

I think this is a problem in my system, I have Windows, there is no CUDA in jax for Windows as I understand it

Hi @kesevone Are you installing via WSL or trying to install on windows? I use wsl and getting same error.

@kesevone
Copy link
Author

kesevone commented Mar 17, 2024

I think this is a problem in my system, I have Windows, there is no CUDA in jax for Windows as I understand it

Hi @kesevone Are you installing via WSL or trying to install on windows? I use wsl and getting same error.

I'm trying to install on windows, now I'll try on wsl

@alpernae
Copy link

I think this is a problem in my system, I have Windows, there is no CUDA in jax for Windows as I understand it

Hi @kesevone Are you installing via WSL or trying to install on windows? I use wsl and getting same error.

I'm trying to install on windows, now I'll try on wsl

If you successfully install on WSL can you tell me too

@yarodevuci
Copy link

yarodevuci commented Mar 17, 2024

pip install dm-haiku

in requirements it's dm_haiku==0.0.12 with underscore ...

@ahsan3219
Copy link

I got the same error and try this
pip install git+https://github.com/deepmind/dm-haiku

It work on my case.

@hidenway
Copy link

I get an error on startup
INFO:jax._src.xla_bridge:Unable to initialize backend 'cuda': INFO:jax._src.xla_bridge:Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig' INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': UNIMPLEMENTED: LoadPjrtPlugin is not implemented on windows yet. INFO:rank:Initializing mesh for self.local_mesh_config=(1, 8) self.between_hosts_config=(1, 1)... INFO:rank:Detected 1 devices in mesh Traceback (most recent call last): File "c:\Users\Maksim\Desktop\grok-1\run.py", line 72, in <module> main() File "c:\Users\Maksim\Desktop\grok-1\run.py", line 63, in main inference_runner.initialize() File "c:\Users\Maksim\Desktop\grok-1\runners.py", line 282, in initialize runner.initialize( File "c:\Users\Maksim\Desktop\grok-1\runners.py", line 181, in initialize self.mesh = make_mesh(self.local_mesh_config, self.between_hosts_config) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "c:\Users\Maksim\Desktop\grok-1\runners.py", line 586, in make_mesh device_mesh = mesh_utils.create_hybrid_device_mesh( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "C:\Users\Maksim\AppData\Local\Programs\Python\Python311\Lib\site-packages\jax\experimental\mesh_utils.py", line 373, in create_hybrid_device_mesh per_granule_meshes = [create_device_mesh(mesh_shape, granule) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "C:\Users\Maksim\AppData\Local\Programs\Python\Python311\Lib\site-packages\jax\experimental\mesh_utils.py", line 373, in <listcomp> per_granule_meshes = [create_device_mesh(mesh_shape, granule) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "C:\Users\Maksim\AppData\Local\Programs\Python\Python311\Lib\site-packages\jax\experimental\mesh_utils.py", line 302, in create_device_mesh raise ValueError(f'Number of devices {len(devices)} must equal the product ' ValueError: Number of devices 1 must equal the product of mesh_shape (1, 8)

@AndreSlavescu
Copy link

That just means you don't have the appropriate number of devices. The mesh_shape is the configuration for what is expected, in this case 8 devices to distribute the model over and run inference on. If you don't allocate exactly 8 gpus, it will not work, granted running inference with this model will require a minimum of 8 large GPUs anyway.

@yarodevuci
Copy link

@AndreSlavescu I have tried with 1. i did mesh_shape (1, 1) instead mesh_shape (1, 8)

@lbg-686
Copy link

lbg-686 commented Mar 18, 2024

I also have the same problem. Have you fixed it now?

@felifri
Copy link

felifri commented Mar 18, 2024

I got the same error, but running only:

pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

worked for me, so it's probably because of the dm_haiku problem as described above

@AndreSlavescu
Copy link

@AndreSlavescu I have tried with 1. i did mesh_shape (1, 1) instead mesh_shape (1, 8)

In theory that should work in terms of not breaking at that point, but you won't have enough memory to load onto a single GPU, so you will get a CUDA OOM errror.

@imfunniee
Copy link

same issue

@yarodevuci
Copy link

@AndreSlavescu I have tried with 1. i did mesh_shape (1, 1) instead mesh_shape (1, 8)

In theory that should work in terms of not breaking at that point, but you won't have enough memory to load onto a single GPU, so you will get a CUDA OOM errror.

yep, I have up.. 300GB dowloaded for nothing :D

@ywiyogo
Copy link

ywiyogo commented Mar 19, 2024

running these commands after the error, fix the installation issue

pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install -r requirements.txt

@AndreSlavescu
Copy link

@AndreSlavescu I have tried with 1. i did mesh_shape (1, 1) instead mesh_shape (1, 8)

In theory that should work in terms of not breaking at that point, but you won't have enough memory to load onto a single GPU, so you will get a CUDA OOM errror.

yep, I have up.. 300GB dowloaded for nothing :D

If you want to try it and don’t have access to a 8 gpu cluster, there are cloud compute solutions with AWS sagemaker EC2 instances, lambda labs, coreweave, and a few more where you might be able to get an 8xA100 80GB (640GB total) allocation.

@guobi777
Copy link

change requirements.txt
dm-haiku==0.0.12
jax[cuda12-pip]==0.4.25 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
numpy==1.26.4
sentencepiece==0.2.0

@Jintao97
Copy link

INFO:jax._src.xla_bridge:Unable to initialize backend 'cuda': 
INFO:jax._src.xla_bridge:Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
INFO:rank:Initializing mesh for self.local_mesh_config=(1, 8) self.between_hosts_config=(1, 1)...
INFO:rank:Detected 1 devices in mesh
Traceback (most recent call last):
  File "/workspace/grok-1/run.py", line 72, in <module>
    main()
  File "/workspace/grok-1/run.py", line 63, in main
    inference_runner.initialize()
  File "/workspace/grok-1/runners.py", line 282, in initialize
    runner.initialize(
  File "/workspace/grok-1/runners.py", line 181, in initialize
    self.mesh = make_mesh(self.local_mesh_config, self.between_hosts_config)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/grok-1/runners.py", line 586, in make_mesh
    device_mesh = mesh_utils.create_hybrid_device_mesh(
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/.pyenv_mirror/user/current/lib/python3.12/site-packages/jax/experimental/mesh_utils.py", line 373, in create_hybrid_device_mesh
    per_granule_meshes = [create_device_mesh(mesh_shape, granule)
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/.pyenv_mirror/user/current/lib/python3.12/site-packages/jax/experimental/mesh_utils.py", line 302, in create_device_mesh
    raise ValueError(f'Number of devices {len(devices)} must equal the product '
ValueError: Number of devices 1 must equal the product of mesh_shape (1, 8)


How do you solve this?

@user-matth
Copy link

After changing requirements.txt to:

dm-haiku==0.0.12
jax[cuda12-pip]==0.4.25 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
numpy==1.26.4
sentencepiece==0.2.0

and then running pip install -r requirements.txt, it ˜worked˜. But when I run python3 run.py I just got this new issue:

INFO:jax._src.xla_bridge:Unable to initialize backend 'cuda':
INFO:jax._src.xla_bridge:Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: dlopen(libtpu.so, 0x0001): tried: 'libtpu.so' (no such file), '/System/Volumes/Preboot/Cryptexes/OSlibtpu.so' (no such file), '/usr/lib/libtpu.so' (no such file, not in dyld cache), 'libtpu.so' (no such file), '/usr/lib/libtpu.so' (no such file, not in dyld cache)
INFO:rank:Initializing mesh for self.local_mesh_config=(1, 8) self.between_hosts_config=(1, 1)...
INFO:rank:Detected 1 devices in mesh
Traceback (most recent call last):
  File "/Users/matheuscardoso/Projects/grok-1/run.py", line 72, in <module>
    main()
  File "/Users/matheuscardoso/Projects/grok-1/run.py", line 63, in main
    inference_runner.initialize()
  File "/Users/matheuscardoso/Projects/grok-1/runners.py", line 282, in initialize
    runner.initialize(
  File "/Users/matheuscardoso/Projects/grok-1/runners.py", line 181, in initialize
    self.mesh = make_mesh(self.local_mesh_config, self.between_hosts_config)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/matheuscardoso/Projects/grok-1/runners.py", line 586, in make_mesh
    device_mesh = mesh_utils.create_hybrid_device_mesh(
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/jax/experimental/mesh_utils.py", line 373, in create_hybrid_device_mesh
    per_granule_meshes = [create_device_mesh(mesh_shape, granule)
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/jax/experimental/mesh_utils.py", line 302, in create_device_mesh
    raise ValueError(f'Number of devices {len(devices)} must equal the product '
ValueError: Number of devices 1 must equal the product of mesh_shape (1, 8)

Did anyone solved it?

@cw-lucasgabriel
Copy link

I was able to run Grok-1 yesterday. As people have commented, what did the trick for us at CloudWalk (a Brazilian fintech) was to use our K8 cluster with at least 8xA100 GPUs (80 GB family). Grok-1 uses almost all the memory from the GPUs (so using only 1 or 2 GPUs will not give you enough memory).

Another thing that solved our problems was running: pip install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

Then, we just needed to run python run.py, and voilà.

@sbhavani
Copy link

sbhavani commented May 1, 2024

You can also pull this container to run grok: ghcr.io/nvidia/jax:grok from JAX Toolbox

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests