# Protein Language Model Embeddings Generator

In [None]:
# @title 1. Install Required Dependencies
# @markdown Run this cell first to install the necessary packages
%%capture
!pip install h5py numpy pandas pyfaidx torch tqdm transformers esm huggingface_hub

In [None]:
# @title ### 2. Import Libraries and Setup
from pathlib import Path
import h5py
import numpy as np
import pandas as pd
import torch
from pyfaidx import Fasta
from tqdm import tqdm
from transformers import AutoTokenizer, EsmModel, T5EncoderModel, T5Tokenizer
from google.colab import drive, files, userdata

# Imports for Native ESM and HF Login
from esm.models.esm3 import ESM3
from esm.models.esmc import ESMC
from esm.sdk.api import ESMProtein, SamplingConfig, LogitsConfig
from huggingface_hub import login as hf_login, HfFolder

In [None]:
# @title ### Optional: Hugging Face Login (Needed for models like ESM3, ESMC)
# @markdown If you're using a model that requires authentication (e.g., native ESM models from EvolutionaryScale),
# @markdown run this cell and enter your Hugging Face token when prompted.
# @markdown You can get a token from [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens).
# @markdown Leave blank and run if you have already configured login in this Colab environment or the model is public.
hf_token_input = ""  # @param {type:"string"}

if hf_token_input:
    hf_login(token=hf_token_input)
    print("Attempted login with provided token.")
elif userdata.get("HF_TOKEN"):
    print("Attempted login with secret 'HF_TOKEN' token")
elif HfFolder.get_token() is not None:
    print("Already logged in to Hugging Face Hub (found existing token/credentials).")
else:
    print("Not logged in. Native ESM models may fail if they require authentication.")

In [None]:
# @title Optional: Mount Google Drive
# @markdown If you want to use a FASTA file from your Google Drive, run this cell to mount your drive.
# @markdown It will prompt for authorization the first time.
should_mount_drive = False  # @param {type:"boolean"}
if should_mount_drive:
    try:
        drive.mount("/content/drive")
        print("Google Drive mounted successfully.")
    except Exception as e:
        print(f"Error mounting Google Drive: {e}")
else:
    print("Skipping Google Drive mount.")

In [None]:
# @title ### 3. Select Model and Upload File

# @markdown Choose a protein language model:
model_name = "Rostlab/prot_t5_xl_half_uniref50-enc"  # @param ["Rostlab/prot_t5_xl_half_uniref50-enc", "Rostlab/ProstT5_fp16", "ElnaggarLab/ankh-base", "ElnaggarLab/ankh-large", "facebook/esm2_t6_8M_UR50D", "facebook/esm2_t12_35M_UR50D", "facebook/esm2_t30_150M_UR50D", "facebook/esm2_t33_650M_UR50D", "EvolutionaryScale/esm3-sm-open-v1", "EvolutionaryScale/esmc-300m-2024-12", "EvolutionaryScale/esmc-600m-2024-12"]

# @markdown Choose embedding type:
embedding_type = "per_prot"  # @param ["per_prot", "per_res"]

# @markdown Set maximum sequence length (longer sequences will be skipped):
max_sequence_length = 2000  # @param {type:"integer"}

# @markdown Enter a Google Drive path or upload from your computer:
fasta_filename = ""  # @param {type:"string", placeholder:"Path to FASTA file in Google Drive (leave empty to upload)"}
if not fasta_filename:
    uploaded = files.upload()
    fasta_filename = list(uploaded.keys())[0]

MODEL_NAME_TO_SHORT_KEY_MAP = {
    "Rostlab/prot_t5_xl_half_uniref50-enc": "prot_t5",
    "Rostlab/ProstT5_fp16": "prost_t5",
    "ElnaggarLab/ankh-base": "ankh_base",
    "ElnaggarLab/ankh-large": "ankh_large",
    "facebook/esm2_t6_8M_UR50D": "esm2_8m",
    "facebook/esm2_t12_35M_UR50D": "esm2_35m",
    "facebook/esm2_t30_150M_UR50D": "esm2_150m",
    "facebook/esm2_t33_650M_UR50D": "esm2_650m",
    "EvolutionaryScale/esm3-sm-open-v1": "esm3_open",
    "EvolutionaryScale/esmc-300m-2024-12": "esmc_300m",
    "EvolutionaryScale/esmc-600m-2024-12": "esmc_600m",
}


def seq_preprocess(df, model_type="esm"):
    # Store raw sequence before preprocessing
    df["raw_sequence"] = df["sequence"]

    # Replace special amino acids with X
    df["sequence"] = df["sequence"].str.replace("[BJOUZ]", "X", regex=True)

    # Add spaces between amino acids for specific models
    if model_type in ["prost_t5", "prot_t5"]:
        df["sequence"] = df["sequence"].apply(lambda seq: " ".join(list(seq)))

    # Add prefix for ProstT5 model
    if model_type == "prost_t5":
        df["sequence"] = df["sequence"].apply(lambda seq: "<AA2fold> " + seq)

    return df


# @markdown ---
def setup_model(checkpoint):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = None
    tokenizer = None
    mod_type = "unknown"

    print(f"Setting up model: {checkpoint} on {device}")

    if "esm2" in checkpoint:
        mod_type = "esm_transformer"
        tokenizer = AutoTokenizer.from_pretrained(checkpoint)
        model = EsmModel.from_pretrained(checkpoint)
    elif "esm3-sm-open-v1" in checkpoint:
        mod_type = "native_esm3"
        model = ESM3.from_pretrained("esm3-open")
    elif checkpoint == "EvolutionaryScale/esmc-300m-2024-12":
        mod_type = "native_esmc"
        model = ESMC.from_pretrained("esmc_300m")
    elif checkpoint == "EvolutionaryScale/esmc-600m-2024-12":
        mod_type = "native_esmc"
        model = ESMC.from_pretrained("esmc_600m")
    elif "ankh" in checkpoint:
        mod_type = "ankh"
        tokenizer = AutoTokenizer.from_pretrained(checkpoint)
        model = T5EncoderModel.from_pretrained(checkpoint)
    elif "prot_t5" in checkpoint or "ProstT5" in checkpoint:
        is_prostt5 = "prostt5" in checkpoint.lower()
        mod_type = "prost_t5" if is_prostt5 else "prot_t5"
        tokenizer = T5Tokenizer.from_pretrained(
            checkpoint, do_lower_case=not is_prostt5
        )
        model = T5EncoderModel.from_pretrained(checkpoint, torch_dtype=torch.float16)
        if device.type == "cuda":
            model = model.half()
    else:
        raise ValueError(f"Unknown model checkpoint type: {checkpoint}")

    return model.to(device), tokenizer, mod_type


def read_fasta(file_path):
    headers = []
    sequences = []
    with Fasta(str(file_path)) as fasta_data:
        for seq_record in fasta_data:
            headers.append(seq_record.name)
            sequences.append(str(seq_record))
    return headers, sequences


def create_embedding(
    checkpoint,
    df,
    emb_type="per_prot",
    output_file="protein_embeddings.h5",
    max_len=2000,
):
    print("Setting up model...")
    model_instance, tokenizer_instance, mod_type = setup_model(checkpoint)
    model_instance.eval()
    df_processed = seq_preprocess(df.copy(), mod_type)
    device = model_instance.device

    def compute_embedding(
        sequence, current_emb_type, current_mod_type, model, tokenizer
    ):
        if current_mod_type.startswith("native_esm"):
            protein = ESMProtein(sequence=sequence)
            tokens = model.encode(protein).to(device)
            raw_per_residue_embeddings = None
            if current_mod_type == "native_esm3":
                out = model.forward_and_sample(
                    tokens, SamplingConfig(return_per_residue_embeddings=True)
                )
                raw_per_residue_embeddings = (
                    out.per_residue_embedding.squeeze(0).cpu().numpy()
                )
            elif current_mod_type == "native_esmc":
                out = model.logits(
                    tokens, LogitsConfig(sequence=True, return_embeddings=True)
                )
                raw_per_residue_embeddings = out.embeddings.squeeze(0).cpu().numpy()
            else:
                raise ValueError(f"Unknown native ESM model type: {current_mod_type}")

            # Directly slice to remove BOS/EOS tokens
            cleaned_per_residue_embeddings = raw_per_residue_embeddings[1:-1, :]

            if current_emb_type == "per_prot":
                return cleaned_per_residue_embeddings.mean(axis=0)
            elif current_emb_type == "per_res":
                return cleaned_per_residue_embeddings
        else:  # Transformers-based models
            if tokenizer is None:
                raise ValueError(
                    f"Tokenizer not available for model type: {current_mod_type}"
                )

            tokenization_input = sequence
            tokenizer_call_kwargs = {
                "return_tensors": "pt",
                "max_length": max_len + 2,
                "truncation": True,
                "padding": True,
                "add_special_tokens": True,  # Adds EOS for T5 models
            }

            inputs = tokenizer(tokenization_input, **tokenizer_call_kwargs).to(device)

            with torch.no_grad():
                embeddings_tensor = model(**inputs).last_hidden_state.cpu()
            if embeddings_tensor.ndim == 3 and embeddings_tensor.shape[0] == 1:
                embeddings_tensor = embeddings_tensor.squeeze(0)
            embeddings_np_arr = embeddings_tensor.numpy()

            # Slicing logic based on model type
            token_embeddings_for_avg = None
            per_residue_slice = None

            if current_mod_type == "esm_transformer":
                # Skip <cls> and <eos>
                token_embeddings_for_avg = embeddings_np_arr[1:-1, :]
                per_residue_slice = embeddings_np_arr[1:-1, :]
            elif current_mod_type == "prost_t5":
                # Skip <AA2fold> (and potential BOS) and <eos>
                token_embeddings_for_avg = embeddings_np_arr[1:-1, :]
                per_residue_slice = embeddings_np_arr[1:-1, :]
            elif current_mod_type in ["prot_t5", "ankh"]:  # Regular ProtT5 and Ankh
                # Skip <eos>
                token_embeddings_for_avg = embeddings_np_arr[:-1, :]
                per_residue_slice = embeddings_np_arr[:-1, :]
            else:  # Should not happen if mod_type is correctly set
                token_embeddings_for_avg = embeddings_np_arr
                per_residue_slice = embeddings_np_arr

            if current_emb_type == "per_prot":
                if token_embeddings_for_avg.shape[0] == 0:
                    print(
                        f"Warning: No token embeddings to average for sequence after slicing for model {current_mod_type}. Original shape: {embeddings_np_arr.shape}"
                    )
                    return np.array([])  # Or handle as error appropriately
                return token_embeddings_for_avg.mean(axis=0)
            elif current_emb_type == "per_res":
                return per_residue_slice

        raise ValueError(
            f"Invalid embedding type '{current_emb_type}' or model type '{current_mod_type}' combination if not returned by above"
        )

    invalid_headers = df_processed[df_processed["header"].str.contains("/")][
        "header"
    ].tolist()
    if invalid_headers:
        error_msg = (
            "ERROR: The following sequence headers contain '/' which is not allowed in HDF5 dataset names (HDF5 prior to groups):\n"
            + "\n".join(invalid_headers)
            + "\nPlease remove or replace '/' characters in these headers and try again."
        )
        raise ValueError(error_msg)

    print("Generating embeddings...")
    with h5py.File(output_file, "a") as hdf:
        for _, row_data in tqdm(df_processed.iterrows(), total=len(df_processed)):
            sequence_val = row_data["sequence"]
            header_val = row_data["header"]

            # Skip sequences that exceed the maximum length
            if len(row_data["raw_sequence"]) > max_len:
                tqdm.write(
                    f"Skipping sequence {header_val} (length {len(row_data['raw_sequence'])}) as it exceeds max length {max_len}."
                )
                continue

            if header_val in hdf:
                tqdm.write(f"Skipping existing embedding for {header_val}")
                continue
            embedding_result = compute_embedding(
                sequence_val, emb_type, mod_type, model_instance, tokenizer_instance
            )
            hdf.create_dataset(
                name=header_val,
                data=embedding_result.astype(np.float32)
                if isinstance(embedding_result, np.ndarray)
                else embedding_result,
            )

    del model_instance, tokenizer_instance
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

In [None]:
# @title ### 4. Generate Embeddings
# @markdown Click the play button to start generating embeddings

fasta_path = Path(fasta_filename)

# Get the short key for the selected model_name, fallback to sanitized full name
short_model_key = MODEL_NAME_TO_SHORT_KEY_MAP.get(
    model_name, model_name.replace("/", "_")
)
output_file = str(fasta_path.with_name(f"{fasta_path.stem}_{short_model_key}.h5"))

headers, sequences = read_fasta(fasta_path)
df = pd.DataFrame({"header": headers, "sequence": sequences})
print(f"Processing {len(df)} sequences for model {model_name}...")

create_embedding(
    model_name,
    df,
    emb_type=embedding_type,
    output_file=output_file,
    max_len=max_sequence_length,
)

print(f"\nEmbeddings saved to {output_file}")

In [None]:
# @title ### 5. Download Results
# @markdown Run this cell to download your embeddings file
files.download(output_file)