In [None]:
#@title # Protein Language Model Embeddings Generator 🧬

In [None]:
#@title 1. Install Required Dependencies
#@markdown Run this cell first to install the necessary packages
%%capture
!pip install pyfaidx transformers torch pandas h5py numpy tqdm


In [None]:
#@title ### 2. Import Libraries and Setup
from pathlib import Path
import h5py
import numpy as np
import pandas as pd
import torch
from pyfaidx import Fasta
from tqdm import tqdm
from transformers import AutoTokenizer, EsmModel, T5EncoderModel, T5Tokenizer
from google.colab import files
import os

In [None]:
#@title ### 3. Select Model and Upload File

#@markdown Choose a protein language model:
model_name = "Rostlab/prot_t5_xl_half_uniref50-enc" #@param ["Rostlab/prot_t5_xl_half_uniref50-enc", "ElnaggarLab/ankh-base", "ElnaggarLab/ankh-large", "facebook/esm2_t6_8M_UR50D", "facebook/esm2_t12_35M_UR50D", "facebook/esm2_t30_150M_UR50D", "facebook/esm2_t33_650M_UR50D"]

#@markdown Choose embedding type:
embedding_type = "per_prot" #@param ["per_prot", "per_res"]

# @markdown Upload your FASTA file:
from google.colab import files
uploaded = files.upload()
fasta_filename = list(uploaded.keys())[0]

#Hidden helper functions
def seq_preprocess(df, model_type="esm"):
    df["sequence"] = df["sequence"].str.replace("[UZO]", "X", regex=True)
    if model_type == "pt":
        df["sequence"] = df.apply(lambda row: " ".join(row["sequence"]), axis=1)
    return df
#@markdown ---
def setup_model(checkpoint):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if "esm" in checkpoint:
        mod_type = "esm"
        tokenizer = AutoTokenizer.from_pretrained(checkpoint)
        model = EsmModel.from_pretrained(checkpoint)
    elif "ankh" in checkpoint:
        mod_type = "ankh"
        tokenizer = AutoTokenizer.from_pretrained(checkpoint)
        model = T5EncoderModel.from_pretrained(checkpoint)
    else:
        mod_type = "pt"
        tokenizer = T5Tokenizer.from_pretrained(checkpoint)
        model = T5EncoderModel.from_pretrained(checkpoint, torch_dtype=torch.float16)
        model = model.half()
    return model.to(device), tokenizer, mod_type

def read_fasta(file_path):
    headers = []
    sequences = []
    fasta = Fasta(str(file_path))
    for seq in fasta:
        headers.append(seq.name)
        sequences.append(str(seq))
    return headers, sequences

def create_embedding(checkpoint, df, emb_type="per_prot", output_file="protein_embeddings.h5"):
    print("Setting up model...")
    model, tokenizer, mod_type = setup_model(checkpoint)
    model.eval()
    df = seq_preprocess(df, mod_type)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    def compute_embedding(sequence, emb_type):
        inputs = tokenizer(
            sequence,
            return_tensors="pt",
            max_length=10_000,
            truncation=True,
            padding=True,
            add_special_tokens=True,
        ).to(device)
        with torch.no_grad():
            outputs = model(**inputs).last_hidden_state.cpu().numpy()
        if emb_type == "per_res":
            if mod_type in ["pt", "ankh"]:
                outputs = outputs[:-1, :]
            elif mod_type == "esm":
                outputs = np.squeeze(outputs, axis=0)[:-1, :]
            return outputs
        elif emb_type == "per_prot":
            return outputs.mean(axis=1).flatten()
        else:
            raise ValueError("Input valid embedding type")

    print("Generating embeddings...")
    with h5py.File(output_file, "a") as hdf:
        for _, row in tqdm(df.iterrows(), total=len(df)):
            sequence = row["sequence"]
            header = row["header"]
            if header in hdf:
                continue
            embedding = compute_embedding(sequence, emb_type)
            hdf.create_dataset(name=header, data=embedding)

    del model, tokenizer
    torch.cuda.empty_cache()

In [None]:
#@title ### 4. Generate Embeddings
#@markdown Click the play button to start generating embeddings

fasta_path = Path(fasta_filename)
output_file = str(fasta_path.with_suffix(".h5"))

# Create DataFrame and generate embeddings
headers, sequences = read_fasta(fasta_path)
df = pd.DataFrame({"header": headers, "sequence": sequences})
print(f"Processing {len(df)} sequences...")

create_embedding(
    model_name,
    df,
    emb_type=embedding_type,
    output_file=output_file
)

print(f"\nEmbeddings saved to {output_file}")

In [None]:
#@title ### 5. Download Results
#@markdown Run this cell to download your embeddings file
files.download(output_file)