In [None]:
#| default_exp model.ccs

In [None]:
#| hide
%reload_ext autoreload
%autoreload 2

# CCS model and interface

In [None]:
#| export
import torch
import pandas as pd
import numpy as np

from tqdm import tqdm

from alphabase.peptide.fragment import update_precursor_mz

from alphabase.peptide.mobility import (
    ccs_to_mobility_for_df,
    mobility_to_ccs_for_df
)

from peptdeep.model.featurize import (
    mod_feature_size
)

import peptdeep.model.base as model_base


In [None]:
#| export
class Model_CCS_Bert(torch.nn.Module):
    """
    Transformer model for CCS prediction
    """
    def __init__(self,
        dropout = 0.1,
        nlayers = 4,
        hidden = 128,
        output_attentions=False,
        **kwargs,
    ):
        super().__init__()

        self.dropout = torch.nn.Dropout(dropout)

        self.input_nn = model_base.AATransformerEncoding(hidden-2)

        self._output_attentions = output_attentions
        
        self.hidden_nn = model_base.Hidden_HFace_Transformer(
            hidden, nlayers=nlayers, dropout=dropout,
            output_attentions=output_attentions
        )

        self.output_nn = torch.nn.Sequential(
            model_base.SeqAttentionSum(hidden),
            torch.nn.PReLU(),
            self.dropout,
            torch.nn.Linear(hidden, 1),
        )

    @property
    def output_attentions(self):
        return self._output_attentions

    @output_attentions.setter
    def output_attentions(self, val:bool):
        self._output_attentions = val
        self.hidden_nn.output_attentions = val

    def forward(self, 
        aa_indices, 
        mod_x,
        charges:torch.Tensor,
    ):
        x = self.dropout(self.input_nn(
            aa_indices, mod_x
        ))
        charges = charges.unsqueeze(1).repeat(1,x.size(1),2)
        x = torch.cat((x, charges),2)

        hidden_x = self.hidden_nn(x)
        if self.output_attentions:
            self.attentions = hidden_x[1]
        else:
            self.attentions = None
        x = self.dropout(hidden_x[0]+x*0.2)

        return self.output_nn(x).squeeze(1)

In [None]:
#| hide
torch.manual_seed(1337)
model = Model_CCS_Bert()
mod_hidden = mod_feature_size
model(
    torch.LongTensor([[1,2,3,4,5,6]]), 
    torch.tensor([[[0.0]*mod_hidden]*6]), 
    torch.tensor([[1]])
)

tensor([0.1206], grad_fn=<SqueezeBackward1>)

In [None]:
#| export
class Model_CCS_LSTM(torch.nn.Module):
    """LSTM model for CCS prediction"""
    def __init__(self,
        dropout=0.1
    ):
        super().__init__()
        
        self.dropout = torch.nn.Dropout(dropout)
        
        hidden = 256

        self.ccs_encoder = (
            model_base.Encoder_26AA_Mod_Charge_CNN_LSTM_AttnSum(
                hidden
            )
        )

        self.ccs_decoder = model_base.Decoder_Linear(
            hidden+1, 1
        )

    def forward(self, 
        aa_indices, 
        mod_x,
        charges,
    ):
        x = self.ccs_encoder(aa_indices, mod_x, charges)
        x = self.dropout(x)
        x = torch.cat((x, charges),1)
        return self.ccs_decoder(x).squeeze(1)

In [None]:
#| export

def ccs_to_mobility_pred_df(
    precursor_df:pd.DataFrame
)->pd.DataFrame:
    """ Add 'mobility_pred' into precursor_df inplace """
    precursor_df[
        'mobility_pred'
    ] = ccs_to_mobility_for_df(
        precursor_df, 'ccs_pred'
    )
    return precursor_df

def mobility_to_ccs_df_(
    precursor_df:pd.DataFrame
)->pd.DataFrame:
    """ Add 'ccs' into precursor_df inplace """
    precursor_df[
        'ccs'
    ] = mobility_to_ccs_for_df(
        precursor_df, 'mobility'
    )
    return precursor_df

In [None]:
#| export

class AlphaCCSModel(model_base.ModelInterface):
    """
    `ModelInterface` for `Model_CCS_LSTM` or `Model_CCS_Bert`
    """
    def __init__(self, 
        dropout=0.1,
        model_class:torch.nn.Module=Model_CCS_LSTM,
        device:str='gpu',
        **kwargs,
    ):
        super().__init__(device=device)
        self.build(
            model_class,
            dropout=dropout, 
            **kwargs
        )
        self.charge_factor = 0.1

        self.target_column_to_predict = 'ccs_pred'
        self.target_column_to_train = 'ccs'

    def _get_features_from_batch_df(self, 
        batch_df: pd.DataFrame,
    ):
        aa_indices = self._get_26aa_indice_features(batch_df)

        mod_x = self._get_mod_features(batch_df)

        charges = self._as_tensor(
            batch_df['charge'].values
        ).unsqueeze(1)*self.charge_factor

        return aa_indices, mod_x, charges

    def ccs_to_mobility_pred(self,
        precursor_df:pd.DataFrame
    )->pd.DataFrame:
        return ccs_to_mobility_pred_df(precursor_df)

In [None]:
#| hide
torch.manual_seed(1337)
model = AlphaCCSModel()
model.device = torch.device('cpu')
model.model.to(model.device)
mod_hidden = mod_feature_size
model.model(
    torch.LongTensor([[1,2,3,4,5,6]]), 
    torch.tensor([[[0.0]*mod_hidden]*6]), 
    torch.tensor([[1]])
)

tensor([0.0029], grad_fn=<SqueezeBackward1>)

In [None]:
model.get_parameter_num()

713452

In [None]:
model.model

Model_CCS_LSTM(
  (dropout): Dropout(p=0.1, inplace=False)
  (ccs_encoder): Encoder_AA_Mod_Charge_CNN_LSTM_AttnSum(
    (mod_nn): Mod_Embedding_FixFirstK(
      (nn): Linear(in_features=103, out_features=2, bias=False)
    )
    (input_cnn): SeqCNN(
      (cnn_short): Conv1d(36, 36, kernel_size=(3,), stride=(1,), padding=(1,))
      (cnn_medium): Conv1d(36, 36, kernel_size=(5,), stride=(1,), padding=(2,))
      (cnn_long): Conv1d(36, 36, kernel_size=(7,), stride=(1,), padding=(3,))
    )
    (hidden_nn): SeqLSTM(
      (rnn): LSTM(144, 128, num_layers=2, batch_first=True, bidirectional=True)
    )
    (attn_sum): SeqAttentionSum(
      (attn): Sequential(
        (0): Linear(in_features=256, out_features=1, bias=False)
        (1): Softmax(dim=1)
      )
    )
  )
  (ccs_decoder): Decoder_Linear(
    (nn): Sequential(
      (0): Linear(in_features=257, out_features=64, bias=True)
      (1): PReLU(num_parameters=1)
      (2): Linear(in_features=64, out_features=1, bias=True)
    )
  )
)

In [None]:
repeat = 10
precursor_df = pd.DataFrame({
    'sequence': ['AGHCEWQMKYR']*repeat,
    'mods': ['Acetyl@Protein N-term;Carbamidomethyl@C;Oxidation@M']*repeat,
    'mod_sites': ['0;4;8']*repeat,
    'nAA': [11]*repeat,
    'charge': [2]*repeat,
    'ccs': [1]*repeat
})
precursor_df

Unnamed: 0,sequence,mods,mod_sites,nAA,charge,ccs
0,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,2,1
1,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,2,1
2,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,2,1
3,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,2,1
4,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,2,1
5,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,2,1
6,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,2,1
7,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,2,1
8,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,2,1
9,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,2,1


In [None]:
model.train(precursor_df, epoch=5)

In [None]:
model.predict(precursor_df)

Unnamed: 0,sequence,mods,mod_sites,nAA,charge,ccs,ccs_pred
0,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,2,1,0.032652
1,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,2,1,0.032652
2,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,2,1,0.032652
3,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,2,1,0.032652
4,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,2,1,0.032652
5,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,2,1,0.032652
6,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,2,1,0.032652
7,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,2,1,0.032652
8,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,2,1,0.032652
9,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,2,1,0.032652


In [None]:
model.ccs_to_mobility_pred(precursor_df)

Unnamed: 0,sequence,mods,mod_sites,nAA,charge,ccs,ccs_pred,precursor_mz,mobility_pred
0,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,2,1,0.032652,762.329553,8.1e-05
1,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,2,1,0.032652,762.329553,8.1e-05
2,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,2,1,0.032652,762.329553,8.1e-05
3,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,2,1,0.032652,762.329553,8.1e-05
4,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,2,1,0.032652,762.329553,8.1e-05
5,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,2,1,0.032652,762.329553,8.1e-05
6,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,2,1,0.032652,762.329553,8.1e-05
7,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,2,1,0.032652,762.329553,8.1e-05
8,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,2,1,0.032652,762.329553,8.1e-05
9,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,2,1,0.032652,762.329553,8.1e-05
