Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Running grok-1 with Hugging Face int8 weights with ROCm 6.0 fails #130

Closed
Spurthi-Bhat-ScalersAI opened this issue Mar 18, 2024 · 7 comments

Comments

@Spurthi-Bhat-ScalersAI
Copy link

Reproducing steps:

  1. Clone the grok-1 repo

  2. Install the jax library:

        pip install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html  
  3. Install the requirements:

    pip install -r requirements.txt
  4. Download the Hugging Face weights:

      huggingface-cli download xai-org/grok-1 --repo-type model --include ckpt/tensor* --local-dir checkpoints/ckpt-0 --local-dir-use-symlinks False
  5. The local mesh config change in run.py file

    local_mesh_config=(1, 1)
  6. Run the run.py file

     python3 run.py

The error below:

Traceback (most recent call last):
  File "/home/user/grok/grok-1/run.py", line 72, in <module>
    main()
  File "/home/user/grok/grok-1/run.py", line 67, in main
    print(f"Output for prompt: {inp}", sample_from_model(gen, inp,max_len=1, temperature=0.5))
  File "/home/user/grok/grok-1/runners.py", line 596, in sample_from_model
    next(server)
  File "/home/user/grok/grok-1/runners.py", line 480, in run
    rngs, last_output, memory, settings = self.prefill_memory(
  File "/home/user/grok/virtual/lib/python3.10/site-packages/haiku/_src/multi_transform.py", line 314, in apply_fn
    return f.apply(params, None, *args, **kwargs)
  File "/home/user/grok/virtual/lib/python3.10/site-packages/haiku/_src/transform.py", line 183, in apply_fn
    out, state = f.apply(params, None, *args, **kwargs)
  File "/home/user/grok/virtual/lib/python3.10/site-packages/haiku/_src/transform.py", line 456, in apply_fn
    out = f(*args, **kwargs)
  File "/home/user/grok/grok-1/runners.py", line 351, in hk_prefill_memory
    settings = jax.tree_map(
  File "/home/user/grok/grok-1/runners.py", line 352, in <lambda>
    lambda o,v: jax.lax.dynamic_update_index_in_dim(o,v, i, axis=0),
TypeError: dynamic_update_slice update shape must be smaller than operand shape, got update shape (1,) for operand shape (0,).

Issues:

  1. Does the run.py work with the int8 HuggingFace weights?
  2. What needs to be done to solve the above error?
  3. Is there support for running the grok-1 model on AMD ROCm 6.0?
@ibab
Copy link
Contributor

ibab commented Mar 18, 2024

The current code assumes that we use 8 GPUs, hence local_mesh_config=(1, 8). Are you sure you have enough memory on your device to run the model?

@rahulbatra85
Copy link

@Spurthi-Bhat-ScalersAI
Few things
-requirements.txt hardcodes JAX cuda wheels. To run on ROCm you would need to install ROCm wheels.(https://github.com/ROCm/jax/releases). Please note 0.4.25 release on ROCm is not out, but should be out soon.
-Also, current only AMD MI Instinct GPUs have JAX support. Which AMD GPU do you have?

@Konard
Copy link

Konard commented Mar 18, 2024

@ibab how much memory should be enough?

@rahulbatra85
Copy link

rahulbatra85 commented Mar 18, 2024

@Spurthi-Bhat-ScalersAI

Please see the steps below to install for ROCm

  1. Install JAX and Jaxlib wheels for ROCm from this link. https://github.com/ROCm/jax/releases/tag/jaxlib-v0.4.25. Alternatively, you can use the provided docker image.

  2. Install the remaining requirements from the requirements.txt. Comment out the jaxlib line

dm_haiku==0.0.12
#jax[cuda12_pip]==0.4.25 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
numpy==1.26.4
sentencepiece==0.2.0

Sample output
Groq1

@Spurthi-Bhat-ScalersAI
Copy link
Author

The current code assumes that we use 8 GPUs, hence local_mesh_config=(1, 8). Are you sure you have enough memory on your device to run the model?

@ibab Yes, I have 8 GPUs with around ~190 GB VRAM memory, which I assume is enough to run the model?

@Spurthi-Bhat-ScalersAI
Copy link
Author

@Spurthi-Bhat-ScalersAI Few things -requirements.txt hardcodes JAX cuda wheels. To run on ROCm you would need to install ROCm wheels.(https://github.com/ROCm/jax/releases). Please note 0.4.25 release on ROCm is not out, but should be out soon. -Also, current only AMD MI Instinct GPUs have JAX support. Which AMD GPU do you have?

Yes, I am using the AMD Instinct MI300X GPU.

@Spurthi-Bhat-ScalersAI

Please see the steps below to install for ROCm

  1. Install JAX and Jaxlib wheels for ROCm from this link. https://github.com/ROCm/jax/releases/tag/jaxlib-v0.4.25. Alternatively, you can use the provided docker image.
  2. Install the remaining requirements from the requirements.txt. Comment out the jaxlib line
dm_haiku==0.0.12
#jax[cuda12_pip]==0.4.25 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
numpy==1.26.4
sentencepiece==0.2.0

Sample output Groq1

@rahulbatra85 Thank you for the quick support regarding the ROCm-JAX!

@Spurthi-Bhat-ScalersAI
Copy link
Author

@rahulbatra85 Got it working with the ROCm-JAX latest release! Thanks once again for the release and detailed steps!
@ibab Thanks for the clarification regarding the local mesh config. I am able to run the script successfully!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants