In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
import transformers
from transformers import AutoTokenizer
from multimolecule import RnaTokenizer, RnaFmModel
import pandas as pd
import numpy as np
from sklearn.model_selection import KFold
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
import matplotlib.pyplot as plt
import copy
import gc

In [None]:
if torch.backends.mps.is_available():
    device = torch.device("mps")
    print("Using MPS (Apple Silicon GPU)")
elif torch.cuda.is_available():
    device = torch.device("cuda")
    print("Using CUDA")
else:
    device = torch.device("cpu")
    print("Using CPU")


In [None]:
df = pd.read_csv('data.csv')
df

In [4]:
MODEL_NAME = "multimolecule/rnafm"
MAX_LENGTH = 512
BATCH_SIZE = 8
NUM_FOLDS = 5
SEED = 42

In [5]:
tokenizer = RnaTokenizer.from_pretrained(MODEL_NAME)

In [6]:
sep_token = tokenizer.sep_token if tokenizer.sep_token else "[SEP]"

In [7]:
def combine_sequences(sirna, mrna):
    sirna_str = str(sirna)
    mrna_str = str(mrna)
    return f"{sirna_str}{sep_token}{mrna_str}"

In [None]:

df['combined_sequence'] = df.apply(lambda row: combine_sequences(row['siRNA_sequence'], row['mRNA_sequence']), axis=1)
print("\nData with combined sequences:")
print(df[['combined_sequence', 'inhibition_value']].head())

In [12]:
class SiRnaDataset(Dataset):
    def __init__(self, sequences, targets, tokenizer, max_len):
        self.sequences = sequences
        self.targets = targets
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        sequence = str(self.sequences[idx])
        target = float(self.targets[idx])

        encoding = self.tokenizer.encode_plus(
            sequence,
            add_special_tokens=True,
            max_length=self.max_len,
            return_token_type_ids=False,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt',
        )

        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(target, dtype=torch.float)
        }

all_sequences = df['combined_sequence'].tolist()
all_labels = df['inhibition_value'].tolist()

In [13]:
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output.last_hidden_state
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
    sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
    return sum_embeddings / sum_mask

def create_simple_regressor(input_size):
    return nn.Linear(input_size, 1)

def create_medium_mlp_head(input_size, hidden_dim=256):
    return nn.Sequential(
        nn.Linear(input_size, hidden_dim),
        nn.ReLU(),
        nn.Dropout(0.1),
        nn.Linear(hidden_dim, 1)
    )

def create_deep_mlp_head(input_size, hidden_dims=[512, 256, 128]):
    layers = []
    current_dim = input_size
    for h_dim in hidden_dims:
        layers.append(nn.Linear(current_dim, h_dim))
        layers.append(nn.ReLU())
        layers.append(nn.Dropout(0.1))
        current_dim = h_dim
    layers.append(nn.Linear(current_dim, 1))
    return nn.Sequential(*layers)

class RnaFmForRegression(nn.Module):
    def __init__(self, model_name, head_type='medium'):
        super().__init__()
        self.rna_fm = RnaFmModel.from_pretrained(model_name)
        hidden_size = self.rna_fm.config.hidden_size

        print("Freezing base model parameters...")
        for param in self.rna_fm.parameters():
            param.requires_grad = False
        print("Base model parameters frozen.")

        print(f"Creating regression head of type: {head_type}")
        if head_type == 'simple':
            self.regressor = create_simple_regressor(hidden_size)
        elif head_type == 'medium':
            self.regressor = create_medium_mlp_head(hidden_size)
        elif head_type == 'deep':
            self.regressor = create_deep_mlp_head(hidden_size)
        else:
            raise ValueError("Invalid head_type. Choose 'simple', 'medium', or 'deep'.")

    def forward(self, input_ids, attention_mask):
        outputs = self.rna_fm(
            input_ids=input_ids,
            attention_mask=attention_mask
        )

        pooled_output = mean_pooling(outputs, attention_mask)

        logits = self.regressor(pooled_output)

        return logits.squeeze(-1)

def initialize_model_and_optimizer(model_name, head_type, learning_rate, weight_decay):
    model = RnaFmForRegression(model_name, head_type=head_type)
    model.to(device)

    optimizer = optim.AdamW(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=learning_rate,
        weight_decay=weight_decay
    )
    print("\nOptimizer will train the following parameters:")
    for name, param in model.named_parameters():
        if param.requires_grad:
            print(name)

    criterion = nn.MSELoss()

    return model, optimizer, criterion

In [14]:
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 0.01
EPOCHS = 10
HEAD_TYPE = 'medium'

In [15]:
def train_epoch(model, data_loader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    for batch in data_loader:
        optimizer.zero_grad()

        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        loss = criterion(outputs, labels)

        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(data_loader)

def evaluate_model(model, data_loader, criterion, device):
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for batch in data_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            loss = criterion(outputs, labels)

            total_loss += loss.item()
            all_preds.extend(outputs.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    avg_loss = total_loss / len(data_loader)
    mse = mean_squared_error(all_labels, all_preds)
    mae = mean_absolute_error(all_labels, all_preds)
    r2 = r2_score(all_labels, all_preds)

    return avg_loss, mse, mae, r2, all_labels, all_preds

def run_training_fold(model, optimizer, criterion, train_loader, val_loader, epochs, device):
    best_val_loss = float('inf')
    best_model_state = None
    train_losses = []
    val_losses = []

    print(f"Starting training for {epochs} epochs...")
    for epoch in range(epochs):
        train_loss = train_epoch(model, train_loader, criterion, optimizer, device)
        val_loss, val_mse, val_mae, val_r2, _, _ = evaluate_model(model, val_loader, criterion, device)

        train_losses.append(train_loss)
        val_losses.append(val_loss)

        print(f"Epoch {epoch+1}/{epochs} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val MSE: {val_mse:.4f} | Val MAE: {val_mae:.4f} | Val R2: {val_r2:.4f}")

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model_state = copy.deepcopy(model.regressor.state_dict())
            print(f"  -> New best validation loss: {best_val_loss:.4f}. Saving model state.")

        if device == torch.device("mps"):
             torch.mps.empty_cache()
        gc.collect()

    print("Training finished.")
    return best_model_state, train_losses, val_losses, best_val_loss

In [None]:
kf = KFold(n_splits=NUM_FOLDS, shuffle=True, random_state=SEED)
fold_results = []

full_dataset = SiRnaDataset(all_sequences, all_labels, tokenizer, MAX_LENGTH)

print(f"\n--- Starting {NUM_FOLDS}-Fold Cross-Validation ---")

for fold, (train_idx, val_idx) in enumerate(kf.split(all_sequences)):
    print(f"\n--- Fold {fold+1}/{NUM_FOLDS} ---")

    train_sampler = SubsetRandomSampler(train_idx)
    val_sampler = SubsetRandomSampler(val_idx)

    train_loader = DataLoader(full_dataset, batch_size=BATCH_SIZE, sampler=train_sampler)
    val_loader = DataLoader(full_dataset, batch_size=BATCH_SIZE, sampler=val_sampler)

    print(f"Train samples: {len(train_idx)}, Validation samples: {len(val_idx)}")

    model, optimizer, criterion = initialize_model_and_optimizer(
        MODEL_NAME, HEAD_TYPE, LEARNING_RATE, WEIGHT_DECAY
    )

    best_head_state, train_losses, val_losses, best_val_loss = run_training_fold(
        model, optimizer, criterion, train_loader, val_loader, EPOCHS, device
    )

    if best_head_state:
        model.regressor.load_state_dict(best_head_state)
        print("Loaded best model state for final evaluation.")
    else:
        print("Warning: No best model state saved (likely only 1 epoch or no improvement). Evaluating last state.")

    final_val_loss, final_mse, final_mae, final_r2, fold_labels, fold_preds = evaluate_model(
        model, val_loader, criterion, device
    )

    print(f"\nFold {fold+1} Final Validation Metrics:")
    print(f"  MSE: {final_mse:.4f}")
    print(f"  MAE: {final_mae:.4f}")
    print(f"  R2 Score: {final_r2:.4f}")

    fold_results.append({
        'fold': fold + 1,
        'mse': final_mse,
        'mae': final_mae,
        'r2': final_r2,
        'best_val_loss': best_val_loss,
        'train_losses': train_losses,
        'val_losses': val_losses,
        'labels': fold_labels,
        'predictions': fold_preds
    })

    plt.figure(figsize=(8, 4))
    plt.plot(range(1, EPOCHS + 1), train_losses, label='Training Loss')
    plt.plot(range(1, EPOCHS + 1), val_losses, label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss (MSE)')
    plt.title(f'Fold {fold+1} - Training & Validation Loss')
    plt.legend()
    plt.grid(True)
    plt.show()

    del model, optimizer, criterion, train_loader, val_loader, train_sampler, val_sampler
    if device == torch.device("mps"):
        torch.mps.empty_cache()
    gc.collect()


print("\n--- Cross-Validation Finished ---")

In [None]:
results_df = pd.DataFrame(fold_results)
print("\n--- Cross-Validation Summary ---")
print(results_df[['fold', 'mse', 'mae', 'r2', 'best_val_loss']])

avg_mse = results_df['mse'].mean()
std_mse = results_df['mse'].std()
avg_mae = results_df['mae'].mean()
std_mae = results_df['mae'].std()
avg_r2 = results_df['r2'].mean()
std_r2 = results_df['r2'].std()

print("\nAverage Metrics Across Folds:")
print(f"  MSE: {avg_mse:.4f} +/- {std_mse:.4f}")
print(f"  MAE: {avg_mae:.4f} +/- {std_mae:.4f}")
print(f"  R2 Score: {avg_r2:.4f} +/- {std_r2:.4f}")
