In [None]:
# Setup directories and install dependencies
!mkdir -p "protbert_embeddings"
%pip install biopython

In [None]:
# Import required libraries
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from transformers import BertModel, BertTokenizer
from Bio.Seq import Seq
import re
from tqdm import tqdm

In [4]:
# ProtBERT embedding model class
class ProtBERTEmbeddingModel:
    def __init__(self, model_name="Rostlab/prot_bert", max_length=512):
        self.model_name = model_name
        self.max_length = max_length
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        self.tokenizer = BertTokenizer.from_pretrained(model_name, do_lower_case=False)
        self.model = BertModel.from_pretrained(model_name, output_hidden_states=True)
        self.model = self.model.to(self.device)
        self.model.eval()
    
    def translate_dna_to_protein(self, dna_sequence):
        if not isinstance(dna_sequence, str) or len(dna_sequence) == 0:
            return ""
        
        # Ensure sequence length is multiple of 3
        if len(dna_sequence) % 3 != 0:
            dna_sequence = dna_sequence[:-(len(dna_sequence) % 3)]
        
        # Translate sequence to protein
        if len(dna_sequence) > 0:
            protein_sequence = str(Seq(dna_sequence).translate())
            # Replace unsupported amino acids with X
            protein_sequence = re.sub(r"[UZOB*]", "X", protein_sequence)
            return protein_sequence
        return ""
    
    def format_protein_sequence(self, protein_sequence):
        if not protein_sequence:
            return ""
        return " ".join(list(protein_sequence))
    
    def encode_sequences(self, protein_sequences):
        formatted_sequences = [self.format_protein_sequence(seq) for seq in protein_sequences]
        
        encodings = self.tokenizer(
            formatted_sequences,
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=self.max_length
        )
        
        return {key: tensor.to(self.device) for key, tensor in encodings.items()}
    
    def get_embeddings(self, encodings):
        with torch.no_grad():
            outputs = self.model(**encodings)
            last_hidden_states = outputs.last_hidden_state
            attention_mask = encodings['attention_mask']
            
            mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_states.size()).float()
            sum_embeddings = torch.sum(last_hidden_states * mask_expanded, dim=1)
            sum_mask = torch.clamp(mask_expanded.sum(dim=1), min=1e-9)
            mean_embeddings = sum_embeddings / sum_mask
            
            return mean_embeddings.cpu().numpy()
    
    def process_batch(self, protein_sequences):
        encodings = self.encode_sequences(protein_sequences)
        embeddings = self.get_embeddings(encodings)
        return embeddings

In [None]:
# Initialize ProtBERT model
embedding_model = ProtBERTEmbeddingModel()
print(f"Using device: {embedding_model.device}")
print(f"Model hidden size: {embedding_model.model.config.hidden_size}")

In [None]:
# Load and preprocess data
DATA_PATH = "data/human_sequence_data.csv"
REQUIRED_COLUMNS = ["ORF"]

sequence_dataframe = pd.read_csv(DATA_PATH, usecols=REQUIRED_COLUMNS)

# Translate DNA sequences to protein sequences
protein_sequences = [
    embedding_model.translate_dna_to_protein(orf_sequence)
    for orf_sequence in sequence_dataframe["ORF"]
]

print(f"Loaded {len(protein_sequences)} protein sequences")
print(f"Average protein length: {np.mean([len(seq) for seq in protein_sequences if seq]):.1f}")

In [None]:
# Configure batch processing parameters
BATCH_SIZE = 100

dataset_size = len(protein_sequences)
total_batches = (dataset_size + BATCH_SIZE - 1) // BATCH_SIZE

print(f"Dataset size: {dataset_size}")
print(f"Batch size: {BATCH_SIZE}")
print(f"Total batches: {total_batches}")

In [None]:
# Generate embeddings in batches
for batch_index in tqdm(range(total_batches), desc="Processing batches"):
    start_index = batch_index * BATCH_SIZE
    end_index = min(start_index + BATCH_SIZE, dataset_size)
    
    # Extract batch sequences
    batch_protein_sequences = protein_sequences[start_index:end_index]
    
    # Generate embeddings
    batch_embeddings = embedding_model.process_batch(batch_protein_sequences)
    
    # Save embeddings
    embeddings_path = f"protbert_embeddings/batch_{batch_index:04d}.npy"
    np.save(embeddings_path, batch_embeddings)
    
    # Clean up memory
    del batch_embeddings, batch_protein_sequences
    torch.cuda.empty_cache()

print("Embedding generation completed!")

In [None]:
# Verify generated embeddings
sample_embeddings = np.load("protbert_embeddings/batch_0000.npy")

print(f"Sample embeddings shape: {sample_embeddings.shape}")
print(f"Embedding dimension: {sample_embeddings.shape[1]}")
print(f"Batch size: {sample_embeddings.shape[0]}")
print(f"Sample embedding norm: {np.linalg.norm(sample_embeddings[0]):.3f}")

In [None]:
# !python merge_batches.py -b "protbert_embeddings/" -o "protbert_embeddings/merged_embeddings.npy"