In [None]:
from pyopenms import *

import pandas as pd
import numpy as np
import tensorflow as tf
from datetime import datetime

from tqdm.notebook import tqdm

import os
# suppress CUDA logs
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 

gpus = tf.config.experimental.list_physical_devices('GPU')

tf.config.experimental.set_virtual_device_configuration(gpus[0], 
                                        [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=2048)])

from matplotlib import pyplot as plt
from ionmob.alignment.experiment import Experiment

from ionmob.models.deep_models import ProjectToInitialSqrtCCS, DeepRecurrentModel, DeepRecurrentConvModel
from ionmob.preprocess.data import get_tf_dataset, partition_tf_dataset, to_tf_dataset
from ionmob.preprocess.helpers import get_sqrt_slopes_and_intercepts, sequence_to_tokens, fit_tokenizer
from ionmob.preprocess.helpers import tokenizer_to_json, tokenizer_from_json, split_dataset

In [None]:
def is_phospho(seq):
    for char in seq:
        if char.find('<PH>') != -1:
            return True
    return False

In [None]:
tokenizer = tokenizer_from_json('../pretrained-models/tokenizers/tokenizer.json')

In [None]:
# TRAIN data
# Meier et al.
meier = pd.read_parquet('../data/Meier.parquet')

# Tenzer data
zepeda = pd.read_parquet('../data/Zepeda_unique.parquet')
tenzer_p = pd.read_parquet('../data/Tenzer-phospho-train.parquet')
tenzer_p['is_phos'] = tenzer_p.apply(lambda r: is_phospho(['sequence-tokenized']), axis=1)
tenzer_p = tenzer_p[tenzer_p.is_phos]

# validation data
tenzer = pd.read_parquet('../data/Tenzer_unique.parquet')
tenzer_p_valid = pd.read_parquet('../data/Tenzer-phospho-valid_unique.parquet')

# TEST data
chang = pd.read_parquet('../data/Chang_unique.parquet')
sara = pd.read_parquet('../data/Sara_unique.parquet')
ogata = pd.read_parquet('../data/Ogata_unique.parquet')

# shuffle and split
TRAIN = pd.concat([meier, zepeda, tenzer_p]).sample(frac=1.0)

# TRAIN, VALID, TEST = split_dataset(TRAIN)

VALID = pd.concat([tenzer, tenzer_p_valid])

train = to_tf_dataset(TRAIN.mz.values, TRAIN.charge.values, 
                      [list(x) for x in TRAIN['sequence-tokenized'].values], 
                      TRAIN.ccs.values, tokenizer, batch=False)

train = train.shuffle(TRAIN.shape[0] + 1).batch(128).prefetch(5)

validation = to_tf_dataset(VALID.mz.values, VALID.charge.values, 
                           [list(x) for x in VALID['sequence-tokenized'].values], 
                           VALID.ccs.values, tokenizer, batch=False)

validation = validation.batch(2024).prefetch(5)

In [None]:
slopes, intercepts = get_sqrt_slopes_and_intercepts(TRAIN.mz.values, TRAIN.charge.values, TRAIN.ccs.values, fit_charge_state_one=True)

In [None]:
logdir = "logs/scalars/" + datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=logdir)

early_stopper = tf.keras.callbacks.EarlyStopping(
    monitor='val_output_1_loss',
    patience=5
)

checkpoint = tf.keras.callbacks.ModelCheckpoint(
    filepath='training/rnn/checkpoint',
    monitor='val_output_1_loss',
    save_best_only=True,
    mode='min'
)

csv_logger = tf.keras.callbacks.CSVLogger(
    filename='training/rnn/training.csv',
    separator=',',
    append=True
)

reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(
    monitor='val_output_1_loss',
    factor=1e-1,
    patience=3,
    monde='auto',
    min_delta=1e-5,
    cooldown=0,
    min_lr=1e-5
)

cbs = [early_stopper, checkpoint, csv_logger, reduce_lr, tensorboard_callback]

model = DeepRecurrentModel(slopes, 
                           intercepts,
                           gru_1=128,
                           gru_2=128,
                           num_tokens=len(tokenizer.word_index), 
                           do=0.2)

model.build([(None, 1), (None, 4), (None, 50)])

model.compile(loss=tf.keras.losses.MeanAbsoluteError(), loss_weights=[1., 0.0],
              optimizer=tf.keras.optimizers.Adam(1e-3), metrics=['mae', 'mean_absolute_percentage_error'])

In [None]:
history = model.fit(train, validation_data=validation, 
                    epochs=100, verbose=True, callbacks=cbs)

In [None]:
# plot training and validation loss 
plt.figure(figsize=(8, 4), dpi=200)
plt.plot(history.history['output_1_mae'], label='training')
plt.plot(history.history['val_output_1_mae'], label='validation')
# plt.hlines(9.5, xmin=0, xmax=len(history.history['output_1_mae']), linestyles='dashed', color='black', linewidth=1, alpha=.75)
plt.xlabel('epoch')
plt.ylabel('MAE')
plt.legend()
plt.show()