# Implementation of Creating Callbacks for Refinement/Transfer Learning
This script is for trying out different callback mechanisms using tensorflow and keras. Goal is to have some easy access functions that we can use for further training, refinement and transfer learning for the Prosit Models and ultimately implement in DLOmix.

### Parser

In [None]:
# import argparse

# Parsing the configuration file (required when using a script instead of a notebook)
# parser = argparse.ArgumentParser(prog='Baseline Model Training')
# parser.add_argument('--config', type=str, required=True)
# parser.add_argument('--tf-device-nr', type=str, required=True)
# args = parser.parse_args()


### Configuration file

In [249]:
import yaml

# Manually specify the path to the configuration file
config_file_path = '/nfs/home/students/s.baier/mapra/dlomix/bmpc_shared_scripts/refinement_transfer_learning_utils/config_files/baseline_noptm_baseline_small_bs1024.yaml'

with open(config_file_path, 'r') as yaml_file:
    config = yaml.safe_load(yaml_file)

# Show config containing the configuration data
print(config)


{'dataset': {'name': 'noptm_baseline_small_bs1024', 'hf_home': '/cmnfs/proj/prosit_astral/bmpc_dlomix_group/datasets', 'hf_cache': '/cmnfs/proj/prosit_astral/bmpc_dlomix_group/datasets/hf_cache', 'parquet_path': '/cmnfs/data/proteomics/Prosit_PTMs/Transformer_Train/clean', 'processed_path': '/cmnfs/proj/prosit_astral/bmpc_dlomix_group/datasets/processed/noptm_baseline_small_bs1024', 'seq_length': 30, 'batch_size': 1024}, 'training': {'learning_rate': 0.0001, 'num_epochs': 2}, 'model': {'save_dir': '/cmnfs/proj/prosit_astral/bmpc_dlomix_group/models/callback_models'}, 'processing': {'num_proc': 40}, 'callbacks': {'early_stopping': {'monitor': 'val_loss', 'min_delta': 0.001, 'patience': 20, 'restore_best_weights': True}, 'model_checkpoint': {'filepath': '/nfs/home/students/s.baier/mapra/dlomix/bmpc_shared_scripts/refinement_transfer_learning_utils/saved_models/checkpoints/model-{epoch:02d}-{val_loss:.2f}.hdf5', 'monitor': 'val_loss', 'save_best_only': False, 'save_weights_only': True, 'm

In [250]:
import os
os.environ['HF_HOME'] = config['dataset']['hf_home']
os.environ['HF_DATASETS_CACHE'] = config['dataset']['hf_cache']

# os.environ["CUDA_VISIBLE_DEVICES"] = args.tf_device_nr

### Weights and Biases

In [251]:
import uuid
# initialize weights and biases
import wandb
# from wandb.keras import WandbCallback
from wandb.integration.keras import WandbCallback


config['run_id'] = uuid.uuid4()

project_name = f'callback model training'
wandb.init(
    project=project_name,
    config=config,
    tags=[config['dataset']['name']], 
    entity = 'mapra_dlomix'
)


VBox(children=(Label(value='0.004 MB of 0.004 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

### Dataset

In [252]:
# Load dataset

# DLOmix dataset 
from dlomix.data import FragmentIonIntensityDataset

# Own dataset
from dlomix.data import load_processed_dataset
dataset = load_processed_dataset(wandb.config['dataset']['processed_path'])


### Optimizer

In [253]:
# Initialize TensorFlow and the optimizer
import tensorflow as tf
optimizer = tf.keras.optimizers.Adam(learning_rate=wandb.config['training']['learning_rate'])

### Loss functions

In [254]:
# Define loss functions
from dlomix.losses import masked_spectral_distance, masked_pearson_correlation_distance

### Callbacks

In [255]:
# Define callbacks
from tensorflow.keras.callbacks import (
    EarlyStopping, ModelCheckpoint, LearningRateScheduler, ReduceLROnPlateau,
    LambdaCallback, TerminateOnNaN, CSVLogger
)

In [256]:
# Early Stopping Callback
early_stopping = EarlyStopping(
    monitor=wandb.config['callbacks']['early_stopping']['monitor'],
    min_delta=wandb.config['callbacks']['early_stopping']['min_delta'],
    patience=wandb.config['callbacks']['early_stopping']['patience'],
    restore_best_weights=wandb.config['callbacks']['early_stopping']['restore_best_weights']
)

In [257]:
# Model Checkpoint Callback
model_checkpoint = ModelCheckpoint(
    filepath=wandb.config['callbacks']['model_checkpoint']['filepath'],
    monitor=wandb.config['callbacks']['model_checkpoint']['monitor'],
    save_best_only=wandb.config['callbacks']['model_checkpoint']['save_best_only'],
    save_weights_only=wandb.config['callbacks']['model_checkpoint']['save_weights_only'],
    mode=wandb.config['callbacks']['model_checkpoint']['mode'],
    save_freq=wandb.config['callbacks']['model_checkpoint']['save_freq'],
    verbose=wandb.config['callbacks']['model_checkpoint']['verbose']
)


In [258]:
# Reduce LR on Plateau Callback
reduce_lr = ReduceLROnPlateau(
    monitor=wandb.config['callbacks']['reduce_lr']['monitor'],
    factor=wandb.config['callbacks']['reduce_lr']['factor'],
    patience=wandb.config['callbacks']['reduce_lr']['patience'],
    min_lr=wandb.config['callbacks']['reduce_lr']['min_lr']
)

In [259]:
# CSV Logger Callback
csv_logger = CSVLogger(
    filename=wandb.config['callbacks']['csv_logger']['filename']
)

In [260]:
# Learning Rate Scheduler Callback
learning_rate_scheduler = LearningRateScheduler(
    schedule=lambda epoch: wandb.config['callbacks']['learning_rate_scheduler']['initial_lr'] * wandb.config['callbacks']['learning_rate_scheduler']['decay_rate'] ** epoch
)

In [261]:
# Terminate On NaN Callback: Callback that terminates training when a NaN loss is encountered.
terminate_on_nan = TerminateOnNaN()


In [262]:
# Lambda Callback (example: logging epoch start)
lambda_callback = LambdaCallback(
    on_epoch_begin=lambda epoch, logs: print(f"Starting epoch {epoch}"), 
    on_epoch_end=None,
    on_train_begin=None,
    on_train_end=None,
    on_train_batch_begin=None,
    on_train_batch_end=None
)

### Initialized Model

In [263]:
# Initialize the model 
from dlomix.models import PrositIntensityPredictor
from dlomix.constants import PTMS_ALPHABET

input_mapping = {
    "SEQUENCE_KEY": "modified_sequence",
    "COLLISION_ENERGY_KEY": "collision_energy_aligned_normed",
    "PRECURSOR_CHARGE_KEY": "precursor_charge_onehot",
    "FRAGMENTATION_TYPE_KEY": "method_nbr",
}

meta_data_keys = ["collision_energy_aligned_normed", "precursor_charge_onehot", "method_nbr"]

model = PrositIntensityPredictor(
    seq_length=wandb.config['dataset']['seq_length'],
    alphabet=PTMS_ALPHABET,
    use_prosit_ptm_features=False,
    with_termini=False,
    input_keys=input_mapping,
    meta_data_keys=meta_data_keys
)

In [264]:
# Compile the model 
model.compile(
    optimizer=optimizer,
    loss=masked_spectral_distance,
    metrics=[masked_pearson_correlation_distance]
)

### Train Model

In [265]:
# train model
model.fit(
    dataset.tensor_train_data,
    validation_data=dataset.tensor_val_data,
    epochs=wandb.config['training']['num_epochs'],
    callbacks=[WandbCallback(save_model=False, log_batch_frequency=True), 
               early_stopping, 
               reduce_lr, 
            #  learning_rate_scheduler,           
               terminate_on_nan, 
            #  lambda_callback, 
            #  csv_logger,   
            #  model_checkpoint
               ]
)

Epoch 1/2
Epoch 2/2


<keras.src.callbacks.History at 0x7fe5fcce52a0>

### Save Model

In [266]:
model_path = f"{wandb.config['model']['save_dir']}/{wandb.config['dataset']['name']}/{wandb.config['run_id']}.keras"

model.save(model_path)  # The file needs to end with the .keras extension

In [267]:
# Finish the wandb run
wandb.finish()

VBox(children=(Label(value='0.018 MB of 0.018 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁█
loss,█▅▄▃▃▂▂▂▂▁▁▁▁▁▁▁▁▁
masked_pearson_correlation_distance,█▅▄▃▃▂▂▂▂▁▁▁▁▁▁▁▁▁
val_loss,█▁
val_masked_pearson_correlation_distance,█▁

0,1
best_epoch,1.0
best_val_loss,0.66794
epoch,1.0
loss,0.6695
masked_pearson_correlation_distance,0.55861
val_loss,0.66794
val_masked_pearson_correlation_distance,0.55546


### Load Model

In [268]:
import keras

reconstructed_model = keras.models.load_model(model_path)

In [269]:
# Model summary 

# Print parameters
print("Embedding Output Dimension:", reconstructed_model.embedding_output_dim)
print("Sequence Length:", reconstructed_model.seq_length)
print("Alphabet Dictionary:", reconstructed_model.alphabet)
print("Dropout Rate:", reconstructed_model.dropout_rate)
print("Latent Dropout Rate:", reconstructed_model.latent_dropout_rate)
print("Recurrent Layers Sizes:", reconstructed_model.recurrent_layers_sizes)
print("Regressor Layer Size:", reconstructed_model.regressor_layer_size)
print("Use Prosit PTM Features:", reconstructed_model.use_prosit_ptm_features)
print("Input Keys:", reconstructed_model.input_keys)

# Print attributes
print("Default Input Keys:", reconstructed_model.DEFAULT_INPUT_KEYS)
print("Meta Data Keys (Attribute):", reconstructed_model.META_DATA_KEYS)
print("PTM Input Keys:", reconstructed_model.PTM_INPUT_KEYS)


Embedding Output Dimension: 16
Sequence Length: 30
Alphabet Dictionary: {'A': 1, 'C': 2, 'D': 3, 'E': 4, 'F': 5, 'G': 6, 'H': 7, 'I': 8, 'K': 9, 'L': 10, 'M': 11, 'N': 12, 'P': 13, 'Q': 14, 'R': 15, 'S': 16, 'T': 17, 'V': 18, 'W': 19, 'Y': 20, '[]-': 21, '-[]': 22, '[UNIMOD:737]-': 56, 'M[UNIMOD:35]': 23, 'S[UNIMOD:21]': 24, 'T[UNIMOD:21]': 25, 'Y[UNIMOD:21]': 26, 'R[UNIMOD:7]': 27, 'Q[UNIMOD:7]': 4, 'N[UNIMOD:7]': 3, 'K[UNIMOD:1]': 28, 'K[UNIMOD:121]': 29, 'Q[UNIMOD:28]': 30, 'R[UNIMOD:34]': 31, 'K[UNIMOD:34]': 32, 'T[UNIMOD:43]': 35, 'S[UNIMOD:43]': 36, 'C[UNIMOD:4]': 37, '[UNIMOD:1]-': 38, 'E[UNIMOD:27]': 39, 'K[UNIMOD:36]': 40, 'K[UNIMOD:37]': 41, 'K[UNIMOD:122]': 42, 'K[UNIMOD:58]': 43, 'K[UNIMOD:1289]': 44, 'K[UNIMOD:747]': 45, 'K[UNIMOD:64]': 46, 'K[UNIMOD:1848]': 47, 'K[UNIMOD:1363]': 48, 'K[UNIMOD:1849]': 49, 'K[UNIMOD:3]': 50, 'K[UNIMOD:737]': 55, 'R[UNIMOD:36]': 51, 'R[UNIMOD:36a]': 52, 'P[UNIMOD:35]': 53, 'Y[UNIMOD:354]': 54}
Dropout Rate: 0.2
Latent Dropout Rate: 0.1
Recur