In [3]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, SequentialSampler
import transformers
from transformers import AutoTokenizer, AutoModel
from multimolecule import RnaTokenizer, RnaFmModel
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import time
import gc

from sklearn.model_selection import KFold, cross_val_score, GridSearchCV
from sklearn.preprocessing import StandardScaler
from sklearn.ensemble import RandomForestRegressor
from sklearn.svm import SVR
from sklearn.neighbors import KNeighborsRegressor
from sklearn.linear_model import Ridge
from sklearn.pipeline import Pipeline
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score

import xgboost as xgb
import lightgbm as lgb

In [None]:
if torch.backends.mps.is_available():
    device = torch.device("mps")
    print("Using MPS (Apple Silicon GPU)")
elif torch.cuda.is_available():
    device = torch.device("cuda")
    print("Using CUDA for feature extraction")
else:
    device = torch.device("cpu")
    print("Using CPU for feature extraction")

In [None]:
df = pd.read_csv('data.csv')
df

In [13]:

MODEL_NAME = "multimolecule/rnafm"
MAX_LENGTH = 512
EXTRACT_BATCH_SIZE = 16
NUM_FOLDS = 5
SEED = 42

In [7]:

tokenizer = RnaTokenizer.from_pretrained(MODEL_NAME)
sep_token = tokenizer.sep_token if tokenizer.sep_token else "[SEP]"

In [8]:
def combine_sequences(sirna, mrna):
    sirna_str = str(sirna)
    mrna_str = str(mrna)
    return f"{sirna_str}{sep_token}{mrna_str}"

In [None]:
df['combined_sequence'] = df.apply(lambda row: combine_sequences(row['siRNA_sequence'], row['mRNA_sequence']), axis=1)
all_sequences = df['combined_sequence'].tolist()
all_labels = df['inhibition_value'].to_numpy() 

print(f"\nPrepared {len(all_sequences)} sequences for feature extraction.")

In [10]:
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output.last_hidden_state
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
    sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
    return sum_embeddings / sum_mask

class ExtractionDataset(Dataset):
    def __init__(self, sequences, tokenizer, max_len):
        self.sequences = sequences
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        sequence = str(self.sequences[idx])
        encoding = self.tokenizer.encode_plus(
            sequence,
            add_special_tokens=True,
            max_length=self.max_len,
            return_token_type_ids=False,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt',
        )
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
        }

def extract_rnafm_features(sequences, model, tokenizer, max_len, batch_size, device):
    dataset = ExtractionDataset(sequences, tokenizer, max_len)
    data_loader = DataLoader(dataset, batch_size=batch_size, sampler=SequentialSampler(dataset))

    model.eval()
    model.to(device)
    all_features = []
    print(f"Starting feature extraction for {len(sequences)} sequences on device: {device}")
    with torch.no_grad():
        start_time = time.time()
        for i, batch in enumerate(data_loader):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)

            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                return_dict=True
            )

            pooled_output = mean_pooling(outputs, attention_mask)

            all_features.append(pooled_output.cpu().numpy())

            if (i + 1) % 50 == 0:
                 elapsed = time.time() - start_time
                 print(f"  Processed batch {i+1}/{len(data_loader)} ({elapsed:.2f}s)")

    print(f"Feature extraction completed in {time.time() - start_time:.2f}s")
    features_array = np.concatenate(all_features, axis=0)
    print(f"Extracted features shape: {features_array.shape}")
    return features_array

In [None]:

print(f"\nLoading RNA-FM model: {MODEL_NAME}")
rnafm_model = RnaFmModel.from_pretrained(MODEL_NAME)

In [None]:

X_features = extract_rnafm_features(
    all_sequences,
    rnafm_model,
    tokenizer,
    MAX_LENGTH,
    EXTRACT_BATCH_SIZE,
    device
)
y_labels = all_labels 

In [None]:
models = {
    "RandomForest": RandomForestRegressor(
        n_estimators=200,
        max_depth=20,
        min_samples_split=5,
        min_samples_leaf=2,
        n_jobs=-1,
        random_state=SEED
    ),

    "XGBoost": xgb.XGBRegressor(
        n_estimators=300,
        learning_rate=0.1,
        max_depth=5,
        subsample=0.8,
        colsample_bytree=0.8,
        objective='reg:squarederror',
        n_jobs=-1,
        random_state=SEED,
    ),

    "SVR_RBF": SVR(
        kernel='rbf',
        C=10.0,
        gamma='scale',
        epsilon=0.1
    ),

    "LightGBM": lgb.LGBMRegressor(
        n_estimators=300,
        learning_rate=0.1,
        max_depth=7,
        num_leaves=31,
        subsample=0.8,
        colsample_bytree=0.8,
        objective='regression',
        n_jobs=-1,
        random_state=SEED
    ),

    "KNeighbors": KNeighborsRegressor(
        n_neighbors=7,
        weights='distance',
        p=2
    ),

     "Ridge": Ridge(
         alpha=1.0,
         random_state=SEED
     )
}

print(f"\nDefined {len(models)} classical ML models for evaluation.")

In [None]:
kf = KFold(n_splits=NUM_FOLDS, shuffle=True, random_state=SEED)
cv_results = []

print(f"\n--- Starting {NUM_FOLDS}-Fold Cross-Validation for Classical Models ---")

X = X_features
y = y_labels

for fold, (train_idx, val_idx) in enumerate(kf.split(X, y)):
    print(f"\n--- Fold {fold+1}/{NUM_FOLDS} ---")
    X_train, X_val = X[train_idx], X[val_idx]
    y_train, y_val = y[train_idx], y[val_idx]

    for model_name, model_instance in models.items():
        print(f"  Training {model_name}...")
        start_time = time.time()

        pipeline = Pipeline([
            ('scaler', StandardScaler()),
            ('regressor', model_instance)
        ])

        pipeline.fit(X_train, y_train)

        y_pred = pipeline.predict(X_val)

        mse = mean_squared_error(y_val, y_pred)
        mae = mean_absolute_error(y_val, y_pred)
        r2 = r2_score(y_val, y_pred)

        elapsed = time.time() - start_time
        print(f"    {model_name} Fold {fold+1} | MSE: {mse:.4f} | MAE: {mae:.4f} | R2: {r2:.4f} | Time: {elapsed:.2f}s")

        cv_results.append({
            'fold': fold + 1,
            'model': model_name,
            'mse': mse,
            'mae': mae,
            'r2': r2
        })

    gc.collect()

print("\n--- Cross-Validation Finished ---")

In [None]:
results_df = pd.DataFrame(cv_results)

print("\n--- Cross-Validation Summary (Classical Models on RNA-FM Features) ---")

summary = results_df.groupby('model').agg(
    avg_mse=('mse', 'mean'),
    std_mse=('mse', 'std'),
    avg_mae=('mae', 'mean'),
    std_mae=('mae', 'std'),
    avg_r2=('r2', 'mean'),
    std_r2=('r2', 'std')
).reset_index()

summary['MSE'] = summary.apply(lambda row: f"{row['avg_mse']:.4f} +/- {row['std_mse']:.4f}", axis=1)
summary['MAE'] = summary.apply(lambda row: f"{row['avg_mae']:.4f} +/- {row['std_mae']:.4f}", axis=1)
summary['R2'] = summary.apply(lambda row: f"{row['avg_r2']:.4f} +/- {row['std_r2']:.4f}", axis=1)

print(summary[['model', 'MSE', 'MAE', 'R2']].to_string(index=False))

best_model_r2 = summary.loc[summary['avg_r2'].idxmax()]
print(f"\nBest model based on average R2 Score: {best_model_r2['model']} (R2 = {best_model_r2['R2']})")
best_model_mse = summary.loc[summary['avg_mse'].idxmin()]
print(f"Best model based on average MSE: {best_model_mse['model']} (MSE = {best_model_mse['MSE']})")