# Deep Learning Homework 2 - Question 2.1
## RNA Binding Protein (RBP) Interaction Prediction

## 1. Setup and Imports

In [1]:
# Install required packages
!pip install openpyxl -q

In [2]:
import os
import random
import time
import itertools
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from typing import List, Tuple

# Create Output Directory
OUTPUT_DIR = 'Outputs'
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Check device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

Using device: cuda
GPU: NVIDIA GeForce RTX 4060 Laptop GPU


## 2. Download Data

Download the data files from the Google Drive link provided in the homework:
- `norm_data.txt`
- `metadata.xlsx`

Upload them to Colab or mount your Google Drive.

In [3]:
#from google.colab import drive
#drive.mount('/content/drive')

In [4]:
# Set the path to your data files
DATA_DIR = 'data'  # Change this to your folder path

## 3. Configuration and Utility Functions

In [5]:
from config import RNAConfig

In [6]:
from utils import configure_seed, masked_mse_loss, masked_spearman_correlation

configure_seed(42)

## 4. Data Loader

In [7]:
from utils import RNACompeteLoader, load_rnacompete_data

## 5. Model Definitions

### 5.1 CNN Model

In [8]:
class RNABindingCNN(nn.Module):
    """
    1D Convolutional Neural Network for RNA sequence binding prediction.

    Architecture:
    - 3 convolutional layers with increasing channels (64 -> 128 -> 256)
    - Batch normalization after each conv layer
    - ReLU activation and dropout for regularization
    - Global max + average pooling for richer representation
    - 2 fully connected layers for regression output
    """

    def __init__(self, input_channels=4, seq_length=41, hidden_dim=128, dropout=0.3):
        super(RNABindingCNN, self).__init__()
        #L_out = ((L_in + 2P - K) / S)  + 1

        # Convolutional layers with different kernel sizes to capture various motif lengths
        # Padding choices allow to preserve input length through the feature maps
        #L_out = 41
        self.conv1 = nn.Conv1d(input_channels, 64, kernel_size=5, padding=2)
        self.bn1 = nn.BatchNorm1d(64)

        #L_out = 41
        self.conv2 = nn.Conv1d(64, 128, kernel_size=7, padding=3)
        self.bn2 = nn.BatchNorm1d(128)

        #L_out = 41
        self.conv3 = nn.Conv1d(128, 256, kernel_size=9, padding=4)
        self.bn3 = nn.BatchNorm1d(256)

        self.dropout = nn.Dropout(dropout)

        # Global pooling (both max and average)
        # allows to get the strongest value of a given feature for the sequence
        self.global_max_pool = nn.AdaptiveMaxPool1d(1)
        # allows to get the average of how much a feature is present throughout the sequence
        self.global_avg_pool = nn.AdaptiveAvgPool1d(1)

        # Fully connected layers
        self.fc1 = nn.Linear(256 * 2, hidden_dim)  # *2 for concat of max and avg pool
        self.fc2 = nn.Linear(hidden_dim, 1)
        self.relu = nn.ReLU()

    def forward(self, x):
        # Input x shape: (batch, seq_length, 4)
        # Conv1d expects: (batch, channels, seq_length)
        x = x.permute(0, 2, 1)

        # Convolutional blocks
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.dropout(x)

        x = self.relu(self.bn2(self.conv2(x)))
        x = self.dropout(x)

        x = self.relu(self.bn3(self.conv3(x)))
        x = self.dropout(x)

        # Global pooling
        max_pool = self.global_max_pool(x).squeeze(-1)
        avg_pool = self.global_avg_pool(x).squeeze(-1)
        x = torch.cat([max_pool, avg_pool], dim=1)

        # Fully connected layers
        x = self.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)

        return x

### 5.2 LSTM Model

In [9]:
class RNABindingLSTM(nn.Module):
    """
    Bidirectional LSTM for RNA sequence binding prediction.

    Architecture:
    - 2-layer bidirectional LSTM
    - Batch normalization
    - Dropout for regularization
    - 2 fully connected layers for regression
    """

    def __init__(self, input_dim=4, hidden_dim=128, num_layers=2, dropout=0.3, bidirectional=True):
        super(RNABindingLSTM, self).__init__()

        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.bidirectional = bidirectional
        self.num_directions = 2 if bidirectional else 1

        self.lstm = nn.LSTM(
            input_size=input_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout if num_layers > 1 else 0,
            bidirectional=bidirectional
        )

        self.bn = nn.BatchNorm1d(hidden_dim * self.num_directions)
        self.dropout = nn.Dropout(dropout)

        self.fc1 = nn.Linear(hidden_dim * self.num_directions, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, 1)
        self.relu = nn.ReLU()

    def forward(self, x):
        # Input x shape: (batch, seq_length, 4)
        lstm_out, (hidden, cell) = self.lstm(x)

        if self.bidirectional:
            # Concatenate last hidden states from forward and backward
            hidden_forward = hidden[-2, :, :]
            hidden_backward = hidden[-1, :, :]
            combined = torch.cat([hidden_forward, hidden_backward], dim=1)
        else:
            combined = hidden[-1, :, :]

        combined = self.bn(combined)
        combined = self.dropout(combined)

        x = self.relu(self.fc1(combined))
        x = self.dropout(x)
        x = self.fc2(x)

        return x

## 6. Training and Evaluation Functions

In [10]:
def train_epoch(model, train_loader, optimizer, device):
    """Train for one epoch and return loss and correlation."""
    model.train()
    total_loss = 0.0
    num_batches = 0
    
    all_preds = []
    all_targets = []
    all_masks = []

    for batch in train_loader:
        x, y, mask = batch
        x, y, mask = x.to(device), y.to(device), mask.to(device)

        optimizer.zero_grad()
        predictions = model(x)
        loss = masked_mse_loss(predictions, y, mask)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        total_loss += loss.item()
        num_batches += 1
        
        # Collect for metric calculation
        all_preds.append(predictions.detach().cpu())
        all_targets.append(y.detach().cpu())
        all_masks.append(mask.detach().cpu())

    all_preds = torch.cat(all_preds, dim=0)
    all_targets = torch.cat(all_targets, dim=0)
    all_masks = torch.cat(all_masks, dim=0)
    
    spearman_corr = masked_spearman_correlation(all_preds, all_targets, all_masks)

    return total_loss / num_batches, spearman_corr.item()


def evaluate(model, data_loader, device):
    """Evaluate model and return loss and Spearman correlation."""
    model.eval()
    total_loss = 0.0
    num_batches = 0

    all_preds = []
    all_targets = []
    all_masks = []

    with torch.no_grad():
        for batch in data_loader:
            x, y, mask = batch
            x, y, mask = x.to(device), y.to(device), mask.to(device)

            predictions = model(x)
            loss = masked_mse_loss(predictions, y, mask)

            total_loss += loss.item()
            num_batches += 1

            all_preds.append(predictions.cpu())
            all_targets.append(y.cpu())
            all_masks.append(mask.cpu())

    all_preds = torch.cat(all_preds, dim=0)
    all_targets = torch.cat(all_targets, dim=0)
    all_masks = torch.cat(all_masks, dim=0)

    spearman_corr = masked_spearman_correlation(all_preds, all_targets, all_masks)

    return total_loss / num_batches, spearman_corr.item()


def train_model(model, train_loader, val_loader, optimizer, scheduler, device,
                num_epochs, model_name, patience=15, save_every=10):
    """Full training loop with early stopping and periodic checkpoints."""
    print(f"\n{'='*60}")
    print(f"Training {model_name}")
    print(f"{'='*60}")

    train_losses = []
    train_correlations = []
    val_losses = []
    val_correlations = []

    best_val_corr = -float('inf')
    best_model_state = None
    epochs_without_improvement = 0

    start_time = time.time()

    for epoch in range(num_epochs):
        epoch_start = time.time()

        train_loss, train_corr = train_epoch(model, train_loader, optimizer, device)
        val_loss, val_corr = evaluate(model, val_loader, device)

        if scheduler is not None:
            scheduler.step(val_corr)

        train_losses.append(train_loss)
        train_correlations.append(train_corr)
        val_losses.append(val_loss)
        val_correlations.append(val_corr)

        if val_corr > best_val_corr:
            best_val_corr = val_corr
            best_model_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
            epochs_without_improvement = 0
            # Save best model to Output directory
            torch.save(best_model_state, f'{OUTPUT_DIR}/{model_name}_best.pth')
            print(f"  → Saved new best model (Spearman: {val_corr:.4f})")
        else:
            epochs_without_improvement += 1

        epoch_time = time.time() - epoch_start

        if (epoch + 1) % 5 == 0 or epoch == 0:
            print(f"Epoch {epoch+1:3d}/{num_epochs} | "
                  f"Train Loss: {train_loss:.4f} | "
                  f"Val Loss: {val_loss:.4f} | "
                  f"Train Spearman: {train_corr:.4f} | "
                  f"Val Spearman: {val_corr:.4f} | "
                  f"Time: {epoch_time:.2f}s")

        # Periodic checkpoint every N epochs
        if (epoch + 1) % save_every == 0:
            checkpoint = {
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_losses': train_losses,
                'train_correlations': train_correlations,
                'val_losses': val_losses,
                'val_correlations': val_correlations,
                'best_val_corr': best_val_corr
            }
            torch.save(checkpoint, f'{OUTPUT_DIR}/{model_name}_checkpoint_epoch{epoch+1}.pth')
            print(f"  → Checkpoint saved at epoch {epoch+1}")

        if epochs_without_improvement >= patience:
            print(f"\nEarly stopping at epoch {epoch+1}")
            break

    total_time = time.time() - start_time

    if best_model_state is not None:
        model.load_state_dict(best_model_state)

    print(f"\nTraining completed in {total_time:.2f}s")
    print(f"Best validation Spearman correlation: {best_val_corr:.4f}")

    return {
        'train_losses': train_losses,
        'train_correlations': train_correlations,
        'val_losses': val_losses,
        'val_correlations': val_correlations,
        'best_val_corr': best_val_corr,
        'training_time': total_time
    }

## 7. Load Data

In [11]:
PROTEIN_NAME = 'RBFOX1'
BATCH_SIZE = 64

# Load data
print(f"Loading data for protein: {PROTEIN_NAME}")
config = RNAConfig()

train_dataset = load_rnacompete_data(PROTEIN_NAME, split='train', config=config)
val_dataset = load_rnacompete_data(PROTEIN_NAME, split='val', config=config)
test_dataset = load_rnacompete_data(PROTEIN_NAME, split='test', config=config)

print(f"\nDataset sizes:")
print(f"  Train: {len(train_dataset)}")
print(f"  Val: {len(val_dataset)}")
print(f"  Test: {len(test_dataset)}")

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

Loading data for protein: RBFOX1
Found cached data for RBFOX1 (train). Loading from data/RBFOX1_train_data.pt...
Found cached data for RBFOX1 (val). Loading from data/RBFOX1_val_data.pt...
Found cached data for RBFOX1 (test). Loading from data/RBFOX1_test_data.pt...

Dataset sizes:
  Train: 96261
  Val: 24065
  Test: 121031


## 8. Hyperparameter Search and Training

Here we define the search space and a function to iterate through hyperparameters, training the model for each combination.

In [12]:
# --- ROBUST GRID SEARCH (12 Combinations) ---
# for train_model with patience 7, and scheduler patience 5
param_grid = {
    # Covers best LSTM (0.001), best CNN (0.0001) and midpoint (0.0005)
    'learning_rate': [0.001, 0.0005, 0.0001],
    # Tests if doubling capacity improves results
    'hidden_dim': [128, 256],
    # Standard regularization vs Stronger regularization (for larger models)
    'dropout': [0.3, 0.5]
}

#simply fast grid to test if code runs well
# param_grid = {
#     # Covers best LSTM (0.001), best CNN (0.0001) and midpoint (0.0005)
#     'learning_rate': [0.0001],
#     # Tests if doubling capacity improves results
#     'hidden_dim': [2],
#     # Standard regularization vs Stronger regularization (for larger models)
#     'dropout': [0.8]
# }

def hyperparameter_search(model_class, model_name, param_grid, train_loader, val_loader, device, epochs=50):
    keys, values = zip(*param_grid.items())
    experiments = [dict(zip(keys, v)) for v in itertools.product(*values)]
    
    best_val_corr = -float('inf')
    best_config = None
    best_history = None
    
    print(f"Starting Hyperparameter Search for {model_name} with {len(experiments)} configurations...")
    
    results = []
    # Define a path for the intermediate results
    results_path = f'{OUTPUT_DIR}/{model_name}_search_results.json'
    
    for i, config in enumerate(experiments):
        print(f"\nRunning experiment {i+1}/{len(experiments)}: {config}")
        
        configure_seed(42)
        
        # Initialize model with current config
        # Note: LSTM and CNN have slightly different init signatures if not careful,
        # but here we matched the param names in the grid to the init args
        if model_name == 'LSTM':
             model = model_class(
                hidden_dim=config['hidden_dim'],
                dropout=config['dropout']
            ).to(device)
        else:
             model = model_class(
                hidden_dim=config['hidden_dim'],
                dropout=config['dropout']
            ).to(device)
            
        optimizer = optim.Adam(model.parameters(), lr=config['learning_rate'], weight_decay=1e-5)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=5)
        
        # Train - using fewer epochs for search to save time (can increase for final result)
        # We use a unique name for temporary checkpoints
        temp_name = f"{model_name}_exp{i}"
        history = train_model(
            model, train_loader, val_loader, optimizer, scheduler,
            device, num_epochs=epochs, model_name=temp_name, patience=7, save_every=100
        )
        
        # --- SAVE PROGRESS IMMEDIATELY ---
        run_result = {
            'experiment_index': i,
            'config': config,
            'best_val_corr': history['best_val_corr'],
            'training_time': history['training_time'],
            'history': history # Full history included for plotting/resuming
        }
        results.append(run_result)
        
        try:
            with open(results_path, 'w') as f:
                json.dump(results, f, indent=4)
            print(f"  → Search progress saved to {results_path}")
        except Exception as e:
            print(f"  → Warning: Could not save progress to JSON: {e}")

        if history['best_val_corr'] > best_val_corr:
            best_val_corr = history['best_val_corr']
            best_config = config
            best_history = history
            
            # Save the best model of the search as the final best model
            # We load the best state from the temp file and save it to the final destination
            best_state = torch.load(f'{OUTPUT_DIR}/{temp_name}_best.pth')
            torch.save(best_state, f'{OUTPUT_DIR}/best_{model_name.lower()}_model.pth')
            
    print(f"\n{'='*60}")
    print(f"Best {model_name} configuration: {best_config}")
    print(f"Best validation Spearman: {best_val_corr:.4f}")
    print(f"Best model saved to {OUTPUT_DIR}/best_{model_name.lower()}_model.pth")
    print(f"{'='*60}")
    
    return best_config, best_history


In [13]:
# 8.1 Train and Search CNN
print("---- Tuning CNN ----")
best_cnn_config, cnn_history = hyperparameter_search(
    RNABindingCNN, 'CNN', param_grid, train_loader, val_loader, device, epochs=50
)

---- Tuning CNN ----
Starting Hyperparameter Search for CNN with 12 configurations...

Running experiment 1/12: {'learning_rate': 0.001, 'hidden_dim': 128, 'dropout': 0.3}

Training CNN_exp0
  → Saved new best model (Spearman: 0.5452)
Epoch   1/50 | Train Loss: 0.6327 | Val Loss: 0.5194 | Train Spearman: 0.4410 | Val Spearman: 0.5452 | Time: 23.46s
  → Saved new best model (Spearman: 0.5710)
  → Saved new best model (Spearman: 0.5893)
  → Saved new best model (Spearman: 0.6036)
  → Saved new best model (Spearman: 0.6060)
Epoch   5/50 | Train Loss: 0.4774 | Val Loss: 0.4930 | Train Spearman: 0.5460 | Val Spearman: 0.6060 | Time: 18.50s
  → Saved new best model (Spearman: 0.6115)
  → Saved new best model (Spearman: 0.6223)
  → Saved new best model (Spearman: 0.6268)
  → Saved new best model (Spearman: 0.6307)
Epoch  10/50 | Train Loss: 0.4424 | Val Loss: 0.4380 | Train Spearman: 0.5741 | Val Spearman: 0.6307 | Time: 20.13s
  → Saved new best model (Spearman: 0.6347)
  → Saved new best mo

In [None]:
# 8.2 Train and Search LSTM
print("---- Tuning LSTM ----")
best_lstm_config, lstm_history = hyperparameter_search(
    RNABindingLSTM, 'LSTM', param_grid, train_loader, val_loader, device, epochs=50
)

---- Tuning LSTM ----
Starting Hyperparameter Search for LSTM with 12 configurations...

Running experiment 1/12: {'learning_rate': 0.001, 'hidden_dim': 128, 'dropout': 0.3}

Training LSTM_exp0
  → Saved new best model (Spearman: 0.4778)
Epoch   1/50 | Train Loss: 0.8795 | Val Loss: 0.6444 | Train Spearman: 0.3200 | Val Spearman: 0.4778 | Time: 28.15s
  → Saved new best model (Spearman: 0.5250)
  → Saved new best model (Spearman: 0.5390)
  → Saved new best model (Spearman: 0.5772)


## 9. Test Evaluation

In [None]:
# Load best CNN model
cnn_model = RNABindingCNN(
    hidden_dim=best_cnn_config['hidden_dim'], 
    dropout=best_cnn_config['dropout']
).to(device)
cnn_model.load_state_dict(torch.load(f'{OUTPUT_DIR}/best_cnn_model.pth'))

# Evaluate CNN on test set
cnn_test_loss, cnn_test_corr = evaluate(cnn_model, test_loader, device)
print(f"\nBest CNN Test Results:")
print(f"  Test Loss: {cnn_test_loss:.4f}")
print(f"  Test Spearman Correlation: {cnn_test_corr:.4f}")

In [None]:
# Load best LSTM model
lstm_model = RNABindingLSTM(
    hidden_dim=best_lstm_config['hidden_dim'], 
    dropout=best_lstm_config['dropout']
).to(device)
lstm_model.load_state_dict(torch.load(f'{OUTPUT_DIR}/best_lstm_model.pth'))

# Evaluate LSTM on test set
lstm_test_loss, lstm_test_corr = evaluate(lstm_model, test_loader, device)
print(f"\nBest LSTM Test Results:")
print(f"  Test Loss: {lstm_test_loss:.4f}")
print(f"  Test Spearman Correlation: {lstm_test_corr:.4f}")

## 10. Plotting Results and Saving History

Saving plots and full experiment history to JSON for reproducibility.

In [None]:
import json

def save_and_plot_results(cnn_hist, lstm_hist, cnn_config, lstm_config, cnn_test_res, lstm_test_res):
    # Unpack test results
    cnn_test_loss, cnn_test_corr = cnn_test_res
    lstm_test_loss, lstm_test_corr = lstm_test_res
    
    # 1. Prepare Data for JSON Storage
    experiment_data = {
        "CNN": {
            "parameters": cnn_config,
            "training_time_seconds": cnn_hist['training_time'],
            "best_val_accuracy": cnn_hist['best_val_corr'],
            "best_val_epoch": int(np.argmax(cnn_hist['val_correlations']) + 1),
            "test_loss": cnn_test_loss,
            "test_accuracy": cnn_test_corr,
            "history": {
                "train_losses": cnn_hist['train_losses'],
                "train_correlations": cnn_hist['train_correlations'],
                "val_losses": cnn_hist['val_losses'],
                "val_correlations": cnn_hist['val_correlations']
            }
        },
        "LSTM": {
            "parameters": lstm_config,
            "training_time_seconds": lstm_hist['training_time'],
            "best_val_accuracy": lstm_hist['best_val_corr'],
            "best_val_epoch": int(np.argmax(lstm_hist['val_correlations']) + 1),
            "test_loss": lstm_test_loss,
            "test_accuracy": lstm_test_corr,
            "history": {
                "train_losses": lstm_hist['train_losses'],
                "train_correlations": lstm_hist['train_correlations'],
                "val_losses": lstm_hist['val_losses'],
                "val_correlations": lstm_hist['val_correlations']
            }
        }
    }

    # Save to JSON
    json_path = f'{OUTPUT_DIR}/complete_experiment_data.json'
    with open(json_path, 'w') as f:
        json.dump(experiment_data, f, indent=4)
    print(f"Complete experiment data saved to '{json_path}'")

    # 2. Plotting
    epochs_cnn = range(1, len(cnn_hist['train_losses']) + 1)
    epochs_lstm = range(1, len(lstm_hist['train_losses']) + 1)

    # --- Plot 1: Loss ---
    plt.figure(figsize=(10, 6))
    plt.plot(epochs_cnn, cnn_hist['train_losses'], label='CNN Train Loss', color='blue', linestyle='--')
    plt.plot(epochs_cnn, cnn_hist['val_losses'], label='CNN Val Loss', color='blue')
    plt.plot(epochs_lstm, lstm_hist['train_losses'], label='LSTM Train Loss', color='red', linestyle='--')
    plt.plot(epochs_lstm, lstm_hist['val_losses'], label='LSTM Val Loss', color='red')
    plt.xlabel('Epochs')
    plt.ylabel('Loss (MSE)')
    plt.title('Training and Validation Loss')
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(f'{OUTPUT_DIR}/loss_plot.png')
    plt.close()

    # --- Plot 2: Accuracy (Spearman Correlation) ---
    plt.figure(figsize=(10, 6))
    plt.plot(epochs_cnn, cnn_hist['train_correlations'], label='CNN Train Spearman', color='blue', linestyle='--')
    plt.plot(epochs_cnn, cnn_hist['val_correlations'], label='CNN Val Spearman', color='blue')
    plt.plot(epochs_lstm, lstm_hist['train_correlations'], label='LSTM Train Spearman', color='red', linestyle='--')
    plt.plot(epochs_lstm, lstm_hist['val_correlations'], label='LSTM Val Spearman', color='red')
    plt.xlabel('Epochs')
    plt.ylabel('Spearman Correlation')
    plt.title('Training and Validation Spearman Correlation')
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(f'{OUTPUT_DIR}/accuracy_plot.png')
    plt.close()
    
    print("Plots saved to 'loss_plot.png' and 'accuracy_plot.png'.")

# Execute
save_and_plot_results(
    cnn_history, lstm_history, 
    best_cnn_config, best_lstm_config, 
    (cnn_test_loss, cnn_test_corr), 
    (lstm_test_loss, lstm_test_corr)
)