# Phase 1: RNA-Protein Binding Prediction - Data Exploration (Colab Version)

This Colab-ready notebook mirrors the local `01_data_exploration.ipynb` and adds setup for Google Colab (Drive mount, pip installs, and paths).

In [None]:
# 0. Colab Setup: Install deps and (optionally) mount Google Drive
import sys, os, subprocess
from pathlib import Path

IN_COLAB = 'google.colab' in sys.modules or 'COLAB_GPU' in os.environ
print('Running in Colab:', IN_COLAB)

DRIVE_MOUNTED = False
if IN_COLAB:
    try:
        from google.colab import drive  # type: ignore
        drive.mount('/content/drive', force_remount=False)
        DRIVE_MOUNTED = True
        print('Drive mounted at /content/drive')
    except Exception as e:
        print('Drive mount skipped or failed:', e)
        DRIVE_MOUNTED = False

# Install minimal required packages (independent of any repo requirements.txt)
if IN_COLAB:
    pkgs = ['numpy', 'pandas', 'matplotlib', 'seaborn', 'tqdm', 'scipy', 'scikit-learn', 'torch']
    print('Installing packages:', pkgs)
    subprocess.check_call([sys.executable, '-m', 'pip', 'install', *pkgs])

# Base directories (project-independent)
if IN_COLAB:
    BASE_DIR = Path('/content/drive/MyDrive') if DRIVE_MOUNTED else Path('/content')
else:
    BASE_DIR = Path.cwd()

RUNS_DIR = BASE_DIR / 'dlcb_runs'  # where outputs will be saved
DATA_DIR = BASE_DIR / 'data'       # where data files should exist or be uploaded to

print('Base dir:', BASE_DIR)
print('Runs dir:', RUNS_DIR)
print('Data dir:', DATA_DIR)

In [None]:
# 0.1 Local helper utilities (self-contained, no repo imports)
import json
from pathlib import Path

def create_run_directory(base_dir, run_name):
    p = Path(base_dir) / run_name
    (p).mkdir(parents=True, exist_ok=True)
    (p / 'models').mkdir(parents=True, exist_ok=True)
    (p / 'plots').mkdir(parents=True, exist_ok=True)
    return str(p)

def save_training_config(cfg, run_dir):
    with open(Path(run_dir) / 'config.json', 'w') as f:
        json.dump(cfg, f, indent=2)

def save_training_summary(summary, run_dir):
    with open(Path(run_dir) / 'summary.json', 'w') as f:
        json.dump(summary, f, indent=2)

print('Local helpers ready (project-independent)')

In [None]:
# 1. Data paths and basic checks
from pathlib import Path
import os

print('Looking for data files in:', DATA_DIR)
DATA_DIR.mkdir(parents=True, exist_ok=True)

required_files = ['training_seqs.txt', 'training_RBPs2.txt', 'training_data2.txt']
for p in required_files:
    full = Path(DATA_DIR)/p
    print('-', p, 'exists:', full.exists())

missing = [p for p in required_files if not (Path(DATA_DIR)/p).exists()]
if missing:
    print('\nMissing files:', missing)
    print('Options:')
    print('  • Place them under', DATA_DIR)
    print('  • Or run the previous cell to upload them directly')

### Optional: Upload data files directly (no Drive)

In [None]:
# Run this cell if you prefer to upload the 3 data files instead of using Drive
try:
    from google.colab import files  # type: ignore
    DATA_DIR.mkdir(parents=True, exist_ok=True)
    print('Please select files: training_seqs.txt, training_RBPs2.txt, training_data2.txt')
    uploaded = files.upload()
    for fname, content in uploaded.items():
        out_path = DATA_DIR / fname
        with open(out_path, 'wb') as f:
            f.write(content)
        print('Saved', out_path)
except Exception as e:
    print('Upload helper not available. If in Colab, ensure the environment is active. Error:', e)

## 1. Import Required Libraries (Colab-safe)

In [None]:
# Core libraries
import os
import sys
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import time
import signal
from collections import Counter
from datetime import datetime

# PyTorch for deep learning
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import torch.nn.functional as F

# For evaluation metrics
from scipy.stats import pearsonr
from sklearn.metrics import mean_squared_error, mean_absolute_error

# For progress bars
from tqdm import tqdm

# Set up plotting style
plt.style.use('default')
sns.set_palette('husl')
plt.rcParams['figure.figsize'] = (12, 8)
plt.rcParams['font.size'] = 12

# Set random seeds for reproducibility
np.random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)

# Create run directory under RUNS_DIR
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
run_name = f'notebook_phase1_{timestamp}'
run_dir = create_run_directory(RUNS_DIR, run_name)

print('Libraries imported successfully!')
print(f'PyTorch version: {torch.__version__}')
print(f'CUDA available: {torch.cuda.is_available()}')
print(f"Device: {'cuda' if torch.cuda.is_available() else 'cpu'}")
print(f'Run directory: {run_dir}')

## 2. Data Loading and Exploration

In [None]:
# Define data directory from earlier setup
data_dir = str(DATA_DIR)

# Load RNA sequences
print('Loading RNA sequences...')
rna_file = os.path.join(data_dir, 'training_seqs.txt')
with open(rna_file, 'r') as f:
    rna_sequences = [line.strip() for line in f.readlines()]

print(f'Loaded {len(rna_sequences):,} RNA sequences')
print('First 5 RNA sequences:')
for i, seq in enumerate(rna_sequences[:5]):
    print(f'  {i+1}: {seq}')

print('\n' + '='*60)

# Load protein sequences
print('Loading protein sequences...')
protein_file = os.path.join(data_dir, 'training_RBPs2.txt')
with open(protein_file, 'r') as f:
    protein_sequences = [line.strip() for line in f.readlines()]

print(f'Loaded {len(protein_sequences):,} protein sequences')
print('First 3 protein sequences:')
for i, seq in enumerate(protein_sequences[:3]):
    print(f"  {i+1}: {seq[:60]}{'...' if len(seq) > 60 else ''} (length: {len(seq)})")

print('\n' + '='*60)

# Load a sample of binding scores to understand the format
print('Examining binding scores format...')
scores_file = os.path.join(data_dir, 'training_data2.txt')

# Read first few lines to understand structure
with open(scores_file, 'r') as f:
    sample_lines = [f.readline().strip() for _ in range(5)]

print('Sample binding score lines:')
for i, line in enumerate(sample_lines):
    scores = line.split()
    print(f'  Line {i+1}: {len(scores)} scores, first 10: {scores[:10]}')

print(f"\nData structure:")
print(f"- Each line represents binding scores for one RNA sequence across all {len(protein_sequences)} proteins")
print(f"- Total expected lines: {len(rna_sequences):,}")
print(f"- Each line should have {len(protein_sequences)} scores")

In [None]:
# Analyze sequence characteristics
print('Analyzing sequence characteristics...')

# RNA sequence analysis
rna_lengths = [len(seq) for seq in rna_sequences]
rna_nucleotide_counts = Counter(''.join(rna_sequences))

# Protein sequence analysis
protein_lengths = [len(seq) for seq in protein_sequences]
protein_aa_counts = Counter(''.join(protein_sequences))

# Display statistics
print('\nRNA Sequence Statistics:')
print(f'  Count: {len(rna_sequences):,}')
print(f'  Length - Min: {min(rna_lengths)}, Max: {max(rna_lengths)}, Mean: {np.mean(rna_lengths):.1f}')
print(f'  Length - Median: {np.median(rna_lengths):.1f}, Std: {np.std(rna_lengths):.1f}')
print(f'  Nucleotide composition: {dict(sorted(rna_nucleotide_counts.items()))}')

print('\nProtein Sequence Statistics:')
print(f'  Count: {len(protein_sequences):,}')
print(f'  Length - Min: {min(protein_lengths)}, Max: {max(protein_lengths)}, Mean: {np.mean(protein_lengths):.1f}')
print(f'  Length - Median: {np.median(protein_lengths):.1f}, Std: {np.std(protein_lengths):.1f}')

# Top 10 amino acids
top_aa = sorted(protein_aa_counts.items(), key=lambda x: x[1], reverse=True)[:10]
print(f'  Top 10 amino acids: {top_aa}')

# Visualize sequence length distributions
fig, axes = plt.subplots(1, 2, figsize=(15, 6))

# RNA length distribution
axes[0].hist(rna_lengths, bins=50, alpha=0.7, color='skyblue', edgecolor='black')
axes[0].axvline(np.mean(rna_lengths), color='red', linestyle='--', label=f'Mean: {np.mean(rna_lengths):.1f}')
axes[0].axvline(np.median(rna_lengths), color='orange', linestyle='--', label=f'Median: {np.median(rna_lengths):.1f}')
axes[0].set_xlabel('RNA Sequence Length')
axes[0].set_ylabel('Frequency')
axes[0].set_title('Distribution of RNA Sequence Lengths')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Protein length distribution
axes[1].hist(protein_lengths, bins=30, alpha=0.7, color='lightcoral', edgecolor='black')
axes[1].axvline(np.mean(protein_lengths), color='red', linestyle='--', label=f'Mean: {np.mean(protein_lengths):.1f}')
axes[1].axvline(np.median(protein_lengths), color='orange', linestyle='--', label=f'Median: {np.median(protein_lengths):.1f}')
axes[1].set_xlabel('Protein Sequence Length')
axes[1].set_ylabel('Frequency')
axes[1].set_title('Distribution of Protein Sequence Lengths')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Determine optimal sequence lengths (95th percentile)
rna_95th = np.percentile(rna_lengths, 95)
protein_95th = np.percentile(protein_lengths, 95)

print(f'\nRecommended max lengths (95th percentile):')
print(f'  RNA: {rna_95th:.0f}')
print(f'  Protein: {protein_95th:.0f}')

# Load and analyze binding scores (sample only due to large file size)
print('Loading and analyzing binding scores...')

# Load first 1000 lines of binding scores for analysis
sample_scores = []
with open(scores_file, 'r') as f:
    for i in range(1000):  # Sample first 1000 RNA sequences
        line = f.readline().strip()
        if line:
            scores = [float(x) for x in line.split()]
            sample_scores.extend(scores)

sample_scores = np.array(sample_scores)

print(f'Binding Score Statistics (sample of {len(sample_scores):,} scores):')
print(f'  Min: {sample_scores.min():.3f}')
print(f'  Max: {sample_scores.max():.3f}')
print(f'  Mean: {sample_scores.mean():.3f}')
print(f'  Median: {np.median(sample_scores):.3f}')
print(f'  Std: {sample_scores.std():.3f}')
print(f'  25th percentile: {np.percentile(sample_scores, 25):.3f}')
print(f'  75th percentile: {np.percentile(sample_scores, 75):.3f}')

# Visualize binding score distribution
fig, axes = plt.subplots(1, 2, figsize=(15, 6))

# Histogram
axes[0].hist(sample_scores, bins=50, alpha=0.7, color='lightgreen', edgecolor='black')
axes[0].axvline(sample_scores.mean(), color='red', linestyle='--', label=f'Mean: {sample_scores.mean():.3f}')
axes[0].axvline(np.median(sample_scores), color='orange', linestyle='--', label=f'Median: {np.median(sample_scores):.3f}')
axes[0].set_xlabel('Binding Score')
axes[0].set_ylabel('Frequency')
axes[0].set_title('Distribution of Binding Scores')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Box plot
axes[1].boxplot(sample_scores, patch_artist=True, 
                boxprops=dict(facecolor='lightblue', alpha=0.7))
axes[1].set_ylabel('Binding Score')
axes[1].set_title('Box Plot of Binding Scores')
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Check for data quality issues
print(f'\nData Quality Check:')
print(f'  NaN values: {np.isnan(sample_scores).sum()}')
print(f'  Infinite values: {np.isinf(sample_scores).sum()}')
print(f'  Negative values: {(sample_scores < 0).sum()}')
print(f'  Values > 10: {(sample_scores > 10).sum()}')

# Normalize sample scores to see normalized distribution
normalized_scores = (sample_scores - sample_scores.min()) / (sample_scores.max() - sample_scores.min())
print(f'\nNormalized scores (0-1 range):')
print(f'  Min: {normalized_scores.min():.3f}, Max: {normalized_scores.max():.3f}')
print(f'  Mean: {normalized_scores.mean():.3f}, Std: {normalized_scores.std():.3f}')

## 3. Data Preprocessing and One-Hot Encoding

In [None]:
def encode_rna_sequence(sequence, max_length=None):
    """One-hot encode RNA sequence."""
    nucleotide_mapping = {'A': 0, 'U': 1, 'G': 2, 'C': 3, 'T': 1, 'N': 4}
    sequence_indices = []
    for nucleotide in sequence.upper():
        sequence_indices.append(nucleotide_mapping.get(nucleotide, 4))
    if max_length:
        if len(sequence_indices) > max_length:
            sequence_indices = sequence_indices[:max_length]
        else:
            sequence_indices.extend([4] * (max_length - len(sequence_indices)))
    num_nucleotides = 5
    one_hot = np.zeros((len(sequence_indices), num_nucleotides))
    for i, idx in enumerate(sequence_indices):
        one_hot[i, idx] = 1
    return one_hot


def encode_protein_sequence(sequence, max_length=None):
    """One-hot encode protein sequence."""
    amino_acids = 'ACDEFGHIKLMNPQRSTVWY'
    aa_mapping = {aa: i for i, aa in enumerate(amino_acids)}
    aa_mapping['X'] = 20
    sequence_indices = []
    for aa in sequence.upper():
        sequence_indices.append(aa_mapping.get(aa, 20))
    if max_length:
        if len(sequence_indices) > max_length:
            sequence_indices = sequence_indices[:max_length]
        else:
            sequence_indices.extend([20] * (max_length - len(sequence_indices)))
    num_amino_acids = 21
    one_hot = np.zeros((len(sequence_indices), num_amino_acids))
    for i, idx in enumerate(sequence_indices):
        one_hot[i, idx] = 1
    return one_hot

# Test encoding functions
print('Testing encoding functions...')

# Test RNA encoding
test_rna = rna_sequences[0]
encoded_rna = encode_rna_sequence(test_rna, max_length=50)
print(f'Original RNA: {test_rna}')
print(f'Encoded shape: {encoded_rna.shape}')
print('First 5 positions (one-hot):')
for i in range(5):
    print(f"  Position {i}: {encoded_rna[i]} -> {test_rna[i] if i < len(test_rna) else 'PAD'}")

print('\n' + '-'*50)

# Test protein encoding
test_protein = protein_sequences[0]
encoded_protein = encode_protein_sequence(test_protein, max_length=100)
print(f"Original protein: {test_protein[:50]}...")
print(f'Encoded shape: {encoded_protein.shape}')
print('First 5 positions (one-hot):')
for i in range(5):
    print(f"  Position {i}: {encoded_protein[i]} -> {test_protein[i] if i < len(test_protein) else 'PAD'}")

# Visualize encoding example
fig, axes = plt.subplots(1, 2, figsize=(15, 6))
axes[0].imshow(encoded_rna[:20].T, cmap='Blues', aspect='auto')
axes[0].set_xlabel('Sequence Position')
axes[0].set_ylabel('Nucleotide (A,U,G,C,N)')
axes[0].set_title('RNA One-Hot Encoding (First 20 positions)')
axes[0].set_yticks(range(5))
axes[0].set_yticklabels(['A', 'U', 'G', 'C', 'N'])

axes[1].imshow(encoded_protein[:20].T, cmap='Reds', aspect='auto')
axes[1].set_xlabel('Sequence Position')
axes[1].set_ylabel('Amino Acid Index')
axes[1].set_title('Protein One-Hot Encoding (First 20 positions)')

plt.tight_layout()
plt.show()

## 4. Dataset and DataLoader Implementation

In [None]:
class RNAProteinDataset(Dataset):
    def __init__(self, rna_sequences, protein_sequences, binding_scores,
                 rna_max_length=50, protein_max_length=500, normalize_scores=True):
        self.rna_sequences = rna_sequences
        self.protein_sequences = protein_sequences
        self.binding_scores = binding_scores.copy()
        self.rna_max_length = rna_max_length
        self.protein_max_length = protein_max_length
        if normalize_scores and len(self.binding_scores) > 0:
            min_score = self.binding_scores.min()
            max_score = self.binding_scores.max()
            if max_score > min_score:
                self.binding_scores = (self.binding_scores - min_score) / (max_score - min_score)
        print('Pre-encoding sequences...')
        self.rna_encoded = []
        self.protein_encoded = []
        for i, (rna_seq, protein_seq) in enumerate(zip(rna_sequences, protein_sequences)):
            if i % 1000 == 0:
                print(f'Encoded {i}/{len(rna_sequences)} sequences')
            self.rna_encoded.append(encode_rna_sequence(rna_seq, self.rna_max_length))
            self.protein_encoded.append(encode_protein_sequence(protein_seq, self.protein_max_length))
        print(f'Encoding complete. Dataset size: {len(self.rna_encoded)}')
    def __len__(self):
        return len(self.rna_sequences)
    def __getitem__(self, idx):
        rna_tensor = torch.FloatTensor(self.rna_encoded[idx])
        protein_tensor = torch.FloatTensor(self.protein_encoded[idx])
        score_tensor = torch.FloatTensor([self.binding_scores[idx]])
        return {'rna': rna_tensor, 'protein': protein_tensor, 'score': score_tensor}


def load_subset_data(data_dir, subset_size=200):
    print(f'Loading subset of {subset_size} RNA sequences...')
    rna_subset = rna_sequences[:subset_size]
    binding_scores = []
    with open(os.path.join(data_dir, 'training_data2.txt'), 'r') as f:
        for i in range(subset_size):
            line = f.readline().strip()
            if line:
                scores = [float(x) for x in line.split()]
                binding_scores.extend(scores)
    binding_scores = np.array(binding_scores)
    rna_protein_pairs = []
    protein_rna_pairs = []
    final_scores = []
    scores_per_rna = len(protein_sequences)
    for rna_idx, rna_seq in enumerate(rna_subset):
        start_idx = rna_idx * scores_per_rna
        end_idx = start_idx + scores_per_rna
        if end_idx <= len(binding_scores):
            rna_scores = binding_scores[start_idx:end_idx]
            for protein_idx, protein_seq in enumerate(protein_sequences):
                rna_protein_pairs.append(rna_seq)
                protein_rna_pairs.append(protein_seq)
                final_scores.append(rna_scores[protein_idx])
    print(f'Created {len(rna_protein_pairs)} RNA-protein pairs')
    return rna_protein_pairs, protein_rna_pairs, np.array(final_scores)

subset_rna, subset_protein, subset_scores = load_subset_data(data_dir, subset_size=200)

print('Subset data summary:')
print(f'  RNA-protein pairs: {len(subset_rna)}')
print(f'  Score range: {subset_scores.min():.3f} - {subset_scores.max():.3f}')
print(f'  Score mean: {subset_scores.mean():.3f}')

RNA_MAX_LENGTH = 50
PROTEIN_MAX_LENGTH = 500

dataset = RNAProteinDataset(
    subset_rna, subset_protein, subset_scores,
    rna_max_length=RNA_MAX_LENGTH,
    protein_max_length=PROTEIN_MAX_LENGTH
)

train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

print('Dataset split:')
print(f'  Training: {len(train_dataset)} samples')
print(f'  Validation: {len(val_dataset)} samples')

batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0)

print(f'Data loaders created with batch size: {batch_size}')

sample_batch = next(iter(train_loader))
print('Sample batch shapes:')
print(f"  RNA: {sample_batch['rna'].shape}")
print(f"  Protein: {sample_batch['protein'].shape}")
print(f"  Scores: {sample_batch['score'].shape}")

## 5. Basic Bidirectional LSTM Model

In [None]:
class BasicLSTM(nn.Module):
    def __init__(self, rna_input_size=5, protein_input_size=21,
                 rna_hidden_size=128, protein_hidden_size=128,
                 num_layers=2, dropout=0.2, fusion_hidden_size=256):
        super().__init__()
        self.rna_hidden_size = rna_hidden_size
        self.protein_hidden_size = protein_hidden_size
        self.num_layers = num_layers
        self.rna_lstm = nn.LSTM(
            input_size=rna_input_size,
            hidden_size=rna_hidden_size,
            num_layers=num_layers,
            dropout=dropout if num_layers > 1 else 0,
            bidirectional=True,
            batch_first=True
        )
        self.protein_lstm = nn.LSTM(
            input_size=protein_input_size,
            hidden_size=protein_hidden_size,
            num_layers=num_layers,
            dropout=dropout if num_layers > 1 else 0,
            bidirectional=True,
            batch_first=True
        )
        self.dropout = nn.Dropout(dropout)
        fusion_input_size = rna_hidden_size * 2 + protein_hidden_size * 2
        self.fusion_layers = nn.Sequential(
            nn.Linear(fusion_input_size, fusion_hidden_size),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(fusion_hidden_size, fusion_hidden_size // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(fusion_hidden_size // 2, 1)
        )
        self._init_weights()
    def _init_weights(self):
        for name, param in self.named_parameters():
            if 'weight_ih' in name:
                nn.init.xavier_uniform_(param.data)
            elif 'weight_hh' in name:
                nn.init.orthogonal_(param.data)
            elif 'bias' in name:
                param.data.fill_(0)
                if 'bias_ih' in name:
                    n = param.size(0)
                    param.data[n//4:n//2].fill_(1)
    def forward(self, rna, protein):
        batch_size = rna.size(0)
        rna_output, (rna_hidden, _) = self.rna_lstm(rna)
        rna_hidden = rna_hidden.view(self.num_layers, 2, batch_size, self.rna_hidden_size)
        rna_representation = torch.cat([rna_hidden[-1, 0], rna_hidden[-1, 1]], dim=1)
        protein_output, (protein_hidden, _) = self.protein_lstm(protein)
        protein_hidden = protein_hidden.view(self.num_layers, 2, batch_size, self.protein_hidden_size)
        protein_representation = torch.cat([protein_hidden[-1, 0], protein_hidden[-1, 1]], dim=1)
        rna_representation = self.dropout(rna_representation)
        protein_representation = self.dropout(protein_representation)
        combined = torch.cat([rna_representation, protein_representation], dim=1)
        binding_score = self.fusion_layers(combined)
        return binding_score

# Create model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

model = BasicLSTM().to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print('Model created:')
print(f'  Total parameters: {total_params:,}')
print(f'  Trainable parameters: {trainable_params:,}')
print(f'  Model size: ~{total_params * 4 / 1024**2:.1f} MB (FP32)')

# Test forward pass
model.eval()
with torch.no_grad():
    sample_batch = next(iter(train_loader))
    rna_test = sample_batch['rna'].to(device)
    protein_test = sample_batch['protein'].to(device)
    output = model(rna_test, protein_test)
    print('Forward pass test:')
    print('  Input RNA shape:', rna_test.shape)
    print('  Input protein shape:', protein_test.shape)
    print('  Output shape:', output.shape)
    print('  Sample predictions:', output[:5].flatten().cpu().numpy())

## 6. Training Loop Implementation

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

training_config = {
    'learning_rate': 1e-3,
    'weight_decay': 1e-4,
    'patience': 5,
    'min_delta': 1e-4,
    'max_grad_norm': 1.0,
    'lr_scheduler_patience': 3,
    'lr_scheduler_factor': 0.5,
    'num_epochs': 5,  # Keep small for Colab demo; increase as needed
    'device': str(device),
    'run_name': run_name,
    'timestamp': datetime.now().isoformat()
}

save_training_config(training_config, run_dir)

criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=training_config['learning_rate'], 
                      weight_decay=training_config['weight_decay'])
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='max', factor=training_config['lr_scheduler_factor'], 
    patience=training_config['lr_scheduler_patience'], verbose=True
)

def calculate_pearson_correlation(y_true, y_pred):
    correlation, _ = pearsonr(y_true.flatten(), y_pred.flatten())
    return correlation if not np.isnan(correlation) else 0.0

class TrainingInterruptHandler:
    def __init__(self):
        self.interrupted = False
    def interrupt(self):
        self.interrupted = True
        print('\nTraining interrupted! Finishing current epoch...')

interrupt_handler = TrainingInterruptHandler()

def train_epoch_enhanced(model, train_loader, criterion, optimizer, device, interrupt_handler, max_grad_norm=1.0):
    model.train()
    total_loss = 0.0
    all_predictions = []
    all_targets = []
    progress_bar = tqdm(train_loader, desc='Training')
    for batch_idx, batch in enumerate(progress_bar):
        if interrupt_handler.interrupted:
            print(f'Training interrupted at batch {batch_idx}')
            break
        rna = batch['rna'].to(device)
        protein = batch['protein'].to(device)
        targets = batch['score'].to(device)
        optimizer.zero_grad()
        predictions = model(rna, protein)
        loss = criterion(predictions, targets)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=max_grad_norm)
        optimizer.step()
        total_loss += loss.item()
        all_predictions.append(predictions.detach().cpu().numpy())
        all_targets.append(targets.detach().cpu().numpy())
        grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=float('inf'))
        progress_bar.set_postfix({'Loss': f'{loss.item():.4f}', 'Grad Norm': f'{grad_norm:.3f}'})
    if all_predictions:
        predictions = np.concatenate(all_predictions, axis=0)
        targets = np.concatenate(all_targets, axis=0)
        correlation = calculate_pearson_correlation(targets, predictions)
    else:
        correlation = 0.0
    return total_loss / len(train_loader) if len(train_loader) > 0 else float('inf'), correlation


def validate_epoch_enhanced(model, val_loader, criterion, device):
    model.eval()
    total_loss = 0.0
    all_predictions = []
    all_targets = []
    with torch.no_grad():
        for batch in tqdm(val_loader, desc='Validating'):
            rna = batch['rna'].to(device)
            protein = batch['protein'].to(device)
            targets = batch['score'].to(device)
            predictions = model(rna, protein)
            loss = criterion(predictions, targets)
            total_loss += loss.item()
            all_predictions.append(predictions.cpu().numpy())
            all_targets.append(targets.cpu().numpy())
    predictions = np.concatenate(all_predictions, axis=0)
    targets = np.concatenate(all_targets, axis=0)
    correlation = calculate_pearson_correlation(targets, predictions)
    return total_loss / len(val_loader), correlation, predictions, targets

print('Training setup complete')

num_epochs = training_config['num_epochs']
train_losses, val_losses = [], []
train_correlations, val_correlations = [], []
best_val_correlation = -np.inf
best_model_state = None
epochs_without_improvement = 0
patience = training_config['patience']
min_delta = training_config['min_delta']

print(f'Starting training for {num_epochs} epochs on {device}...')
start_time = time.time()

try:
    for epoch in range(num_epochs):
        if interrupt_handler.interrupted:
            print(f'Training stopped at epoch {epoch+1} due to interruption')
            break
        train_loss, train_corr = train_epoch_enhanced(
            model, train_loader, criterion, optimizer, device, 
            interrupt_handler, training_config['max_grad_norm']
        )
        if interrupt_handler.interrupted:
            print(f'Epoch {epoch+1} interrupted during training')
            break
        val_loss, val_corr, val_predictions, val_targets = validate_epoch_enhanced(
            model, val_loader, criterion, device
        )
        scheduler.step(val_corr)
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        train_correlations.append(train_corr)
        val_correlations.append(val_corr)
        improvement = val_corr - best_val_correlation
        if improvement > min_delta:
            epochs_without_improvement = 0
            best_val_correlation = val_corr
            best_model_state = model.state_dict().copy()
            print(f'New best model! Correlation improved by {improvement:.4f}')
        else:
            epochs_without_improvement += 1
        lr = optimizer.param_groups[0]['lr']
        print(f'Epoch {epoch+1}/{num_epochs} - LR: {lr:.2e}')
        print(f'  Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')
        print(f'  Train Corr: {train_corr:.4f}, Val Corr: {val_corr:.4f}')
        print(f'  Best Val Corr: {best_val_correlation:.4f}')
        if epochs_without_improvement >= patience:
            print(f'Early stopping triggered after {epoch+1} epochs')
            break
        print()
except KeyboardInterrupt:
    print(f'Interrupted at epoch {len(train_losses)+1}')
    interrupt_handler.interrupt()

total_time = time.time() - start_time
if best_model_state is not None:
    model.load_state_dict(best_model_state)
    print(f'Loaded best model (correlation: {best_val_correlation:.4f})')

training_summary = {
    'best_val_correlation': best_val_correlation,
    'best_epoch': int(np.argmax(val_correlations)) if val_correlations else 0,
    'final_train_correlation': train_correlations[-1] if train_correlations else 0,
    'final_val_correlation': val_correlations[-1] if val_correlations else 0,
    'total_epochs': len(train_losses),
    'total_training_time': total_time,
    'interrupted': interrupt_handler.interrupted,
    'early_stopped': epochs_without_improvement >= patience,
    'training_config': training_config,
    'train_losses': train_losses,
    'val_losses': val_losses,
    'train_correlations': train_correlations,
    'val_correlations': val_correlations
}

save_training_summary(training_summary, run_dir)
print('Training completed')

## 7. Model Evaluation and Metrics

In [None]:
# Evaluate on validation set with best model
model.eval()
val_loss, val_corr, final_predictions, final_targets = validate_epoch_enhanced(model, val_loader, criterion, device)

# Calculate additional metrics
mse = mean_squared_error(final_targets.flatten(), final_predictions.flatten())
mae = mean_absolute_error(final_targets.flatten(), final_predictions.flatten())
rmse = np.sqrt(mse)

print('Final Model Performance:')
print(f'  Validation Loss (MSE): {val_loss:.4f}')
print(f'  Pearson Correlation: {val_corr:.4f}')
print(f'  Mean Absolute Error: {mae:.4f}')
print(f'  Root Mean Square Error: {rmse:.4f}')

# Model inference speed test
print('\nInference Speed Test:')
model.eval()
with torch.no_grad():
    test_batch = next(iter(val_loader))
    rna_test = test_batch['rna'].to(device)
    protein_test = test_batch['protein'].to(device)
    for _ in range(5):
        _ = model(rna_test, protein_test)
    start_time = time.time()
    num_runs = 20
    for _ in range(num_runs):
        _ = model(rna_test, protein_test)
    end_time = time.time()
    avg_time = (end_time - start_time) / num_runs
    samples_per_second = rna_test.size(0) / avg_time

print(f'  Average inference time: {avg_time*1000:.2f}ms per batch')
print(f'  Throughput: {samples_per_second:.1f} samples/second')

print('\nPhase 1 Summary:')
print('  Model: Basic Bidirectional LSTM')
print(f'  Dataset: {len(subset_rna)} RNA-protein pairs')
print(f'  Training time: {total_time:.1f}s')
print(f'  Best validation correlation: {best_val_correlation:.4f}')
print(f'  Final validation correlation: {val_corr:.4f}')

## 8. Results Visualization and Saving

In [None]:
plots_dir = os.path.join(run_dir, 'plots')
os.makedirs(plots_dir, exist_ok=True)

fig, axes = plt.subplots(2, 2, figsize=(16, 12))

# 1. Training/Validation Loss
axes[0, 0].plot(train_losses, label='Training Loss', color='blue', linewidth=2)
axes[0, 0].plot(val_losses, label='Validation Loss', color='orange', linewidth=2)
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss (MSE)')
axes[0, 0].set_title('Training and Validation Loss')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# 2. Training/Validation Correlation
axes[0, 1].plot(train_correlations, label='Training Correlation', color='blue', linewidth=2)
axes[0, 1].plot(val_correlations, label='Validation Correlation', color='orange', linewidth=2)
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Pearson Correlation')
axes[0, 1].set_title('Training and Validation Correlation')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# 3. Predictions vs Targets Scatter Plot
axes[1, 0].scatter(final_targets.flatten(), final_predictions.flatten(), alpha=0.6, s=20)
min_val = min(final_targets.min(), final_predictions.min())
max_val = max(final_targets.max(), final_predictions.max())
axes[1, 0].plot([min_val, max_val], [min_val, max_val], 'r--', linewidth=2, label='Perfect Prediction')
correlation = calculate_pearson_correlation(final_targets, final_predictions)
axes[1, 0].set_xlabel('True Binding Scores')
axes[1, 0].set_ylabel('Predicted Binding Scores')
axes[1, 0].set_title(f'Predictions vs Targets\nPearson Correlation: {correlation:.4f}')
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)

# 4. Residuals Plot
residuals = final_targets.flatten() - final_predictions.flatten()
axes[1, 1].scatter(final_predictions.flatten(), residuals, alpha=0.6, s=20)
axes[1, 1].axhline(y=0, color='r', linestyle='--', linewidth=2)
axes[1, 1].set_xlabel('Predicted Binding Scores')
axes[1, 1].set_ylabel('Residuals (True - Predicted)')
axes[1, 1].set_title(f'Residuals Plot\nMean: {residuals.mean():.4f}, Std: {residuals.std():.4f}')
axes[1, 1].grid(True, alpha=0.3)

plt.suptitle(f'Phase 1 Training Results - {run_name}', fontsize=16, fontweight='bold')
plt.tight_layout()

main_plot_path = os.path.join(plots_dir, 'training_results.png')
plt.savefig(main_plot_path, dpi=300, bbox_inches='tight')
print('Main results plot saved to:', main_plot_path)
plt.show()

# Save model for future use
model_save_path = os.path.join(run_dir, 'models', 'phase1_basic_lstm_model.pth')
os.makedirs(os.path.dirname(model_save_path), exist_ok=True)
torch.save({
    'model_state_dict': model.state_dict(),
    'model_config': {
        'rna_input_size': 5,
        'protein_input_size': 21,
        'rna_hidden_size': 128,
        'protein_hidden_size': 128,
        'num_layers': 2,
        'dropout': 0.2,
        'fusion_hidden_size': 256
    },
    'training_history': training_summary,
    'sequence_lengths': {
        'rna_max_length': RNA_MAX_LENGTH,
        'protein_max_length': PROTEIN_MAX_LENGTH
    },
    'final_performance': {
        'correlation': float(correlation),
        'mse': float(mean_squared_error(final_targets.flatten(), final_predictions.flatten())),
        'mae': float(mean_absolute_error(final_targets.flatten(), final_predictions.flatten()))
    }
}, model_save_path)

print('Model saved to:', model_save_path)
print('All outputs saved to:', run_dir)