In [7]:
## plots the final results once all are done. 

In [8]:
# Core imports
import os
import numpy as np
import pandas as pd
import h5py
import pickle
import random
import time

# TensorFlow/Keras imports for model loading
import tensorflow as tf
tf.compat.v1.disable_eager_execution()
tf.compat.v1.disable_v2_behavior()
from keras.models import model_from_json
import shap
shap.explainers.deep.deep_tf.op_handlers["AddV2"] = shap.explainers.deep.deep_tf.passthrough

# SEAM imports
import seam
from seam import Compiler, Attributer, Clusterer, MetaExplainer

# SQUID imports for mutagenesis
import squid

# Scipy for correlation analysis
from scipy.stats import spearmanr, pearsonr

# Matplotlib for plotting
import matplotlib.pyplot as plt
import matplotlib.cm as cm

# =============================================================================
# Configuration
# =============================================================================
mutation_rates = [.75, .50, .25, 0.20, .15, .12, .10, .08, .05, .03, .01]   # High to low sweep
lib_size = 25000
cluster_number = 30
task_index = 0  # Dev task

In [9]:
# =============================================================================
# Configuration
# =============================================================================

# Base paths
BASE_DIR = '/grid/wsbs/home_norepl/pmantill/SEAM_revisions/SEAM_revisions/hyperparameter_selection'
DEV20_LIBRARY_PATH = f'{BASE_DIR}/data_and_models/dev_20_library/dev_20_library.pkl'
MODEL_DIR = f'{BASE_DIR}/data_and_models/models'
MUTAGENESIS_LIBRARY_DIR = f'{BASE_DIR}/data_and_models/mut_sweep_libraries'
ATTRIBUTION_DIR = f'{BASE_DIR}/b_mutation_rate_sweep/seq_libraries/mut_sweep/deepshap'
RESULTS_DIR = f'{BASE_DIR}/b_mutation_rate_sweep/results'

alphabet = ['A', 'C', 'G', 'T']

# 1. Load Dev_20 library (20 sequences, 3 removed)
print("\nLoading Dev_20 library...")
dev_pkl = pd.read_pickle(DEV20_LIBRARY_PATH)
dev_pkl = dev_pkl["dev"]
print(f"Loaded {len(dev_pkl)} sequences")

seq_indices = dev_pkl["test_idx"].tolist()
x_seqs = dev_pkl["ohe_seq"]



Loading Dev_20 library...
Loaded 20 sequences


In [10]:
def get_mut_label(mut_rate):
    """Convert mutation rate to label string (e.g., 0.75 -> '75.0%')"""
    return f"{mut_rate*100}%"

# =============================================================================
# Completion Check Functions
# =============================================================================

def check_libraries_complete(seq_indices):
    """Check if all mutagenesis libraries exist for all sequences and mutation rates."""
    for seq_idx in seq_indices:
        for mut_rate in mutation_rates:
            mut_label = get_mut_label(mut_rate)
            filepath = f'{MUTAGENESIS_LIBRARY_DIR}/Dev/seq_{seq_idx}/{mut_label}/25K.h5'
            if not os.path.exists(filepath):
                return False
    return True


def check_attributions_complete(seq_indices):
    """Check if all attributions exist for all sequences and mutation rates."""
    for seq_idx in seq_indices:
        for mut_rate in mutation_rates:
            mut_label = get_mut_label(mut_rate)
            filepath = f'{ATTRIBUTION_DIR}/Dev/seq_{seq_idx}/{mut_label}/25K.h5'
            if not os.path.exists(filepath):
                return False
    return True


def check_clustering_complete(seq_indices):
    """Check if all clustering results exist for all sequences and mutation rates."""
    for seq_idx in seq_indices:
        for mut_rate in mutation_rates:
            mut_label = get_mut_label(mut_rate)
            cluster_dir = f'{RESULTS_DIR}/clustering/seq_{seq_idx}/{mut_label}'
            linkage_path = os.path.join(cluster_dir, 'hierarchical_linkage_ward.npy')
            labels_path = os.path.join(cluster_dir, 'cluster_labels.npy')
            if not os.path.exists(linkage_path) or not os.path.exists(labels_path):
                return False
    return True


def check_msms_complete(seq_indices):
    """Check if all MSMs and variance summaries exist for all sequences and mutation rates."""
    for seq_idx in seq_indices:
        for mut_rate in mutation_rates:
            mut_label = get_mut_label(mut_rate)
            msm_dir = f'{RESULTS_DIR}/msms/seq_{seq_idx}/{mut_label}'
            msm_path = os.path.join(msm_dir, 'msm.csv')
            variance_path = os.path.join(msm_dir, 'variance_summary.csv')
            if not os.path.exists(msm_path) or not os.path.exists(variance_path):
                return False
    return True


def check_correlations_complete(seq_indices):
    """Check if all correlation files exist for all sequences."""
    ref_label = get_mut_label(0.10)
    for seq_idx in seq_indices:
        corr_path = f'{RESULTS_DIR}/correlations/seq_{seq_idx}/correlation_with_{ref_label}.csv'
        if not os.path.exists(corr_path):
            return False
    return True


def check_correlation_plots_complete(seq_indices):
    """Check if all individual correlation plots exist for all sequences."""
    ref_label = get_mut_label(0.10)
    for seq_idx in seq_indices:
        plot_path = f'{RESULTS_DIR}/correlations/seq_{seq_idx}/correlation_with_{ref_label}.png'
        if not os.path.exists(plot_path):
            return False
    return True


def check_summary_plots_complete():
    """Check if all summary correlation plots exist (one per mutation rate as reference)."""
    for mut_rate in mutation_rates:
        ref_label = get_mut_label(mut_rate)
        plot_path = f'{RESULTS_DIR}/results_final/correlation_summary_with_{ref_label}.png'
        if not os.path.exists(plot_path):
            return False
    return True

In [11]:
def compute_correlations_for_reference(seq_idx, reference_mut_rate):
    """Compute correlations with a given reference mutation rate using variance summaries."""
    ref_label = get_mut_label(reference_mut_rate)
    
    # Load reference variance summary
    variance_ref_path = f'{RESULTS_DIR}/msms/seq_{seq_idx}/{ref_label}/variance_summary.csv'
    if not os.path.exists(variance_ref_path):
        return None
    
    variance_ref = pd.read_csv(variance_ref_path)['Variance'].values
    
    # Compute correlations for each mutation rate
    results = []
    for mut_rate in mutation_rates:
        mut_label = get_mut_label(mut_rate)
        variance_path = f'{RESULTS_DIR}/msms/seq_{seq_idx}/{mut_label}/variance_summary.csv'
        if not os.path.exists(variance_path):
            continue
        
        variance_values = pd.read_csv(variance_path)['Variance'].values
        pearson_corr, _ = pearsonr(variance_ref, variance_values)
        spearman_corr, _ = spearmanr(variance_ref, variance_values)
        
        results.append({
            'Mut_Rate': mut_label,
            'Mut_Rate_Numeric': mut_rate,
            'Pearson': pearson_corr,
            'Spearman': spearman_corr
        })
    
    return pd.DataFrame(results) if results else None


def plot_summary_correlations(seq_indices, reference_mut_rate):
    """Plot summary of correlations for all sequences on a single figure."""
    ref_label = get_mut_label(reference_mut_rate)

    # Create output directory
    output_dir = f'{RESULTS_DIR}/results_final'
    os.makedirs(output_dir, exist_ok=True)
    plot_path = os.path.join(output_dir, f'correlation_summary_with_{ref_label}.png')

    # Collect data from all sequences (compute on the fly)
    all_data = []
    valid_seq_indices = []

    for seq_idx in seq_indices:
        corr_df = compute_correlations_for_reference(seq_idx, reference_mut_rate)
        if corr_df is not None and len(corr_df) > 1:
            corr_df['seq_idx'] = seq_idx
            all_data.append(corr_df)
            valid_seq_indices.append(seq_idx)

    if not all_data:
        print(f"No correlation data found for reference {ref_label}")
        return None

    # Create colormap for sequences
    n_seqs = len(valid_seq_indices)
    colors = cm.viridis(np.linspace(0, 1, n_seqs))

    # Create figure
    fig, ax = plt.subplots(figsize=(10, 7))

    for i, (seq_idx, corr_df) in enumerate(zip(valid_seq_indices, all_data)):
        ax.plot(corr_df['Mut_Rate_Numeric']*100, corr_df['Pearson'], 'o-',
                color=colors[i], markersize=6, linewidth=1.5, alpha=0.8,
                label=f'Seq {seq_idx}')

    # Add vertical line at reference point
    ax.axvline(x=reference_mut_rate*100, color='black', linestyle='--', linewidth=1.5, alpha=0.7, label=f'{ref_label} reference')

    # Add horizontal line at correlation = 1
    ax.axhline(y=1, color='gray', linestyle=':', alpha=0.5)

    ax.set_xlabel('Mutation Rate (%)', fontsize=12)
    ax.set_ylabel('Pearson Correlation', fontsize=12)
    ax.set_title(f'Correlation with {ref_label} Reference', fontsize=14)
    ax.set_ylim(0, 1.05)
    ax.grid(True, alpha=0.3)

    ax.legend(bbox_to_anchor=(1.02, 1), loc='upper left', fontsize=8, ncol=2)

    plt.tight_layout()
    plt.savefig(plot_path, dpi=150, bbox_inches='tight')
    plt.close()

    print(f"Saved summary correlation plot to {plot_path}")
    return plot_path


def plot_summary_correlation_plot():
    """Generate individual summary plots for each mutation rate as reference."""
    for ref_mut_rate in mutation_rates:
        ref_label = get_mut_label(ref_mut_rate)
        plot_path = f'{RESULTS_DIR}/results_final/correlation_summary_with_{ref_label}.png'
        if not os.path.exists(plot_path):
            print(f"Plot for {ref_label} is missing, generating...")
            plot_summary_correlations(seq_indices, ref_mut_rate)
        else:
            print(f"Plot for {ref_label} already exists")


def plot_combined_distance_subplots(seq_indices):
    """
    Create a single figure with 10 subplots (excluding 0.75 as reference).
    X-axis shows absolute distance from reference mutation rate.
    """
    # Exclude 0.75 from reference rates
    reference_rates = [r for r in mutation_rates if r != 0.75]
    n_refs = len(reference_rates)
    
    # Create output directory
    output_dir = f'{RESULTS_DIR}/results_final'
    os.makedirs(output_dir, exist_ok=True)
    plot_path = os.path.join(output_dir, 'correlation_vs_distance_combined.png')
    
    # Create figure with subplots (2 rows x 5 cols for 10 plots)
    fig, axes = plt.subplots(2, 5, figsize=(20, 8), sharex=False, sharey=True)
    axes = axes.flatten()
    
    # Create consistent colormap for sequences across all subplots
    # First, find all valid sequences across all references
    all_valid_seqs = set()
    for ref_rate in reference_rates:
        for seq_idx in seq_indices:
            corr_df = compute_correlations_for_reference(seq_idx, ref_rate)
            if corr_df is not None and len(corr_df) > 1:
                all_valid_seqs.add(seq_idx)
    
    all_valid_seqs = sorted(all_valid_seqs)
    n_seqs = len(all_valid_seqs)
    colors = {seq_idx: cm.viridis(i / max(n_seqs - 1, 1)) 
              for i, seq_idx in enumerate(all_valid_seqs)}
    
    for ax_idx, ref_rate in enumerate(reference_rates):
        ax = axes[ax_idx]
        ref_label = get_mut_label(ref_rate)
        
        # Collect data for this reference
        for seq_idx in seq_indices:
            corr_df = compute_correlations_for_reference(seq_idx, ref_rate)
            if corr_df is None or len(corr_df) <= 1:
                continue
            
            # Compute absolute distance from reference
            abs_distance = np.abs(corr_df['Mut_Rate_Numeric'].values - ref_rate) * 100
            
            ax.plot(abs_distance, corr_df['Pearson'], 'o-',
                    color=colors[seq_idx], markersize=4, linewidth=1, alpha=0.7)
        
        # Add vertical line at distance = 0 (reference point)
        ax.axvline(x=0, color='black', linestyle='--', linewidth=1, alpha=0.5)
        
        # Add horizontal line at correlation = 1
        ax.axhline(y=1, color='gray', linestyle=':', alpha=0.5)
        
        ax.set_title(f'Ref: {ref_label}', fontsize=10)
        ax.set_ylim(0, 1.05)
        ax.grid(True, alpha=0.3)
        
        # Only add x-label for bottom row
        if ax_idx >= 5:
            ax.set_xlabel('|Distance from Ref| (%)', fontsize=9)
        
        # Only add y-label for leftmost column
        if ax_idx % 5 == 0:
            ax.set_ylabel('Pearson Correlation', fontsize=9)
    
    # Create a single legend for all subplots
    legend_handles = [plt.Line2D([0], [0], color=colors[seq_idx], marker='o', 
                                  linestyle='-', markersize=4, linewidth=1, 
                                  label=f'Seq {seq_idx}')
                      for seq_idx in all_valid_seqs]
    
    fig.legend(handles=legend_handles, loc='center right', fontsize=7, 
               bbox_to_anchor=(1.08, 0.5), ncol=1)
    
    fig.suptitle('Correlation vs Absolute Distance from Reference Mutation Rate', 
                 fontsize=14, y=1.02)
    
    plt.tight_layout()
    plt.savefig(plot_path, dpi=150, bbox_inches='tight')
    plt.close()
    
    print(f"Saved combined distance plot to {plot_path}")
    return plot_path

In [12]:
## confirm all files exist

def check_combined_distance_plot_complete():
    """Check if the combined distance subplot figure exists."""
    plot_path = f'{RESULTS_DIR}/results_final/correlation_vs_distance_combined.png'
    return os.path.exists(plot_path)


def check_all_files_complete():
    """Uses all check functions to confirm all files exist"""
    if not check_libraries_complete(seq_indices):
        print("Libraries are missing")
    if not check_attributions_complete(seq_indices):
        print("Attributions are missing")
    if not check_clustering_complete(seq_indices):
        print("Clustering results are missing")
    if not check_msms_complete(seq_indices):
        print("MSMs and variance summaries are missing")
    if not check_correlations_complete(seq_indices):
        print("Correlation files are missing")
    if not check_correlation_plots_complete(seq_indices):
        print("Individual correlation plots are missing")
    if not check_summary_plots_complete():
        print("Summary plots are missing, generating...")
        plot_summary_correlation_plot()
    else:
        print("All summary plots exist")
    
    # Check and generate combined distance subplot
    if not check_combined_distance_plot_complete():
        print("Combined distance plot is missing, generating...")
        plot_combined_distance_subplots(seq_indices)
    else:
        print("Combined distance plot exists")


check_all_files_complete()

All summary plots exist
Combined distance plot is missing, generating...
Saved combined distance plot to /grid/wsbs/home_norepl/pmantill/SEAM_revisions/SEAM_revisions/hyperparameter_selection/b_mutation_rate_sweep/results/results_final/correlation_vs_distance_combined.png
