In [1]:
import torch
import pandas as pd
from transformers import AutoTokenizer, AutoModel
from tqdm import tqdm
import numpy as np
import os

In [2]:
# Constants
DATA_PATH = "./data/protein_info.csv"     
OUTPUT_PATH = "./esm_bluebert/bluebert_embeddings.npz"
BATCH_SIZE = 32
MODEL_NAME = "bionlp/bluebert_pubmed_mimic_uncased_L-12_H-768_A-12"
MAX_SEQ_LENGTH = 256
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
def load_model_and_tokenizer():
    """Load and return the tokenizer and model."""
    print(f"Loading model {MODEL_NAME} on {DEVICE}...")
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    model = AutoModel.from_pretrained(MODEL_NAME).to(DEVICE)
    return tokenizer, model

def load_and_preprocess_data(data_path):
    """Load and preprocess the protein data."""
    print(f"Loading data from {data_path}...")
    df = pd.read_csv(data_path)
    df = df.drop_duplicates(subset=["NCBI_gene_id"], keep="first")
    
    gene_ids = df["NCBI_gene_id"].astype(str).tolist()
    texts = df["Protein names"].astype(str).tolist()
    
    print(f"Found {len(gene_ids)} unique proteins")
    return gene_ids, texts

def generate_text_embeddings(tokenizer, model, batch_texts):
    """Generate embeddings for a batch of texts."""
    try:
        inputs = tokenizer(
            batch_texts,
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=MAX_SEQ_LENGTH,
            add_special_tokens=True
        ).to(DEVICE)
        
        with torch.no_grad():
            if DEVICE == "cuda":
                with torch.amp.autocast(device_type='cuda'):
                    outputs = model(**inputs)
            else:
                outputs = model(**inputs)
        
        embeddings = outputs.last_hidden_state[:, 0, :].cpu().float().numpy()
        return embeddings
    
    except Exception as e:
        print(f"Error processing batch: {str(e)}")
        return None

def process_texts_in_batches(tokenizer, model, gene_ids, texts):
    """Process texts in batches and generate embeddings."""
    all_embeddings = []
    valid_gene_ids = []
    
    progress = tqdm(total=len(texts), desc="Processing texts")
    
    for i in range(0, len(texts), BATCH_SIZE):
        batch_texts = texts[i:i+BATCH_SIZE]
        batch_ids = gene_ids[i:i+BATCH_SIZE]
        
        valid_indices = [idx for idx, text in enumerate(batch_texts) 
                        if isinstance(text, str) and len(text.strip()) > 0]
        valid_texts = [batch_texts[idx] for idx in valid_indices]
        valid_ids = [batch_ids[idx] for idx in valid_indices]
        
        if not valid_texts:
            progress.update(len(batch_texts))
            continue
        
        embeddings = generate_text_embeddings(tokenizer, model, valid_texts)
        
        if embeddings is not None and len(embeddings) > 0:
            all_embeddings.append(embeddings)
            valid_gene_ids.extend(valid_ids)
        
        progress.update(len(batch_texts))
    
    progress.close()
    
    if len(all_embeddings) == 0:
        raise ValueError("No embeddings generated. Check input data and model.")
    
    final_embeddings = np.concatenate(all_embeddings, axis=0)
    return valid_gene_ids, final_embeddings

def save_embeddings(output_path, gene_ids, embeddings):
    """Save embeddings to a file."""
    print(f"Saving embeddings to {output_path}...")
    np.savez(
        output_path,
        gene_ids=np.array(gene_ids),
        embeddings=embeddings
    )
    print(f"Successfully processed {len(gene_ids)} protein names")

def embeddings_exist(output_path):
    """Check if embeddings file already exists."""
    return os.path.exists(output_path)

In [4]:
def main():
    # Check if embeddings already exist
    if embeddings_exist(OUTPUT_PATH):
        print(f"Embeddings already exist at {OUTPUT_PATH}. Skipping generation.")
        return
    
    # Load model and tokenizer
    tokenizer, model = load_model_and_tokenizer()
    
    # Load and preprocess data
    gene_ids, texts = load_and_preprocess_data(DATA_PATH)
    
    # Process texts and generate embeddings
    valid_gene_ids, final_embeddings = process_texts_in_batches(
        tokenizer, model, gene_ids, texts
    )
    
    # Save embeddings
    save_embeddings(OUTPUT_PATH, valid_gene_ids, final_embeddings)

if __name__ == "__main__":
    main()

Loading model bionlp/bluebert_pubmed_mimic_uncased_L-12_H-768_A-12 on cuda...
Loading data from ./data/protein_info.csv...
Found 9820 unique proteins


Processing texts: 100%|██████████| 9820/9820 [00:05<00:00, 1642.86it/s]

Saving embeddings to ./esm_bluebert/bluebert_embeddings.npz...
Successfully processed 9820 protein names



