In [1]:
from transformers import FlaxSpeechEncoderDecoderModel
from models.modeling_flax_speech_encoder_decoder import FlaxSpeechEncoderDecoderModel as CustomFlaxSpeechEncoderDecoderModel
from flax.traverse_util import flatten_dict, unflatten_dict
import numpy as np
import jax.numpy as jnp



In [2]:
encoder_id = 'hf-internal-testing/tiny-random-wav2vec2'
decoder_id = 'hf-internal-testing/tiny-random-bart'

In [3]:
hf_model = FlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained(encoder_id, decoder_id, encoder_from_pt=True, decoder_from_pt=True)

Some weights of the model checkpoint at hf-internal-testing/tiny-random-wav2vec2 were not used when initializing FlaxWav2Vec2Model: {('project_q', 'bias'), ('project_q', 'kernel'), ('project_hid', 'bias'), ('project_hid', 'kernel'), ('quantizer', 'codevectors'), ('quantizer', 'weight_proj', 'bias'), ('quantizer', 'weight_proj', 'kernel'), ('lm_head', 'bias'), ('lm_head', 'kernel')}
- This IS expected if you are initializing FlaxWav2Vec2Model from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing FlaxWav2Vec2Model from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of FlaxWav2Vec2Model were not initialized from the model checkpoint at hf-internal-testing/tiny-random-wav2vec2 and are newly initial

In [4]:
config = hf_model.config
num_decoder_layers = config.decoder.decoder_layers
num_encoder_layers = config.encoder.num_hidden_layers

In [5]:
custom_model = CustomFlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained(encoder_id, decoder_id, encoder_use_scan=True, decoder_use_scan=True, encoder_from_pt=True, decoder_from_pt=True)

encoder checkpointing: False
encoder scan: True


Some weights of the model checkpoint at hf-internal-testing/tiny-random-wav2vec2 were not used when initializing FlaxWav2Vec2Model: {('encoder', 'layers', '0', 'feed_forward', 'intermediate_dense', 'kernel'), ('encoder', 'layers', '2', 'attention', 'out_proj', 'bias'), ('encoder', 'layers', '2', 'layer_norm', 'kernel'), ('encoder', 'layers', '0', 'feed_forward', 'output_dense', 'kernel'), ('encoder', 'layers', '3', 'attention', 'q_proj', 'bias'), ('encoder', 'layers', '1', 'feed_forward', 'intermediate_dense', 'bias'), ('encoder', 'layers', '1', 'attention', 'k_proj', 'kernel'), ('project_hid', 'bias'), ('encoder', 'layers', '3', 'layer_norm', 'kernel'), ('encoder', 'layers', '2', 'attention', 'v_proj', 'bias'), ('encoder', 'layers', '3', 'feed_forward', 'intermediate_dense', 'kernel'), ('encoder', 'layers', '2', 'feed_forward', 'output_dense', 'kernel'), ('encoder', 'layers', '3', 'feed_forward', 'output_dense', 'kernel'), ('encoder', 'layers', '0', 'attention', 'q_proj', 'bias'), ('e

decoder checkpointing: False
decoder scan: True


Some weights of the model checkpoint at hf-internal-testing/tiny-random-bart were not used when initializing FlaxBartForCausalLM: {('qa_outputs', 'bias'), ('decoder', 'layers', '1', 'self_attn', 'k_proj', 'bias'), ('final_logits_bias',), ('model', 'encoder', 'layers', '1', 'self_attn_layer_norm', 'bias'), ('model', 'decoder', 'layers', '0', 'encoder_attn_layer_norm', 'bias'), ('decoder', 'layers', '1', 'encoder_attn', 'out_proj', 'bias'), ('model', 'encoder', 'layers', '0', 'fc1', 'bias'), ('model', 'decoder', 'layers', '0', 'encoder_attn', 'v_proj', 'bias'), ('model', 'encoder', 'layers', '1', 'fc1', 'bias'), ('decoder', 'layers', '1', 'self_attn', 'out_proj', 'bias'), ('model', 'encoder', 'layers', '0', 'self_attn', 'v_proj', 'kernel'), ('model', 'encoder', 'layers', '0', 'self_attn', 'k_proj', 'bias'), ('model', 'encoder', 'layers', '1', 'self_attn', 'q_proj', 'bias'), ('encoder', 'layers', '1', 'fc2', 'bias'), ('model', 'decoder', 'layers', '1', 'self_attn', 'v_proj', 'kernel'), ('

encoder checkpointing: False
encoder scan: True
decoder checkpointing: False
decoder scan: True


In [6]:
hf_params = hf_model.params
custom_params = custom_model.params
custom_params = flatten_dict(custom_params)

In [7]:
def unrolled_to_scanned(vs):
    vs = unfreeze(vs)
    Ws = jnp.stack([vs['params'][f'SinDot_{i}']['W'] for i in range(L)])
    bs = jnp.stack([vs['params'][f'SinDot_{i}']['b'] for i in range(L)])
    new_vs = {'params': {}}
    new_vs['params']['scanned_layer'] = {}
    new_vs['params']['scanned_layer']['W'] = Ws
    new_vs['params']['scanned_layer']['b'] = bs
    return freeze(new_vs)

In [8]:
def unrolled_to_scanned(hf_params):
    new_enc_params = {}
    # get the key of a scanned module
    for k in flatten_dict(hf_params['encoder']['encoder']['layers']['0']):
        # stack the weights for each layer of the scanned module into one matrix
        new_enc_params[k] = jnp.stack([flatten_dict(hf_params['encoder']['encoder']['layers'][str(i)])[k] for i in range(num_encoder_layers)])
    # append the correct prefix to the scanned modules' keys
    new_enc_params = unflatten_dict({('encoder', 'layers', 'FlaxWav2Vec2EncoderLayers'): unflatten_dict(new_enc_params)})
    
    # repeat for the decoder (note that the key 'layers' appears one index to the right than in the encoder, thus we'll treat the encoder and decoder independently for now)
    new_dec_params = {}
    for k in flatten_dict(hf_params['decoder']['model']['decoder']['layers']['0']):
        new_dec_params[k] = jnp.stack([flatten_dict(hf_params['decoder']['model']['decoder']['layers'][str(i)])[k] for i in range(num_decoder_layers)])
    new_dec_params = unflatten_dict({('model', 'decoder', 'layers', 'FlaxBartDecoderLayers'): unflatten_dict(new_dec_params)})
    
    # combine the encoder and decoder parameters
    new_params = {'encoder': new_enc_params, 'decoder': new_dec_params}
    new_params = flatten_dict(new_params)
    
    # append parameters for non-scanned modules (i.e. all modules that do not contain the key 'layers')
    for k in flatten_dict(hf_params):
        if 'layers' not in k:
            new_params[k] = flatten_dict(hf_params)[k]

    return unflatten_dict(new_params)

In [9]:
custom_model.params = unrolled_to_scanned(hf_params)

In [10]:
match = []
mismatch = []

flat_hf_params = flatten_dict(hf_params['encoder']['encoder']['layers'])
flat_custom_params = flatten_dict(custom_model.params['encoder']['encoder']['layers']['FlaxWav2Vec2EncoderLayers'])

for k in flat_hf_params:
    assert flat_custom_params[k[1:]][int(k[0])].shape == flat_hf_params[k].shape, "shapes do not match"
    
    if (flat_custom_params[k[1:]][int(k[0])] == flat_hf_params[k]).all():
        match.append(k)
        
    if (flat_custom_params[k[1:]][int(k[0])] != flat_hf_params[k]).all():
        mismatch.append(k)
        
print(len(match) + len(mismatch), len(match), len(mismatch))

64 64 0


In [11]:
match = []
mismatch = []

flat_hf_params = flatten_dict(hf_params['decoder']['model']['decoder']['layers'])
flat_custom_params = flatten_dict(custom_model.params['decoder']['model']['decoder']['layers']['FlaxBartDecoderLayers'])

for k in flat_hf_params:
    assert flat_custom_params[k[1:]][int(k[0])].shape == flat_hf_params[k].shape, "shapes do not match"
    
    if (flat_custom_params[k[1:]][int(k[0])] == flat_hf_params[k]).all():
        match.append(k)
        
    if (flat_custom_params[k[1:]][int(k[0])] != flat_hf_params[k]).all():
        mismatch.append(k)
        
print(len(match) + len(mismatch), len(match), len(mismatch))

52 52 0


In [12]:
# create some dummy data
inputs = np.random.randn(1, 5000)
decoder_input_ids = np.arange(100).reshape(1, 100)

In [13]:
# get ground-truth outputs from Transformers 🤗 model
hf_outputs = hf_model(inputs, decoder_input_ids=decoder_input_ids)

In [14]:
custom_outputs = custom_model(inputs, decoder_input_ids=decoder_input_ids)

encoder checkpointing: False
encoder scan: True
decoder checkpointing: False
decoder scan: False


In [24]:
# define a helper function for our analysis
def assert_almost_equals(a: np.ndarray, b: np.ndarray, tol: float = 1e-5):
    diff = np.abs((a - b)).max()
    if diff <= tol:
        print(f"✅ Difference between HF and custom is {diff} (< {tol})")
    else:
        print(f"❌ Difference between HF and custom is {diff} (>= {tol})")

In [25]:
print("--------------------------Checking encoder last hidden states match--------------------------")
print(f"HF output shape: {hf_outputs.encoder_last_hidden_state.shape}, custom output shape: {custom_outputs.encoder_last_hidden_state.shape}")
assert_almost_equals(hf_outputs.encoder_last_hidden_state, custom_outputs.encoder_last_hidden_state)

print("--------------------------Checking logits match--------------------------")
print(f"HF logits shape: {hf_outputs.logits.shape}, Custom logits shape: {custom_outputs.logits.shape}")
assert_almost_equals(hf_outputs.logits, custom_outputs.logits)

--------------------------Checking encoder last hidden states match--------------------------
HF output shape: (2, 76, 16), custom output shape: (2, 76, 16)
✅ Difference between HF and custom is 3.5762786865234375e-07 (< 1e-05)
--------------------------Checking logits match--------------------------
HF logits shape: (2, 50, 1000), Custom logits shape: (2, 50, 1000)
✅ Difference between HF and custom is 8.940696716308594e-08 (< 1e-05)
