# Creating Callbacks for Model Refinement and Transfer Learning
This script facilitates the exploration of various callback mechanisms within TensorFlow and Keras. The objective is to develop accessible functions that enable further training, refinement, and transfer learning of Prosit Models, with the intention of integrating these into DLOmix.

### Configuration file

In [2]:
import yaml

# Manually specify the path to the configuration file (Note: change path according to your directories)
config_file_path = '/nfs/home/students/s.baier/mapra/dlomix/bmpc_shared_scripts/refinement_transfer_learning/config_files/baseline_noptm_baseline_small_bs1024.yaml'

with open(config_file_path, 'r') as yaml_file:
    config = yaml.safe_load(yaml_file)

# Show config containing the configuration data
print(config)


{'dataset': {'name': 'noptm_baseline_small_bs1024', 'hf_home': '/cmnfs/proj/prosit_astral/bmpc_dlomix_group/datasets', 'hf_cache': '/cmnfs/proj/prosit_astral/bmpc_dlomix_group/datasets/hf_cache', 'parquet_path': '/cmnfs/data/proteomics/Prosit_PTMs/Transformer_Train/clean', 'processed_path': '/cmnfs/proj/prosit_astral/bmpc_dlomix_group/datasets/processed/noptm_baseline_small_bs1024', 'seq_length': 30, 'batch_size': 1024}, 'training': {'learning_rate': 0.0001, 'num_epochs': 2}, 'model': {'save_dir': '/cmnfs/proj/prosit_astral/bmpc_dlomix_group/models/callback_models'}, 'processing': {'num_proc': 40}, 'callbacks': {'early_stopping': {'monitor': 'val_loss', 'min_delta': 0.001, 'patience': 20, 'restore_best_weights': True}, 'reduce_lr': {'monitor': 'val_loss', 'factor': 0.1, 'patience': 10, 'min_lr': '1e-6', 'mode': 'auto', 'min_delta': '1e-4', 'verbose': 1}, 'learning_rate_scheduler': {'initial_lr': 0.0001, 'decay_rate': 0.9}, 'lambda_callback': {'on_epoch_end': "lambda epoch, logs: print(

In [3]:
# configure environment
import os
os.environ['HF_HOME'] = config['dataset']['hf_home']
os.environ['HF_DATASETS_CACHE'] = config['dataset']['hf_cache']

# os.environ["CUDA_VISIBLE_DEVICES"] = args.tf_device_nr

### Weights and Biases

In [4]:
# uuid to ensure unique identifiers
import uuid
# initialize weights and biases
import wandb
# import WandbCallback
from wandb.integration.keras import WandbCallback

# set id for run using uuid
config['run_id'] = uuid.uuid4()

# set up wandb for this project
project_name = f'callback model training'
wandb.init(
    project=project_name,
    config=config,
    tags=[config['dataset']['name']], 
    entity = 'mapra_dlomix'
)


2024-06-19 11:04:23.982480: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-06-19 11:04:23.982629: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-06-19 11:04:24.274614: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-06-19 11:04:25.658767: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
Failed to detect the name of this notebook, you can s

### Dataset

In [5]:
# DLOmix dataset 
from dlomix.data import FragmentIonIntensityDataset

# Own dataset
from dlomix.data import load_processed_dataset
dataset = load_processed_dataset(wandb.config['dataset']['processed_path'])



Avaliable feature extractors are (use the key of the following dict and pass it to features_to_extract in the Dataset Class):
{
   "atom_count": "Atom count of PTM.",
   "delta_mass": "Delta mass of PTM.",
   "mod_gain": "Gain of atoms due to PTM.",
   "mod_loss": "Loss of atoms due to PTM.",
   "red_smiles": "Reduced SMILES representation of PTM."
}.
When writing your own feature extractor, you can either
    (1) use the FeatureExtractor class or
    (2) write a function that can be mapped to the Hugging Face dataset.
In both cases, you can access the parsed sequence information from the dataset using the following keys, which all provide python lists:
    - _parsed_sequence: parsed sequence
    - _n_term_mods: N-terminal modifications
    - _c_term_mods: C-terminal modifications



### Optimizer

In [6]:
# Initialize TensorFlow and the optimizer
import tensorflow as tf
optimizer = tf.keras.optimizers.Adam(learning_rate=wandb.config['training']['learning_rate'])

2024-06-19 11:05:06.376816: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:274] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected


### Loss functions

In [7]:
# Define loss functions
from dlomix.losses import masked_spectral_distance, masked_pearson_correlation_distance

### Callbacks

In [8]:
# Import callbacks
from tensorflow.keras.callbacks import (
    EarlyStopping, ModelCheckpoint, LearningRateScheduler, ReduceLROnPlateau,
    LambdaCallback, TerminateOnNaN, CSVLogger
)

In [9]:
# Early Stopping Callback
early_stopping = EarlyStopping(
    monitor=wandb.config['callbacks']['early_stopping']['monitor'],
    min_delta=wandb.config['callbacks']['early_stopping']['min_delta'],
    patience=wandb.config['callbacks']['early_stopping']['patience'],
    restore_best_weights=wandb.config['callbacks']['early_stopping']['restore_best_weights']
)

In [10]:
# Reduce LR on Plateau Callback
reduce_lr = ReduceLROnPlateau(
    monitor=wandb.config['callbacks']['reduce_lr']['monitor'],
    factor=wandb.config['callbacks']['reduce_lr']['factor'],
    patience=wandb.config['callbacks']['reduce_lr']['patience'],
    min_lr=wandb.config['callbacks']['reduce_lr']['min_lr']
)

In [11]:
# Learning Rate Scheduler Callback
learning_rate_scheduler = LearningRateScheduler(
    schedule=lambda epoch: wandb.config['callbacks']['learning_rate_scheduler']['initial_lr'] * wandb.config['callbacks']['learning_rate_scheduler']['decay_rate'] ** epoch
)

In [12]:
# Terminate On NaN Callback: Callback that terminates training when a NaN loss is encountered.
terminate_on_nan = TerminateOnNaN()

In [13]:
# Lambda Callback 
lambda_callback = LambdaCallback(
    on_epoch_begin=lambda epoch, logs: print(f"Starting epoch {epoch + 1}")
    # on_epoch_end=None,
    # on_train_begin=None,
    # on_train_end=None,
    # on_train_batch_begin=None,
    # on_train_batch_end=None
)

In [14]:
# CSV Logger Callback (Note: not necessary when using wandb)
csv_logger = CSVLogger(
    filename=wandb.config['callbacks']['csv_logger']['filename'], 
    append=wandb.config['callbacks']['csv_logger']['append']
)

In [15]:
# Model Checkpoint Callback
model_checkpoint = ModelCheckpoint(
    filepath=wandb.config['callbacks']['model_checkpoint']['filepath'],
    monitor=wandb.config['callbacks']['model_checkpoint']['monitor'],
    save_best_only=wandb.config['callbacks']['model_checkpoint']['save_best_only'],
    save_weights_only=wandb.config['callbacks']['model_checkpoint']['save_weights_only'],
    mode=wandb.config['callbacks']['model_checkpoint']['mode'],
    save_freq=wandb.config['callbacks']['model_checkpoint']['save_freq'],
    verbose=wandb.config['callbacks']['model_checkpoint']['verbose']
)

### Initialized Model

In [16]:
# Initialize the model 
from dlomix.models import PrositIntensityPredictor # predictor for intensity 
from dlomix.constants import PTMS_ALPHABET # alphabet with PTMs (can be adapted based on the data)

input_mapping = {
    "SEQUENCE_KEY": "modified_sequence",
    "COLLISION_ENERGY_KEY": "collision_energy_aligned_normed",
    "PRECURSOR_CHARGE_KEY": "precursor_charge_onehot",
    "FRAGMENTATION_TYPE_KEY": "method_nbr",
}

meta_data_keys = ["collision_energy_aligned_normed", "precursor_charge_onehot", "method_nbr"]

# initialize prosit model
model = PrositIntensityPredictor(
    seq_length=wandb.config['dataset']['seq_length'],
    alphabet=PTMS_ALPHABET,
    use_prosit_ptm_features=False,
    with_termini=False,
    input_keys=input_mapping,
    meta_data_keys=meta_data_keys
)

In [17]:
# Compile the model 
model.compile(
    optimizer=optimizer,
    loss=masked_spectral_distance,
    metrics=[masked_pearson_correlation_distance]
)

### Train Model

In [18]:
# train model
model.fit(
    dataset.tensor_train_data,
    validation_data=dataset.tensor_val_data,
    epochs=wandb.config['training']['num_epochs'],
    callbacks=[WandbCallback(save_model=False, log_batch_frequency=True), 
               early_stopping, 
               reduce_lr, 
               learning_rate_scheduler,           
               terminate_on_nan, 
               lambda_callback, 
               csv_logger, # (Note: not necessary when using wandb; shown for completeness)
               model_checkpoint
               ]
)

Starting epoch 1
Epoch 1/2
Epoch 1: val_loss improved from inf to 0.67107, saving model to /nfs/home/students/s.baier/mapra/dlomix/bmpc_shared_scripts/refinement_transfer_learning/checkpoints/model-01-0.67.keras
Starting epoch 2
Epoch 2/2
Epoch 2: val_loss improved from 0.67107 to 0.66755, saving model to /nfs/home/students/s.baier/mapra/dlomix/bmpc_shared_scripts/refinement_transfer_learning/checkpoints/model-02-0.67.keras


<keras.src.callbacks.History at 0x7ff9b82b5fc0>

### Save Model

In [19]:
# model path to save to the model to (Note: The file needs to end with the .keras extension.)
model_path = f"{wandb.config['model']['save_dir']}/{wandb.config['dataset']['name']}/{wandb.config['run_id']}.keras"

# save the model
model.save(model_path)  

print(f"Model saved to: {model_path}")

Model saved to: /cmnfs/proj/prosit_astral/bmpc_dlomix_group/models/callback_models/noptm_baseline_small_bs1024/b2877e59-3e23-4ced-a751-ff5bcbf2c331.keras


In [20]:
# Finish the wandb run
wandb.finish()

VBox(children=(Label(value='0.006 MB of 0.006 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁█
loss,█▅▃▃▃▂▂▂▂▁▁▁▁▁▁▁▁▁
masked_pearson_correlation_distance,█▅▄▃▃▂▂▂▂▁▁▁▁▁▁▁▁▁
val_loss,█▁
val_masked_pearson_correlation_distance,█▁

0,1
best_epoch,1.0
best_val_loss,0.66755
epoch,1.0
loss,0.66934
masked_pearson_correlation_distance,0.56202
val_loss,0.66755
val_masked_pearson_correlation_distance,0.55755


### Load Model

In [21]:
import keras

# load the trained model 
reconstructed_model = keras.models.load_model(model_path)

In [22]:
# Model summary 

# Print parameters
print("Embedding Output Dimension:", reconstructed_model.embedding_output_dim)
print("Sequence Length:", reconstructed_model.seq_length)
print("Alphabet Dictionary:", reconstructed_model.alphabet)
print("Dropout Rate:", reconstructed_model.dropout_rate)
print("Latent Dropout Rate:", reconstructed_model.latent_dropout_rate)
print("Recurrent Layers Sizes:", reconstructed_model.recurrent_layers_sizes)
print("Regressor Layer Size:", reconstructed_model.regressor_layer_size)
print("Use Prosit PTM Features:", reconstructed_model.use_prosit_ptm_features)
print("Input Keys:", reconstructed_model.input_keys)

# Print attributes
print("Default Input Keys:", reconstructed_model.DEFAULT_INPUT_KEYS)
print("Meta Data Keys (Attribute):", reconstructed_model.META_DATA_KEYS)
print("PTM Input Keys:", reconstructed_model.PTM_INPUT_KEYS)


Embedding Output Dimension: 16
Sequence Length: 30
Alphabet Dictionary: {'A': 1, 'C': 2, 'D': 3, 'E': 4, 'F': 5, 'G': 6, 'H': 7, 'I': 8, 'K': 9, 'L': 10, 'M': 11, 'N': 12, 'P': 13, 'Q': 14, 'R': 15, 'S': 16, 'T': 17, 'V': 18, 'W': 19, 'Y': 20, '[]-': 21, '-[]': 22, '[UNIMOD:737]-': 56, 'M[UNIMOD:35]': 23, 'S[UNIMOD:21]': 24, 'T[UNIMOD:21]': 25, 'Y[UNIMOD:21]': 26, 'R[UNIMOD:7]': 27, 'Q[UNIMOD:7]': 4, 'N[UNIMOD:7]': 3, 'K[UNIMOD:1]': 28, 'K[UNIMOD:121]': 29, 'Q[UNIMOD:28]': 30, 'R[UNIMOD:34]': 31, 'K[UNIMOD:34]': 32, 'T[UNIMOD:43]': 35, 'S[UNIMOD:43]': 36, 'C[UNIMOD:4]': 37, '[UNIMOD:1]-': 38, 'E[UNIMOD:27]': 39, 'K[UNIMOD:36]': 40, 'K[UNIMOD:37]': 41, 'K[UNIMOD:122]': 42, 'K[UNIMOD:58]': 43, 'K[UNIMOD:1289]': 44, 'K[UNIMOD:747]': 45, 'K[UNIMOD:64]': 46, 'K[UNIMOD:1848]': 47, 'K[UNIMOD:1363]': 48, 'K[UNIMOD:1849]': 49, 'K[UNIMOD:3]': 50, 'K[UNIMOD:737]': 55, 'R[UNIMOD:36]': 51, 'R[UNIMOD:36a]': 52, 'P[UNIMOD:35]': 53, 'Y[UNIMOD:354]': 54}
Dropout Rate: 0.2
Latent Dropout Rate: 0.1
Recur

In [23]:
# load the model at a certain checkpoint (Note: change path according to your directories)
checkpoint_model = keras.models.load_model("/nfs/home/students/s.baier/mapra/dlomix/bmpc_shared_scripts/refinement_transfer_learning/checkpoints/model-01-0.67.keras")

In [24]:
# Model summary 

# Print parameters
print("Embedding Output Dimension:", checkpoint_model.embedding_output_dim)
print("Sequence Length:", checkpoint_model.seq_length)
print("Alphabet Dictionary:", checkpoint_model.alphabet)
print("Dropout Rate:", checkpoint_model.dropout_rate)
print("Latent Dropout Rate:", checkpoint_model.latent_dropout_rate)
print("Recurrent Layers Sizes:", checkpoint_model.recurrent_layers_sizes)
print("Regressor Layer Size:", checkpoint_model.regressor_layer_size)
print("Use Prosit PTM Features:", checkpoint_model.use_prosit_ptm_features)
print("Input Keys:", checkpoint_model.input_keys)

# Print attributes
print("Default Input Keys:", checkpoint_model.DEFAULT_INPUT_KEYS)
print("Meta Data Keys (Attribute):", checkpoint_model.META_DATA_KEYS)
print("PTM Input Keys:", checkpoint_model.PTM_INPUT_KEYS)


Embedding Output Dimension: 16
Sequence Length: 30
Alphabet Dictionary: {'A': 1, 'C': 2, 'D': 3, 'E': 4, 'F': 5, 'G': 6, 'H': 7, 'I': 8, 'K': 9, 'L': 10, 'M': 11, 'N': 12, 'P': 13, 'Q': 14, 'R': 15, 'S': 16, 'T': 17, 'V': 18, 'W': 19, 'Y': 20, '[]-': 21, '-[]': 22, '[UNIMOD:737]-': 56, 'M[UNIMOD:35]': 23, 'S[UNIMOD:21]': 24, 'T[UNIMOD:21]': 25, 'Y[UNIMOD:21]': 26, 'R[UNIMOD:7]': 27, 'Q[UNIMOD:7]': 4, 'N[UNIMOD:7]': 3, 'K[UNIMOD:1]': 28, 'K[UNIMOD:121]': 29, 'Q[UNIMOD:28]': 30, 'R[UNIMOD:34]': 31, 'K[UNIMOD:34]': 32, 'T[UNIMOD:43]': 35, 'S[UNIMOD:43]': 36, 'C[UNIMOD:4]': 37, '[UNIMOD:1]-': 38, 'E[UNIMOD:27]': 39, 'K[UNIMOD:36]': 40, 'K[UNIMOD:37]': 41, 'K[UNIMOD:122]': 42, 'K[UNIMOD:58]': 43, 'K[UNIMOD:1289]': 44, 'K[UNIMOD:747]': 45, 'K[UNIMOD:64]': 46, 'K[UNIMOD:1848]': 47, 'K[UNIMOD:1363]': 48, 'K[UNIMOD:1849]': 49, 'K[UNIMOD:3]': 50, 'K[UNIMOD:737]': 55, 'R[UNIMOD:36]': 51, 'R[UNIMOD:36a]': 52, 'P[UNIMOD:35]': 53, 'Y[UNIMOD:354]': 54}
Dropout Rate: 0.2
Latent Dropout Rate: 0.1
Recur