# Implementation of Freezing Strategies for Refinement/Transfer Learning
This script is for trying out different freezing mechanisms and controls using tensorflow and keras. Goal is to have some easy access functions that we can use for further training, refinement and transfer learning for the Prosit Models and ultimately implement in DLOmix.

### Imports

In [None]:
import dlomix
import tensorflow as tf
from tensorflow import keras 
import yaml
from dlomix.losses import masked_spectral_distance, masked_pearson_correlation_distance
from tensorflow.keras.callbacks import EarlyStopping

### Prepare stuff for training

##### Read the config file

In [None]:
config_path = "baseline_training/config_files/noptm_baseline_small_bs1024.yaml"

with open(config_path, 'r') as yaml_file:
    config = yaml.safe_load(yaml_file)


##### Load the dataset and the PTM alphabet

In [None]:
# load dataset
from dlomix.data import FragmentIonIntensityDataset

# from misc import PTMS_ALPHABET
from dlomix.constants import PTMS_ALPHABET

from dlomix.data import load_processed_dataset
dataset = load_processed_dataset(config['dataset']['processed_path'])

##### Initialize the optimizer and callbacks

In [None]:
optimizer = tf.keras.optimizers.Adam(learning_rate=config['training']['learning_rate'])

early_stopping = EarlyStopping(
    monitor="val_loss",
    min_delta=0.001,
    patience=20,
    restore_best_weights=True)


##### Import the model

In [None]:
from dlomix.models import PrositIntensityPredictor

input_mapping = {
        "SEQUENCE_KEY": "modified_sequence",
        "COLLISION_ENERGY_KEY": "collision_energy_aligned_normed",
        "PRECURSOR_CHARGE_KEY": "precursor_charge_onehot",
        "FRAGMENTATION_TYPE_KEY": "method_nbr",
    }

meta_data_keys=["collision_energy_aligned_normed", "precursor_charge_onehot", "method_nbr"]

model = PrositIntensityPredictor(
    seq_length=config['dataset']['seq_length'],
    alphabet=PTMS_ALPHABET,
    use_prosit_ptm_features=False,
    with_termini=False,
    input_keys=input_mapping,
    meta_data_keys=meta_data_keys
)

model.compile(
    optimizer=optimizer,
    loss=masked_spectral_distance,
    metrics=[masked_pearson_correlation_distance]
)

In [None]:
model.summary()

### Initial training

In [None]:
model.fit(
    dataset.tensor_train_data,
    validation_data=dataset.tensor_val_data,
    epochs=config['training']['num_epochs'],
    callbacks=[early_stopping]
)

### Freezing

##### Check if weights are correctly assigned to trainable/non-trainable:


In [None]:
print(f'layers with weigths: {len(model.weights)}')
print(f'layers with trainable weigths: {len(model.trainable_weights)}')
print(f'layers with non-trainable weigths: {len(model.non_trainable_weights)}')

##### Freeze the whole model:

In [None]:
model.trainable = False

In [None]:
print(f'layers with weigths: {len(model.weights)}')
print(f'layers with trainable weigths: {len(model.trainable_weights)}')
print(f'layers with non-trainable weigths: {len(model.non_trainable_weights)}')

Unfreeze again:

In [None]:
model.trainable = True

##### Freeze the model and only keep the first and/or the last layer trainable:

In [None]:
# function to freeze all layers except first and/or last layer
def freeze_model(model:dlomix.models.prosit.PrositIntensityPredictor, trainable_first_layer:bool = False, trainable_last_layer:bool = False) -> None:
    
    # reset everything to trainable, 'model.trainable = False' overshadowes trainable arguments of sublayers
    model.trainable = True 
    
    # go through layers and set trainable to False at lowest level so trainable argument is not overshadowed
    for lay in model.layers:
        try:
            for sublay in lay.layers:
                sublay.trainable = False
        except (AttributeError):
            lay.trainable = False

    if (trainable_first_layer):
        first_layer = model.get_layer(name="embedding")
        first_layer.trainable = True

    if (trainable_last_layer):
        last_layer = model.get_layer(name = "sequential_4").get_layer(name = "time_dense")
        last_layer.trainable = True

    # compile the model again to make changes take effect
    model.compile(
        optimizer=optimizer,
        loss=masked_spectral_distance,
        metrics=[masked_pearson_correlation_distance]
    )

In [None]:
def check_trainability(model, sublayers = False):
    for lay in model.layers:
        if(sublayers):
            print()
            try:
                lay.layers
                print(f'Sequential {lay} trainable: {lay.trainable}')
                for lay2 in lay.layers:
                    print(f'{lay2} trainable: {lay2.trainable}')
            except(AttributeError):
                print(f'{lay} trainable: {lay.trainable}')
        else:
            print(f'{lay} trainable: {lay.trainable}')


##### Testing the function:

*Freeze all layers except the first layer:*

In [None]:
freeze_model(model, trainable_first_layer=True)
check_trainability(model, sublayers=True)

*Freeze all layers except the last layer:*

In [None]:
freeze_model(model, trainable_last_layer=True)
check_trainability(model, sublayers=True)

*Freeze all layers except the first and the last layer:*

In [None]:
freeze_model(model, trainable_first_layer=True, trainable_last_layer=True)
check_trainability(model, sublayers=True)

### Continue Training with frozen layers

In [None]:
freeze_model(model, trainable_first_layer=True, trainable_last_layer=True)
check_trainability(model, sublayers=True)

In [None]:
original_weights = model.get_weights()

In [None]:
# train again while only the first layer and the last layer are trainable
model.fit(
    dataset.tensor_train_data,
    validation_data=dataset.tensor_val_data,
    epochs=1,
    callbacks=[early_stopping]
)

In [None]:
# check which weights have changed
retrained_weights = model.get_weights()
for i, w in enumerate(zip(original_weights, retrained_weights)):
    print(f'weights {i} stayed the same: {(w[0]==w[1]).all()}')
    

Two weight tensors changed for the last layer. Both tensors belong to the last time_dense layer:

In [None]:
print(retrained_weights[20]) # 6
print(len(retrained_weights[19])) # 512
print([len(x) for x in retrained_weights[19]]) # 6
print(512 * 6 + 6)
print(model.get_layer(name="sequential_4").summary())
