### 1. Set the JAX platform (CPU/TPU) and matmul precision (if on TPU)

In [1]:
import os
os.environ["JAX_PLATFORM_NAME"] = "cpu"
#os.environ["JAX_DEFAULT_MATMUL_PRECISION"]="float32"

### 2. Import libraries

In [2]:
import torch
import numpy as np
from transformers import Wav2Vec2Model, FlaxWav2Vec2Model, SpeechEncoderDecoderModel, FlaxSpeechEncoderDecoderModel
from flax.traverse_util import flatten_dict
import random
import tempfile

  from .autonotebook import tqdm as notebook_tqdm


### 3. Load pretrained 'tiny random' models

In [3]:
encoder_id = "hf-internal-testing/tiny-random-wav2vec2"
decoder_id = "hf-internal-testing/tiny-random-bart"

In [4]:
fx_enc_model = FlaxWav2Vec2Model.from_pretrained(encoder_id, from_pt=True)
fx_enc_dec_model = FlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained(encoder_id, decoder_id, encoder_from_pt=True, decoder_from_pt=True)

Some weights of the model checkpoint at hf-internal-testing/tiny-random-wav2vec2 were not used when initializing FlaxWav2Vec2Model: {('project_hid', 'bias'), ('project_q', 'kernel'), ('quantizer', 'codevectors'), ('project_hid', 'kernel'), ('lm_head', 'kernel'), ('quantizer', 'weight_proj', 'kernel'), ('project_q', 'bias'), ('lm_head', 'bias'), ('quantizer', 'weight_proj', 'bias')}
- This IS expected if you are initializing FlaxWav2Vec2Model from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing FlaxWav2Vec2Model from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of FlaxWav2Vec2Model were not initialized from the model checkpoint at hf-internal-testing/tiny-random-wav2vec2 and are newly initial

In [5]:
pt_enc_model = Wav2Vec2Model.from_pretrained(encoder_id)
pt_enc_dec_model = SpeechEncoderDecoderModel.from_encoder_decoder_pretrained(encoder_id, decoder_id)

Some weights of the model checkpoint at hf-internal-testing/tiny-random-wav2vec2 were not used when initializing Wav2Vec2Model: ['project_hid.weight', 'project_q.bias', 'project_q.weight', 'project_hid.bias', 'lm_head.bias', 'quantizer.weight_proj.bias', 'lm_head.weight', 'quantizer.codevectors', 'quantizer.weight_proj.weight']
- This IS expected if you are initializing Wav2Vec2Model from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Wav2Vec2Model from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of Wav2Vec2Model were not initialized from the model checkpoint at hf-internal-testing/tiny-random-wav2vec2 and are newly initialized: ['wav2vec2.feature_extractor.conv_layers.2.layer_norm.weight'

### 4. Check that the weights of the FlaxWav2Vec2 model match those of the FlaxSpeechEncoderDecoderModel

In [6]:
# It's easier to work with flattened dictionaries of parameters
fx_enc_params = flatten_dict(fx_enc_model.params)
fx_enc_dec_params = flatten_dict(fx_enc_dec_model.params['encoder'])  # require just the encoder parameters

# Check that all keys match
assert fx_enc_params.keys() == fx_enc_dec_params.keys()

# Check that all the weights are **precisely** equal - verifies that the encoder module is loaded correctly into the SpeechEncoderDecoder framework
for param in fx_enc_params:
    assert (fx_enc_params[param] == fx_enc_dec_params[param]).all(), param

### 5. Check that the weights of the SpeechEncoderDecoderModel match those of the FlaxSpeechEncoderDecoderModel

In [7]:
# Convert the PT model to FX 
with tempfile.TemporaryDirectory() as tmpdirname:
    pt_enc_dec_model.save_pretrained(tmpdirname)
    pt_enc_dec_model_to_fx = FlaxSpeechEncoderDecoderModel.from_pretrained(tmpdirname, from_pt=True)
    
pt_params_to_fx = flatten_dict(pt_enc_dec_model_to_fx.params)
fx_enc_dec_params = flatten_dict(fx_enc_dec_model.params)  # need all the parameters this time, not just encoder

# Check that all keys match
assert fx_enc_dec_params.keys() == pt_params_to_fx.keys()

# Check that all the weights are **precisely** equal
for param in pt_params_to_fx:
    assert (fx_enc_dec_params[param] == pt_params_to_fx[param]).all(), param
    
# Free CPU memory 
del fx_enc_params, fx_enc_dec_params, pt_enc_dec_model_to_fx, pt_params_to_fx

Some weights of the model checkpoint at /tmp/tmpegm4xyev were not used when initializing FlaxSpeechEncoderDecoderModel: {('decoder', 'lm_head', 'kernel')}
- This IS expected if you are initializing FlaxSpeechEncoderDecoderModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing FlaxSpeechEncoderDecoderModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


### 6. Create synthetic input data

In [8]:
def ids_tensor(shape, vocab_size, rng=None):
    """Creates a random int32 tensor of the shape within the vocab size."""
    if rng is None:
        rng = random.Random()

    total_dims = 1
    for dim in shape:
        total_dims *= dim

    values = []
    for _ in range(total_dims):
        values.append(rng.randint(0, vocab_size - 1))

    output = np.array(values).reshape(shape)

    return output


def random_attention_mask(shape, rng=None):
    attn_mask = ids_tensor(shape, vocab_size=2, rng=rng)
    # make sure that at least one token is attended to for each batch
    attn_mask[:, -1] = 1
    return attn_mask


def floats_tensor(shape, scale=1.0, rng=None):
    """Creates a random float32 tensor"""
    if rng is None:
        rng = random.Random()

    total_dims = 1
    for dim in shape:
        total_dims *= dim

    values = []
    for _ in range(total_dims):
        values.append(rng.random() * scale)

    return np.array(values, dtype=np.float32).reshape(shape)


def shift_tokens_right(input_ids: np.array, pad_token_id: int, decoder_start_token_id: int) -> np.ndarray:
    """
    Shift input ids one token to the right.
    """
    shifted_input_ids = np.zeros_like(input_ids)
    shifted_input_ids[:, 1:] = input_ids[:, :-1]
    shifted_input_ids[:, 0] = decoder_start_token_id

    shifted_input_ids = np.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
    return shifted_input_ids

In [9]:
batch_size = 2
encoder_input_length = 5000
decoder_input_length = 16

input_ids = floats_tensor([batch_size, encoder_input_length])
attention_mask = random_attention_mask([batch_size, encoder_input_length])
label_ids = ids_tensor([batch_size, decoder_input_length], fx_enc_dec_model.config.decoder.vocab_size)
decoder_input_ids = shift_tokens_right(input_ids=label_ids, pad_token_id=fx_enc_dec_model.config.decoder.pad_token_id,
                                       decoder_start_token_id=fx_enc_dec_model.config.decoder.decoder_start_token_id or fx_enc_dec_model.config.decoder.bos_token_id)

fx_inputs = {
    "inputs": input_ids,
    "decoder_input_ids": decoder_input_ids
}

pt_inputs = {k: torch.tensor(v.tolist()) for k, v in fx_inputs.items()}
pt_inputs["labels"] = torch.tensor(label_ids.tolist())

### 7. FlaxWav2Vec2Model vs FlaxSpeechEncoderDecoderModel's encoder

In [10]:
# Compare the FlaxWav2Vec2Model outputs to those of the FlaxSpeechEncoderDecoderModel's encoder - they should be equal!
fx_enc_outputs = fx_enc_model(fx_inputs["inputs"], output_hidden_states=True)
fx_enc_dec_outputs = fx_enc_dec_model(**fx_inputs, output_hidden_states=True)

In [11]:
fx_enc_outputs.keys(), fx_enc_dec_outputs.keys()

(odict_keys(['last_hidden_state', 'extract_features', 'hidden_states']),
 odict_keys(['logits', 'decoder_hidden_states', 'encoder_last_hidden_state', 'encoder_hidden_states']))

In [12]:
# define a helper function for our analysis
def assert_almost_equals(a: np.ndarray, b: np.ndarray, tol: float = 1e-2):
    diff = np.abs((a - b)).max()
    if diff < tol:
        print(f"✅ Difference between Flax and PyTorch is {diff} (< {tol})")
    else:
        print(f"❌ Difference between Flax and PyTorch is {diff} (>= {tol})")

In [13]:
print("--------------------------Checking encoder hidden states match--------------------------")
for fx_enc_dec_state, fx_enc_state in zip(fx_enc_dec_outputs.encoder_hidden_states, fx_enc_outputs.hidden_states):
    assert fx_enc_dec_state.shape == fx_enc_state.shape
    assert_almost_equals(fx_enc_dec_state, fx_enc_state)

print("--------------------------Checking encoder last hidden states match--------------------------")
print(f"Encoder-decoder output shape: {fx_enc_dec_outputs.encoder_last_hidden_state.shape}, encoder-only output shape: {fx_enc_outputs.last_hidden_state.shape}")
assert_almost_equals(fx_enc_dec_outputs.encoder_last_hidden_state, fx_enc_outputs.last_hidden_state)

--------------------------Checking encoder hidden states match--------------------------
❌ Difference between Flax and PyTorch is 0.408400297164917 (>= 0.01)
❌ Difference between Flax and PyTorch is 0.41299885511398315 (>= 0.01)
❌ Difference between Flax and PyTorch is 0.4098670184612274 (>= 0.01)
❌ Difference between Flax and PyTorch is 0.41270017623901367 (>= 0.01)
❌ Difference between Flax and PyTorch is 2.5930871963500977 (>= 0.01)
--------------------------Checking encoder last hidden states match--------------------------
Encoder-decoder output shape: (2, 76, 16), encoder-only output shape: (2, 76, 16)
❌ Difference between Flax and PyTorch is 2.5930871963500977 (>= 0.01)


#### Comments:
The above results clearly demonstrate that there is something wrong with how the encoder module is called in the FlaxSpeechEncoderDecoderModel. Let's probe a little deeper by comparing:

8. **PyTorch Encoder vs PyTorch Encoder-Decoder Outputs** - assert that the encoder is loaded correctly in PyTorch and highlight the results we expect to see in Flax ✅
9. **Flax Encoder vs PyTorch Encoder Outputs** - assert that the Wav2Vec2 Model is correctly implemented in Flax ✅
10. **Flax Encoder-Decoder vs PyTorch Encoder-Decoder Outputs** - sanity check to make sure the same results hold when we compare the same hidden-state values across PyTorch and Flax

### 8. Wav2Vec2Model vs SpeechEncoderDecoderModel's encoder module

In [14]:
# Compare the outputs of the PyTorch Wav2Vec2 encoder only model with the output of the PyTorch SpeechEncoderDecoder's encoder module - they should be equal!
pt_enc_outputs = pt_enc_model(pt_inputs["inputs"], output_hidden_states=True)
pt_enc_dec_outputs = pt_enc_dec_model(**pt_inputs, output_hidden_states=True)

In [15]:
print("--------------------------Checking encoder hidden states match--------------------------")
for pt_enc_dec_state, pt_enc_state in zip(pt_enc_dec_outputs.encoder_hidden_states, pt_enc_outputs.hidden_states):
    assert pt_enc_dec_state.shape == pt_enc_state.shape
    assert_almost_equals(pt_enc_dec_state.detach().numpy(), pt_enc_state.detach().numpy())

print("--------------------------Checking encoder last hidden states match--------------------------")
print(f"Encoder-decoder output shape: {pt_enc_dec_outputs.encoder_last_hidden_state.shape}, encoder-only output shape: {pt_enc_outputs.last_hidden_state.shape}")
assert_almost_equals(pt_enc_dec_outputs.encoder_last_hidden_state.detach().numpy(), pt_enc_outputs.last_hidden_state.detach().numpy())

--------------------------Checking encoder hidden states match--------------------------
✅ Difference between Flax and PyTorch is 0.0 (< 0.01)
✅ Difference between Flax and PyTorch is 0.0 (< 0.01)
✅ Difference between Flax and PyTorch is 0.0 (< 0.01)
✅ Difference between Flax and PyTorch is 0.0 (< 0.01)
✅ Difference between Flax and PyTorch is 0.0 (< 0.01)
--------------------------Checking encoder last hidden states match--------------------------
Encoder-decoder output shape: torch.Size([2, 76, 16]), encoder-only output shape: torch.Size([2, 76, 16])
✅ Difference between Flax and PyTorch is 0.0 (< 0.01)


### 9. FlaxWav2Vec2Model vs Wav2Vec2Model

In [16]:
# Compare the outputs of the FlaxWav2Vec2Model with the PyTorch Wav2Vec2Model 

print("--------------------------Checking extract features match--------------------------")
assert_almost_equals(fx_enc_outputs.extract_features, pt_enc_outputs.extract_features.detach().numpy())

print("--------------------------Checking encoder hidden states match--------------------------")
for fx_enc_state, pt_enc_state in zip(fx_enc_outputs.hidden_states, pt_enc_outputs.hidden_states):
    assert fx_enc_state.shape == pt_enc_state.shape
    assert_almost_equals(fx_enc_state, pt_enc_state.detach().numpy())

print("--------------------------Checking encoder last hidden states match--------------------------")
print(f"Encoder-decoder output shape: {fx_enc_outputs.last_hidden_state.shape}, encoder-only output shape: {pt_enc_outputs.last_hidden_state.shape}")
assert_almost_equals(fx_enc_outputs.last_hidden_state, pt_enc_outputs.last_hidden_state.detach().numpy())

--------------------------Checking extract features match--------------------------
✅ Difference between Flax and PyTorch is 4.0531158447265625e-06 (< 0.01)
--------------------------Checking encoder hidden states match--------------------------
✅ Difference between Flax and PyTorch is 3.501772880554199e-07 (< 0.01)
✅ Difference between Flax and PyTorch is 3.501772880554199e-07 (< 0.01)
✅ Difference between Flax and PyTorch is 3.5390257835388184e-07 (< 0.01)
✅ Difference between Flax and PyTorch is 3.5390257835388184e-07 (< 0.01)
✅ Difference between Flax and PyTorch is 3.7550926208496094e-06 (< 0.01)
--------------------------Checking encoder last hidden states match--------------------------
Encoder-decoder output shape: (2, 76, 16), encoder-only output shape: torch.Size([2, 76, 16])
✅ Difference between Flax and PyTorch is 3.7550926208496094e-06 (< 0.01)


### 10. FlaxSpeechEncoderDecoderModel's encoder module vs SpeechEncoderDecoderModel's encoder module

In [17]:
# Compare the outputs of the FlaxSpeechEncoderDecoderModel's encoder outputs to that of the PyTorch SpeechEncoderDecoderModel's encoder

print("--------------------------Checking encoder hidden states match--------------------------")
for fx_enc_dec_state, pt_enc_state in zip(fx_enc_dec_outputs.encoder_hidden_states, pt_enc_dec_outputs.encoder_hidden_states):
    assert fx_enc_dec_state.shape == pt_enc_dec_state.shape
    assert_almost_equals(fx_enc_dec_state, pt_enc_dec_state.detach().numpy())

print("--------------------------Checking encoder last hidden states match--------------------------")
print(f"Encoder-decoder output shape: {fx_enc_dec_outputs.encoder_last_hidden_state.shape}, encoder-only output shape: {pt_enc_dec_outputs.encoder_last_hidden_state.shape}")
assert_almost_equals(fx_enc_dec_outputs.encoder_last_hidden_state, pt_enc_dec_outputs.encoder_last_hidden_state.detach().numpy())

--------------------------Checking encoder hidden states match--------------------------
❌ Difference between Flax and PyTorch is 2.5930871963500977 (>= 0.01)
❌ Difference between Flax and PyTorch is 2.5930871963500977 (>= 0.01)
❌ Difference between Flax and PyTorch is 2.5930871963500977 (>= 0.01)
❌ Difference between Flax and PyTorch is 2.5930871963500977 (>= 0.01)
❌ Difference between Flax and PyTorch is 2.5930871963500977 (>= 0.01)
--------------------------Checking encoder last hidden states match--------------------------
Encoder-decoder output shape: (2, 76, 16), encoder-only output shape: torch.Size([2, 76, 16])
❌ Difference between Flax and PyTorch is 2.5930871963500977 (>= 0.01)


### Conclusions

1. The Flax encoder module weights are loaded correctly into the speech encoder-decoder model framework.
2. The Flax speech encoder-decoder model weights are identical to those of their PyTorch counterpart.
3. The Flax encoder model outputs do not match those of the encoder-decoder model's encoder outputs for randomly distributed inputs - there is something wrong with the way the encoder module is implemented in the Flax speech encoder-decoder model framework!
4. The PyTorch encoder model outputs match those of the PyTorch encoder-decoder model's encoder outputs - the encoder module is correctly implemented in the PyTorch speech encoder-decoder model framework.
5. The Flax encoder model outputs match those of the PyTorch encoder model outputs - the Wav2Vec2 model is correctly implemented in Flax.
6. The Flax encoder-decoder outputs do not match those of the PyTorch encoder-decoder outputs - further evidence there is something wrong with the Flax encoder-decoder implementation