# Tutorial for freezing the Prosit Model
This tutorial shows, how to freeze the Prosit Intensity Predictor model and only let the first and last layer remain trainable for refinement and transfer learning.

### Imports

In [1]:
import dlomix
import tensorflow as tf
import yaml
from dlomix.losses import masked_spectral_distance, masked_pearson_correlation_distance

2024-06-12 12:30:14.892621: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-06-12 12:30:14.892658: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-06-12 12:30:14.893911: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-06-12 12:30:14.901013: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


### Load a pretrained model

In [2]:
model_path = "/cmnfs/proj/prosit_astral/bmpc_dlomix_group/models/baseline_models/noptm_baseline_full_bs1024/"
model = tf.keras.models.load_model(model_path + "85c6c918-4a2a-42e5-aab1-e666121c69a6.keras")
model.summary()

  from .autonotebook import tqdm as notebook_tqdm



Avaliable feature extractors are (use the key of the following dict and pass it to features_to_extract in the Dataset Class):
{
   "atom_count": "Atom count of PTM.",
   "delta_mass": "Delta mass of PTM.",
   "mod_gain": "Gain of atoms due to PTM.",
   "mod_loss": "Loss of atoms due to PTM.",
   "red_smiles": "Reduced SMILES representation of PTM."
}.
When writing your own feature extractor, you can either
    (1) use the FeatureExtractor class or
    (2) write a function that can be mapped to the Hugging Face dataset.
In both cases, you can access the parsed sequence information from the dataset using the following keys, which all provide python lists:
    - _parsed_sequence: parsed sequence
    - _n_term_mods: N-terminal modifications
    - _c_term_mods: C-terminal modifications



2024-06-12 12:30:19.274046: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:274] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected


Model: "prosit_intensity_predictor"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 embedding (Embedding)       multiple                  928       
                                                                 
 sequential (Sequential)     (None, 30, 512)           1996800   
                                                                 
 sequential_1 (Sequential)   multiple                  4608      
                                                                 
 sequential_2 (Sequential)   (None, 29, 512)           1576806   
                                                                 
 encoder_att (AttentionLaye  multiple                  542       
 r)                                                              
                                                                 
 sequential_3 (Sequential)   multiple                  0         
                                        

##### Initialize the optimizer 

In [3]:
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4)

### Freeze the model

In [4]:
# function to freeze all layers except first and/or last layer
def freeze_model(model:dlomix.models.prosit.PrositIntensityPredictor, trainable_first_layer:bool = False, trainable_last_layer:bool = False) -> None:
    
    # reset everything to trainable, 'model.trainable = False' overshadowes trainable arguments of sublayers
    model.trainable = True 
    
    # go through layers and set trainable to False at lowest level so trainable argument is not overshadowed
    for lay in model.layers:
        try:
            for sublay in lay.layers:
                sublay.trainable = False
        except (AttributeError):
            lay.trainable = False

    if (trainable_first_layer):
        first_layer = model.get_layer(name="embedding")
        first_layer.trainable = True

    if (trainable_last_layer):
        last_layer = model.get_layer(name = "sequential_4").get_layer(name = "time_dense")
        last_layer.trainable = True

    # compile the model again to make changes take effect
    model.compile(
        optimizer=optimizer,
        loss=masked_spectral_distance,
        metrics=[masked_pearson_correlation_distance]
    )

In [5]:
# function to print the trainable attribute of every layer
def check_trainability(model, sublayers = False):
    for lay in model.layers:
        if(sublayers):
            print()
            try:
                lay.layers
                print(f'Sequential {lay} trainable: {lay.trainable}')
                for lay2 in lay.layers:
                    print(f'{lay2} trainable: {lay2.trainable}')
            except(AttributeError):
                print(f'{lay} trainable: {lay.trainable}')
        else:
            print(f'{lay} trainable: {lay.trainable}')


*Freeze all layers except the first and the last layer:*

In [6]:
freeze_model(model, trainable_first_layer=True, trainable_last_layer=True)
check_trainability(model, sublayers=True)


<keras.src.layers.core.embedding.Embedding object at 0x7fa0c2f23880> trainable: True

Sequential <keras.src.engine.sequential.Sequential object at 0x7fa0b87c13c0> trainable: True
<keras.src.layers.rnn.bidirectional.Bidirectional object at 0x7fa0c2f230d0> trainable: False
<keras.src.layers.regularization.dropout.Dropout object at 0x7fa0b87c03a0> trainable: False
<keras.src.layers.rnn.gru.GRU object at 0x7fa0b87c06a0> trainable: False
<keras.src.layers.regularization.dropout.Dropout object at 0x7fa0b87c1090> trainable: False

Sequential <keras.src.engine.sequential.Sequential object at 0x7fa0b87c2320> trainable: True
<keras.src.layers.merging.concatenate.Concatenate object at 0x7fa0b87c18d0> trainable: False
<keras.src.layers.core.dense.Dense object at 0x7fa0b87c1c60> trainable: False
<keras.src.layers.regularization.dropout.Dropout object at 0x7fa0b87c2050> trainable: False

Sequential <keras.src.engine.sequential.Sequential object at 0x7fa0b87c3820> trainable: True
<keras.src.layers.r

### Prepare everything for training

##### Load the dataset and the PTM alphabet

In [7]:
from dlomix.data import load_processed_dataset
dataset = load_processed_dataset("/cmnfs/proj/prosit_astral/bmpc_dlomix_group/datasets/processed/noptm_baseline_small_bs1024")

### Continue Training with frozen layers

In [8]:
freeze_model(model, trainable_first_layer=True, trainable_last_layer=True)
check_trainability(model, sublayers=True)


<keras.src.layers.core.embedding.Embedding object at 0x7fa0c2f23880> trainable: True

Sequential <keras.src.engine.sequential.Sequential object at 0x7fa0b87c13c0> trainable: True
<keras.src.layers.rnn.bidirectional.Bidirectional object at 0x7fa0c2f230d0> trainable: False
<keras.src.layers.regularization.dropout.Dropout object at 0x7fa0b87c03a0> trainable: False
<keras.src.layers.rnn.gru.GRU object at 0x7fa0b87c06a0> trainable: False
<keras.src.layers.regularization.dropout.Dropout object at 0x7fa0b87c1090> trainable: False

Sequential <keras.src.engine.sequential.Sequential object at 0x7fa0b87c2320> trainable: True
<keras.src.layers.merging.concatenate.Concatenate object at 0x7fa0b87c18d0> trainable: False
<keras.src.layers.core.dense.Dense object at 0x7fa0b87c1c60> trainable: False
<keras.src.layers.regularization.dropout.Dropout object at 0x7fa0b87c2050> trainable: False

Sequential <keras.src.engine.sequential.Sequential object at 0x7fa0b87c3820> trainable: True
<keras.src.layers.r

In [9]:
original_weights = model.get_weights()

In [10]:
# train again while only the first layer and the last layer are trainable
model.fit(
    dataset.tensor_train_data,
    validation_data=dataset.tensor_val_data,
    epochs=1
)



<keras.src.callbacks.History at 0x7fa0a011ca30>

In [11]:
# check which weights have changed
retrained_weights = model.get_weights()
for i, w in enumerate(zip(original_weights, retrained_weights)):
    print(f'weights {i} stayed the same: {(w[0]==w[1]).all()}')
    

weights 0 stayed the same: False
weights 1 stayed the same: True
weights 2 stayed the same: True
weights 3 stayed the same: True
weights 4 stayed the same: True
weights 5 stayed the same: True
weights 6 stayed the same: True
weights 7 stayed the same: True
weights 8 stayed the same: True
weights 9 stayed the same: True
weights 10 stayed the same: True
weights 11 stayed the same: True
weights 12 stayed the same: True
weights 13 stayed the same: True
weights 14 stayed the same: True
weights 15 stayed the same: True
weights 16 stayed the same: True
weights 17 stayed the same: True
weights 18 stayed the same: True
weights 19 stayed the same: False
weights 20 stayed the same: False


Two weight tensors changed for the last layer. Both tensors belong to the last time_dense layer:

In [12]:
print(retrained_weights[20]) # 6
print(len(retrained_weights[19])) # 512
print([len(x) for x in retrained_weights[19]]) # 6
print(512 * 6 + 6) # 3078
print(model.get_layer(name="sequential_4").summary())

[0.15456411 0.05034332 0.01185355 0.07343634 0.02754287 0.00198995]
512
[6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 