Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ValueError: Received incompatible devices for pjitted computation #1

Closed
wimjan123 opened this issue Apr 8, 2023 · 3 comments
Closed

Comments

@wimjan123
Copy link

wimjan123 commented Apr 8, 2023

Awesome repo! I have one question tho: Whenever I try running this code on my own TPU-v4-8, I get the following error:

WARNING:absl:Tiling device assignment mesh by hosts, which may lead to reduced XLA collective performance. To avoid this, modify the model parallel submesh or run with more tasks per host.
Traceback (most recent call last):
  File "fastapi_app.py", line 17, in <module>
    pipeline.shard_params()
  File "/root/ai/whisper-jax/whisper_jax/pipeline.py", line 127, in shard_params
    self.params = p_shard_params(freeze(self.params))
  File "/root/ai/whisper-jax/whisper_jax/partitioner.py", line 787, in __call__
    return self._pjitted_fn(*args)
  File "/usr/local/lib/python3.8/dist-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/jax/_src/pjit.py", line 238, in cache_miss
    outs, out_flat, out_tree, args_flat = _python_pjit_helper(
  File "/usr/local/lib/python3.8/dist-packages/jax/_src/pjit.py", line 193, in _python_pjit_helper
    raise ValueError(msg) from None
jax._src.traceback_util.UnfilteredStackTrace: ValueError: Received incompatible devices for pjitted computation. Got argument params['model']['decoder']['embed_positions']['embedding'] of FlaxPreTrainedModel.to_bf16 with shape float32[448,1280] and device ids [0] on platform CPU and pjit's devices with device ids [0, 2, 1, 3] on platform TPU

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "fastapi_app.py", line 17, in <module>
    pipeline.shard_params()
  File "/root/ai/whisper-jax/whisper_jax/pipeline.py", line 127, in shard_params
    self.params = p_shard_params(freeze(self.params))
  File "/root/ai/whisper-jax/whisper_jax/partitioner.py", line 787, in __call__
    return self._pjitted_fn(*args)
ValueError: Received incompatible devices for pjitted computation. Got argument params['model']['decoder']['embed_positions']['embedding'] of FlaxPreTrainedModel.to_bf16 with shape float32[448,1280] and device ids [0] on platform CPU and pjit's devices with device ids [0, 2, 1, 3] on platform TPU

Any idea how I can fix it?

@sanchit-gandhi
Copy link
Owner

Hey @wimjan123! The issue originates because we now load Flax weights on CPU by default in Transformers: huggingface/transformers#15295

Currently, the easiest workaround is to comment out the following lines: https://github.com/huggingface/transformers/blob/fe1f5a639d93c9272856c670cff3b0e1a10d5b2b/src/transformers/modeling_flax_utils.py#L836-L838

This will leave the Flax weights to default to the accelerator device you have available. This should be fixed by default by the time the repo is announced!

@wimjan123
Copy link
Author

Everything seems to work great. And the speed is absolutely crazy. 50 minutes audio transcribed in 30 seconds. I know this is still a WIP, but if I can give one suggestion: maybe add a way to export the output as txt and srt files? Awesome repo and a gamechanger for whisper in terms of speed.

@sanchit-gandhi
Copy link
Owner

Awesome, glad to hear that @wimjan123! Let me know if you encounter any other issues! The repo has only really been tested so far with my personal experimenting.

We're otherwise more or less ready for release though. Getting quite similar numbers to you in my benchmark tests on a GPU: https://github.com/sanchit-gandhi/whisper-jax#benchmarks (note that this does not include time to load the audio file, which is the same for all three repos but can be a significant proportion of the overall transcription time)

Thanks for the tip! I'll look into how we could export the output as txt/srt files. Currently, the easiest way is to write to a file manually:

pred_str = pipeline(...)

with open(output.txt", "w") as text_file:
    text_file.write(pred_str)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants