In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
import transformers
from transformers import AutoTokenizer
from multimolecule import RnaTokenizer, RnaFmModel
import pandas as pd
import numpy as np
from sklearn.model_selection import KFold
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score, confusion_matrix, ConfusionMatrixDisplay
import matplotlib.pyplot as plt
import copy
import gc
import time

In [None]:
if torch.backends.mps.is_available():
    device = torch.device("mps")
    print("Using MPS (Apple Silicon GPU)")
elif torch.cuda.is_available():
    device = torch.device("cuda")
    print("Using CUDA")
else:
    device = torch.device("cpu")
    print("Using CPU")

In [None]:
df = pd.read_csv('data.csv')
df = df.drop('inhibition_value', axis=1)
df = df.rename(columns={'label_cls': 'inhibition'})
df

In [6]:
MODEL_NAME = "multimolecule/rnafm"
MAX_LENGTH = 512
BATCH_SIZE = 8
NUM_FOLDS = 5
SEED = 42

In [7]:

tokenizer = RnaTokenizer.from_pretrained(MODEL_NAME)

In [8]:
sep_token = tokenizer.sep_token if tokenizer.sep_token else "[SEP]"
def combine_sequences(sirna, mrna):
    sirna_str = str(sirna)
    mrna_str = str(mrna)
    return f"{sirna_str}{sep_token}{mrna_str}"

In [None]:

df['combined_sequence'] = df.apply(lambda row: combine_sequences(row['siRNA_sequence'], row['mRNA_sequence']), axis=1)
print("\nData with combined sequences:")
print(df[['combined_sequence', 'inhibition']].head())


In [10]:
class SiRnaClassificationDataset(Dataset):
    def __init__(self, sequences, targets, tokenizer, max_len):
        self.sequences = sequences
        self.targets = targets
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        sequence = str(self.sequences[idx])
        target = int(self.targets[idx])

        encoding = self.tokenizer.encode_plus(
            sequence,
            add_special_tokens=True,
            max_length=self.max_len,
            return_token_type_ids=False,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt',
        )

        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(target, dtype=torch.float)
        }

In [11]:

all_sequences = df['combined_sequence'].tolist()
all_labels = df['inhibition'].tolist()

In [12]:
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output.last_hidden_state
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
    sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
    return sum_embeddings / sum_mask

# def cls_pooling(model_output, attention_mask=None):
#     return model_output.last_hidden_state[:, 0]

In [13]:

pooling_strategy = mean_pooling # Choose mean_pooling or cls_pooling

In [14]:
def create_simple_classifier(input_size):
    return nn.Linear(input_size, 1)

def create_medium_mlp_classifier(input_size, hidden_dim=256):
    return nn.Sequential(
        nn.Linear(input_size, hidden_dim),
        nn.ReLU(),
        nn.Dropout(0.1),
        nn.Linear(hidden_dim, 1)
    )

def create_deep_mlp_classifier(input_size, hidden_dims=[512, 256, 128]):
    layers = []
    current_dim = input_size
    for h_dim in hidden_dims:
        layers.append(nn.Linear(current_dim, h_dim))
        layers.append(nn.ReLU())
        layers.append(nn.Dropout(0.1))
        current_dim = h_dim
    layers.append(nn.Linear(current_dim, 1))
    return nn.Sequential(*layers)

In [15]:
class RnaFmForSequenceClassification(nn.Module):
    def __init__(self, model_name, head_type='medium', pooling_func=mean_pooling):
        super().__init__()
        print(f"Loading base model: {model_name}")
        self.rna_fm = RnaFmModel.from_pretrained(model_name)
        hidden_size = self.rna_fm.config.hidden_size
        self.pooling = pooling_func

        print("Freezing base model parameters...")
        num_frozen = 0
        for param in self.rna_fm.parameters():
            param.requires_grad = False
            num_frozen += 1
        print(f"Froze {num_frozen} parameter tensors in the base model.")

        print(f"Creating classification head of type: {head_type}")
        if head_type == 'simple':
            self.classifier = create_simple_classifier(hidden_size)
        elif head_type == 'medium':
            self.classifier = create_medium_mlp_classifier(hidden_size)
        elif head_type == 'deep':
            self.classifier = create_deep_mlp_classifier(hidden_size)
        else:
            raise ValueError("Invalid head_type. Choose 'simple', 'medium', or 'deep'.")

        print("Model initialization complete.")

    def forward(self, input_ids, attention_mask):
        outputs = self.rna_fm(
            input_ids=input_ids,
            attention_mask=attention_mask,
            return_dict=True
        )

        pooled_output = self.pooling(outputs, attention_mask)

        logits = self.classifier(pooled_output)

        return logits.squeeze(-1)

In [16]:
def initialize_model_and_optimizer(model_name, head_type, learning_rate, weight_decay):
    model = RnaFmForSequenceClassification(model_name, head_type=head_type, pooling_func=pooling_strategy)
    model.to(device)

    optimizer = optim.AdamW(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=learning_rate,
        weight_decay=weight_decay
    )
    print("\nOptimizer initialized. Training parameters:")
    num_trainable = 0
    for name, param in model.named_parameters():
        if param.requires_grad:
            print(f"- {name} ({param.numel()})")
            num_trainable += param.numel()
    print(f"Total trainable parameters: {num_trainable}")

    criterion = nn.BCEWithLogitsLoss()

    return model, optimizer, criterion

In [17]:
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 0.01
EPOCHS = 10
HEAD_TYPE = 'medium'

In [18]:
def train_epoch(model, data_loader, criterion, optimizer, device):
    model.train()
    total_loss = 0.0
    start_time = time.time()

    for i, batch in enumerate(data_loader):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        optimizer.zero_grad()

        logits = model(input_ids=input_ids, attention_mask=attention_mask)

        loss = criterion(logits, labels)

        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(data_loader)
    elapsed_time = time.time() - start_time
    print(f"  Train Epoch completed in {elapsed_time:.2f}s, Avg. Loss: {avg_loss:.4f}")
    return avg_loss

def evaluate_model(model, data_loader, criterion, device):
    model.eval()
    total_loss = 0.0
    all_labels = []
    all_predictions = []
    all_probabilities = []
    start_time = time.time()

    with torch.no_grad():
        for batch in data_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            logits = model(input_ids=input_ids, attention_mask=attention_mask)

            loss = criterion(logits, labels)
            total_loss += loss.item()

            probabilities = torch.sigmoid(logits)
            predictions = (probabilities > 0.5).int()

            all_labels.extend(labels.cpu().numpy())
            all_predictions.extend(predictions.cpu().numpy())
            all_probabilities.extend(probabilities.cpu().numpy())

    avg_loss = total_loss / len(data_loader)
    all_labels = np.array(all_labels)
    all_predictions = np.array(all_predictions)
    all_probabilities = np.array(all_probabilities)

    accuracy = accuracy_score(all_labels, all_predictions)
    precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_predictions, average='binary', zero_division=0)
    try:
        roc_auc = roc_auc_score(all_labels, all_probabilities)
    except ValueError:
        roc_auc = float('nan')
        print("  Warning: ROC AUC could not be calculated (likely only one class in validation set).")


    elapsed_time = time.time() - start_time
    print(f"  Evaluation completed in {elapsed_time:.2f}s")
    return avg_loss, accuracy, precision, recall, f1, roc_auc, all_labels, all_predictions


def run_training_fold(model, optimizer, criterion, train_loader, val_loader, epochs, device, fold_num):
    best_val_loss = float('inf')
    best_model_state = None
    train_losses = []
    val_losses = []
    metrics_history = []

    print(f"Starting training for Fold {fold_num}, {epochs} epochs...")
    total_fold_start_time = time.time()

    for epoch in range(epochs):
        epoch_start_time = time.time()
        print(f"\n-- Epoch {epoch+1}/{epochs} --")

        avg_train_loss = train_epoch(model, train_loader, criterion, optimizer, device)
        train_losses.append(avg_train_loss)

        avg_val_loss, val_accuracy, val_precision, val_recall, val_f1, val_roc_auc, _, _ = evaluate_model(
            model, val_loader, criterion, device
        )
        val_losses.append(avg_val_loss)

        metrics_history.append({
            'epoch': epoch + 1,
            'train_loss': avg_train_loss,
            'val_loss': avg_val_loss,
            'val_accuracy': val_accuracy,
            'val_precision': val_precision,
            'val_recall': val_recall,
            'val_f1': val_f1,
            'val_roc_auc': val_roc_auc
        })

        epoch_time = time.time() - epoch_start_time
        print(f"Epoch {epoch+1} Summary | Time: {epoch_time:.2f}s")
        print(f"  Val Loss: {avg_val_loss:.4f} | Val Acc: {val_accuracy:.4f} | Val F1: {val_f1:.4f} | Val AUC: {val_roc_auc:.4f}")

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            best_model_state = copy.deepcopy(model.state_dict())
            print(f"  -> New best validation loss: {best_val_loss:.4f}. Saving model state.")

        if device == torch.device("mps"):
             torch.mps.empty_cache()
        gc.collect()

    total_fold_time = time.time() - total_fold_start_time
    print(f"\nTraining finished for Fold {fold_num}. Total time: {total_fold_time:.2f}s")

    return best_model_state, train_losses, val_losses, metrics_history, best_val_loss

In [None]:
kf = KFold(n_splits=NUM_FOLDS, shuffle=True, random_state=SEED)
fold_results = []
all_metrics_history = []

full_dataset = SiRnaClassificationDataset(all_sequences, all_labels, tokenizer, MAX_LENGTH)

print(f"\n--- Starting {NUM_FOLDS}-Fold Cross-Validation ---")
cv_start_time = time.time()

for fold, (train_idx, val_idx) in enumerate(kf.split(all_sequences)):
    fold_num = fold + 1
    print(f"\n==================== Fold {fold_num}/{NUM_FOLDS} ====================")

    train_sampler = SubsetRandomSampler(train_idx)
    val_sampler = SubsetRandomSampler(val_idx)

    train_loader = DataLoader(full_dataset, batch_size=BATCH_SIZE, sampler=train_sampler, num_workers=0)
    val_loader = DataLoader(full_dataset, batch_size=BATCH_SIZE, sampler=val_sampler, num_workers=0)

    print(f"Fold {fold_num}: Train samples = {len(train_idx)}, Validation samples = {len(val_idx)}")

    model, optimizer, criterion = initialize_model_and_optimizer(
        MODEL_NAME, HEAD_TYPE, LEARNING_RATE, WEIGHT_DECAY
    )

    best_model_state, train_losses, val_losses, metrics_history, fold_best_val_loss = run_training_fold(
        model, optimizer, criterion, train_loader, val_loader, EPOCHS, device, fold_num
    )
    all_metrics_history.append(metrics_history)

    if best_model_state:
        model.load_state_dict(best_model_state)
        print("\nLoaded best model state from training for final evaluation.")
    else:
        print("\nWarning: No best model state saved (check training). Evaluating last state.")

    print(f"Performing final evaluation for Fold {fold_num}...")
    final_val_loss, final_accuracy, final_precision, final_recall, final_f1, final_roc_auc, fold_labels, fold_preds = evaluate_model(
        model, val_loader, criterion, device
    )

    print(f"\nFold {fold_num} Final Validation Metrics (Best Model):")
    print(f"  Loss:      {final_val_loss:.4f}")
    print(f"  Accuracy:  {final_accuracy:.4f}")
    print(f"  Precision: {final_precision:.4f}")
    print(f"  Recall:    {final_recall:.4f}")
    print(f"  F1 Score:  {final_f1:.4f}")
    print(f"  ROC AUC:   {final_roc_auc:.4f}")

    fold_results.append({
        'fold': fold_num,
        'accuracy': final_accuracy,
        'precision': final_precision,
        'recall': final_recall,
        'f1': final_f1,
        'roc_auc': final_roc_auc,
        'best_val_loss': fold_best_val_loss,
        'train_losses': train_losses,
        'val_losses': val_losses,
    })

    plt.figure(figsize=(10, 5))

    plt.subplot(1, 2, 1)
    plt.plot(range(1, EPOCHS + 1), train_losses, label='Training Loss', marker='o')
    plt.plot(range(1, EPOCHS + 1), val_losses, label='Validation Loss', marker='x')
    plt.xlabel('Epoch')
    plt.ylabel('Loss (BCEWithLogits)')
    plt.title(f'Fold {fold_num} - Training & Validation Loss')
    plt.legend()
    plt.grid(True)

    plt.subplot(1, 2, 2)
    epochs_list = [m['epoch'] for m in metrics_history]
    f1_list = [m['val_f1'] for m in metrics_history]
    auc_list = [m['val_roc_auc'] for m in metrics_history]
    plt.plot(epochs_list, f1_list, label='Validation F1 Score', marker='s')
    plt.plot(epochs_list, auc_list, label='Validation ROC AUC', marker='^')
    plt.xlabel('Epoch')
    plt.ylabel('Metric Value')
    plt.title(f'Fold {fold_num} - Validation F1 & ROC AUC')
    plt.legend()
    plt.grid(True)

    plt.tight_layout()
    plt.show()

    try:
        cm = confusion_matrix(fold_labels, fold_preds)
        disp = ConfusionMatrixDisplay(confusion_matrix=cm)
        disp.plot(cmap=plt.cm.Blues)
        plt.title(f'Fold {fold_num} - Confusion Matrix (Best Model)')
        plt.show()
    except Exception as e:
        print(f"Could not display confusion matrix for Fold {fold_num}: {e}")

    print(f"Cleaning up Fold {fold_num} resources...")
    del model, optimizer, criterion, train_loader, val_loader, train_sampler, val_sampler, best_model_state
    if device == torch.device("mps"):
        torch.mps.empty_cache()
    gc.collect()

cv_end_time = time.time()
print(f"\n--- Cross-Validation Finished --- Total Time: {cv_end_time - cv_start_time:.2f}s ---")

In [None]:

results_df = pd.DataFrame(fold_results)

print("\n--- Cross-Validation Summary ---")

print(results_df[[
    'fold', 'accuracy', 'precision', 'recall', 'f1', 'roc_auc', 'best_val_loss'
]].round(4).to_string(index=False))

In [None]:

print("\nAverage Metrics Across Folds:")
avg_accuracy = results_df['accuracy'].mean()
std_accuracy = results_df['accuracy'].std()
avg_precision = results_df['precision'].mean()
std_precision = results_df['precision'].std()
avg_recall = results_df['recall'].mean()
std_recall = results_df['recall'].std()
avg_f1 = results_df['f1'].mean()
std_f1 = results_df['f1'].std()
avg_roc_auc = results_df['roc_auc'].mean()
std_roc_auc = results_df['roc_auc'].std()

print(f"  Accuracy:  {avg_accuracy:.4f} +/- {std_accuracy:.4f}")
print(f"  Precision: {avg_precision:.4f} +/- {std_precision:.4f}")
print(f"  Recall:    {avg_recall:.4f} +/- {std_recall:.4f}")
print(f"  F1 Score:  {avg_f1:.4f} +/- {std_f1:.4f}")
print(f"  ROC AUC:   {avg_roc_auc:.4f} +/- {std_roc_auc:.4f}")
