## Whisper JAX ⚡️

This Kaggle notebook demonstratese how to run Indic Whisper JAX on a TPU v3-8. Indic Whisper JAX is a highly optimised JAX implementation of the Indic Whisper model by AI4Bharat,n. Compared to OpenAI's PyTorch code, Whisper JAX runs over **70x faster**, making it the fastest Whisper implementation available.

You can find the code [here](https://github.com/sanchit-gandhi/whisper-jax).

## Let's get started!

The first thing we need to do is connect to a TPU. Kaggle offers 20 hours of TPU v3-8 usage per month for free, which we'll make use of for this notebook. Refer to the guide [Introducing TPUs to Kaggle](https://www.kaggle.com/product-feedback/129828) for more information on TPU quotas in Kaggle.

You will need to register a Kaggle account and verify your phone number if you haven't done so already. Once verified, open up the settings menu in the Notebook editor (the small arrow in the bottom right). Then under _Notebook options_, select ‘TPU VM v3-8’ from the _Accelerator_ menu. You will also need to toggle the internet switch so that it is set to "on".

Once we've got a TPU allocated (there might be a queue to get one!), we can run the following to see the TPU devices we have available:

In [1]:
import jax
jax.devices()

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

Cool! We've got 8 TPU devices packaged into one accelerator.

Kaggle TPUs come with JAX pre-installed, so we can directly install the remaining Python packages. If you're running the notebook on a Cloud TPU, ensure you have installed JAX according to the official [installation guide](https://github.com/google/jax#pip-installation-google-cloud-tpu). 

We'll install [Whisper JAX](https://github.com/sanchit-gandhi/whisper-jax) from main, as well as `datasets` and `librosa` for loading audio files:

In [None]:
!pip install --quiet --upgrade pip
!pip install --quiet git+https://github.com/sanchit-gandhi/whisper-jax.git datasets librosa tqdm

## Loading the Pipeline

The recommended way of running Whisper JAX is through the [`FlaxWhisperPipline`](https://github.com/sanchit-gandhi/whisper-jax/blob/main/whisper_jax/pipeline.py#L57) class. This class handles all the necessary pre- and post-processing for the model, as well as wrapping the generate method for data parallelism across all available accelerator devices


Let's load the large-v2 model in bfloat16 (half-precision). Using half-precision will speed-up the computation quite considerably by storing intermediate tensors in half-precision. There is no change to the precision of the model weights.

We'll also make use of _batching_ for single audio inputs: the audio is first chunked into 30 second segments, and then chunks dispatched to the model to be transcribed in parallel. By batching an audio input and transcribing it in parallel, we get a ~10x speed-up compared to transcribing the audio samples sequentially.

In [None]:
from whisper_jax import FlaxWhisperPipline
import jax.numpy as jnp

pipeline = FlaxWhisperPipline("parthiv11/indic_whisper_hi_multi_gpu", dtype=jnp.bfloat16, batch_size=16)

We'll then initialise a compilation cache, which will speed-up the compilation time if we close our kernel and want to compile the model again:

In [None]:
%cd /kaggle/working
from jax.experimental.compilation_cache import compilation_cache as cc

cc.initialize_cache("./jax_cache")

## 🎶 Load an audio file

Let's load up a long audio file for our tests.Note that you can also pass in any `.mp3`, `.wav` or `.flac` audio file directly to the Whisper JAX pipeline, and it will take care of loading the audio file for you.

We can take a listen to the audio file that we've loaded - we'll see that it's approximately 5 mins long:

In [97]:
# _6900884053.mp3' is 0 seconds not work
file_path='//kaggle/working/2_hours_hanuman.wav/Rasraj Ji Maharaj - Hanuman Chalisa Bajrang Baan Sundarkand Paath Bhakti & More  TRSH 232.mp4'

In [None]:
from IPython.display import Audio

Audio(file_path)

In [None]:
import librosa

# Load the audio file
waveform, sample_rate = librosa.load(file_path)

# Print waveform shape and sample rate
print("Waveform shape:", waveform.shape)
print("Sample rate:", sample_rate)


## Run the model

Now we're ready to transcribe! We'll need to compile the `pmap` function the first time we use it. You can expect compilation to take ~2 minutes on a TPU v3-8 with a batch size of 16. Enough time to grab a coffee ☕️

Thereafter, we can use our cached `pmap` function, which you'll see is amazingly fast.

In [99]:
# JIT compile the forward call - slow, but we only do once
%time text= pipeline(waveform)



CPU times: user 23.3 s, sys: 39.9 s, total: 1min 3s
Wall time: 3.93 s


In [109]:
# used cached function thereafter - super fast!
%time text= pipeline(waveform,task='transcribe',language='<|gu|>')
text



CPU times: user 23.5 s, sys: 39.9 s, total: 1min 3s
Wall time: 3.92 s


{'text': 'हेलो नमस्कार सर मैं डिपार्टमेंट ऑफ़ सेडमिनिस्टेटी डिपार्टमेंट संपत्ति जिसकी तरफ ऐसी सभ्य का बात करूँ क्या मेरी बस श्याम विश्वरियों से नहीं वो उसको पेरिस बोल रहा हूँ जी सर मैं उन्होंने में लिखी अपनी शिकायत दर्ज किया था सत्ताईस अगस्त को दो हजार तेईस में क्या आपको जानकारी होगी इसके बारे में हाँ थोड़ा है बोले जी सर इसी का शिकायत संख्या है जीरो वन सेव डबल सेम फाइव थ्री सिर आप ऐसी जानना चाहेंगे की आपको प्रदेश का रिप्लाई जवाब डिपार्टमेंट के द्वारा ग्यारह सितम्बर को दिया गया है की आपने ऐसी चेक कर लिया और अनकोश ठाकुर रूपया इससे और कोई अवार्ड के लिए न मैडम इसलिए ये किया था जी जी जी सबको हम बता दते हैं की उन्होंने किस प्रकार में ये शिकायत दर्ज की तो यह दर्ज दया दया जी जी और वार्ड के लिए किया था न जी जी बिल्कुल सर और वार्ड से रिलेटेड में की शिकायत थी हाँ हाँ जी जी सर तो आपको जिसका रिस्प्लाई दिजवाब आरोप डिपार्टमेंट के द्वारा ग्यारह सितम्बर को दिया गया है की आ आपने उसे चेक कर लिया है आप चेक नहीं किया नहीं किया अब भी चेक नहीं किया है आपने जी सिर यदि आपने उसे चेक नहीं किया तो हम आपको बता देत

In [83]:
# let's check our transcription - looks spot on!
print(text)

{'text': 'हेलो नमस्कार सर मैं डिपार्टमेंट ऑफ़ सेडमिनिस्टेटी डिपार्टमेंट संपत्ति जिसकी तरफ ऐसी सभ्य का बात करूँ क्या मेरी बस ऐसी आम जी को रिपोर्टर नहीं वो उसको पेरिस बोल रहो है जी सर मैं उन्होंने निश्चित रिपोर्टर लेटीज अपनी शिकायत दर्ज किया था सत्ताईस अगस्त को दो हजार तेईस में क्या आपको जानकारी होगी इसके बारे में हाँ थोड़ा है बोले जी सर इसी का शिकायत संख्या इसकी शिकायत संख्या है जीरो वन सेमन डबल फाइव थ्री तरह आपको जानना च मिले इसके बारे में है थोड़ा है बोले जी सिर इसी का शिकायत संख्या है जीरो वन सेमन डबल सेम फाइव थ्री सिर आपको जानना चाहेंगे की आपको सदस्य का रिप्लाई जी जवाब डिपार्टमेंट के द्वारा ग्यारह सितम्बर को दिया गया है की आपने उसे चेक कर लिया और आप अनकोष ठारू रूपया इससे और कोई अवार्ड के लिए न मैडम इसलिए ये किया था जी जी जी सबको हम बता देते है की उन्होंने किस प्रकार में ये शिकायत दर्ज की तो या दर्ज दर्ज जी जी यह वार्ड के लिए किया था न जी जी बिल्कुल सर और वार्ड से रिलेटेड में की शिकायत थी हाँ हाँ जी जी सर तो आपको जिसका रिस्प्लाई गुजवाब आरोप डिपार्टमेंट के द्वारा ग्यारह सितम्बर को दिय

## Run it again!

Now let's step it up a notch. Let's try transcribing 30 minutes of audio from the LibriSpeech dataset. We'll first load up the second sample from our dataset, which corresponds to the 30 min audio file. We'll then pass the audio to the model for transcription, again timing how long the foward pass takes:

In [None]:
audio = test_dataset[1]["audio"]  # load the second sample (30 mins) and get the audio array

audio_length_in_mins = len(audio["array"]) / audio["sampling_rate"] / 60
print(f"Audio is {audio_length_in_mins} mins.")

In [None]:
# transcribe using cached function
%time text = pipeline(audio)

Just 35s to transcribe for 30 mins of audio! That means you could transcribe an entire 2 hour movie in under 2.5 minutes 🤯 By increasing the batch size, we could also reduce the transcription time for long audio files further: increasing the batch size by 2x roughly decreases the transcription time by 2x, provided the overall batch size is less than the total audio time.

If you're fortunate enough to have access to a TPU v4, you'll find that the transcription times a factor of 2 faster than on a v3 - you can quickly see how we can get super fast transcription times using Whisper JAX on TPU!

## ⏰ Timestamps and more

We can also get timestamps from the model by passing `return_timestamps=True`, but this will require a recompilation since we change the signature of the forward pass. 

The timestamps compilation takes longer than the non-timestamps one. Luckily, because we initialised our compilation cache above, we're not starting from scratch in compiling this time. This is the last compilation we need to do!

In [None]:
# compile the forward call with timestamps - slow but we only do once
%time outputs = pipeline(audio, return_timestamps=True)
text = outputs["text"]  # transcription
chunks = outputs["chunks"]  # transcription + timestamps

In [None]:
# use cached timestamps function - super fast!
%time outputs = pipeline(audio, return_timestamps=True)
text = outputs["text"] 
chunks = outputs["chunks"]

We've shown how you can transcibe an audio file in English. The pipeline is also compatible with two further arguments that you can use to control the generation process. It's perfectly fine to omit these if you want speech transcription and the Whisper model to automatically detect which language the audio is in. Otherwise, you can change them depending on your task/language:


* `task`: task to use for generation, either `"transcribe"` or `"translate"`. Defaults to `"transcribe"`.
* `language`: language token to use for generation, can be either in the form of `"<|en|>"`, `"en"` or `"english"`. Defaults to `None`, meaning the language is automatically inferred from the audio input. Optional, and only relevant if the source audio language is known a-priori.

## Used for transcribing cpgrams audio dataser\t

In [None]:
import os
import librosa
import tqdm
import json


In [None]:
# preprocessing
# Function to preprocess audio file to 16kHz waveform
def preprocess_audio(audio_file):
    signal, _ = librosa.load(audio_file, sr=16000)
    return signal

In [None]:
files=os.listdir('/kaggle/input/cpgram-audios')
%cd /kaggle/input/cpgram-audios

In [None]:
%time
transcript={}
for f,j in tqdm(wave_dict:
    transcript[f] = pipeline(j)
               

In [None]:
                
print(len(transcript))
with open('/kaggle/working/transcript.json', "w") as json_file:
    json.dump(transcript, json_file)