<a href="https://colab.research.google.com/github/sanchit-gandhi/notebooks/blob/main/Flax_vs_PyTorch_Whisper.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Flax Whisper vs PyTorch Whisper

In this notebook, we demonstrate how you can run inference with Whisper up to **10x faster** in Flax than PyTorch on a Colab GPU with largely the same code.

## Setup the environment

First of all, let's try to secure a decent GPU for our Colab! To get a GPU, click _Runtime_ -> _Change runtime type_, then change _Hardware accelerator_ from _None_ to _GPU_.

We can verify that we've been assigned a GPU and view its specifications:

In [1]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)

Thu Mar  2 11:01:21 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   58C    P0    27W /  70W |      0MiB / 15360MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

Next, we install [Flax](https://flax.readthedocs.io/en/latest/), [Datasets](https://github.com/huggingface/datasets) and [Transformers](https://github.com/huggingface/transformers) from main:

In [2]:
!pip install --quiet flax datasets>=2.6.1 git+https://github.com/huggingface/transformers 

[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
ipython 7.9.0 requires jedi>=0.10, which is not installed.[0m[31m
[0m

And finally, import all the required packages:

In [3]:
import jax
import jax.numpy as jnp
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm import tqdm

from transformers import FlaxWhisperForConditionalGeneration, WhisperForConditionalGeneration, WhisperProcessor

## Benchmark Specifications

We'll evalaute both the PyTorch and Flax Whisper models using the same checkpoint, defined below. You can change this to a checkpoint of your choice, including larger pre-trained checkpoints or fine-tuned variants. Refer to the [HF Hub](https://huggingface.co/models?sort=downloads&search=whisper) for a full list of checkpoints:

In [4]:
model_id = "openai/whisper-tiny.en"  # change to a model checkpoint of your choice

## Prepare Dataset

We'll load a small dataset consisting of 73 samples from the [LibriSpeech ASR](https://huggingface.co/datasets/librispeech_asr) validation-clean split. This amounts to ~9MB of data, so it's very lightweight and quick to download on device:

In [5]:
librispeech = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")



Let's preprocess the data by computing the log-mel features. These features will be the same for our PyTorch model and Flax one, so we only have to do this once:

In [6]:
processor = WhisperProcessor.from_pretrained(model_id)

def preprocess(batch):
    batch["input_features"] = processor(batch["audio"]["array"], sampling_rate=16000, return_tensors="np").input_features[0]
    return batch

dataset_processed = librispeech.map(preprocess, remove_columns=librispeech.column_names)



## PyTorch Benchmark

First, let's load the PyTorch model and move it to the GPU. We'll follow the official inference recommendations and evaluate in half (float16) precision:

In [7]:
model = WhisperForConditionalGeneration.from_pretrained(model_id)
model.to("cuda")
model.eval()
model.half();

Next, we define our dataloader. For this benchmark, we'll perform single-batch inference (`batch_size=1`):

In [8]:
dataloader = DataLoader(dataset_processed.with_format("torch"), batch_size=1)

Finally, we perform inference over our dataset:

In [9]:
for batch in tqdm(dataloader):
    input_features = batch["input_features"].to("cuda").half()
    pred_ids = model.generate(input_features, max_length=128)

100%|██████████| 73/73 [00:20<00:00,  3.60it/s]


Depending on the GPU allocated to the Colab, you can expect inference to take ~20sec for 73 samples on a T4 or ~8sec on a V100.

## Flax Benchmark

We perform the parallel steps to the PyTorch benchmark, this time in Flax. First, we load the model. JAX automatically places the model onto the accelerator device, we just need to specifiy the `dtype` as fp16. We set the flag `from_pt=True` to load and convert PyTorch weights if available:

In [10]:
model = FlaxWhisperForConditionalGeneration.from_pretrained(model_id, dtype=jnp.float16, from_pt=True)

Again, we define our dataloader for single-batch inference:

In [11]:
dataloader = DataLoader(dataset_processed.with_format("numpy"), batch_size=1)

We [JIT](https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html#just-in-time-compilation-with-jax) the generate function so that we can compile it and re-use the cached kernels:

In [12]:
jit_generate = jax.jit(model.generate, static_argnames=["max_length"])

We run the compile step on the first batch of data. We only have to do this compilation once. Afterwards, we can re-use the kernels for fast inference:

In [13]:
batch = next(iter(dataloader))
input_features = jnp.array(batch["input_features"], dtype=jnp.float16)
pred_ids = jit_generate(input_features, max_length=128)

Finally, we perform inference over our dataset:

In [14]:
for batch in tqdm(dataloader):
    input_features = jnp.array(batch["input_features"], dtype=jnp.float16)
    pred_ids = jit_generate(input_features, max_length=128)

100%|██████████| 73/73 [00:02<00:00, 29.56it/s]


Depending on the GPU allocated to the Colab, you can expect inference to take ~2sec for 73 samples. This is a factor of **10x faster** than PyTorch on a T4⚡️