# Protein Language Model Embeddings Generator


In [None]:
# @title 1. Install Required Dependencies
# @markdown Run this cell first to install the necessary packages
%%capture
!pip install h5py numpy pandas pyfaidx torch tqdm transformers esm huggingface_hub

In [None]:
# @title ### 2. Import Libraries and Setup
import re
from pathlib import Path

import h5py
import numpy as np
import pandas as pd
import torch
from esm.models.esm3 import ESM3
from esm.models.esmc import ESMC
from esm.sdk.api import ESMProtein, LogitsConfig, SamplingConfig
from google.colab import drive, files, userdata
from huggingface_hub import HfFolder
from huggingface_hub import login as hf_login
from pyfaidx import Fasta
from tqdm.auto import tqdm
from transformers import AutoTokenizer, EsmModel, T5EncoderModel, T5Tokenizer

In [None]:
# @title ### Optional: Hugging Face Login (Needed for models like ESM3, ESMC)
# @markdown If you're using a model that requires authentication (e.g., native ESM models from EvolutionaryScale),
# @markdown run this cell and enter your Hugging Face token when prompted.
# @markdown You can get a token from [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens).
# @markdown Leave blank and run if you have already configured login in this Colab environment or the model is public.
hf_token_input = ""  # @param {type:"string"}

try:
    hf_token_secret = userdata.get("HF_TOKEN")
except Exception:
    hf_token_secret = None

if hf_token_input:
    hf_login(token=hf_token_input)
    print("Logged in with provided token.")
elif hf_token_secret:
    hf_login(token=hf_token_secret)
    print("Logged in with Colab secret 'HF_TOKEN'.")
elif HfFolder.get_token() is not None:
    print("Already logged in to Hugging Face Hub (found existing token/credentials).")
else:
    print("Not logged in. Native ESM models may fail if they require authentication.")

In [None]:
# @title Optional: Mount Google Drive
# @markdown If you want to use a FASTA file from your Google Drive, run this cell to mount your drive.
# @markdown It will prompt for authorization the first time.
should_mount_drive = False  # @param {type:"boolean"}
if should_mount_drive:
    try:
        drive.mount("/content/drive")
        print("Google Drive mounted successfully.")
    except Exception as e:
        print(f"Error mounting Google Drive: {e}")
else:
    print("Skipping Google Drive mount.")

In [None]:
# @title ### 3. Select Model and Upload File

# @markdown Choose a protein language model:
model_name = "Rostlab/prot_t5_xl_half_uniref50-enc"  # @param ["Rostlab/prot_t5_xl_half_uniref50-enc", "Rostlab/ProstT5_fp16", "ElnaggarLab/ankh-base", "ElnaggarLab/ankh-large", "facebook/esm2_t6_8M_UR50D", "facebook/esm2_t12_35M_UR50D", "facebook/esm2_t30_150M_UR50D", "facebook/esm2_t33_650M_UR50D", "EvolutionaryScale/esm3-sm-open-v1", "EvolutionaryScale/esmc-300m-2024-12", "EvolutionaryScale/esmc-600m-2024-12"]

# @markdown Choose embedding type:
embedding_type = "per_prot"  # @param ["per_prot", "per_res"]

# @markdown Set maximum sequence length (longer sequences will be skipped):
max_sequence_length = 2000  # @param {type:"integer"}

# @markdown Max batch size for transformer models (auto-reduces on GPU OOM):
batch_size = 128  # @param {type:"integer"}

# @markdown Enter a Google Drive path or upload from your computer:
fasta_filename = ""  # @param {type:"string", placeholder:"Path to FASTA file in Google Drive (leave empty to upload)"}
if not fasta_filename:
    uploaded = files.upload()
    fasta_filename = list(uploaded.keys())[0]

In [None]:
# @title ### Functions { display-mode: "form" }

MODEL_SHORT_KEYS = {
    "Rostlab/prot_t5_xl_half_uniref50-enc": "prot_t5",
    "Rostlab/ProstT5_fp16": "prost_t5",
    "ElnaggarLab/ankh-base": "ankh_base",
    "ElnaggarLab/ankh-large": "ankh_large",
    "facebook/esm2_t6_8M_UR50D": "esm2_8m",
    "facebook/esm2_t12_35M_UR50D": "esm2_35m",
    "facebook/esm2_t30_150M_UR50D": "esm2_150m",
    "facebook/esm2_t33_650M_UR50D": "esm2_650m",
    "EvolutionaryScale/esm3-sm-open-v1": "esm3_open",
    "EvolutionaryScale/esmc-300m-2024-12": "esmc_300m",
    "EvolutionaryScale/esmc-600m-2024-12": "esmc_600m",
}


def preprocess_sequences(df, model_type):
    """Prepare sequences for the given model type. Returns a new DataFrame."""
    df = df.copy()
    df["raw_sequence"] = df["sequence"]
    df["sequence"] = df["sequence"].str.replace("[BJOUZ]", "X", regex=True)
    if model_type in ("prost_t5", "prot_t5"):
        df["sequence"] = df["sequence"].apply(lambda s: " ".join(s))
    if model_type == "prost_t5":
        df["sequence"] = "<AA2fold> " + df["sequence"]
    return df


def setup_model(checkpoint):
    """Load model and tokenizer for the given checkpoint."""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Setting up model: {checkpoint} on {device}")

    if "esm2" in checkpoint:
        mod_type = "esm_transformer"
        tokenizer = AutoTokenizer.from_pretrained(checkpoint)
        model = EsmModel.from_pretrained(checkpoint)
    elif "esm3-sm-open-v1" in checkpoint:
        mod_type = "native_esm3"
        tokenizer = None
        model = ESM3.from_pretrained("esm3-open")
    elif checkpoint == "EvolutionaryScale/esmc-300m-2024-12":
        mod_type = "native_esmc"
        tokenizer = None
        model = ESMC.from_pretrained("esmc_300m")
    elif checkpoint == "EvolutionaryScale/esmc-600m-2024-12":
        mod_type = "native_esmc"
        tokenizer = None
        model = ESMC.from_pretrained("esmc_600m")
    elif "ankh" in checkpoint:
        mod_type = "ankh"
        tokenizer = AutoTokenizer.from_pretrained(checkpoint)
        model = T5EncoderModel.from_pretrained(checkpoint)
    elif "prot_t5" in checkpoint or "ProstT5" in checkpoint:
        is_prostt5 = "prostt5" in checkpoint.lower()
        mod_type = "prost_t5" if is_prostt5 else "prot_t5"
        tokenizer = T5Tokenizer.from_pretrained(
            checkpoint, do_lower_case=not is_prostt5
        )
        model = T5EncoderModel.from_pretrained(checkpoint, torch_dtype=torch.float16)
        if device.type == "cuda":
            model = model.half()
    else:
        raise ValueError(f"Unknown model checkpoint: {checkpoint}")

    return model.to(device), tokenizer, mod_type


def read_fasta(file_path):
    """Read a FASTA file and return headers and sequences."""
    headers, sequences = [], []
    with Fasta(str(file_path)) as fasta_data:
        for record in fasta_data:
            headers.append(record.name)
            sequences.append(str(record))
    return headers, sequences


def _embed_native_esm(sequence, emb_type, mod_type, model, device):
    """Compute embeddings using native ESM3/ESMC models."""
    protein = ESMProtein(sequence=sequence)
    tokens = model.encode(protein).to(device)

    with torch.inference_mode():
        if mod_type == "native_esm3":
            out = model.forward_and_sample(
                tokens, SamplingConfig(return_per_residue_embeddings=True)
            )
            per_res = out.per_residue_embedding.squeeze(0).cpu().float().numpy()
        else:  # native_esmc
            out = model.logits(
                tokens, LogitsConfig(sequence=True, return_embeddings=True)
            )
            per_res = out.embeddings.squeeze(0).cpu().float().numpy()

    per_res = per_res[1:-1, :]  # strip BOS/EOS
    return per_res.mean(axis=0) if emb_type == "per_prot" else per_res


def _embed_transformers_batch(sequences, emb_type, mod_type, model, tokenizer, max_len):
    """Compute embeddings for a batch using HuggingFace Transformers."""
    inputs = tokenizer(
        sequences,
        return_tensors="pt",
        max_length=max_len + 2,
        truncation=True,
        padding=True,
        add_special_tokens=True,
    ).to(model.device)

    with torch.inference_mode():
        hidden = model(**inputs).last_hidden_state.cpu().float().numpy()

    lengths = inputs["attention_mask"].sum(dim=1).tolist()

    results = []
    for i, seq_len in enumerate(lengths):
        # Strip special tokens: ESM2/ProstT5 have <cls>/<AA2fold> and <eos>,
        # ProtT5/Ankh have only <eos>
        if mod_type in ("esm_transformer", "prost_t5"):
            emb = hidden[i, 1 : seq_len - 1, :]
        else:  # prot_t5, ankh
            emb = hidden[i, : seq_len - 1, :]
        results.append(emb.mean(axis=0) if emb_type == "per_prot" else emb)

    return results


def _write_batch(hdf, headers, embeddings):
    """Write a batch of embeddings to an open HDF5 file."""
    for header, emb in zip(headers, embeddings, strict=False):
        hdf.create_dataset(name=header, data=emb.astype(np.float32))


def create_embedding(
    checkpoint,
    df,
    emb_type="per_prot",
    output_file="protein_embeddings.h5",
    max_len=2000,
    batch_size=32,
):
    """Generate embeddings and write them to an HDF5 file."""
    print("Setting up model...")
    model, tokenizer, mod_type = setup_model(checkpoint)
    model.eval()
    df_proc = preprocess_sequences(df, mod_type)

    # Validate headers (HDF5 doesn't allow '/' in dataset names)
    bad = df_proc[df_proc["header"].str.contains("/")]["header"].tolist()
    if bad:
        raise ValueError(
            "Headers contain '/' (invalid for HDF5):\n"
            + "\n".join(bad)
            + "\nPlease fix these headers and try again."
        )

    with h5py.File(output_file, "a") as hdf:
        # Skip already-computed embeddings via set comparison
        existing = set(hdf.keys())
        already_done = df_proc["header"].isin(existing)
        too_long = df_proc["raw_sequence"].str.len() > max_len
        to_compute = df_proc[~already_done & ~too_long].copy()

        n_existing = already_done.sum()
        n_too_long = (~already_done & too_long).sum()
        if n_existing:
            print(f"Skipping {n_existing} already computed embeddings")
        if n_too_long:
            print(f"Skipping {n_too_long} sequences exceeding max length {max_len}")

        if to_compute.empty:
            print("All embeddings already computed!")
            return

        # Sort by sequence length for efficient batching (shortest first)
        to_compute = to_compute.sort_values(
            by="raw_sequence", key=lambda s: s.str.len()
        )

        is_native = mod_type.startswith("native_esm")
        device = model.device
        print(f"Computing {len(to_compute)} embeddings...")

        if is_native:
            # Native ESM: single sequence at a time (API limitation)
            for _, row in tqdm(to_compute.iterrows(), total=len(to_compute)):
                emb = _embed_native_esm(
                    row["sequence"], emb_type, mod_type, model, device
                )
                hdf.create_dataset(name=row["header"], data=emb.astype(np.float32))
        else:
            # Transformers: batched with adaptive OOM recovery.
            # Sequences are sorted shortest-first, so once we reduce the
            # batch size we never need to increase it again.
            pbar = tqdm(total=len(to_compute))
            i = 0
            bs = batch_size
            while i < len(to_compute):
                batch = to_compute.iloc[i : i + bs]
                try:
                    embeddings = _embed_transformers_batch(
                        batch["sequence"].tolist(),
                        emb_type,
                        mod_type,
                        model,
                        tokenizer,
                        max_len,
                    )
                    _write_batch(hdf, batch["header"], embeddings)
                    pbar.update(len(batch))
                    i += bs
                except torch.cuda.OutOfMemoryError:
                    torch.cuda.empty_cache()
                    if bs > 1:
                        bs = max(1, bs // 2)
                        tqdm.write(f"GPU OOM â€” reducing batch size to {bs}")
                    else:
                        tqdm.write(
                            f"Skipping {batch.iloc[0]['header']} "
                            f"(len={len(batch.iloc[0]['raw_sequence'])}, OOM at batch_size=1)"
                        )
                        i += 1
                        pbar.update(1)
            pbar.close()

    del model, tokenizer
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

In [None]:
# @title ### 4. Generate Embeddings
# @markdown Click the play button to start generating embeddings

# @markdown Extract UniProt accession ID from FASTA header?
# @markdown If unchecked, the full FASTA header will be used as the identifier.
should_extract_uniprot_id = False  # @param {type:"boolean"}

fasta_path = Path(fasta_filename)

# Get the short key for the selected model_name, fallback to sanitized full name
short_model_key = MODEL_SHORT_KEYS.get(model_name, model_name.replace("/", "_"))
output_file = str(fasta_path.with_name(f"{fasta_path.stem}_{short_model_key}.h5"))

headers, sequences = read_fasta(fasta_path)
df = pd.DataFrame({"header": headers, "sequence": sequences})
print(f"Processing {len(df)} sequences for model {model_name}...")

# Regex to extract UniProt ID
uniprot_regex = r"[OPQ][0-9][A-Z0-9]{3}[0-9]|[A-NR-Z][0-9]([A-Z][A-Z0-9]{2}[0-9]){1,2}"

if should_extract_uniprot_id:
    df["header"] = df["header"].apply(
        lambda header: re.search(uniprot_regex, header).group(0)
        if re.search(uniprot_regex, header)
        else header
    )
else:
    print("Skipping UniProt ID extraction. Using full FASTA headers.")

create_embedding(
    model_name,
    df,
    emb_type=embedding_type,
    output_file=output_file,
    max_len=max_sequence_length,
    batch_size=batch_size,
)

print(f"\nEmbeddings saved to {output_file}")

In [None]:
# @title ### 5. Download Results
# @markdown Run this cell to download your embeddings file
files.download(output_file)