# Implementation of Freezing Strategies for Refinement/Transfer Learning
This script is for trying out different freezing mechanisms and controls using tensorflow and keras. Goal is to have an easy access function that we can use for further training, refinement and transfer learning for the Prosit Models and ultimately implement in DLOmix.

### Imports

In [1]:
import dlomix
import tensorflow as tf
from tensorflow import keras 
import yaml
from dlomix.losses import masked_spectral_distance, masked_pearson_correlation_distance
from tensorflow.keras.callbacks import EarlyStopping

2024-06-12 12:32:15.351367: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-06-12 12:32:15.351404: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-06-12 12:32:15.352700: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-06-12 12:32:15.360056: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


### Prepare stuff for training

##### Read the config file

In [2]:
config_path = "baseline_training/config_files/noptm_baseline_small_bs1024.yaml"

with open(config_path, 'r') as yaml_file:
    config = yaml.safe_load(yaml_file)


##### Load the dataset and the PTM alphabet

In [3]:
# load dataset
from dlomix.data import FragmentIonIntensityDataset

# from misc import PTMS_ALPHABET
from dlomix.constants import PTMS_ALPHABET

from dlomix.data import load_processed_dataset
dataset = load_processed_dataset(config['dataset']['processed_path'])

  from .autonotebook import tqdm as notebook_tqdm



Avaliable feature extractors are (use the key of the following dict and pass it to features_to_extract in the Dataset Class):
{
   "atom_count": "Atom count of PTM.",
   "delta_mass": "Delta mass of PTM.",
   "mod_gain": "Gain of atoms due to PTM.",
   "mod_loss": "Loss of atoms due to PTM.",
   "red_smiles": "Reduced SMILES representation of PTM."
}.
When writing your own feature extractor, you can either
    (1) use the FeatureExtractor class or
    (2) write a function that can be mapped to the Hugging Face dataset.
In both cases, you can access the parsed sequence information from the dataset using the following keys, which all provide python lists:
    - _parsed_sequence: parsed sequence
    - _n_term_mods: N-terminal modifications
    - _c_term_mods: C-terminal modifications



##### Initialize the optimizer and callbacks

In [4]:
optimizer = tf.keras.optimizers.Adam(learning_rate=config['training']['learning_rate'])

early_stopping = EarlyStopping(
    monitor="val_loss",
    min_delta=0.001,
    patience=20,
    restore_best_weights=True)


2024-06-12 12:32:20.244576: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:274] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected


##### Import the model

In [5]:
from dlomix.models import PrositIntensityPredictor

input_mapping = {
        "SEQUENCE_KEY": "modified_sequence",
        "COLLISION_ENERGY_KEY": "collision_energy_aligned_normed",
        "PRECURSOR_CHARGE_KEY": "precursor_charge_onehot",
        "FRAGMENTATION_TYPE_KEY": "method_nbr",
    }

meta_data_keys=["collision_energy_aligned_normed", "precursor_charge_onehot", "method_nbr"]

model = PrositIntensityPredictor(
    seq_length=config['dataset']['seq_length'],
    alphabet=PTMS_ALPHABET,
    use_prosit_ptm_features=False,
    with_termini=False,
    input_keys=input_mapping,
    meta_data_keys=meta_data_keys
)

model.compile(
    optimizer=optimizer,
    loss=masked_spectral_distance,
    metrics=[masked_pearson_correlation_distance]
)

### Initial training

In [6]:
model.fit(
    dataset.tensor_train_data,
    validation_data=dataset.tensor_val_data,
    epochs=config['training']['num_epochs'],
    callbacks=[early_stopping]
)



<keras.src.callbacks.History at 0x7f5b843d0fa0>

In [7]:
model.summary()

Model: "prosit_intensity_predictor"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 embedding (Embedding)       multiple                  928       
                                                                 
 sequential (Sequential)     (None, 30, 512)           1996800   
                                                                 
 sequential_1 (Sequential)   multiple                  4608      
                                                                 
 sequential_2 (Sequential)   (None, 29, 512)           1576806   
                                                                 
 encoder_att (AttentionLaye  multiple                  542       
 r)                                                              
                                                                 
 sequential_3 (Sequential)   multiple                  0         
                                        

### Freezing

##### Check if weights are correctly assigned to trainable/non-trainable:


In [8]:
print(f'layers with weigths: {len(model.weights)}')
print(f'layers with trainable weigths: {len(model.trainable_weights)}')
print(f'layers with non-trainable weigths: {len(model.non_trainable_weights)}')

layers with weigths: 21
layers with trainable weigths: 21
layers with non-trainable weigths: 0


##### Freeze the whole model:

In [9]:
model.trainable = False

In [10]:
print(f'layers with weigths: {len(model.weights)}')
print(f'layers with trainable weigths: {len(model.trainable_weights)}')
print(f'layers with non-trainable weigths: {len(model.non_trainable_weights)}')

layers with weigths: 21
layers with trainable weigths: 0
layers with non-trainable weigths: 21


Unfreeze again:

In [11]:
model.trainable = True

##### Freeze the model and only keep the first and/or the last layer trainable:

In [12]:
# function to freeze all layers except first and/or last layer
def freeze_model(model:dlomix.models.prosit.PrositIntensityPredictor, trainable_first_layer:bool = False, trainable_last_layer:bool = False) -> None:
    
    # reset everything to trainable, 'model.trainable = False' overshadowes trainable arguments of sublayers
    model.trainable = True 
    
    # go through layers and set trainable to False at lowest level so trainable argument is not overshadowed
    for lay in model.layers:
        try:
            for sublay in lay.layers:
                sublay.trainable = False
        except (AttributeError):
            lay.trainable = False

    if (trainable_first_layer):
        first_layer = model.get_layer(name="embedding")
        first_layer.trainable = True

    if (trainable_last_layer):
        last_layer = model.get_layer(name = "sequential_4").get_layer(name = "time_dense")
        last_layer.trainable = True

    # compile the model again to make changes take effect
    model.compile(
        optimizer=optimizer,
        loss=masked_spectral_distance,
        metrics=[masked_pearson_correlation_distance]
    )

In [13]:
def check_trainability(model, sublayers = False):
    for lay in model.layers:
        if(sublayers):
            print()
            try:
                lay.layers
                print(f'Sequential {lay} trainable: {lay.trainable}')
                for lay2 in lay.layers:
                    print(f'{lay2} trainable: {lay2.trainable}')
            except(AttributeError):
                print(f'{lay} trainable: {lay.trainable}')
        else:
            print(f'{lay} trainable: {lay.trainable}')


##### Testing the function:

*Freeze all layers except the first layer:*

In [14]:
freeze_model(model, trainable_first_layer=True)
check_trainability(model, sublayers=True)


<keras.src.layers.core.embedding.Embedding object at 0x7f5b846589a0> trainable: True

Sequential <keras.src.engine.sequential.Sequential object at 0x7f5b8465bbe0> trainable: True
<keras.src.layers.rnn.bidirectional.Bidirectional object at 0x7f5b84659690> trainable: False
<keras.src.layers.regularization.dropout.Dropout object at 0x7f5b8465ad10> trainable: False
<keras.src.layers.rnn.gru.GRU object at 0x7f5b8465afe0> trainable: False
<keras.src.layers.regularization.dropout.Dropout object at 0x7f5b8465b910> trainable: False

Sequential <keras.src.engine.sequential.Sequential object at 0x7f5b84740a60> trainable: True
<keras.src.layers.merging.concatenate.Concatenate object at 0x7f5b8465bc10> trainable: False
<keras.src.layers.core.dense.Dense object at 0x7f5b84740400> trainable: False
<keras.src.layers.regularization.dropout.Dropout object at 0x7f5b847407f0> trainable: False

Sequential <keras.src.engine.sequential.Sequential object at 0x7f5b84741d80> trainable: True
<keras.src.layers.r

*Freeze all layers except the last layer:*

In [15]:
freeze_model(model, trainable_last_layer=True)
check_trainability(model, sublayers=True)


<keras.src.layers.core.embedding.Embedding object at 0x7f5b846589a0> trainable: False

Sequential <keras.src.engine.sequential.Sequential object at 0x7f5b8465bbe0> trainable: True
<keras.src.layers.rnn.bidirectional.Bidirectional object at 0x7f5b84659690> trainable: False
<keras.src.layers.regularization.dropout.Dropout object at 0x7f5b8465ad10> trainable: False
<keras.src.layers.rnn.gru.GRU object at 0x7f5b8465afe0> trainable: False
<keras.src.layers.regularization.dropout.Dropout object at 0x7f5b8465b910> trainable: False

Sequential <keras.src.engine.sequential.Sequential object at 0x7f5b84740a60> trainable: True
<keras.src.layers.merging.concatenate.Concatenate object at 0x7f5b8465bc10> trainable: False
<keras.src.layers.core.dense.Dense object at 0x7f5b84740400> trainable: False
<keras.src.layers.regularization.dropout.Dropout object at 0x7f5b847407f0> trainable: False

Sequential <keras.src.engine.sequential.Sequential object at 0x7f5b84741d80> trainable: True
<keras.src.layers.

*Freeze all layers except the first and the last layer:*

In [16]:
freeze_model(model, trainable_first_layer=True, trainable_last_layer=True)
check_trainability(model, sublayers=True)


<keras.src.layers.core.embedding.Embedding object at 0x7f5b846589a0> trainable: True

Sequential <keras.src.engine.sequential.Sequential object at 0x7f5b8465bbe0> trainable: True
<keras.src.layers.rnn.bidirectional.Bidirectional object at 0x7f5b84659690> trainable: False
<keras.src.layers.regularization.dropout.Dropout object at 0x7f5b8465ad10> trainable: False
<keras.src.layers.rnn.gru.GRU object at 0x7f5b8465afe0> trainable: False
<keras.src.layers.regularization.dropout.Dropout object at 0x7f5b8465b910> trainable: False

Sequential <keras.src.engine.sequential.Sequential object at 0x7f5b84740a60> trainable: True
<keras.src.layers.merging.concatenate.Concatenate object at 0x7f5b8465bc10> trainable: False
<keras.src.layers.core.dense.Dense object at 0x7f5b84740400> trainable: False
<keras.src.layers.regularization.dropout.Dropout object at 0x7f5b847407f0> trainable: False

Sequential <keras.src.engine.sequential.Sequential object at 0x7f5b84741d80> trainable: True
<keras.src.layers.r

### Continue Training with frozen layers

In [17]:
freeze_model(model, trainable_first_layer=True, trainable_last_layer=True)
check_trainability(model, sublayers=True)


<keras.src.layers.core.embedding.Embedding object at 0x7f5b846589a0> trainable: True

Sequential <keras.src.engine.sequential.Sequential object at 0x7f5b8465bbe0> trainable: True
<keras.src.layers.rnn.bidirectional.Bidirectional object at 0x7f5b84659690> trainable: False
<keras.src.layers.regularization.dropout.Dropout object at 0x7f5b8465ad10> trainable: False
<keras.src.layers.rnn.gru.GRU object at 0x7f5b8465afe0> trainable: False
<keras.src.layers.regularization.dropout.Dropout object at 0x7f5b8465b910> trainable: False

Sequential <keras.src.engine.sequential.Sequential object at 0x7f5b84740a60> trainable: True
<keras.src.layers.merging.concatenate.Concatenate object at 0x7f5b8465bc10> trainable: False
<keras.src.layers.core.dense.Dense object at 0x7f5b84740400> trainable: False
<keras.src.layers.regularization.dropout.Dropout object at 0x7f5b847407f0> trainable: False

Sequential <keras.src.engine.sequential.Sequential object at 0x7f5b84741d80> trainable: True
<keras.src.layers.r

In [18]:
original_weights = model.get_weights()

In [19]:
# train again while only the first layer and the last layer are trainable
model.fit(
    dataset.tensor_train_data,
    validation_data=dataset.tensor_val_data,
    epochs=1,
    callbacks=[early_stopping]
)



<keras.src.callbacks.History at 0x7f5b283bd030>

In [20]:
# check which weights have changed
retrained_weights = model.get_weights()
for i, w in enumerate(zip(original_weights, retrained_weights)):
    print(f'weights {i} stayed the same: {(w[0]==w[1]).all()}')
    

weights 0 stayed the same: False
weights 1 stayed the same: True
weights 2 stayed the same: True
weights 3 stayed the same: True
weights 4 stayed the same: True
weights 5 stayed the same: True
weights 6 stayed the same: True
weights 7 stayed the same: True
weights 8 stayed the same: True
weights 9 stayed the same: True
weights 10 stayed the same: True
weights 11 stayed the same: True
weights 12 stayed the same: True
weights 13 stayed the same: True
weights 14 stayed the same: True
weights 15 stayed the same: True
weights 16 stayed the same: True
weights 17 stayed the same: True
weights 18 stayed the same: True
weights 19 stayed the same: False
weights 20 stayed the same: False


Two weight tensors changed for the last layer. Both tensors belong to the last time_dense layer:

In [21]:
print(retrained_weights[20]) # 6
print(len(retrained_weights[19])) # 512
print([len(x) for x in retrained_weights[19]]) # 6
print(512 * 6 + 6)
print(model.get_layer(name="sequential_4").summary())


[ 0.00050617  0.00042338  0.00032402  0.00051246  0.00016971 -0.00047507]
512
[6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 