Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Recreate Benchmarks on A100 #5

Open
AndrewZhaoLuo opened this issue Apr 20, 2023 · 8 comments
Open

Recreate Benchmarks on A100 #5

AndrewZhaoLuo opened this issue Apr 20, 2023 · 8 comments

Comments

@AndrewZhaoLuo
Copy link

AndrewZhaoLuo commented Apr 20, 2023

Hey all,

Very interesting work! I am trying to recreate some of the results you have in table 1.

Do you happen to have the script + audio used on hand? I am having trouble matching it on my machine:

from whisper_jax import FlaxWhisperPipline
import jax.numpy as jnp
import time 
import librosa

SAMPLING_RATE = 16000
audio, sr = librosa.load('test_audio.mp3', sr=SAMPLING_RATE)

# instantiate pipeline in bfloat16
pipeline = FlaxWhisperPipline("openai/whisper-large-v2", dtype=jnp.float16, batch_size=32)

print("Warmup compiling forward pass")
text = pipeline(audio)


start_time = time.time()
for i in range(10):
    print(f"Go iter {i}")
    text = pipeline(audio)
end_time = time.time()
print(text)
print(f"Took {end_time - start_time} s")

# Took 330.93562269210815 s

test_audio.mp3 is a 13 min ted talk clip. I get about 30s per transcription iteration with this. Could be a bunch of things, but just want to know if this code would expect to give the benchmark results under optimal config.

@AndrewZhaoLuo AndrewZhaoLuo changed the title Recreate Results on A100 Recreate Benchmarks on A100 Apr 20, 2023
@sanchit-gandhi
Copy link
Owner

sanchit-gandhi commented Apr 21, 2023

This looks more or less correct! The benchmarks we ran were from a bunch of YouTube videos (I can give you the URLs), and transcription time is somewhat dependent on audio file. This slower transcription time could be because Whisper is getting caught in a hallucination in one of the batches, causing it to generate till it hits max length (448 tokens).

You could check whether the text has repetitions, or try instantiating the pipeline with a lower max length (we set it to 128 and got complete transcriptions):

# instantiate pipeline in float16
pipeline = FlaxWhisperPipline("openai/whisper-large-v2", dtype=jnp.float16, batch_size=32, max_length=128)

@ahxxm
Copy link

ahxxm commented Apr 23, 2023

reproduced the hallucination with this audio file on Huggingface
image

it's impressively fast

but 16G memory seems not enough for statement jax = FlaxWhisperPipline("openai/whisper-large-v2", dtype=jnp.bfloat16, batch_size=16), how much memory does it require to instantiate the pipeline? based on a very rough observation, the GPU(Tesla T4, 14G) memory was filled instantly, then memory grows slowly until it hits 16G, then OOM killed

just followed discussions in #7 and the transformer issue, seems we haven't found the cause yet

@sanchit-gandhi
Copy link
Owner

Also worth making sure your audio is already at 16kHz so that we don't resample in the Flax Whisper pipeline (which can be lengthy for long audio files)

@sanchit-gandhi
Copy link
Owner

The absolute transcription time is somewhat dependent on audio sample - since it's proportional to number of tokens generated, it'll depend on speaking rate, propensity to hallucinate, speech:silence ratio, etc. Since we what we really care about is the relative time between systems (rather than necessarily the absolute ones), it would be cool to benchmark with the same audio file using OpenAI's Whisper and Transformer's Whisper on GPU to see what we're aiming for

@AndrewZhaoLuo
Copy link
Author

AndrewZhaoLuo commented Apr 25, 2023

One more question to do some fair comparisons across libraries. If I am reading the codebase correctly, this is doing a greedy search (e.g. beam_size=1). Is that correct?

@sanchit-gandhi
Copy link
Owner

Correct!

@AndrewZhaoLuo
Copy link
Author

AndrewZhaoLuo commented Apr 26, 2023

Thanks for all your help.

Finally, it might be good to just have the audio you used to benchmark. @sanchit-gandhi can you direct me to the youtube video?

@s-tomar
Copy link

s-tomar commented Nov 23, 2023

Hi,

On CPU only system (no TPU/GPU), the following deteriorates the overall performance. For a <10 min audio, it consumes almost 25% more time.

SAMPLING_RATE = 16000
audio, sr = librosa.load('test_audio.mp3', sr=SAMPLING_RATE)

I guess there are quite a few parameters to be tuned to achieve good/best performance, and improper tuning can worsen the situation 🤔

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants