# Stastic Notebook
This is the notebook for our experiment’s data processing and statistics. It covers everything from handling `.fasta` files generated by our model to output statistical summaries — including **median/mean** results for **single seed** experiments and **median/mean** across **multiple seeds**. Please make sure to check and adjust the file paths accordingly.
- Process `.fasta` files to `.txt` files.
- Cauculate the single seed stastic data.
- Cauculate the multiple seed stastic data.

In [4]:
import numpy as np
import os
import re
import glob
from pathlib import Path
from statistics import median
from typing import List

In [11]:
#####Process fasta files and generate the data txt files that we need#####
def parse_fasta_file(fasta_path):
    """
    Parse a single FASTA file to extract the input sequence length and metrics for 16 samples
    
    Returns:
    - input_length: Length of the input sequence
    - median_metrics: Median metrics for 16 samples {perplexity, recovery, edit_dist, sc_score}
    """
    try:
        with open(fasta_path, 'r', encoding='utf-8') as f:
            content = f.read()
        
        # Split different sequence blocks
        sequences = content.strip().split('>')
        sequences = [seq.strip() for seq in sequences if seq.strip()]
        
        input_length = 0
        sample_metrics = []
        
        for seq_block in sequences:
            lines = seq_block.strip().split('\n')
            if not lines:
                continue
            
            header = lines[0]
            sequence_lines = lines[1:]
            
            # Process input sequence
            if 'input_sequence' in header:
                # Combine all sequence lines and calculate length
                sequence = ''.join(sequence_lines).replace(' ', '').replace('\n', '')
                input_length = len(sequence)
                print(f"  Input sequence length: {input_length}")
            
            # Process sample sequence
            elif 'sample=' in header:
                # Extract metrics using regular expressions
                perplexity_match = re.search(r'perplexity=([0-9.]+)', header)
                recovery_match = re.search(r'recovery=([0-9.]+)', header)
                edit_dist_match = re.search(r'edit_dist=([0-9.]+)', header)
                sc_score_match = re.search(r'sc_score=([0-9.]+)', header)
                
                if all([perplexity_match, recovery_match, edit_dist_match, sc_score_match]):
                    metrics = {
                        'perplexity': float(perplexity_match.group(1)),
                        'recovery': float(recovery_match.group(1)),
                        'edit_dist': float(edit_dist_match.group(1)),
                        'sc_score': float(sc_score_match.group(1))
                    }
                    sample_metrics.append(metrics)
        
        # Calculate medians
        if sample_metrics:
            median_metrics = {
                'perplexity': median(m['perplexity'] for m in sample_metrics),
                'recovery': median(m['recovery'] for m in sample_metrics),
                'edit_dist': median(m['edit_dist'] for m in sample_metrics),
                'sc_score': median(m['sc_score'] for m in sample_metrics)
            }
            print(f"  Found {len(sample_metrics)} samples")
            print(f"  Median metrics: perplexity={median_metrics['perplexity']:.4f}, recovery={median_metrics['recovery']:.4f}, edit_dist={median_metrics['edit_dist']:.4f}, sc_score={median_metrics['sc_score']:.4f}")
        else:
            print("  Warning: No valid sample metrics found")
            median_metrics = None
        
        return input_length, median_metrics
        
    except Exception as e:
        print(f"  Error: Exception occurred while processing file: {e}")
        return 0, None

def process_all_fasta_files(input_dir, output_file):
    """
    Process all FASTA files in the specified directory
    """
    print("=== FASTA File Processing Script ===")
    
    # Check input directory
    if not os.path.exists(input_dir):
        print(f"Error: Input directory does not exist: {input_dir}")
        return
    
    # Create output directory
    output_dir = os.path.dirname(output_file)
    if output_dir:
        os.makedirs(output_dir, exist_ok=True)
        print(f"Output directory: {output_dir}")
    
    # Find all FASTA files
    fasta_pattern = os.path.join(input_dir, "*.fasta")
    fasta_files = glob.glob(fasta_pattern)
    
    if not fasta_files:
        print(f"Warning: No .fasta files found in {input_dir}")
        return
    
    print(f"Found {len(fasta_files)} FASTA files")
    
    # Store processing results
    results = []
    processed_count = 0
    failed_count = 0
    
    # Process each FASTA file
    for i, fasta_file in enumerate(sorted(fasta_files), 1):
        filename = os.path.basename(fasta_file)
        print(f"\n[{i}/{len(fasta_files)}] Processing file: {filename}")
        
        input_length, median_metrics = parse_fasta_file(fasta_file)
        
        if input_length > 0 and median_metrics is not None:
            # Format result
            result_line = f"{input_length} {median_metrics['perplexity']:.4f} {median_metrics['recovery']:.4f} {median_metrics['edit_dist']:.4f} {median_metrics['sc_score']:.4f}"
            results.append(result_line)
            processed_count += 1
            print(f"  ✓ Processed successfully")
        else:
            print(f"  ✗ Processing failed")
            failed_count += 1
    
    # Save results
    if results:
        try:
            with open(output_file, 'w', encoding='utf-8') as f:
                for result in results:
                    f.write(result + '\n')
            
            print(f"\n=== Processing Completed ===")
            print(f"Total files: {len(fasta_files)}")
            print(f"Successfully processed: {processed_count}")
            print(f"Failed: {failed_count}")
            print(f"Results saved to: {output_file}")
            
            # Show preview of first few lines
            print(f"\nResults preview (first 5 lines):")
            for i, result in enumerate(results[:5]):
                print(f"  {result}")
            if len(results) > 5:
                print(f"  ... (total {len(results)} lines)")
                
        except Exception as e:
            print(f"Error saving file: {e}")
    else:
        print("No successfully processed data, unable to generate output file")

def validate_output_format(output_file):
    """
    Validate the format of the output file
    """
    if not os.path.exists(output_file):
        print("Output file does not exist")
        return
    
    print(f"\n=== Validate Output Format ===")
    try:
        with open(output_file, 'r', encoding='utf-8') as f:
            lines = f.readlines()
        
        print(f"Total lines: {len(lines)}")
        
        for i, line in enumerate(lines[:3], 1):  # Check first 3 lines
            parts = line.strip().split()
            if len(parts) == 5:
                length = int(parts[0])
                perplexity = float(parts[1])
                recovery = float(parts[2])
                edit_dist = float(parts[3])
                sc_score = float(parts[4])
                print(f"Line {i}: length={length}, perplexity={perplexity}, recovery={recovery}, edit_dist={edit_dist}, sc_score={sc_score}")
            else:
                print(f"Line {i} format error: {line.strip()}")
    
    except Exception as e:
        print(f"Error during validation: {e}")

def main():
    """
    Main function
    """
    # Set input and output paths
    input_directory = "./testfasta/gr_1_seed4"
    output_filepath = "./data/gr_1_seed4.txt"

    # Process all FASTA files
    process_all_fasta_files(input_directory, output_filepath)
    
    # Validate output format
    validate_output_format(output_filepath)

if __name__ == "__main__":
    main()

=== FASTA File Processing Script ===
Output directory: ./data
Found 612 FASTA files

[1/612] Processing file: 1EFW_1_C.fasta
  Input sequence length: 63
  Found 16 samples
  Median metrics: perplexity=1.2711, recovery=0.6667, edit_dist=21.0000, sc_score=0.8121
  ✓ Processed successfully

[2/612] Processing file: 1EFW_1_D.fasta
  Input sequence length: 63
  Found 16 samples
  Median metrics: perplexity=1.2607, recovery=0.6429, edit_dist=22.5000, sc_score=0.7189
  ✓ Processed successfully

[3/612] Processing file: 1EHT_3_A.fasta
  Input sequence length: 33
  Found 10 samples
  Median metrics: perplexity=1.3628, recovery=0.6406, edit_dist=12.5000, sc_score=0.5989
  ✓ Processed successfully

[4/612] Processing file: 1EHZ_1_A.fasta
  Input sequence length: 62
  Found 16 samples
  Median metrics: perplexity=1.2371, recovery=0.7419, edit_dist=15.0000, sc_score=0.6084
  ✓ Processed successfully

[5/612] Processing file: 1EVV_1_A.fasta
  Input sequence length: 62
  Found 16 samples
  Median met

In [None]:
### Single seed RNA Design Data Analysis Script###
def analyze_rna_data(filename="./data/v11s4.txt"):#Just change the filename or path to analyze different data files
    """
    Read our .fasta files,
    length (int)
    recovery (float)
    scc (float)
    """
    if not os.path.exists(filename):
        print(f"Error: '{filename}' is not found in the current directory.")
        return

    # Create data buckets for three length ranges
    range_small = []  # 0-100 nt
    range_medium = [] # 101-200 nt
    range_large = []  # >200 nt

    with open(filename, 'r') as f:
        for line in f:
            # Try to parse each line
            try:
                parts = line.strip().split()
                # Ensure the line has enough data columns
                if len(parts) >= 5:
                    length = int(parts[0])
                    recovery = float(parts[2])
                    scc = float(parts[4])
                    if length <= 100:
                        range_small.append([recovery, scc])
                    elif 100 < length <= 200:
                        range_medium.append([recovery, scc])
                    else: # length > 200
                        range_large.append([recovery, scc])
            except (ValueError, IndexError):
                continue

    # Define a function to calculate statistics and return a dictionary
    def calculate_statistics(data):
        if not data:
            return {
                "mean_recovery": 0.0, "median_recovery": 0.0,
                "mean_scc": 0.0, "median_scc": 0.0
            }

        data_np = np.array(data)
        recoveries = data_np[:, 0]
        sccs = data_np[:, 1]
        
        return {
            "mean_recovery": np.mean(recoveries),
            "median_recovery": np.median(recoveries),
            "mean_scc": np.mean(sccs),
            "median_scc": np.median(sccs)
        }

    stats_small = calculate_statistics(range_small)
    stats_medium = calculate_statistics(range_medium)
    stats_large = calculate_statistics(range_large)

    print(f"Reading {filename} ...")
    print("           & Recovery (0-100) & Recovery (100-200) & Recovery (>200) & SC Score (0-100) & SC Score (100-200) & SC Score (>200)")
    
    mean_line = (
        f"Average:    & {stats_small['mean_recovery']:.3f} & {stats_medium['mean_recovery']:.3f} & {stats_large['mean_recovery']:.3f} & "
        f"{stats_small['mean_scc']:.3f} & {stats_medium['mean_scc']:.3f} & {stats_large['mean_scc']:.3f}"
    )

    median_line = (
        f"Median:     & {stats_small['median_recovery']:.3f} & {stats_medium['median_recovery']:.3f} & {stats_large['median_recovery']:.3f} & "
        f"{stats_small['median_scc']:.3f} & {stats_medium['median_scc']:.3f} & {stats_large['median_scc']:.3f}"
    )

    print(mean_line)
    print(median_line)


if __name__ == "__main__":
    analyze_rna_data()

Reading ./data/v11s0.txt ...
           & Recovery (0-100) & Recovery (100-200) & Recovery (>200) & SC Score (0-100) & SC Score (100-200) & SC Score (>200)
Average:    & 0.521 & 0.588 & 0.706 & 0.712 & 0.635 & 0.414
Median:     & 0.488 & 0.560 & 0.693 & 0.769 & 0.618 & 0.403


In [None]:
def analyze_mean_stability(filenames: List[str]):
    for filename in filenames:
        if not os.path.exists(filename):
            print(f"Error: '{filename}' is not found in the current directory.")
            return
###### We will take the Average for each file, and then compute the mean ± standard deviation across those average values.######
    small_recovery_means, medium_recovery_means, large_recovery_means = [], [], []
    small_scc_means, medium_scc_means, large_scc_means = [], [], []

    for filename in filenames:
        range_small, range_medium, range_large = [], [], [] 
        with open(filename, 'r') as f:
            for line in f:
                try:
                    parts = line.strip().split()
                    if len(parts) >= 5:
                        length = int(parts[0])
                        recovery = float(parts[2])
                        scc = float(parts[4])
                        
                        if length <= 100:
                            range_small.append([recovery, scc])
                        elif 100 < length <= 200:
                            range_medium.append([recovery, scc])
                        else:
                            range_large.append([recovery, scc])
                except (ValueError, IndexError):
                    continue

        if range_small:
            small_recovery_means.append(np.mean(np.array(range_small)[:, 0]))
            small_scc_means.append(np.mean(np.array(range_small)[:, 1]))
        if range_medium:
            medium_recovery_means.append(np.mean(np.array(range_medium)[:, 0]))
            medium_scc_means.append(np.mean(np.array(range_medium)[:, 1]))
        if range_large:
            large_recovery_means.append(np.mean(np.array(range_large)[:, 0]))
            large_scc_means.append(np.mean(np.array(range_large)[:, 1]))

    def calculate_stats_of_means(means_list):
        if not means_list:
            return 0.0, 0.0
        return np.mean(means_list), np.std(means_list)

    mean_rec_small, std_rec_small = calculate_stats_of_means(small_recovery_means)
    mean_rec_medium, std_rec_medium = calculate_stats_of_means(medium_recovery_means)
    mean_rec_large, std_rec_large = calculate_stats_of_means(large_recovery_means)
    
    mean_scc_small, std_scc_small = calculate_stats_of_means(small_scc_means)
    mean_scc_medium, std_scc_medium = calculate_stats_of_means(medium_scc_means)
    mean_scc_large, std_scc_large = calculate_stats_of_means(large_scc_means)

    print("\n--- Multiple Seed Experiments [Average] Stability Analysis ---")
    print(f"Reading {', '.join(filenames)}...")

    result_line_recovery = (
        f"Recovery (Mean): & "
        f"{mean_rec_small:.3f} ± {std_rec_small:.3f} & "
        f"{mean_rec_medium:.3f} ± {std_rec_medium:.3f} & "
        f"{mean_rec_large:.3f} ± {std_rec_large:.3f}"
    )
    result_line_scc = (
        f"SC Score (Mean): & "
        f"{mean_scc_small:.3f} ± {std_scc_small:.3f} & "
        f"{mean_scc_medium:.3f} ± {std_scc_medium:.3f} & "
        f"{mean_scc_large:.3f} ± {std_scc_large:.3f}"
    )

    print("Average               & 0-100 nt     & 100-200 nt     & >200 nt")
    print("-" * 90)
    print(result_line_recovery)
    print(result_line_scc)

def analyze_median_stability(filenames: List[str]):
###### We will take the median for each file, and then compute the mean ± standard deviation across those median values.######
    for filename in filenames:
        if not os.path.exists(filename):
            print(f"Error: '{filename}' is not found in the current directory.")
            return

    small_recovery_medians, medium_recovery_medians, large_recovery_medians = [], [], []
    small_scc_medians, medium_scc_medians, large_scc_medians = [], [], []

    for filename in filenames:
        range_small, range_medium, range_large = [], [], []
        
        with open(filename, 'r') as f:
            for line in f:
                try:
                    parts = line.strip().split()
                    if len(parts) >= 5:
                        length = int(parts[0])
                        recovery = float(parts[2])
                        scc = float(parts[4])
                        
                        if length <= 100:
                            range_small.append([recovery, scc])
                        elif 100 < length <= 200:
                            range_medium.append([recovery, scc])
                        else:
                            range_large.append([recovery, scc])
                except (ValueError, IndexError):
                    continue
        
        if range_small:
            small_recovery_medians.append(np.median(np.array(range_small)[:, 0]))
            small_scc_medians.append(np.median(np.array(range_small)[:, 1]))
        if range_medium:
            medium_recovery_medians.append(np.median(np.array(range_medium)[:, 0]))
            medium_scc_medians.append(np.median(np.array(range_medium)[:, 1]))
        if range_large:
            large_recovery_medians.append(np.median(np.array(range_large)[:, 0]))
            large_scc_medians.append(np.median(np.array(range_large)[:, 1]))

    def calculate_stats_of_medians(medians_list):
        if not medians_list:
            return 0.0, 0.0
        return np.mean(medians_list), np.std(medians_list)

    mean_rec_small, std_rec_small = calculate_stats_of_medians(small_recovery_medians)
    mean_rec_medium, std_rec_medium = calculate_stats_of_medians(medium_recovery_medians)
    mean_rec_large, std_rec_large = calculate_stats_of_medians(large_recovery_medians)
    
    mean_scc_small, std_scc_small = calculate_stats_of_medians(small_scc_medians)
    mean_scc_medium, std_scc_medium = calculate_stats_of_medians(medium_scc_medians)
    mean_scc_large, std_scc_large = calculate_stats_of_medians(large_scc_medians)

    print("\n--- Multiple Seed Experiments [Median] Stability Analysis ---")
    print(f"Reading {', '.join(filenames)} ...")

    result_line_recovery = (
        f"Recovery (Median): & "
        f"{mean_rec_small:.3f} ± {std_rec_small:.3f} & "
        f"{mean_rec_medium:.3f} ± {std_rec_medium:.3f} & "
        f"{mean_rec_large:.3f} ± {std_rec_large:.3f}"
    )
    result_line_scc = (
        f"SC Score (Median): & "
        f"{mean_scc_small:.3f} ± {std_scc_small:.3f} & "
        f"{mean_scc_medium:.3f} ± {std_scc_medium:.3f} & "
        f"{mean_scc_large:.3f} ± {std_scc_large:.3f}"
    )

    print("Median                & 0-100 nt     & 100-200 nt     & >200 nt")
    print("-" * 90)
    print(result_line_recovery)
    print(result_line_scc)


if __name__ == "__main__":
    files_to_analyze = [
        "./data/gvpa1s4_avg.txt", 
        "./data/gvpa1s1_avg.txt", 
        "./data/gvpa1s0_avg.txt"
    ]
    
    analyze_mean_stability(files_to_analyze)
    analyze_median_stability(files_to_analyze)


--- Multiple Seed Experiments [Average] Stability Analysis ---
Reading ./data/gr_1_seed4_avg.txt, ./data/gr_1_seed42_avg.txt, ./data/gr_1_seed2_avg.txt...
Average               & 0-100 nt     & 100-200 nt     & >200 nt
------------------------------------------------------------------------------------------
Recovery (Mean): & 0.511 ± 0.004 & 0.521 ± 0.038 & 0.673 ± 0.010
SC Score (Mean): & 0.684 ± 0.009 & 0.602 ± 0.025 & 0.391 ± 0.013

--- Multiple Seed Experiments [Median] Stability Analysis ---
Reading ./data/gr_1_seed4_avg.txt, ./data/gr_1_seed42_avg.txt, ./data/gr_1_seed2_avg.txt ...
Median                & 0-100 nt     & 100-200 nt     & >200 nt
------------------------------------------------------------------------------------------
Recovery (Median): & 0.493 ± 0.005 & 0.497 ± 0.017 & 0.610 ± 0.019
SC Score (Median): & 0.703 ± 0.014 & 0.564 ± 0.062 & 0.379 ± 0.018


In [None]:
import matplotlib.pyplot as plt
import numpy as np
import matplotlib

# --- UPDATED Data extracted from the new image ---
# List of models to plot
models = [
    'gRNAde',
    'long-short',
    'gat',
    'gat_gate_0.5',          # Renamed from 'gat_gvp_0.5' for simplicity
    'TransformerHead_0.3',
    'GVPAttention_lr1e-4', # New model
    'GVPAttention_lr2e-4'  # New model
]
# Corresponding names for the legend
legend_names = [
    'gRNAde',
    'long-short transformer(window=16)',
    'gat transformer',
    'gat transformer(gate=0.5)',     # Updated legend name
    'TransformerHead(gate=0.3)',
    'GVPAttention(lr=0.0001)',     # New legend name
    'GVPAttention(lr=0.0002)'      # New legend name
]


ranges = ['Short (0-100 nt)', 'Medium (100-200 nt)', 'Long (>200 nt)'] # Range labels in English

# Updated data dictionary based on Table 3 ONLY
data = {
    'Average Recovery': {
        'gRNAde': [0.509, 0.530, 0.673],
        'long-short': [0.487, 0.454, 0.707],
        'gat': [0.472, 0.467, 0.713],
        'gat_gate_0.5': [0.487, 0.465, 0.693],         # Data for gat(gate=0.5)
        'TransformerHead_0.3': [0.498, 0.481, 0.730],
        'GVPAttention_lr1e-4': [0.515, 0.550, 0.679], # New data
        'GVPAttention_lr2e-4': [0.524, 0.569, 0.698]  # New data
    },
    'Average SC Score': {
        'gRNAde': [0.691, 0.607, 0.393],
        'long-short': [0.692, 0.637, 0.398],
        'gat': [0.693, 0.631, 0.415],
        'gat_gate_0.5': [0.695, 0.637, 0.432],         # Data for gat(gate=0.5)
        'TransformerHead_0.3': [0.692, 0.635, 0.418],
        'GVPAttention_lr1e-4': [0.699, 0.607, 0.399], # New data
        'GVPAttention_lr2e-4': [0.688, 0.576, 0.403]  # New data
    }
    # Median data is NOT included as it was not provided in the new image
}
# --- END UPDATED DATA ---

# --- Plotting Function (Unchanged from previous successful version) ---
def plot_metric_comparison(metric_name, data_dict, models_list, legend_names_list, ranges_list):
    """
    Plots a compact bar chart comparing the specified metric across different sequence
    length ranges and different models, with horizontal value labels above bars and an
    auto-placed legend inside. Saves the plot.

    Args:
        metric_name (str): The name of the metric (e.g., 'Average Recovery')
        data_dict (dict): Dictionary containing the data for this metric (MUST contain keys from models_list)
        models_list (list): List of model keys to plot from data_dict
        legend_names_list(list): List of names for the legend (MUST match models_list order)
        ranges_list (list): List of sequence length range names
    """
    n_models = len(models_list) # Now n_models will be 7
    n_ranges = len(ranges_list)

    x = np.arange(n_ranges)
    total_width = 0.8 # Keep total width allocation
    width = total_width / n_models # Bar width adjusts automatically based on n_models
    offsets = np.linspace(-total_width / 2 + width / 2, total_width / 2 - width / 2, n_models)

    fig, ax = plt.subplots(figsize=(14, 7), layout='constrained')

    for i, model_key in enumerate(models_list):
        # Check if model exists in data_dict before trying to access it
        if model_key in data_dict:
             model_data = data_dict[model_key]
        else:
            print(f"Warning: Model key '{model_key}' not found in data for metric '{metric_name}'. Skipping.")
            continue # Skip plotting if data is missing

        legend_name = legend_names_list[i]
        offset = offsets[i]
        rects = ax.bar(x + offset, model_data, width, label=legend_name)

        # Add horizontal value labels above bars
        labels = [f"{height:.3f}" for height in model_data]
        ax.bar_label(rects,
                     labels=labels,
                     padding=3,
                     fontsize=9,
                     fontweight='bold',
                     rotation=0)

    # Add labels and title
    ax.set_ylabel(metric_name.split()[-1])
    ax.set_title(f'{metric_name} Comparison Across Sequence Lengths')
    ax.set_xticks(x)
    ax.set_xticklabels(ranges_list)

    # Place legend inside using 'best' location, large font
    ax.legend(loc='best', fontsize='large')

    # Adjust Y-axis limits
    # Calculate limits based only on the models actually plotted
    all_values = [val for model_key in models_list if model_key in data_dict for val in data_dict[model_key] ]
    if not all_values:
        print(f"Warning: No data found to plot for metric '{metric_name}'. Y-axis limits may be incorrect.")
        min_val, max_val = 0, 1
    else:
        min_val = min(all_values)
        max_val = max(all_values)

    padding_bottom = (max_val - min_val) * 0.05 if max_val > min_val else 0.05
    padding_top = (max_val - min_val) * 0.15 if max_val > min_val else 0.15
    ax.set_ylim([min_val - padding_bottom, max_val + padding_top])

    # Add grid lines
    ax.yaxis.grid(True, linestyle='--', alpha=0.7)

    # Save the figure to the current directory
    # Create a filename based on the metric name
    filename = f"{metric_name.replace(' ', '_')}_comparison_new_data.png"
    fig.savefig(filename, dpi=300, bbox_inches='tight')
    print(f"Saved plot to: {filename}")

    plt.show()

# --- Generate the TWO plots for which data is available ---
try:
    plot_metric_comparison('Average Recovery', data['Average Recovery'], models, legend_names, ranges)
except KeyError as e:
    print(f"Error plotting Average Recovery: Missing key {e}")

try:
    plot_metric_comparison('Average SC Score', data['Average SC Score'], models, legend_names, ranges)
except KeyError as e:
    print(f"Error plotting Average SC Score: Missing key {e}")

# Removed plotting calls for 'Median Recovery' and 'Median SC Score'