# Implementation of Creating Callbacks for Refinement/Transfer Learning
This script is for trying out different callback mechanisms using tensorflow and keras. Goal is to have some easy access functions that we can use for further training, refinement and transfer learning for the Prosit Models and ultimately implement in DLOmix.

### Imports

In [58]:
# Imports
import argparse
import yaml

### Parser

In [59]:
# Parsing the configuration file (required when using a script instead of a notebook)
#parser = argparse.ArgumentParser(prog='Extended Model Training')
#parser.add_argument('--config', type=str, required=True)
#args = parser.parse_args()


### Configuration file

In [60]:
# Manually specify the path to the configuration file
config_file_path = '/nfs/home/students/s.baier/mapra/dlomix/bmpc_shared_scripts/refinement_transfer_learning_utils/config_files/baseline_noptm_baseline_small_bs1024.yaml'

with open(config_file_path, 'r') as yaml_file:
    config = yaml.safe_load(yaml_file)

# Show config containing the configuration data
print(config)


{'dataset': {'hf_home': '/cmnfs/proj/prosit_astral/bmpc_dlomix_group/datasets', 'hf_cache': '/cmnfs/proj/prosit_astral/bmpc_dlomix_group/datasets/hf_cache', 'parquet_path': '/cmnfs/data/proteomics/Prosit_PTMs/Transformer_Train/clean', 'processed_path': '/cmnfs/proj/prosit_astral/bmpc_dlomix_group/datasets/processed/noptm_baseline_small_bs1024', 'seq_length': 30, 'batch_size': 1024}, 'training': {'learning_rate': 0.0001, 'num_epochs': 2}, 'processing': {'num_proc': 40}, 'callbacks': {'early_stopping': {'monitor': 'val_loss', 'min_delta': 0.001, 'patience': 20, 'restore_best_weights': True}, 'model_checkpoint': {'filepath': '/nfs/home/students/s.baier/mapra/dlomix/bmpc_shared_scripts/refinement_transfer_learning_utils/saved_models/checkpoints/model-{epoch:02d}-{val_loss:.2f}.hdf5', 'monitor': 'val_loss', 'save_best_only': False, 'save_weights_only': True, 'mode': 'auto', 'save_freq': 'epoch', 'verbose': 1}, 'reduce_lr': {'monitor': 'val_loss', 'factor': 0.1, 'patience': 10, 'min_lr': '1e

### Weights and Biases

In [61]:
# Initialize wandb for experiment tracking
# import wandb
# from wandb.integration.keras import WandbCallback

# project_name = 'extended_model_training'
# wandb.init(project=project_name)
# wandb.config = config

### Dataset

In [62]:
# Load dataset
from dlomix.data import load_processed_dataset
dataset = load_processed_dataset(config['dataset']['processed_path'])

### Optimizer

In [63]:
# Initialize TensorFlow and the optimizer
import tensorflow as tf
optimizer = tf.keras.optimizers.Adam(learning_rate=config['training']['learning_rate'])

### Loss functions

In [64]:
# Define loss functions
from dlomix.losses import masked_spectral_distance, masked_pearson_correlation_distance

### Callbacks

In [65]:
# Define callbacks
from tensorflow.keras.callbacks import (
    EarlyStopping, ModelCheckpoint, LearningRateScheduler, ReduceLROnPlateau,
    LambdaCallback, TerminateOnNaN, CSVLogger
)

In [66]:
# Early Stopping Callback
early_stopping = EarlyStopping(
    monitor=config['callbacks']['early_stopping']['monitor'],
    min_delta=config['callbacks']['early_stopping']['min_delta'],
    patience=config['callbacks']['early_stopping']['patience'],
    restore_best_weights=config['callbacks']['early_stopping']['restore_best_weights']
)

In [67]:
# Model Checkpoint Callback
model_checkpoint = ModelCheckpoint(
    filepath=config['callbacks']['model_checkpoint']['filepath'],
    monitor=config['callbacks']['model_checkpoint']['monitor'],
    save_best_only=config['callbacks']['model_checkpoint']['save_best_only'],
    save_weights_only=config['callbacks']['model_checkpoint']['save_weights_only'],
    mode=config['callbacks']['model_checkpoint']['mode'],
    save_freq=config['callbacks']['model_checkpoint']['save_freq'],
    verbose=config['callbacks']['model_checkpoint']['verbose']
)



In [68]:
# Reduce LR on Plateau Callback
reduce_lr = ReduceLROnPlateau(
    monitor=config['callbacks']['reduce_lr']['monitor'],
    factor=config['callbacks']['reduce_lr']['factor'],
    patience=config['callbacks']['reduce_lr']['patience'],
    min_lr=config['callbacks']['reduce_lr']['min_lr']
)

In [69]:
# CSV Logger Callback
csv_logger = CSVLogger(
    filename=config['callbacks']['csv_logger']['filename']
)

In [70]:
# Learning Rate Scheduler Callback
learning_rate_scheduler = LearningRateScheduler(
    schedule=lambda epoch: config['callbacks']['learning_rate_scheduler']['initial_lr'] * config['callbacks']['learning_rate_scheduler']['decay_rate'] ** epoch
)

In [71]:
# Terminate On NaN Callback: Callback that terminates training when a NaN loss is encountered.
terminate_on_nan = TerminateOnNaN()


In [72]:
# Lambda Callback (example: logging epoch start)
lambda_callback = LambdaCallback(
    on_epoch_begin=lambda epoch, logs: print(f"Starting epoch {epoch}"), 
    on_epoch_end=None,
    on_train_begin=None,
    on_train_end=None,
    on_train_batch_begin=None,
    on_train_batch_end=None
)

### Model initialization

In [73]:
# Initialize the model 
from dlomix.models import PrositIntensityPredictor
from dlomix.constants import PTMS_ALPHABET

input_mapping = {
    "SEQUENCE_KEY": "modified_sequence",
    "COLLISION_ENERGY_KEY": "collision_energy_aligned_normed",
    "PRECURSOR_CHARGE_KEY": "precursor_charge_onehot",
    "FRAGMENTATION_TYPE_KEY": "method_nbr",
}

meta_data_keys = ["collision_energy_aligned_normed", "precursor_charge_onehot", "method_nbr"]

model = PrositIntensityPredictor(
    seq_length=config['dataset']['seq_length'],
    alphabet=PTMS_ALPHABET,
    use_prosit_ptm_features=False,
    with_termini=False,
    input_keys=input_mapping,
    meta_data_keys=meta_data_keys
)

In [74]:
# Compile the model 
model.compile(
    optimizer=optimizer,
    loss=masked_spectral_distance,
    metrics=[masked_pearson_correlation_distance]
)

### Model training

In [75]:
# Train the model with various callbacks
model.fit(
    dataset.tensor_train_data,
    validation_data=dataset.tensor_val_data,
    epochs=config['training']['num_epochs'],
    callbacks=[
        # WandbCallback(),
        early_stopping,
        model_checkpoint, 
        # reduce_lr,
        # csv_logger, 
        # lr_scheduler, 
        # terminate_on_nan, 
        # lambda_callback
    ]
)

Epoch 1/2
Epoch 1: saving model to /nfs/home/students/s.baier/mapra/dlomix/bmpc_shared_scripts/refinement_transfer_learning_utils/saved_models/checkpoints/model-01-0.67.hdf5
Epoch 2/2
Epoch 2: saving model to /nfs/home/students/s.baier/mapra/dlomix/bmpc_shared_scripts/refinement_transfer_learning_utils/saved_models/checkpoints/model-02-0.67.hdf5


<keras.src.callbacks.History at 0x7fe600669c00>

In [76]:
# Finish the wandb run
# wandb.finish()