# Freezing a Subset of Model Parameters

In [1]:
from transformers import FlaxSpeechEncoderDecoderModel
import random
import numpy as np
import jax
import jax.numpy as jnp
import optax
from flax import traverse_util
from flax.training import train_state
from flax.core import freeze, frozen_dict
from flax.core.frozen_dict import FrozenDict
from flax.training.common_utils import onehot
from typing import Mapping

  from .autonotebook import tqdm as notebook_tqdm


## 1. Load a pre-trained 'tiny' speech encoder-decoder model

In [2]:
encoder_id = "hf-internal-testing/tiny-random-wav2vec2"
decoder_id = "hf-internal-testing/tiny-random-bart"

In [3]:
model = FlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained(encoder_id, decoder_id, encoder_from_pt=True, decoder_from_pt=True)
params = freeze(model.params)

Some weights of the model checkpoint at hf-internal-testing/tiny-random-wav2vec2 were not used when initializing FlaxWav2Vec2Model: {('quantizer', 'weight_proj', 'bias'), ('lm_head', 'kernel'), ('project_hid', 'kernel'), ('project_q', 'kernel'), ('lm_head', 'bias'), ('project_hid', 'bias'), ('quantizer', 'weight_proj', 'kernel'), ('project_q', 'bias'), ('quantizer', 'codevectors')}
- This IS expected if you are initializing FlaxWav2Vec2Model from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing FlaxWav2Vec2Model from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of FlaxWav2Vec2Model were not initialized from the model checkpoint at hf-internal-testing/tiny-random-wav2vec2 and are newly initial

## 2. Utility functions for our analysis

In [4]:
def print_tree(d, depth=0, print_value=False):
    for k in d.keys():
        if isinstance(d[k], FrozenDict):
            print('  ' * depth, k)
            print_tree(d[k], depth + 1, print_value)
        else:
            if print_value:
                print('  ' * depth, k, d[k])
            else:
                print('  ' * depth, k)

In [5]:
def compare_params(lhs, rhs, depth=0):
    for k in lhs.keys():
        if isinstance(lhs[k], FrozenDict):
            print('  ' * depth, k)
            compare_params(lhs[k], rhs[k], depth + 1)
        else:
            print('  ' * depth, k, jnp.mean(jnp.abs(lhs[k] - rhs[k])))

## 3. View the parameter tree

In [6]:
# view the param tree for the full model (note: long output!)
print_tree(params)

 decoder
   model
     decoder
       embed_tokens
         embedding
       embed_positions
         embedding
       layers
         0
           self_attn
             k_proj
               kernel
               bias
             v_proj
               kernel
               bias
             q_proj
               kernel
               bias
             out_proj
               kernel
               bias
           self_attn_layer_norm
             scale
             bias
           encoder_attn
             k_proj
               kernel
               bias
             v_proj
               kernel
               bias
             q_proj
               kernel
               bias
             out_proj
               kernel
               bias
           encoder_attn_layer_norm
             scale
             bias
           fc1
             kernel
             bias
           fc2
             kernel
             bias
           final_layer_norm
             scale
             bias
      

As we can see, the param tree is somewhat large! And this is just for a 'tiny' seq2seq model... To make our lives easier, we'll focus on two specific modules from the encoder: the `feature_extractor` and the `feature_projection`. For the former, we have added a `jax.lax.stop_gradient` operator within the [modelling code](https://github.com/huggingface/transformers/blob/d57da992371c1c8258dc683275b4711dee949d20/src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py#L409-L410), which is triggered if the argument `freeze_feature_encoder` is set to `True`. We can use this argument to freeze the feature extractor layers during training - by action of the `jax.lax.stop_gradient` operator, the gradients of these layers are clamped to precisely zero.

In [7]:
print("feature_extractor")
print_tree(params['encoder']['feature_extractor'], depth=1)

feature_extractor
   conv_layers
     0
       conv
         kernel
       layer_norm
         scale
         bias
     1
       conv
         kernel
       layer_norm
         bias
         scale
     2
       conv
         kernel
       layer_norm
         bias
         scale


In [8]:
print("feature_projection")
print_tree(params['encoder']['feature_projection'], depth=1)

feature_projection
   layer_norm
     scale
     bias
   projection
     kernel
     bias


Lets get an idea of how many of the model parameters we are freezing:

In [9]:
# get num params
num_params = sum(jax.tree_leaves(jax.tree_map(lambda x: jnp.size(x), params)))
# get num frozen params
num_frozen_params = sum(jax.tree_leaves(jax.tree_map(lambda x: jnp.size(x), params['encoder']['feature_extractor'])))
# accumulate statistics
prop_frozen_params = num_frozen_params / num_params * 100
print(f"Num params: {num_params}, num frozen params: {num_frozen_params}, proportion of model params frozen: {prop_frozen_params:.1f}%")

Num params: 49368, num frozen params: 16832, proportion of model params frozen: 34.1%


## 4. Create synthetic data

In [10]:
def ids_tensor(shape, vocab_size, key):
    """Creates a random int32 tensor of the shape within the vocab size."""
    return jax.random.randint(key, shape, 0, vocab_size - 1)


def random_attention_mask(shape, key):
    """Creates a random binary int32 tensor."""
    attn_mask = ids_tensor(shape, vocab_size=2, key=key)
    # make sure that at least one token is attended to for each batch
    attn_mask = attn_mask.at[:, -1].set(1)
    return attn_mask


def floats_tensor(shape, key, scale=1.0):
    """Creates a random float32 tensor."""
    return jax.random.normal(key, shape=shape) * scale


def shift_tokens_right(input_ids: jnp.array, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
    """Shifts input ids one token to the right."""
    shifted_input_ids = jnp.zeros_like(input_ids)
    shifted_input_ids = shifted_input_ids.at[:, 1:].set(input_ids[:, :-1])
    shifted_input_ids = shifted_input_ids.at[:, 0].set(decoder_start_token_id)

    shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
    return shifted_input_ids

In [11]:
def create_batch(key, batch_size=2, encoder_input_length=96000, decoder_input_length=31):
    """Function for creating a dummy batch of data of specified dimensions."""
    input_ids = floats_tensor([batch_size, encoder_input_length], key)
    attention_mask = random_attention_mask([batch_size, encoder_input_length], key)
    label_ids = ids_tensor([batch_size, decoder_input_length], model.config.decoder.vocab_size, key)
    decoder_input_ids = shift_tokens_right(input_ids=label_ids, pad_token_id=model.config.decoder.pad_token_id,
                                           decoder_start_token_id=model.config.decoder.decoder_start_token_id or model.config.decoder.bos_token_id)

    batch_inputs = {
        "inputs": input_ids,
        "attention_mask": attention_mask,
        "decoder_input_ids": decoder_input_ids,
        "labels": label_ids,
    }
    return batch_inputs

In [12]:
rng = jax.random.PRNGKey(0)

# create batches of synthetic input data
inputs = []
for i in range(1000):
    key, rng = jax.random.split(rng, num=2)
    inputs.append(create_batch(key))

## 5. Define a train step

In [13]:
# cross-entropy loss function
def loss_fn(logits, labels):
    vocab_size = logits.shape[-1]
    loss = optax.softmax_cross_entropy(logits, onehot(labels, vocab_size))
    # ignore padded tokens from loss, i.e. where labels are not set to -100
    padding = labels >= 0
    loss = loss * padding
    loss = loss.sum() / padding.sum()
    return loss

In [14]:
# define train step
def train_step(state, batch, rng, freeze_feature_encoder=False):
    dropout_rng, new_dropout_rng = jax.random.split(rng)
    
    def compute_loss(params):
        labels = batch.pop('labels')
        outputs = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, freeze_feature_encoder=freeze_feature_encoder, train=True)
        loss = loss_fn(outputs.logits, labels)
        return loss, outputs

    grad_fn = jax.value_and_grad(compute_loss, has_aux=True)
    (loss, outputs), grads = grad_fn(state.params)
    
    new_state = state.apply_gradients(grads=grads)
    
    return new_state, new_dropout_rng

# jit training step
j_train_step = jax.jit(train_step, static_argnums=(3,))

## 6. Optimizer without freezing

### Define the Optimizer

In [15]:
tx = optax.adam(0.1)

state = train_state.TrainState.create(apply_fn=model.__call__,
                                      params=params,
                                      tx=tx)

### Run one training step

In [16]:
state, rng = j_train_step(state, inputs[0], rng, freeze_feature_encoder=False)

### Compare parameters after optimization

In [17]:
# the change for the (non-frozen) feature extractor parameters should be greater than zero
compare_params(params['encoder']['feature_extractor'], state.params['encoder']['feature_extractor'])

 conv_layers
   0
     conv
       kernel 0.09710271
     layer_norm
       scale 0.09954791
       bias 0.099674284
   1
     conv
       kernel 0.09776986
     layer_norm
       bias 0.09932138
       scale 0.099454924
   2
     conv
       kernel 0.09607415
     layer_norm
       bias 0.099158935
       scale 0.09879182


In [18]:
# the change for the (non-frozen) feature projection parameters should be greater than zero
compare_params(params['encoder']['feature_projection'], state.params['encoder']['feature_projection'])

 layer_norm
   scale 0.09655121
   bias 0.09949996
 projection
   kernel 0.099958576
   bias 0.10006119


### Benchmark an non-frozen train loop

In [19]:
%%timeit
for batch in inputs:
    new_state, new_rng = j_train_step(state, batch, rng, freeze_feature_encoder=False)

13.5 s ± 269 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


## 7. Optimizer with `set_to_zero` and parameter mask

### Create a mask for the optimizer

In [20]:
def grad_mask_fn(params):
    flat_params = traverse_util.flatten_dict(params)
    flat_mask = {path: 'zero' if path[1] == 'feature_extractor' else 'adam' for path in flat_params}
    mask = traverse_util.unflatten_dict(flat_mask)
    return freeze(mask)

In [21]:
mask = grad_mask_fn(params)

# check that we are masking the right params (feature_extractor and not others)
print("feature_extractor")
print_tree(mask['encoder']['feature_extractor'], depth=1, print_value=True)
print("feature_projection")
print_tree(mask['encoder']['feature_projection'], depth=1, print_value=True)

feature_extractor
   conv_layers
     0
       conv
         kernel zero
       layer_norm
         scale zero
         bias zero
     1
       conv
         kernel zero
       layer_norm
         bias zero
         scale zero
     2
       conv
         kernel zero
       layer_norm
         bias zero
         scale zero
feature_projection
   layer_norm
     scale adam
     bias adam
   projection
     kernel adam
     bias adam


### Define the optimizer

In [22]:
tx = optax.multi_transform({'adam': optax.adam(0.1), 'zero': optax.set_to_zero()},
                           mask)

state = train_state.TrainState.create(apply_fn=model.__call__,
                                      params=params,
                                      tx=tx)

### Run one training step

In [23]:
# note that we are freezing the layers with the grad mask and **not** with the `freeze_feature_encoder` argument
state, rng = j_train_step(state, inputs[0], rng, freeze_feature_encoder=False)

### Compare parameters after optimization

In [24]:
# the change for the frozen feature extractor parameters should be precisely zero
compare_params(params['encoder']['feature_extractor'], state.params['encoder']['feature_extractor'])

 conv_layers
   0
     conv
       kernel 0.0
     layer_norm
       scale 0.0
       bias 0.0
   1
     conv
       kernel 0.0
     layer_norm
       bias 0.0
       scale 0.0
   2
     conv
       kernel 0.0
     layer_norm
       bias 0.0
       scale 0.0


In [25]:
# the change for the (non-frozen) feature projection parameters should be greater than zero
compare_params(params['encoder']['feature_projection'], state.params['encoder']['feature_projection'])

 layer_norm
   scale 0.095603384
   bias 0.09918707
 projection
   kernel 0.099999696
   bias 0.10006401


### Benchmark a `set_to_zero` train loop

In [26]:
%%timeit
for batch in inputs:
    new_state, new_rng = j_train_step(state, batch, rng, freeze_feature_encoder=False)

8.5 s ± 442 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


## 8. Optimizer with `jax.lax.stop_gradient`

### Define the Optimizer

In [27]:
tx = optax.adam(0.1)

state = train_state.TrainState.create(apply_fn=model.__call__,
                                      params=params,
                                      tx=tx)

### Run one training step

In [28]:
state, rng = j_train_step(state, inputs[0], rng, freeze_feature_encoder=True)

### Compare parameters after optimization

In [29]:
# the change for the frozen feature extractor parameters should be precisely zero
compare_params(params['encoder']['feature_extractor'], state.params['encoder']['feature_extractor'])

 conv_layers
   0
     conv
       kernel 0.0
     layer_norm
       scale 0.0
       bias 0.0
   1
     conv
       kernel 0.0
     layer_norm
       bias 0.0
       scale 0.0
   2
     conv
       kernel 0.0
     layer_norm
       bias 0.0
       scale 0.0


In [30]:
# the change for the (non-frozen) feature projection parameters should be greater than zero
compare_params(params['encoder']['feature_projection'], state.params['encoder']['feature_projection'])

 layer_norm
   scale 0.09889982
   bias 0.09973503
 projection
   kernel 0.09995212
   bias 0.100052685


### Benchmark a `jax.lax.stop_gradient` train loop

In [31]:
%%timeit
for batch in inputs:
    new_state, new_rng = j_train_step(state, batch, rng, freeze_feature_encoder=True)

8.5 s ± 243 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


## 9. Differentiate non-frozen params only

### Define filter and merge functions

In [32]:
def bool_grad_mask_fn(params):
    """Create a boolean mask over parameters for filtering.
    Returns:
        mask (frozen_dict): `True` for non-frozen parameters, `False` for frozen parameters (i.e. the feature encoder).
    """
    flat_params = traverse_util.flatten_dict(params)
    flat_mask = {path: path[1] != 'feature_extractor' for path in flat_params}
    mask = traverse_util.unflatten_dict(flat_mask)
    return freeze(mask)

def filter_params(params, mask):
    flat_params = traverse_util.flatten_dict(params)
    flat_mask = traverse_util.flatten_dict(mask)
    filtered_params = {}
    for name, value in flat_params.items():
        if flat_mask[name]:
            filtered_params[name] = flat_params[name]
    return freeze(traverse_util.unflatten_dict(filtered_params))

def merge_params(params: Mapping, updates: Mapping) -> dict:
    output = params.unfreeze()

    for name, update_value in updates.items():
        current_value = params.get(name, None)
        if isinstance(current_value, Mapping) and isinstance(update_value, Mapping):
            output[name] = merge_params(current_value, update_value)
        else:
            output[name] = update_value

    return freeze(output)

In [33]:
# check that the masking function works as intended
bool_mask = bool_grad_mask_fn(params)

# check that we are masking the right params (feature_extractor and not others)
print("feature_extractor")
print_tree(bool_mask['encoder']['feature_extractor'], depth=1, print_value=True)
print("feature_projection")
print_tree(bool_mask['encoder']['feature_projection'], depth=1, print_value=True)

feature_extractor
   conv_layers
     0
       conv
         kernel False
       layer_norm
         scale False
         bias False
     1
       conv
         kernel False
       layer_norm
         bias False
         scale False
     2
       conv
         kernel False
       layer_norm
         bias False
         scale False
feature_projection
   layer_norm
     scale True
     bias True
   projection
     kernel True
     bias True


In [34]:
# check that the filter function works as intended
filtered_params = filter_params(params, bool_mask)

assert 'feature_extractor' not in filtered_params['encoder'], "Feature extractor params not filtered by filter function"
print(f"Modules for filtered encoder: {filtered_params['encoder'].keys()}")

Modules for filtered encoder: frozen_dict_keys(['masked_spec_embed', 'feature_projection', 'encoder'])


### Define a modified train step - only take grads of differentiable params

In [35]:
def train_step(state, batch, rng):
    dropout_rng, new_dropout_rng = jax.random.split(rng)
    
    def compute_loss(differentiable_params, params):
        params = merge_params(params, differentiable_params)
        labels = batch.pop("labels")
        outputs = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)
        loss = loss_fn(outputs.logits, labels)
        return loss, outputs

    differentiable_params = filter_params(state.params, bool_grad_mask_fn(state.params))
    
    grad_fn = jax.value_and_grad(compute_loss, has_aux=True)
    (loss, outputs), grads = grad_fn(differentiable_params, state.params)

    updates, opt_state = state.tx.update(grads, state.opt_state)
    differentiable_params = optax.apply_updates(differentiable_params, updates)
    params = merge_params(state.params, differentiable_params)
    
    new_state = state.replace(
        params=params,
        opt_state=opt_state,
    )
    
    return new_state, new_dropout_rng

# jit training step
j_train_step = jax.jit(train_step)

### Define the Optimizer

In [36]:
tx = optax.adam(0.1)

state = train_state.TrainState.create(apply_fn=model.__call__,
                                      params=params,
                                      tx=tx)

state = state.replace(
    opt_state=state.tx.init(filter_params(state.params, bool_grad_mask_fn(state.params)))
)

### Run one training step

In [37]:
state, rng = j_train_step(state, inputs[0], rng)

### Compare parameters after optimization

In [38]:
# the change for the frozen feature extractor parameters should be precisely zero
compare_params(params['encoder']['feature_extractor'], state.params['encoder']['feature_extractor'])

 conv_layers
   0
     conv
       kernel 0.0
     layer_norm
       scale 0.0
       bias 0.0
   1
     conv
       kernel 0.0
     layer_norm
       bias 0.0
       scale 0.0
   2
     conv
       kernel 0.0
     layer_norm
       bias 0.0
       scale 0.0


In [39]:
# the change for the (non-frozen) feature projection parameters should be greater than zero
compare_params(params['encoder']['feature_projection'], state.params['encoder']['feature_projection'])

 layer_norm
   scale 0.09865255
   bias 0.09918152
 projection
   kernel 0.09994524
   bias 0.10005053


### Benchmark a filtered train loop

In [40]:
%%timeit
for batch in inputs:
    new_state, new_rng = j_train_step(state, batch, rng)

8.5 s ± 278 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
