In [1]:
from transformers import FlaxSpeechEncoderDecoderModel
from models.modeling_flax_speech_encoder_decoder import FlaxSpeechEncoderDecoderModel as CustomFlaxSpeechEncoderDecoderModel
import numpy as np



In [2]:
encoder_id = 'hf-internal-testing/tiny-random-wav2vec2'
decoder_id = 'hf-internal-testing/tiny-random-bart'

In [3]:
hf_model = FlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained(encoder_id, decoder_id, encoder_from_pt=True, decoder_from_pt=True)
custom_model = CustomFlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained(encoder_id, decoder_id, encoder_from_pt=True, decoder_from_pt=True)

Some weights of the model checkpoint at hf-internal-testing/tiny-random-wav2vec2 were not used when initializing FlaxWav2Vec2Model: {('quantizer', 'weight_proj', 'bias'), ('lm_head', 'bias'), ('project_q', 'bias'), ('quantizer', 'codevectors'), ('quantizer', 'weight_proj', 'kernel'), ('lm_head', 'kernel'), ('project_hid', 'bias'), ('project_q', 'kernel'), ('project_hid', 'kernel')}
- This IS expected if you are initializing FlaxWav2Vec2Model from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing FlaxWav2Vec2Model from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of FlaxWav2Vec2Model were not initialized from the model checkpoint at hf-internal-testing/tiny-random-wav2vec2 and are newly initial

In [4]:
# create some dummy data
inputs = np.random.randn(2, 2000)
decoder_input_ids = np.arange(100).reshape(2,50)

In [5]:
# get ground-truth outputs from Transformers 🤗 model
hf_outputs = hf_model(inputs, decoder_input_ids=decoder_input_ids, output_hidden_states=True)

In [6]:
extract_features = custom_model.encode(inputs, output_features=True)

In [7]:
custom_outputs = custom_model(inputs, extract_features=extract_features, decoder_input_ids=decoder_input_ids, output_hidden_states=True)

In [8]:
# define a helper function for our analysis
def assert_almost_equals(a: np.ndarray, b: np.ndarray, tol: float = 1e-9):
    diff = np.abs((a - b)).max()
    if diff <= tol:
        print(f"✅ Difference between Flax and PyTorch is {diff} (< {tol})")
    else:
        print(f"❌ Difference between Flax and PyTorch is {diff} (>= {tol})")

In [9]:
print("--------------------------Checking encoder hidden states match--------------------------")
for hf_state, custom_state in zip(hf_outputs.encoder_hidden_states, custom_outputs.encoder_hidden_states):
    assert hf_state.shape == custom_state.shape
    assert_almost_equals(hf_state, custom_state)

print("--------------------------Checking encoder last hidden states match--------------------------")
print(f"HF output shape: {hf_outputs.encoder_last_hidden_state.shape}, custom output shape: {custom_outputs.encoder_last_hidden_state.shape}")
assert_almost_equals(hf_outputs.encoder_last_hidden_state, custom_outputs.encoder_last_hidden_state)

print("--------------------------Checking decoder hidden states match--------------------------")
for hf_state, custom_state in zip(hf_outputs.decoder_hidden_states, custom_outputs.decoder_hidden_states):
    assert hf_state.shape == custom_state.shape
    assert_almost_equals(hf_state, custom_state)

print("--------------------------Checking logits match--------------------------")
print(f"HF logits shape: {hf_outputs.logits.shape}, Custom logits shape: {custom_outputs.logits.shape}")
assert_almost_equals(hf_outputs.logits, custom_outputs.logits)

--------------------------Checking encoder hidden states match--------------------------
✅ Difference between Flax and PyTorch is 0.0 (< 1e-09)
✅ Difference between Flax and PyTorch is 0.0 (< 1e-09)
✅ Difference between Flax and PyTorch is 0.0 (< 1e-09)
✅ Difference between Flax and PyTorch is 0.0 (< 1e-09)
✅ Difference between Flax and PyTorch is 0.0 (< 1e-09)
--------------------------Checking encoder last hidden states match--------------------------
HF output shape: (2, 29, 16), custom output shape: (2, 29, 16)
✅ Difference between Flax and PyTorch is 0.0 (< 1e-09)
--------------------------Checking decoder hidden states match--------------------------
✅ Difference between Flax and PyTorch is 0.0 (< 1e-09)
✅ Difference between Flax and PyTorch is 0.0 (< 1e-09)
✅ Difference between Flax and PyTorch is 0.0 (< 1e-09)
--------------------------Checking logits match--------------------------
HF logits shape: (2, 50, 1000), Custom logits shape: (2, 50, 1000)
✅ Difference between Flax and