# DeepNovoV2 - Jupyter Notebook Implementation


As part of the _Introduction to Deep Learning_ module at ESPCI Paris-PSL, the DeepNovoV2 model by [Qiao et al.](https://arxiv.org/abs/1904.08514) has been dismantled and adapted to work in a Jupyter notebook environment by Simon Chardin and Samuel Diebolt. It was then modified to work with any MGF spectrum file and annotation results from the Mascot software. The model was then trained with data offered by the _Spectrométrie de Masse Biologique et Protéomique (SMBP)_ laboratory from ESPCI Paris-PSL.

This notebook is compute-intensive and was designed to run with an available GPU or virtual GPU. It should work well on Google Colab with moderately sized datasets.

The steps that will be run by the notebook are defined by a list in the following cell, with the following options:

- `train` : training mode;
- `valid` : validation of the previously trained model. This step is performed during training;
- `denovo` : _de novo_ peptide sequencing;
- `test` : evaluation of the _de novo_ results with the theoretical sequences.

After having selected the desired steps, all cells should be run.

In [1]:
option = ['train', 'denovo', 'valid', 'test']

In [2]:
# Python libraries.
import torch
from torch.utils.data import Dataset
from torch import optim
import numpy as np
import pickle
import csv
import re
import sys
from dataclasses import dataclass
import torch.nn as nn
import torch.nn.functional as F
from enum import Enum
from tqdm import tqdm
import time
import math
import os

%qtconsole

# Table of Contents
1. [Model Configuration](#modelconfiguration)
2. [Feature Class, Helper Functions](#featureclass)
3. [Train Function](#train_func)
4. [Neural Networks Models](#model)
5. [Inference Model Wrapper](#inference) 
6. [Build Model](#build_func)
7. [Data Reader Function](#reader)
8. [Accecories Functions](#accesories)
    1. [Collate Function](#collate)  
    2. [Extract & Move Data](#extract)  
    3. [Focal Loss Function](#loss)
9. [Validation Function](#valid_func)
10. [Cython Intregration](#cython)
11. [Training Launch sequence](#train)
12. [Validation of the trainning](#valid)
13. [Denovo Part](#denovo)
    1. [Denovo File Path](#denovo_path)  
    2. [Denovo Data Reader](#denovo_datareader)  
    3. [Results Writer Function](#denovo_writer)  
    4. [Knapsack Implementation](#knapsack)   
    5. [ION CNN Denovo](#ioncnn)  
    6. [Denovo Launch Sequence](#denovo_launch)
13. [Test of the prediction](#test)
    1. [Testing File path selection](#test_path)  
    2. [Worker Test function](#test_worker)  
    3. [Read Feature Accuracy](#test_accuracy)  
    4. [Score cutoff function](#test_cutoff)  
    5. [Testing Launch Sequence](#test_launch)\

## Model Configuration <a name="model configuration"></a>

All parameters contained in the original configuration file `deepnovo_config.py` are declared in the following cells. 

**Caution**: the `batch_size` parameter needs to be assigned in the [`train` function](#train_func) due to an unresolved issue (line 14).

### Global Variables and Vocabulary

In [3]:
# Specify is the model will use one or more LSTM layers.
use_lstm = True

# Directory where the training parameters will be saved. Must be
# created by the user beforehand.
train_dir = 'train'

# Name of the training parameters saved to the previous directory.
forward_model_save_name = 'forward_deepnovo.pth'
backward_model_save_name = 'backward_deepnovo.pth'
init_net_save_name = 'init_net.pth'

# Activation function for the model.
activation_func = F.relu

# Enable CUDA if available.
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print (device)

class Direction(Enum):
    forward = 1
    backward = 2

cuda:0


In [4]:
# ==============================================================================
# GLOBAL VARIABLES for VOCABULARY
# ==============================================================================

# Special vocabulary symbols - we always put them at the start.
_PAD = "_PAD"
_GO = "_GO"
_EOS = "_EOS"
_START_VOCAB = [_PAD, _GO, _EOS]

PAD_ID = 0
GO_ID = 1
EOS_ID = 2
assert PAD_ID == 0

# Amino acids vocabulary. This must match the vocabulary used in the
# sequences available in the features file.
vocab_reverse = ['A',
                 'R',
                 'N',
                 'N(Deamidated)',
                 'D',
                 'C',
                 'C(Carboxymethyl)',
                 'E',
                 'Q',
                 'Q(Deamidated)',
                 'G',
                 'H',
                 'I',
                 'L',
                 'K',
                 'M',
                 'M(Oxidation)',
                 'F',
                 'P',
                 'S',
                 'T',
                 'W',
                 'Y',
                 'V']

vocab_reverse = _START_VOCAB + vocab_reverse
print("vocab_reverse ", vocab_reverse)

vocab = dict([(x, y) for (y, x) in enumerate(vocab_reverse)])
print("vocab ", vocab)

vocab_size = len(vocab_reverse)
print("vocab_size ", vocab_size)


# ==============================================================================
# GLOBAL VARIABLES for THEORETICAL MASS
# ==============================================================================
mass_H = 1.0078
mass_H2O = 18.0106
mass_NH3 = 17.0265
mass_N_terminus = 1.0078
mass_C_terminus = 17.0027
mass_CO = 27.9949

# Masses associated to the vocabulary.
mass_AA = {'_PAD': 0.0,
           '_GO': mass_N_terminus-mass_H,
           '_EOS': mass_C_terminus+mass_H,
           'A': 71.03711,
           'R': 156.10111,
           'N': 114.04293,
           'N(Deamidated)': 115.02695,
           'D': 115.02694,
           'C': 103.00919,
           'C(Carboxymethyl)': 161.01919,
           'E': 129.04259,
           'Q': 128.05858,
           'Q(Deamidated)': 129.0426,
           'G': 57.02146,
           'H': 137.05891,
           'I': 113.08406,
           'L': 113.08406, 
           'K': 128.09496,
           'M': 131.04049,
           'M(Oxidation)': 147.0354,
           'F': 147.06841,
           'P': 97.05276,
           'S': 87.03203,
           'T': 101.04768,
           'W': 186.07931,
           'Y': 163.06333,
           'V': 99.06841,
          }

mass_ID = [mass_AA[vocab_reverse[x]] for x in range(vocab_size)]
mass_ID_np = np.array(mass_ID, dtype=np.float32)

mass_AA_min = mass_AA["G"] # 57.02146


# ==============================================================================
# GLOBAL VARIABLES for PRECISION, RESOLUTION, temp-Limits of MASS & LEN
# ==============================================================================

WINDOW_SIZE = 10 # 10 bins
print("WINDOW_SIZE ", WINDOW_SIZE)

# Maximum peptide m/z ratio allowed.
MZ_MAX = 3000.0

# Number of top peaks selected in each spectrum.
MAX_NUM_PEAK = 500

# Knapsack dynamic programming parameters.
knapsack_file = "knapsack.npy"
KNAPSACK_AA_RESOLUTION = 10000 # 0.0001 Da
mass_AA_min_round = int(round(mass_AA_min * KNAPSACK_AA_RESOLUTION)) # 57.02146
KNAPSACK_MASS_PRECISION_TOLERANCE = 100 # 0.01 Da
num_position = 0

PRECURSOR_MASS_PRECISION_TOLERANCE = 0.01

# Tolerance for an amino acid prediction match.
AA_MATCH_PRECISION = 0.1

# Maximum peptide length.
MAX_LEN = 30 
print("MAX_LEN ", MAX_LEN)


# ==============================================================================
# HYPER-PARAMETERS of the NEURAL NETWORKS
# ==============================================================================
num_ion = 12
print("num_ion ", num_ion)

weight_decay = 0.0  # no weight decay lead to better result.
print("weight_decay ", weight_decay)

#~ encoding_cnn_size = 4 * (RESOLUTION//10) # 4 # proportion to RESOLUTION
#~ encoding_cnn_filter = 4
#~ print("encoding_cnn_size ", encoding_cnn_size)
#~ print("encoding_cnn_filter ", encoding_cnn_filter)

embedding_size = 512
print("embedding_size ", embedding_size)

# LSTM parameters.
num_lstm_layers = 2
num_units = 64
lstm_hidden_units = 512
print("num_lstm_layers ", num_lstm_layers)
print("num_units ", num_units)

# Dropout isn't used in the LSTM, as the authors found it hindered performance.
dropout_rate = 0.25

# Batch size. Setting this value too high will cause memory issues.
batch_size = 24

num_workers = 2
print("batch_size ", batch_size)

# Number of epochs to use.
num_epoch = 5

# Initial learning rate.
init_lr = 1e-3

# Number of steps per validation.
steps_per_validation = 10
print("steps_per_validation ", steps_per_validation)

# De novo beam search parameter.
beam_size_param = 5 #2 #5
print("Beam size = ",beam_size_param)

# Feature file column format.
col_feature_id = "spec_group_id"
col_precursor_mz = "m/z"
col_precursor_charge = "z"
col_rt_mean = "rt_mean"
col_raw_sequence = "seq"
col_scan_list = "scans"
col_feature_area = "feature area"

# Predicted file column format.
pcol_feature_id = 0
pcol_feature_area = 1
pcol_sequence = 2
pcol_score = 3
pcol_position_score = 4
pcol_precursor_mz = 5
pcol_precursor_charge = 6
pcol_protein_id = 7
pcol_scan_list_middle = 8
pcol_scan_list_original = 9
pcol_score_max = 10

# Parameters for the Cython module.
distance_scale_factor = 100.
sinusoid_base = 30000.
spectrum_reso = 10
n_position = int(MZ_MAX) * spectrum_reso

vocab_reverse  ['_PAD', '_GO', '_EOS', 'A', 'R', 'N', 'N(Deamidated)', 'D', 'C', 'C(Carboxymethyl)', 'E', 'Q', 'Q(Deamidated)', 'G', 'H', 'I', 'L', 'K', 'M', 'M(Oxidation)', 'F', 'P', 'S', 'T', 'W', 'Y', 'V']
vocab  {'_PAD': 0, '_GO': 1, '_EOS': 2, 'A': 3, 'R': 4, 'N': 5, 'N(Deamidated)': 6, 'D': 7, 'C': 8, 'C(Carboxymethyl)': 9, 'E': 10, 'Q': 11, 'Q(Deamidated)': 12, 'G': 13, 'H': 14, 'I': 15, 'L': 16, 'K': 17, 'M': 18, 'M(Oxidation)': 19, 'F': 20, 'P': 21, 'S': 22, 'T': 23, 'W': 24, 'Y': 25, 'V': 26}
vocab_size  27
WINDOW_SIZE  10
MAX_LEN  30
num_ion  12
weight_decay  0.0
embedding_size  512
num_lstm_layers  2
num_units  64
batch_size  24
steps_per_validation  10
Beam size =  5


### Input Files

In [5]:
# ==============================================================================
# DATASETS Path
# ==============================================================================

input_spectrum_file_train = "../smbp_data/spectrum_smbp.mgf"
input_feature_file_train = "../smbp_data/features_smbp.csv.train"
input_spectrum_file_valid = "../smbp_data/spectrum_smbp.mgf"
input_feature_file_valid = "../smbp_data/features_smbp.csv.valid"

## Feature Class, Helper Functions <a name="featureclass"></a>

- The `Feature` class is used to store data obtained from the features file.
- The `perplexity` function is used to compute the perplexity (i.e. the measurement of how well the model predicts the sample) from the log loss during training and validation.
- The `adjust_learning_rate` function is used to add a learning rate decay every three epochs during training.
- The `save_model` function will save the training parameters of the best model.

In [6]:
#------------------------------------------- Feature Class ------------------------------
@dataclass
class Feature:
    spec_id: str
    mz: str
    z: str
    rt_mean: str
    seq: str
    scan: str

    def to_list(self):
        """Convert the dataset to list"""
        return [self.spec_id, self.mz, self.z, self.rt_mean, self.seq, self.scan, "0.0:1.0", "1.0"]

In [7]:
def perplexity(log_loss):
    """ Compute the perplexity from the loss of the model"""
    return math.exp(log_loss) if log_loss < 300 else float('inf')


def adjust_learning_rate(optimizer, epoch):
    """Sets the learning rate to the initial LR decayed by 10 every 3 epochs"""
    lr = init_lr * (0.1 ** ((epoch + 1) // 3))
    print(f"epoch: {epoch}\tlr: {lr}")
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr


def save_model(forward_deepnovo, backward_deepnovo, init_net):
    """
    Save the model paramters for forward, backward and initalisation net, at the
    location sets in the Global variables
    """
    torch.save(forward_deepnovo.state_dict(), os.path.join(train_dir,
                                                           forward_model_save_name))
    torch.save(backward_deepnovo.state_dict(), os.path.join(train_dir,
                                                            backward_model_save_name))
    if use_lstm:
        torch.save(init_net.state_dict(), os.path.join(train_dir,
                                                   init_net_save_name))

## Train Function <a name="train_func"></a>

The train function performs the model training and validation. It will print out the training and validation perplexity regularly and will save the optimal model parameters using the previously defined `save_model` function.

In [8]:
training_perp_tab = []
valid_perp_tab = []

def train():
    """
    Function govern the training of the model:
    Parameters = input_feature_file_train
                  input_spectrum_file_train
    """
    # Training dataset creation.
    train_set = DeepNovoTrainDataset(input_feature_file_train,
                                     input_spectrum_file_train)
    num_train_features = len(train_set)
    global batch_size
    steps_per_epoch = int(num_train_features / batch_size)
    print(steps_per_epoch,'steps per epoch')
    train_data_loader = torch.utils.data.DataLoader(dataset=train_set,
                                                    batch_size=batch_size,
                                                    shuffle=True,
                                                    num_workers=num_workers,
                                                    collate_fn=collate_func)
    # Validation dataset creation
    valid_set = DeepNovoTrainDataset(input_feature_file_valid,
                                     input_spectrum_file_valid)
    valid_data_loader = torch.utils.data.DataLoader(dataset=valid_set,
                                                    batch_size=batch_size,
                                                    shuffle=False,
                                                    num_workers=num_workers,
                                                    collate_fn=collate_func)
    
    forward_deepnovo, backward_deepnovo, init_net = build_model()
    dense_params = list(forward_deepnovo.parameters()) + list(backward_deepnovo.parameters())
    dense_optimizer = optim.Adam(dense_params,
                                 lr=init_lr,
                                 weight_decay=weight_decay)
    dense_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(dense_optimizer, 'min', factor=0.5, verbose=True,
                                                                 threshold=1e-4, cooldown=10, min_lr=1e-5)

    best_valid_loss = float("inf")
    # train loop
    best_epoch = None
    best_step = None
    start_time = time.time()
    for epoch in tqdm(range(num_epoch)):
        # Adjust learning rate every 3 epoch.
        adjust_learning_rate(dense_optimizer, epoch)
        for i, data in enumerate(train_data_loader):
            dense_optimizer.zero_grad()
            peak_location, \
            peak_intensity, \
            spectrum_representation, \
            batch_forward_id_target, \
            batch_backward_id_target, \
            batch_forward_ion_index, \
            batch_backward_ion_index, \
            batch_forward_id_input, \
            batch_backward_id_input = extract_and_move_data(data)
            batch_size = batch_backward_id_target.size(0)

            if use_lstm:
                initial_state_tuple = init_net(spectrum_representation)
                forward_logit, _ = forward_deepnovo(batch_forward_ion_index, peak_location, peak_intensity,
                                                    batch_forward_id_input, initial_state_tuple)
                backward_logit, _ = backward_deepnovo(batch_backward_ion_index, peak_location, peak_intensity,
                                                      batch_backward_id_input, initial_state_tuple)
            else:
                forward_logit = forward_deepnovo(batch_forward_ion_index, peak_location, peak_intensity)
                backward_logit = backward_deepnovo(batch_backward_ion_index, peak_location, peak_intensity)

            forward_loss, _ = focal_loss(forward_logit, batch_forward_id_target, ignore_index=0, gamma=2.)
            backward_loss, _ = focal_loss(backward_logit, batch_backward_id_target, ignore_index=0, gamma=2.)
            total_loss = (forward_loss + backward_loss) / 2.
            # compute gradient
            total_loss.backward()
            dense_optimizer.step()
            #depending of the steps_per_validation selected, print the time it takes,
            #for the model to train, and display the loss off training and of validation
            if (i + 1) % steps_per_validation == 0:
                duration = time.time() - start_time
                step_time = duration / steps_per_validation
                loss_cpu = total_loss.item()
                # evaluation mode
                forward_deepnovo.eval()
                backward_deepnovo.eval()
                validation_loss = validation(forward_deepnovo, backward_deepnovo, init_net, valid_data_loader)
                dense_scheduler.step(validation_loss)

                #Training and validation loss are saved to display graph of evolution.
                training_perp_tab.append(perplexity(loss_cpu))
                valid_perp_tab.append(perplexity(validation_loss))

                print(f"epoch {epoch} step {i}/{steps_per_epoch}, "
                            f"train perplexity: {perplexity(loss_cpu)}\t"
                            f"validation perplexity: {perplexity(validation_loss)}\tstep time: {step_time}")

                if validation_loss < best_valid_loss:
                    best_valid_loss = validation_loss
                    print('best valid loss achieved at epoch',epoch, 'step', i)
                    best_epoch = epoch
                    best_step = i
                    # save model if achieve a new best valid loss. Careful, if the
                    #step_per_validation is to low compared to the size of the data
                    #the model may not be saved, because not evaluated during training
                    save_model(forward_deepnovo, backward_deepnovo, init_net)

                # back to train model
                forward_deepnovo.train()
                backward_deepnovo.train()

                start_time = time.time()
            # observed that most of gpu memory is unoccupied cache, so clear cache after each batch
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

    print('best model at epoch',best_epoch,'step',best_step)

## Neural Network Models <a name="model"></a>

All models used in DeepNovoV2 are defined below.

   ### DeepNovoPointNet with LSTM

In [9]:
class DeepNovoPointNetWithLSTM(nn.Module):
    def __init__(self):
        super(DeepNovoPointNetWithLSTM, self).__init__()
        self.t_net = TNet(with_lstm=True)
        self.embedding = nn.Embedding(num_embeddings=vocab_size,
                                      embedding_dim=embedding_size)
        self.lstm = nn.LSTM(embedding_size,
                            lstm_hidden_units,
                            num_layers=num_lstm_layers,
                            batch_first=True)
        self.dropout = nn.Dropout(dropout_rate)
        self.output_layer = nn.Linear(num_units + lstm_hidden_units,
                                      vocab_size)

    def forward(self, location_index, peaks_location, peaks_intensity, aa_input=None, state_tuple=None):
        """

        :param location_index: [batch, T, 26, 8] long
        :param peaks_location: [batch, N] N stands for MAX_NUM_PEAK, long
        :param peaks_intensity: [batch, N], float32
        :param aa_input:[batch, T]
        :param state_tuple: (h0, c0), where each is [num_lstm_layer, batch_size, num_units] tensor
        :return:
            logits: [batch, T, 26]
        """
        assert aa_input is not None
        N = peaks_location.size(1)
        assert N == peaks_intensity.size(1)
        batch_size, T, vocab_size, num_ion = location_index.size()

        peaks_location = peaks_location.view(batch_size, 1, N, 1)
        peaks_intensity = peaks_intensity.view(batch_size, 1, N, 1)
        peaks_location = peaks_location.expand(-1, T, -1, -1)  # [batch, T, N, 1]
        peaks_location_mask = (peaks_location > 1e-5).float()
        peaks_intensity = peaks_intensity.expand(-1, T, -1, -1)  # [batch, T, N, 1]

        location_index = location_index.view(batch_size, T, 1, vocab_size * num_ion)
        location_index_mask = (location_index > 1e-5).float()

        location_exp_minus_abs_diff = torch.exp(
            -torch.abs(
                (peaks_location - location_index) * distance_scale_factor
            )
        )
        # [batch, T, N, 26*8]

        location_exp_minus_abs_diff = location_exp_minus_abs_diff * peaks_location_mask * location_index_mask

        input_feature = torch.cat((location_exp_minus_abs_diff, peaks_intensity), dim=3)
        input_feature = input_feature.view(batch_size * T, N, vocab_size * num_ion + 1)
        input_feature = input_feature.transpose(1, 2)

        ion_feature = self.t_net(input_feature).view(batch_size, T, num_units)  # attention on peaks

        # embedding
        aa_embedded = self.embedding(aa_input)
        lstm_input = aa_embedded  # [batch, T, embedding_size]
        #dropout layer
        #lstm_input = self.dropout(lstm_input)
        #dropout doesn't appear to be efficient in this configuration
        output_feature, new_state_tuple = self.lstm(lstm_input, state_tuple)
        output_feature = torch.cat((ion_feature, activation_func(output_feature)), dim=2)
        output_feature = self.dropout(output_feature)
        logit = self.output_layer(output_feature)
        return logit, new_state_tuple
  

### T-Net Module

Same model as presented in the [PointNet article](https://arxiv.org/abs/1612.00593).

In [10]:
class TNet(nn.Module):
    """
    the T-net structure in the Point Net paper
    """
    def __init__(self, with_lstm=False):
        super(TNet, self).__init__()
        self.with_lstm = with_lstm
        self.conv1 = nn.Conv1d(vocab_size * num_ion + 1, num_units, 1)
        self.conv2 = nn.Conv1d(num_units, 2*num_units, 1)
        self.conv3 = nn.Conv1d(2*num_units, 4*num_units, 1)
        self.fc1 = nn.Linear(4*num_units, 2*num_units)
        self.fc2 = nn.Linear(2*num_units, num_units)
        if not with_lstm:
            self.output_layer = nn.Linear(num_units, vocab_size)
        self.relu = nn.ReLU()

        self.input_batch_norm = nn.BatchNorm1d(vocab_size * num_ion + 1)

        self.bn1 = nn.BatchNorm1d(num_units)
        self.bn2 = nn.BatchNorm1d(2*num_units)
        self.bn3 = nn.BatchNorm1d(4*num_units)
        self.bn4 = nn.BatchNorm1d(2*num_units)
        self.bn5 = nn.BatchNorm1d(num_units)

    def forward(self, x):
        """

        :param x: [batch * T, 26*8+1, N]
        :return:
            logit: [batch * T, 26]
        """
        x = self.input_batch_norm(x)
        x = activation_func(self.bn1(self.conv1(x)))
        x = activation_func(self.bn2(self.conv2(x)))
        x = activation_func(self.bn3(self.conv3(x)))
        x, _ = torch.max(x, dim=2)  # global max pooling
        assert x.size(1) == 4*num_units

        x = activation_func(self.bn4(self.fc1(x)))
        x = activation_func(self.bn5(self.fc2(x)))
        if not self.with_lstm:
            x = self.output_layer(x)  # [batch * T, 26]
        return x

### Initialization Layer of the LSTM


In [11]:
class InitNet(nn.Module):
    def __init__(self):
        super(InitNet, self).__init__()
        self.init_state_layer = nn.Linear(embedding_size, 2 * lstm_hidden_units)

    def forward(self, spectrum_representation):
        """

        :param spectrum_representation: [N, embedding_size]
        :return:
            [num_lstm_layers, batch_size, lstm_units], [num_lstm_layers, batch_size, lstm_units],
        """
        x = torch.tanh(self.init_state_layer(spectrum_representation))
        h_0, c_0 = torch.split(x, lstm_hidden_units, dim=1)
        h_0 = torch.unsqueeze(h_0, dim=0)
        h_0 = h_0.repeat(num_lstm_layers, 1, 1)
        c_0 = torch.unsqueeze(c_0, dim=0)
        c_0 = c_0.repeat(num_lstm_layers, 1, 1)
        return h_0, c_0

### DeepNovoPointNet Model, without LSTM

In [12]:
class DeepNovoPointNet(nn.Module):
    def __init__(self):
        super(DeepNovoPointNet, self).__init__()
        self.t_net = TNet(with_lstm=False)
        self.distance_scale_factor = distance_scale_factor

    def forward(self, location_index, peaks_location, peaks_intensity):
        """

        :param location_index: [batch, T, 26, 8] long
        :param peaks_location: [batch, N] N stands for MAX_NUM_PEAK, long
        :param peaks_intensity: [batch, N], float32
        :return:
            logits: [batch, T, 26]
        """

        N = peaks_location.size(1)
        assert N == peaks_intensity.size(1)
        batch_size, T, vocab_size, num_ion = location_index.size()

        peaks_location = peaks_location.view(batch_size, 1, N, 1)
        peaks_intensity = peaks_intensity.view(batch_size, 1, N, 1)
        peaks_location = peaks_location.expand(-1, T, -1, -1)  # [batch, T, N, 1]
        peaks_location_mask = (peaks_location > 1e-5).float()
        peaks_intensity = peaks_intensity.expand(-1, T, -1, -1)  # [batch, T, N, 1]

        location_index = location_index.view(batch_size, T, 1, vocab_size*num_ion)
        location_index_mask = (location_index > 1e-5).float()

        location_exp_minus_abs_diff = torch.exp(
            -torch.abs(
                (peaks_location - location_index) * self.distance_scale_factor
            )
        )
        # [batch, T, N, 26*8]

        location_exp_minus_abs_diff = location_exp_minus_abs_diff * peaks_location_mask * location_index_mask

        input_feature = torch.cat((location_exp_minus_abs_diff, peaks_intensity), dim=3)
        input_feature = input_feature.view(batch_size*T, N, vocab_size*num_ion + 1)
        input_feature = input_feature.transpose(1, 2)

        result = self.t_net(input_feature).view(batch_size, T, vocab_size)
        return result


if use_lstm: #Global Variable
    DeepNovoModel = DeepNovoPointNetWithLSTM
else:
    DeepNovoModel = DeepNovoPointNet


## Inference Model Wrapper <a name="inference"></a>


In [13]:
class InferenceModelWrapper(object):
    """
    a wrapper class so that the beam search part of code is the same for both with lstm and without lstm model.
    """
    def __init__(self, forward_model: DeepNovoModel, backward_model: DeepNovoModel, init_net: InitNet=None):
        self.forward_model = forward_model
        self.backward_model = backward_model
        # make sure all models are in eval mode
        self.forward_model.eval()
        self.backward_model.eval()
        if use_lstm:
            assert init_net is not None
            self.init_net = init_net
            self.init_net.eval()

    def step(self, candidate_location, peaks_location, peaks_intensity, aa_input, state_tuple, direction):
        """
        :param state_tuple: tuple of ([num_layer, batch_size, num_unit], [num_layer, batch_size, num_unit])
        :param aa_input: [batch, 1]
        :param candidate_location: [batch, 1, 26, 8]
        :param peaks_location: [batch, N]
        :param peaks_intensity: [batch, N]
        :param direction: enum class, whether forward or backward
        :return: (log_prob, new_hidden_state)
        log_prob: the pred log prob of shape [batch, 26]
        """
        if direction == Direction.forward:
            model = self.forward_model
        else:
            model = self.backward_model

        with torch.no_grad():
            if use_lstm:
                logit, new_state_tuple = model(candidate_location, peaks_location, peaks_intensity, aa_input,
                                               state_tuple)
            else:
                logit = model(candidate_location, peaks_location, peaks_intensity)
                new_state_tuple = None
            logit = torch.squeeze(logit, dim=1)
            log_prob = F.log_softmax(logit)
        return log_prob, new_state_tuple

    def initial_hidden_state(self, spectrum_representation):
        """

        :param: spectrum_representation, [batch, embedding_size]
        :return:
            [num_lstm_layers, batch_size, lstm_units], [num_lstm_layers, batch_size, lstm_units],
        """
        with torch.no_grad():
            h_0, c_0 = self.init_net(spectrum_representation)
            return h_0.to(device), c_0.to(device)

## Build Model  <a name="build_func"></a>

The `build_model` function uses the previously defined classes and the model parameters to create the final model (with or without LSTM, etc.).

In [14]:
def build_model(training=True):
    """
    :Parameters:
        training to put the model in training of prediction mode
    :return:
        forward_deepnovo, backward_deepnovo, init_net : in device, and with the 
        pretrained parameters if present.
    """
    forward_deepnovo = DeepNovoModel()
    backward_deepnovo = DeepNovoModel()
    if use_lstm:
        init_net = InitNet() #initialisation of the LSTM
    else:
        init_net = None

    # load pretrained params if exist
    if os.path.exists(os.path.join(train_dir, forward_model_save_name)):
        assert os.path.exists(os.path.join(train_dir, backward_model_save_name))
        print("load pretrained model")
        forward_deepnovo.load_state_dict(torch.load(os.path.join(train_dir, forward_model_save_name),
                                                    map_location=device))
        backward_deepnovo.load_state_dict(torch.load(os.path.join(train_dir, backward_model_save_name),
                                                     map_location=device))
        if use_lstm:
            init_net.load_state_dict(torch.load(os.path.join(train_dir, init_net_save_name),
                                                map_location=device))
    else:
        assert training, f"building model for testing, but could not found weight under directory " \
                         f"{train_dir}"
        print("initialize a set of new parameters")

    if use_lstm:
        # share embedding matrix
        backward_deepnovo.embedding.weight = forward_deepnovo.embedding.weight

    backward_deepnovo = backward_deepnovo.to(device)
    forward_deepnovo = forward_deepnovo.to(device)
    if use_lstm:
        init_net = init_net.to(device)
    return forward_deepnovo, backward_deepnovo, init_net

## Data Reader <a name="reader"></a>

The data reader module is used to parse the input files and generate the Pytorch dataset.

### Raw Sequence Parser

The original raw sequence parser of DeepNovoV2 contained hardcoded PTMs and was usable only with the sequence format from PEAKS. It has been bypassed, since our Python module ensures that the sequences in the features file will contain the same vocabulary as defined in the model configuration.

In [15]:
#-----------------------------------Data_reader------------------------------------------
def parse_raw_sequence(raw_sequence: str):
    return True, re.findall(r'[A-Z](?:\(.+?\))?', raw_sequence)

### Dataset Classes

In [16]:
#-----------------------------------Class Definition------------------------------------------

@dataclass
class DDAFeature:
    feature_id: str
    mz: float
    z: float
    rt_mean: float
    peptide: list
    scan: str
    mass: float
    feature_area: str

@dataclass
class DenovoData:
    peak_location: np.ndarray
    peak_intensity: np.ndarray
    spectrum_representation: np.ndarray
    original_dda_feature: DDAFeature


@dataclass
class TrainData:
    peak_location: np.ndarray
    peak_intensity: np.ndarray
    spectrum_representation: np.ndarray
    forward_id_target: list
    backward_id_target: list
    forward_ion_location_index_list: list
    backward_ion_location_index_list: list
    forward_id_input: list
    backward_id_input: list


class DeepNovoTrainDataset(Dataset):
    def __init__(self, feature_filename, spectrum_filename, transform=None):
        """
        read all feature information and store in memory,
        :param feature_filename:
        :param spectrum_filename:
        """
        print(f"input spectrum file: {spectrum_filename}")
        print(f"input feature file: {feature_filename}")
        self.spectrum_filename = spectrum_filename
        self.input_spectrum_handle = None
        self.feature_list = []
        self.spectrum_location_dict = {}
        self.transform = transform
        # read spectrum location file
        spectrum_location_file = spectrum_filename + '.location.pytorch.pkl'
        if os.path.exists(spectrum_location_file):
            print("read cached spectrum locations")
            with open(spectrum_location_file, 'rb') as fr:
                self.spectrum_location_dict = pickle.load(fr)
        else:
            print("build spectrum location from scratch")
            spectrum_location_dict = {}
            line = True
            with open(spectrum_filename, 'r') as f:
                while line:
                    current_location = f.tell()
                    line = f.readline()
                    if "BEGIN IONS" in line:
                        spectrum_location = current_location
                    elif "SCANS=" in line:
                        scan = re.split('[=\r\n]', line)[1]
                        spectrum_location_dict[scan] = spectrum_location
            self.spectrum_location_dict = spectrum_location_dict
            with open(spectrum_location_file, 'wb') as fw:
                pickle.dump(self.spectrum_location_dict, fw)

        # read feature file
        skipped_by_mass = 0
        skipped_by_ptm = 0
        skipped_by_length = 0
        with open(feature_filename, 'r') as fr:
            reader = csv.reader(fr, delimiter=',')
            header = next(reader)
            feature_id_index = header.index(col_feature_id)
            mz_index = header.index(col_precursor_mz)
            z_index = header.index(col_precursor_charge)
            rt_mean_index = header.index(col_rt_mean)
            seq_index = header.index(col_raw_sequence)
            scan_index = header.index(col_scan_list)
            feature_area_index = header.index(col_feature_area)
            for line in reader:
                mass = (float(line[mz_index]) - mass_H) * float(line[z_index])
                ok, peptide = parse_raw_sequence(line[seq_index])
                if not ok:
                    skipped_by_ptm += 1
                    print(f"{line[seq_index]} skipped by ptm")
                    continue
                if mass > MZ_MAX:
                    skipped_by_mass += 1
                    continue
                if len(peptide) >= MAX_LEN:
                    skipped_by_length += 1
                    continue
                new_feature = DDAFeature(feature_id=line[feature_id_index],
                                         mz=float(line[mz_index]),
                                         z=float(line[z_index]),
                                         rt_mean=float(line[rt_mean_index]),
                                         peptide=peptide,
                                         scan=line[scan_index],
                                         mass=mass,
                                         feature_area=line[feature_area_index])
                self.feature_list.append(new_feature)
        print(f"read {len(self.feature_list)} features, {skipped_by_mass} skipped by mass, "
                    f"{skipped_by_ptm} skipped by unknown modification, {skipped_by_length} skipped by length")

    def __len__(self):
        return len(self.feature_list)

    def close(self):
        self.input_spectrum_handle.close()

    def _parse_spectrum_ion(self):
        mz_list = []
        intensity_list = []
        line = self.input_spectrum_handle.readline()
        while not "END IONS" in line:
            mz, intensity = re.split(' |\r|\n', line)[:2]
            mz_float = float(mz)
            intensity_float = float(intensity)
            # skip an ion if its mass > MZ_MAX
            if mz_float > MZ_MAX:
                line = self.input_spectrum_handle.readline()
                continue
            mz_list.append(mz_float)
            intensity_list.append(intensity_float)
            line = self.input_spectrum_handle.readline()
        return mz_list, intensity_list

    def _get_feature(self, feature: DDAFeature) -> TrainData:
        spectrum_location = self.spectrum_location_dict[feature.scan]
        self.input_spectrum_handle.seek(spectrum_location)
        # parse header lines
        line = self.input_spectrum_handle.readline()
        assert "BEGIN IONS" in line, "Error: wrong input BEGIN IONS"
        line = self.input_spectrum_handle.readline()
        assert "TITLE=" in line, "Error: wrong input TITLE="
        line = self.input_spectrum_handle.readline()
        assert "PEPMASS=" in line, "Error: wrong input PEPMASS="
        line = self.input_spectrum_handle.readline()
        assert "CHARGE=" in line, "Error: wrong input CHARGE="
        line = self.input_spectrum_handle.readline()
        assert "SCANS=" in line, "Error: wrong input SCANS="
        line = self.input_spectrum_handle.readline()
        assert "RTINSECONDS=" in line, "Error: wrong input RTINSECONDS="
        mz_list, intensity_list = self._parse_spectrum_ion()
        peak_location, peak_intensity, spectrum_representation = process_peaks(mz_list, intensity_list, feature.mass)

        assert np.max(peak_intensity) < 1.0 + 1e-5
        # no id needed for the denovo (Denovo part)
        peptide_id_list = [vocab[x] for x in feature.peptide]
        forward_id_input = [GO_ID] + peptide_id_list
        forward_id_target = peptide_id_list + [EOS_ID]
        forward_ion_location_index_list = []
        prefix_mass = 0.
        for i, id in enumerate(forward_id_input):
            prefix_mass += mass_ID[id]
            ion_location = get_ion_index(feature.mass, prefix_mass, 0)
            forward_ion_location_index_list.append(ion_location)

        backward_id_input = [EOS_ID] + peptide_id_list[::-1]
        backward_id_target = peptide_id_list[::-1] + [GO_ID]
        backward_ion_location_index_list = []
        suffix_mass = 0
        for i, id in enumerate(backward_id_input):
            suffix_mass += mass_ID[id]
            ion_location = get_ion_index(feature.mass, suffix_mass, 1)
            backward_ion_location_index_list.append(ion_location)

        return TrainData(peak_location=peak_location,
                         peak_intensity=peak_intensity,
                         spectrum_representation=spectrum_representation,
                         forward_id_target=forward_id_target,
                         backward_id_target=backward_id_target,
                         forward_ion_location_index_list=forward_ion_location_index_list,
                         backward_ion_location_index_list=backward_ion_location_index_list,
                         forward_id_input=forward_id_input,
                         backward_id_input=backward_id_input)

    def __getitem__(self, idx):
        if self.input_spectrum_handle is None:
            self.input_spectrum_handle = open(self.spectrum_filename, 'r')
        feature = self.feature_list[idx]
        return self._get_feature(feature)

## Collate Function <a name="collate"></a>
Allows for the creation of a map-style for the dataset `TrainData`.

In [17]:
def collate_func(train_data_list):
    """

    :param train_data_list: list of TrainData
    :return:
        peak_location: [batch, N]
        peak_intensity: [batch, N]
        forward_target_id: [batch, T]
        backward_target_id: [batch, T]
        forward_ion_index_list: [batch, T, 26, 8]
        backward_ion_index_list: [batch, T, 26, 8]
    """
    # sort data by seq length (decreasing order)
    train_data_list.sort(key=lambda x: len(x.forward_id_target), reverse=True)
    batch_max_seq_len = len(train_data_list[0].forward_id_target)
    ion_index_shape = train_data_list[0].forward_ion_location_index_list[0].shape
    assert ion_index_shape == (vocab_size, num_ion)

    peak_location = [x.peak_location for x in train_data_list]
    peak_location = np.stack(peak_location) # [batch_size, N]
    peak_location = torch.from_numpy(peak_location)

    peak_intensity = [x.peak_intensity for x in train_data_list]
    peak_intensity = np.stack(peak_intensity) # [batch_size, N]
    peak_intensity = torch.from_numpy(peak_intensity)

    spectrum_representation = [x.spectrum_representation for x in train_data_list]
    spectrum_representation = np.stack(spectrum_representation)  # [batch_size, embed_size]
    spectrum_representation = torch.from_numpy(spectrum_representation)

    batch_forward_ion_index = []
    batch_forward_id_target = []
    batch_forward_id_input = []
    for data in train_data_list:
        ion_index = np.zeros((batch_max_seq_len, ion_index_shape[0], ion_index_shape[1]),
                               np.float32)
        forward_ion_index = np.stack(data.forward_ion_location_index_list)
        ion_index[:forward_ion_index.shape[0], :, :] = forward_ion_index
        batch_forward_ion_index.append(ion_index)

        f_target = np.zeros((batch_max_seq_len,), np.int64)
        forward_target = np.array(data.forward_id_target, np.int64)
        f_target[:forward_target.shape[0]] = forward_target
        batch_forward_id_target.append(f_target)

        f_input = np.zeros((batch_max_seq_len,), np.int64)
        forward_input = np.array(data.forward_id_input, np.int64)
        f_input[:forward_input.shape[0]] = forward_input
        batch_forward_id_input.append(f_input)



    batch_forward_id_target = torch.from_numpy(np.stack(batch_forward_id_target))  # [batch_size, T]
    batch_forward_ion_index = torch.from_numpy(np.stack(batch_forward_ion_index))  # [batch, T, 26, 8]
    batch_forward_id_input = torch.from_numpy(np.stack(batch_forward_id_input))

    batch_backward_ion_index = []
    batch_backward_id_target = []
    batch_backward_id_input = []
    for data in train_data_list:
        ion_index = np.zeros((batch_max_seq_len, ion_index_shape[0], ion_index_shape[1]),
                             np.float32)
        backward_ion_index = np.stack(data.backward_ion_location_index_list)
        ion_index[:backward_ion_index.shape[0], :, :] = backward_ion_index
        batch_backward_ion_index.append(ion_index)

        b_target = np.zeros((batch_max_seq_len,), np.int64)
        backward_target = np.array(data.backward_id_target, np.int64)
        b_target[:backward_target.shape[0]] = backward_target
        batch_backward_id_target.append(b_target)

        b_input = np.zeros((batch_max_seq_len,), np.int64)
        backward_input = np.array(data.backward_id_input, np.int64)
        b_input[:backward_input.shape[0]] = backward_input
        batch_backward_id_input.append(b_input)

    batch_backward_id_target = torch.from_numpy(np.stack(batch_backward_id_target))  # [batch_size, T]
    batch_backward_ion_index = torch.from_numpy(np.stack(batch_backward_ion_index))  # [batch, T, 26, 8]
    batch_backward_id_input = torch.from_numpy(np.stack(batch_backward_id_input))

    return (peak_location,
            peak_intensity,
            spectrum_representation,
            batch_forward_id_target,
            batch_backward_id_target,
            batch_forward_ion_index,
            batch_backward_ion_index,
            batch_forward_id_input,
            batch_backward_id_input
            )

## Extract & Move Data <a name="extract"></a>

In [18]:
def extract_and_move_data(data):
    """
    This function extract the data from the dataclass pytorch and put it in the device
    used (CPU/RAM or GPU). This speed-up the process.
    :param data: result from dataloader
    :return:
    """
    peak_location, \
    peak_intensity, \
    spectrum_representation,\
    batch_forward_id_target, \
    batch_backward_id_target, \
    batch_forward_ion_index, \
    batch_backward_ion_index, \
    batch_forward_id_input, \
    batch_backward_id_input = data

    # move to device
    peak_location = peak_location.to(device)
    peak_intensity = peak_intensity.to(device)
    spectrum_representation = spectrum_representation.to(device)
    batch_forward_id_target = batch_forward_id_target.to(device)
    batch_backward_id_target = batch_backward_id_target.to(device)
    batch_forward_ion_index = batch_forward_ion_index.to(device)
    batch_backward_ion_index = batch_backward_ion_index.to(device)
    batch_forward_id_input = batch_forward_id_input.to(device)
    batch_backward_id_input = batch_backward_id_input.to(device)
    return (peak_location,
            peak_intensity,
            spectrum_representation,
            batch_forward_id_target,
            batch_backward_id_target,
            batch_forward_ion_index,
            batch_backward_ion_index,
            batch_forward_id_input,
            batch_backward_id_input
            )


## Focal loss function <a name="loss"></a>
Focal loss focuses training on a sparse set of hard examples and prevents the vast number of easy negatives from overwhelming the detector during training. See [Focal Loss for Dense Object Detection](https://arxiv.org/pdf/1708.02002.pdf).

In [19]:
def focal_loss(logits, labels, ignore_index=-100, gamma=2.):
    """
    :param logits: float tensor of shape [batch, T, 26]
    :param labels: long tensor of shape [batch, T]
    :param ignore_index: ignore the loss of those tokens
    :param gamma:
    :return: average loss, num_valid_token
    """
    valid_token_mask = (labels != ignore_index).float()  # [batch, T]
    num_valid_token = torch.sum(valid_token_mask)
    batch_size, T, num_classes = logits.size()
    sigmoid_p = torch.sigmoid(logits) #sigmoid used instead of soft_max
    target_tensor = to_one_hot(labels, n_dims=num_classes).float().to(device)
    zeros = torch.zeros_like(sigmoid_p)
    pos_p_sub = torch.where(target_tensor >= sigmoid_p, target_tensor - sigmoid_p, zeros)  # [batch, T, 26]
    neg_p_sub = torch.where(target_tensor > zeros, zeros, sigmoid_p)  # [batch, T, 26]

    per_token_loss = - (pos_p_sub ** gamma) * torch.log(torch.clamp(sigmoid_p, 1e-8, 1.0)) - \
                     (neg_p_sub ** gamma) * torch.log(torch.clamp(1.0 - sigmoid_p, 1e-8, 1.0))
    per_entry_loss = torch.sum(per_token_loss, dim=2)  # [batch, T]
    per_entry_loss = per_entry_loss * valid_token_mask  # masking out loss from pad tokens

    per_entry_average_loss = torch.sum(per_entry_loss) / (num_valid_token + 1e-6)
    return per_entry_average_loss, num_valid_token

### to_one_hot

In [20]:
def to_one_hot(y, n_dims=None):
    """ Take integer y with n dims and convert it to 1-hot representation with n+1 dims. 
    c.f. Report 2.5
    """
    y_tensor = y.data
    y_tensor = y_tensor.type(torch.LongTensor).view(-1, 1)
    n_dims = n_dims if n_dims is not None else int(torch.max(y_tensor)) + 1
    y_one_hot = torch.zeros(y_tensor.size()[0], n_dims).scatter_(1, y_tensor, 1)
    y_one_hot = y_one_hot.view(*y.shape, -1)
    return y_one_hot

## Validation Function <a name="valid_func"></a>

In [21]:
def validation(forward_deepnovo, backward_deepnovo, init_net, valid_loader) -> float:
    """ Compute and return the loss of the model trained on Validation dataset (valid_loader) """
    with torch.no_grad():
        valid_loss = 0
        num_valid_samples = 0
        for data in valid_loader:
            peak_location, \
            peak_intensity, \
            spectrum_representation, \
            batch_forward_id_target, \
            batch_backward_id_target, \
            batch_forward_ion_index, \
            batch_backward_ion_index, \
            batch_forward_id_input, \
            batch_backward_id_input = extract_and_move_data(data)
            batch_size = batch_backward_id_target.size(0)
            #evaluate the validation data through the model
            if use_lstm:
                initial_state_tuple = init_net(spectrum_representation)
                forward_logit, _ = forward_deepnovo(batch_forward_ion_index, peak_location, peak_intensity,
                                                    batch_forward_id_input, initial_state_tuple)
                backward_logit, _ = backward_deepnovo(batch_backward_ion_index, peak_location, peak_intensity,
                                                      batch_backward_id_input, initial_state_tuple)
            else:
                forward_logit = forward_deepnovo(batch_forward_ion_index, peak_location, peak_intensity)
                backward_logit = backward_deepnovo(batch_backward_ion_index, peak_location, peak_intensity)
            #compute the loss
            forward_loss, f_num = focal_loss(forward_logit, batch_forward_id_target, ignore_index=0, gamma=2.)
            backward_loss, b_num = focal_loss(backward_logit, batch_backward_id_target, ignore_index=0, gamma=2.)
            valid_loss += forward_loss.item() * f_num.item() + backward_loss.item() * b_num.item()
            num_valid_samples += f_num.item() + b_num.item()
    #average the loss on all validation samples        
    average_valid_loss = valid_loss / (num_valid_samples + 1e-6)
    return float(average_valid_loss)

## Cython Integration <a name="cython"></a>

The following functions were initially designed to be build with Cython. Since Cython isn't realiable when working with Jupyter notebooks, they have been redesigned to work in plain Python. This impacts perfomance, but shouldn't prevent the use of this notebook on moderately sized datasets.

In [None]:
#---------------------------------cython------------------------------------
def get_sinusoid_encoding_table(n_position, embed_size, padding_idx=0):
    """ Sinusoid position encoding table
    n_position: maximum integer that the embedding op could receive
    embed_size: embed size
    return
      a embedding matrix of shape [n_position, embed_size]
    """

    def cal_angle(position, hid_idx):
        return position / np.power(sinusoid_base, 2 * (hid_idx // 2) / embed_size)

    def get_posi_angle_vec(position):
        return [cal_angle(position, hid_j) for hid_j in range(embed_size)]

    sinusoid_matrix = np.array([get_posi_angle_vec(pos_i) for pos_i in range(n_position + 1)], dtype=np.float32)

    sinusoid_matrix[:, 0::2] = np.sin(sinusoid_matrix[:, 0::2])  # dim 2i
    sinusoid_matrix[:, 1::2] = np.cos(sinusoid_matrix[:, 1::2])  # dim 2i+1

    sinusoid_matrix[padding_idx] = 0.
    return sinusoid_matrix

sinusoid_matrix = get_sinusoid_encoding_table(n_position, embedding_size,
                                              padding_idx=PAD_ID)

@cython.boundscheck(False) # turn off bounds-checking
@cython.wraparound(False) # turn off negative index wrapping
def get_ion_index(peptide_mass, prefix_mass, direction):
  """
  :param peptide_mass: neutral mass of a peptide
  :param prefix_mass:
  :param direction: 0 for forward, 1 for backward
  :return: an int32 ndarray of shape [26, 8], each element represent a index of the spectrum embbeding matrix. for out
  of bound position, the index is 0
  """
  if direction == 0:
    candidate_b_mass = prefix_mass + mass_ID_np
    candidate_y_mass = peptide_mass - candidate_b_mass
  elif direction == 1:
    candidate_y_mass = prefix_mass + mass_ID_np
    candidate_b_mass = peptide_mass - candidate_y_mass
  candidate_a_mass = candidate_b_mass - mass_CO

  # b-ions
  candidate_b_H2O = candidate_b_mass - mass_H2O
  candidate_b_NH3 = candidate_b_mass - mass_NH3
  candidate_b_plus2_charge1 = ((candidate_b_mass + 2 * mass_H) / 2
                               - mass_H)

  # a-ions
  candidate_a_H2O = candidate_a_mass - mass_H2O
  candidate_a_NH3 = candidate_a_mass - mass_NH3
  candidate_a_plus2_charge1 = ((candidate_a_mass + 2 * mass_H) / 2
                               - mass_H)

  # y-ions
  candidate_y_H2O = candidate_y_mass - mass_H2O
  candidate_y_NH3 = candidate_y_mass - mass_NH3
  candidate_y_plus2_charge1 = ((candidate_y_mass + 2 * mass_H) / 2
                               - mass_H)

  # ion_8
  b_ions = [candidate_b_mass,
            candidate_b_H2O,
            candidate_b_NH3,
            candidate_b_plus2_charge1]
  y_ions = [candidate_y_mass,
            candidate_y_H2O,
            candidate_y_NH3,
            candidate_y_plus2_charge1]
  a_ions = [candidate_a_mass,
            candidate_a_H2O,
            candidate_a_NH3,
            candidate_a_plus2_charge1]
  ion_mass_list = b_ions + y_ions + a_ions
  ion_mass = np.array(ion_mass_list, dtype=np.float32)  # 8 by 26

  # ion locations
  in_bound_mask = np.logical_and(
      ion_mass > 0,
      ion_mass <= MZ_MAX).astype(np.float32)
  ion_location = ion_mass * in_bound_mask  # 8 by 26, out of bound index would have value 0
  return ion_location.transpose()  # 26 by 8


def pad_to_length(data: list, length, pad_token=0.):
  """
  pad data to length if len(data) is smaller than length
  :param data:
  :param length:
  :param pad_token:
  :return:
  """
  for i in range(length - len(data)):
    data.append(pad_token)

In [23]:
def process_peaks(spectrum_mz_list, spectrum_intensity_list, peptide_mass):
  """

  :param spectrum_mz_list:
  :param spectrum_intensity_list:
  :param peptide_mass: peptide neutral mass
  :return:
    peak_location: int64, [N]
    peak_intensity: float32, [N]
    spectrum_representation: float32 [embedding_size]
  """
  charge = 1.0
  spectrum_intensity_max = np.max(spectrum_intensity_list)
  # charge 1 peptide location
  spectrum_mz_list.append(peptide_mass + charge*mass_H)
  spectrum_intensity_list.append(spectrum_intensity_max)

  # N-terminal, b-ion, peptide_mass_C
  # append N-terminal
  mass_N = mass_N_terminus - mass_H
  spectrum_mz_list.append(mass_N + charge*mass_H)
  spectrum_intensity_list.append(spectrum_intensity_max)
  # append peptide_mass_C
  mass_C = mass_C_terminus + mass_H
  peptide_mass_C = peptide_mass - mass_C
  spectrum_mz_list.append(peptide_mass_C + charge*mass_H)
  spectrum_intensity_list.append(spectrum_intensity_max)

  # C-terminal, y-ion, peptide_mass_N
  # append C-terminal
  mass_C = mass_C_terminus + mass_H
  spectrum_mz_list.append(mass_C + charge*mass_H)
  spectrum_intensity_list.append(spectrum_intensity_max)


  pad_to_length(spectrum_mz_list, MAX_NUM_PEAK)
  pad_to_length(spectrum_intensity_list, MAX_NUM_PEAK)

  spectrum_mz = np.array(spectrum_mz_list, dtype=np.float32)
  spectrum_mz_location = np.ceil(spectrum_mz * spectrum_reso).astype(np.int32)

  neutral_mass = spectrum_mz - charge*mass_H
  in_bound_mask = np.logical_and(neutral_mass > 0., neutral_mass < MZ_MAX)
  neutral_mass[~in_bound_mask] = 0.
  # intensity
  spectrum_intensity = np.array(spectrum_intensity_list, dtype=np.float32)
  norm_intensity = spectrum_intensity / spectrum_intensity_max

  spectrum_representation = np.zeros(embedding_size, dtype=np.float32)
  for i, loc in enumerate(spectrum_mz_location):
    if loc < 0.5 or loc > n_position:
      continue
    else:
      spectrum_representation += sinusoid_matrix[loc] * norm_intensity[i]

  top_N_indices = np.argpartition(norm_intensity, -MAX_NUM_PEAK)[-MAX_NUM_PEAK:]
  intensity = norm_intensity[top_N_indices]
  mass_location = neutral_mass[top_N_indices]

  return mass_location, intensity, spectrum_representation

## Training <a name="train"></a>
This cell performs the training using the `train` function.

In [24]:
if 'train' in option:
  train()

import matplotlib.pyplot as plt
plt.plot(training_perp_tab)

input spectrum file: ../smbp_data/spectrum_smbp.mgf
input feature file: ../smbp_data/features_smbp.csv.train
read cached spectrum locations
read 5416 features, 3 skipped by mass, 0 skipped by unknown modification, 1 skipped by length
225 steps per epoch
input spectrum file: ../smbp_data/spectrum_smbp.mgf
input feature file: ../smbp_data/features_smbp.csv.valid
read cached spectrum locations
read 535 features, 3 skipped by mass, 0 skipped by unknown modification, 0 skipped by length
initialize a set of new parameters


  0%|          | 0/5 [00:00<?, ?it/s]

epoch: 0	lr: 0.001
epoch 0 step 9/225, train perplexity: 2.793949438907112	validation perplexity: 3.202474555941306	step time: 0.3664912939071655
best valid loss achieved at epoch 0 step 9
epoch 0 step 19/225, train perplexity: 2.3534467970607693	validation perplexity: 3.0489873399027534	step time: 0.3393018007278442
best valid loss achieved at epoch 0 step 19
epoch 0 step 29/225, train perplexity: 2.032802122952982	validation perplexity: 3.010252387204791	step time: 0.3487901210784912
best valid loss achieved at epoch 0 step 29
epoch 0 step 39/225, train perplexity: 1.9092910071362985	validation perplexity: 2.8932988400621875	step time: 0.3349908351898193
best valid loss achieved at epoch 0 step 39
epoch 0 step 49/225, train perplexity: 1.7135304974646226	validation perplexity: 2.845209593261391	step time: 0.34039170742034913
best valid loss achieved at epoch 0 step 49
epoch 0 step 59/225, train perplexity: 1.5732570229532956	validation perplexity: 2.700769023000199	step time: 0.32280

 20%|██        | 1/5 [02:22<09:31, 142.94s/it]

epoch: 1	lr: 0.001
epoch 1 step 9/225, train perplexity: 1.172215155559618	validation perplexity: 1.3334662389133904	step time: 0.5906629800796509
best valid loss achieved at epoch 1 step 9
epoch 1 step 19/225, train perplexity: 1.1772523921050537	validation perplexity: 1.3518329869320636	step time: 0.34897186756134035
epoch 1 step 29/225, train perplexity: 1.1558845733010679	validation perplexity: 1.3523456532605407	step time: 0.35559179782867434
epoch 1 step 39/225, train perplexity: 1.1241122589462436	validation perplexity: 1.3650543767004955	step time: 0.3615119218826294
epoch 1 step 49/225, train perplexity: 1.111399195906675	validation perplexity: 1.341771420536697	step time: 0.33811635971069337
epoch 1 step 59/225, train perplexity: 1.120864433144194	validation perplexity: 1.343696545088079	step time: 0.3493918180465698
epoch 1 step 69/225, train perplexity: 1.1057151261201954	validation perplexity: 1.3465057765158925	step time: 0.3385584592819214
epoch 1 step 79/225, train perp

 40%|████      | 2/5 [04:46<07:09, 143.16s/it]

epoch: 2	lr: 0.0001
epoch 2 step 9/225, train perplexity: 1.0980431424647514	validation perplexity: 1.3502115444916425	step time: 0.552337908744812
epoch 2 step 19/225, train perplexity: 1.0558672985385762	validation perplexity: 1.3519723678809081	step time: 0.37793865203857424
epoch 2 step 29/225, train perplexity: 1.0924419086348152	validation perplexity: 1.3595432416661584	step time: 0.36386895179748535
epoch 2 step 39/225, train perplexity: 1.0600182439531747	validation perplexity: 1.3609946956874723	step time: 0.33583109378814696
epoch 2 step 49/225, train perplexity: 1.1212722827409791	validation perplexity: 1.3585266934074272	step time: 0.346697473526001
epoch 2 step 59/225, train perplexity: 1.0697108736686187	validation perplexity: 1.3594798895822757	step time: 0.357592511177063
epoch 2 step 69/225, train perplexity: 1.087192794726473	validation perplexity: 1.361020876805671	step time: 0.3531265020370483
epoch 2 step 79/225, train perplexity: 1.052450996801551	validation perpl

 60%|██████    | 3/5 [07:08<04:45, 142.88s/it]

epoch: 3	lr: 0.0001
epoch 3 step 9/225, train perplexity: 1.0846711127040323	validation perplexity: 1.3633509942369744	step time: 0.5644603729248047
epoch 3 step 19/225, train perplexity: 1.0613432399536489	validation perplexity: 1.3640216368589415	step time: 0.3451084613800049
epoch 3 step 29/225, train perplexity: 1.0549861729861696	validation perplexity: 1.3655636501363124	step time: 0.3754934310913086
epoch 3 step 39/225, train perplexity: 1.0454554172437502	validation perplexity: 1.3687365954872266	step time: 0.3483560562133789
epoch 3 step 49/225, train perplexity: 1.0576810326132808	validation perplexity: 1.3671143645738422	step time: 0.3267146348953247
epoch 3 step 59/225, train perplexity: 1.039424345051049	validation perplexity: 1.3623808595622635	step time: 0.3526721954345703
epoch 3 step 69/225, train perplexity: 1.075568632611243	validation perplexity: 1.3615443104887683	step time: 0.33495924472808836
epoch 3 step 79/225, train perplexity: 1.057473051548844	validation perp

 80%|████████  | 4/5 [09:30<02:22, 142.57s/it]

epoch: 4	lr: 0.0001
epoch 4 step 9/225, train perplexity: 1.0632316808938074	validation perplexity: 1.3631925384064778	step time: 0.5842698097229004
epoch 4 step 19/225, train perplexity: 1.0780149365996794	validation perplexity: 1.3603414359087938	step time: 0.36088201999664304
epoch 4 step 29/225, train perplexity: 1.0442115174988875	validation perplexity: 1.3531789437816026	step time: 0.3615744113922119
epoch 4 step 39/225, train perplexity: 1.04607077375233	validation perplexity: 1.347204982760071	step time: 0.34669697284698486
epoch 4 step 49/225, train perplexity: 1.0576768718031175	validation perplexity: 1.3544894955700817	step time: 0.3287489652633667
epoch 4 step 59/225, train perplexity: 1.0833791853195465	validation perplexity: 1.358630538979704	step time: 0.36172804832458494
epoch 4 step 69/225, train perplexity: 1.0466224225858682	validation perplexity: 1.3605016921395763	step time: 0.34485957622528074
epoch 4 step 79/225, train perplexity: 1.0784034262293736	validation pe

100%|██████████| 5/5 [11:52<00:00, 142.43s/it]


best model at epoch 1 step 89


[<matplotlib.lines.Line2D at 0x7fbd394ce150>]

## Validation

In [27]:
if 'valid' in option:
  #initialize the validation data in DataClass
  valid_set = DeepNovoTrainDataset(input_feature_file_valid,
                                          input_spectrum_file_valid)
  valid_data_loader = torch.utils.data.DataLoader(dataset=valid_set,
                                                          batch_size=batch_size,
                                                          shuffle=False,
                                                          num_workers=num_workers,
                                                          collate_fn=collate_func)
  #Feed validation data to the Model
  forward_deepnovo, backward_deepnovo, init_net = build_model(training=False)
  forward_deepnovo.eval()
  backward_deepnovo.eval()
  #compute validation loss with the validation function
  validation_loss = validation(forward_deepnovo, backward_deepnovo, init_net, valid_data_loader)
  print(f"validation perplexity: {perplexity(validation_loss)}")

input spectrum file: ../smbp_data/spectrum_smbp.mgf
input feature file: ../smbp_data/features_smbp.csv.valid
read cached spectrum locations
read 535 features, 3 skipped by mass, 0 skipped by unknown modification, 0 skipped by length
load pretrained model
validation perplexity: 1.3306849783659758


# Denovo Prediction <a name="denovo"></a>

## Denovo Path <a name="denovo_path"></a>

In [28]:
# denovo path files
denovo_input_spectrum_file = "../smbp_data/spectrum_smbp.mgf"
denovo_input_feature_file = "../smbp_data/features_smbp.csv.test"

#Path of the denovo output predictions
denovo_output_file = denovo_input_feature_file + ".deepnovo_denovo"

predicted_format = "deepnovo"
target_file = denovo_input_feature_file
predicted_file = denovo_output_file

#Name of the files produced after denovo prediction
accuracy_file = predicted_file + ".accuracy"
denovo_only_file = predicted_file + ".denovo_only"
scan2fea_file = predicted_file + ".scan2fea"
multifea_file = predicted_file + ".multifea"

## Denovo Dataset Data reader <a name="denovo_datareader"></a>

In [30]:
class DeepNovoDenovoDataset(DeepNovoTrainDataset):
    # override _get_feature method in the dataclass DeepNovoTrainDataset
    def _get_feature(self, feature: DDAFeature) -> DenovoData:
        spectrum_location = self.spectrum_location_dict[feature.scan]
        self.input_spectrum_handle.seek(spectrum_location)
        # parse header lines
        line = self.input_spectrum_handle.readline()
        assert "BEGIN IONS" in line, "Error: wrong input BEGIN IONS"
        line = self.input_spectrum_handle.readline()
        assert "TITLE=" in line, "Error: wrong input TITLE="
        line = self.input_spectrum_handle.readline()
        assert "PEPMASS=" in line, "Error: wrong input PEPMASS="
        line = self.input_spectrum_handle.readline()
        assert "CHARGE=" in line, "Error: wrong input CHARGE="
        line = self.input_spectrum_handle.readline()
        assert "SCANS=" in line, "Error: wrong input SCANS="
        line = self.input_spectrum_handle.readline()
        assert "RTINSECONDS=" in line, "Error: wrong input RTINSECONDS="
        mz_list, intensity_list = self._parse_spectrum_ion()
        peak_location, peak_intensity, spectrum_representation = process_peaks(mz_list, intensity_list, feature.mass)

        return DenovoData(peak_location=peak_location,
                          peak_intensity=peak_intensity,
                          spectrum_representation=spectrum_representation,
                          original_dda_feature=feature)

## Writer Functions <a name="denovo_writer"></a>

In [31]:
@dataclass
class BeamSearchedSequence:
    sequence: list  # list of aa id
    position_score: list
    score: float  # average by length score


class DenovoWriter(object):
    def __init__(self, denovo_output_file):
        self.output_handle = open(denovo_output_file, 'w')
        header_list = ["feature_id",
                       "feature_area",
                       "predicted_sequence",
                       "predicted_score",
                       "predicted_position_score",
                       "precursor_mz",
                       "precursor_charge",
                       "protein_access_id",
                       "scan_list_middle",
                       "scan_list_original",
                       "predicted_score_max"]
        header_row = "\t".join(header_list)
        print(header_row, file=self.output_handle, end='\n')

    def close(self):
        self.output_handle.close()

    def write(self, dda_original_feature: DDAFeature, searched_sequence: BeamSearchedSequence):
        """
        Write the rows of prediction, outputed from the BeamSearch
        :param dda_original_feature:
        :param searched_sequence:
        :return: -> Print predicted row
        """
        feature_id = dda_original_feature.feature_id
        feature_area = dda_original_feature.feature_area
        precursor_mz = str(dda_original_feature.mz)
        precursor_charge = str(dda_original_feature.z)
        scan_list_middle = dda_original_feature.scan
        scan_list_original = dda_original_feature.scan
        if searched_sequence.sequence:
            predicted_sequence = ','.join([vocab_reverse[aa_id] for
                                           aa_id in searched_sequence.sequence])
            predicted_score = "{:.2f}".format(searched_sequence.score)
            predicted_score_max = predicted_score
            predicted_position_score = ','.join(['{0:.2f}'.format(x) for x in searched_sequence.position_score])
            protein_access_id = 'DENOVO'
        else:
            predicted_sequence = ""
            predicted_score = ""
            predicted_score_max = ""
            predicted_position_score = ""
            protein_access_id = ""
        predicted_row = "\t".join([feature_id,
                                   feature_area,
                                   predicted_sequence,
                                   predicted_score,
                                   predicted_position_score,
                                   precursor_mz,
                                   precursor_charge,
                                   protein_access_id,
                                   scan_list_middle,
                                   scan_list_original,
                                   predicted_score_max])
        print(predicted_row, file=self.output_handle, end="\n")

    def __del__(self):
        self.close()

## KnapSack Implementation <a name="knapsack"></a>
If the file knapsack.py is not present in the root path of the program, it will build a new one. (it takes time to build, ~40min with Colab).
But you can upload it from the data included as it is a constant. 

In [32]:
class Direction(Enum):
    forward = 1
    backward = 2

@dataclass
class BeamSearchStartPoint:
    prefix_mass: float
    suffix_mass: float
    mass_tolerance: float
    direction: Direction


@dataclass
class DenovoResult:
    dda_feature: DDAFeature
    best_beam_search_sequence: BeamSearchedSequence


class KnapsackSearcher(object):
    """
    Implementation of the Knapsack Algorithm.
    If Knapsack not present in the source file, build a new one based on the parameters set
    in global variables (vocabulary and masses)
    """
    def __init__(self, MZ_MAX, knapsack_file):
        self.knapsack_file = knapsack_file
        self.MZ_MAX = MZ_MAX
        self.knapsack_aa_resolution = KNAPSACK_AA_RESOLUTION
        if os.path.isfile(knapsack_file):
            print("KnapsackSearcher.__init__(): load knapsack matrix")
            self.knapsack_matrix = np.load(knapsack_file)
        else:
            print("KnapsackSearcher.__init__(): build knapsack matrix from scratch")
            self.knapsack_matrix = self._build_knapsack()

    def _build_knapsack(self):
        max_mass = self.MZ_MAX - mass_N_terminus - mass_C_terminus
        max_mass_round = int(round(max_mass * self.knapsack_aa_resolution))
        max_mass_upperbound = max_mass_round + self.knapsack_aa_resolution
        knapsack_matrix = np.zeros(shape=(vocab_size, max_mass_upperbound), dtype=bool)
        for aa_id in tqdm(range(3, vocab_size)):
            mass_aa = int(round(mass_ID[aa_id] * self.knapsack_aa_resolution))

            for col in range(max_mass_upperbound):
                current_mass = col + 1
                if current_mass < mass_aa:
                    knapsack_matrix[aa_id, col] = False
                elif current_mass == mass_aa:
                    knapsack_matrix[aa_id, col] = True
                elif current_mass > mass_aa:
                    sub_mass = current_mass - mass_aa
                    sub_col = sub_mass - 1
                    if np.sum(knapsack_matrix[:, sub_col]) > 0:
                        knapsack_matrix[aa_id, col] = True
                        knapsack_matrix[:, col] = np.logical_or(knapsack_matrix[:, col], knapsack_matrix[:, sub_col])
                    else:
                        knapsack_matrix[aa_id, col] = False
        np.save(self.knapsack_file, knapsack_matrix)
        return knapsack_matrix

    def search_knapsack(self, mass, knapsack_tolerance):
        mass_round = int(round(mass * self.knapsack_aa_resolution))
        mass_upperbound = mass_round + knapsack_tolerance
        mass_lowerbound = mass_round - knapsack_tolerance
        if mass_upperbound < mass_AA_min_round:
            return []
        mass_lowerbound_col = mass_lowerbound - 1
        mass_upperbound_col = mass_upperbound - 1
        candidate_aa_id = np.flatnonzero(np.any(self.knapsack_matrix[:, mass_lowerbound_col:(mass_upperbound_col + 1)],
                                                axis=1))
        return candidate_aa_id.tolist()



## ION CNN Denovo <a name="ioncnn"></a>

In [33]:
@dataclass
class SearchPath:
    aa_id_list: list
    aa_seq_mass: float
    score_list: list
    score_sum: float
    lstm_state: tuple  # state tuple store in search path is of shape [num_lstm_layers, num_units]
    direction: Direction

@dataclass
class SearchEntry:
    feature_index: int
    current_path_list: list  # list of search paths
    spectrum_state: tuple  # tuple of (peak_location, peak_intensity)


class IonCNNDenovo(object):
    def __init__(self, MZ_MAX, knapsack_file, beam_size):
        self.MZ_MAX = MZ_MAX
        self.beam_size = beam_size
        self.knapsack_searcher = KnapsackSearcher(MZ_MAX, knapsack_file)

    def _beam_search(self, model_wrapper: InferenceModelWrapper,
                     feature_dp_batch: list, start_point_batch: list) -> list:
        """

        :param model_wrapper:
        :param feature_dp_batch: list of DenovoData
        :param start_point_batch:
        :return:
        """
        num_features = len(feature_dp_batch)
        top_path_batch = [[] for _ in range(num_features)]

        direction_cint_map = {Direction.forward: 0, Direction.backward: 1}

        direction = start_point_batch[0].direction
        if direction == Direction.forward:
            get_start_mass = lambda x: x.prefix_mass
            first_label = GO_ID
            last_label = EOS_ID
        elif direction == Direction.backward:
            get_start_mass = lambda x: x.suffix_mass
            first_label = EOS_ID
            last_label = GO_ID
        else:
            raise ValueError('direction neither forward nor backward')

        # step 1: extract original spectrum
        batch_peak_location = np.array([x.peak_location for x in feature_dp_batch])
        batch_peak_intensity = np.array([x.peak_intensity for x in feature_dp_batch])
        batch_spectrum_representation = np.array([x.spectrum_representation for x in feature_dp_batch])

        batch_peak_location = torch.from_numpy(batch_peak_location).to(device)
        batch_peak_intensity = torch.from_numpy(batch_peak_intensity).to(device)
        batch_spectrum_representation = torch.from_numpy(batch_spectrum_representation).to(device)

        initial_hidden_state_tuple = model_wrapper.initial_hidden_state(batch_spectrum_representation) if \
            use_lstm else None

        # initialize activate search list
        active_search_list = []
        for feature_index in range(num_features):
            # all feature in the same batch should be from same direction
            assert direction == start_point_batch[feature_index].direction

            spectrum_state = (batch_peak_location[feature_index], batch_peak_intensity[feature_index])

            if use_lstm:
                lstm_state_temp = (initial_hidden_state_tuple[0][:, feature_index, :],
                                   initial_hidden_state_tuple[1][:, feature_index, :]
                                   )
            else:
                lstm_state_temp = None

            path = SearchPath(
                aa_id_list=[first_label],
                aa_seq_mass=get_start_mass(start_point_batch[feature_index]),
                score_list=[0.0],
                score_sum=0.0,
                lstm_state=lstm_state_temp,
                direction=direction,
            )
            search_entry = SearchEntry(
                feature_index=feature_index,
                current_path_list=[path],
                spectrum_state=spectrum_state,
            )
            active_search_list.append(search_entry)

        # repeat STEP 2, 3, 4 until the active_search_list is empty.
        while True:
            # STEP 2: gather data from active search entries and group into blocks.

            # model input
            block_aa_id_input = []
            block_ion_location = []
            block_peak_location = []
            block_peak_intensity = []
            block_lstm_h = []
            block_lstm_c = []
            # data stored in path
            block_aa_id_list = []
            block_aa_seq_mass = []
            block_score_list = []
            block_score_sum = []
            block_knapsack_candidates = []

            # store the number of paths of each search entry in the big blocks
            #     to retrieve the info of each search entry later in STEP 4.
            search_entry_size = [0] * len(active_search_list)

            for entry_index, search_entry in enumerate(active_search_list):
                feature_index = search_entry.feature_index
                current_path_list = search_entry.current_path_list
                precursor_mass = feature_dp_batch[feature_index].original_dda_feature.mass
                peak_mass_tolerance = start_point_batch[feature_index].mass_tolerance

                for path in current_path_list:
                    aa_id_list = path.aa_id_list
                    aa_id = aa_id_list[-1]
                    score_sum = path.score_sum
                    aa_seq_mass = path.aa_seq_mass
                    score_list = path.score_list
                    original_spectrum_tuple = search_entry.spectrum_state
                    lstm_state_tuple = path.lstm_state

                    if aa_id == last_label:
                        if abs(aa_seq_mass - precursor_mass) <= peak_mass_tolerance:
                            seq = aa_id_list[1:-1]
                            trunc_score_list = score_list[1:-1]
                            if direction == Direction.backward:
                                seq = seq[::-1]
                                trunc_score_list = trunc_score_list[::-1]

                            top_path_batch[feature_index].append(
                                BeamSearchedSequence(sequence=seq,
                                                     position_score=trunc_score_list,
                                                     score=path.score_sum / len(seq))
                            )
                        continue

                    ion_location = get_ion_index(precursor_mass, aa_seq_mass, direction_cint_map[direction])  # [26,8]

                    residual_mass = precursor_mass - aa_seq_mass - mass_ID[last_label]
                    knapsack_tolerance = int(round(peak_mass_tolerance * KNAPSACK_AA_RESOLUTION))
                    knapsack_candidates = self.knapsack_searcher.search_knapsack(residual_mass, knapsack_tolerance)

                    if not knapsack_candidates:
                        # if not possible aa, force it to stop.
                        knapsack_candidates.append(last_label)

                    block_ion_location.append(ion_location)
                    block_aa_id_input.append(aa_id)
                    # get hidden state block
                    block_peak_location.append(original_spectrum_tuple[0])
                    block_peak_intensity.append(original_spectrum_tuple[1])
                    if use_lstm:
                        block_lstm_h.append(lstm_state_tuple[0])
                        block_lstm_c.append(lstm_state_tuple[1])

                    block_aa_id_list.append(aa_id_list)
                    block_aa_seq_mass.append(aa_seq_mass)
                    block_score_list.append(score_list)
                    block_score_sum.append(score_sum)
                    block_knapsack_candidates.append(knapsack_candidates)
                    # record the size of each search entry in the blocks
                    search_entry_size[entry_index] += 1

            # step 3 run model on data blocks to predict next AA.
            #     output is stored in current_log_prob
            # assert block_aa_id_list, 'IonCNNDenovo._beam_search(): aa_id_list is empty.'
            if not block_ion_location:
                # all search entry finished in the previous step
                break

            block_ion_location = torch.from_numpy(np.array(block_ion_location)).to(device)  # [batch, 26, 8, 10]
            block_ion_location = torch.unsqueeze(block_ion_location, dim=1)  # [batch, 1, 26, 8]
            block_peak_location = torch.stack(block_peak_location, dim=0).contiguous()
            block_peak_intensity = torch.stack(block_peak_intensity, dim=0).contiguous()
            if use_lstm:
                block_lstm_h = torch.stack(block_lstm_h, dim=1).contiguous()
                block_lstm_c = torch.stack(block_lstm_c, dim=1).contiguous()
                block_state_tuple = (block_lstm_h, block_lstm_c)
                block_aa_id_input = torch.from_numpy(np.array(block_aa_id_input, dtype=np.int64)).unsqueeze(1).to(
                    device)
            else:
                block_state_tuple = None
                block_aa_id_input = None

            current_log_prob, new_state_tuple = model_wrapper.step(block_ion_location,
                                                                   block_peak_location,
                                                                   block_peak_intensity,
                                                                   block_aa_id_input,
                                                                   block_state_tuple,
                                                                   direction)
            # transfer log_prob back to cpu
            current_log_prob = current_log_prob.cpu().numpy()

            # STEP 4: retrieve data from blocks to update the active_search_list
            #     with knapsack dynamic programming and beam search.
            block_index = 0
            for entry_index, search_entry in enumerate(active_search_list):
                new_path_list = []
                direction = search_entry.current_path_list[0].direction
                for index in range(block_index, block_index + search_entry_size[entry_index]):
                    for aa_id in block_knapsack_candidates[index]:
                        if aa_id > 2:
                            # do not add score of GO, EOS, PAD
                            new_score_list = block_score_list[index] + [current_log_prob[index][aa_id]]
                            new_score_sum = block_score_sum[index] + current_log_prob[index][aa_id]
                        else:
                            new_score_list = block_score_list[index] + [0.0]
                            new_score_sum = block_score_sum[index] + 0.0

                        if use_lstm:
                            new_path_state_tuple = (new_state_tuple[0][:, index, :], new_state_tuple[1][:, index, :])
                        else:
                            new_path_state_tuple = None

                        new_path = SearchPath(
                            aa_id_list=block_aa_id_list[index] + [aa_id],
                            aa_seq_mass=block_aa_seq_mass[index] + mass_ID[aa_id],
                            score_list=new_score_list,
                            score_sum=new_score_sum,
                            lstm_state=new_path_state_tuple,
                            direction=direction
                        )
                        new_path_list.append(new_path)
                if len(new_path_list) > self.beam_size:
                    new_path_score = np.array([x.score_sum for x in new_path_list])
                    top_k_index = np.argpartition(-new_path_score, self.beam_size)[:self.beam_size]
                    search_entry.current_path_list = [new_path_list[ii] for ii in top_k_index]
                else:
                    search_entry.current_path_list = new_path_list

                block_index += search_entry_size[entry_index]

            active_search_list = [x for x in active_search_list if x.current_path_list]

            if not active_search_list:
                break
        return top_path_batch

    @staticmethod
    def _get_start_point(feature_dp_batch: list) -> tuple:
        mass_GO = mass_ID[GO_ID]
        forward_start_point_lists = [BeamSearchStartPoint(prefix_mass=mass_GO,
                                                          suffix_mass=feature_dp.original_dda_feature.mass - mass_GO,
                                                          mass_tolerance=PRECURSOR_MASS_PRECISION_TOLERANCE,
                                                          direction=Direction.forward)
                                     for feature_dp in feature_dp_batch]

        mass_EOS = mass_ID[EOS_ID]
        backward_start_point_lists = [BeamSearchStartPoint(prefix_mass=feature_dp.original_dda_feature.mass - mass_EOS,
                                                           suffix_mass=mass_EOS,
                                                           mass_tolerance=PRECURSOR_MASS_PRECISION_TOLERANCE,
                                                           direction=Direction.backward)
                                      for feature_dp in feature_dp_batch]
        return forward_start_point_lists, backward_start_point_lists

    @staticmethod
    def _select_path(feature_dp_batch: list, top_candidate_batch: list) -> list:
        """
        for each feature, select the best denovo sequence given by DeepNovo model
        :param feature_dp_batch: list of DenovoData
        :param top_candidate_batch: defined in _search_denovo_batch
        :return:
        list of DenovoResult
        """
        feature_batch_size = len(feature_dp_batch)

        refine_batch = [[] for x in range(feature_batch_size)]
        for feature_index in range(feature_batch_size):
            precursor_mass = feature_dp_batch[feature_index].original_dda_feature.mass
            candidate_list = top_candidate_batch[feature_index]
            for beam_search_sequence in candidate_list:
                sequence = beam_search_sequence.sequence
                sequence_mass = sum(mass_ID[x] for x in sequence)
                sequence_mass += mass_ID[GO_ID] + mass_ID[EOS_ID]
                if abs(sequence_mass - precursor_mass) <= PRECURSOR_MASS_PRECISION_TOLERANCE:
                    refine_batch[feature_index].append(beam_search_sequence)
        predicted_batch = []
        for feature_index in range(feature_batch_size):
            candidate_list = refine_batch[feature_index]
            if not candidate_list:
                best_beam_search_sequence = BeamSearchedSequence(
                    sequence=[],
                    position_score=[],
                    score=-float('inf')
                )
            else:
                # sort candidate sequence by average position score
                best_beam_search_sequence = max(candidate_list, key=lambda x: x.score)

            denovo_result = DenovoResult(
                dda_feature=feature_dp_batch[feature_index].original_dda_feature,
                best_beam_search_sequence=best_beam_search_sequence
            )
            predicted_batch.append(denovo_result)
        return predicted_batch

    def _search_denovo_batch(self, feature_dp_batch: list, model_wrapper: InferenceModelWrapper) -> list:
        start_time = time.time()
        feature_batch_size = len(feature_dp_batch)
        start_points_tuple = self._get_start_point(feature_dp_batch)
        top_candidate_batch = [[] for x in range(feature_batch_size)]

        for start_points in start_points_tuple:
            beam_search_result_batch = self._beam_search(model_wrapper, feature_dp_batch, start_points)
            for feature_index in range(feature_batch_size):
                top_candidate_batch[feature_index].extend(beam_search_result_batch[feature_index])
        predicted_batch = self._select_path(feature_dp_batch, top_candidate_batch)
        test_time = time.time() - start_time
        print("beam_search(): batch time {}s".format(test_time))
        return predicted_batch

    def search_denovo(self, model_wrapper: InferenceModelWrapper,
                      beam_search_reader: DeepNovoDenovoDataset, denovo_writer: DenovoWriter):
        print("start beam search denovo")
        predicted_denovo_list = []

        test_set_iter = chunks(list(range(len(beam_search_reader))), n=batch_size)
        total_batch_num = int(len(beam_search_reader) / batch_size)
        for index, feature_batch_index in enumerate(test_set_iter):
            feature_dp_batch = [beam_search_reader[i] for i in feature_batch_index]
            print("Read {}th/{} batches".format(index, total_batch_num))
            predicted_batch = self._search_denovo_batch(feature_dp_batch, model_wrapper)
            predicted_denovo_list += predicted_batch
            for denovo_result in predicted_batch:
                denovo_writer.write(denovo_result.dda_feature, denovo_result.best_beam_search_sequence)

        return predicted_denovo_list


## Chunk function

In [34]:
def chunks(l, n: int):
    for i in range(0, len(l), n):
        yield l[i:i + n]

## Launch DeepNovo denovo prediction <a name="denovo_launch"></a>

In [35]:
beam_size = beam_size_param #by default 5
if 'denovo' in option:
  data_reader = DeepNovoDenovoDataset(feature_filename=denovo_input_feature_file,
                                              spectrum_filename=denovo_input_spectrum_file)
  print('data_reader done \n')
  print('IonCNNDenovo starting \n')
  denovo_worker = IonCNNDenovo(MZ_MAX,
                                      knapsack_file,
                                      beam_size=beam_size)
  print('Building Model ... \n')
  forward_deepnovo, backward_deepnovo, init_net = build_model(training=False)
  print('Wrapper \n')
  model_wrapper = InferenceModelWrapper(forward_deepnovo, backward_deepnovo, init_net)
  print('Writing ... \n')
  writer = DenovoWriter(denovo_output_file)

  denovo_worker.search_denovo(model_wrapper, data_reader, writer)

input spectrum file: ../smbp_data/spectrum_smbp.mgf
input feature file: ../smbp_data/features_smbp.csv.test
read cached spectrum locations
read 469 features, 0 skipped by mass, 0 skipped by unknown modification, 0 skipped by length
data_reader done 

IonCNNDenovo starting 

KnapsackSearcher.__init__(): load knapsack matrix
Building Model ... 

load pretrained model
Wrapper 

Writing ... 

start beam search denovo
Read 0th/29 batches




beam_search(): batch time 0.6671297550201416s
Read 1th/29 batches
beam_search(): batch time 0.7019972801208496s
Read 2th/29 batches
beam_search(): batch time 0.6950774192810059s
Read 3th/29 batches
beam_search(): batch time 0.6168060302734375s
Read 4th/29 batches
beam_search(): batch time 0.8356964588165283s
Read 5th/29 batches
beam_search(): batch time 0.753415584564209s
Read 6th/29 batches
beam_search(): batch time 0.7738871574401855s
Read 7th/29 batches
beam_search(): batch time 0.9844949245452881s
Read 8th/29 batches
beam_search(): batch time 0.6814014911651611s
Read 9th/29 batches
beam_search(): batch time 0.6944026947021484s
Read 10th/29 batches
beam_search(): batch time 0.538715124130249s
Read 11th/29 batches
beam_search(): batch time 0.579709529876709s
Read 12th/29 batches
beam_search(): batch time 0.5671288967132568s
Read 13th/29 batches
beam_search(): batch time 0.7232954502105713s
Read 14th/29 batches
beam_search(): batch time 0.6646101474761963s
Read 15th/29 batches
beam_se

# Testing Model <a name="test"></a>

## Testing File Path <a name="test_path"></a>

Select the path of the testing dataset (usualy part of the initial dataset randomely split in training, valid and testing parts)

In [36]:
input_spectrum_file_test = "../smbp_data/spectrum_smbp.mgf"
input_feature_file_test = "../smbp_data/feature_smbp.csv.test"

## Worker Test function <a name="test_worker"></a>

In [38]:
class WorkerTest(object):
  """
     The WorkerTest is a function that will compare the sequence predicted by the
     denovo function with the original sequence (if available) and display the accuracy
     of the model trained, on these 'denovo' testing sequences
  """


  def __init__(self):
    print("".join(["="] * 80)) # section-separating line
    print("WorkerTest.__init__()")

    #get all the variables needed
    self.MZ_MAX = MZ_MAX

    self.target_file = target_file #denovo input
    self.predicted_file = predicted_file #denovo output
    self.predicted_format = predicted_format # tag "deepnovo"
    
    #get the path for the output of the function
    self.accuracy_file = accuracy_file 
    self.denovo_only_file = denovo_only_file
    self.scan2fea_file = scan2fea_file
    self.multifea_file = multifea_file

    print("target_file = {0:s}".format(self.target_file))
    print("predicted_file = {0:s}".format(self.predicted_file))
    print("predicted_format = {0:s}".format(self.predicted_format))
    print("accuracy_file = {0:s}".format(self.accuracy_file))
    print("denovo_only_file = {0:s}".format(self.denovo_only_file))
    print("scan2fea_file = {0:s}".format(self.scan2fea_file))
    print("multifea_file = {0:s}".format(self.multifea_file))

    self.target_dict = {}
    self.predicted_list = []


  def test_accuracy(self, db_peptide_list=None):
    """"""

    print("".join(["="] * 80)) # section-separating line
    print("WorkerTest.test_accuracy()")

    # write the accuracy of predicted peptides
    accuracy_handle = open(self.accuracy_file, 'w')
    header_list = ["feature_id",
                   "feature_area",
                   "target_sequence",
                   "predicted_sequence",
                   "predicted_score",
                   "recall_AA",
                   "predicted_len",
                   "target_len",
                   "scan_list_middle",
                   "scan_list_original"]
    header_row = "\t".join(header_list)
    print(header_row, file=accuracy_handle, end="\n")

    # write denovo_only peptides (sequence with a reference to coompare)
    denovo_only_handle = open(self.denovo_only_file, 'w')
    header_list = ["feature_id",
                   "feature_area",
                   "predicted_sequence",
                   "predicted_score",
                   "predicted_score_max",
                   "scan_list_middle",
                   "scan_list_original"]
    header_row = "\t".join(header_list)
    print(header_row, file=denovo_only_handle, end="\n")

    self._get_target()
    target_count_total = len(self.target_dict)
    target_len_total = sum([len(x) for x in self.target_dict.values()])
    target_dict_db = {}

    #Depricated if no use of PEAKS DB to compare the sequence predicted (in SMBP case, use of MASCOT)

    # this part is tricky!
    # some target peptides are reported by PEAKS DB but not found in
    #   db_peptide_list due to mistakes in cleavage rules.
    # if db_peptide_list is given, we only consider those target peptides,
    #   otherwise, use all target peptides
    if db_peptide_list is not None:
      for feature_id, target in self.target_dict.items():
        target_simplied = target
        # remove the extension 'mod' from variable modifications
        target_simplied = ['M' if x=='M(Oxidation)' else x for x in target_simplied]
        target_simplied = ['N' if x=='N(Deamidation)' else x for x in target_simplied]
        target_simplied = ['Q' if x=='Q(Deamidation)' else x for x in target_simplied]

        #target_simplied = ['M' if x=='M(Oxidated)' else x for x in target_simplied]
        #target_simplied = ['Q' if x=='Q(Deamidated)' else x for x in target_simplied]
        #target_simplied = ['C' if x=='C(Carboxymethyl)' else x for x in target_simplied]

        if target_simplied in db_peptide_list:
          target_dict_db[feature_id] = target
        else:
          print("target not found: ", target_simplied)
    #==========================================================================
    else:
      target_dict_db = self.target_dict
    target_count_db = len(target_dict_db)
    target_len_db = sum([len(x) for x in target_dict_db.values()])

    # we also skip target peptides with precursor_mass > MZ_MAX
    target_dict_db_mass = {}
    for feature_id, peptide in target_dict_db.items():
      if self._compute_peptide_mass(peptide) <= self.MZ_MAX:
        target_dict_db_mass[feature_id] = peptide
    target_count_db_mass = len(target_dict_db_mass)
    target_len_db_mass = sum([len(x) for x in target_dict_db_mass.values()])

    # read predicted peptides from deepnovo or peaks
    if predicted_format == "deepnovo":
      self._get_predicted()
    else:
      self._get_predicted_peaks()

    # note that the prediction has already skipped precursor_mass > MZ_MAX
    # we also skip predicted peptides whose feature_id's are not in target_dict_db_mass
    predicted_count_mass = len(self.predicted_list)
    predicted_count_mass_db = 0
    predicted_len_mass_db = 0
    predicted_only = 0
    # the recall is calculated on remaining peptides
    recall_AA_total = 0.0
    recall_peptide_total = 0.0

    # record scan with multiple features
    scan_dict = {}

    for index, predicted in enumerate(self.predicted_list):

      feature_id = predicted["feature_id"]
      feature_area = str(predicted["feature_area"])
      feature_scan_list_middle = predicted["scan_list_middle"]
      feature_scan_list_original = predicted["scan_list_original"]
      if feature_scan_list_original:
        for scan in re.split(';|\r|\n', feature_scan_list_original):
          if scan in scan_dict:
            scan_dict[scan]["feature_count"] += 1
            scan_dict[scan]["feature_list"].append(feature_id)
          else:
            scan_dict[scan] = {}
            scan_dict[scan]["feature_count"] = 1
            scan_dict[scan]["feature_list"] = [feature_id]

      if feature_id in target_dict_db_mass:

        predicted_count_mass_db += 1

        target = target_dict_db_mass[feature_id]
        target_len= len(target)

        # if >= 1 denovo peptides reported, calculate the best accuracy
        best_recall_AA = 0
        best_predicted_sequence = predicted["sequence"][0]
        best_predicted_score = predicted["score"][0]
        for predicted_sequence, predicted_score in zip(predicted["sequence"], predicted["score"]):
          predicted_AA_id = [vocab[x] for x in predicted_sequence]
          target_AA_id = [vocab[x] for x in target]
          recall_AA = self._match_AA_novor(target_AA_id, predicted_AA_id)
          if (recall_AA > best_recall_AA
              or (recall_AA == best_recall_AA and predicted_score > best_predicted_score)):
            best_recall_AA = recall_AA
            best_predicted_sequence = predicted_sequence[:]
            best_predicted_score = predicted_score
        recall_AA = best_recall_AA
        predicted_sequence = best_predicted_sequence[:]
        predicted_score = best_predicted_score

        recall_AA_total += recall_AA
        if recall_AA == target_len:
          recall_peptide_total += 1
        predicted_len= len(predicted_sequence)
        predicted_len_mass_db += predicted_len

        # convert to string format to print out
        target_sequence = ",".join(target)
        predicted_sequence = ",".join(predicted_sequence)
        predicted_score = "{0:.2f}".format(predicted_score)
        recall_AA = "{0:d}".format(recall_AA)
        predicted_len = "{0:d}".format(predicted_len)
        target_len = "{0:d}".format(target_len)
        print_list = [feature_id,
                      feature_area,
                      target_sequence,
                      predicted_sequence,
                      predicted_score,
                      recall_AA,
                      predicted_len,
                      target_len,
                      feature_scan_list_middle,
                      feature_scan_list_original]
        print_row = "\t".join(print_list)
        print(print_row, file=accuracy_handle, end="\n")
      else:
        predicted_only += 1
        predicted_sequence = ';'.join([','.join(x) for x in predicted["sequence"]])
        predicted_score = ';'.join(['{0:.2f}'.format(x) for x in predicted["score"]])
        if predicted["score"]:
          predicted_score_max = '{0:.2f}'.format(np.max(predicted["score"]))
        else:
          predicted_score_max = ''
        print_list = [feature_id,
                      feature_area,
                      predicted_sequence,
                      predicted_score,
                      predicted_score_max,
                      feature_scan_list_middle,
                      feature_scan_list_original]
        print_row = "\t".join(print_list)
        print(print_row, file=denovo_only_handle, end="\n")

    accuracy_handle.close()
    denovo_only_handle.close()

    multifea_dict = {}
    for scan_id, value in scan_dict.items():
      feature_count = value["feature_count"]
      feature_list = value["feature_list"]
      if feature_count > 1:
        for feature_id in feature_list:
          if feature_id in multifea_dict:
            multifea_dict[feature_id].append(scan_id + ':' + str(feature_count))
          else:
            multifea_dict[feature_id] = [scan_id + ':' + str(feature_count)]

    #scan2fea_file. Display the nomber of identic scan in a the spectrum
    with open(self.scan2fea_file, 'w') as handle:
      header_list = ["scan_id",
                     "feature_count",
                     "feature_list"]
      header_row = "\t".join(header_list)
      print(header_row, file=handle, end="\n")
      for scan_id, value in scan_dict.items():
        print_list = [scan_id,
                      str(value["feature_count"]),
                      ";".join(value["feature_list"])]
        print_row = "\t".join(print_list)
        print(print_row, file=handle, end="\n")
    #multifea_file. Display wich scans are doubles in spectrum
    with open(self.multifea_file, 'w') as handle:
      header_list = ["feature_id",
                     "scan_list"]
      header_row = "\t".join(header_list)
      print(header_row, file=handle, end="\n")
      for feature_id, scan_list in multifea_dict.items():
        print_list = [feature_id,
                      ";".join(scan_list)]
        print_row = "\t".join(print_list)
        print(print_row, file=handle, end="\n")

    print("target_count_total = {0:d}".format(target_count_total))
    print("target_len_total = {0:d}".format(target_len_total))
    print("target_count_db = {0:d}".format(target_count_db))
    print("target_len_db = {0:d}".format(target_len_db))
    print("target_count_db_mass: {0:d}".format(target_count_db_mass))
    print("target_len_db_mass: {0:d}".format(target_len_db_mass))

    print("predicted_count_mass: {0:d}".format(predicted_count_mass))
    print("predicted_count_mass_db: {0:d}".format(predicted_count_mass_db))
    print("predicted_len_mass_db: {0:d}".format(predicted_len_mass_db))
    print("predicted_only: {0:d}".format(predicted_only))

    #if added in case of failed training to avoid an error.
    if target_len_total != 0:
      print("recall_AA_total = {0:.4f}".format(recall_AA_total / target_len_total))
    if target_len_db != 0:
      print("recall_AA_db = {0:.4f}".format(recall_AA_total / target_len_db))
    if target_len_db_mass != 0:
      print("recall_AA_db_mass = {0:.4f}".format(recall_AA_total / target_len_db_mass))
    if target_count_total != 0:
      print("recall_peptide_total = {0:.4f}".format(recall_peptide_total / target_count_total))
    if target_count_db != 0:
      print("recall_peptide_db = {0:.4f}".format(recall_peptide_total / target_count_db))
    if target_count_db_mass != 0:
      print("recall_peptide_db_mass = {0:.4f}".format(recall_peptide_total / target_count_db_mass))
    if predicted_len_mass_db != 0:
      print("precision_AA_mass_db  = {0:.4f}".format(recall_AA_total / predicted_len_mass_db))
    if predicted_count_mass_db != 0:
      print("precision_peptide_mass_db  = {0:.4f}".format(recall_peptide_total / predicted_count_mass_db))
  
  
  def _compute_peptide_mass(self, peptide):
    """
    Add start and en mass to the aa chains
    """
    #print("".join(["="] * 80)) # section-separating line
    #print("WorkerTest._compute_peptide_mass()")
    peptide_mass = (mass_N_terminus
                    + sum(mass_AA[aa] for aa in peptide)
                    + mass_C_terminus)

    return peptide_mass


  def _get_predicted(self):
    print("".join(["="] * 80)) # section-separating line
    print("WorkerTest._get_predicted()")

    predicted_list = []
    col_feature_id = pcol_feature_id
    col_feature_area = pcol_feature_area
    col_sequence = pcol_sequence
    col_score = pcol_score
    col_scan_list_middle = pcol_scan_list_middle
    col_scan_list_original = pcol_scan_list_original
    with open(self.predicted_file, 'r') as handle:
      # header
      handle.readline()
      for line in handle:
        line_split = re.split('\t|\n', line)
        predicted = {}
        predicted["feature_id"] = line_split[col_feature_id]
        predicted["feature_area"] = float(line_split[col_feature_area])
        predicted["scan_list_middle"] = line_split[col_scan_list_middle]
        predicted["scan_list_original"] = line_split[col_scan_list_original]
        if line_split[col_sequence]: # not empty sequence
          predicted["sequence"] = [re.split(',', x)
                                   for x in re.split(';', line_split[col_sequence])]
          predicted["score"] = [float(x)
                                for x in re.split(';', line_split[col_score])]
        else: 
          predicted["sequence"] = [[]]
          predicted["score"] = [-999]
        predicted_list.append(predicted)

    self.predicted_list = predicted_list


  def _get_predicted_peaks(self):
    print("".join(["="] * 80)) # section-separating line
    print("WorkerTest._get_predicted_peaks()")

    predicted_list = []
    col_fraction_id = 0
    fraction_id_map = {'1':'1',
                       '2':'10',
                       '3':'11',
                       '4':'12',
                       '5':'2',
                       '6':'3',
                       '7':'4',
                       '8':'5',
                       '9':'6',
                       '10':'7',
                       '11':'8',
                       '12':'9',
                      }
    col_scan_id = 1
    col_sequence = 3
    with open(self.predicted_file, 'r') as handle:
      # header
      handle.readline()
      for line in handle:
        line_split = re.split(',|\n', line)
        predicted = {}
        predicted["feature_id"] = "F" + line_split[col_fraction_id] + ":" + line_split[col_scan_id]
        raw_sequence = line_split[col_sequence]
        assert raw_sequence, "Error: wrong format."
        predicted["sequence"] = self._parse_sequence(raw_sequence)
        
        # skip peptides with precursor_mass > MZ_MAX
        if self._compute_peptide_mass(predicted["sequence"]) > self.MZ_MAX:
          continue
        predicted["feature_area"] = 0
        predicted["scan_list_middle"] = ""
        predicted["scan_list_original"] = ""
        predicted["sequence"] = [predicted["sequence"]]
        predicted["score"] = [-999]
        predicted_list.append(predicted)

    self.predicted_list = predicted_list


  def _get_target(self):
    print("".join(["="] * 80)) # section-separating line
    print("WorkerTest._get_target()")

    target_dict = {}
    with open(self.target_file, 'r') as handle:
      header_line = handle.readline()
      header = header_line.strip().split(',')
      raw_sequence_index = header.index(col_raw_sequence)
      for line in handle:
        line = re.split(',|\r|\n', line)
        feature_id = line[0]
        raw_sequence = line[raw_sequence_index]
        assert raw_sequence, "Error: wrong target format."
        peptide = self._parse_sequence(raw_sequence)
        target_dict[feature_id] = peptide
    self.target_dict = target_dict


  def _parse_sequence(self, raw_sequence):

    #print("".join(["="] * 80)) # section-separating line
    #print("WorkerTest._parse_sequence()")

    return re.findall(r'[A-Z](?:\(.+?\))?', raw_sequence)

    """ # Depricated if use the MASCOT processing script. Uncomment if use the 
    data from the original paper of Tran et al. 2019.

    raw_sequence_len = len(raw_sequence)
    peptide = []
    index = 0
    while index < raw_sequence_len:
      if raw_sequence[index] == "(":
        if peptide[-1] == "C" and raw_sequence[index:index+8] == "(+57.02)":
          peptide[-1] = "C(Carbamidomethylation)"
          index += 8
        elif peptide[-1] == 'M' and raw_sequence[index:index+8] == "(+15.99)":
          peptide[-1] = 'M(Oxidation)'
          index += 8
        elif peptide[-1] == 'N' and raw_sequence[index:index+6] == "(+.98)":
          peptide[-1] = 'N(Deamidation)'
          index += 6
        elif peptide[-1] == 'Q' and raw_sequence[index:index+6] == "(+.98)":
          peptide[-1] = 'Q(Deamidation)'
          index += 6
        else: # unknown modification
          print("ERROR: unknown modification!")
          print("raw_sequence = ", raw_sequence)
          sys.exit()
      else:
        peptide.append(raw_sequence[index])
        index += 1

    return peptide
    """

  def _match_AA_novor(self, target, predicted):
    """"""
  
    #print("".join(["="] * 80)) # section-separating line
    #print("WorkerTest._test_AA_match_novor()")

    num_match = 0
    target_len = len(target)
    predicted_len = len(predicted)
    target_mass = [mass_ID[x] for x in target]
    target_mass_cum = np.cumsum(target_mass)
    predicted_mass = [mass_ID[x] for x in predicted]
    predicted_mass_cum = np.cumsum(predicted_mass)
  
    i = 0
    j = 0
    while i < target_len and j < predicted_len:
      if abs(target_mass_cum[i] - predicted_mass_cum[j]) < 0.5:
        if abs(target_mass[i] - predicted_mass[j]) < 0.1:
        #~ if  decoder_input[index_aa] == output[index_aa]:
          num_match += 1
        i += 1
        j += 1
      elif target_mass_cum[i] < predicted_mass_cum[j]:
        i += 1
      else:
        j += 1

    return num_match

## Read feature accuracy <a name="test_accuracy"></a>

In [39]:
def read_feature_accuracy(input_file, split_char):
  """
  Read and convert the accuracy file and return a list
  """
  feature_list = []
  with open(input_file, 'r') as handle:
    header_line = handle.readline()
    for line in handle:
      line = re.split(split_char, line)
      feature = {}
      feature["feature_id"] = line[0]
      feature["feature_area"] = math.log10(float(line[1]) + 1e-5)
      feature["predicted_score"] = float(line[4])
      feature["recall_AA"] = float(line[5])
      feature["predicted_len"] = float(line[6])
      feature_list.append(feature)
  return feature_list

## Find score cutoff <a name="test_cutoff"></a>

In [40]:
def find_score_cutoff(accuracy_file, accuracy_cutoff):
  """
  Find the sequences with an acccurcy higher than the cutoff and return the count,
  and the average score computed of the predicted data.
  """
  print("".join(["="] * 80)) # section-separating line
  print("find_score_cutoff()")

  feature_list = read_feature_accuracy(accuracy_file, '\t|\r|\n')
  feature_list_sorted = sorted(feature_list, key=lambda k: k['predicted_score'], reverse=True)
  recall_cumsum = np.cumsum([f['recall_AA'] for f in feature_list_sorted])
  predicted_len_cumsum = np.cumsum([f['predicted_len'] for f in feature_list_sorted])
  accuracy_cumsum = recall_cumsum / predicted_len_cumsum
  cutoff_index = np.flatnonzero(accuracy_cumsum < accuracy_cutoff)[0]
  cutoff_score = feature_list_sorted[cutoff_index]['predicted_score']
  print('cutoff_index = ', cutoff_index)
  print('cutoff_score = ', cutoff_score)
  print('cutoff_score = ', 100*math.exp(cutoff_score))

  return cutoff_score

## Testing Launch section <a name="test_launch"></a>

In [41]:
if 'test' in option:
  worker_test = WorkerTest()
  worker_test.test_accuracy()

  # show 95 accuracy score threshold
  accuracy_cutoff = 0.95
  score_cutoff = find_score_cutoff(accuracy_file, accuracy_cutoff)

WorkerTest.__init__()
target_file = ../smbp_data/features_smbp.csv.test
predicted_file = ../smbp_data/features_smbp.csv.test.deepnovo_denovo
predicted_format = deepnovo
accuracy_file = ../smbp_data/features_smbp.csv.test.deepnovo_denovo.accuracy
denovo_only_file = ../smbp_data/features_smbp.csv.test.deepnovo_denovo.denovo_only
scan2fea_file = ../smbp_data/features_smbp.csv.test.deepnovo_denovo.scan2fea
multifea_file = ../smbp_data/features_smbp.csv.test.deepnovo_denovo.multifea
WorkerTest.test_accuracy()
WorkerTest._get_target()
WorkerTest._get_predicted()
target_count_total = 467
target_len_total = 6669
target_count_db = 467
target_len_db = 6669
target_count_db_mass: 467
target_len_db_mass: 6669
predicted_count_mass: 452
predicted_count_mass_db: 452
predicted_len_mass_db: 6450
predicted_only: 0
recall_AA_total = 0.6839
recall_AA_db = 0.6839
recall_AA_db_mass = 0.6839
recall_peptide_total = 0.3340
recall_peptide_db = 0.3340
recall_peptide_db_mass = 0.3340
precision_AA_mass_db  = 0.7071