In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, SequentialSampler
import transformers
from transformers import AutoTokenizer, AutoModel
from multimolecule import RnaTokenizer, RnaFmModel
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import time
import gc
import os


from sklearn.model_selection import KFold, cross_val_score, GridSearchCV
from sklearn.preprocessing import StandardScaler
from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import SVC 
from sklearn.neighbors import KNeighborsClassifier
from sklearn.linear_model import LogisticRegression 
from sklearn.pipeline import Pipeline
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score, confusion_matrix, ConfusionMatrixDisplay

import xgboost as xgb
import lightgbm as lgb

In [None]:

if torch.backends.mps.is_available():
    device = torch.device("mps")
    print("Using MPS (Apple Silicon GPU) for feature extraction.")
elif torch.cuda.is_available():
    device = torch.device("cuda")
    print("Using CUDA for feature extraction.")
else:
    device = torch.device("cpu")
    print("Using CPU for feature extraction.")

In [None]:
df = pd.read_csv('data.csv')
# data.csv has four columns, siRNA_sequence, mRNA_sequence, inhibition_value, label_cls. inhibition_value has continuous values 0-1, label_cls is 0 or 1
df = df.drop('inhibition_value', axis=1)
df = df.rename(columns={'label_cls': 'inhibition'})
df

In [4]:

MODEL_NAME = "multimolecule/rnafm"
MAX_LENGTH = 512
EXTRACT_BATCH_SIZE = 16 
NUM_FOLDS = 5
SEED = 42

In [None]:

tokenizer = RnaTokenizer.from_pretrained(MODEL_NAME)
sep_token = tokenizer.sep_token if tokenizer.sep_token else "[SEP]"

def combine_sequences(sirna, mrna):
    sirna_str = str(sirna)
    mrna_str = str(mrna)
    return f"{sirna_str}{sep_token}{mrna_str}"

df['combined_sequence'] = df.apply(lambda row: combine_sequences(row['siRNA_sequence'], row['mRNA_sequence']), axis=1)
all_sequences = df['combined_sequence'].tolist()

all_labels = df['inhibition'].to_numpy(dtype=int)

print(f"\nPrepared {len(all_sequences)} sequences for feature extraction.")

In [None]:

def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output.last_hidden_state
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
    sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
    return sum_embeddings / sum_mask


class ExtractionDataset(Dataset):
    def __init__(self, sequences, tokenizer, max_len):
        self.sequences = sequences
        self.tokenizer = tokenizer
        self.max_len = max_len
    def __len__(self): return len(self.sequences)
    def __getitem__(self, idx):
        sequence = str(self.sequences[idx])
        encoding = self.tokenizer.encode_plus(
            sequence, add_special_tokens=True, max_length=self.max_len,
            return_token_type_ids=False, padding='max_length', truncation=True,
            return_attention_mask=True, return_tensors='pt',
        )
        return {'input_ids': encoding['input_ids'].flatten(),
                'attention_mask': encoding['attention_mask'].flatten()}


def extract_rnafm_features(sequences, model, tokenizer, max_len, batch_size, device):
    dataset = ExtractionDataset(sequences, tokenizer, max_len)
    data_loader = DataLoader(dataset, batch_size=batch_size, sampler=SequentialSampler(dataset))
    model.eval()
    model.to(device)
    all_features = []
    print(f"Starting feature extraction for {len(sequences)} sequences on device: {device}")
    with torch.no_grad():
        start_time = time.time()
        for i, batch in enumerate(data_loader):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            outputs = model(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
            pooled_output = mean_pooling(outputs, attention_mask)
            all_features.append(pooled_output.cpu().numpy())
            if (i + 1) % 50 == 0:
                 elapsed = time.time() - start_time
                 print(f"  Processed batch {i+1}/{len(data_loader)} ({elapsed:.2f}s)")
    print(f"Feature extraction completed in {time.time() - start_time:.2f}s")
    features_array = np.concatenate(all_features, axis=0)
    print(f"Extracted features shape: {features_array.shape}") 
    return features_array


print(f"\nLoading RNA-FM model: {MODEL_NAME}")

try:
    rnafm_model = AutoModel.from_pretrained(MODEL_NAME)
except Exception:
    print("AutoModel failed, trying RnaFmModel directly...")
    rnafm_model = RnaFmModel.from_pretrained(MODEL_NAME)



X_features = extract_rnafm_features(
    all_sequences,
    rnafm_model,
    tokenizer,
    MAX_LENGTH,
    EXTRACT_BATCH_SIZE,
    device
)
y_labels = all_labels # Our target variable (0 or 1)


del rnafm_model
if device == torch.device("mps"): torch.mps.empty_cache()
elif device == torch.device("cuda"): torch.cuda.empty_cache()
gc.collect()
print("RNA-FM model removed from memory.")

In [None]:
neg_count = np.sum(y_labels == 0)
pos_count = np.sum(y_labels == 1)
scale_pos_weight_val = neg_count / pos_count if pos_count > 0 else 1
print(f"Calculated scale_pos_weight for XGBoost: {scale_pos_weight_val:.2f}")

models = {
    "RandomForest": RandomForestClassifier(
        n_estimators=200,       
        max_depth=20,           
        min_samples_split=5,    
        min_samples_leaf=2,     
        class_weight='balanced',
        n_jobs=-1,              
        random_state=SEED
    ),

    "XGBoost": xgb.XGBClassifier(
        n_estimators=300,       
        learning_rate=0.1,      
        max_depth=5,            
        subsample=0.8,          
        colsample_bytree=0.8,   
        objective='binary:logistic',
        eval_metric='logloss',  
        use_label_encoder=False,
        scale_pos_weight=scale_pos_weight_val, 
        n_jobs=-1,
        random_state=SEED,
    ),

    "SVC_RBF": SVC(
        kernel='rbf',
        C=10.0,                 
        gamma='scale',          
        class_weight='balanced',
        probability=True,       
        random_state=SEED
    ),

    "LightGBM": lgb.LGBMClassifier(
        n_estimators=300,       
        learning_rate=0.1,      
        max_depth=7,            
        num_leaves=31,          
        subsample=0.8,          
        colsample_bytree=0.8,   
        objective='binary',
        metric='auc',           
        scale_pos_weight=scale_pos_weight_val, 
        n_jobs=-1,
        random_state=SEED
    ),

    "KNeighbors": KNeighborsClassifier(
        n_neighbors=7,          
        weights='distance',     
        p=2                     
    ),

     "LogisticRegression": LogisticRegression(
         C=1.0,                 
         penalty='l2',          
         solver='liblinear',    
         class_weight='balanced',
         max_iter=1000,         
         random_state=SEED
     )
}

print(f"\nDefined {len(models)} classical ML classifiers for evaluation.")

In [None]:
kf = KFold(n_splits=NUM_FOLDS, shuffle=True, random_state=SEED)
cv_results = []

print(f"\n--- Starting {NUM_FOLDS}-Fold Cross-Validation for Classical Classifiers ---")

X = X_features
y = y_labels

for fold, (train_idx, val_idx) in enumerate(kf.split(X, y)):
    fold_num = fold + 1
    print(f"\n--- Fold {fold_num}/{NUM_FOLDS} ---")
    X_train, X_val = X[train_idx], X[val_idx]
    y_train, y_val = y[train_idx], y[val_idx]

    for model_name, model_instance in models.items():
        print(f"  Training {model_name}...")
        start_time = time.time()

        pipeline = Pipeline([
            ('scaler', StandardScaler()),
            ('classifier', model_instance)
        ])

        pipeline.fit(X_train, y_train)

        y_pred = pipeline.predict(X_val)
        try:
            y_prob = pipeline.predict_proba(X_val)[:, 1]
        except AttributeError:
            print(f"    Note: {model_name} does not have predict_proba. ROC AUC calculated using decision function or predict.")
            try:
                 y_scores = pipeline.decision_function(X_val)
                 if y_scores.ndim > 1 and y_scores.shape[1] > 1:
                     y_prob = y_scores[:, 1]
                 else:
                     y_prob = y_scores
            except AttributeError:
                 y_prob = y_pred

        accuracy = accuracy_score(y_val, y_pred)
        precision, recall, f1, _ = precision_recall_fscore_support(y_val, y_pred, average='binary', zero_division=0)
        try:
            roc_auc = roc_auc_score(y_val, y_prob)
        except ValueError:
            roc_auc = float('nan')

        elapsed = time.time() - start_time
        print(f"    {model_name} Fold {fold_num} | Acc: {accuracy:.4f} | F1: {f1:.4f} | AUC: {roc_auc:.4f} | Time: {elapsed:.2f}s")

        cv_results.append({
            'fold': fold_num,
            'model': model_name,
            'accuracy': accuracy,
            'precision': precision,
            'recall': recall,
            'f1': f1,
            'roc_auc': roc_auc
        })

    gc.collect()

print("\n--- Cross-Validation Finished ---")

In [None]:
results_df = pd.DataFrame(cv_results)

print("\n--- Cross-Validation Summary (Classical Classifiers on RNA-FM Features) ---")

summary = results_df.groupby('model').agg(
    avg_acc=('accuracy', 'mean'), std_acc=('accuracy', 'std'),
    avg_prec=('precision', 'mean'), std_prec=('precision', 'std'),
    avg_rec=('recall', 'mean'), std_rec=('recall', 'std'),
    avg_f1=('f1', 'mean'), std_f1=('f1', 'std'),
    avg_auc=('roc_auc', 'mean'), std_auc=('roc_auc', 'std')
).reset_index()

summary['Accuracy'] = summary.apply(lambda row: f"{row['avg_acc']:.4f} +/- {row['std_acc']:.4f}", axis=1)
summary['Precision'] = summary.apply(lambda row: f"{row['avg_prec']:.4f} +/- {row['std_prec']:.4f}", axis=1)
summary['Recall'] = summary.apply(lambda row: f"{row['avg_rec']:.4f} +/- {row['std_rec']:.4f}", axis=1)
summary['F1 Score'] = summary.apply(lambda row: f"{row['avg_f1']:.4f} +/- {row['std_f1']:.4f}", axis=1)
summary['ROC AUC'] = summary.apply(lambda row: f"{row['avg_auc']:.4f} +/- {row['std_auc']:.4f}", axis=1)

print(summary[['model', 'Accuracy', 'Precision', 'Recall', 'F1 Score', 'ROC AUC']].to_string(index=False))

best_model_f1 = summary.loc[summary['avg_f1'].idxmax()]
print(f"\nBest model based on average F1 Score: {best_model_f1['model']} (F1 = {best_model_f1['F1 Score']})")
best_model_auc = summary.loc[summary['avg_auc'].idxmax()]
print(f"Best model based on average ROC AUC:  {best_model_auc['model']} (AUC = {best_model_auc['ROC AUC']})")