# RNA 3D Validation Features Extraction

This notebook extracts all three types of features for RNA validation data:
1. Thermodynamic features from RNA sequences
2. Pseudodihedral angle features from 3D coordinates
3. Mutual Information features from Multiple Sequence Alignments (MSAs)

This notebook works with validation data that includes 3D structural information.

In [None]:
# Standard imports
import os
import sys
import numpy as np
import pandas as pd
from pathlib import Path
import matplotlib.pyplot as plt
import time
import json

# Ensure the parent directory is in the path so we can import our modules
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

# Import feature extraction modules
from src.analysis.thermodynamic_analysis import extract_thermodynamic_features
from src.analysis.dihedral_analysis import extract_dihedral_features
from src.analysis.mutual_information import calculate_mutual_information, convert_mi_to_evolutionary_features
from src.data.extract_features_simple import save_features_npz

## Configuration

Define paths and parameters for feature extraction.

In [None]:
# Define relative paths
DATA_DIR = Path("../data")
RAW_DIR = DATA_DIR / "raw"
PROCESSED_DIR = DATA_DIR / "processed"

# Output directories for each feature type
THERMO_DIR = PROCESSED_DIR / "thermo_features"
DIHEDRAL_DIR = PROCESSED_DIR / "dihedral_features"
MI_DIR = PROCESSED_DIR / "mi_features"

# Make sure all directories exist
for directory in [RAW_DIR, PROCESSED_DIR, THERMO_DIR, DIHEDRAL_DIR, MI_DIR]:
    directory.mkdir(exist_ok=True, parents=True)

# Parameters
LIMIT = 5  # Limit for testing; set to None to process all data
VERBOSE = True  # Whether to print detailed progress

## Helper Functions

Define utility functions for loading data and extracting features.

In [ ]:
def load_rna_data(csv_path):\n    \"\"\"\n    Load RNA data from CSV file.\n    \n    Args:\n        csv_path: Path to CSV file containing RNA data\n        \n    Returns:\n        DataFrame with RNA data\n    \"\"\"\n    try:\n        df = pd.read_csv(csv_path)\n        print(f\"Loaded {len(df)} entries from {csv_path}\")\n        return df\n    except Exception as e:\n        print(f\"Error loading CSV file: {e}\")\n        return None\n\ndef get_unique_target_ids(df, id_col=\"ID\"):\n    \"\"\"\n    Extract unique target IDs from dataframe.\n    \n    Args:\n        df: DataFrame with RNA data\n        id_col: Column containing IDs\n        \n    Returns:\n        List of unique target IDs\n    \"\"\"\n    # Extract target IDs (format: TARGET_ID_RESIDUE_NUM)\n    target_ids = []\n    for id_str in df[id_col]:\n        # Split the ID string and get the target ID part\n        parts = id_str.split('_')\n        if len(parts) >= 2:\n            target_id = f\"{parts[0]}_{parts[1]}\"  # Take the first two parts (e.g., \"1SCL_A\")\n            target_ids.append(target_id)\n    \n    # Get unique target IDs\n    unique_targets = sorted(list(set(target_ids)))\n    print(f\"Found {len(unique_targets)} unique target IDs\")\n    return unique_targets\n\ndef load_structure_data(target_id, data_dir=RAW_DIR):\n    \"\"\"\n    Load structure data for a given target from labels CSV.\n    \n    Args:\n        target_id: Target ID\n        data_dir: Directory containing data\n        \n    Returns:\n        DataFrame with structure coordinates or None if not found\n    \"\"\"\n    data_dir = Path(data_dir)\n    \n    # Define possible label files (train or validation)\n    label_files = [\n        data_dir / \"train_labels.csv\",\n        data_dir / \"validation_labels.csv\"\n    ]\n    \n    for label_file in label_files:\n        if label_file.exists():\n            try:\n                print(f\"Looking for {target_id} in {label_file}\")\n                # Read the entire CSV file\n                all_data = pd.read_csv(label_file)\n                \n                # Filter rows for this target ID\n                target_data = all_data[all_data[\"ID\"].str.startswith(f\"{target_id}_\")]\n                \n                if len(target_data) > 0:\n                    print(f\"Found {len(target_data)} residues for {target_id}\")\n                    return target_data\n            except Exception as e:\n                print(f\"Error loading from {label_file}: {e}\")\n    \n    print(f\"Could not find structure data for {target_id} in any labels file\")\n    return None\n\ndef load_msa_data(target_id, data_dir=RAW_DIR):\n    \"\"\"\n    Load MSA data for a given target.\n    \n    Args:\n        target_id: Target ID\n        data_dir: Directory containing MSA data\n        \n    Returns:\n        List of MSA sequences or None if not found\n    \"\"\"\n    # Define possible MSA directories and extensions\n    msa_dirs = [\n        data_dir / \"MSA\",\n        data_dir,\n        data_dir / \"alignments\",\n        data_dir / \"validation\" / \"MSA\",\n        data_dir / \"validation\",\n        data_dir / \"validation\" / \"alignments\"\n    ]\n    \n    extensions = [\".MSA.fasta\", \".fasta\", \".fa\", \".afa\", \".msa\"]\n    \n    # Try all combinations of directories and extensions\n    for msa_dir in msa_dirs:\n        if not msa_dir.exists():\n            continue\n            \n        for ext in extensions:\n            msa_path = msa_dir / f\"{target_id}{ext}\"\n            if msa_path.exists():\n                print(f\"Loading MSA data from {msa_path}\")\n                try:\n                    # Parse FASTA file\n                    sequences = []\n                    current_seq = \"\"\n                    \n                    with open(msa_path, 'r') as f:\n                        for line in f:\n                            line = line.strip()\n                            if line.startswith('>'):\n                                if current_seq:\n                                    sequences.append(current_seq)\n                                    current_seq = \"\"\n                            else:\n                                current_seq += line\n                                \n                        # Add the last sequence\n                        if current_seq:\n                            sequences.append(current_seq)\n                    \n                    print(f\"Loaded {len(sequences)} sequences from MSA\")\n                    return sequences\n                except Exception as e:\n                    print(f\"Error loading MSA data: {e}\")\n    \n    # Fallback: try recursive search\n    print(f\"MSA file not found in standard locations, trying recursive search...\")\n    try:\n        for msa_dir in [data_dir, data_dir / \"validation\"]:\n            if not msa_dir.exists():\n                continue\n                \n            for ext in extensions:\n                pattern = f\"**/{target_id}{ext}\"\n                matches = list(msa_dir.glob(pattern))\n                if matches:\n                    msa_path = matches[0]\n                    print(f\"Found MSA via recursive search: {msa_path}\")\n                    \n                    # Parse the file\n                    sequences = []\n                    current_seq = \"\"\n                    \n                    with open(msa_path, 'r') as f:\n                        for line in f:\n                            line = line.strip()\n                            if line.startswith('>'):\n                                if current_seq:\n                                    sequences.append(current_seq)\n                                    current_seq = \"\"\n                            else:\n                                current_seq += line\n                                \n                        # Add the last sequence\n                        if current_seq:\n                            sequences.append(current_seq)\n                    \n                    print(f\"Loaded {len(sequences)} sequences from MSA\")\n                    return sequences\n    except Exception as e:\n        print(f\"Error in recursive MSA search: {e}\")\n    \n    print(f\"Could not find MSA data for {target_id}\")\n    return None\n\ndef get_sequence_for_target(target_id, data_dir=RAW_DIR):\n    \"\"\"\n    Get RNA sequence for a target ID from the sequence file.\n    \n    Args:\n        target_id: Target ID\n        data_dir: Directory containing sequence data\n        \n    Returns:\n        RNA sequence as string or None if not found\n    \"\"\"\n    # Try different possible file locations\n    sequence_paths = [\n        data_dir / \"sequences.csv\",\n        data_dir / \"train_sequences.csv\",\n        data_dir / \"validation_sequences.csv\",\n        data_dir / \"rna_sequences.csv\",\n        data_dir / \"validation\" / \"sequences.csv\",\n        data_dir / \"validation\" / \"validation_sequences.csv\"\n    ]\n    \n    for path in sequence_paths:\n        if path.exists():\n            try:\n                df = pd.read_csv(path)\n                \n                # Try different possible column names\n                id_cols = [\"target_id\", \"ID\", \"id\"]\n                seq_cols = [\"sequence\", \"Sequence\", \"seq\"]\n                \n                for id_col in id_cols:\n                    if id_col in df.columns:\n                        for seq_col in seq_cols:\n                            if seq_col in df.columns:\n                                # Find the target in the dataframe\n                                target_row = df[df[id_col] == target_id]\n                                if len(target_row) > 0:\n                                    sequence = target_row[seq_col].iloc[0]\n                                    return sequence\n            except Exception as e:\n                print(f\"Error loading sequence data from {path}: {e}\")\n    \n    # If we still haven't found the sequence, try to extract it from MSA data\n    msa_sequences = load_msa_data(target_id, data_dir)\n    if msa_sequences and len(msa_sequences) > 0:\n        # The first sequence in the MSA is typically the target sequence\n        return msa_sequences[0]\n    \n    print(f\"Could not find sequence for {target_id}\")\n    return None

## Feature Extraction Functions

Define functions for extracting each type of feature.

In [None]:
def extract_thermo_features_for_target(target_id, sequence=None):
    """
    Extract thermodynamic features for a given target.
    
    Args:
        target_id: Target ID
        sequence: RNA sequence (optional, will be loaded if not provided)
        
    Returns:
        Dictionary with thermodynamic features or None if failed
    """
    print(f"Extracting thermodynamic features for {target_id}")
    start_time = time.time()
    
    try:
        # Get sequence if not provided
        if sequence is None:
            sequence = get_sequence_for_target(target_id)
            if sequence is None:
                print(f"Failed to get sequence for {target_id}")
                return None
        
        # Calculate features
        print(f"Calculating thermodynamic features for sequence of length {len(sequence)}")
        features = extract_thermodynamic_features(sequence)
        
        # Save features
        output_file = THERMO_DIR / f"{target_id}_thermo_features.npz"
        features['target_id'] = target_id
        features['sequence'] = sequence
        
        save_features_npz(features, output_file)
        
        elapsed_time = time.time() - start_time
        print(f"Extracted thermodynamic features in {elapsed_time:.2f} seconds")
        return features
    
    except Exception as e:
        print(f"Error extracting thermodynamic features for {target_id}: {e}")
        import traceback
        traceback.print_exc()
        return None

def extract_dihedral_features_for_target(target_id, structure_data=None):
    """
    Extract pseudodihedral angle features for a given target.
    
    Args:
        target_id: Target ID
        structure_data: DataFrame with structure coordinates (optional, will be loaded if not provided)
        
    Returns:
        Dictionary with dihedral features or None if failed
    """
    print(f"Extracting dihedral features for {target_id}")
    start_time = time.time()
    
    try:
        # Get structure data if not provided
        if structure_data is None:
            structure_data = load_structure_data(target_id)
            if structure_data is None:
                print(f"Failed to get structure data for {target_id}")
                return None
        
        # Check if we have at least 4 residues (required for dihedral angles)
        if len(structure_data) < 4:
            print(f"Not enough residues ({len(structure_data)}) for {target_id}, minimum 4 required for dihedral angles")
            return None
        
        # Check if we have the necessary coordinate columns
        required_cols = ['x_1', 'y_1', 'z_1']
        if not all(col in structure_data.columns for col in required_cols):
            # Try to find alternative column names
            alt_x_cols = [col for col in structure_data.columns if col.startswith('x_')]
            if alt_x_cols:
                x_col = alt_x_cols[0]
                y_col = x_col.replace('x_', 'y_')
                z_col = x_col.replace('x_', 'z_')
                
                # Rename columns for compatibility
                if all(col in structure_data.columns for col in [x_col, y_col, z_col]):
                    structure_data = structure_data.rename(columns={
                        x_col: 'x_1',
                        y_col: 'y_1',
                        z_col: 'z_1'
                    })
                    print(f"Renamed columns {x_col}, {y_col}, {z_col} to x_1, y_1, z_1")
                else:
                    print(f"Missing coordinate columns for {target_id}")
                    return None
            else:
                print(f"Missing coordinate columns for {target_id}")
                return None
        
        # Calculate dihedral features
        output_file = DIHEDRAL_DIR / f"{target_id}_dihedral_features.npz"
        print(f"Calculating dihedral features for {len(structure_data)} residues")
        
        dihedral_features = extract_dihedral_features(structure_data, output_file=output_file, include_raw_angles=True)
        
        elapsed_time = time.time() - start_time
        print(f"Extracted dihedral features in {elapsed_time:.2f} seconds")
        
        # Add target ID
        dihedral_features['target_id'] = target_id
        return dihedral_features
    
    except Exception as e:
        print(f"Error extracting dihedral features for {target_id}: {e}")
        import traceback
        traceback.print_exc()
        return None

def extract_mi_features_for_target(target_id, structure_data=None, msa_sequences=None):
    """
    Extract Mutual Information features for a given target.
    
    Args:
        target_id: Target ID
        structure_data: DataFrame with structure data for correlation calculation (optional)
        msa_sequences: List of MSA sequences (optional, will be loaded if not provided)
        
    Returns:
        Dictionary with MI features or None if failed
    """
    print(f"Extracting MI features for {target_id}")
    start_time = time.time()
    
    try:
        # Get MSA sequences if not provided
        if msa_sequences is None:
            msa_sequences = load_msa_data(target_id)
            if msa_sequences is None or len(msa_sequences) < 2:
                print(f"Failed to get MSA data for {target_id} or not enough sequences")
                return None
        
        # Get structure data if not provided (for correlation calculation)
        if structure_data is None and target_id is not None:
            structure_data = load_structure_data(target_id)
        
        # Calculate MI (this may take some time for large MSAs)
        print(f"Calculating MI for {len(msa_sequences)} sequences")
        mi_result = calculate_mutual_information(msa_sequences, verbose=VERBOSE)
        
        if mi_result is None:
            print(f"Failed to calculate MI for {target_id}")
            return None
        
        # Convert to evolutionary features
        output_file = MI_DIR / f"{target_id}_mi_features.npz"
        
        # If we have structure data, use it for correlation calculation
        if structure_data is not None:
            print(f"Converting MI to evolutionary features with structural correlation")
            features = convert_mi_to_evolutionary_features(mi_result, structure_data, output_file=output_file)
        else:
            print(f"Converting MI to evolutionary features without structural correlation")
            features = mi_result
            
            # Save manually if convert_mi_to_evolutionary_features wasn't used
            if output_file is not None:
                np.savez_compressed(output_file, **features)
                print(f"Saved MI features to {output_file}")
        
        elapsed_time = time.time() - start_time
        print(f"Extracted MI features in {elapsed_time:.2f} seconds")
        
        # Add target ID
        features['target_id'] = target_id
        return features
    
    except Exception as e:
        print(f"Error extracting MI features for {target_id}: {e}")
        import traceback
        traceback.print_exc()
        return None

## Batch Processing

Process multiple targets in batch mode.

In [None]:
def process_target(target_id, extract_thermo=True, extract_dihedral=True, extract_mi=True):
    """
    Process a single target, extracting all requested feature types.
    
    Args:
        target_id: Target ID
        extract_thermo: Whether to extract thermodynamic features
        extract_dihedral: Whether to extract dihedral features
        extract_mi: Whether to extract MI features
        
    Returns:
        Dictionary with results for each feature type
    """
    print(f"\nProcessing target: {target_id}")
    results = {'target_id': target_id}
    start_time = time.time()
    
    # Load common data that might be used by multiple feature types
    sequence = get_sequence_for_target(target_id) if extract_thermo else None
    structure_data = load_structure_data(target_id) if extract_dihedral or extract_mi else None
    msa_sequences = load_msa_data(target_id) if extract_mi else None
    
    # Extract thermodynamic features
    if extract_thermo:
        thermo_file = THERMO_DIR / f"{target_id}_thermo_features.npz"
        
        if thermo_file.exists():
            print(f"Thermodynamic features already exist for {target_id}")
            results['thermo'] = {'success': True, 'skipped': True}
        else:
            thermo_features = extract_thermo_features_for_target(target_id, sequence)
            results['thermo'] = {'success': thermo_features is not None}
    
    # Extract dihedral features
    if extract_dihedral:
        dihedral_file = DIHEDRAL_DIR / f"{target_id}_dihedral_features.npz"
        
        if dihedral_file.exists():
            print(f"Dihedral features already exist for {target_id}")
            results['dihedral'] = {'success': True, 'skipped': True}
        else:
            dihedral_features = extract_dihedral_features_for_target(target_id, structure_data)
            results['dihedral'] = {'success': dihedral_features is not None}
    
    # Extract MI features
    if extract_mi:
        mi_file = MI_DIR / f"{target_id}_mi_features.npz"
        
        if mi_file.exists():
            print(f"MI features already exist for {target_id}")
            results['mi'] = {'success': True, 'skipped': True}
        else:
            mi_features = extract_mi_features_for_target(target_id, structure_data, msa_sequences)
            results['mi'] = {'success': mi_features is not None}
    
    # Calculate total time
    elapsed_time = time.time() - start_time
    results['elapsed_time'] = elapsed_time
    print(f"Completed processing {target_id} in {elapsed_time:.2f} seconds")
    
    return results

def batch_process_targets(target_ids, extract_thermo=True, extract_dihedral=True, extract_mi=True):
    """
    Process multiple targets in batch mode.
    
    Args:
        target_ids: List of target IDs
        extract_thermo: Whether to extract thermodynamic features
        extract_dihedral: Whether to extract dihedral features
        extract_mi: Whether to extract MI features
        
    Returns:
        Dictionary with results for each target
    """
    print(f"Starting batch processing for {len(target_ids)} targets")
    start_time = time.time()
    
    results = {}
    for i, target_id in enumerate(target_ids):
        print(f"\nProcessing target {i+1}/{len(target_ids)}: {target_id}")
        
        # Process the target
        target_results = process_target(
            target_id, 
            extract_thermo=extract_thermo, 
            extract_dihedral=extract_dihedral, 
            extract_mi=extract_mi
        )
        
        # Store results
        results[target_id] = target_results
    
    # Calculate statistics
    total_time = time.time() - start_time
    
    success_counts = {
        'thermo': sum(1 for r in results.values() if 'thermo' in r and r['thermo']['success']),
        'dihedral': sum(1 for r in results.values() if 'dihedral' in r and r['dihedral']['success']),
        'mi': sum(1 for r in results.values() if 'mi' in r and r['mi']['success'])
    }
    
    skipped_counts = {
        'thermo': sum(1 for r in results.values() if 'thermo' in r and r['thermo'].get('skipped', False)),
        'dihedral': sum(1 for r in results.values() if 'dihedral' in r and r['dihedral'].get('skipped', False)),
        'mi': sum(1 for r in results.values() if 'mi' in r and r['mi'].get('skipped', False))
    }
    
    # Print summary
    print("\nBatch processing complete!")
    print(f"Total targets: {len(target_ids)}")
    print(f"Total time: {total_time:.2f} seconds")
    
    if extract_thermo:
        print(f"Thermodynamic features: {success_counts['thermo']} successful ({skipped_counts['thermo']} skipped)")
        
    if extract_dihedral:
        print(f"Dihedral features: {success_counts['dihedral']} successful ({skipped_counts['dihedral']} skipped)")
        
    if extract_mi:
        print(f"MI features: {success_counts['mi']} successful ({skipped_counts['mi']} skipped)")
    
    # Save summary
    summary = {
        'timestamp': time.strftime('%Y-%m-%d %H:%M:%S'),
        'total_targets': len(target_ids),
        'total_time': total_time,
        'success_counts': success_counts,
        'skipped_counts': skipped_counts,
        'target_results': results
    }
    
    with open(PROCESSED_DIR / 'validation_processing_summary.json', 'w') as f:
        json.dump(summary, f, indent=2)
    
    return results

## Load Data and Process

In [ ]:
# Load validation data file from different possible locations
validation_paths = [
    RAW_DIR / "validation_labels.csv",
    RAW_DIR / "validation" / "validation_labels.csv"
]

validation_data = None
for validation_file in validation_paths:
    if validation_file.exists():
        validation_data = load_rna_data(validation_file)
        if validation_data is not None:
            break

if validation_data is None:
    print("Error loading validation data. Please make sure at least one validation file exists.")

## Visualization and Validation

Visualize and validate the features.

In [ ]:
# Select a target for visualization
if validation_data is not None and not validation_data.empty:
    # Get unique target IDs
    target_ids = get_unique_target_ids(validation_data)
    
    # Limit for testing
    if LIMIT is not None and LIMIT < len(target_ids):
        print(f"Limiting to first {LIMIT} targets for testing")
        target_ids = target_ids[:LIMIT]
    
    # Process targets
    results = batch_process_targets(
        target_ids,
        extract_thermo=True,
        extract_dihedral=True,
        extract_mi=True
    )
else:
    print("No validation data available. Please check your validation data files.")
    target_ids = []