# Compute embeddings for designed binders

Use ESM

In [1]:
%pip install biopython

Note: you may need to restart the kernel to use updated packages.


In [2]:
import esm
import torch
from Bio import SeqIO

In [3]:
model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
model.eval()  # Disable dropout for evaluation


ESM2(
  (embed_tokens): Embedding(33, 1280, padding_idx=1)
  (layers): ModuleList(
    (0-32): 33 x TransformerLayer(
      (self_attn): MultiheadAttention(
        (k_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (v_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (q_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (out_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (rot_emb): RotaryEmbedding()
      )
      (self_attn_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
      (fc1): Linear(in_features=1280, out_features=5120, bias=True)
      (fc2): Linear(in_features=5120, out_features=1280, bias=True)
      (final_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
    )
  )
  (contact_head): ContactPredictionHead(
    (regression): Linear(in_features=660, out_features=1, bias=True)
    (activation): Sigmoid()
  )
  (emb_layer_norm_after): LayerNorm((1280,), eps=1

In [4]:
batch_converter = alphabet.get_batch_converter()

In [5]:
def compute_perplexity(sequence):
    data = [("sequence", sequence)]
    batch_labels, batch_strs, batch_tokens = batch_converter(data)
    batch_tokens = batch_tokens.to(device='cpu')

    with torch.no_grad():
        logits = model(batch_tokens, repr_layers=[], return_contacts=False)["logits"]
    loss_fct = torch.nn.CrossEntropyLoss()
    shift_logits = logits[..., :-1, :].contiguous()
    shift_labels = batch_tokens[..., 1:].contiguous()
    loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
    perplexity = torch.exp(loss).item()
    return perplexity


In [6]:
sequences = []
perplexities = []

fasta_files = [
    "../results/predictions/6y92_chainD_rational_designs.fasta",
    "../results/predictions/6vja_chainI_rational_designs.fasta",
    "../results/predictions/6y97_chainL_rational_designs.fasta"
]

for fasta_file in fasta_files:
    for record in SeqIO.parse(fasta_file, "fasta"):
        sequence = str(record.seq)
        perplexity = compute_perplexity(sequence)
        sequences.append({
            'id': record.id,
            'sequence': sequence,
            'perplexity': perplexity
        })


In [7]:
sequences.sort(key=lambda x: x['perplexity'])

In [8]:
for seq in sequences[:10]:
    print(f"ID: {seq['id']}, Perplexity: {seq['perplexity']}")


ID: 6vja_chainI_resseq_67-100, Perplexity: 38.40760803222656
ID: 6vja_chainI_resseq_57-100, Perplexity: 43.096954345703125
ID: 6vja_chainI_resseq_58-100, Perplexity: 43.09773254394531
ID: 6vja_chainI_resseq_64-100, Perplexity: 43.51446533203125
ID: 6vja_chainI_resseq_56-100, Perplexity: 43.684654235839844
ID: 6vja_chainI_resseq_61-100, Perplexity: 46.58464050292969
ID: 6vja_chainI_resseq_51-100, Perplexity: 48.430362701416016
ID: 6vja_chainI_resseq_52-100, Perplexity: 48.536800384521484
ID: 6vja_chainI_resseq_47-100, Perplexity: 50.004390716552734
ID: 6vja_chainI_resseq_68-100, Perplexity: 53.6435546875


In [9]:
from Bio.SeqRecord import SeqRecord
from Bio.Seq import Seq
from Bio import SeqIO

# Output file path
output_fasta = "../results/predictions/ranked_CD20_binder_designs.fasta"

# Create a list of SeqRecord objects with the perplexity in the sequence ID
seq_records = []
for seq in sequences:
    # Create a new SeqRecord with the perplexity appended to the sequence ID
    record = SeqRecord(
        Seq(seq['sequence']), 
        id=f"{seq['id']}_perplexity_{seq['perplexity']:.2f}",  # Append perplexity to ID
        description=""  # Description can be left empty
    )
    seq_records.append(record)

# Write all sequences to the output FASTA file
with open(output_fasta, "w") as output_handle:
    SeqIO.write(seq_records, output_handle, "fasta")

print(f"Ranked sequences written to {output_fasta}")


Ranked sequences written to ../results/predictions/ranked_CD20_binder_designs.fasta


In [11]:
# Write the top 500 sequences to a separate file for submission
# Output file path
output_fasta = "../results/submission.fasta"

# Create a list of SeqRecord objects with the perplexity in the sequence ID
seq_records = []
for seq in sequences[:500]:
    # Create a new SeqRecord with the perplexity appended to the sequence ID
    record = SeqRecord(
        Seq(seq['sequence']), 
        id=f"{seq['id']}_perplexity_{seq['perplexity']:.2f}",  # Append perplexity to ID
        description=""  # Description can be left empty
    )
    seq_records.append(record)

# Write top 500 sequences to the output FASTA file
with open(output_fasta, "w") as output_handle:
    SeqIO.write(seq_records, output_handle, "fasta")

print(f"Ranked sequences written to {output_fasta}")


Ranked sequences written to ../results/submission.fasta
