In [None]:
"""

| Size   | Layers | Width | Heads | Parameters | English-only                                         | Multilingual                                      |
|--------|--------|-------|-------|------------|------------------------------------------------------|---------------------------------------------------|
| tiny   | 4      | 384   | 6     | 39 M       | [✓](https://huggingface.co/openai/whisper-tiny.en)   | [✓](https://huggingface.co/openai/whisper-tiny.)  |
| base   | 6      | 512   | 8     | 74 M       | [✓](https://huggingface.co/openai/whisper-base.en)   | [✓](https://huggingface.co/openai/whisper-base)   |
| small  | 12     | 768   | 12    | 244 M      | [✓](https://huggingface.co/openai/whisper-small.en)  | [✓](https://huggingface.co/openai/whisper-small)  |
| medium | 24     | 1024  | 16    | 769 M      | [✓](https://huggingface.co/openai/whisper-medium.en) | [✓](https://huggingface.co/openai/whisper-medium) |
| large  | 32     | 1280  | 20    | 1550 M     | x                                                    | [✓](https://huggingface.co/openai/whisper-large)  |

"""

Imports

In [1]:
import os
from os.path import dirname, abspath
from pathlib import Path
from functools import partial
import logging
import argparse
from tqdm.auto import tqdm
import time, math
import shutil

import numpy as np
import jax
import jax.numpy as jnp
import optax
from flax import jax_utils, traverse_util
from flax.jax_utils import unreplicate
import orbax
from flax.training import (
    train_state,
    orbax_utils
)
from flax.training.common_utils import (
    onehot,
    shard,
    shard_prng_key,
    get_metrics
)

import transformers
from transformers import (
    GenerationConfig,
    WhisperFeatureExtractor,
    WhisperTokenizer,
    WhisperProcessor,
    FlaxWhisperForConditionalGeneration,
)
from transformers import (
    set_seed,
)
from transformers.utils import send_example_telemetry

import datasets
from datasets import (
    Dataset,
    load_dataset,
    DatasetDict,
    Audio
)

import evaluate

  from .autonotebook import tqdm as notebook_tqdm


Logger

In [2]:
logger = logging.getLogger(__name__)

# sending telemetry
# tracking the example usage helps us better allocate resources to maintain them
# the information sent is the one passed as arguments along with your Python/PyTorch versions
send_example_telemetry("run_summarization", framework="flax")

Constants

In [None]:
# get root directory
# only works if run from within 'speech-processing' directory
# else replace `root` with correct path
root = abspath(__file__)
while root.split('/')[-1] != 'speech-processing':
    root = dirname(root)

# constants
LANG_TO_ID = {"hindi" : "<|hi|>"}

Arguments

In [4]:
seed = 42
# set seed
set_seed(seed)

# model
model_name_or_path = 'openai/whisper-small'
model_lang = 'hindi'
task = 'transcribe'
dtype = 'float32'  # float16
generation_max_length = 225
per_device_eval_batch_size = 4
eval_batch_size = int(per_device_eval_batch_size) * jax.device_count()
num_beams = 1
label_smoothing_factor =0.0

# data
data_dir = 'mozilla-foundation/common_voice_11_0'
data_lang = 'hi'
max_train_samples = None
max_test_samples = None

# flags
freeze_encoder = False

TrainState

In [5]:
class TrainState(train_state.TrainState):
    dropout_rng: jnp.ndarray

    def replicate(self):
        return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))

Data Loader

In [6]:
def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False, drop_last=True):
    """
    Returns batches of size `batch_size` from `dataset`. If `drop_last` is set to `False`, the final batch may be incomplete,
    and range in size from 1 to `batch_size`. Shuffle batches if `shuffle` is `True`.
    """
    if shuffle:
        batch_idx = jax.random.permutation(rng, len(dataset))
        batch_idx = np.asarray(batch_idx)
    else:
        batch_idx = np.arange(len(dataset))

    if drop_last:
        steps_per_epoch = len(dataset) // batch_size
        batch_idx = batch_idx[: steps_per_epoch * batch_size]  # Skip incomplete batch.
        batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
    else:
        steps_per_epoch = math.ceil(len(dataset) / batch_size)
        batch_idx = np.array_split(batch_idx, steps_per_epoch)

    for idx in batch_idx:
        batch = dataset[idx]
        batch = {k: np.array(v) for k, v in batch.items()}

        yield batch


# in Flax, for seq2seq models we need to pass `decoder_input_ids`
# as the Flax models don't accept `labels`, we need to prepare the decoder_input_ids here
# `shift_tokens_right` function
# copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right
def shift_tokens_right(input_ids: np.array, pad_token_id: int, decoder_start_token_id: int) -> np.ndarray:
    """
    Shift input ids one token to the right.
    """

    shifted_input_ids = np.zeros_like(input_ids)
    shifted_input_ids[:, 1:] = input_ids[:, :-1]
    shifted_input_ids[:, 0] = decoder_start_token_id
    shifted_input_ids = np.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)

    return shifted_input_ids

Model, Feature extractor, Tokenizer, Processor

In [7]:
feature_extractor = WhisperFeatureExtractor.from_pretrained(model_name_or_path)
tokenizer = WhisperTokenizer.from_pretrained(model_name_or_path, language=model_lang, task=task)

# We only need to set the task id when the language is specified (i.e. in a multilingual setting)
tokenizer.set_prefix_tokens(language=model_lang, task=task)
processor = WhisperProcessor.from_pretrained(model_name_or_path, language=model_lang, task=task)
    
# model
# FlaxWhisperForConditionalGeneration uses the FlaxWhisperPreTrainedModel forward method,
# overrides the __call__ special method
# FlaxWhisperForConditionalGeneration -> module_class = FlaxWhisperForConditionalGenerationModule
# FlaxWhisperForConditionalGeneration(FlaxWhisperPreTrainedModel)
# FlaxWhisperPreTrainedModel -> module = self.module_class
# FlaxWhisperPreTrainedModel -> __call__ -> self.module.apply
# FlaxWhisperForConditionalGenerationModule -> __call__ -> self.model -> FlaxWhisperModule
# input_shape: typing.Tuple[int] = (b, 80, 3000)
model = FlaxWhisperForConditionalGeneration.from_pretrained(
    model_name_or_path,
    seed=seed,
    dtype=getattr(jnp, dtype)
)

if model.config.decoder_start_token_id is None:
    raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")  
if freeze_encoder:
    model.freeze_encoder()
    model.model.encoder.gradient_checkpointing = False

Dataset

In [None]:
common_voice = DatasetDict()
common_voice["test"] = load_dataset(data_dir, data_lang, split="test", use_auth_token=True)

# remove unused columns
common_voice = common_voice.remove_columns(
    [
        "accent", "age", "client_id", "down_votes", "gender", "locale", "path", "segment", "up_votes"
    ]
)

# select small dataset for testing
if max_test_samples is not None:
    common_voice["test"] = common_voice["test"].select(range(max_test_samples))

# resample to 16kHz
common_voice = common_voice.cast_column("audio", Audio(sampling_rate=feature_extractor.sampling_rate))

Preprocess dataset

In [None]:
# tokenizer and generation max length
max_length = (
    generation_max_length if generation_max_length is not None else model.config.max_length
)

# function to vectorize dataset
# flax models need decoder_input_ids instead of labels
# we need fixed length inputs for jitted functions
# https://github.com/huggingface/transformers/blob/v4.29.1/src/transformers/models/whisper/feature_extraction_whisper.py#L254
#if return_attention_mask:
    # rescale from sample (48000) to feature (3000)
def prepare_dataset(batch):
    # load and resample audio data from 48 to 16kHz
    audio = batch["audio"]

    # compute log-Mel input features from input audio array 
    # 80 x 3000
    batch["input_features"] = feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]

    # encode target text to label ids 
    labels = tokenizer(
        batch["sentence"],
        padding="max_length",
        max_length=max_length,
        return_tensors="np"
    )

    # labels to compute loss
    # 1 x generation length or max length
    batch["labels"] = labels["input_ids"].flatten()
    decoder_input_ids = shift_tokens_right(
        labels["input_ids"], model.config.pad_token_id, model.config.decoder_start_token_id
    )
    # decoder_input_ids to feed into the flax model
    batch["decoder_input_ids"] = np.asarray(decoder_input_ids).flatten()

    # we need decoder_attention_mask so we can ignore pad tokens from loss
    # completely masks decoder_input_ids
    # leaves first pad token (after input ids) unmasked in labels
    # need different mask for labels?
    batch["decoder_attention_mask"] = labels["attention_mask"].flatten()

    return batch

# vectorize dataset
# input_features, decoder_input_ids, decoder_attention_mask, labels
common_voice = common_voice.map(
    prepare_dataset,
    remove_columns=common_voice.column_names["train"],
    desc="vectorize dataset"
) #, num_proc=2)

# test dataset
test_dataset = common_voice["test"]

# eval steps in test_dataset
eval_steps = math.ceil(len(common_voice["test"]) / eval_batch_size)

Metrics

In [None]:
# metrics
cer = evaluate.load("cer")
wer = evaluate.load("wer")

def compute_metrics(preds, labels):
    result = {}
    predictions = processor.batch_decode(
        preds,
        skip_special_tokens=True,
        clean_up_tokenization_spaces=True
    )
    references = processor.batch_decode(
        labels,
        group_tokens=False,
        skip_special_tokens=True,
        clean_up_tokenization_spaces=True
    )
    # compute cer, wer
    result["cer"] = cer.compute(predictions=predictions, references=references)
    result["wer"] = wer.compute(predictions=predictions, references=references)
    return result

Optimizer for state

In [None]:
# create learning rate schedule
warmup_fn = optax.linear_schedule(
    init_value=0.0, end_value=args.learning_rate, transition_steps=args.warmup_steps
)
decay_fn = optax.linear_schedule(
        init_value=args.learning_rate,
        end_value=0,
        transition_steps=args.train_steps - args.warmup_steps,
    )
    linear_decay_lr_schedule_fn = optax.join_schedules(
        schedules=[warmup_fn, decay_fn], boundaries=[args.warmup_steps]
    )

    # we use Optax's "masking" functionality to not apply weight decay
    # to bias and LayerNorm scale parameters. decay_mask_fn returns a
    # mask boolean with the same structure as the parameters.
    # the mask is True for parameters that should be decayed.
    def decay_mask_fn(params):
        flat_params = traverse_util.flatten_dict(params)
        # find out all LayerNorm parameters
        layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
        layer_norm_named_params = {
            layer[-2:]
            for layer_norm_name in layer_norm_candidates
            for layer in flat_params.keys()
            if layer_norm_name in "".join(layer).lower()
        }
        flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
        return traverse_util.unflatten_dict(flat_mask)

# create adam optimizer
adamw = optax.adamw(
    learning_rate=linear_decay_lr_schedule_fn,
    b1=args.adam_beta1,
    b2=args.adam_beta2,
    eps=args.adam_epsilon,
    weight_decay=args.weight_decay,
    mask=decay_mask_fn,
)

State and rng

In [None]:
rng = jax.random.PRNGKey(seed)
rng, dropout_rng = jax.random.split(rng)

# setup train state
# FlaxWhisperForConditionalGenerationModule -> __call__ -> self.model -> FlaxWhisperModule
state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng)

Loss function

In [None]:
# label smoothed cross entropy
def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0):
    """
    The label smoothing implementation is adapted from Flax's official example:
    https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104
    """
    vocab_size = logits.shape[-1]
    confidence = 1.0 - label_smoothing_factor
    low_confidence = (1.0 - confidence) / (vocab_size - 1)
    normalizing_constant = -(
        confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)
    )
    soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence)

    loss = optax.softmax_cross_entropy(logits, soft_labels)
    loss = loss - normalizing_constant

    # ignore padded tokens from loss
    loss = loss * padding_mask
    loss = loss.sum()
    # what is num_labels?
    num_labels = padding_mask.sum()
    
    return loss, num_labels

Eval step

In [None]:
# define eval fn
def eval_step(params, batch, label_smoothing_factor=0.0):

    labels = batch.pop("labels")
    logits = model(**batch, params=params, train=False)[0]

    loss, num_labels = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor)
    num_labels = jax.lax.psum(num_labels, "batch")

    # true loss = total loss / total samples
    loss = jax.lax.psum(loss, "batch")
    loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss)

    metrics = {"loss": loss}
    return metrics


p_eval_step = jax.pmap(partial(eval_step, label_smoothing_factor=label_smoothing_factor), "batch")

Generation step

In [None]:
# generation functions

def make_generation_config(supress_en=False):

    generation_config = GenerationConfig.from_pretrained(model_name_or_path)
    gen_dict = generation_config.to_dict()
    # add attributes to genration_config
    # generation_config does not have "langauge", but generate() tries to use it
    # can be empty dict here since being set in generate_step
    gen_dict["language"] = LANG_TO_ID[model_lang]
    if supress_en:
        # en tokens to suppress from multilingual vocab
        en_tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny.en")  # change if loaded locally
        suppress_en_list = []
        for key in en_tokenizer.encoder.keys():
            if key in tokenizer.encoder.keys() and key.isalpha():
                suppress_en_list.append(key)
        # supress english tokens
        gen_dict['suppress_tokens'].extend(tokenizer.encode(suppress_en_list, add_special_tokens=False))

    # reload with new attributes
    generation_config = GenerationConfig.from_dict(gen_dict)

    return generation_config


# max_length defined after tokenizer
num_beams = num_beams if num_beams is not None else model.config.num_beams
gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
# generation config
generation_config = make_generation_config(supress_en=False)

# batch -> input_features, decoder_input_ids, decoder_attention_mask, labels
def generate_step(params, batch):
    model.params = params
    output_ids = model.generate(
        batch["input_features"],
        generation_config=generation_config,
        task=task,
        language=LANG_TO_ID[model_lang],  # set lang here
        is_multilingual=True,
        **gen_kwargs
    )   
    return output_ids.sequences


p_generate_step = jax.pmap(generate_step, "batch")

Eval

In [None]:
eval_metrics = []
eval_preds = []
eval_labels = []
result_dict = {}

eval_loader = data_loader(input_rng, test_dataset, eval_batch_size, shuffle=True)

# eval progress bar
eval_bar = tqdm(range(eval_steps), position=0)
for batch in eval_loader:
    batch = shard(batch)
    metrics = p_eval_step(state.params, batch) # dict {'loss' : loss}
    eval_metrics.append(metrics)

    generated_ids = p_generate_step(state.params, batch)
    preds = jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"]))  # ndarray
    # labels padded to batch seq length, pred padded to max gen length
    eval_preds.extend(preds)  # b, gen_len
    eval_labels.extend(batch["labels"][0])  # b, seq_len  

    eval_bar.update(1)

# eval metrics (loss)
eval_metrics = get_metrics(eval_metrics)  # dict
eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics)  # dict
# cer, wer
result = compute_metrics(eval_preds, eval_labels)
eval_metrics.update(result)
                
# collect results together
result_dict['eval_loss'] = eval_metrics['loss']
result_dict['cer'] = eval_metrics['cer']
result_dict['wer'] = eval_metrics['wer']

# write to terminal
for key, val in result_dict.items():
    print('{} : {}'.format(key, val))