# Deep Learning Homework 2 - Question 2.1
## RNA Binding Protein (RBP) Interaction Prediction

## 1. Setup and Imports

In [None]:
# Install required packages
!pip install openpyxl -q

In [None]:
import os
import random
import time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from typing import List, Tuple

# Check device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## 2. Download Data

Download the data files from the Google Drive link provided in the homework:
- `norm_data.txt`
- `metadata.xlsx`

Upload them to Colab or mount your Google Drive.

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Set the path to your data files
DATA_DIR = 'MyDrive/path'  # Change this to your folder path

## 3. Configuration and Utility Functions

In [None]:
from dataclasses import dataclass

@dataclass
class RNAConfig:
    """Global configuration for the RNAcompete Data Pipeline."""

    # Data Path - UPDATE THESE PATHS
    DATA_PATH: str = f"{DATA_DIR}/norm_data.txt"
    METADATA_PATH: str = f"{DATA_DIR}/metadata.xlsx"
    METADATA_SHEET: str = "Master List--Plasmid Info"

    # Save Path
    SAVE_DIR: str = "data"

    # Sequence Parameters
    SEQ_MAX_LEN: int = 41
    ALPHABET: str = "ACGUN"

    # Preprocessing
    CLIP_PERCENTILE: float = 99.95
    EPSILON: float = 1e-6

    # Split Identifiers
    TRAIN_SPLIT_ID: str = "SetA"
    TEST_SPLIT_ID: str = "SetB"

    VAL_SPLIT_PCT: float = 0.2
    SEED: int = 42

In [None]:
def configure_seed(seed):
    """Set all random seeds for reproducibility."""
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

configure_seed(42)

In [None]:
def masked_mse_loss(preds, targets, masks):
    """
    Calculates Mean Squared Error, ignoring padded elements.
    """
    preds = preds.squeeze()
    targets = targets.squeeze()
    masks = masks.squeeze().bool()

    masked_preds = preds[masks]
    masked_targets = targets[masks]

    if masked_preds.numel() == 0:
        return torch.tensor(0.0, device=preds.device, requires_grad=True)

    squared_error = (masked_preds - masked_targets) ** 2
    loss = torch.mean(squared_error)

    return loss


def masked_spearman_correlation(preds, targets, masks):
    """
    Calculates Spearman Rank Correlation on masked data.
    """
    preds = preds.squeeze().detach()
    targets = targets.squeeze().detach()
    masks = masks.squeeze().bool()

    valid_preds = preds[masks]
    valid_targets = targets[masks]

    if valid_preds.numel() < 2:
        return torch.tensor(0.0)

    pred_ranks = valid_preds.argsort().argsort().float()
    target_ranks = valid_targets.argsort().argsort().float()

    pred_mean = pred_ranks.mean()
    target_mean = target_ranks.mean()

    pred_var = pred_ranks - pred_mean
    target_var = target_ranks - target_mean

    correlation = (pred_var * target_var).sum() / torch.sqrt((pred_var ** 2).sum() * (target_var ** 2).sum() + 1e-8)

    return correlation

## 4. Data Loader

In [None]:
class RNACompeteLoader:
    def __init__(self, config: RNAConfig):
        self.cfg = config
        self.meta_df = None
        self.data_df = None
        self.protein_to_id = None

        self.char_map = {
            'A': np.array([1, 0, 0, 0], dtype=np.float32),
            'C': np.array([0, 1, 0, 0], dtype=np.float32),
            'G': np.array([0, 0, 1, 0], dtype=np.float32),
            'U': np.array([0, 0, 0, 1], dtype=np.float32),
            'N': np.array([0.25, 0.25, 0.25, 0.25], dtype=np.float32)
        }
        self.padding_vec = np.zeros(4, dtype=np.float32)

    def _ensure_data_loaded(self):
        if self.data_df is not None:
            return

        print(f"Loading Metadata from {self.cfg.METADATA_PATH}...")
        start_time = time.time()
        self.meta_df = pd.read_excel(
            self.cfg.METADATA_PATH,
            sheet_name=self.cfg.METADATA_SHEET
        )
        print(f"  > Metadata loaded in {time.time() - start_time:.2f} seconds.")

        self.meta_df.columns = [c.strip() for c in self.meta_df.columns]
        self.protein_to_id = pd.Series(
            self.meta_df['Motif_ID'].values,
            index=self.meta_df['Protein_name']
        ).to_dict()

        print(f"Loading Data from {self.cfg.DATA_PATH}...")
        start_time = time.time()
        self.data_df = pd.read_csv(self.cfg.DATA_PATH, sep='\t', low_memory=False)
        print(f"  > Data Matrix loaded in {time.time() - start_time:.2f} seconds.")
        self.data_df.columns = [c.strip() for c in self.data_df.columns]

    def _encode_sequence(self, seq: str) -> np.ndarray:
        if not isinstance(seq, str):
            seq = "N" * self.cfg.SEQ_MAX_LEN
        seq = seq.upper()[:self.cfg.SEQ_MAX_LEN]
        encoded = [self.char_map.get(base, self.char_map['N']) for base in seq]
        pad_len = self.cfg.SEQ_MAX_LEN - len(encoded)
        if pad_len > 0:
            encoded.extend([self.padding_vec] * pad_len)
        return np.array(encoded, dtype=np.float32)

    def _preprocess_intensities(self, intensities: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
        mask = (~np.isnan(intensities)).astype(np.float32)
        clean_vals = np.nan_to_num(intensities, nan=0.0)

        if np.sum(mask) > 0:
            valid_data = intensities[mask == 1]
            clip_val = np.percentile(valid_data, self.cfg.CLIP_PERCENTILE)
            clean_vals = np.clip(clean_vals, None, clip_val)

        min_val = np.min(clean_vals)
        shift = 0
        if min_val <= 0:
            shift = abs(min_val) + 1.0
        clean_vals = np.log(clean_vals + shift + self.cfg.EPSILON)

        masked_vals = clean_vals[mask == 1]
        if len(masked_vals) > 0:
            mean = np.mean(masked_vals)
            std = np.std(masked_vals) + self.cfg.EPSILON
            clean_vals = (clean_vals - mean) / std

        clean_vals = clean_vals * mask
        return clean_vals, mask

    def get_data(self, protein_name: str, split: str = 'train') -> TensorDataset:
        os.makedirs(self.cfg.SAVE_DIR, exist_ok=True)
        data_path = os.path.join(self.cfg.SAVE_DIR, f"{protein_name}_{split}_data.pt")

        if os.path.exists(data_path):
            print(f"Found cached data for {protein_name} ({split}). Loading...")
            tensors = torch.load(data_path, weights_only=True)
            return TensorDataset(*tensors)

        self._ensure_data_loaded()

        if protein_name not in self.protein_to_id:
            raise ValueError(f"Protein '{protein_name}' not found in metadata.")

        rncmpt_id = self.protein_to_id[protein_name]

        if rncmpt_id not in self.data_df.columns:
            raise ValueError(f"ID {rncmpt_id} for {protein_name} missing from data matrix.")

        s_lower = split.lower()

        if s_lower == 'test':
            subset = self.data_df[self.data_df['Probe_Set'] == self.cfg.TEST_SPLIT_ID].copy()
        elif s_lower in ['train', 'val']:
            full_set = self.data_df[self.data_df['Probe_Set'] == self.cfg.TRAIN_SPLIT_ID]
            full_set = full_set.sort_index()
            n_samples = len(full_set)
            indices = np.arange(n_samples)
            rng = np.random.RandomState(self.cfg.SEED)
            rng.shuffle(indices)
            val_size = int(n_samples * self.cfg.VAL_SPLIT_PCT)
            if s_lower == 'val':
                subset_indices = indices[:val_size]
            else:
                subset_indices = indices[val_size:]
            subset = full_set.iloc[subset_indices].copy()
        else:
            raise ValueError(f"Unknown split '{split}'.")

        raw_seqs = subset['RNA_Seq'].values
        X = np.stack([self._encode_sequence(s) for s in raw_seqs])

        raw_intensities = pd.to_numeric(subset[rncmpt_id], errors='coerce').values
        Y, mask = self._preprocess_intensities(raw_intensities)

        dataset = TensorDataset(
            torch.FloatTensor(X),
            torch.FloatTensor(Y).unsqueeze(1),
            torch.FloatTensor(mask).unsqueeze(1)
        )

        print(f"Saving processed data to {data_path}...")
        torch.save(dataset.tensors, data_path)

        return dataset


def load_rnacompete_data(protein_name: str, split: str = 'train', config: RNAConfig = None):
    if config is None:
        config = RNAConfig()
    loader = RNACompeteLoader(config)
    return loader.get_data(protein_name, split)

## 5. Model Definitions

### 5.1 CNN Model

In [None]:
class RNABindingCNN(nn.Module):
    """
    1D Convolutional Neural Network for RNA sequence binding prediction.

    Architecture:
    - 3 convolutional layers with increasing channels (64 -> 128 -> 256)
    - Batch normalization after each conv layer
    - ReLU activation and dropout for regularization
    - Global max + average pooling for richer representation
    - 2 fully connected layers for regression output
    """

    def __init__(self, input_channels=4, seq_length=41, hidden_dim=128, dropout=0.3):
        super(RNABindingCNN, self).__init__()

        # Convolutional layers with different kernel sizes to capture various motif lengths
        self.conv1 = nn.Conv1d(input_channels, 64, kernel_size=5, padding=2)
        self.bn1 = nn.BatchNorm1d(64)

        self.conv2 = nn.Conv1d(64, 128, kernel_size=7, padding=3)
        self.bn2 = nn.BatchNorm1d(128)

        self.conv3 = nn.Conv1d(128, 256, kernel_size=9, padding=4)
        self.bn3 = nn.BatchNorm1d(256)

        self.dropout = nn.Dropout(dropout)

        # Global pooling (both max and average)
        self.global_max_pool = nn.AdaptiveMaxPool1d(1)
        self.global_avg_pool = nn.AdaptiveAvgPool1d(1)

        # Fully connected layers
        self.fc1 = nn.Linear(256 * 2, hidden_dim)  # *2 for concat of max and avg pool
        self.fc2 = nn.Linear(hidden_dim, 1)
        self.relu = nn.ReLU()

    def forward(self, x):
        # Input x shape: (batch, seq_length, 4)
        # Conv1d expects: (batch, channels, seq_length)
        x = x.permute(0, 2, 1)

        # Convolutional blocks
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.dropout(x)

        x = self.relu(self.bn2(self.conv2(x)))
        x = self.dropout(x)

        x = self.relu(self.bn3(self.conv3(x)))
        x = self.dropout(x)

        # Global pooling
        max_pool = self.global_max_pool(x).squeeze(-1)
        avg_pool = self.global_avg_pool(x).squeeze(-1)
        x = torch.cat([max_pool, avg_pool], dim=1)

        # Fully connected layers
        x = self.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)

        return x

### 5.2 LSTM Model

In [None]:
class RNABindingLSTM(nn.Module):
    """
    Bidirectional LSTM for RNA sequence binding prediction.

    Architecture:
    - 2-layer bidirectional LSTM
    - Batch normalization
    - Dropout for regularization
    - 2 fully connected layers for regression
    """

    def __init__(self, input_dim=4, hidden_dim=128, num_layers=2, dropout=0.3, bidirectional=True):
        super(RNABindingLSTM, self).__init__()

        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.bidirectional = bidirectional
        self.num_directions = 2 if bidirectional else 1

        self.lstm = nn.LSTM(
            input_size=input_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout if num_layers > 1 else 0,
            bidirectional=bidirectional
        )

        self.bn = nn.BatchNorm1d(hidden_dim * self.num_directions)
        self.dropout = nn.Dropout(dropout)

        self.fc1 = nn.Linear(hidden_dim * self.num_directions, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, 1)
        self.relu = nn.ReLU()

    def forward(self, x):
        # Input x shape: (batch, seq_length, 4)
        lstm_out, (hidden, cell) = self.lstm(x)

        if self.bidirectional:
            # Concatenate last hidden states from forward and backward
            hidden_forward = hidden[-2, :, :]
            hidden_backward = hidden[-1, :, :]
            combined = torch.cat([hidden_forward, hidden_backward], dim=1)
        else:
            combined = hidden[-1, :, :]

        combined = self.bn(combined)
        combined = self.dropout(combined)

        x = self.relu(self.fc1(combined))
        x = self.dropout(x)
        x = self.fc2(x)

        return x

## 6. Training and Evaluation Functions

In [None]:
def train_epoch(model, train_loader, optimizer, device):
    """Train for one epoch."""
    model.train()
    total_loss = 0.0
    num_batches = 0

    for batch in train_loader:
        x, y, mask = batch
        x, y, mask = x.to(device), y.to(device), mask.to(device)

        optimizer.zero_grad()
        predictions = model(x)
        loss = masked_mse_loss(predictions, y, mask)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        total_loss += loss.item()
        num_batches += 1

    return total_loss / num_batches


def evaluate(model, data_loader, device):
    """Evaluate model and return loss and Spearman correlation."""
    model.eval()
    total_loss = 0.0
    num_batches = 0

    all_preds = []
    all_targets = []
    all_masks = []

    with torch.no_grad():
        for batch in data_loader:
            x, y, mask = batch
            x, y, mask = x.to(device), y.to(device), mask.to(device)

            predictions = model(x)
            loss = masked_mse_loss(predictions, y, mask)

            total_loss += loss.item()
            num_batches += 1

            all_preds.append(predictions.cpu())
            all_targets.append(y.cpu())
            all_masks.append(mask.cpu())

    all_preds = torch.cat(all_preds, dim=0)
    all_targets = torch.cat(all_targets, dim=0)
    all_masks = torch.cat(all_masks, dim=0)

    spearman_corr = masked_spearman_correlation(all_preds, all_targets, all_masks)

    return total_loss / num_batches, spearman_corr.item()


def train_model(model, train_loader, val_loader, optimizer, scheduler, device,
                num_epochs, model_name, patience=15, save_every=10):
    """Full training loop with early stopping and periodic checkpoints."""
    print(f"\n{'='*60}")
    print(f"Training {model_name}")
    print(f"{'='*60}")

    train_losses = []
    val_losses = []
    val_correlations = []

    best_val_corr = -float('inf')
    best_model_state = None
    epochs_without_improvement = 0

    start_time = time.time()

    for epoch in range(num_epochs):
        epoch_start = time.time()

        train_loss = train_epoch(model, train_loader, optimizer, device)
        val_loss, val_corr = evaluate(model, val_loader, device)

        if scheduler is not None:
            scheduler.step(val_corr)

        train_losses.append(train_loss)
        val_losses.append(val_loss)
        val_correlations.append(val_corr)

        if val_corr > best_val_corr:
            best_val_corr = val_corr
            best_model_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
            epochs_without_improvement = 0
            # Save best model
            torch.save(best_model_state, f'{model_name}_best.pth')
            print(f"  → Saved new best model (Spearman: {val_corr:.4f})")
        else:
            epochs_without_improvement += 1

        epoch_time = time.time() - epoch_start

        if (epoch + 1) % 5 == 0 or epoch == 0:
            print(f"Epoch {epoch+1:3d}/{num_epochs} | "
                  f"Train Loss: {train_loss:.4f} | "
                  f"Val Loss: {val_loss:.4f} | "
                  f"Val Spearman: {val_corr:.4f} | "
                  f"Time: {epoch_time:.2f}s")

        # Periodic checkpoint every N epochs
        if (epoch + 1) % save_every == 0:
            checkpoint = {
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_losses': train_losses,
                'val_losses': val_losses,
                'val_correlations': val_correlations,
                'best_val_corr': best_val_corr
            }
            torch.save(checkpoint, f'{model_name}_checkpoint_epoch{epoch+1}.pth')
            print(f"  → Checkpoint saved at epoch {epoch+1}")

        if epochs_without_improvement >= patience:
            print(f"\nEarly stopping at epoch {epoch+1}")
            break

    total_time = time.time() - start_time

    if best_model_state is not None:
        model.load_state_dict(best_model_state)

    print(f"\nTraining completed in {total_time:.2f}s")
    print(f"Best validation Spearman correlation: {best_val_corr:.4f}")

    return {
        'train_losses': train_losses,
        'val_losses': val_losses,
        'val_correlations': val_correlations,
        'best_val_corr': best_val_corr,
        'training_time': total_time
    }

## 7. Load Data

In [None]:
# Hyperparameters
BATCH_SIZE = 64
NUM_EPOCHS = 50
LEARNING_RATE = 0.001
HIDDEN_DIM = 128
DROPOUT = 0.3
PROTEIN_NAME = 'RBFOX1'

# Load data
print(f"Loading data for protein: {PROTEIN_NAME}")
config = RNAConfig()

train_dataset = load_rnacompete_data(PROTEIN_NAME, split='train', config=config)
val_dataset = load_rnacompete_data(PROTEIN_NAME, split='val', config=config)
test_dataset = load_rnacompete_data(PROTEIN_NAME, split='test', config=config)

print(f"\nDataset sizes:")
print(f"  Train: {len(train_dataset)}")
print(f"  Val: {len(val_dataset)}")
print(f"  Test: {len(test_dataset)}")

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

## 8. Train CNN Model

In [None]:
configure_seed(42)

cnn_model = RNABindingCNN(
    hidden_dim=HIDDEN_DIM,
    dropout=DROPOUT
).to(device)

print("CNN Model Architecture:")
print(cnn_model)
print(f"\nTotal parameters: {sum(p.numel() for p in cnn_model.parameters()):,}")

cnn_optimizer = optim.Adam(cnn_model.parameters(), lr=LEARNING_RATE, weight_decay=1e-5)
cnn_scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    cnn_optimizer, mode='max', factor=0.5, patience=7
)

cnn_history = train_model(
    cnn_model, train_loader, val_loader, cnn_optimizer, cnn_scheduler,
    device, num_epochs=NUM_EPOCHS, model_name='CNN', patience=15
)

In [None]:
# Evaluate CNN on test set
cnn_test_loss, cnn_test_corr = evaluate(cnn_model, test_loader, device)
print(f"\nCNN Test Results:")
print(f"  Test Loss: {cnn_test_loss:.4f}")
print(f"  Test Spearman Correlation: {cnn_test_corr:.4f}")

## 9. Train LSTM Model

In [None]:
configure_seed(42)  # Re-seed for fair comparison

lstm_model = RNABindingLSTM(
    hidden_dim=HIDDEN_DIM,
    dropout=DROPOUT
).to(device)

print("LSTM Model Architecture:")
print(lstm_model)
print(f"\nTotal parameters: {sum(p.numel() for p in lstm_model.parameters()):,}")

lstm_optimizer = optim.Adam(lstm_model.parameters(), lr=LEARNING_RATE, weight_decay=1e-5)
lstm_scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    lstm_optimizer, mode='max', factor=0.5, patience=7
)

lstm_history = train_model(
    lstm_model, train_loader, val_loader, lstm_optimizer, lstm_scheduler,
    device, num_epochs=NUM_EPOCHS, model_name='LSTM', patience=15
)

In [None]:
# Evaluate LSTM on test set
lstm_test_loss, lstm_test_corr = evaluate(lstm_model, test_loader, device)
print(f"\nLSTM Test Results:")
print(f"  Test Loss: {lstm_test_loss:.4f}")
print(f"  Test Spearman Correlation: {lstm_test_corr:.4f}")

## 10. Plotting Results

In [None]:
# Comparison plot
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

colors = ['#2196F3', '#FF5722']  # Blue for CNN, Orange for LSTM

# Plot losses
ax1 = axes[0]
for idx, (name, hist) in enumerate(zip(['CNN', 'LSTM'], [cnn_history, lstm_history])):
    epochs = range(1, len(hist['train_losses']) + 1)
    ax1.plot(epochs, hist['train_losses'], label=f'{name} - Train',
             color=colors[idx], linestyle='-', linewidth=2)
    ax1.plot(epochs, hist['val_losses'], label=f'{name} - Val',
             color=colors[idx], linestyle='--', linewidth=2)
ax1.set_xlabel('Epoch', fontsize=12)
ax1.set_ylabel('Loss (MSE)', fontsize=12)
ax1.set_title('Training and Validation Loss', fontsize=14)
ax1.legend(fontsize=10)
ax1.grid(True, alpha=0.3)

# Plot validation correlation
ax2 = axes[1]
for idx, (name, hist) in enumerate(zip(['CNN', 'LSTM'], [cnn_history, lstm_history])):
    epochs = range(1, len(hist['val_correlations']) + 1)
    ax2.plot(epochs, hist['val_correlations'], label=name,
             color=colors[idx], linewidth=2)
ax2.set_xlabel('Epoch', fontsize=12)
ax2.set_ylabel('Spearman Correlation', fontsize=12)
ax2.set_title('Validation Spearman Correlation', fontsize=14)
ax2.legend(fontsize=10)
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('training_curves_comparison.pdf', bbox_inches='tight', dpi=150)
plt.show()

In [None]:
# Individual CNN loss plot
fig, ax = plt.subplots(figsize=(10, 6))
epochs = range(1, len(cnn_history['train_losses']) + 1)
ax.plot(epochs, cnn_history['train_losses'], label='Train Loss', color='blue', linewidth=2)
ax.plot(epochs, cnn_history['val_losses'], label='Validation Loss', color='orange', linewidth=2)
ax.set_xlabel('Epoch', fontsize=12)
ax.set_ylabel('Loss (MSE)', fontsize=12)
ax.set_title('CNN - Training and Validation Loss', fontsize=14)
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('cnn_loss.pdf', bbox_inches='tight', dpi=150)
plt.show()

In [None]:
# Individual LSTM loss plot
fig, ax = plt.subplots(figsize=(10, 6))
epochs = range(1, len(lstm_history['train_losses']) + 1)
ax.plot(epochs, lstm_history['train_losses'], label='Train Loss', color='blue', linewidth=2)
ax.plot(epochs, lstm_history['val_losses'], label='Validation Loss', color='orange', linewidth=2)
ax.set_xlabel('Epoch', fontsize=12)
ax.set_ylabel('Loss (MSE)', fontsize=12)
ax.set_title('LSTM - Training and Validation Loss', fontsize=14)
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('lstm_loss.pdf', bbox_inches='tight', dpi=150)
plt.show()

## 11. Final Summary

In [None]:
print("="*70)
print("FINAL SUMMARY")
print("="*70)
print(f"\n{'Model':<10} {'Val Spearman':<15} {'Test Spearman':<15} {'Test MSE':<12} {'Train Time':<12}")
print("-" * 64)
print(f"{'CNN':<10} {cnn_history['best_val_corr']:<15.4f} {cnn_test_corr:<15.4f} "
      f"{cnn_test_loss:<12.4f} {cnn_history['training_time']:<12.2f}s")
print(f"{'LSTM':<10} {lstm_history['best_val_corr']:<15.4f} {lstm_test_corr:<15.4f} "
      f"{lstm_test_loss:<12.4f} {lstm_history['training_time']:<12.2f}s")

# Save models
torch.save(cnn_model.state_dict(), 'cnn_model.pth')
torch.save(lstm_model.state_dict(), 'lstm_model.pth')
print("\nModels saved: cnn_model.pth, lstm_model.pth")

## 12. Hyperparameter Tuning

Performs grid search over hyperparameters.

In [None]:
# Hyperparameter grid - focused on learning rate
param_grid = {
    'learning_rate': [0.01, 0.001, 0.0005, 0.0001],
    'hidden_dim': [128],  
    'dropout': [0.3]      
}

def hyperparameter_search(model_class, model_name, param_grid, train_loader, val_loader, device, epochs=20):
    """Grid search for hyperparameters."""
    best_config = None
    best_val_corr = -float('inf')
    results = []
    
    total_configs = len(param_grid['learning_rate']) * len(param_grid['hidden_dim']) * len(param_grid['dropout'])
    config_num = 0
    
    for lr in param_grid['learning_rate']:
        for hidden_dim in param_grid['hidden_dim']:
            for dropout in param_grid['dropout']:
                config_num += 1
                config = {'learning_rate': lr, 'hidden_dim': hidden_dim, 'dropout': dropout}
                print(f"\n[{config_num}/{total_configs}] Config: {config}")
                
                configure_seed(42)
                model = model_class(hidden_dim=hidden_dim, dropout=dropout).to(device)
                optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5)
                scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=3)
                
                history = train_model(model, train_loader, val_loader, optimizer, scheduler,
                                     device, num_epochs=epochs, model_name=model_name, patience=7)
                
                results.append({'config': config, 'best_val_corr': history['best_val_corr']})
                
                if history['best_val_corr'] > best_val_corr:
                    best_val_corr = history['best_val_corr']
                    best_config = config
    
    print(f"\n{'='*60}")
    print(f"Best {model_name} configuration: {best_config}")
    print(f"Best validation Spearman: {best_val_corr:.4f}")
    
    return best_config, results

# Run hyperparameter search
best_cnn_config, cnn_tuning_results = hyperparameter_search(
    RNABindingCNN, 'CNN', param_grid, train_loader, val_loader, device, epochs=20
)

best_lstm_config, lstm_tuning_results = hyperparameter_search(
    RNABindingLSTM, 'LSTM', param_grid, train_loader, val_loader, device, epochs=20
)

print("\n" + "="*60)
print("HYPERPARAMETER SEARCH SUMMARY")
print("="*60)
print(f"Best CNN config: {best_cnn_config}")
print(f"Best LSTM config: {best_lstm_config}")