In [3]:
import torch  # PyTorch for model loading and tensor operations
import pandas as pd  # For handling dataframes and CSV
import pathlib  # For working with file paths
from esm import pretrained, FastaBatchedDataset  # ESM utilities for loading model and dataset

def extract_esm_embeddings_to_csv(
    fasta_file_path,
    output_csv_path,
    model_name='esm2_t30_150M_UR50D',
    repr_layer=30,
    tokens_per_batch=4096,
    max_seq_len=1500
):
    """
    Extract ESM protein embeddings from a FASTA file and save them to a CSV file.

    Parameters:
    fasta_file_path : str - Path to input FASTA file
    output_csv_path : str - Path to output CSV file
    model_name : str - ESM model name to use
    repr_layer : int - Layer number to extract embeddings from
    tokens_per_batch : int - Number of tokens per batch
    max_seq_len : int - Maximum length of sequence to consider
    """

    # Load the pretrained ESM model and corresponding alphabet for tokenization
    print("Loading model...")
    model, alphabet = pretrained.load_model_and_alphabet(model_name)

    # Set the model to evaluation mode (no training, disables dropout)
    model.eval()

    # Move model to GPU if available for faster processing
    if torch.cuda.is_available():
        model = model.cuda()

    # Create a batch converter function to tokenize sequences
    batch_converter = alphabet.get_batch_converter(max_seq_len)

    # Convert fasta file path to pathlib object
    fasta_path = pathlib.Path(fasta_file_path)

    # Create a dataset from the input FASTA file
    dataset = FastaBatchedDataset.from_file(fasta_path)

    # Generate indices to form batches based on token length
    batches = dataset.get_batch_indices(tokens_per_batch, extra_toks_per_seq=1)

    # Create a DataLoader to iterate through batches efficiently
    data_loader = torch.utils.data.DataLoader(
        dataset, 
        collate_fn=batch_converter,  # Convert sequences to model tokens
        batch_sampler=batches  # Provide custom batch indices
    )

    # Initialize lists to collect data for final CSV
    sequence_ids = []  # To store sequence headers (SeqID)
    sequences = []     # To store raw sequence strings
    embeddings = []    # To store embedding vectors

    print(f"Extracting embeddings using model: {model_name}, layer: {repr_layer}")

    # Disable gradient calculations for inference
    with torch.no_grad():
        # Iterate over each batch
        for batch_idx, (labels, seqs, tokens) in enumerate(data_loader):
            print(f"Processing batch {batch_idx + 1} of {len(batches)}...")

            # Move input tokens to GPU if available
            if torch.cuda.is_available():
                tokens = tokens.cuda(non_blocking=True)

            # Perform forward pass to get model outputs
            out = model(tokens, repr_layers=[repr_layer], return_contacts=False)

            # Extract embeddings from the specified layer
            reps = out["representations"][repr_layer].cpu()

            # Process each sequence in the current batch
            for i, label in enumerate(labels):
                entry_id = label.split()[0]  # Extract the sequence ID
                sequence = seqs[i]  # Extract the sequence string
                seq_len = len(sequence)  # Get length of sequence

                # Mean-pool the embedding across sequence length (excluding start/end tokens)
                rep = reps[i, 1:seq_len + 1].mean(0).numpy()

                # Append ID, sequence, and embedding vector to lists
                sequence_ids.append(entry_id)
                sequences.append(sequence)
                embeddings.append(rep)

    # Create a DataFrame from the collected embeddings
    df = pd.DataFrame(embeddings)  # Convert list of vectors to DataFrame
    df.insert(0, 'Sequence', sequences)  # Add Sequence column at the beginning
    df.insert(0, 'SeqID', sequence_ids)  # Add SeqID column before Sequence

    # Convert output file path to pathlib object
    output_csv_path = pathlib.Path(output_csv_path)

    # Save the DataFrame as a CSV file
    df.to_csv(output_csv_path, index=False)

    # Notify user of successful save
    print(f"Saved embeddings to: {output_csv_path}")

# === Example Usage ===
if __name__ == "__main__":
    extract_esm_embeddings_to_csv(
        fasta_file_path="example.fasta",  # Input FASTA
        output_csv_path="esm2_t30_embeddings.csv",           # Output CSV
        model_name="esm2_t30_150M_UR50D",                    # ESM model
        repr_layer=30                                        # Layer to extract
    )


Loading model...
Extracting embeddings using model: esm2_t30_150M_UR50D, layer: 30
Processing batch 1 of 1...
Saved embeddings to: esm2_t30_embeddings.csv
