# ESM

In [152]:
import torch
import esm

# Load ESM-2 model
model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
batch_converter = alphabet.get_batch_converter()
model.eval()  # disables dropout for deterministic results

# Prepare data (first 2 sequences from ESMStructuralSplitDataset superfamily / 4)
data = [
    ("protein1", "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG"),
    ("protein2", "KALTARQQEVFDLIRDHISQTGMPPTRAEIAQRLGFRSPNAAEEHLKALARKGVIEIVSGASRGIRLLQEE"),
    ("protein3",  "K A <mask> I S Q"),
]
batch_labels, batch_strs, batch_tokens = batch_converter(data)
batch_lens = (batch_tokens != alphabet.padding_idx).sum(1)

# Extract per-residue representations (on CPU)
with torch.no_grad():
    results = model(batch_tokens, repr_layers=[33], return_contacts=True)
token_representations = results["representations"][33]


In [153]:
print(len("KALTARQQEVFDLIRDHISQTGMPPTRAEIAQRLGFRSPNAAEEHLKALARKGVIEIVSGASRGIRLLQEE"))


71


In [154]:
print(token_representations.size())


torch.Size([3, 73, 1280])


# Antiberty

In [157]:
from antiberty import AntiBERTyRunner

antiberty = AntiBERTyRunner()

sequences = [
    "EVQLVQSGPEVKKPGTSVKVSCKASGFTFMSSAVQWVRQARGQRLEWIGWIVIGSGNTNYAQKFQERVTITRDMSTSTAYMELSSLRSEDTAVYYCAAPYCSSISCNDGFDIWGQGTMVTVS",
    "DVVMTQTPFSLPVSLGDQASISCRSSQSLVHSNGNTYLHWYLQKPGQSPKLLIYKVSNRFSGVPDRFSGSGSGTDFTLKISRVEAEDLGVYFCSQSTHVPYTFGGGTKLEIK",
]
embeddings = antiberty.embed(sequences)


In [162]:
print(embeddings[1].size())


torch.Size([114, 512])


In [161]:
print(len("DVVMTQTPFSLPVSLGDQASISCRSSQSLVHSNGNTYLHWYLQKPGQSPKLLIYKVSNRFSGVPDRFSGSGSGTDFTLKISRVEAEDLGVYFCSQSTHVPYTFGGGTKLEIK"))


112


# Prot trans

In [155]:
from transformers import T5Tokenizer, T5EncoderModel
import torch
import re

device = torch.device('cpu')

# Load the tokenizer
tokenizer = T5Tokenizer.from_pretrained('Rostlab/prot_t5_xl_half_uniref50-enc', do_lower_case=False)

# Load the model
model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_half_uniref50-enc").to(device)

# only GPUs support half-precision currently; if you want to run on CPU use full-precision (not recommended, much slower)
model.to(torch.float32)

# prepare your protein sequences as a list
sequence_examples = ["PRTEINO", "SEQWENCE"]

# replace all rare/ambiguous amino acids by X and introduce white-space between all amino acids
sequence_examples = [" ".join(list(re.sub(r"[UZOB]", "X", sequence))) for sequence in sequence_examples]

# tokenize sequences and pad up to the longest sequence in the batch
ids = tokenizer(sequence_examples, add_special_tokens=True, padding="longest")

input_ids = torch.tensor(ids['input_ids']).to(device)
attention_mask = torch.tensor(ids['attention_mask']).to(device)

# generate embeddings
with torch.no_grad():
    embedding_repr = model(input_ids=input_ids, attention_mask=attention_mask)

# extract residue embeddings for the first ([0,:]) sequence in the batch and remove padded & special tokens ([0,:7])
emb_0 = embedding_repr.last_hidden_state[0,:7] # shape (7 x 1024)
# same for the second ([1,:]) sequence but taking into account different sequence lengths ([1,:8])
emb_1 = embedding_repr.last_hidden_state[1,:8] # shape (8 x 1024)

# if you want to derive a single representation (per-protein embedding) for the whole protein
emb_0_per_protein = emb_0.mean(dim=0) # shape (1024)


In [156]:
print(emb_0.size())


torch.Size([7, 1024])


# creating a dataset

In [147]:
from pathlib import Path
from typing import Dict
import json
import ablang2
import numpy as np
import torch
import typer
from transformers import BertModel, BertTokenizer, T5EncoderModel, T5Tokenizer
import torch.nn.functional as F
import re
from antiberty import AntiBERTyRunner



app = typer.Typer(add_completion=False)
def create_embeddings_3(
        dataset_dict: Dict,
        save_path: Path = Path("/home/gathenes/paratope_model/test/test3/results"),
    ):
    """Create LLM amino acid embeddings.

    Args:
        dataset_dict (Dict): Dictionary mapping index to heavy and light aa sequence.
        save_path (Path): Path where to save embeddings.
    """
    print("CREATING EMBEDDINGS")
    sequence_heavy_emb = [dataset_dict[index]["H_id sequence"] for index in dataset_dict]
    sequence_light_emb = [dataset_dict[index]["L_id sequence"] for index in dataset_dict]
    paired_sequences = []
    for seq_heavy, seq_light in zip(sequence_heavy_emb, sequence_light_emb):
        paired_sequences.append(
            " ".join(seq_heavy) + " [SEP] " + " ".join(seq_light)
        )

    ########################################################
    ######################## BERT ##########################
    ########################################################
    bert_tokeniser = BertTokenizer.from_pretrained("Exscientia/IgBert", do_lower_case=False)
    bert_model = BertModel.from_pretrained("Exscientia/IgBert", add_pooling_layer=False)
    tokens = bert_tokeniser.batch_encode_plus(
        paired_sequences,
        add_special_tokens=True,
        padding="max_length",
        max_length=280,
        return_tensors="pt",
        return_special_tokens_mask=True,
    )
    with torch.no_grad():
        output = bert_model(
            input_ids=tokens["input_ids"], attention_mask=tokens["attention_mask"]
        )
        bert_residue_embeddings = output.last_hidden_state

    ########################################################
    ###################### IGT5 ############################
    ########################################################
    igt5_tokeniser = T5Tokenizer.from_pretrained("Exscientia/IgT5", do_lower_case=False)
    igt5_model = T5EncoderModel.from_pretrained("Exscientia/IgT5")
    tokens = igt5_tokeniser.batch_encode_plus(
        paired_sequences,
        add_special_tokens=True,
        padding="max_length",
        max_length=280,
        return_tensors="pt",
        return_special_tokens_mask=True,
    )
    with torch.no_grad():
        output = igt5_model(
            input_ids=tokens["input_ids"], attention_mask=tokens["attention_mask"]
        )
        igt5_residue_embeddings = output.last_hidden_state

    ########################################################
    ##################### ABLANG ###########################
    ########################################################
    ablang = ablang2.pretrained()
    all_seqs = [[seq_heavy, seq_light] for seq_heavy, seq_light in zip(sequence_heavy_emb, sequence_light_emb)]
    ablang_embeddings = ablang(all_seqs, mode='rescoding', stepwise_masking=False)
    ablang_embeddings = [np.pad(each, ((0, 280 - each.shape[0]), (0, 0)), 'constant') for each in ablang_embeddings]
    ablang_embeddings = torch.Tensor(np.stack(ablang_embeddings))

    ########################################################
    ######################## ESM ###########################
    ########################################################
    esm_model, esm_alphabet = esm.pretrained.esm2_t33_650M_UR50D()
    esm_batch_converter = esm_alphabet.get_batch_converter()
    esm_model.eval()

    data = []
    for seq_heavy, seq_light in zip(sequence_heavy_emb, sequence_light_emb):
        data.append(("ab", "".join(seq_heavy) + "".join(seq_light)))
    _, _, esm_batch_tokens = esm_batch_converter(data)
    with torch.no_grad():
        esm_results = esm_model(esm_batch_tokens, repr_layers=[33], return_contacts=True)
    esm_embeddings = esm_results["representations"][33]
    pad_length = 280 - esm_embeddings.size(1)  # 280 is the desired length
    padding = (0, 0, 0, pad_length)
    esm_embeddings = F.pad(esm_embeddings, padding, mode='constant', value=0)

    ########################################################
    #################### ANTIBERTY #########################
    ########################################################
    antiberty = AntiBERTyRunner()
    antiberty_sequences = [
        "".join(seq_heavy) + "".join(seq_light)
        for seq_heavy, seq_light in zip(sequence_heavy_emb, sequence_light_emb)
    ]
    antiberty_embeddings = antiberty.embed(antiberty_sequences)
    antiberty_embeddings = [np.pad(each, ((0, 280 - each.shape[0]), (0, 0)), 'constant') for each in antiberty_embeddings]
    antiberty_embeddings = torch.Tensor(np.stack(antiberty_embeddings))

    ########################################################
    ####################### ProtT5 #########################
    ########################################################

    device = torch.device('cpu')

    prot_t5_tokenizer = T5Tokenizer.from_pretrained("Rostlab/prot_t5_xl_half_uniref50-enc", do_lower_case=False)
    prot_t5_model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_half_uniref50-enc").to(device)
    prot_t5_model.to(torch.float32)

    prot_t5_sequences = [
        "".join(seq_heavy) + "".join(seq_light)
        for seq_heavy, seq_light in zip(sequence_heavy_emb, sequence_light_emb)
    ]
    prot_t5_sequences = [" ".join(list(re.sub(r"[UZOB]", "X", seq))) for seq in prot_t5_sequences]
    prot_t5_ids = prot_t5_tokenizer(prot_t5_sequences, add_special_tokens=True, padding="longest", return_tensors="pt")

    input_ids = prot_t5_ids['input_ids'].to(device)
    attention_mask = prot_t5_ids['attention_mask'].to(device)

    with torch.no_grad():
        prot_t5_output = prot_t5_model(input_ids=input_ids, attention_mask=attention_mask)

    prot_t5_embeddings = prot_t5_output.last_hidden_state
    pad_length = 280 - prot_t5_embeddings.size(1)
    padding = (0, 0, 0, pad_length)
    prot_t5_embeddings = F.pad(prot_t5_embeddings, padding, mode='constant', value=0)

    ########################################################
    ################# CONCATENATE EMBEDDINGS ###############
    ########################################################
    residue_embeddings = torch.cat([
        bert_residue_embeddings,
        igt5_residue_embeddings,
        ablang_embeddings,
        esm_embeddings,
        antiberty_embeddings,
        prot_t5_embeddings
    ], dim=2)

    return (
        bert_residue_embeddings,
        igt5_residue_embeddings,
        ablang_embeddings,
        esm_embeddings,
        antiberty_embeddings,
        prot_t5_embeddings,
        residue_embeddings
    )


In [148]:
with open ("/home/gathenes/paratope_model/test/test3/test/dict.json") as f:
    test_dict = json.load(f)


In [149]:
bert_residue_embeddings, igt5_residue_embeddings, ablang_embeddings, esm_embeddings, antiberty_embeddings, prot_t5_embeddings, residue_embeddings= create_embeddings_3(test_dict)


CREATING EMBEDDINGS


Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00,  5.68it/s]
  torch.load(


In [150]:
for each in [bert_residue_embeddings, igt5_residue_embeddings, ablang_embeddings, esm_embeddings, antiberty_embeddings, prot_t5_embeddings]:
    print(each.size())


torch.Size([3, 280, 1024])
torch.Size([3, 280, 1024])
torch.Size([3, 280, 480])
torch.Size([3, 280, 1280])
torch.Size([3, 280, 512])
torch.Size([3, 280, 1024])


In [151]:
print(residue_embeddings.size())


torch.Size([3, 280, 5344])


In [29]:
from create_dataset import add_convex_hull_column
from utils import read_pdb_to_dataframe


In [32]:
chains=["B","C"]
df_pdb = (
    read_pdb_to_dataframe("/home/gathenes/all_structures/imgt_renumbered_pecan/5bv7.pdb")
    .query("chain_id.isin(@chains) and residue_number<129")
)
print(df_pdb)


Empty DataFrame
Columns: [record_name, atom_number, blank_1, atom_name, alt_loc, residue_name, blank_2, chain_id, residue_number, insertion, blank_3, x_coord, y_coord, z_coord, occupancy, b_factor, blank_4, segment_id, element_symbol, charge, line_idx, IMGT]
Index: []


In [None]:
df_pdb = add_convex_hull_column(df_pdb)
