In [None]:
!pip install transformers
!pip install sentencepiece
!pip install biopython
!pip install h5py
!pip install torch
!pip install tqdm

In [None]:
import torch
from Bio import SeqIO
from transformers import T5Tokenizer, T5EncoderModel
import numpy as np
from pathlib import Path
import gc
import pickle
import time
import os
from tqdm import tqdm
import shutil
import json
import re

In [None]:
# Specify your model name or path
# For Hugging Face models, use the model ID (e.g., "Rostlab/prot_t5_xl_uniref50")
# For local models, use the local path
model_name = "Rostlab/prot_t5_xl_uniref50"  # or use local path like "./embedding_model/Rostlab_prot_t5_xl_uniref50"

print(f"Loading tokenizer for {model_name}...")
tokenizer = T5Tokenizer.from_pretrained(model_name, do_lower_case=False)

print(f"Loading model {model_name}...")
model = T5EncoderModel.from_pretrained(model_name)

# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model.eval()  # Set to evaluation mode

print(f"Model loaded successfully on {device}")

In [None]:
# Path to your FASTA file
fasta_path = "train_sequences.fasta"  # Update this path

# Read all sequences
sequences = []
for record in SeqIO.parse(fasta_path, "fasta"):
    sequences.append((record.id, str(record.seq)))

print(f"Total sequences loaded: {len(sequences)}")

In [None]:
def prepare_sequence(seq):
    """
    Prepare protein sequence for ProtTrans model.
    ProtT5 expects sequences with spaces between amino acids.
    """
    # Add space between each amino acid
    return " ".join(list(seq))


def extract_embeddings_batch_with_progress(sequence_list, model, tokenizer, device, batch_size=8):
    """
    Extract embeddings for a batch of sequences with progress bar
    
    Args:
        sequence_list: List of tuples (id, sequence)
        model: ProtTrans T5 model
        tokenizer: T5 tokenizer
        device: torch device
        batch_size: Number of sequences to process at once
    
    Returns:
        Dictionary with sequence IDs as keys and embeddings as values
    """
    embeddings_dict = {}
    
    total_batches = (len(sequence_list) + batch_size - 1) // batch_size
    
    # Create progress bar
    pbar = tqdm(total=len(sequence_list), desc="Processing sequences", 
                unit="seq", ncols=100)
    
    # Process in batches
    for i in range(0, len(sequence_list), batch_size):
        batch = sequence_list[i:i+batch_size]
        
        # Prepare sequences (add spaces between amino acids)
        batch_ids = [seq_id for seq_id, _ in batch]
        batch_seqs = [prepare_sequence(seq) for _, seq in batch]
        
        # Tokenize sequences
        ids = tokenizer(batch_seqs, add_special_tokens=True, padding="longest", return_tensors="pt")
        input_ids = ids['input_ids'].to(device)
        attention_mask = ids['attention_mask'].to(device)
        
        # Extract embeddings
        with torch.no_grad():
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            embeddings = outputs.last_hidden_state
        
        # Per-sequence mean pooling (average over sequence length)
        for j, (seq_id, seq) in enumerate(batch):
            # Get sequence length (excluding padding)
            seq_len = attention_mask[j].sum().item()
            
            # Mean pooling over sequence length (excluding padding and special tokens)
            # Note: ProtT5 adds special tokens, so we take mean over valid positions
            seq_embedding = embeddings[j, :seq_len].mean(dim=0)
            
            embeddings_dict[seq_id] = seq_embedding.cpu().numpy()
        
        # Update progress bar
        pbar.update(len(batch))
        
        # Clear GPU cache periodically
        if i % (batch_size * 10) == 0:
            torch.cuda.empty_cache()
            gc.collect()
    
    pbar.close()
    return embeddings_dict

In [None]:
from collections import defaultdict

def group_sequences_by_length(sequences, length_bins=[100, 500, 1000, 5000, 10000, 20000]):
    """Group sequences into bins based on length"""
    grouped = defaultdict(list)
    
    for seq_id, seq in sequences:
        seq_len = len(seq)
        
        # Find appropriate bin
        bin_idx = 0
        for i, bin_size in enumerate(length_bins):
            if seq_len <= bin_size:
                bin_idx = i
                break
        else:
            bin_idx = len(length_bins)  # Sequences longer than max bin
        
        grouped[bin_idx].append((seq_id, seq))
    
    return grouped

# Group sequences
grouped_sequences = group_sequences_by_length(sequences)

# Print distribution
print("Sequence length distribution:")
bins = [100, 500, 1000, 5000, 10000, 20000]
for idx in sorted(grouped_sequences.keys()):
    if idx < len(bins):
        print(f"  Bin {idx} (≤{bins[idx]} aa): {len(grouped_sequences[idx])} sequences")
    else:
        print(f"  Bin {idx} (>{bins[-1]} aa): {len(grouped_sequences[idx])} sequences")

In [None]:
def save_embeddings(embeddings_dict, output_dir, bin_idx=None, is_cumulative=False, is_final=False):
    """
    Save embeddings with proper naming
    """
    os.makedirs(output_dir, exist_ok=True)
    
    if is_final:
        filename = "embeddings_final"
    elif is_cumulative:
        filename = "embeddings_cumulative"
    elif bin_idx is not None:
        filename = f"embeddings_bin_{bin_idx}"
    else:
        filename = "embeddings"
    
    # Save as pickle
    pkl_path = os.path.join(output_dir, f"{filename}.pkl")
    with open(pkl_path, 'wb') as f:
        pickle.dump(embeddings_dict, f, protocol=4)
    
    # Save as npz for easier loading in other tools
    npz_path = os.path.join(output_dir, f"{filename}.npz")
    np.savez_compressed(npz_path, **embeddings_dict)
    
    return pkl_path, npz_path

In [None]:
def load_progress(output_dir):
    """
    Load previous progress if exists
    """
    metadata_path = os.path.join(output_dir, "progress_metadata.pkl")
    cumulative_path = os.path.join(output_dir, "embeddings_cumulative.pkl")
    
    if os.path.exists(metadata_path) and os.path.exists(cumulative_path):
        try:
            with open(metadata_path, 'rb') as f:
                metadata = pickle.load(f)
            
            with open(cumulative_path, 'rb') as f:
                embeddings = pickle.load(f)
            
            return metadata, embeddings
        except Exception as e:
            print(f"Warning: Could not load progress: {e}")
            return None, {}
    
    return None, {}

In [None]:
# Define output directory
output_dir = "./extracted_embeddings/ProtTrans-embeddings/"

# Try to load previous progress
print("Checking for previous progress...")
metadata, all_embeddings = load_progress(output_dir)

if metadata:
    print(f"✓ Found previous progress!")
    print(f"  Last completed bin: {metadata['last_bin']}")
    print(f"  Sequences processed: {metadata['processed_sequences']}/{metadata['total_sequences']}")
    print(f"  Timestamp: {metadata['timestamp']}")
    start_bin = metadata['last_bin'] + 1
else:
    print("No previous progress found. Starting from scratch.")
    all_embeddings = {}
    start_bin = 0

In [None]:
# Process each group with appropriate batch size
# ProtTrans models typically need smaller batch sizes than ESM due to larger model size
batch_sizes = {
    0: 32,   # Short sequences (≤100 aa)
    1: 8,    # Medium sequences (≤500 aa)
    2: 4,    # Long sequences (≤1000 aa)
    3: 2,    # Very long sequences (≤5000 aa)
    4: 1,    # Extremely long sequences (≤10000 aa)
    5: 1,    # Ultra long sequences (≤20000 aa)
    6: 1     # Longest sequences (>20000 aa)
}

total_sequences = sum(len(seqs) for seqs in grouped_sequences.values())

print(f"Total sequences to process: {total_sequences}")
print(f"Starting from bin: {start_bin}")

In [None]:
for bin_idx, seqs in sorted(grouped_sequences.items()):
    # Skip already processed bins
    if bin_idx < start_bin:
        print(f"Skipping Bin {bin_idx} (already processed)")
        continue
    
    batch_size = batch_sizes.get(bin_idx, 1)
    total_batches = (len(seqs) + batch_size - 1) // batch_size
    
    print(f"\n{'='*80}")
    print(f"BIN {bin_idx}: Processing {len(seqs)} sequences")
    print(f"Batch size: {batch_size} | Total batches: {total_batches}")
    print(f"{'='*80}\n")
    
    # Extract embeddings for this bin
    bin_embeddings = extract_embeddings_batch_with_progress(
        seqs, model, tokenizer, device, batch_size=batch_size
    )
    
    # Merge with cumulative embeddings
    all_embeddings.update(bin_embeddings)
    
    # Save bin-specific results
    print(f"\n💾 Saving bin {bin_idx} results...")
    pkl_path, npz_path = save_embeddings(bin_embeddings, output_dir, bin_idx=bin_idx)
    print(f"✓ Bin {bin_idx} saved: {pkl_path}")
    
    # Save cumulative results
    print(f"💾 Saving cumulative results...")
    pkl_path, npz_path = save_embeddings(all_embeddings, output_dir, is_cumulative=True)
    print(f"✓ Cumulative embeddings saved: {pkl_path}")
    
    # Save metadata for progress tracking
    metadata = {
        'last_bin': bin_idx,
        'processed_sequences': len(all_embeddings),
        'total_sequences': total_sequences,
        'timestamp': time.strftime('%Y-%m-%d %H:%M:%S')
    }
    metadata_path = os.path.join(output_dir, "progress_metadata.pkl")
    with open(metadata_path, 'wb') as f:
        pickle.dump(metadata, f)
    
    print(f"✓ Progress: {len(all_embeddings)}/{total_sequences} sequences processed")
    
    # Clear memory
    del bin_embeddings
    gc.collect()
    torch.cuda.empty_cache()

print(f"\n{'='*80}")
print("✅ All bins processed successfully!")
print(f"{'='*80}")

In [None]:
# Final save
print("💾 Creating final output files...")
pkl_path, npz_path = save_embeddings(all_embeddings, output_dir, is_final=True)
print(f"✓ Final embeddings saved:")
print(f"  - {pkl_path}")
print(f"  - {npz_path}")
print(f"\n✅ Processing complete! Total embeddings: {len(all_embeddings)}")