# Implementation of Freezing Strategies for Refinement/Transfer Learning
This script is for trying out different freezing mechanisms and controls using tensorflow and keras. Goal is to have some easy access functions that we can use for further training, refinement and transfer learning for the Prosit Models and ultimately implement in DLOmix.

### Imports

In [1]:
import dlomix
import tensorflow as tf
from tensorflow import keras 
import yaml
from dlomix.losses import masked_spectral_distance, masked_pearson_correlation_distance
from tensorflow.keras.callbacks import EarlyStopping

2024-06-12 06:56:20.447750: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-06-12 06:56:20.447927: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-06-12 06:56:20.756529: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-06-12 06:56:21.656853: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


### Prepare stuff for training

##### Read the config file

In [3]:
config_path = "baseline_training/config_files/noptm_baseline_small_bs1024.yaml"

with open(config_path, 'r') as yaml_file:
    config = yaml.safe_load(yaml_file)


##### Load the dataset and the PTM alphabet

In [4]:
# load dataset
from dlomix.data import FragmentIonIntensityDataset

# from misc import PTMS_ALPHABET
from dlomix.constants import PTMS_ALPHABET

from dlomix.data import load_processed_dataset
dataset = load_processed_dataset(config['dataset']['processed_path'])

  from .autonotebook import tqdm as notebook_tqdm



Avaliable feature extractors are (use the key of the following dict and pass it to features_to_extract in the Dataset Class):
{
   "atom_count": "Atom count of PTM.",
   "delta_mass": "Delta mass of PTM.",
   "mod_gain": "Gain of atoms due to PTM.",
   "mod_loss": "Loss of atoms due to PTM.",
   "red_smiles": "Reduced SMILES representation of PTM."
}.
When writing your own feature extractor, you can either
    (1) use the FeatureExtractor class or
    (2) write a function that can be mapped to the Hugging Face dataset.
In both cases, you can access the parsed sequence information from the dataset using the following keys, which all provide python lists:
    - _parsed_sequence: parsed sequence
    - _n_term_mods: N-terminal modifications
    - _c_term_mods: C-terminal modifications



##### Initialize the optimizer and callbacks

In [5]:
optimizer = tf.keras.optimizers.Adam(learning_rate=config['training']['learning_rate'])

early_stopping = EarlyStopping(
    monitor="val_loss",
    min_delta=0.001,
    patience=20,
    restore_best_weights=True)


2024-06-12 06:59:46.085620: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:274] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected


##### Import the model

In [6]:
from dlomix.models import PrositIntensityPredictor

input_mapping = {
        "SEQUENCE_KEY": "modified_sequence",
        "COLLISION_ENERGY_KEY": "collision_energy_aligned_normed",
        "PRECURSOR_CHARGE_KEY": "precursor_charge_onehot",
        "FRAGMENTATION_TYPE_KEY": "method_nbr",
    }

meta_data_keys=["collision_energy_aligned_normed", "precursor_charge_onehot", "method_nbr"]

model = PrositIntensityPredictor(
    seq_length=config['dataset']['seq_length'],
    alphabet=PTMS_ALPHABET,
    use_prosit_ptm_features=False,
    with_termini=False,
    input_keys=input_mapping,
    meta_data_keys=meta_data_keys
)

model.compile(
    optimizer=optimizer,
    loss=masked_spectral_distance,
    metrics=[masked_pearson_correlation_distance]
)

In [7]:
model

<dlomix.models.prosit.PrositIntensityPredictor at 0x7f36270e5f90>

### Initial training

In [8]:
model.fit(
    dataset.tensor_train_data,
    validation_data=dataset.tensor_val_data,
    epochs=config['training']['num_epochs'],
    callbacks=[early_stopping]
)

2024-06-12 07:00:03.007810: W external/local_tsl/tsl/framework/cpu_allocator_impl.cc:83] Allocation of 31457280 exceeds 10% of free system memory.
2024-06-12 07:00:03.010455: W external/local_tsl/tsl/framework/cpu_allocator_impl.cc:83] Allocation of 31457280 exceeds 10% of free system memory.
2024-06-12 07:00:03.014973: W external/local_tsl/tsl/framework/cpu_allocator_impl.cc:83] Allocation of 31457280 exceeds 10% of free system memory.
2024-06-12 07:00:03.017199: W external/local_tsl/tsl/framework/cpu_allocator_impl.cc:83] Allocation of 31457280 exceeds 10% of free system memory.
2024-06-12 07:00:03.019520: W external/local_tsl/tsl/framework/cpu_allocator_impl.cc:83] Allocation of 31457280 exceeds 10% of free system memory.




<keras.src.callbacks.History at 0x7f362445dcf0>

### Freezing

Check if weights are correctly assigned to trainable/non-trainable:


In [9]:
print(f'layers with weigths: {len(model.weights)}')
print(f'layers with trainable weigths: {len(model.trainable_weights)}')
print(f'layers with non-trainable weigths: {len(model.non_trainable_weights)}')

layers with weigths: 21
layers with trainable weigths: 21
layers with non-trainable weigths: 0


Freeze the whole model:

In [10]:
model.trainable = False

In [11]:
print(f'layers with weigths: {len(model.weights)}')
print(f'layers with trainable weigths: {len(model.trainable_weights)}')
print(f'layers with non-trainable weigths: {len(model.non_trainable_weights)}')

layers with weigths: 21
layers with trainable weigths: 0
layers with non-trainable weigths: 21


Unfreeze again:

In [12]:
model.trainable = True

Freeze individual layers:

In [13]:
# function to freeze certain layers (indicated by index)
def freeze_layers(model:dlomix.models, layers:list[int]) -> None:
    model.trainable = True
    for l in layers:
        model.layers[l].trainable = False
    
    # compile the model again to make changes take effect
    model.compile(
        optimizer=optimizer,
        loss=masked_spectral_distance,
        metrics=[masked_pearson_correlation_distance]
    )


Testing the function:

In [14]:
freeze_layers(model, [0,1,2])
[model.layers[i].trainable for i in range(0, len(model.layers))]

[False, False, False, True, True, True, True]

In [15]:
freeze_layers(model, [2,3])
[model.layers[i].trainable for i in range(0, len(model.layers))]

[True, True, False, False, True, True, True]

In [16]:
# freeze all layers except layer 1
freeze_layers(model, range(1, len(model.layers)))
[model.layers[i].trainable for i in range(0, len(model.layers))]

[True, False, False, False, False, False, False]

In [17]:
# freeze all layers except last layer
freeze_layers(model, range(0, len(model.layers)-1))
[model.layers[i].trainable for i in range(0, len(model.layers))]

[False, False, False, False, False, False, True]

In [18]:
# freeze all layers 
freeze_layers(model, range(0, len(model.layers)))
[model.layers[i].trainable for i in range(0, len(model.layers))]

[False, False, False, False, False, False, False]

In [129]:
# freeze all layers except the second layer 
layers_to_freeze = list(range(0, len(model.layers)))
layers_to_freeze.pop(1)
freeze_layers(model, layers_to_freeze)
[model.layers[i].trainable for i in range(0, len(model.layers))]

[False, True, False, False, False, False, False]

### Continue Training with frozen layers

In [19]:
original_weights = model.get_weights()
original_weights[0][0]

array([ 0.03843875, -0.04833046,  0.0237368 ,  0.0021462 ,  0.04882205,
        0.00462924, -0.04129234, -0.04895075,  0.02260053,  0.02259899,
       -0.00010501,  0.00306951, -0.04859496, -0.04148008,  0.00248486,
       -0.03462191], dtype=float32)

In [20]:
freeze_layers(model, [0])

In [21]:
# train again while only first layer is trainable
model.fit(
    dataset.tensor_train_data,
    validation_data=dataset.tensor_val_data,
    epochs=1,
    callbacks=[early_stopping]
)



<keras.src.callbacks.History at 0x7f35e03bd150>

In [22]:
# check which weights have changed
# what tf happened to weights 18 - 21?
retrained_weights = model.get_weights()
print(f'Weights in layer 0 stayed the same: {(retrained_weights[0] == original_weights[0]).all()}')
print(f'Weights in layer 1 stayed the same: {(retrained_weights[1] == original_weights[1]).all() & (retrained_weights[2] == original_weights[2]).all() & (retrained_weights[3] == original_weights[3]).all() & (retrained_weights[4] == original_weights[4]).all()}')
print(f'Weights in layer 2 stayed the same: {(retrained_weights[5] == original_weights[5]).all() & (retrained_weights[6] == original_weights[6]).all() & (retrained_weights[7] == original_weights[7]).all()}')
print(f'Weights in layer 3 stayed the same: {(retrained_weights[8] == original_weights[8]).all() & (retrained_weights[9] == original_weights[9]).all() & (retrained_weights[10] == original_weights[10]).all()}')
print(f'Weights in layer 4 stayed the same: {(retrained_weights[11] == original_weights[11]).all()}')
print(f'Weights in layer 5 stayed the same: {(retrained_weights[12] == original_weights[12]).all() & (retrained_weights[13] == original_weights[13]).all()}')
print(f'Weights in layer 6 stayed the same: {(retrained_weights[14] == original_weights[14]).all() & (retrained_weights[15] == original_weights[15]).all() & (retrained_weights[16] == original_weights[16]).all() & (retrained_weights[17] == original_weights[17]).all()}')

Weights in layer 0 stayed the same: True
Weights in layer 1 stayed the same: False
Weights in layer 2 stayed the same: False
Weights in layer 3 stayed the same: False
Weights in layer 4 stayed the same: False
Weights in layer 5 stayed the same: False
Weights in layer 6 stayed the same: False


### Freezing of individual sublayers

*Function to see which layers and sublayers are trainable* 

In [27]:
def check_trainability(model, sublayers = False):
    for lay in model.layers:
        if(sublayers):
            print()
        else:
            print(f'{lay} trainable: {lay.trainable}')

        if (sublayers):
            try:
                for lay2 in lay.layers:
                    print(f'{lay2} trainable: {lay2.trainable}')
            except(Exception):
                print(f'{lay} trainable: {lay.trainable}')

In [28]:
check_trainability(model)

<keras.src.layers.core.embedding.Embedding object at 0x7f36246d4610> trainable: False
<keras.src.engine.sequential.Sequential object at 0x7f36246d7850> trainable: True
<keras.src.engine.sequential.Sequential object at 0x7f36245cc6d0> trainable: True
<keras.src.engine.sequential.Sequential object at 0x7f36245cd9f0> trainable: True
<dlomix.layers.attention.AttentionLayer object at 0x7f36245cddb0> trainable: True
<keras.src.engine.sequential.Sequential object at 0x7f36245ce710> trainable: True
<keras.src.engine.sequential.Sequential object at 0x7f36245cf6a0> trainable: True


In [29]:
check_trainability(model, sublayers=True)


<keras.src.layers.core.embedding.Embedding object at 0x7f36246d4610> trainable: False

<keras.src.layers.rnn.bidirectional.Bidirectional object at 0x7f36246d5300> trainable: True
<keras.src.layers.regularization.dropout.Dropout object at 0x7f36246d6980> trainable: True
<keras.src.layers.rnn.gru.GRU object at 0x7f36246d6c50> trainable: True
<keras.src.layers.regularization.dropout.Dropout object at 0x7f36246d7580> trainable: True

<keras.src.layers.merging.concatenate.Concatenate object at 0x7f36246d7cd0> trainable: True
<keras.src.layers.core.dense.Dense object at 0x7f36245cc070> trainable: True
<keras.src.layers.regularization.dropout.Dropout object at 0x7f36245cc460> trainable: True

<keras.src.layers.rnn.gru.GRU object at 0x7f36245cca60> trainable: True
<keras.src.layers.regularization.dropout.Dropout object at 0x7f36245cd570> trainable: True
<dlomix.layers.attention.DecoderAttentionLayer object at 0x7f36245cd840> trainable: True

<dlomix.layers.attention.AttentionLayer object at 0

In [125]:
model.layers[0]

<keras.src.layers.core.embedding.Embedding at 0x7f7c801244f0>

In [52]:
model.layers[1][0]

<keras.src.engine.sequential.Sequential at 0x7f7c80127730>

*Does not work from this point on*

In [116]:
# function to recursively freeze layers 
def freeze_layers_rec(layer, layers_to_freeze) -> None:
    # TODO: make function can go infinitely deep into the model layers to freeze only specific parts
    if isinstance(layers_to_freeze, int):
        print(layers_to_freeze)
        print(layer)
        layer.trainable = False
    else:
        for l in range(0, len(layers_to_freeze)):
            print(layers_to_freeze[l])
            freeze_layers_rec(layer.layers[layers_to_freeze[l][0]], l)


# function to freeze certain layers of a model (indicated by index)
def freeze_layers(model, layers_to_freeze) -> None:
    # reset previous freezing configurations
    model.trainable = True

    # call the recursive function on the whole model
    freeze_layers_rec(model, layers_to_freeze)

    # compile the model again to make changes take effect
    model.compile(
        optimizer=optimizer,
        loss=masked_spectral_distance,
        metrics=[masked_pearson_correlation_distance]
    )


In [130]:
for lay in model.layers:
    print()
    try:
        for lay2 in lay.layers:
            print(f'{lay2} trainable: {lay2.trainable}')
    except(Exception):
         print(f'{lay} trainable: {lay.trainable}')



<keras.src.layers.core.embedding.Embedding object at 0x7f7c801244f0> trainable: False

<keras.src.layers.rnn.bidirectional.Bidirectional object at 0x7f7c801251e0> trainable: True
<keras.src.layers.regularization.dropout.Dropout object at 0x7f7c80126860> trainable: True
<keras.src.layers.rnn.gru.GRU object at 0x7f7c80126b30> trainable: True
<keras.src.layers.regularization.dropout.Dropout object at 0x7f7c80127460> trainable: True

<keras.src.layers.merging.concatenate.Concatenate object at 0x7f7c80127bb0> trainable: False
<keras.src.layers.core.dense.Dense object at 0x7f7c80127f10> trainable: False
<keras.src.layers.regularization.dropout.Dropout object at 0x7f7c80220340> trainable: False

<keras.src.layers.rnn.gru.GRU object at 0x7f7c80220940> trainable: False
<keras.src.layers.regularization.dropout.Dropout object at 0x7f7c80221450> trainable: False
<dlomix.layers.attention.DecoderAttentionLayer object at 0x7f7c80221720> trainable: False

<dlomix.layers.attention.AttentionLayer objec