亲和力预测  
即抗体与抗原的亲和力越高，则表明二者的结合越强

该部分使用csv或者FASTA格式的文件对亲和力进行预测  
模型训练所给定的文件如下（以抗体的重链为例）  
而抗原只需要给定对应的抗原序列即可

| heavy    | 重链序列，完整的重链氨基酸序列信息。                                    |
| -------- | ----------------------------------------------------- |
| cdr1     | 重链的互补决定区 1（Complementary Determining Region 1）的氨基酸序列。 |
| cdr2     | 重链的互补决定区 2（Complementary Determining Region 2）的氨基酸序列。 |
| cdr3     | 重链的互补决定区 3（Complementary Determining Region 3）的氨基酸序列。 |
| affinity | 抗体与抗原之间的结合亲和力，通常以某种单位（如 pM 或 nM）表示。                   |


In [None]:

import torch
import torch.nn as nn
from transformers import BertModel
from base import BaseModel

class BERTBinding(BaseModel):
    def __init__(self, ReceptorBert_dir, emb_dim, dropout):
        super().__init__()
        self.ReceptorBert = BertModel.from_pretrained(ReceptorBert_dir)
        self.binding_predict = nn.Sequential(
            nn.Linear(in_features=emb_dim, out_features=emb_dim),
            nn.Tanh(),
            nn.Dropout(p=dropout),

            nn.Linear(in_features=emb_dim, out_features=1)
        )
        # self.binding_predict = nn.Sequential(
        #     nn.Linear(in_features=32, out_features=32*2),
        #     nn.Tanh(),
        #     nn.Linear(in_features=32*2, out_features=32*4),
        #     nn.Tanh(),
        #     nn.Linear(in_features=32*4, out_features=32*2),
        #     nn.Tanh(),
        #     nn.Linear(in_features=32*2, out_features=32),
        #     nn.Tanh(),
        #     nn.Linear(in_features=32, out_features=1)
        # )


    def forward(self, epitope, receptor):
        # shape: [batch_size, seq_length, emb_dim]
        receptor_encoded = self.ReceptorBert(**receptor).last_hidden_state

        '''
        Using the cls (classification) token as the input to get the score which is borrowed
        from huggingface NextSentencePrediciton implementation
        https://github.com/huggingface/transformers/issues/7540
        https://huggingface.co/transformers/v2.0.0/_modules/transformers/modeling_bert.html
        '''
        # shape: [batch_size, emb_dim]
        receptor_cls = receptor_encoded[:, 0, :]
        # receptor_cls = torch.squeeze(torch.sum(receptor_encoded, dim=1))
        output = self.binding_predict(receptor_cls)
        # output = self.binding_predict(receptor['input_ids'].type(torch.float))

        return output

还有基于ACE/SPR数据训练的模型函数，但所需输入为重链、轻链以及抗原序列，FASTA和txt格式应均可

In [None]:
from transformers import AutoModelForSequenceClassification
affinity_model = AutoModelForSequenceClassification.from_pretrained("Absci/affinity_predictor")

def predict_affinity(heavy_seq, light_seq, antigen_seq):
    inputs = tokenizer(heavy_seq, light_seq, antigen_seq, return_tensors="pt")
    return affinity_model(**inputs).logits