## **esm-2 representation**

#### -- this doc aims to use the esm-2 model to represent the seq of antibodies & antigens, and train the Bi-LSTM model to learn the relationship between seq representations and the interface.

#### **1. loading the training dataset**

In [1]:
import pandas as pd

In [2]:
train_data_path = "./data/train.csv"
train_data = pd.read_csv(train_data_path,sep=',')

#### **2. representing data feature using esm-2**

In [4]:
import torch
import esm

In [5]:
# Load ESM-2 model
model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
batch_converter = alphabet.get_batch_converter()

In [30]:
# 2. input data
data_train_ab = list(zip(train_data.iloc[:, 0].astype(str) + '_' + train_data.iloc[:, 1].astype(str), train_data.iloc[:, 2].astype(str)))
data_train_ag = list(zip(train_data.iloc[:, 0].astype(str) + '_' + train_data.iloc[:, 3].astype(str), train_data.iloc[:, 4].astype(str)))

# data_train = data_train_ab + data_train_ag
data_train = data_train_ag

print(type(data_train))
print(len(data_train))
print(data_train[0])

df = pd.DataFrame(data_train, columns=['protein_name', 'seq'])
df.to_csv('/data/databases/epitope_prediction/sab/esm-rep-data_ag.csv', index=False)

<class 'list'>
124
('1afv_A', 'PIVQNLQGQMVHQAISPRTLNAWVKVVEEKAFSPEVIPMFSALSEGATPQDLNTMLNTVGGHQAAMQMLKETINEEAAEWDRLHPVHAGPIAPGQMREPRGSDIAGTTSTLQEQIGWMTHNPPIPVGEIYKRWIILGLNKIVRMYSPTSIL')


In [31]:
# Extract per-residue representations fro each seq
representations_list = []
for i in range(len(data_train)):
    data = [data_train[i]]
    batch_labels, batch_strs, batch_tokens = batch_converter(data)
    batch_lens = (batch_tokens != alphabet.padding_idx).sum(1)

    # Extract per-residue representations (on CPU) 
    with torch.no_grad():
        results = model(batch_tokens, repr_layers=[33], return_contacts=True)
    token_representations = results["representations"][33]
    representations_list.append(token_representations[0][1:-1])

In [32]:
# save representations data
max_cols = max([tensor.size(0) for tensor in representations_list])
print('the max length of protein seq is: ', max_cols)
for i, tensor in enumerate(representations_list):
    if tensor.size(0) < max_cols:
        padding = torch.zeros(max_cols - tensor.size(0), tensor.size(1))
        representations_list[i] = torch.cat([tensor, padding], dim=0)
    
representations_tensor = torch.stack(representations_list)

torch.save(representations_tensor, './data/traindata_esm_ag.pt')
print('successfully save the representation data of antigen sequences!')

the max length of protein seq is:  624
successfully save the representation data of antigen sequences!


In [27]:
# Extract per-residue representations fro each seq
representations_list = []
for i in range(len(data_train_ab)):
    data = [data_train_ab[i]]
    batch_labels, batch_strs, batch_tokens = batch_converter(data)
    batch_lens = (batch_tokens != alphabet.padding_idx).sum(1)

    # Extract per-residue representations (on CPU) 
    with torch.no_grad():
        results = model(batch_tokens, repr_layers=[33], return_contacts=True)
    token_representations = results["representations"][33]
    representations_list.append(token_representations[0][1:-1])

In [29]:
# save representations data
max_cols = max([tensor.size(0) for tensor in representations_list])
for i, tensor in enumerate(representations_list):
    if tensor.size(0) < max_cols:
        padding = torch.zeros(max_cols - tensor.size(0), tensor.size(1))
        representations_list[i] = torch.cat([tensor, padding], dim=0)
    

representations_tensor = torch.stack(representations_list)

torch.save(representations_tensor, './data/traindata_esm_ab.pt')
print('successfully save the representation data of antibody sequences!')

the max length of protein seq is:  251
successfully save the representation data of antibody sequences!
