In [None]:
# Setup directories and install dependencies
!mkdir -p "rnabert_embeddings"
%pip install multimolecule

In [None]:
# Import required libraries
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from multimolecule import RnaTokenizer, RnaBertModel
import re
from tqdm import tqdm

In [3]:
# RNABERT embedding model class
class RNABERTEmbeddingModel:
    def __init__(self, model_name="multimolecule/rnabert", max_length=440):
        self.model_name = model_name
        self.max_length = max_length
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        self.tokenizer = RnaTokenizer.from_pretrained(
            model_name,
            bos_token=None,
            eos_token=None
        )
        self.model = RnaBertModel.from_pretrained(
            model_name,
            bos_token_id=None,
            eos_token_id=None
        )
        self.model = self.model.to(self.device)
        self.model.eval()
    
    def preprocess_rna_sequence(self, rna_sequence):
        if not isinstance(rna_sequence, str) or len(rna_sequence) == 0:
            return ""
        
        rna_sequence = rna_sequence.upper()
        rna_sequence = re.sub(r"[^AUCG]", "N", rna_sequence)
        
        return rna_sequence
    
    def encode_sequences(self, rna_sequences):
        preprocessed_sequences = [self.preprocess_rna_sequence(seq) for seq in rna_sequences]
        
        encodings = self.tokenizer(
            preprocessed_sequences,
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=self.max_length
        )
        
        return {key: tensor.to(self.device) for key, tensor in encodings.items()}
    
    def get_embeddings(self, encodings):
        with torch.no_grad():
            outputs = self.model(**encodings)
            last_hidden_states = outputs.last_hidden_state
            attention_mask = encodings['attention_mask']
            
            mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_states.size()).float()
            sum_embeddings = torch.sum(last_hidden_states * mask_expanded, dim=1)
            sum_mask = torch.clamp(mask_expanded.sum(dim=1), min=1e-9)
            mean_embeddings = sum_embeddings / sum_mask
            
            return mean_embeddings.cpu().numpy()
    
    def process_batch(self, rna_sequences):
        encodings = self.encode_sequences(rna_sequences)
        embeddings = self.get_embeddings(encodings)
        return embeddings


In [None]:
# Initialize RNABERT model
embedding_model = RNABERTEmbeddingModel()
print(f"Using device: {embedding_model.device}")
print(f"Model hidden size: {embedding_model.model.config.hidden_size}")

In [None]:
# Load and preprocess data
DATA_PATH = "data/human_sequence_data.csv"
REQUIRED_COLUMNS = ["ORF"]

sequence_dataframe = pd.read_csv(DATA_PATH, usecols=REQUIRED_COLUMNS)

rna_sequences = [
    str(orf_sequence).replace('T', 'U') if isinstance(orf_sequence, str) else ""
    for orf_sequence in sequence_dataframe["ORF"]
]

print(f"Loaded {len(rna_sequences)} RNA sequences")
print(f"Average RNA length: {np.mean([len(seq) for seq in rna_sequences if seq]):.1f}")

In [None]:
# Configure batch processing parameters
BATCH_SIZE = 100

dataset_size = len(rna_sequences)
total_batches = (dataset_size + BATCH_SIZE - 1) // BATCH_SIZE

print(f"Dataset size: {dataset_size}")
print(f"Batch size: {BATCH_SIZE}")
print(f"Total batches: {total_batches}")

In [None]:
# Generate embeddings in batches
for batch_index in tqdm(range(total_batches), desc="Processing batches"):
    start_index = batch_index * BATCH_SIZE
    end_index = min(start_index + BATCH_SIZE, dataset_size)
    
    batch_rna_sequences = rna_sequences[start_index:end_index]
    
    batch_embeddings = embedding_model.process_batch(batch_rna_sequences)
    
    embeddings_path = f"rnabert_embeddings/batch_{batch_index:04d}.npy"
    np.save(embeddings_path, batch_embeddings)
    
    del batch_embeddings, batch_rna_sequences
    torch.cuda.empty_cache()

print("Embedding generation completed!")

In [None]:
# Verify generated embeddings
sample_embeddings = np.load("rnabert_embeddings/batch_0000.npy")

print(f"Sample embeddings shape: {sample_embeddings.shape}")
print(f"Embedding dimension: {sample_embeddings.shape[1]}")
print(f"Batch size: {sample_embeddings.shape[0]}")
print(f"Sample embedding norm: {np.linalg.norm(sample_embeddings[0]):.3f}")

In [None]:
# !python merge_batches.py -b "rnabert_embeddings/" -o "rnabert_embeddings/merged_embeddings.npy"