### 1. Set the JAX platform (CPU/TPU) and matmul precision (if on TPU)

In [1]:
import os
os.environ["JAX_PLATFORM_NAME"] = "cpu"
#os.environ["JAX_DEFAULT_MATMUL_PRECISION"]="float32"

### 2. Import libraries

In [2]:
import datasets
import numpy as np
from datasets import DatasetDict, load_dataset
from dataclasses import dataclass
from transformers import AutoConfig, AutoModelForSpeechSeq2Seq, FlaxAutoModelForSpeechSeq2Seq, AutoFeatureExtractor, AutoTokenizer, AutoProcessor, FlaxSpeechEncoderDecoderModel, SpeechEncoderDecoderModel
from models.modeling_flax_speech_encoder_decoder import FlaxSpeechEncoderDecoderModel as CustomFlaxSpeechEncoderDecoderModel
import flax
import optax
import jax.numpy as jnp
import jax
from typing import Any, Callable, Dict, List, Optional, Union
from numpy.random import default_rng
import tempfile
from flax.traverse_util import flatten_dict, unflatten_dict
import torch
from flax.training.common_utils import onehot

  from .autonotebook import tqdm as notebook_tqdm
I0000 00:00:1650545753.222788   84139 tpu_initializer_helper.cc:94] libtpu.so already in use by another process. Run "$ sudo lsof -w /dev/accel0" to figure out which process is using the TPU. Not attempting to load libtpu.so in this process.


### 3. Set model, training and data args

In [3]:
# model args
tiny = False

if tiny:
    encoder_id = "hf-internal-testing/tiny-random-wav2vec2"
    decoder_id = "hf-internal-testing/tiny-random-bart"

else:
    encoder_id = "facebook/wav2vec2-large-lv60"
    decoder_id = "patrickvonplaten/bart-large-fp32"
    
# training args
batch_size_per_update = 2
gradient_accumulation_steps = 1

# data args
dataset_name = "librispeech_asr"
dataset_config_name = "clean"
train_split_name = "train.100[:5%]"
eval_split_name = "validation[:5%]"
dataset_cache_dir = "/home/sanchitgandhi/cache/huggingface/datasets"
audio_column_name = "audio"
text_column_name = "text"
do_lower_case = True

max_duration_in_seconds = 5
min_duration_in_seconds = 0
max_target_length = 32
min_target_length = 0
pad_input_to_multiple_of = 32000
pad_target_to_multiple_of = None
max_train_samples = max_eval_samples = None
preprocessing_num_workers = num_workers = 1

### 4. Load dataset

In [4]:
raw_datasets = DatasetDict()
raw_datasets["train"] = load_dataset(
            dataset_name,
            dataset_config_name,
            split=train_split_name,
            cache_dir=dataset_cache_dir,
        )

Reusing dataset librispeech_asr (/home/sanchitgandhi/cache/huggingface/datasets/librispeech_asr/clean/2.1.0/1f4602f6b5fed8d3ab3e3382783173f2e12d9877e98775e34d7780881175096c)


In [5]:
raw_datasets["eval"] = load_dataset(
            dataset_name,
            dataset_config_name,
            split=eval_split_name,
            cache_dir=dataset_cache_dir,
        )

Reusing dataset librispeech_asr (/home/sanchitgandhi/cache/huggingface/datasets/librispeech_asr/clean/2.1.0/1f4602f6b5fed8d3ab3e3382783173f2e12d9877e98775e34d7780881175096c)


### 5. Load pretrained model, tokenizer, and feature extractor

In [6]:
# Distributed training:
# The .from_pretrained methods guarantee that only one local process can concurrently
fx_model = CustomFlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained(encoder_id, decoder_id, encoder_from_pt=tiny, decoder_from_pt=True)

fx_model.config.decoder_start_token_id = fx_model.config.decoder.bos_token_id
fx_model.config.pad_token_id = fx_model.config.decoder.pad_token_id
fx_model.config.eos_token_id = fx_model.config.decoder.eos_token_id
fx_model.config.processor_class = "Wav2Vec2Processor"

# check if generation works
fx_out = fx_model.generate(jnp.ones((1, 2000)))

encoder checkpointing: False
encoder scan: True


tcmalloc: large alloc 1269579776 bytes == 0x62c98000 @  0x7f76e53be680 0x7f76e53df824 0x5f8a01 0x648cf1 0x5c4676 0x4f290e 0x64f718 0x5048b3 0x56b1da 0x56939a 0x5f6a13 0x50aa2c 0x5f3547 0x56c8cd 0x56939a 0x50aaa0 0x56c28c 0x56939a 0x68d047 0x6003a4 0x5c4a40 0x56b0ae 0x5002d8 0x56cadf 0x5002d8 0x56cadf 0x5002d8 0x503fb6 0x56b1da 0x5f6836 0x56b0ae
Some weights of the model checkpoint at facebook/wav2vec2-large-lv60 were not used when initializing FlaxWav2Vec2Model: {('encoder', 'layers', '21', 'attention', 'out_proj', 'bias'), ('encoder', 'layers', '2', 'layer_norm', 'scale'), ('encoder', 'layers', '3', 'feed_forward', 'intermediate_dense', 'kernel'), ('encoder', 'layers', '16', 'feed_forward', 'intermediate_dense', 'bias'), ('encoder', 'layers', '20', 'attention', 'v_proj', 'bias'), ('encoder', 'layers', '8', 'attention', 'out_proj', 'bias'), ('encoder', 'layers', '8', 'feed_forward', 'intermediate_dense', 'bias'), ('encoder', 'layers', '7', 'attention', 'out_proj', 'kernel'), ('encoder'

decoder checkpointing: False
decoder scan: True


Some weights of the model checkpoint at patrickvonplaten/bart-large-fp32 were not used when initializing FlaxBartForCausalLM: {('decoder', 'layers', '3', 'self_attn', 'out_proj', 'kernel'), ('encoder', 'layers', '4', 'fc1', 'kernel'), ('decoder', 'layers', '2', 'encoder_attn_layer_norm', 'kernel'), ('decoder', 'layers', '0', 'encoder_attn', 'out_proj', 'kernel'), ('encoder', 'layers', '2', 'self_attn', 'q_proj', 'bias'), ('decoder', 'layers', '9', 'self_attn', 'out_proj', 'kernel'), ('encoder', 'layers', '8', 'self_attn', 'k_proj', 'kernel'), ('encoder', 'layers', '7', 'fc2', 'bias'), ('decoder', 'layers', '4', 'encoder_attn', 'out_proj', 'bias'), ('encoder', 'layers', '9', 'final_layer_norm', 'bias'), ('encoder', 'layers', '10', 'self_attn_layer_norm', 'kernel'), ('decoder', 'layers', '5', 'self_attn', 'v_proj', 'bias'), ('decoder', 'layers', '6', 'self_attn', 'q_proj', 'bias'), ('decoder', 'layers', '11', 'encoder_attn', 'q_proj', 'kernel'), ('decoder', 'layers', '6', 'fc2', 'bias'),

encoder checkpointing: False
encoder scan: True


2022-04-21 12:56:22.149184: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory


decoder checkpointing: False
decoder scan: True
encoder checkpointing: False
encoder scan: True
decoder checkpointing: False
decoder scan: True
decoder checkpointing: False
decoder scan: True


In [7]:
pt_model = SpeechEncoderDecoderModel.from_encoder_decoder_pretrained(encoder_id, decoder_id)
pt_model.config.decoder_start_token_id = pt_model.config.decoder.bos_token_id
pt_model.config.pad_token_id = pt_model.config.decoder.pad_token_id
pt_model.config.eos_token_id = pt_model.config.decoder.eos_token_id
pt_model.config.processor_class = "Wav2Vec2Processor"

# check if generation works
pt_out = pt_model.generate(torch.ones((1, 2000)))

Some weights of the model checkpoint at facebook/wav2vec2-large-lv60 were not used when initializing Wav2Vec2Model: ['project_hid.weight', 'project_q.weight', 'project_q.bias', 'quantizer.codevectors', 'quantizer.weight_proj.weight', 'quantizer.weight_proj.bias', 'project_hid.bias']
- This IS expected if you are initializing Wav2Vec2Model from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Wav2Vec2Model from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at patrickvonplaten/bart-large-fp32 were not used when initializing BartForCausalLM: ['encoder.layers.6.self_attn.k_proj.bias', 'encoder.layers.0.self_attn.v_proj.bias', 'encoder.layers.5.final_layer_norm.bias', 'encod

In [8]:
feature_extractor = AutoFeatureExtractor.from_pretrained(encoder_id)
processor = AutoProcessor.from_pretrained(encoder_id)
tokenizer = AutoTokenizer.from_pretrained(decoder_id)

In [9]:
if fx_model.config.decoder_start_token_id or pt_model.config.decoder_start_token_id is None:
    raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")

### 6. Convert unrolled weights to scanned

In [10]:
# Convert the PT model to FX to enable manipulation of param dicts (PT state dict -> FX param dict)
with tempfile.TemporaryDirectory() as tmpdirname:
    pt_model.save_pretrained(tmpdirname)
    pt_model_to_fx = FlaxAutoModelForSpeechSeq2Seq.from_pretrained(tmpdirname, from_pt=True)
    
def unrolled_to_scanned(model):
    params = model.params
    new_enc_params = {}
    # get the key of a scanned module
    for k in flatten_dict(params['encoder']['encoder']['layers']['0']):
        # stack the weights for each layer of the scanned module into one matrix
        new_enc_params[k] = jnp.stack([flatten_dict(params['encoder']['encoder']['layers'][str(i)])[k] for i in range(model.config.encoder.num_hidden_layers)])
    # append the correct prefix to the scanned modules' keys
    new_enc_params = unflatten_dict({('encoder', 'layers', 'FlaxWav2Vec2EncoderLayers'): unflatten_dict(new_enc_params)})
    
    # repeat for the decoder (note that the key 'layers' appears one index to the right than in the encoder, thus we'll treat the encoder and decoder independently for now)
    new_dec_params = {}
    for k in flatten_dict(params['decoder']['model']['decoder']['layers']['0']):
        new_dec_params[k] = jnp.stack([flatten_dict(params['decoder']['model']['decoder']['layers'][str(i)])[k] for i in range(model.config.decoder.decoder_layers)])
    new_dec_params = unflatten_dict({('model', 'decoder', 'layers', 'FlaxBartDecoderLayers'): unflatten_dict(new_dec_params)})
    
    # combine the encoder and decoder parameters
    new_params = {'encoder': new_enc_params, 'decoder': new_dec_params}
    new_params = flatten_dict(new_params)
    
    # append parameters for non-scanned modules (i.e. all modules that do not contain the key 'layers')
    for k in flatten_dict(params):
        if 'layers' not in k:
            new_params[k] = flatten_dict(params)[k]

    return unflatten_dict(new_params)

fx_model.params = unrolled_to_scanned(pt_model_to_fx)

Some weights of the model checkpoint at /tmp/tmpw4933bqz were not used when initializing FlaxSpeechEncoderDecoderModel: {('decoder', 'lm_head', 'kernel')}
- This IS expected if you are initializing FlaxSpeechEncoderDecoderModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing FlaxSpeechEncoderDecoderModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


### 7. Resample speech dataset if necessary

In [11]:
# We use Torch audio in this resampling step for convinience
dataset_sampling_rate = next(iter(raw_datasets.values())).features[audio_column_name].sampling_rate
if dataset_sampling_rate != feature_extractor.sampling_rate:
    raw_datasets = raw_datasets.cast_column(
        audio_column_name, datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate)
    )

### 8. Preprocessing the datasets

In [12]:
# Define some constants
max_input_length = int(max_duration_in_seconds * feature_extractor.sampling_rate)
min_input_length = int(min_duration_in_seconds * feature_extractor.sampling_rate)

model_input_name = feature_extractor.model_input_names[0]

In [13]:
# Truncate data to max_samples
if max_train_samples is not None:
        raw_datasets["train"] = raw_datasets["train"].select(range(max_train_samples))

if max_eval_samples is not None:
    raw_datasets["eval"] = raw_datasets["eval"].select(range(max_eval_samples))

In [14]:
def prepare_dataset(batch):
    # process audio
    sample = batch[audio_column_name]
    inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"])
    # process audio length
    batch[model_input_name] = inputs.input_values[0]
    batch["input_length"] = len(batch["input_values"])

    # process targets
    input_str = batch[text_column_name].lower() if do_lower_case else batch[text_column_name]
    batch["labels"] = tokenizer(input_str).input_ids
    batch["labels_length"] = len(batch["labels"])
    return batch

In [15]:
vectorized_datasets = raw_datasets.map(
            prepare_dataset,
            remove_columns=next(iter(raw_datasets.values())).column_names,
            num_proc=preprocessing_num_workers,
            desc="preprocess train dataset",
        )

preprocess train dataset: 100%|█████████████████████████████████████████████████████| 1427/1427 [00:16<00:00, 84.58ex/s]
preprocess train dataset: 100%|██████████████████████████████████████████████████████| 135/135 [00:00<00:00, 193.54ex/s]


In [16]:
# filter data with inputs shorter than min_input_length or longer than max_input_length
def is_audio_in_length_range(length):
    return length > min_input_length and length < max_input_length

vectorized_datasets = vectorized_datasets.filter(
    is_audio_in_length_range,
    num_proc=num_workers,
    input_columns=["input_length"],
)

100%|█████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 30.71ba/s]
100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 733.78ba/s]


In [17]:
# filter data with targets shorter than min_target_length or longer than max_target_length
def is_labels_in_length_range(length):
    return length > min_target_length and length < max_target_length

vectorized_datasets = vectorized_datasets.filter(
    is_labels_in_length_range,
    num_proc=num_workers,
    input_columns=["labels_length"],
)

100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 299.72ba/s]
100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 563.83ba/s]


### 9. Define DataCollators

In [18]:
# PyTorch DataCollator
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any
    decoder_start_token_id: int

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lenghts and need
        # different padding methods
        input_features = [{"input_values": feature["input_values"]} for feature in features]
        label_features = [{"input_ids": feature["labels"]} for feature in features]

        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        # if bos token is appended in previous tokenization step,
        # cut bos token here as it's append later anyways
        if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels

        return batch

In [19]:
# Flax DataCollator
@flax.struct.dataclass
class FlaxDataCollatorSpeechSeq2SeqWithPadding:
    processor: Any
    decoder_start_token_id: int
    input_padding: Union[bool, str] = "longest"
    target_padding: Union[bool, str] = "max_length"
    max_input_length: Optional[float] = None
    max_target_length: Optional[int] = None
    pad_input_to_multiple_of: Optional[int] = None
    pad_target_to_multiple_of: Optional[int] = None

    def __call__(self, features: List[Dict[str, Union[List[int], np.ndarray]]]) -> Dict[str, np.ndarray]:
        # split inputs and labels since they have to be of different lengths and need
        # different padding methods
        input_features = [{"input_values": feature["input_values"]} for feature in features]
        label_features = [{"input_ids": feature["labels"]} for feature in features]

        # reformat list to dict and set to pytorch format
        batch = self.processor.feature_extractor.pad(
            input_features,
            max_length=self.max_input_length,
            padding=self.input_padding,
            pad_to_multiple_of=self.pad_input_to_multiple_of,
            return_tensors="np",
        )

        labels_batch = self.processor.tokenizer.pad(
            label_features,
            max_length=self.max_target_length,
            padding=self.target_padding,
            pad_to_multiple_of=self.pad_target_to_multiple_of,
            return_tensors="np",
        )

        # if bos token is appended in previous tokenization step,
        # cut bos token here as it's append later anyways
        labels = labels_batch["input_ids"]
        if (labels[:, 0] == self.decoder_start_token_id).all().item():
            labels = labels[:, 1:]
            labels_batch.attention_mask = labels_batch.attention_mask[:, 1:]

        decoder_input_ids = shift_tokens_right(labels, self.decoder_start_token_id)

        # replace padding with -100 to ignore correctly when computing the loss
        labels = np.ma.array(labels, mask=np.not_equal(labels_batch.attention_mask, 1))
        labels = labels.filled(fill_value=-100)

        batch["inputs"] = batch.pop("input_values")
        batch["labels"] = labels
        batch["decoder_input_ids"] = decoder_input_ids

        return batch
    
def shift_tokens_right(label_ids: np.array, decoder_start_token_id: int) -> np.ndarray:
    """
    Shift label ids one token to the right.
    """
    shifted_label_ids = np.zeros_like(label_ids)
    shifted_label_ids[:, 1:] = label_ids[:, :-1]
    shifted_label_ids[:, 0] = decoder_start_token_id

    return shifted_label_ids

### 10. Define length grouped sampler (PT and FX compatible)

In [20]:
def get_grouped_indices(
    dataset, batch_size: int, rng: Optional[List[int]] = None, mega_batch_mult: Optional[int] = None
) -> np.array:
    lengths = dataset["input_length"]

    # Default for mega_batch_mult: 50 or the number to get 4 megabatches, whichever is smaller.
    if mega_batch_mult is None:
        mega_batch_mult = min(len(lengths) // (batch_size * 4), 50)
        # Just in case, for tiny datasets
        if mega_batch_mult == 0:
            mega_batch_mult = 1

    # We need to use JAX for the random permutation as the PRNG key will be set based on the seed outside of the sampler.
    num_samples = len(lengths)
    indices = jax.random.permutation(rng, np.arange(num_samples)) if rng is not None else np.arange(num_samples)

    megabatch_size = mega_batch_mult * batch_size
    megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)]
    megabatches = [list(sorted(megabatch, key=lambda i: lengths[i], reverse=True)) for megabatch in megabatches]

    # The rest is to get the biggest batch first.
    # Since each megabatch is sorted by descending length, the longest element is the first
    megabatch_maximums = [lengths[megabatch[0]] for megabatch in megabatches]
    max_idx = np.argmax(megabatch_maximums).item()
    # Switch to put the longest batch in first position
    # (note that this is different to the PT grouped sampler in which we only put the longest element in the first position, and not its batch)
    megabatches[0], megabatches[max_idx] = megabatches[max_idx], megabatches[0]

    megabatches = np.array([i for megabatch in megabatches for i in megabatch])

    return megabatches

In [21]:
# Function to group samples into batch splits
def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndarray:
    num_samples = len(samples_idx)
    samples_to_remove = num_samples % batch_size

    if samples_to_remove != 0:
        samples_idx = samples_idx[:-samples_to_remove]
    sections_split = num_samples // batch_size
    batch_idx = np.split(samples_idx, sections_split)
    return batch_idx

### 11. Helper funcitons for our analysis

In [22]:
def assert_almost_equals(a: np.ndarray, b: np.ndarray, tol: float = 1e-2):
    """Assert whether the maximum absolute difference between two NumPy arrays a and b is within a given tolerance tol. 
    Due to the pad_to_multiple_of nature of the FlaxDataCollator, the length of the Flax array a will always be greater than 
    or equal to the length of the PyTorch array b. If a and b are of different lengths, array a (Flax, padded) will be 
    reshaped to the shape of b (PyTorch)."""
    if a.shape != b.shape:
        a = a[:, :b.shape[1]]
    
    diff = np.abs((a - b)).max()
    if diff < tol:
        print(f"✅ Difference between Flax and PyTorch is {diff} (< {tol})")
    else:
        print(f"❌ Difference between Flax and PyTorch is {diff} (>= {tol}),")

In [23]:
def assert_dict_equal(a: dict, b: dict, tol: float = 1e-2):
    if a.keys() != b.keys():
        print("❌ Dictionary keys for PyTorch and Flax do not match")
    results_fail = []
    results_correct = []

    results_fail_rel = []
    results_correct_rel = []
    for k in a:
        ak_norm = np.linalg.norm(a[k])
        bk_norm = np.linalg.norm(b[k])
        diff = np.abs(ak_norm - bk_norm)
        diff_rel = np.abs(ak_norm - bk_norm) / np.abs(ak_norm)
        if diff < tol:
            results_correct.append(f"✅ Layer {k} diff is {diff} < {tol}).")
        else:
            results_fail.append(f"❌ Layer {k} has PT grad norm {bk_norm} and flax grad norm {ak_norm}.")
        if diff_rel < tol:
            results_correct_rel.append(f"✅ Layer {k} rel diff is {diff} < {tol}).")
        else:
            results_fail_rel.append(f"❌ Layer {k} has PT grad norm {bk_norm} and flax grad norm {ak_norm}.")
    return results_fail_rel, results_correct_rel, results_fail, results_correct

In [24]:
def softmax(x):
    """Compute softmax values for each sets of scores in x."""
    e_x = np.exp(x - np.max(x))
    return e_x / e_x.sum(axis=0)

def kl_divergence(a: np.ndarray, b:np.ndarray, epsilon=1e-6, tol: float = 1e-2):
    """Epsilon is used here to avoid conditional code for checking that neither p(a) nor p(b) is equal to 0."""
    if a.shape[1] != b.shape[1]:
        a = a[:, :b.shape[1], :]
        
    p_a = softmax(a) + epsilon
    p_b = softmax(b) + epsilon
    divergence = np.sum(p_b * np.log(p_b / p_a))
    if divergence < tol:
        print(f"✅ KL divergence between Flax and PyTorch is {divergence} (< {tol})")
    else:
        print(f"❌ KL divergence between Flax and PyTorch is {divergence} (>= {tol})")

### 12. Instantiate data collators and generate batches

In [25]:
# Instantiate the PT and FX DataCollators
pt_data_collator = DataCollatorSpeechSeq2SeqWithPadding(
        processor=processor,
        decoder_start_token_id=pt_model.config.decoder_start_token_id,
    )

fx_data_collator = FlaxDataCollatorSpeechSeq2SeqWithPadding(
        processor=processor,
        decoder_start_token_id=fx_model.config.decoder_start_token_id,
        input_padding="longest",
        target_padding="max_length",
        max_target_length=max_target_length,
        pad_input_to_multiple_of=pad_input_to_multiple_of,
        pad_target_to_multiple_of=pad_target_to_multiple_of,
    )

In [26]:
# Set JAX seed and generate PRNG for stochasic operations
seed = 0
rng = jax.random.PRNGKey(seed)
rng, input_rng = jax.random.split(rng)

In [27]:
# We'll naively create our batches through random shuffling and no grouping by length
num_train_samples = len(vectorized_datasets["train"])
train_samples_idx = jax.random.permutation(input_rng, np.arange(num_train_samples))
train_batch_idx = generate_batch_splits(train_samples_idx, batch_size_per_update)

In [28]:
# Alt: we'll use the grouped sampler
train_samples_idx = get_grouped_indices(vectorized_datasets["train"], batch_size_per_update, input_rng)
train_batch_idx = generate_batch_splits(train_samples_idx, batch_size_per_update)

In [39]:
# Treat the first training batch
batch_idx = train_batch_idx[0]
samples = [vectorized_datasets["train"][int(idx)] for idx in batch_idx]

In [40]:
fx_batch = fx_data_collator(samples)
pt_batch = pt_data_collator(samples)

In [41]:
# Convert the Flax inputs to PyTorch (optional)
#pt_batch = {k: torch.tensor(v.tolist()) for k, v in fx_batch.items()}

### 13. Check that the inputs are equal

In [42]:
expected_fx_keys = ["inputs", "labels", "decoder_input_ids"]
expected_pt_keys = ["input_values", "labels"]

for expected_fx_key in expected_fx_keys:
    assert expected_fx_key in fx_batch, f"{expected_fx_key} not in Flax batched inputs"

for expected_pt_key in expected_pt_keys:
    assert expected_pt_key in pt_batch, f"{expected_pt_key} not in PyTorch batched inputs"    

# Expect the keys between Flax and PyTorch to be different, this is just for observation
fx_batch.keys(), pt_batch.keys()

(dict_keys(['attention_mask', 'inputs', 'labels', 'decoder_input_ids']),
 dict_keys(['input_values', 'attention_mask', 'labels']))

In [43]:
assert_almost_equals(fx_batch['inputs'], pt_batch['input_values'].numpy())
assert_almost_equals(fx_batch['labels'], pt_batch['labels'].numpy())
if 'attention_mask' in fx_batch.keys() and pt_batch.keys():
    assert_almost_equals(fx_batch['attention_mask'], pt_batch['attention_mask'].numpy())

✅ Difference between Flax and PyTorch is 0.0 (< 0.01)
✅ Difference between Flax and PyTorch is 0 (< 0.01)
✅ Difference between Flax and PyTorch is 0 (< 0.01)


### 14. Run a training step

In [44]:
pt_outputs = pt_model(**pt_batch)
pt_logits = pt_outputs.logits
pt_loss = pt_outputs.loss
pt_loss.backward()

In [45]:
# Flax cross entropy loss
def loss_fn(logits, labels):
    vocab_size = logits.shape[-1]
    loss = optax.softmax_cross_entropy(logits, onehot(labels, vocab_size))
    # ignore padded tokens from loss, i.e. where labels are not set to -100
    padding = labels >= 0
    loss = loss * padding
    loss = loss.sum() / padding.sum()
    return loss

In [46]:
# Flax training step (single device)
def fx_train_step(fx_model, fx_batch):
    def compute_loss(params):
        labels = fx_batch.pop("labels")
        outputs = fx_model(**fx_batch, params=params)
        loss = loss_fn(outputs.logits, labels)
        return loss, outputs

    grad_fn = jax.value_and_grad(compute_loss, has_aux=True)
    (loss, outputs), grad = grad_fn(fx_model.params)
    
    return loss, outputs, grad

In [47]:
fx_loss, fx_outputs, fx_grad = fx_train_step(fx_model, fx_batch)

encoder checkpointing: False
encoder scan: True
decoder checkpointing: False
decoder scan: True


### 15. Compare outputs for the forward pass

In [48]:
print("--------------------------Checking encoder last hidden states match--------------------------")
assert_almost_equals(fx_outputs.encoder_last_hidden_state, pt_outputs.encoder_last_hidden_state.detach().numpy())
    
print("--------------------------Checking logits match--------------------------")
print(f"Flax logits shape: {fx_outputs.logits.shape}, PyTorch logits shape: {pt_logits.shape}")
assert_almost_equals(fx_outputs.logits, pt_logits.detach().numpy())
kl_divergence(fx_outputs.logits, pt_logits.detach().numpy())

print("--------------------------Checking losses match--------------------------")
print(f"Flax loss: {fx_loss}, PyTorch loss: {pt_loss}")
assert_almost_equals(fx_loss, pt_loss.detach().numpy())

--------------------------Checking encoder last hidden states match--------------------------
✅ Difference between Flax and PyTorch is 0.00025206804275512695 (< 0.01)
--------------------------Checking logits match--------------------------
Flax logits shape: (2, 31, 50265), PyTorch logits shape: torch.Size([2, 17, 50265])
✅ Difference between Flax and PyTorch is 0.00012969970703125 (< 0.01)
✅ KL divergence between Flax and PyTorch is 0.0002297908067703247 (< 0.01)
--------------------------Checking losses match--------------------------
Flax loss: 24.808961868286133, PyTorch loss: 24.80891227722168
✅ Difference between Flax and PyTorch is 4.9591064453125e-05 (< 0.01)


In [None]:
print("--------------------------Checking encoder last hidden states match--------------------------")
assert_almost_equals(fx_outputs.encoder_last_hidden_state, pt_outputs.encoder_last_hidden_state.detach().numpy())
    
print("--------------------------Checking logits match--------------------------")
print(f"Flax logits shape: {fx_outputs.logits.shape}, PyTorch logits shape: {pt_logits.shape}")
assert_almost_equals(fx_outputs.logits, pt_logits.detach().numpy())
kl_divergence(fx_outputs.logits, pt_logits.detach().numpy())

print("--------------------------Checking losses match--------------------------")
print(f"Flax loss: {fx_loss}, PyTorch loss: {pt_loss}")
assert_almost_equals(fx_loss, pt_loss.detach().numpy())

--------------------------Checking encoder last hidden states match--------------------------
✅ Difference between Flax and PyTorch is 0.0006010532379150391 (< 0.01)
--------------------------Checking logits match--------------------------
Flax logits shape: (2, 31, 50265), PyTorch logits shape: torch.Size([2, 21, 50265])
✅ Difference between Flax and PyTorch is 4.9591064453125e-05 (< 0.01)
✅ KL divergence between Flax and PyTorch is 0.000494436826556921 (< 0.01)
--------------------------Checking losses match--------------------------
Flax loss: 25.75644302368164, PyTorch loss: 25.756454467773438
✅ Difference between Flax and PyTorch is 1.1444091796875e-05 (< 0.01)
