## Whisper JAX ⚡️

This Kaggle notebook demonstratese how to run Indic Whisper JAX on a TPU v3-8. Indic Whisper JAX is a highly optimised JAX implementation of the Indic Whisper model by AI4Bharat. Compared to OpenAI's PyTorch code, Whisper JAX runs over **70x faster**, making it the fastest Whisper implementation available.

You can find the code [here](https://github.com/sanchit-gandhi/whisper-jax).

## Let's get started!

The first thing we need to do is connect to a TPU. Kaggle offers 20 hours of TPU v3-8 usage per month for free, which we'll make use of for this notebook. Refer to the guide [Introducing TPUs to Kaggle](https://www.kaggle.com/product-feedback/129828) for more information on TPU quotas in Kaggle.

You will need to register a Kaggle account and verify your phone number if you haven't done so already. Once verified, open up the settings menu in the Notebook editor (the small arrow in the bottom right). Then under _Notebook options_, select ‘TPU VM v3-8’ from the _Accelerator_ menu. You will also need to toggle the internet switch so that it is set to "on".

Once we've got a TPU allocated (there might be a queue to get one!), we can run the following to see the TPU devices we have available:

In [1]:
import jax
jax.devices()

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

Cool! We've got 8 TPU devices packaged into one accelerator.

Kaggle TPUs come with JAX pre-installed, so we can directly install the remaining Python packages. If you're running the notebook on a Cloud TPU, ensure you have installed JAX according to the official [installation guide](https://github.com/google/jax#pip-installation-google-cloud-tpu). 

We'll install [Whisper JAX](https://github.com/sanchit-gandhi/whisper-jax) from main, as well as `datasets` and `librosa` for loading audio files:

In [2]:
!pip install --quiet --upgrade pip
!pip install --quiet git+https://github.com/sanchit-gandhi/whisper-jax.git datasets librosa tqdm

[0m

## Loading the Pipeline

The recommended way of running Whisper JAX is through the [`FlaxWhisperPipline`](https://github.com/sanchit-gandhi/whisper-jax/blob/main/whisper_jax/pipeline.py#L57) class. This class handles all the necessary pre- and post-processing for the model, as well as wrapping the generate method for data parallelism across all available accelerator devices


Let's load the large-v2 model in bfloat16 (half-precision). Using half-precision will speed-up the computation quite considerably by storing intermediate tensors in half-precision. There is no change to the precision of the model weights.

We'll also make use of _batching_ for single audio inputs: the audio is first chunked into 30 second segments, and then chunks dispatched to the model to be transcribed in parallel. By batching an audio input and transcribing it in parallel, we get a ~10x speed-up compared to transcribing the audio samples sequentially.

In [3]:
from whisper_jax import FlaxWhisperPipline
import jax.numpy as jnp

pipeline = FlaxWhisperPipline("parthiv11/indic_whisper_hi_multi_gpu", dtype=jnp.bfloat16, batch_size=16)

  from .autonotebook import tqdm as notebook_tqdm
Some of the weights of FlaxWhisperForConditionalGeneration were initialized in float16 precision from the model checkpoint at parthiv11/indic_whisper_hi_multi_gpu:
[('model', 'decoder', 'embed_positions', 'embedding'), ('model', 'decoder', 'embed_tokens', 'embedding'), ('model', 'decoder', 'layer_norm', 'bias'), ('model', 'decoder', 'layer_norm', 'scale'), ('model', 'decoder', 'layers', '0', 'encoder_attn', 'k_proj', 'kernel'), ('model', 'decoder', 'layers', '0', 'encoder_attn', 'out_proj', 'bias'), ('model', 'decoder', 'layers', '0', 'encoder_attn', 'out_proj', 'kernel'), ('model', 'decoder', 'layers', '0', 'encoder_attn', 'q_proj', 'bias'), ('model', 'decoder', 'layers', '0', 'encoder_attn', 'q_proj', 'kernel'), ('model', 'decoder', 'layers', '0', 'encoder_attn', 'v_proj', 'bias'), ('model', 'decoder', 'layers', '0', 'encoder_attn', 'v_proj', 'kernel'), ('model', 'decoder', 'layers', '0', 'encoder_attn_layer_norm', 'bias'), ('model', 

We'll then initialise a compilation cache, which will speed-up the compilation time if we close our kernel and want to compile the model again:



In [4]:
%cd /kaggle/working
from jax.experimental.compilation_cache import compilation_cache as cc

cc.initialize_cache("./jax_cache")



/kaggle/working


## 🎶 Load an audio file

Let's load up a long audio file for our tests.Note that you can also pass in any `.mp3`, `.wav` or `.flac` audio file directly to the Whisper JAX pipeline, and it will take care of loading the audio file for you.

We can take a listen to the audio file that we've loaded - we'll see that it's approximately 5 mins long:

In [9]:
# 1 min audio
file_path='/kaggle/input/cpgram-audios/7020587774.mp3'

In [10]:
from IPython.display import Audio

Audio(file_path)

In [11]:
import librosa

# Load the audio file
signal, sample_rate = librosa.load(file_path, sr=16000)

# Print waveform shape and sample rate
print("Waveform shape:", signal.shape)
print("Sample rate:", sample_rate)


Waveform shape: (975744,)
Sample rate: 16000


## Run the model

Now we're ready to transcribe! We'll need to compile the `pmap` function the first time we use it. You can expect compilation to take ~2 minutes on a TPU v3-8 with a batch size of 16. Enough time to grab a coffee ☕️

Thereafter, we can use our cached `pmap` function, which you'll see is amazingly fast.

In [8]:
# JIT compile the forward call - slow, but we only do once
%time text= pipeline(signal)



CPU times: user 57.6 s, sys: 20.7 s, total: 1min 18s
Wall time: 49.4 s


In [63]:
# used cached function thereafter - super fast!
%time text= pipeline(signal)



CPU times: user 12 s, sys: 20.5 s, total: 32.5 s
Wall time: 3.33 s


In [64]:
# let's check our transcription - looks spot on!
print(text)

{'text': 'नमस्कार मैं विथापनी ऑफ़ एडमिनिस्ट्रेटिव फॉर्म फॉइन पब्लिक के लिए अब सब की तरफ की प्रश्ना बात करती हूँ बोली बोली मैं आपका यह कॉल आपकी ग्राह ऐसी न ने आपको सब का कॉल किया गया है जैसे एम सरकार प्रायः आपकी शिकायत दर जब तक दो अगस्त दो हज़ार बाईस तेईस को आपका फॉसिएशिएटिव दू बाई है मिनिस्ट्री ऑफ़ लेबर एंड इम्प्लॉयेंट अथपर जैसे हम क्या कर प्रकार हैं की अपनी शिकायत को बंद कर दिया जैसे इम्प्लायमेंट फण्ड ऑफ़ अर्गनाइजेशन की तरफ ऐसी इसमें आपको जो रिप्लाई दिया गया आपने चेक किया था तो सिर के आपको क्या जवाब दिया गया क्या समाधान मिला हाँ हाँ ठीक किया न ब्लैंड मैंने जी तो आपको जो रिप्लाई दिया गया आपने ऑफिसर सतिस्फाइड है और नॉट सतिस्फाइपको जो है रिप्लाई दिया की आपने आपसे सतिफाईड है और नॉट सतिफाईड है आप आपसे सतिफाईड है थैंक यू बेरिमस्टर कोई अब फीडबाइट गैस है इससे कोई सजेशन नहीं ब्रह्म दो मौली रूपए दो मिला था अच्छा है मैं रख रहा हूँ थैंक यू बेरिमस्टर कितना कीमती समय देने आरोप आपको जो हो'}


## Run it again!

Now let's step it up a notch. Let's try transcribing 81 minutes of audio.

In [14]:
!pip install y2mate-api

Collecting argparse>=1.1 (from y2mate-api)
  Using cached argparse-1.4.0-py2.py3-none-any.whl (23 kB)
Installing collected packages: argparse
Successfully installed argparse-1.4.0
[0m

In [17]:
from y2mate_api import Handler
api = Handler("CcY4GWjyl2U") # youtube id for audio i have downloaded from
api.save({'size': '75.4 MB', 'f': 'mp3', 'q': '128kbps', 'q_text': 'MP3 - 128kbps', 'k': 'joBHWYWzwJKgasL176uEqMDxAl7j4vImhNgqlAxyU/NQ9dhhhr32bZgdeelenN4=', 'status': 'ok', 'mess': '', 'c_status': 'CONVERTED', 'vid': 'CcY4GWjyl2U', 'title': "India's #1 of the BEST Life Changing Seminar by Mohammad Shakeel (Motivational Speech)", 'ftype': 'mp3', 'fquality': '128', 'dlink': 'https://dl204.filemate22.shop/?file=M3R4SUNiN3JsOHJ6WWQ2a3NQS1Y5ZGlxVlZIOCtyZ05sdEV6eGdOb1VPQkRvTVk3MytIckFleDhHdXdqeEl5bld2MWM5RERmZU42TWV6eU11NUVvUTJISzljOTJsRERIOG9NdFdNMDZaUjc1a09PbW1ucGJoaFA4YU5uWkhMZFliSHN3a0ZCbTFpV2JndnpFNmxQK3VuR29tRjJDZUN4WDkya3RKUHJGNVlwSzBDeVpTZnIwZ05WWGkzRGFzTHhxMmNuSmt6YjgycjVzalpKNFRoY3lkSlZTaEsvR21mWFV0VTRMaFkwVjEwajUrTFgyWDh4dFJmekxLbkZqTnpnTXRMaTJERWhHbTNaSHZuK29xNnNudWc9PQ%3D%3D'})


10:42:17 - ERROR : __init__() got an unexpected keyword argument 'exit_on_error'
India's #1 of the BEST Life Changing Seminar by Mohammad Shakeel (Motivational Speech) CcY4GWjyl2U_128.mp3


[32m78 MB [36m           [33m|[39m    


"India's #1 of the BEST Life Changing Seminar by Mohammad Shakeel (Motivational Speech) CcY4GWjyl2U_128.mp3"

In [18]:
!mv "/kaggle/working/India's #1 of the BEST Life Changing Seminar by Mohammad Shakeel (Motivational Speech) CcY4GWjyl2U_128.mp3" /kaggle/working/81min.mp3

In [21]:
signal, sample_rate = librosa.load('/kaggle/working/81min.mp3', sr=16000)


In [22]:
# transcribe using cached function
%time text = pipeline(signal) #81 min long audio transcribing
# 1min 19s means 80 sec for 81 min audio, that super fast, right?



CPU times: user 11min 6s, sys: 18min 51s, total: 29min 57s
Wall time: 1min 19s


Just **74s** to transcribe for **81 mins** of audio! That means you could transcribe an entire **2 hour movie in under 2.5 minutes 🤯** By increasing the batch size, we could also reduce the transcription time for long audio files further: increasing the batch size by 2x roughly decreases the transcription time by 2x, provided the overall batch size is less than the total audio time.

If you're fortunate enough to have access to a TPU v4, you'll find that the transcription times a factor of 2 faster than on a v3 - you can quickly see how we can get super fast transcription times using Whisper JAX on TPU!

## ⏰ Timestamps and more

We can also get timestamps from the model by passing `return_timestamps=True`, but this will require a recompilation since we change the signature of the forward pass. 

The timestamps compilation takes longer than the non-timestamps one. Luckily, because we initialised our compilation cache above, we're not starting from scratch in compiling this time. This is the last compilation we need to do!

In [None]:
# compile the forward call with timestamps - slow but we only do once
%time outputs = pipeline(signal, return_timestamps=True)
text = outputs["text"]  # transcription
chunks = outputs["chunks"]  # transcription + timestamps

In [None]:
# use cached timestamps function - super fast!
%time outputs = pipeline(signal, return_timestamps=True)
text = outputs["text"] 
chunks = outputs["chunks"]

We've shown how you can transcibe an audio file in English. The pipeline is also compatible with two further arguments that you can use to control the generation process. It's perfectly fine to omit these if you want speech transcription and the Whisper model to automatically detect which language the audio is in. Otherwise, you can change them depending on your task/language:


* `task`: task to use for generation, either `"transcribe"` or `"translate"`. Defaults to `"transcribe"`.
* `language`: language token to use for generation, can be either in the form of `"<|en|>"`, `"en"` or `"english"`. Defaults to `None`, meaning the language is automatically inferred from the audio input. Optional, and only relevant if the source audio language is known a-priori.