## Whisper JAX ⚡️

This Kaggle notebook demonstratese how to run Whisper JAX on a TPU v3-8. Whisper JAX is a highly optimised JAX implementation of the Whisper model by OpenAI, largely built on the 🤗 Hugging Face Transformers Whisper implementation. Compared to OpenAI's PyTorch code, Whisper JAX runs over **70x faster**, making it the fastest Whisper implementation available.

The Whisper JAX model is also running as a [demo](https://huggingface.co/spaces/sanchit-gandhi/whisper-jax) on the Hugging Face Hub.

We'll start by installing the required Python packages:

In [None]:
!pip install --quiet --upgrade pip
!pip install --quiet "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
!pip install --quiet git+https://github.com/sanchit-gandhi/whisper-jax.git datasets soundfile librosa

Now let's connect to a TPU! Open up the settings menu in the Notebook editor, and select ‘TPU v3-8’ in the Accelerator menu. Refer to the guide [Introducing TPUs to Kaggle](https://www.kaggle.com/product-feedback/129828) for more information on choosing a TPU in Kaggle.

We then need to some pre set-up to register JAX our TPU. Note that this step is note required for Cloud TPUs.

In [None]:
import requests 
from jax.config import config
    
if "TPU_NAME" in os.environ and "KAGGLE_DATA_PROXY_TOKEN" in os.environ:
    use_tpu = True

    if "TPU_DRIVER_MODE" not in globals():
        url = "http:" + os.environ["TPU_NAME"].split(":")[1] + ":8475/requestversion/tpu_driver_nightly"
        resp = requests.post(url)
        TPU_DRIVER_MODE = 1

    config.FLAGS.jax_xla_backend = "tpu_driver"
    config.FLAGS.jax_backend_target = os.environ["TPU_NAME"]

    # Enforce bfloat16 multiplication
    config.update("jax_default_matmul_precision", "bfloat16")
    print("Registered (Kaggle) TPU:", config.FLAGS.jax_backend_target)
else:
    use_tpu = False

Let's verify that we've been assigned a TPU. Run the following to see the TPU devices we have available:

In [3]:
import jax
jax.devices()

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=2, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,1,0), core_on_chip=0)]

Cool! We've got 8 TPU devices packaged into one overall accelerator.

## Loading the Pipeline

The recommended way of running Whisper JAX is through the [`FlaxWhisperPipline`](https://github.com/sanchit-gandhi/whisper-jax/blob/main/whisper_jax/pipeline.py#L57) class. This class handles all the necessary pre- and post-processing for the model, as well as wrapping the generate method for data parallelism across all available accelerator devices.

Whisper JAX makes use of JAX's [`pmap`](https://jax.readthedocs.io/en/latest/_autosummary/jax.pmap.html) function for data parallelism across GPU/TPU devices. This function is Just In Time (JIT) compiled the first time it is called. Thereafter, the function will be cached, enabling it to be run in super-fast time.


Let's load the large-v2 model in bfloat16 (half-precision). Using half-precision will speed-up the computation quite considerably by storing intermediate tensors in half-precision. There is no change to the precision of the model weights.

We'll also make use of _batching_ for single audio inputs: the audio is first chunked into 30 second segments, and then chunks dispatched to the model to be transcribed in parallel. By batching an audio input and transcribing it in parallel, we get a ~10x speed-up compared to transcribing the audio samples sequentially.

In [2]:
from whisper_jax import FlaxWhisperPipline
import jax.numpy as jnp

pipeline = FlaxWhisperPipline("openai/whisper-large-v2", dtype=jnp.bfloat16, batch_size=16)

  from .autonotebook import tqdm as notebook_tqdm
tcmalloc: large alloc 6173270016 bytes == 0x10c2c2000 @  0x7fd11ee7f680 0x7fd11eea0824 0x5fa131 0x649f21 0x5c4f26 0x4f30be 0x64ec18 0x5050d3 0x56bbdf 0x569cea 0x50b2b0 0x56cbd1 0x569cea 0x5f6a13 0x59c757 0x5f6fbf 0x5715a2 0x569cea 0x68e7b7 0x601174 0x5c52f0 0x56b9fd 0x500a78 0x56d3fd 0x500a78 0x56d3fd 0x500a78 0x5047d6 0x56bbdf 0x5f6836 0x56b9fd


We'll then initialise a compilation cache, which will speed-up the compilation time if we close our kernel and want to compile the model again:

In [3]:
from jax.experimental.compilation_cache import compilation_cache as cc

cc.initialize_cache("./jax_cache")

Initialized persistent compilation cache at ./jax_cache


## 🎶 Load an audio file

Let's load up a long audio file for our tests. We provide 5 and 30 mins audio files created by contatenating consecutive sample of the [LibriSpeech ASR](https://huggingface.co/datasets/librispeech_asr) corpus, which we can load in one line through Hugging Face Datastes' [`load_dataset`](https://huggingface.co/docs/datasets/loading#load) function. Note that you can also pass in any `.mp3`, `.wav` or `.flac` audio file directly to the Whisper JAX pipeline, and it will take care of loading the audio file for you.

In [4]:
from datasets import load_dataset

test_dataset = load_dataset("sanchit-gandhi/whisper-jax-test-files", split="train")
audio = test_dataset[0]["audio"]  # load the first sample (5 mins) and get the audio array

Found cached dataset parquet (/home/sanchitgandhi/.cache/huggingface/datasets/sanchit-gandhi___parquet/sanchit-gandhi--whisper-jax-test-files-95479fe55e88baac/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)


We can take a listen to the audio file that we've loaded, you'll see that it's approximately 5 mins long:

In [None]:
from IPython.display import Audio

Audio(audio["array"], rate=audio["sampling_rate"])

## Run the model

Now we're ready to transcribe! We'll need to compile the `pmap` function the first time we use it. You can expect compilation to take ~2 minutes on a TPU v3-8 with a batch size of 16. Enough time to grab a coffee ☕️

Thereafter, we can use our cached `pmap` function, which you'll see is amazingly fast.

In [6]:
# JIT compile the forward call - slow, but we only do once
%time text = pipeline(audio)

CPU times: user 4min 18s, sys: 7min 53s, total: 12min 12s
Wall time: 1min 35s


In [7]:
# used cached function thereafter - super fast!
%time text = pipeline(audio)

CPU times: user 12.7 s, sys: 39.5 s, total: 52.2 s
Wall time: 2.93 s


In [8]:
# let's check our transcription - looks spot on!
print(text)

{'text': " Chapter 16. I might have told you of the beginning of this liaison in a few lines, but I wanted you to see every step by which we came, I to agree to whatever Marguerite wished, Marguerite to be unable to live apart from me. It was the day after the evening when she came to see me that I sent her Manon Lescate. From that time, seeing that I could not change my mistress's life, I changed my own. I wished above all not to leave myself time to think over the position I had accepted, for, in spite of myself, it was a great distress to me. Thus my life, generally so calm, assumed all at once an appearance of noise and disorder. Never believe, however disinterested the love of a kept woman may be, that it will cost one nothing. Nothing is so expensive as their caprices, flowers, boxes at the theatre, suppers, days in the country, which one can never refuse to one's mistress. As I have told you, I had little money. My father was, and still is, Receiver General at sea. He has a grea

## Run it again!

Now let's step it up a notch. Let's try transcribing 30 minutes of audio from the LibriSpeech dataset. We'll first load up and listen to the second sample from our dataset, which corresponds to the 30 min audio file. We'll then pass the audio to the model for transcription, again timing how long the foward pass takes:

In [None]:
audio = test_dataset[1]["audio"]  # load the second sample (30 mins) and get the audio array

Audio(audio["array"], rate=audio["sampling_rate"])

In [10]:
# transcribe using cached function
%time text = pipeline(audio)

CPU times: user 1min 12s, sys: 3min 56s, total: 5min 8s
Wall time: 14.6 s


Just 14s to transcribe for 30 mins of audio! That means you could transcribe an entire 2 hour movie in under 1 minute 🤯 By increasing the batch size, we could also reduce the transcription time for long audio files further: increasing the batch size by 2x roughly decreases the transcription time by 2x, provided the overall batch size is less than the total audio time.

## ⏰ Timestamps and more

We can also get timestamps from the model by passing `return_timestamps=True`, but this will require a recompilation since we change the signature of the forward pass. 

The timestamps compilation takes longer than the non-timestamps one. Luckily, because we initialised our compilation cache above, we're not starting from scratch in compiling this time. This is the last compilation we need to do!

In [11]:
%time outputs = pipeline(audio, return_timestamps=True)
text = outputs["text"]  # transcription
chunks = outputs["chunks"]  # transcription + timestamps

CPU times: user 5min 34s, sys: 10min 38s, total: 16min 12s
Wall time: 1min 55s


NameError: name 'outputs' is not defined

We've shown how you can transcibe an audio file in English. The pipeline is also compatible with two further arguments that you can use to control the generation process. It's perfectly fine to omit these if you want speech transcription and the Whisper model to automatically detect which language the audio is in. Otherwise, you can change them depending on your task/language:


* `task`: task to use for generation, either `"transcribe"` or `"translate"`. Defaults to `"transcribe"`.
* `language`: language token to use for generation, can be either in the form of `"<|en|>"`, `"en"` or `"english"`. Defaults to `None`, meaning the language is automatically inferred from the audio input. Optional, and only relevant if the source audio language is known a-priori.