### 1. Set the JAX platform (CPU/TPU) and matmul precision (if on TPU)

In [1]:
import os
#os.environ["JAX_PLATFORM_NAME"] = "cpu"
os.environ["JAX_DEFAULT_MATMUL_PRECISION"]="float32"

### 2. Import libraries

In [2]:
import datasets
import numpy as np
from datasets import DatasetDict, load_dataset
from dataclasses import dataclass
from transformers import AutoConfig, AutoModelForSpeechSeq2Seq, FlaxAutoModelForSpeechSeq2Seq, AutoFeatureExtractor, AutoTokenizer, AutoProcessor, FlaxSpeechEncoderDecoderModel, SpeechEncoderDecoderModel
import flax
import optax
import jax.numpy as jnp
import jax
from typing import Any, Callable, Dict, List, Optional, Union
from numpy.random import default_rng
import tempfile
from flax.traverse_util import flatten_dict
import torch
from flax.training.common_utils import onehot

  from .autonotebook import tqdm as notebook_tqdm


### 3. Set model,training and data args

In [3]:
# model args
encoder_id = "hf-internal-testing/tiny-random-wav2vec2"
decoder_id = "hf-internal-testing/tiny-random-bart"

encoder_id = "facebook/wav2vec2-large-lv60"
decoder_id = "facebook/bart-large"

# training args
batch_size_per_update = 2
gradient_accumulation_steps = 1

# data args
dataset_name = "librispeech_asr"
dataset_config_name = "clean"
train_split_name = "train.100[:5%]"
eval_split_name = "validation[:5%]"
dataset_cache_dir = "/home/sanchitgandhi/cache/huggingface/datasets"
audio_column_name = "audio"
text_column_name = "text"
do_lower_case = True

max_duration_in_seconds = 5
min_duration_in_seconds = 0
max_target_length = 32
min_target_length = 0
pad_input_to_multiple_of = 32000
pad_target_to_multiple_of = None
max_train_samples = max_eval_samples = None
preprocessing_num_workers = num_workers = 1

### 4. Load dataset

In [4]:
raw_datasets = DatasetDict()
raw_datasets["train"] = load_dataset(
            dataset_name,
            dataset_config_name,
            split=train_split_name,
            cache_dir=dataset_cache_dir,
        )

Reusing dataset librispeech_asr (/home/sanchitgandhi/cache/huggingface/datasets/librispeech_asr/clean/2.1.0/1f4602f6b5fed8d3ab3e3382783173f2e12d9877e98775e34d7780881175096c)


In [5]:
raw_datasets["eval"] = load_dataset(
            dataset_name,
            dataset_config_name,
            split=eval_split_name,
            cache_dir=dataset_cache_dir,
        )

Reusing dataset librispeech_asr (/home/sanchitgandhi/cache/huggingface/datasets/librispeech_asr/clean/2.1.0/1f4602f6b5fed8d3ab3e3382783173f2e12d9877e98775e34d7780881175096c)


### 5. Load pretrained model, tokenizer, and feature extractor

In [6]:
# Distributed training:
# The .from_pretrained methods guarantee that only one local process can concurrently
fx_model = FlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained(encoder_id, decoder_id, encoder_from_pt=True, decoder_from_pt=True)

fx_model.config.decoder_start_token_id = fx_model.config.decoder.bos_token_id
fx_model.config.pad_token_id = fx_model.config.decoder.pad_token_id
fx_model.config.eos_token_id = fx_model.config.decoder.eos_token_id
fx_model.config.processor_class = "Wav2Vec2Processor"

pt_model = SpeechEncoderDecoderModel.from_encoder_decoder_pretrained(encoder_id, decoder_id)
pt_model.config.decoder_start_token_id = pt_model.config.decoder.bos_token_id
pt_model.config.pad_token_id = pt_model.config.decoder.pad_token_id
pt_model.config.eos_token_id = pt_model.config.decoder.eos_token_id

feature_extractor = AutoFeatureExtractor.from_pretrained(encoder_id)
processor = AutoProcessor.from_pretrained(encoder_id)
tokenizer = AutoTokenizer.from_pretrained(decoder_id)

if fx_model.config.decoder_start_token_id or pt_model.config.decoder_start_token_id is None:
    raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")

Some weights of the model checkpoint at facebook/wav2vec2-large-lv60 were not used when initializing FlaxWav2Vec2Model: {('project_q', 'bias'), ('quantizer', 'weight_proj', 'bias'), ('project_hid', 'kernel'), ('quantizer', 'codevectors'), ('project_q', 'kernel'), ('quantizer', 'weight_proj', 'kernel'), ('project_hid', 'bias')}
- This IS expected if you are initializing FlaxWav2Vec2Model from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing FlaxWav2Vec2Model from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at facebook/bart-large were not used when initializing FlaxBartForCausalLM: {('encoder', 'layers', '4', 'self_attn', 'out_proj', 'bias'), ('encoder', 'layers', '1', 

### 6. Check that the PT and FX weights are identical

In [7]:
# Convert the PT model to FX to enable comparison of param dicts (PT state dict -> FX param dict)
with tempfile.TemporaryDirectory() as tmpdirname:
    pt_model.save_pretrained(tmpdirname)
    pt_model_to_fx = FlaxAutoModelForSpeechSeq2Seq.from_pretrained(tmpdirname, from_pt=True)
    
# It's easier to view the PyTree param dict when flattened
fx_params = flatten_dict(fx_model.params)
pt_params_to_fx = flatten_dict(pt_model_to_fx.params)

# Check that the keys match
assert fx_params.keys() == pt_params_to_fx.keys()

# Check that the parameters are precisely equal
for param in fx_params:
    assert (fx_params[param] == pt_params_to_fx[param]).all(), f"{param} weights are not equal between Flax and PyTorch"
    
# Free CPU memory
del fx_params, pt_params_to_fx, pt_model_to_fx

Some weights of the model checkpoint at /tmp/tmp61gywpfv were not used when initializing FlaxSpeechEncoderDecoderModel: {('decoder', 'lm_head', 'kernel')}
- This IS expected if you are initializing FlaxSpeechEncoderDecoderModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing FlaxSpeechEncoderDecoderModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


### 7. Resample speech dataset if necessary

In [8]:
# We use Torch audio in this resampling step for convinience
dataset_sampling_rate = next(iter(raw_datasets.values())).features[audio_column_name].sampling_rate
if dataset_sampling_rate != feature_extractor.sampling_rate:
    raw_datasets = raw_datasets.cast_column(
        audio_column_name, datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate)
    )

### 8. Preprocessing the datasets

In [9]:
# Define some constants
max_input_length = int(max_duration_in_seconds * feature_extractor.sampling_rate)
min_input_length = int(min_duration_in_seconds * feature_extractor.sampling_rate)

model_input_name = feature_extractor.model_input_names[0]

In [10]:
# Truncate data to max_samples
if max_train_samples is not None:
        raw_datasets["train"] = raw_datasets["train"].select(range(max_train_samples))

if max_eval_samples is not None:
    raw_datasets["eval"] = raw_datasets["eval"].select(range(max_eval_samples))

In [11]:
def prepare_dataset(batch):
    # process audio
    sample = batch[audio_column_name]
    inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"])
    # process audio length
    batch[model_input_name] = inputs.input_values[0]
    batch["input_length"] = len(batch["input_values"])

    # process targets
    input_str = batch[text_column_name].lower() if do_lower_case else batch[text_column_name]
    batch["labels"] = tokenizer(input_str).input_ids
    batch["labels_length"] = len(batch["labels"])
    return batch

In [12]:
vectorized_datasets = raw_datasets.map(
            prepare_dataset,
            remove_columns=next(iter(raw_datasets.values())).column_names,
            num_proc=preprocessing_num_workers,
            desc="preprocess train dataset",
        )



In [13]:
# filter data with inputs shorter than min_input_length or longer than max_input_length
def is_audio_in_length_range(length):
    return length > min_input_length and length < max_input_length

vectorized_datasets = vectorized_datasets.filter(
    is_audio_in_length_range,
    num_proc=num_workers,
    input_columns=["input_length"],
)



In [14]:
# filter data with targets shorter than min_target_length or longer than max_target_length
def is_labels_in_length_range(length):
    return length > min_target_length and length < max_target_length

vectorized_datasets = vectorized_datasets.filter(
    is_labels_in_length_range,
    num_proc=num_workers,
    input_columns=["labels_length"],
)



### 9. Define DataCollators

In [15]:
# PyTorch DataCollator
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any
    decoder_start_token_id: int

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lenghts and need
        # different padding methods
        input_features = [{"input_values": feature["input_values"]} for feature in features]
        label_features = [{"input_ids": feature["labels"]} for feature in features]

        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        # if bos token is appended in previous tokenization step,
        # cut bos token here as it's append later anyways
        if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels

        return batch

In [16]:
# Flax DataCollator
@flax.struct.dataclass
class FlaxDataCollatorSpeechSeq2SeqWithPadding:
    processor: Any
    decoder_start_token_id: int
    input_padding: Union[bool, str] = "longest"
    target_padding: Union[bool, str] = "max_length"
    max_input_length: Optional[float] = None
    max_target_length: Optional[int] = None
    pad_input_to_multiple_of: Optional[int] = None
    pad_target_to_multiple_of: Optional[int] = None

    def __call__(self, features: List[Dict[str, Union[List[int], np.ndarray]]]) -> Dict[str, np.ndarray]:
        # split inputs and labels since they have to be of different lengths and need
        # different padding methods
        input_features = [{"input_values": feature["input_values"]} for feature in features]
        label_features = [{"input_ids": feature["labels"]} for feature in features]

        # reformat list to dict and set to pytorch format
        batch = self.processor.feature_extractor.pad(
            input_features,
            max_length=self.max_input_length,
            padding=self.input_padding,
            pad_to_multiple_of=self.pad_input_to_multiple_of,
            return_tensors="np",
        )

        labels_batch = self.processor.tokenizer.pad(
            label_features,
            max_length=self.max_target_length,
            padding=self.target_padding,
            pad_to_multiple_of=self.pad_target_to_multiple_of,
            return_tensors="np",
        )

        # if bos token is appended in previous tokenization step,
        # cut bos token here as it's append later anyways
        labels = labels_batch["input_ids"]
        if (labels[:, 0] == self.decoder_start_token_id).all().item():
            labels = labels[:, 1:]
            labels_batch.attention_mask = labels_batch.attention_mask[:, 1:]

        decoder_input_ids = shift_tokens_right(labels, self.decoder_start_token_id)

        # replace padding with -100 to ignore correctly when computing the loss
        labels = np.ma.array(labels, mask=np.not_equal(labels_batch.attention_mask, 1))
        labels = labels.filled(fill_value=-100)

        batch["inputs"] = batch.pop("input_values")
        batch["labels"] = labels
        batch["decoder_input_ids"] = decoder_input_ids

        return batch
    
def shift_tokens_right(label_ids: np.array, decoder_start_token_id: int) -> np.ndarray:
    """
    Shift label ids one token to the right.
    """
    shifted_label_ids = np.zeros_like(label_ids)
    shifted_label_ids[:, 1:] = label_ids[:, :-1]
    shifted_label_ids[:, 0] = decoder_start_token_id

    return shifted_label_ids

### 10. Define length grouped sampler (PT and FX compatible)

In [17]:
def get_grouped_indices(
    dataset, batch_size: int, rng: Optional[List[int]] = None, mega_batch_mult: Optional[int] = None
) -> np.array:
    lengths = dataset["input_length"]

    # Default for mega_batch_mult: 50 or the number to get 4 megabatches, whichever is smaller.
    if mega_batch_mult is None:
        mega_batch_mult = min(len(lengths) // (batch_size * 4), 50)
        # Just in case, for tiny datasets
        if mega_batch_mult == 0:
            mega_batch_mult = 1

    # We need to use JAX for the random permutation as the PRNG key will be set based on the seed outside of the sampler.
    num_samples = len(lengths)
    indices = jax.random.permutation(rng, np.arange(num_samples)) if rng is not None else np.arange(num_samples)

    megabatch_size = mega_batch_mult * batch_size
    megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)]
    megabatches = [list(sorted(megabatch, key=lambda i: lengths[i], reverse=True)) for megabatch in megabatches]

    # The rest is to get the biggest batch first.
    # Since each megabatch is sorted by descending length, the longest element is the first
    megabatch_maximums = [lengths[megabatch[0]] for megabatch in megabatches]
    max_idx = np.argmax(megabatch_maximums).item()
    # Switch to put the longest batch in first position
    # (note that this is different to the PT grouped sampler in which we only put the longest element in the first position, and not its batch)
    megabatches[0], megabatches[max_idx] = megabatches[max_idx], megabatches[0]

    megabatches = np.array([i for megabatch in megabatches for i in megabatch])

    return megabatches

In [18]:
# Function to group samples into batch splits
def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndarray:
    num_samples = len(samples_idx)
    samples_to_remove = num_samples % batch_size

    if samples_to_remove != 0:
        samples_idx = samples_idx[:-samples_to_remove]
    sections_split = num_samples // batch_size
    batch_idx = np.split(samples_idx, sections_split)
    return batch_idx

### 11. Helper funcitons for our analysis

In [19]:
def assert_almost_equals(a: np.ndarray, b: np.ndarray, tol: float = 1e-2):
    """Assert whether the maximum absolute difference between two NumPy arrays a and b is within a given tolerance tol. 
    Due to the pad_to_multiple_of nature of the FlaxDataCollator, the length of the Flax array a will always be greater than 
    or equal to the length of the PyTorch array b. If a and b are of different lengths, array a (Flax, padded) will be 
    reshaped to the shape of b (PyTorch)."""
    if a.shape != b.shape:
        a = a[:, :b.shape[1]]
    
    diff = np.abs((a - b))
    if diff.max() < tol:
        print(f"✅ Difference between Flax and PyTorch is {diff.max()} (< {tol}), avg is {diff.mean()}")
    else:
        print(f"❌ Difference between Flax and PyTorch is {diff.max()} (>= {tol}), avg is {diff.mean()}")

In [20]:
def assert_dict_equal(a: dict, b: dict, tol: float = 1e-2):
    if a.keys() != b.keys():
        print("❌ Dictionary keys for PyTorch and Flax do not match")
    results_fail = []
    results_correct = []

    results_fail_rel = []
    results_correct_rel = []
    for k in a:
        ak_norm = np.linalg.norm(a[k])
        bk_norm = np.linalg.norm(b[k])
        diff = np.abs(ak_norm - bk_norm)
        diff_rel = np.abs(ak_norm - bk_norm) / np.abs(ak_norm)
        if diff < tol:
            results_correct.append(f"✅ Layer {k} diff is {diff} < {tol}).")
        else:
            results_fail.append(f"❌ Layer {k} has PT grad norm {bk_norm} and flax grad norm {ak_norm}.")
        if diff_rel < tol:
            results_correct_rel.append(f"✅ Layer {k} rel diff is {diff} < {tol}).")
        else:
            results_fail_rel.append(f"❌ Layer {k} has PT grad norm {bk_norm} and flax grad norm {ak_norm}.")
    return results_fail_rel, results_correct_rel, results_fail, results_correct

In [21]:
def softmax(x):
    """Compute softmax values for each sets of scores in x."""
    e_x = np.exp(x - np.max(x))
    return e_x / e_x.sum(axis=0)

def kl_divergence(a: np.ndarray, b:np.ndarray, epsilon=1e-6, tol: float = 1e-2):
    """Epsilon is used here to avoid conditional code for checking that neither p(a) nor p(b) is equal to 0."""
    if a.shape[1] != b.shape[1]:
        a = a[:, :b.shape[1], :]
        
    p_a = softmax(a) + epsilon
    p_b = softmax(b) + epsilon
    divergence = np.sum(p_b * np.log(p_b / p_a))
    if divergence < tol:
        print(f"✅ KL divergence between Flax and PyTorch is {divergence} (< {tol})")
    else:
        print(f"❌ KL divergence between Flax and PyTorch is {divergence} (>= {tol})")

### 12. Instantiate data collators and generate batches

In [22]:
# Instantiate the PT and FX DataCollators
pt_data_collator = DataCollatorSpeechSeq2SeqWithPadding(
        processor=processor,
        decoder_start_token_id=pt_model.config.decoder_start_token_id,
    )

fx_data_collator = FlaxDataCollatorSpeechSeq2SeqWithPadding(
        processor=processor,
        decoder_start_token_id=fx_model.config.decoder_start_token_id,
        input_padding="longest",
        target_padding="max_length",
        max_target_length=max_target_length,
        pad_input_to_multiple_of=pad_input_to_multiple_of,
        pad_target_to_multiple_of=pad_target_to_multiple_of,
    )

In [23]:
# Set JAX seed and generate PRNG for stochasic operations
seed = 0
rng = jax.random.PRNGKey(seed)
rng, input_rng = jax.random.split(rng)

In [24]:
# We'll naively create our batches through random shuffling and no grouping by length
num_train_samples = len(vectorized_datasets["train"])
train_samples_idx = jax.random.permutation(input_rng, np.arange(num_train_samples))
train_batch_idx = generate_batch_splits(train_samples_idx, batch_size_per_update)

In [25]:
# Alt: we'll use the grouped sampler
train_samples_idx = get_grouped_indices(vectorized_datasets["train"], batch_size_per_update, input_rng)
train_batch_idx = generate_batch_splits(train_samples_idx, batch_size_per_update)

In [26]:
# Treat the first training batch
batch_idx = train_batch_idx[1]
samples = [vectorized_datasets["train"][int(idx)] for idx in batch_idx]

In [27]:
fx_batch = fx_data_collator(samples)
pt_batch = pt_data_collator(samples)

In [28]:
# Convert the Flax inputs to PyTorch (optional)
#pt_batch = {k: torch.tensor(v.tolist()) for k, v in fx_batch.items()}

### 13. Check that the inputs are equal

In [29]:
expected_fx_keys = ["inputs", "labels", "decoder_input_ids"]
expected_pt_keys = ["input_values", "labels"]

for expected_fx_key in expected_fx_keys:
    assert expected_fx_key in fx_batch, f"{expected_fx_key} not in Flax batched inputs"

for expected_pt_key in expected_pt_keys:
    assert expected_pt_key in pt_batch, f"{expected_pt_key} not in PyTorch batched inputs"    

# Expect the keys between Flax and PyTorch to be different, this is just for observation
fx_batch.keys(), pt_batch.keys()

(dict_keys(['attention_mask', 'inputs', 'labels', 'decoder_input_ids']),
 dict_keys(['input_values', 'attention_mask', 'labels']))

In [30]:
assert_almost_equals(fx_batch['inputs'], pt_batch['input_values'].numpy())
assert_almost_equals(fx_batch['labels'], pt_batch['labels'].numpy())
if 'attention_mask' in fx_batch.keys() and pt_batch.keys():
    assert_almost_equals(fx_batch['attention_mask'], pt_batch['attention_mask'].numpy())

✅ Difference between Flax and PyTorch is 0.0 (< 0.01), avg is 0.0
✅ Difference between Flax and PyTorch is 0 (< 0.01), avg is 0.0
✅ Difference between Flax and PyTorch is 0 (< 0.01), avg is 0.0


### 14. Run a training step

In [31]:
pt_outputs = pt_model(**pt_batch, output_hidden_states=True)
pt_logits = pt_outputs.logits
pt_loss = pt_outputs.loss
pt_loss.backward()

In [32]:
# Flax cross entropy loss
def loss_fn(logits, labels):
    vocab_size = logits.shape[-1]
    loss = optax.softmax_cross_entropy(logits, onehot(labels, vocab_size))
    # ignore padded tokens from loss, i.e. where labels are not set to -100
    padding = labels >= 0
    loss = loss * padding
    loss = loss.sum() / padding.sum()
    return loss

In [33]:
# Flax training step (single device)
def fx_train_step(fx_model, fx_batch, output_hidden_states=True):
    def compute_loss(params):
        labels = fx_batch.pop("labels")
        outputs = fx_model(**fx_batch, params=params, output_hidden_states=output_hidden_states)
        loss = loss_fn(outputs.logits, labels)
        return loss, outputs

    grad_fn = jax.value_and_grad(compute_loss, has_aux=True)
    (loss, outputs), grad = grad_fn(fx_model.params)
    
    return loss, outputs, grad

In [34]:
fx_loss, fx_outputs, fx_grad = fx_train_step(fx_model, fx_batch, output_hidden_states=True)

### 15. Compare outputs for the forward pass

In [35]:
print("--------------------------Checking encoder hidden states match--------------------------")
for fx_state, pt_state in zip(fx_outputs.encoder_hidden_states, pt_outputs.encoder_hidden_states):
    assert_almost_equals(fx_state, pt_state.detach().numpy())

print("--------------------------Checking encoder last hidden states match--------------------------")
assert_almost_equals(fx_outputs.encoder_last_hidden_state, pt_outputs.encoder_last_hidden_state.detach().numpy())

print("--------------------------Checking decoder hidden states match--------------------------")
for fx_state, pt_state in zip(fx_outputs.decoder_hidden_states, pt_outputs.decoder_hidden_states):
    assert_almost_equals(fx_state, pt_state.detach().numpy())
    
print("--------------------------Checking logits match--------------------------")
print(f"Flax logits shape: {fx_outputs.logits.shape}, PyTorch logits shape: {pt_logits.shape}")
assert_almost_equals(fx_outputs.logits, pt_logits.detach().numpy())
kl_divergence(fx_outputs.logits, pt_logits.detach().numpy())

print("--------------------------Checking losses match--------------------------")
print(f"Flax loss: {fx_loss}, PyTorch loss: {pt_loss}")
assert_almost_equals(fx_loss, pt_loss.detach().numpy())

--------------------------Checking encoder hidden states match--------------------------
✅ Difference between Flax and PyTorch is 0.00087738037109375 (< 0.01), avg is 2.060942460957449e-05
✅ Difference between Flax and PyTorch is 0.0009002685546875 (< 0.01), avg is 3.303395351395011e-05
✅ Difference between Flax and PyTorch is 0.000911712646484375 (< 0.01), avg is 3.521700273267925e-05
✅ Difference between Flax and PyTorch is 0.000858306884765625 (< 0.01), avg is 3.61947895726189e-05
✅ Difference between Flax and PyTorch is 0.0008144378662109375 (< 0.01), avg is 3.6649136745836586e-05
✅ Difference between Flax and PyTorch is 0.00079345703125 (< 0.01), avg is 3.6703062505694106e-05
✅ Difference between Flax and PyTorch is 0.0012302398681640625 (< 0.01), avg is 3.7468114896910265e-05
❌ Difference between Flax and PyTorch is 0.0205078125 (>= 0.01), avg is 3.865711187245324e-05
❌ Difference between Flax and PyTorch is 0.021484375 (>= 0.01), avg is 3.895111512974836e-05
❌ Difference between

### 16. Compare outputs for the backward pass

In [36]:
pt_grad_dict = {k: v.grad if v.grad is not None else torch.zeros(v.shape) for k, v in pt_model.named_parameters()}
missing_grads = [k for k in pt_model.state_dict().keys() if k not in pt_grad_dict]

missing_keys, unexpected_keys = pt_model.load_state_dict(pt_grad_dict, strict=False)

In [37]:
assert missing_grads == missing_keys, f"Error with either grads {missing_keys} or keys {unexpected_keys}"

In [38]:
with tempfile.TemporaryDirectory() as tmpdirname:
    pt_model.save_pretrained(tmpdirname)
    pt_grad_model_to_fx = FlaxAutoModelForSpeechSeq2Seq.from_pretrained(tmpdirname, from_pt=True)

Some weights of the model checkpoint at /tmp/tmpttac__gx were not used when initializing FlaxSpeechEncoderDecoderModel: {('decoder', 'lm_head', 'kernel')}
- This IS expected if you are initializing FlaxSpeechEncoderDecoderModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing FlaxSpeechEncoderDecoderModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [39]:
pt_grad_to_fx = pt_grad_model_to_fx.params
fx_grad = flatten_dict(fx_grad)
pt_grad_to_fx = flatten_dict(pt_grad_to_fx)

In [40]:
results_fail_rel, results_correct_rel, results_fail, results_correct = assert_dict_equal(fx_grad, pt_grad_to_fx)

  diff_rel = np.abs(ak_norm - bk_norm) / np.abs(ak_norm)
  diff_rel = np.abs(ak_norm - bk_norm) / np.abs(ak_norm)


In [41]:
print("--------------------------Checking gradients match--------------------------")
if len(results_fail) == 0:
    print("✅ All grads pass")
else:
    print("\n".join(results_fail))

--------------------------Checking gradients match--------------------------
❌ Layer ('decoder', 'model', 'decoder', 'embed_tokens', 'embedding') has PT grad norm 39.86210250854492 and flax grad norm 39.875.
❌ Layer ('decoder', 'model', 'decoder', 'layers', '10', 'encoder_attn', 'out_proj', 'kernel') has PT grad norm 83.90470123291016 and flax grad norm 83.875.
❌ Layer ('decoder', 'model', 'decoder', 'layers', '10', 'encoder_attn', 'v_proj', 'kernel') has PT grad norm 58.8936653137207 and flax grad norm 58.875.
❌ Layer ('decoder', 'model', 'decoder', 'layers', '10', 'fc1', 'kernel') has PT grad norm 552.0279541015625 and flax grad norm inf.
❌ Layer ('decoder', 'model', 'decoder', 'layers', '10', 'fc2', 'kernel') has PT grad norm 93.37858581542969 and flax grad norm 93.0625.
❌ Layer ('decoder', 'model', 'decoder', 'layers', '10', 'final_layer_norm', 'scale') has PT grad norm 144.46214294433594 and flax grad norm 144.5.
❌ Layer ('decoder', 'model', 'decoder', 'layers', '10', 'self_attn',

In [42]:
print("--------------------------Checking rel gradients match--------------------------")

if len(results_fail_rel) == 0:
    print("✅ All rel grads pass")
else:
    print("\n".join(results_fail_rel))

--------------------------Checking rel gradients match--------------------------
❌ Layer ('decoder', 'model', 'decoder', 'embed_positions', 'embedding') has PT grad norm 0.0009617977775633335 and flax grad norm 0.0009765625.
❌ Layer ('decoder', 'model', 'decoder', 'layernorm_embedding', 'scale') has PT grad norm 0.0005227156216278672 and flax grad norm 0.0005459785461425781.
❌ Layer ('decoder', 'model', 'decoder', 'layers', '0', 'encoder_attn', 'k_proj', 'bias') has PT grad norm 3.740085324777098e-11 and flax grad norm 0.0.
❌ Layer ('decoder', 'model', 'decoder', 'layers', '0', 'encoder_attn', 'k_proj', 'kernel') has PT grad norm 5.6070362916216254e-05 and flax grad norm 0.0.
❌ Layer ('decoder', 'model', 'decoder', 'layers', '0', 'encoder_attn', 'out_proj', 'bias') has PT grad norm 0.0001576704962644726 and flax grad norm 0.0.
❌ Layer ('decoder', 'model', 'decoder', 'layers', '0', 'encoder_attn', 'q_proj', 'bias') has PT grad norm 2.3031611817714293e-06 and flax grad norm 0.0.
❌ Layer 