# Implementation of Creating Callbacks for Refinement/Transfer Learning
This script is for trying out different callback mechanisms using tensorflow and keras. Goal is to have some easy access functions that we can use for further training, refinement and transfer learning for the Prosit Models and ultimately implement in DLOmix.

### Imports

In [1]:
# Imports
import argparse
import yaml
import uuid

### Parser

In [2]:
# Parsing the configuration file (required when using a script instead of a notebook)
# parser = argparse.ArgumentParser(prog='Baseline Model Training')
# parser.add_argument('--config', type=str, required=True)
# parser.add_argument('--tf-device-nr', type=str, required=True)
# args = parser.parse_args()


### Configuration file

In [3]:
# Manually specify the path to the configuration file
config_file_path = '/nfs/home/students/s.baier/mapra/dlomix/bmpc_shared_scripts/refinement_transfer_learning_utils/config_files/baseline_noptm_baseline_small_bs1024.yaml'

with open(config_file_path, 'r') as yaml_file:
    config = yaml.safe_load(yaml_file)

# Show config containing the configuration data
print(config)


{'dataset': {'name': 'noptm_baseline_small_bs1024', 'hf_home': '/cmnfs/proj/prosit_astral/bmpc_dlomix_group/datasets', 'hf_cache': '/cmnfs/proj/prosit_astral/bmpc_dlomix_group/datasets/hf_cache', 'parquet_path': '/cmnfs/data/proteomics/Prosit_PTMs/Transformer_Train/clean', 'processed_path': '/cmnfs/proj/prosit_astral/bmpc_dlomix_group/datasets/processed/noptm_baseline_small_bs1024', 'seq_length': 30, 'batch_size': 1024}, 'training': {'learning_rate': 0.0001, 'num_epochs': 2}, 'model': {'save_dir': '/cmnfs/proj/prosit_astral/bmpc_dlomix_group/models/callback_models'}, 'processing': {'num_proc': 40}, 'callbacks': {'early_stopping': {'monitor': 'val_loss', 'min_delta': 0.001, 'patience': 20, 'restore_best_weights': True}, 'model_checkpoint': {'filepath': '/nfs/home/students/s.baier/mapra/dlomix/bmpc_shared_scripts/refinement_transfer_learning_utils/saved_models/checkpoints/model-{epoch:02d}-{val_loss:.2f}.hdf5', 'monitor': 'val_loss', 'save_best_only': False, 'save_weights_only': True, 'm

In [4]:
import os
os.environ['HF_HOME'] = config['dataset']['hf_home']
os.environ['HF_DATASETS_CACHE'] = config['dataset']['hf_cache']

# os.environ["CUDA_VISIBLE_DEVICES"] = args.tf_device_nr

### Weights and Biases

In [5]:
# initialize weights and biases
import wandb
# from wandb.keras import WandbCallback
from wandb.integration.keras import WandbCallback

config['run_id'] = uuid.uuid4()

project_name = f'callback model training'
wandb.init(
    project=project_name,
    config=config,
    tags=[config['dataset']['name']], 
    entity = 'mapra_dlomix'
)


2024-06-04 11:45:31.206756: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2024-06-04 11:45:31.275143: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-06-04 11:45:31.275200: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-06-04 11:45:31.276817: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-06-04 11:45:31.286358: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
Failed to detect the name of this notebook, you can set it manually with the 

### Dataset

In [6]:
# Load dataset

# DLOmix dataset 
from dlomix.data import FragmentIonIntensityDataset

# Own dataset
from dlomix.data import load_processed_dataset
dataset = load_processed_dataset(wandb.config['dataset']['processed_path'])



Avaliable feature extractors are (use the key of the following dict and pass it to features_to_extract in the Dataset Class):
{
   "atom_count": "Atom count of PTM.",
   "delta_mass": "Delta mass of PTM.",
   "mod_gain": "Gain of atoms due to PTM.",
   "mod_loss": "Loss of atoms due to PTM.",
   "red_smiles": "Reduced SMILES representation of PTM."
}.
When writing your own feature extractor, you can either
    (1) use the FeatureExtractor class or
    (2) write a function that can be mapped to the Hugging Face dataset.
In both cases, you can access the parsed sequence information from the dataset using the following keys, which all provide python lists:
    - _parsed_sequence: parsed sequence
    - _n_term_mods: N-terminal modifications
    - _c_term_mods: C-terminal modifications



### Optimizer

In [7]:
# Initialize TensorFlow and the optimizer
import tensorflow as tf
optimizer = tf.keras.optimizers.Adam(learning_rate=wandb.config['training']['learning_rate'])

### Loss functions

In [8]:
# Define loss functions
from dlomix.losses import masked_spectral_distance, masked_pearson_correlation_distance

### Callbacks

In [9]:
# Define callbacks
from tensorflow.keras.callbacks import (
    EarlyStopping, ModelCheckpoint, LearningRateScheduler, ReduceLROnPlateau,
    LambdaCallback, TerminateOnNaN, CSVLogger
)

In [10]:
# Early Stopping Callback
early_stopping = EarlyStopping(
    monitor=wandb.config['callbacks']['early_stopping']['monitor'],
    min_delta=wandb.config['callbacks']['early_stopping']['min_delta'],
    patience=wandb.config['callbacks']['early_stopping']['patience'],
    restore_best_weights=wandb.config['callbacks']['early_stopping']['restore_best_weights']
)

In [11]:
# Model Checkpoint Callback
model_checkpoint = ModelCheckpoint(
    filepath=wandb.config['callbacks']['model_checkpoint']['filepath'],
    monitor=wandb.config['callbacks']['model_checkpoint']['monitor'],
    save_best_only=wandb.config['callbacks']['model_checkpoint']['save_best_only'],
    save_weights_only=wandb.config['callbacks']['model_checkpoint']['save_weights_only'],
    mode=wandb.config['callbacks']['model_checkpoint']['mode'],
    save_freq=wandb.config['callbacks']['model_checkpoint']['save_freq'],
    verbose=wandb.config['callbacks']['model_checkpoint']['verbose']
)


In [12]:
# Reduce LR on Plateau Callback
reduce_lr = ReduceLROnPlateau(
    monitor=wandb.config['callbacks']['reduce_lr']['monitor'],
    factor=wandb.config['callbacks']['reduce_lr']['factor'],
    patience=wandb.config['callbacks']['reduce_lr']['patience'],
    min_lr=wandb.config['callbacks']['reduce_lr']['min_lr']
)

In [13]:
# CSV Logger Callback
csv_logger = CSVLogger(
    filename=wandb.config['callbacks']['csv_logger']['filename']
)

In [14]:
# Learning Rate Scheduler Callback
learning_rate_scheduler = LearningRateScheduler(
    schedule=lambda epoch: wandb.config['callbacks']['learning_rate_scheduler']['initial_lr'] * wandb.config['callbacks']['learning_rate_scheduler']['decay_rate'] ** epoch
)

In [15]:
# Terminate On NaN Callback: Callback that terminates training when a NaN loss is encountered.
terminate_on_nan = TerminateOnNaN()


In [16]:
# Lambda Callback (example: logging epoch start)
lambda_callback = LambdaCallback(
    on_epoch_begin=lambda epoch, logs: print(f"Starting epoch {epoch}"), 
    on_epoch_end=None,
    on_train_begin=None,
    on_train_end=None,
    on_train_batch_begin=None,
    on_train_batch_end=None
)

### Initialized Model

In [17]:
# Initialize the model 
from dlomix.models import PrositIntensityPredictor
from dlomix.constants import PTMS_ALPHABET

input_mapping = {
    "SEQUENCE_KEY": "modified_sequence",
    "COLLISION_ENERGY_KEY": "collision_energy_aligned_normed",
    "PRECURSOR_CHARGE_KEY": "precursor_charge_onehot",
    "FRAGMENTATION_TYPE_KEY": "method_nbr",
}

meta_data_keys = ["collision_energy_aligned_normed", "precursor_charge_onehot", "method_nbr"]

model = PrositIntensityPredictor(
    seq_length=wandb.config['dataset']['seq_length'],
    alphabet=PTMS_ALPHABET,
    use_prosit_ptm_features=False,
    with_termini=False,
    input_keys=input_mapping,
    meta_data_keys=meta_data_keys
)

In [18]:
# Compile the model 
model.compile(
    optimizer=optimizer,
    loss=masked_spectral_distance,
    metrics=[masked_pearson_correlation_distance]
)

### Train Model

In [19]:
# train model
model.fit(
    dataset.tensor_train_data,
    validation_data=dataset.tensor_val_data,
    epochs=wandb.config['training']['num_epochs'],
    callbacks=[WandbCallback(save_model=False, log_batch_frequency=True), 
               early_stopping, 
               reduce_lr, 
            #  learning_rate_scheduler,           
               terminate_on_nan, 
            #  lambda_callback, 
            #  csv_logger,   
            #  model_checkpoint
               ]
)


Epoch 1/2
Epoch 2/2


<keras.src.callbacks.History at 0x7fb21c0799f0>

### Save Model

In [20]:
model_path = f"{wandb.config['model']['save_dir']}/{wandb.config['dataset']['name']}/{wandb.config['run_id']}.keras"

model.save(model_path)  # The file needs to end with the .keras extension

In [21]:
# Finish the wandb run
wandb.finish()

VBox(children=(Label(value='0.005 MB of 0.005 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁█
loss,█▅▄▃▃▂▂▂▂▁▁▁▁▁▁▁▁▁
masked_pearson_correlation_distance,█▅▄▃▃▂▂▂▂▁▁▁▁▁▁▁▁▁
val_loss,█▁
val_masked_pearson_correlation_distance,█▁

0,1
best_epoch,1.0
best_val_loss,0.6666
epoch,1.0
loss,0.66762
masked_pearson_correlation_distance,0.55736
val_loss,0.6666
val_masked_pearson_correlation_distance,0.55492


### Load Model

In [22]:
import keras

reconstructed_model = keras.models.load_model(model_path)

In [23]:
reconstructed_model

<dlomix.models.prosit.PrositIntensityPredictor at 0x7fb21c24e410>