### Imports

In [1]:
from tape import ProteinBertForValuePredictionFragmentationProsit
import numpy as np
from tape import TAPETokenizer
from PrositTransformer.DataHandler import pad_sequences
import torch

### Data pre-processors

##### Charge pre-processors

In [2]:
CHARGES = [1, 2, 3, 4, 5, 6]
def get_precursor_charge_onehot(charges):
    array = np.zeros([len(charges), max(CHARGES)], dtype=int)
    for i, precursor_charge in enumerate(charges):
        array[i, precursor_charge - 1] = 1
    return array



##### Peptide pre-processors

In [None]:
tokenizer = TAPETokenizer()
def TokenizePeptides(peptides):
    input_ids = pad_sequences([tokenizer.encode(p) for p in peptide_sequences])
    return input_ids, np.ones_like(input_ids)

### Toy data

In [3]:
collision_energy, charge, peptide_sequences = np.hstack([0.24, 0.36]), get_precursor_charge_onehot([1,3]), ["AGM", "QPSEP"]

In [4]:
input_ids, input_mask = TokenizePeptides(peptide_sequences)

In [5]:
toy_data = {
    'collision_energy': torch.FloatTensor(collision_energy.astype(np.float32)),
    'charge': torch.FloatTensor(charge.astype(np.float32)),
    'input_ids' : torch.from_numpy(input_ids.astype(np.int64)),
    'input_mask' : torch.from_numpy(input_mask.astype(np.int64))
        }

### Get model

In [6]:
model = ProteinBertForValuePredictionFragmentationProsit.from_pretrained("/sdd/berzelius/delta_-0.025")


### Load GPU

In [7]:
model = model.to(torch.device('cuda:0'))
toy_data = {name: tensor.cuda(device=torch.device('cuda:0'), non_blocking=True)
                     for name, tensor in toy_data.items()}

### Predict

In [8]:
model(**toy_data)[0].cpu().detach().numpy()

array([[-4.21155244e-02, -2.90694050e-02, -2.68873107e-02,
        -2.09047627e-02, -2.90646609e-02, -2.68878061e-02,
        -9.80041921e-02, -2.89471783e-02, -2.67911367e-02,
        -1.67832643e-01, -2.08295155e-02, -2.68641748e-02,
        -5.99793792e-02, -3.19970027e-02, -2.61585228e-02,
        -9.90708619e-02, -1.49606913e-02, -2.66275220e-02,
        -1.79907382e-02, -3.22534889e-02, -2.61591058e-02,
        -7.34631866e-02,  2.97731757e-02, -2.61704400e-02,
        -4.58306670e-02, -3.89692672e-02, -2.63918173e-02,
        -8.14944059e-02, -2.59719733e-02, -2.46919021e-02,
        -1.09346211e-02, -2.59443931e-02, -2.64690481e-02,
        -6.87328503e-02, -3.25176343e-02, -1.79836620e-02,
        -1.18352324e-02,  5.97040802e-02, -2.22637262e-02,
        -4.08426449e-02, -3.13226245e-02,  1.97615884e-02,
        -1.52677596e-02, -2.75294073e-02,  7.15912282e-02,
        -7.57587701e-03, -2.54582744e-02,  1.05183810e-01,
        -1.22921839e-02, -1.83842182e-02,  1.27559304e-0