In [None]:
"""

| Size   | Layers | Width | Heads | Parameters | English-only                                         | Multilingual                                      |
|--------|--------|-------|-------|------------|------------------------------------------------------|---------------------------------------------------|
| tiny   | 4      | 384   | 6     | 39 M       | [✓](https://huggingface.co/openai/whisper-tiny.en)   | [✓](https://huggingface.co/openai/whisper-tiny.)  |
| base   | 6      | 512   | 8     | 74 M       | [✓](https://huggingface.co/openai/whisper-base.en)   | [✓](https://huggingface.co/openai/whisper-base)   |
| small  | 12     | 768   | 12    | 244 M      | [✓](https://huggingface.co/openai/whisper-small.en)  | [✓](https://huggingface.co/openai/whisper-small)  |
| medium | 24     | 1024  | 16    | 769 M      | [✓](https://huggingface.co/openai/whisper-medium.en) | [✓](https://huggingface.co/openai/whisper-medium) |
| large  | 32     | 1280  | 20    | 1550 M     | x                                                    | [✓](https://huggingface.co/openai/whisper-large)  |

"""

Imports

In [1]:
import os
from os.path import dirname, abspath
from pathlib import Path
from functools import partial
import logging
import argparse
from tqdm.auto import tqdm
import time, math
import shutil

import numpy as np
import jax
import jax.numpy as jnp
import optax
from flax import jax_utils, traverse_util
from flax.jax_utils import unreplicate
import orbax
from flax.training import (
    train_state,
    orbax_utils
)
from flax.training.common_utils import (
    onehot,
    shard,
    shard_prng_key,
    get_metrics
)

import transformers
from transformers import (
    GenerationConfig,
    WhisperFeatureExtractor,
    WhisperTokenizer,
    WhisperProcessor,
    FlaxWhisperForConditionalGeneration,
)
from transformers import (
    set_seed,
    is_tensorboard_available,
)
from transformers.utils import send_example_telemetry

import datasets
from datasets import (
    Dataset,
    load_dataset,
    DatasetDict,
    Audio
)

import evaluate

  from .autonotebook import tqdm as notebook_tqdm


Logger

In [2]:
logger = logging.getLogger(__name__)

# sending telemetry
# tracking the example usage helps us better allocate resources to maintain them
# the information sent is the one passed as arguments along with your Python/PyTorch versions
send_example_telemetry("run_summarization", framework="flax")

Constants

In [3]:
# get root directory
# only works if run from within 'speech-processing' directory
# else replace `root` with correct path
root = abspath(__file__)
while root.split('/')[-1] != 'speech-processing':
    root = dirname(root)

# constants
LANG_TO_ID = {"hindi" : "<|hi|>"}

Arguments

In [4]:
seed = 42

# model
model_name_or_path = 'openai/whisper-small'
model_lang = 'hindi'
task = 'transcribe'
dtype = 'float32'  # float16

# data
data_dir = 'mozilla-foundation/common_voice_11_0'
data_lang = 'hi'
max_train_samples = None
max_test_samples = None

# flags
freeze_encoder = False

TrainState

In [5]:
class TrainState(train_state.TrainState):
    dropout_rng: jnp.ndarray

    def replicate(self):
        return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))

Data Loader

In [6]:
def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False, drop_last=True):
    """
    Returns batches of size `batch_size` from `dataset`. If `drop_last` is set to `False`, the final batch may be incomplete,
    and range in size from 1 to `batch_size`. Shuffle batches if `shuffle` is `True`.
    """
    if shuffle:
        batch_idx = jax.random.permutation(rng, len(dataset))
        batch_idx = np.asarray(batch_idx)
    else:
        batch_idx = np.arange(len(dataset))

    if drop_last:
        steps_per_epoch = len(dataset) // batch_size
        batch_idx = batch_idx[: steps_per_epoch * batch_size]  # Skip incomplete batch.
        batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
    else:
        steps_per_epoch = math.ceil(len(dataset) / batch_size)
        batch_idx = np.array_split(batch_idx, steps_per_epoch)

    for idx in batch_idx:
        batch = dataset[idx]
        batch = {k: np.array(v) for k, v in batch.items()}

        yield batch


# in Flax, for seq2seq models we need to pass `decoder_input_ids`
# as the Flax models don't accept `labels`, we need to prepare the decoder_input_ids here
# `shift_tokens_right` function
# copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right
def shift_tokens_right(input_ids: np.array, pad_token_id: int, decoder_start_token_id: int) -> np.ndarray:
    """
    Shift input ids one token to the right.
    """

    shifted_input_ids = np.zeros_like(input_ids)
    shifted_input_ids[:, 1:] = input_ids[:, :-1]
    shifted_input_ids[:, 0] = decoder_start_token_id
    shifted_input_ids = np.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)

    return shifted_input_ids

Model, Feature extractor, Tokenizer, Processor

In [7]:
feature_extractor = WhisperFeatureExtractor.from_pretrained(model_name_or_path)
tokenizer = WhisperTokenizer.from_pretrained(model_name_or_path, language=model_lang, task=task)

# We only need to set the task id when the language is specified (i.e. in a multilingual setting)
tokenizer.set_prefix_tokens(language=model_lang, task=task)
processor = WhisperProcessor.from_pretrained(model_name_or_path, language=model_lang, task=task)
    
# model
# FlaxWhisperForConditionalGeneration uses the FlaxWhisperPreTrainedModel forward method,
# overrides the __call__ special method
# FlaxWhisperForConditionalGeneration -> module_class = FlaxWhisperForConditionalGenerationModule
# FlaxWhisperForConditionalGeneration(FlaxWhisperPreTrainedModel)
# FlaxWhisperPreTrainedModel -> module = self.module_class
# FlaxWhisperPreTrainedModel -> __call__ -> self.module.apply
# FlaxWhisperForConditionalGenerationModule -> __call__ -> self.model -> FlaxWhisperModule
# input_shape: typing.Tuple[int] = (b, 80, 3000)
model = FlaxWhisperForConditionalGeneration.from_pretrained(
    model_name_or_path,
    seed=seed,
    dtype=getattr(jnp, dtype)
)

if model.config.decoder_start_token_id is None:
    raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")  
if freeze_encoder:
    model.freeze_encoder()
    model.model.encoder.gradient_checkpointing = False

Dataset

In [None]:
common_voice = DatasetDict()
common_voice["train"] = load_dataset(data_dir, data_lang, split="train+validation", use_auth_token=True)
common_voice["test"] = load_dataset(data_dir, data_lang, split="test", use_auth_token=True)

# remove unused columns
common_voice = common_voice.remove_columns(
    [
        "accent", "age", "client_id", "down_votes", "gender", "locale", "path", "segment", "up_votes"
    ]
)

# select small dataset for testing
if max_train_samples is not None:
    common_voice["train"] = common_voice["train"].select(range(max_train_samples))

if max_test_samples is not None:
    common_voice["test"] = common_voice["test"].select(range(max_test_samples))

# resample to 16kHz
common_voice = common_voice.cast_column("audio", Audio(sampling_rate=feature_extractor.sampling_rate))

Preprocess dataset

In [None]:
# tokenizer and generation max length
max_length = (
    args.generation_max_length if args.generation_max_length is not None else model.config.max_length
)

# function to vectorize dataset
# flax models need decoder_input_ids instead of labels
# we need fixed length inputs for jitted functions
# https://github.com/huggingface/transformers/blob/v4.29.1/src/transformers/models/whisper/feature_extraction_whisper.py#L254
#if return_attention_mask:
    # rescale from sample (48000) to feature (3000)
def prepare_dataset(batch):
    # load and resample audio data from 48 to 16kHz
    audio = batch["audio"]

    # compute log-Mel input features from input audio array 
    # 80 x 3000
    batch["input_features"] = feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]

    # encode target text to label ids 
    labels = tokenizer(
        batch["sentence"],
        padding="max_length",
        max_length=max_length,
        return_tensors="np"
    )

    # labels to compute loss
    # 1 x generation length or max length
    batch["labels"] = labels["input_ids"].flatten()
    decoder_input_ids = shift_tokens_right(
        labels["input_ids"], model.config.pad_token_id, model.config.decoder_start_token_id
    )
    # decoder_input_ids to feed into the flax model
    batch["decoder_input_ids"] = np.asarray(decoder_input_ids).flatten()

    # we need decoder_attention_mask so we can ignore pad tokens from loss
    # completely masks decoder_input_ids
    # leaves first pad token (after input ids) unmasked in labels
    # need different mask for labels?
    batch["decoder_attention_mask"] = labels["attention_mask"].flatten()

    return batch

# vectorize dataset
# input_features, decoder_input_ids, decoder_attention_mask, labels
common_voice = common_voice.map(
    prepare_dataset,
    remove_columns=common_voice.column_names["train"],
    desc="vectorize dataset"
) #, num_proc=2)

# train and test datasets
train_dataset = common_voice["train"]
test_dataset = common_voice["test"]