In [10]:
import torch
import pandas as pd
from transformers import AutoTokenizer, AutoModel
from tqdm import tqdm
import numpy as np
import os

In [11]:
# Constants
DATA_PATH = "./data/protein_info.csv"
OUTPUT_PATH = "./esm_bluebert/esm_embeddings.npz"
BATCH_SIZE = 16
MODEL_NAME = "facebook/esm2_t30_150M_UR50D"
MAX_SEQ_LENGTH = 1024
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [12]:
def load_model_and_tokenizer(model_name, device):
    """Load ESM model and tokenizer"""
    print(f"Loading model {model_name} on {device}...")
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModel.from_pretrained(model_name).to(device)
    model = model.half() if device == "cuda" else model
    return tokenizer, model

def load_protein_data(data_path):
    """Load and preprocess protein data"""
    print(f"Loading data from {data_path}...")
    df = pd.read_csv(data_path)
    df = df.drop_duplicates(subset=["NCBI_gene_id"], keep="first")
    gene_ids = df["NCBI_gene_id"].astype(str).tolist()
    sequences = df["Sequence"].astype(str).tolist()
    print(f"Found {len(gene_ids)} unique proteins")
    return gene_ids, sequences

def generate_embeddings(batch_sequences, tokenizer, model, device, max_length):
    """Generate embeddings for a batch of protein sequences"""
    try:
        inputs = tokenizer(
            batch_sequences,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=max_length
        ).to(device)
        
        with torch.no_grad():
            if device == "cuda":
                with torch.amp.autocast(device_type='cuda'):
                    outputs = model(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"])
            else:
                outputs = model(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"])
        
        embeddings = outputs.last_hidden_state[:, 0, :].cpu().float().numpy()
        return embeddings
    
    except Exception as e:
        print(f"Error processing batch: {str(e)}")
        return None

def save_embeddings(output_path, gene_ids, embeddings):
    """Save embeddings to file"""
    print(f"Saving embeddings to {output_path}...")
    np.savez(
        output_path,
        gene_ids=np.array(gene_ids),
        embeddings=embeddings
    )

def embeddings_exist(output_path):
    """Check if embeddings file already exists"""
    return os.path.exists(output_path)

主运行代码

In [13]:
def main():
    print(f"Using device: {DEVICE}")
    
    # Check if embeddings already exist
    if embeddings_exist(OUTPUT_PATH):
        print(f"Embeddings already exist at {OUTPUT_PATH}. Skipping generation.")
        return
    
    # Load model and data
    tokenizer, model = load_model_and_tokenizer(MODEL_NAME, DEVICE)
    gene_ids, sequences = load_protein_data(DATA_PATH)
    
    # Process sequences in batches
    all_embeddings = []
    valid_gene_ids = []
    
    progress = tqdm(total=len(sequences), desc="Processing proteins")
    
    for i in range(0, len(sequences), BATCH_SIZE):
        batch_seqs = sequences[i:i+BATCH_SIZE]
        batch_ids = gene_ids[i:i+BATCH_SIZE]
        
        # Filter out invalid sequences
        valid_indices = [idx for idx, seq in enumerate(batch_seqs) 
                        if isinstance(seq, str) and len(seq.strip()) > 0]
        valid_seqs = [batch_seqs[idx] for idx in valid_indices]
        valid_ids = [batch_ids[idx] for idx in valid_indices]
        
        if not valid_seqs:
            progress.update(len(batch_seqs))
            continue
        
        # Generate embeddings
        embeddings = generate_embeddings(
            valid_seqs, 
            tokenizer, 
            model, 
            DEVICE, 
            MAX_SEQ_LENGTH
        )
        
        if embeddings is not None and len(embeddings) > 0:
            all_embeddings.append(embeddings)
            valid_gene_ids.extend(valid_ids)
        
        progress.update(len(batch_seqs))
    
    progress.close()
    
    # Check if any embeddings were generated
    if len(all_embeddings) == 0:
        raise ValueError("No embeddings generated. Check input data and model.")
    
    # Save results
    final_embeddings = np.concatenate(all_embeddings, axis=0)
    save_embeddings(OUTPUT_PATH, valid_gene_ids, final_embeddings)
    
    print(f"Successfully processed {len(valid_gene_ids)} proteins")

if __name__ == "__main__":
    main()

Using device: cuda
Loading model facebook/esm2_t30_150M_UR50D on cuda...


Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t30_150M_UR50D and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Loading data from ./data/protein_info.csv...
Found 9820 unique proteins


Processing proteins: 100%|██████████| 9820/9820 [03:25<00:00, 47.72it/s] 

Saving embeddings to ./esm_bluebert/esm_embeddings.npz...
Successfully processed 9820 proteins



