In [None]:
"""

| Size   | Layers | Width | Heads | Parameters | English-only                                         | Multilingual                                      |
|--------|--------|-------|-------|------------|------------------------------------------------------|---------------------------------------------------|
| tiny   | 4      | 384   | 6     | 39 M       | [✓](https://huggingface.co/openai/whisper-tiny.en)   | [✓](https://huggingface.co/openai/whisper-tiny.)  |
| base   | 6      | 512   | 8     | 74 M       | [✓](https://huggingface.co/openai/whisper-base.en)   | [✓](https://huggingface.co/openai/whisper-base)   |
| small  | 12     | 768   | 12    | 244 M      | [✓](https://huggingface.co/openai/whisper-small.en)  | [✓](https://huggingface.co/openai/whisper-small)  |
| medium | 24     | 1024  | 16    | 769 M      | [✓](https://huggingface.co/openai/whisper-medium.en) | [✓](https://huggingface.co/openai/whisper-medium) |
| large  | 32     | 1280  | 20    | 1550 M     | x                                                    | [✓](https://huggingface.co/openai/whisper-large)  |

"""

import os
from os.path import dirname, abspath
from pathlib import Path
from functools import partial
import logging
import argparse
from tqdm.auto import tqdm
import time, math
import shutil

import numpy as np
import jax
import jax.numpy as jnp
import optax
from flax import jax_utils, traverse_util
from flax.jax_utils import unreplicate, pad_shard_unpad
import orbax
from flax.training import (
    train_state,
    orbax_utils
)
from flax.training.common_utils import (
    onehot,
    shard,
    get_metrics
)

import transformers
from transformers import (
    GenerationConfig,
    WhisperFeatureExtractor,
    WhisperTokenizer,
    WhisperProcessor,
    FlaxWhisperForConditionalGeneration,
)
from transformers import (
    set_seed,
    is_tensorboard_available,
)
from transformers.utils import send_example_telemetry

import datasets
from datasets import (
    Dataset,
    load_dataset,
    DatasetDict,
    Audio
)

import evaluate

from multiprocess import set_start_method
set_start_method("spawn")


#jax.config.update('jax_array', False) -> only works below jax and jaxlib 0.4.6
logger = logging.getLogger(__name__)

# setup logging, we only want one process per machine to log things on the screen.
logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
if jax.process_index() == 0:
    datasets.utils.logging.set_verbosity_warning()
    transformers.utils.logging.set_verbosity_info()
else:
    datasets.utils.logging.set_verbosity_error()
    transformers.utils.logging.set_verbosity_error()



# sending telemetry
# tracking the example usage helps us better allocate resources to maintain them
# the information sent is the one passed as arguments along with your Python/PyTorch versions
send_example_telemetry("run_summarization", framework="flax")


# get root directory
root = abspath(__file__)
while root.split('/')[-1] != 'speech-processing':
    root = dirname(root)

# constants
LANG_TO_ID = {"hindi" : "<|hi|>",
              "chinese" : "<|zh|>"}


seed = 42
# set seed
set_seed(seed)

# model
model_name_or_path = 'openai/whisper-tiny'
model_lang = 'hindi'
task = 'transcribe'
dtype = 'float32'  # float16
sampling_rate = 16000
generation_max_length = 225
per_device_train_batch_size = 4
per_device_eval_batch_size = 4
eval_batch_size = int(per_device_eval_batch_size) * jax.device_count()
num_beams = 1
label_smoothing_factor = 0.0
learning_rate = 1e-5
adam_beta1 = 0.9
adam_beta2 = 0.999
adam_epsilon = 1e-6
weight_decay = 0.0
train_steps = 2000
warmup_steps = 0
max_to_keep = 3

# flags
freeze_encoder = False

# data
data_dir = 'mozilla-foundation/common_voice_11_0'
data_lang = 'hi'
max_train_samples = 100  # None
max_test_samples = 20  # None
num_workers = os.cpu_count()

# output / checkpoint directory
model_str = model_name_or_path.split('/')[-1]
data_str = data_dir.split('/')[-1]
output_dir = root+'/models/whisper/'+model_str+'_jax_'+data_str

In [None]:
# simulate multi gpu

#import os
#flags = os.environ.get('XLA_FLAGS', '')
#os.environ['XLA_FLAGS'] = flags + " --xla_force_host_platform_device_count=4"

In [None]:

#class TrainState(train_state.TrainState):
    #dropout_rng: jnp.ndarray

    #def replicate(self):
        #return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng.reshape(-1)))
    

def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False, drop_last=True):
    """
    Returns batches of size `batch_size` from `dataset`. If `drop_last` is set to `False`, the final batch may be incomplete,
    and range in size from 1 to `batch_size`. Shuffle batches if `shuffle` is `True`.
    """
    if shuffle:
        batch_idx = jax.random.permutation(rng, len(dataset))
        batch_idx = np.asarray(batch_idx)
    else:
        batch_idx = np.arange(len(dataset))

    if drop_last:
        steps_per_epoch = len(dataset) // batch_size
        batch_idx = batch_idx[: steps_per_epoch * batch_size]  # Skip incomplete batch.
        batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
    else:
        steps_per_epoch = math.ceil(len(dataset) / batch_size)
        batch_idx = np.array_split(batch_idx, steps_per_epoch)

    for idx in batch_idx:
        batch = dataset[idx]
        batch = {k: np.array(v) for k, v in batch.items()}

        yield batch
    


# in Flax, for seq2seq models we need to pass `decoder_input_ids`
# as the Flax models don't accept `labels`, we need to prepare the decoder_input_ids here
# `shift_tokens_right` function
# copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right
def shift_tokens_right(input_ids: np.array, pad_token_id: int, decoder_start_token_id: int) -> np.ndarray:
    """
    Shift input ids one token to the right.
    """

    shifted_input_ids = np.zeros_like(input_ids)
    shifted_input_ids[:, 1:] = input_ids[:, :-1]
    shifted_input_ids[:, 0] = decoder_start_token_id
    shifted_input_ids = np.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)

    return shifted_input_ids


# label smoothed cross entropy
def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0):
    """
    The label smoothing implementation is adapted from Flax's official example:
    https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104
    """
    vocab_size = logits.shape[-1]
    confidence = 1.0 - label_smoothing_factor
    low_confidence = (1.0 - confidence) / (vocab_size - 1)
    normalizing_constant = -(
        confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)
    )
    soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence)

    loss = optax.softmax_cross_entropy(logits, soft_labels)
    loss = loss - normalizing_constant

    # ignore padded tokens from loss
    loss = loss * padding_mask
    loss = loss.sum()
    num_labels = padding_mask.sum()  # number of actual values (rest padding)
    
    return loss, num_labels

In [None]:
# extractor, tokenizer, processor
feature_extractor = WhisperFeatureExtractor.from_pretrained(model_name_or_path)
tokenizer = WhisperTokenizer.from_pretrained(model_name_or_path, language=model_lang, task=task)

# we only need to set the task id when the language is specified (i.e. in a multilingual setting)
tokenizer.set_prefix_tokens(language=args.model_lang, task=task)
processor = WhisperProcessor.from_pretrained(args.model_name_or_path, language=model_lang, task=task)
    

# model
# FlaxWhisperForConditionalGeneration uses the FlaxWhisperPreTrainedModel forward method,
# overrides the __call__ special method
# FlaxWhisperForConditionalGeneration -> module_class = FlaxWhisperForConditionalGenerationModule
# FlaxWhisperForConditionalGeneration(FlaxWhisperPreTrainedModel)
# FlaxWhisperPreTrainedModel -> module = self.module_class
# FlaxWhisperPreTrainedModel -> __call__ -> self.module.apply
# FlaxWhisperForConditionalGenerationModule -> __call__ -> self.model -> FlaxWhisperModule
# input_shape: typing.Tuple[int] = (b, 80, 3000)
model = FlaxWhisperForConditionalGeneration.from_pretrained(
    model_name_or_path,
    seed=seed,
    dtype=getattr(jnp, dtype)
)

if model.config.decoder_start_token_id is None:
    raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")  
if freeze_encoder:
    model.freeze_encoder()
    model.model.encoder.gradient_checkpointing = False


# dataset
common_voice = DatasetDict()
common_voice["train"] = load_dataset(args.data_dir, args.data_lang, split="train+validation", use_auth_token=True)
common_voice["test"] = load_dataset(args.data_dir, args.data_lang, split="test", use_auth_token=True)

# remove unused columns
common_voice = common_voice.remove_columns(
    [
        "accent", "age", "client_id", "down_votes", "gender", "locale", "path", "segment", "up_votes"
    ]
)

# select small dataset for testing
if max_train_samples is not None:
    common_voice["train"] = common_voice["train"].select(range(max_train_samples))

if max_test_samples is not None:
    common_voice["test"] = common_voice["test"].select(range(max_test_samples))

# resample to 16kHz
common_voice = common_voice.cast_column("audio", Audio(sampling_rate=sampling_rate))


# tokenizer and generation max length
max_length = (
    generation_max_length if generation_max_length is not None else model.config.max_length
)

# function to vectorize dataset
# flax models need decoder_input_ids instead of labels
# we need fixed length inputs for jitted functions
# https://github.com/huggingface/transformers/blob/v4.29.1/src/transformers/models/whisper/feature_extraction_whisper.py#L254
#if return_attention_mask: # rescale from sample (48000) to feature (3000)
def prepare_dataset(batch, rank):

    os.environ["CUDA_VISIBLE_DEVICES"] = str(rank % jax.device_count())

    # load and resample audio data from 48 to 16kHz
    audio = batch["audio"]

    # compute log-Mel input features from input audio array 
    # 80 x 3000
    batch["input_features"] = feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]

    # encode target text to label ids 
    labels = tokenizer(
        batch["sentence"],
        padding="max_length",
        max_length=max_length,
        return_tensors="np")

    # labels to compute loss
    # 1 x generation length or max length
    batch["labels"] = labels["input_ids"].flatten()
    decoder_input_ids = shift_tokens_right(
        labels["input_ids"], model.config.pad_token_id, model.config.decoder_start_token_id
    )
    # decoder_input_ids to feed into the flax model
    batch["decoder_input_ids"] = np.asarray(decoder_input_ids).flatten()

    # we need decoder_attention_mask so we can ignore pad tokens from loss
    # completely masks decoder_input_ids
    # leaves first pad token (after input ids) unmasked in labels
    # need different mask for labels?
    batch["decoder_attention_mask"] = labels["attention_mask"].flatten()

    return batch


# vectorize dataset
# input_features, decoder_input_ids, decoder_attention_mask, labels
## multithreading errors ##
common_voice = common_voice.map(
    prepare_dataset,
    with_rank=True,
    remove_columns=common_voice.column_names["train"],
    desc="vectorize dataset",
    num_proc=num_workers,
) 

# train and test datasets
train_dataset = common_voice["train"]
test_dataset = common_voice["test"]


# metrics
cer = evaluate.load("cer")
wer = evaluate.load("wer")

def compute_metrics(preds, labels):
    result = {}
    predictions = processor.batch_decode(
        preds,
        skip_special_tokens=True,
        clean_up_tokenization_spaces=True
    )
    references = processor.batch_decode(
        labels,
        group_tokens=False,
        skip_special_tokens=True,
        clean_up_tokenization_spaces=True
    )
    # compute cer, wer
    result["cer"] = cer.compute(predictions=predictions, references=references)
    result["wer"] = wer.compute(predictions=predictions, references=references)
    return result
        

# enable tensorboard only on the master node
has_tensorboard = is_tensorboard_available()
if has_tensorboard and jax.process_index() == 0:
    try:
        from flax.metrics.tensorboard import SummaryWriter
        summary_writer = SummaryWriter(log_dir=Path(args.output_dir))
    except ImportError as ie:
        has_tensorboard = False
        logger.warning(
            f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
        )
else:
    logger.warning(
        "Unable to display metrics through TensorBoard because the package is not installed: "
        "Please run pip install tensorboard to enable."
    )


# compute effective batch size
train_batch_size = int(per_device_train_batch_size) * jax.device_count()
eval_batch_size = int(per_device_eval_batch_size) * jax.device_count()

# eval steps in eval_dataset
# different from args.eval_steps
eval_steps = math.ceil(len(common_voice["test"]) / eval_batch_size)

In [None]:
# create learning rate schedule
warmup_fn = optax.linear_schedule(
    init_value=0.0, end_value=learning_rate, transition_steps=warmup_steps
)
decay_fn = optax.linear_schedule(
    init_value=learning_rate,
    end_value=0,
    transition_steps=train_steps - warmup_steps,
)
linear_decay_lr_schedule_fn = optax.join_schedules(
    schedules=[warmup_fn, decay_fn], boundaries=[warmup_steps]
)

# we use Optax's "masking" functionality to not apply weight decay
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
# mask boolean with the same structure as the parameters.
# the mask is True for parameters that should be decayed.
def decay_mask_fn(params):
    flat_params = traverse_util.flatten_dict(params)
    # find out all LayerNorm parameters
    layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
    layer_norm_named_params = {
        layer[-2:]
        for layer_norm_name in layer_norm_candidates
        for layer in flat_params.keys()
        if layer_norm_name in "".join(layer).lower()
    }
    flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
    return traverse_util.unflatten_dict(flat_mask)
    
# create adam optimizer
adamw = optax.adamw(
    learning_rate=linear_decay_lr_schedule_fn,
    b1=adam_beta1,
    b2=adam_beta2,
    eps=adam_epsilon,
    weight_decay=weight_decay,
    mask=decay_mask_fn,
)

In [None]:
# setup train state
# FlaxWhisperForConditionalGenerationModule -> __call__ -> self.model -> FlaxWhisperModule
#state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng)
state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw)
    

# define gradient update step fn
# batch -> input_features, decoder_input_ids, decoder_attention_mask, labels
# cant print values inside a jit compiled function

# pmap -> replicate your model on devices, shard your data,
# and have each calculate their individual loss and gradients.
# pmean to average them across all devices and apply your gradient (psum)

# use pjit for model sharding
def train_step(state, batch, dropout_rng, label_smoothing_factor=label_smoothing_factor):
    #dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
    dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)

    def compute_loss(params):
        labels = batch.pop("labels")
        logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
        # decoder_attention_mask completely masks decoder_input_ids
        # leaves first pad token (after input ids) unmasked in labels
        loss, num_labels = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor)
        return loss, num_labels

    # value_and_grad
    # creates a function that evaluates both fun and the gradient of fun
    # returns a function with the same arguments as fun that evaluates both fun 
    # and the gradient of fun and returns them as a pair
    # argnums -> which positional argument(s) to differentiate with respect to (default 0).
    # if has_aux is True then a tuple of ((value, auxiliary_data), gradient) is returned.
    grad_fn = jax.value_and_grad(compute_loss, has_aux=True)
    (loss, num_labels), grad = grad_fn(state.params)
    num_labels = jax.lax.psum(num_labels, "batch")  # AllReduce

    # true loss = total loss / total samples
    loss = jax.lax.psum(loss, "batch")  # AllReduce
    loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss)

    # true grad = total grad / total samples
    grad = jax.lax.psum(grad, "batch")  # AllReduce
    # divide replicated grad by num_labels (number of actual values)
    grad = jax.tree_util.tree_map(lambda x: x / num_labels, grad)
    #new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)
    new_state = state.apply_gradients(grads=grad)

    metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step), "num_labels": num_labels}
    return new_state, metrics, new_dropout_rng


# define eval fn
def eval_step(params, batch, label_smoothing_factor=0.0):

    labels = batch.pop("labels")
    logits = model(**batch, params=params, train=False)[0]

    loss, num_labels = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor)
    num_labels = jax.lax.psum(num_labels, "batch") # AllReduce

    # true loss = total loss / total samples
    loss = jax.lax.psum(loss, "batch")  # AllReduce
    loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss)

    metrics = {"loss": loss}
    return metrics
    

# generation functions

def make_generation_config(supress_en=False):

    generation_config = GenerationConfig.from_pretrained(model_name_or_path)
    gen_dict = generation_config.to_dict()
    # add attributes to genration_config
    # generation_config does not have "langauge", but generate() tries to use it
    # can be empty dict here since being set in generate_step
    gen_dict["language"] = LANG_TO_ID[model_lang]
    if supress_en:
        # en tokens to suppress from multilingual vocab
        en_tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny.en")  # change if loaded locally
        suppress_en_list = []
        for key in en_tokenizer.encoder.keys():
            if key in tokenizer.encoder.keys() and key.isalpha():
                suppress_en_list.append(key)
        # supress english tokens
        gen_dict['suppress_tokens'].extend(tokenizer.encode(suppress_en_list, add_special_tokens=False))

    # reload with new attributes
    generation_config = GenerationConfig.from_dict(gen_dict)

    return generation_config


# max_length defined after tokenizer
num_beams = num_beams if num_beams is not None else model.config.num_beams
gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
# generation config
generation_config = make_generation_config(supress_en=False)

# batch -> input_features, decoder_input_ids, decoder_attention_mask, labels
def generate_step(params, batch):
    model.params = params
    output_ids = model.generate(
        batch["input_features"],
        generation_config=generation_config,
        task=task,
        language=LANG_TO_ID[model_lang],  # set lang here
        is_multilingual=True,
        **gen_kwargs
    )   
    return output_ids.sequences

# create parallel version of the train and eval step
# applying pmap() to a function will compile the function with XLA (similarly to jit()),
# then execute it in parallel on XLA devices, such as multiple GPUs or multiple TPU cores.
# it eplicates the function and executes each replica on its own XLA device in parallel.
# donate_argnums -> specify which positional argument buffers are “donated” to the computation.
# it is safe to donate argument buffers if you no longer need them once the computation has finished.
# you should not reuse buffers that you donate to a computation,
# jax will raise an error if you try to.
# donate_argnums only work for positional arguments.
p_train_step = jax.pmap(
    partial(train_step, label_smoothing_factor=label_smoothing_factor), "batch", donate_argnums=(0,)
)
p_eval_step = jax.pmap(partial(eval_step, label_smoothing_factor=label_smoothing_factor), "batch")
p_generate_step = jax.pmap(generate_step, "batch")

In [None]:
# total steps
global_step = 0  
train_time = 0

In [None]:
# init checkpointer
orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
options = orbax.checkpoint.CheckpointManagerOptions(max_to_keep=max_to_keep, create=True)
# checkpoint manager
checkpoint_manager = orbax.checkpoint.CheckpointManager(
    output_dir,
    orbax_checkpointer,
    options
)     

# load from previous checkpoint
if os.path.isdir(output_dir):
    print('checkpoints found')
    # get latest checkpoint
    step = checkpoint_manager.latest_step()  # or choose step
    print('restoring step : {}'.format(step))

    # empty state and config to load state into
    empty_state = train_state.TrainState.create(
        apply_fn=model.__call__,
        params=jax.tree_map(np.zeros_like, model.params),  # values of the tree leaf doesn't matter
        tx=adamw,
        #dropout_rng=dropout_rng
    )
    empty_config = model.config
    #target = {'model': empty_state, 'config': empty_config, 'data': [jnp.zeros_like(x1)]}
    target = {'state': empty_state, 'config': empty_config}  # state or model -> automate maybe

    # restore
    restored = checkpoint_manager.restore(step, items=target)
    state = restored['state']
    global_step = step

else:
    raise ValueError(
        f"no checkpoint found"
    )


# write fixed hyoerparameters to tensorboard
if has_tensorboard and jax.process_index() == 0:
    summary_writer.scalar("train_batch_size", train_batch_size, global_step + 1)
    summary_writer.scalar("eval_batch_size", eval_batch_size, global_step + 1)

    
logger.info("***** Running training *****")
logger.info(f"  Num examples = {len(common_voice['train'])}")
logger.info(f"  Num steps = {train_steps}")
logger.info(f"  Instantaneous batch size per device = {per_device_train_batch_size}")
logger.info(f"  Total train batch size (w. parallel & distributed) = {train_batch_size}")


# replicate the train state on each device
#state = state.replicate()
state = jax_utils.replicate(state)

Training

In [None]:
# Training

# initialize training
rng = jax.random.PRNGKey(seed)
dropout_rngs = jax.random.split(rng, jax.device_count())  # jax.local_device_count()

# main progress bar
progress_bar = tqdm(range(global_step, train_steps), position=0)

train_start = time.time()
train_metrics = []

def train():

    while True:

        # train
        # create sampling rng
        rng, input_rng = jax.random.split(rng)
        train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True)

        for batch in train_loader:
            # input_features : b x 80 x 3000
            # decoder_input_ids : b x max_length
            # decoder_attention_mask : b X max_length
            # labels : b X max_length

            # check with multi gpu
            # shard changes dim
            batch = shard(batch) 
            state, train_metric, dropout_rngs = p_train_step(state, batch, dropout_rngs) 
            train_metrics.append(train_metric)

            progress_bar.update(1)

            # eval
            # eval_loss with eval_step
            # cer, wer with generate_step
            if (global_step + 1) % eval_steps == 0:
            #if True:  # for debugging eval step
                train_time += time.time() - train_start
                eval_metrics = []
                eval_preds = []
                eval_labels = []
                result_dict = {}

                train_metric = unreplicate(train_metric)
                    
                progress_bar.write(
                    f"Step ({global_step + 1} | Loss: {train_metric['loss']}, Learning Rate:"
                    f" {train_metric['learning_rate']})"
                )

                eval_loader = data_loader(input_rng, test_dataset, eval_batch_size, shuffle=True)

                # eval progress bar
                eval_bar = tqdm(range(eval_steps), position=1)
                for batch in eval_loader:
                    labels = batch["labels"]
                    metrics = pad_shard_unpad(p_eval_step, static_return=True)(
                        state.params, batch, min_device_batch=per_device_eval_batch_size)
                    eval_metrics.append(metrics)

                    # generation
                    generated_ids = pad_shard_unpad(p_generate_step)(state.params, batch)
                    eval_preds.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"])))
                    eval_labels.extend(labels)
        
                    eval_bar.update(1)

                # train metrics (loss)
                train_metrics = get_metrics(train_metrics)  # dict
                train_metrics = jax.tree_util.tree_map(jnp.mean, train_metrics)  # dict
                # normalize eval metrics
                # eval metrics (loss)
                eval_metrics = get_metrics(eval_metrics)  # dict
                eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics)  # dict
                # cer, wer
                result = compute_metrics(eval_preds, eval_labels)
                eval_metrics.update(result)

                # collect results together
                result_dict['train_time'] = train_time
                result_dict['train_loss'] = train_metrics['loss']
                result_dict['eval_loss'] = eval_metrics['loss']
                result_dict['cer'] = eval_metrics['cer']
                result_dict['wer'] = eval_metrics['wer']

                # write to terminal and tensorboard
                for key, val in result_dict.items():
                    print('{} : {}'.format(key, val))
                if has_tensorboard and jax.process_index() == 0:
                    for key, val in result_dict.items():
                        summary_writer.scalar(key, val, global_step + 1)

                # save the model, optimizer, lr_scheduler, and seed states 
                ckpt = {'state': unreplicate(state), 'config': model.config}
                save_args = orbax_utils.save_args_from_target(ckpt)
                checkpoint_manager.save(global_step + 1, ckpt, save_kwargs={'save_args': save_args})


            global_step += 1
            train_metrics = []

            if global_step >= train_steps : return

In [None]:
train()

Eval

In [None]:
eval_metrics = []
eval_preds = []
eval_labels = []
result_dict = {}

rng, input_rng = jax.random.split(rng)

eval_loader = data_loader(input_rng, test_dataset, eval_batch_size, shuffle=True)

# eval progress bar
eval_bar = tqdm(range(eval_steps), position=0)
for batch in eval_loader:
    labels = batch["labels"]
    
    metrics = pad_shard_unpad(p_eval_step, static_return=True)(
        state.params, batch, min_device_batch=per_device_eval_batch_size)  # new_state
    eval_metrics.append(metrics)

    # generation
    generated_ids = pad_shard_unpad(p_generate_step)(state.params, batch)  # new_state
    eval_preds.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"])))
    eval_labels.extend(labels)
    
    eval_bar.update(1)