# Creating Callbacks for Model Refinement and Transfer Learning
This script facilitates the exploration of various callback mechanisms within TensorFlow and Keras. The objective is to develop accessible functions that enable further training, refinement, and transfer learning of Prosit Models, with the intention of integrating these into DLOmix.

### Configuration file

In [None]:
import yaml

# Manually specify the path to the configuration file (Note: change path according to your directories) 
config_file_path = '/nfs/home/students/s.baier/mapra/dlomix/bmpc_shared_scripts/refinement_transfer_learning_utils/config_files/baseline_noptm_baseline_small_bs1024.yaml'

with open(config_file_path, 'r') as yaml_file:
    config = yaml.safe_load(yaml_file)

# Show config containing the configuration data
print(config)


In [None]:
# configure environment
import os
os.environ['HF_HOME'] = config['dataset']['hf_home']
os.environ['HF_DATASETS_CACHE'] = config['dataset']['hf_cache']

# os.environ["CUDA_VISIBLE_DEVICES"] = args.tf_device_nr

### Weights and Biases

In [None]:
# uuid to ensure unique identifiers
import uuid
# initialize weights and biases
import wandb
# from wandb.keras import WandbCallback
from wandb.integration.keras import WandbCallback

# set id for run using uuid
config['run_id'] = uuid.uuid4()

# set up wandb for this project
project_name = f'callback model training'
wandb.init(
    project=project_name,
    config=config,
    tags=[config['dataset']['name']], 
    entity = 'mapra_dlomix'
)


### Dataset

In [None]:
# DLOmix dataset 
from dlomix.data import FragmentIonIntensityDataset

# Own dataset
from dlomix.data import load_processed_dataset
dataset = load_processed_dataset(wandb.config['dataset']['processed_path'])


### Optimizer

In [None]:
# Initialize TensorFlow and the optimizer
import tensorflow as tf
optimizer = tf.keras.optimizers.Adam(learning_rate=wandb.config['training']['learning_rate'])

### Loss functions

In [None]:
# Define loss functions
from dlomix.losses import masked_spectral_distance, masked_pearson_correlation_distance

### Callbacks

In [None]:
# Import callbacks
from tensorflow.keras.callbacks import (
    EarlyStopping, ModelCheckpoint, LearningRateScheduler, ReduceLROnPlateau,
    LambdaCallback, TerminateOnNaN, CSVLogger
)

In [None]:
# Early Stopping Callback
early_stopping = EarlyStopping(
    monitor=wandb.config['callbacks']['early_stopping']['monitor'],
    min_delta=wandb.config['callbacks']['early_stopping']['min_delta'],
    patience=wandb.config['callbacks']['early_stopping']['patience'],
    restore_best_weights=wandb.config['callbacks']['early_stopping']['restore_best_weights']
)

In [None]:
# Reduce LR on Plateau Callback
reduce_lr = ReduceLROnPlateau(
    monitor=wandb.config['callbacks']['reduce_lr']['monitor'],
    factor=wandb.config['callbacks']['reduce_lr']['factor'],
    patience=wandb.config['callbacks']['reduce_lr']['patience'],
    min_lr=wandb.config['callbacks']['reduce_lr']['min_lr']
)

In [None]:
# Learning Rate Scheduler Callback
learning_rate_scheduler = LearningRateScheduler(
    schedule=lambda epoch: wandb.config['callbacks']['learning_rate_scheduler']['initial_lr'] * wandb.config['callbacks']['learning_rate_scheduler']['decay_rate'] ** epoch
)

In [None]:
# Terminate On NaN Callback: Callback that terminates training when a NaN loss is encountered.
terminate_on_nan = TerminateOnNaN()

In [None]:
# Lambda Callback 
lambda_callback = LambdaCallback(
    on_epoch_begin=lambda epoch, logs: print(f"Starting epoch {epoch + 1}")
    # on_epoch_end=None,
    # on_train_begin=None,
    # on_train_end=None,
    # on_train_batch_begin=None,
    # on_train_batch_end=None
)

In [None]:
# CSV Logger Callback (Note: not necessary when using wandb)
csv_logger = CSVLogger(
    filename=wandb.config['callbacks']['csv_logger']['filename'], 
    append=wandb.config['callbacks']['csv_logger']['append']
)

In [None]:
# Model Checkpoint Callback
model_checkpoint = ModelCheckpoint(
    filepath=wandb.config['callbacks']['model_checkpoint']['filepath'],
    monitor=wandb.config['callbacks']['model_checkpoint']['monitor'],
    save_best_only=wandb.config['callbacks']['model_checkpoint']['save_best_only'],
    save_weights_only=wandb.config['callbacks']['model_checkpoint']['save_weights_only'],
    mode=wandb.config['callbacks']['model_checkpoint']['mode'],
    save_freq=wandb.config['callbacks']['model_checkpoint']['save_freq'],
    verbose=wandb.config['callbacks']['model_checkpoint']['verbose']
)

### Initialized Model

In [None]:
# Initialize the model 
from dlomix.models import PrositIntensityPredictor # predictor for intensity 
from dlomix.constants import PTMS_ALPHABET # alphabet with PTMs (can be adapted based on the data)

input_mapping = {
    "SEQUENCE_KEY": "modified_sequence",
    "COLLISION_ENERGY_KEY": "collision_energy_aligned_normed",
    "PRECURSOR_CHARGE_KEY": "precursor_charge_onehot",
    "FRAGMENTATION_TYPE_KEY": "method_nbr",
}

meta_data_keys = ["collision_energy_aligned_normed", "precursor_charge_onehot", "method_nbr"]

# initialize prosit model
model = PrositIntensityPredictor(
    seq_length=wandb.config['dataset']['seq_length'],
    alphabet=PTMS_ALPHABET,
    use_prosit_ptm_features=False,
    with_termini=False,
    input_keys=input_mapping,
    meta_data_keys=meta_data_keys
)

In [None]:
# Compile the model 
model.compile(
    optimizer=optimizer,
    loss=masked_spectral_distance,
    metrics=[masked_pearson_correlation_distance]
)

### Train Model

In [None]:
# train model
model.fit(
    dataset.tensor_train_data,
    validation_data=dataset.tensor_val_data,
    epochs=wandb.config['training']['num_epochs'],
    callbacks=[WandbCallback(save_model=False, log_batch_frequency=True), 
               early_stopping, 
               reduce_lr, 
               learning_rate_scheduler,           
               terminate_on_nan, 
               lambda_callback, 
               csv_logger, # (Note: not necessary when using wandb; shown for completeness)
               model_checkpoint
               ]
)

### Save Model

In [None]:
# model path to save to the model to (Note: The file needs to end with the .keras extension.)
model_path = f"{wandb.config['model']['save_dir']}/{wandb.config['dataset']['name']}/{wandb.config['run_id']}.keras"

# save the model
model.save(model_path)  

print(f"Model saved to: {model_path}")

In [None]:
# Finish the wandb run
wandb.finish()

### Load Model

In [None]:
import keras

# load the trained model 
reconstructed_model = keras.models.load_model(model_path)

In [None]:
# Model summary 

# Print parameters
print("Embedding Output Dimension:", reconstructed_model.embedding_output_dim)
print("Sequence Length:", reconstructed_model.seq_length)
print("Alphabet Dictionary:", reconstructed_model.alphabet)
print("Dropout Rate:", reconstructed_model.dropout_rate)
print("Latent Dropout Rate:", reconstructed_model.latent_dropout_rate)
print("Recurrent Layers Sizes:", reconstructed_model.recurrent_layers_sizes)
print("Regressor Layer Size:", reconstructed_model.regressor_layer_size)
print("Use Prosit PTM Features:", reconstructed_model.use_prosit_ptm_features)
print("Input Keys:", reconstructed_model.input_keys)

# Print attributes
print("Default Input Keys:", reconstructed_model.DEFAULT_INPUT_KEYS)
print("Meta Data Keys (Attribute):", reconstructed_model.META_DATA_KEYS)
print("PTM Input Keys:", reconstructed_model.PTM_INPUT_KEYS)


In [None]:
import keras

# load the model at a certain checkpoint (Note: change path according to your directories)
checkpoint_model = keras.models.load_model("/nfs/home/students/s.baier/mapra/dlomix/bmpc_shared_scripts/refinement_transfer_learning_utils/checkpoints/model-01-0.67.keras")

In [None]:
# Model summary 

# Print parameters
print("Embedding Output Dimension:", checkpoint_model.embedding_output_dim)
print("Sequence Length:", checkpoint_model.seq_length)
print("Alphabet Dictionary:", checkpoint_model.alphabet)
print("Dropout Rate:", checkpoint_model.dropout_rate)
print("Latent Dropout Rate:", checkpoint_model.latent_dropout_rate)
print("Recurrent Layers Sizes:", checkpoint_model.recurrent_layers_sizes)
print("Regressor Layer Size:", checkpoint_model.regressor_layer_size)
print("Use Prosit PTM Features:", checkpoint_model.use_prosit_ptm_features)
print("Input Keys:", checkpoint_model.input_keys)

# Print attributes
print("Default Input Keys:", checkpoint_model.DEFAULT_INPUT_KEYS)
print("Meta Data Keys (Attribute):", checkpoint_model.META_DATA_KEYS)
print("PTM Input Keys:", checkpoint_model.PTM_INPUT_KEYS)
