# Compute embeddings for designed binders

Use ESM

In [10]:
%pip install biopython

Collecting biopython
  Using cached biopython-1.83-cp38-cp38-macosx_11_0_arm64.whl.metadata (13 kB)
Downloading biopython-1.83-cp38-cp38-macosx_11_0_arm64.whl (2.7 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.7/2.7 MB[0m [31m881.3 kB/s[0m eta [36m0:00:00[0m [36m0:00:01[0m
[?25hInstalling collected packages: biopython
Successfully installed biopython-1.83
Note: you may need to restart the kernel to use updated packages.


In [11]:
import esm
import torch
from Bio import SeqIO

In [12]:
model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
model.eval()  # Disable dropout for evaluation


ESM2(
  (embed_tokens): Embedding(33, 1280, padding_idx=1)
  (layers): ModuleList(
    (0-32): 33 x TransformerLayer(
      (self_attn): MultiheadAttention(
        (k_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (v_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (q_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (out_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (rot_emb): RotaryEmbedding()
      )
      (self_attn_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
      (fc1): Linear(in_features=1280, out_features=5120, bias=True)
      (fc2): Linear(in_features=5120, out_features=1280, bias=True)
      (final_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
    )
  )
  (contact_head): ContactPredictionHead(
    (regression): Linear(in_features=660, out_features=1, bias=True)
    (activation): Sigmoid()
  )
  (emb_layer_norm_after): LayerNorm((1280,), eps=1

In [13]:
batch_converter = alphabet.get_batch_converter()


In [14]:
def compute_perplexity(sequence):
    data = [("sequence", sequence)]
    batch_labels, batch_strs, batch_tokens = batch_converter(data)
    batch_tokens = batch_tokens.to(device='cpu')

    with torch.no_grad():
        logits = model(batch_tokens, repr_layers=[], return_contacts=False)["logits"]
    loss_fct = torch.nn.CrossEntropyLoss()
    shift_logits = logits[..., :-1, :].contiguous()
    shift_labels = batch_tokens[..., 1:].contiguous()
    loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
    perplexity = torch.exp(loss).item()
    return perplexity


In [15]:
sequences = []
perplexities = []

fasta_files = [
    "../results/predictions/6y92_chainD_rational_designs.fasta",
    "../results/predictions/6vja_chainI_rational_designs.fasta",
    "../results/predictions/6y97_chainL_rational_designs.fasta"
]

for fasta_file in fasta_files:
    for record in SeqIO.parse(fasta_file, "fasta"):
        sequence = str(record.seq)
        perplexity = compute_perplexity(sequence)
        sequences.append({
            'id': record.id,
            'sequence': sequence,
            'perplexity': perplexity
        })


In [16]:
sequences.sort(key=lambda x: x['perplexity'])

In [21]:
for seq in sequences[:10]:
    print(f"ID: {seq['id']}, Perplexity: {seq['perplexity']}")


ID: 6vja_chainI_resseq_63-100, Perplexity: 41.93799591064453
ID: 6y97_chainL_resseq_1-43, Perplexity: 56.74275207519531
ID: 6y97_chainL_resseq_1-41, Perplexity: 61.050941467285156
ID: 6vja_chainI_resseq_1-39, Perplexity: 64.6700210571289
ID: 6y92_chainD_resseq_1-36, Perplexity: 81.5735855102539
ID: 6vja_chainI_resseq_60-100, Perplexity: 84.4903564453125
ID: 6vja_chainI_resseq_62-100, Perplexity: 95.07853698730469
ID: 6vja_chainI_resseq_71-104, Perplexity: 96.37775421142578
ID: 6vja_chainI_resseq_69-104, Perplexity: 108.69477844238281
ID: 6vja_chainI_resseq_46-82, Perplexity: 117.5959243774414


In [22]:
from Bio.SeqRecord import SeqRecord
from Bio.Seq import Seq
from Bio import SeqIO

# Output file path
output_fasta = "../results/predictions/ranked_CD20_binder_designs.fasta"

# Create a list of SeqRecord objects with the perplexity in the sequence ID
seq_records = []
for seq in sequences:
    # Create a new SeqRecord with the perplexity appended to the sequence ID
    record = SeqRecord(
        Seq(seq['sequence']), 
        id=f"{seq['id']}_perplexity_{seq['perplexity']:.2f}",  # Append perplexity to ID
        description=""  # Description can be left empty
    )
    seq_records.append(record)

# Write all sequences to the output FASTA file
with open(output_fasta, "w") as output_handle:
    SeqIO.write(seq_records, output_handle, "fasta")

print(f"Ranked sequences written to {output_fasta}")


Ranked sequences written to ../results/predictions/ranked_CD20_binder_designs.fasta
