### Imports

In [1]:
import tensorflow as tf
import numpy as np
from tape import TAPETokenizer
from PrositTransformer.DataHandler import pad_sequences

### Data pre-processors

##### Charge pre-processors

In [2]:
CHARGES = [1, 2, 3, 4, 5, 6]
def get_precursor_charge_onehot(charges):
    array = np.zeros([len(charges), max(CHARGES)], dtype=int)
    for i, precursor_charge in enumerate(charges):
        array[i, precursor_charge - 1] = 1
    return array



##### Peptide pre-processors

In [3]:
tokenizer = TAPETokenizer()
def TokenizePeptides(peptides):
    input_ids = pad_sequences([tokenizer.encode(p) for p in peptide_sequences])
    return input_ids, np.ones_like(input_ids)

### Toy data

In [4]:
collision_energy, charge, peptide_sequences = np.hstack([0.24, 0.36]), get_precursor_charge_onehot([1,3]), ["AGM", "QPSEP"]

In [5]:
input_ids, input_mask = TokenizePeptides(peptide_sequences)

In [6]:
toy_data = {
    'collision_energy': collision_energy.astype(np.float32),
    'charge': charge.astype(np.float32),
    'input_ids' : input_ids.astype(np.int64),
    'input_mask' : input_mask.astype(np.int64)
        }

### Get model

In [7]:
model = tf.keras.models.load_model('../tf_model/model.pb')

### Predict

In [8]:
model(**toy_data)[0].numpy()

array([[-4.21143621e-02, -2.90691927e-02, -2.68871225e-02,
        -2.09045168e-02, -2.90644504e-02, -2.68876180e-02,
        -9.80034247e-02, -2.89469808e-02, -2.67909467e-02,
        -1.67832077e-01, -2.08292156e-02, -2.68639848e-02,
        -5.99781498e-02, -3.19968574e-02, -2.61583254e-02,
        -9.90706086e-02, -1.49603300e-02, -2.66273338e-02,
        -1.79886222e-02, -3.22533697e-02, -2.61589102e-02,
        -7.34628066e-02,  2.97741145e-02, -2.61702426e-02,
        -4.58282083e-02, -3.89691852e-02, -2.63916329e-02,
        -8.14941153e-02, -2.59717591e-02, -2.46916898e-02,
        -1.09315217e-02, -2.59439722e-02, -2.64688823e-02,
        -6.87325001e-02, -3.25174667e-02, -1.79833677e-02,
        -1.18325800e-02,  5.97056225e-02, -2.22636219e-02,
        -4.08420078e-02, -3.13223936e-02,  1.97623335e-02,
        -1.52658671e-02, -2.75287963e-02,  7.15924501e-02,
        -7.57510215e-03, -2.54580192e-02,  1.05185464e-01,
        -1.22905895e-02, -1.83833838e-02,  1.27561092e-0