## Whisper JAX ⚡️

This Kaggle notebook demonstratese how to run Whisper JAX on a TPU v3-8. Whisper JAX is a highly optimised JAX implementation of the Whisper model by OpenAI, largely built on the 🤗 Hugging Face Transformers Whisper implementation. Compared to OpenAI's PyTorch code, Whisper JAX runs over **70x faster**, making it the fastest Whisper implementation available.

The Whisper JAX model is also running as a [demo](https://huggingface.co/spaces/sanchit-gandhi/whisper-jax) on the Hugging Face Hub.

We'll start by installing the required Python packages:

In [2]:
!pip install --quiet jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
!pip install --quiet git+https://github.com/sanchit-gandhi/whisper-jax.git datasets soundfile librosa

^C
[31mERROR: Operation cancelled by user[0m[31m
[0m

Let's verify that we've been assigned a TPU. Run the following to see the TPU devices we have available:

In [None]:
import jax
jax.devices()

Cool! We've got 8 TPU devices packaged into one overall accelerator.

## Loading the Pipeline

The recommended way of running Whisper JAX is through the [`FlaxWhisperPipline`](https://github.com/sanchit-gandhi/whisper-jax/blob/main/whisper_jax/pipeline.py#L57) class. This class handles all the necessary pre- and post-processing for the model, as well as wrapping the generate method for data parallelism across all available accelerator devices.

Whisper JAX makes use of JAX's [`pmap`](https://jax.readthedocs.io/en/latest/_autosummary/jax.pmap.html) function for data parallelism across GPU/TPU devices. This function is Just In Time (JIT) compiled the first time it is called. Thereafter, the function will be cached, enabling it to be run in super-fast time.


Let's load the large-v2 model in bfloat16 (half-precision). Using half-precision will speed-up the computation quite considerably by storing intermediate tensors in half-precision. There is no change to the precision of the model weights.

We'll also make use of _batching_ for single audio inputs: the audio is first chunked into 30 second segments, and then chunks dispatched to the model to be transcribed in parallel. By batching an audio input and transcribing it in parallel, we get a ~10x speed-up compared to transcribing the audio samples sequentially.

In [None]:
from whisper_jax import FlaxWhisperPipline
import jax.numpy as jnp


# TODO(SG): update to large-v2 checkpoint
pipeline = FlaxWhisperPipline("openai/whisper-tiny", dtype=jnp.bfloat16, batch_size=16)

## Load an audio file

Let's load up a long audio file for our tests. We provide 5 and 30 mins audio files created by contatenating consecutive sample of the [LibriSpeech ASR](https://huggingface.co/datasets/librispeech_asr) corpus, which we can load in one line through Hugging Face Datastes' [`load_dataset`](https://huggingface.co/docs/datasets/loading#load) function. Note that you can also pass in any `.mp3`, `.wav` or `.flac` audio file directly to the Whisper JAX pipeline, and it will take care of loading the audio file for you.

In [None]:
from datasets import load_dataset

test_dataset = load_dataset("sanchit-gandhi/whisper-jax-test-file")
audio = test_dataset[0]["audio"]  # load the first sample (5 mins) and get the audio array

We can take a listen to the audio file that we've loaded, you'll see that it's approximately 5 mins long:

In [None]:
from IPython.display import Audio

Audio(audio["array"], rate=audio["sampling_rate"])

## Run the model

Now we're ready to transcribe! We'll need to compile the `pmap` function the first time we use it. You can expect compilation to take ~2 minutes on a TPU v3-8 with a batch size of 16.

Thereafter, we can use our cached `pmap` function, which you'll see is amazingly fast.

In [None]:
# JIT compile the forward call - slow, but we only do once
%time text = pipeline(audio)

# used cached function thereafter - super fast!!
%time text = pipeline(audio)

Now let's step it up a notch. Let's try transcribing 30 minutes of audio from the LibriSpeech dataset:

In [None]:
audio = test_dataset[1]["audio"]  # load the second sample (30 mins) and get the audio array
Audio(audio["array"], rate=audio["sampling_rate"])

In [None]:
# transcribe using cached function
%time text = pipeline(audio)

Just X mins to transcribe for 30 mins of audio! We can also get timestamps from the model by passing `return_timestamps=True`, but this will require a recompilation:

In [None]:
text = pipeline(audio, return_timestamps=True)
text = outputs["text"]  # transcription
chunks = outputs["chunks"]  # transcription + timestamps

We've shown how you can transcibe an audio file in English. The pipeline is also compatible with two further arguments that you can use to control the generation process:
* `task`: Task to use for generation, either `"transcribe"` or `"translate"`. Defaults to `"transcribe"`.
* `language`: Language token to use for generation, can be either in the form of `"<|en|>"`, `"en"` or `"english"`. Defaults to `None`, meaning the language is automatically inferred from the audio input. Optional, and only relevant if the source audio language is known a-priori.