In [None]:
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns


from utils import *
from ardca import *
from classes import *

In [None]:
N_SEQ = 500
msa_path = "../data/PF00014.fasta.gz"
model_path = "models/ardca_PF00014_v3.pt"

In [None]:
alignment = read_fasta_alignment(msa_path, max_gap_fraction=1.0)
model = load_ardca_model(model_path)

# PCA on MSA Sequences vs Model Generated Sequences

In [None]:
# perfom pca on MSA
alignment2d, _ = pca_from_onehot(alignment)

# sample N_SEQ from the PCA
random_indices = np.random.choice(alignment2d.shape[0], size=N_SEQ, replace=False)
random_msa_samples = alignment2d[random_indices]

In [None]:
fig, ax = plt.subplots(figsize=(8,6))
h = ax.hist2d(
    random_msa_samples[:, 0],
    random_msa_samples[:, 1],
    bins=50,
    density=True,
)

ax.set(
    title=f"PCA on MSA with {N_SEQ} samples",
    xlabel="principal component 1",
    ylabel="principal component 2",
)

plt.colorbar(h[3], ax=ax)
plt.savefig("out/PCA_DATA")

In [None]:
# sample sequences from model
generated_seqs = model.sample_sequences(n_samples=N_SEQ, seed=42).numpy()

In [None]:
# perform pca on the model
pca_generated, _ = pca_from_onehot(generated_seqs)

In [None]:
fig, ax = plt.subplots(figsize=(8,6))
h = ax.hist2d(
    pca_generated[:, 0],
    pca_generated[:, 1],
    bins=50,
    density=True,
)

ax.set(
    title=f"PCA on Generated Sequences with {N_SEQ} samples",
    xlabel="principal component 1",
    ylabel="principal component 2",
)

plt.colorbar(h[3], ax=ax)
plt.savefig("out/PCA_MODEL")

# UMAP

In [None]:
# perform umap on alignment
alignment_umap, _ = umap_from_onehot(alignment, n_components=2)
plt.figure(figsize=(7, 6))
plt.scatter(alignment_umap[:, 0], alignment_umap[:, 1], s=10, alpha=0.8, color="steelblue")

plt.xlabel("UMAP-1")
plt.ylabel("UMAP-2")
plt.title("UMAP from MSA")
plt.tight_layout()
plt.savefig("out/UMAP_DATA")

In [None]:
# perform umap on model
generated_umap, _ = umap_from_onehot(generated_seqs, n_components=2)

plt.figure(figsize=(7, 6))
plt.scatter(generated_umap[:, 0], generated_umap[:, 1], s=10, alpha=0.8, color="steelblue")

plt.xlabel("UMAP-1")
plt.ylabel("UMAP-2")
plt.title("UMAP from Model")
plt.tight_layout()
plt.savefig("out/UMAP_MODEL")

# Hamming Distance Matrix

In [None]:
from sklearn.metrics import pairwise_distances
import seaborn as sns

N_SEQ = 200

random_indices = np.random.choice(alignment.shape[0], size=N_SEQ, replace=False)
msa_samples = alignment[random_indices]

X = one_hot_for_pca(msa_samples)
dist_matrix = pairwise_distances(X, metric="hamming")

plt.figure(figsize=(12, 10))
cluster_grid = sns.clustermap(
    dist_matrix,
    cmap="viridis",
    figsize=(12, 10),
    method="average",  # Use average linkage for clustering
    xticklabels=False, # Hide x-axis labels
    yticklabels=False, # Hide y-axis labels
    cbar_kws={"label": "Hamming Distance"}
)
plt.title(f"Hamming Distance on {N_SEQ} Data Samples")
plt.savefig("out/HD_DATA")

In [None]:
generated_seqs = model.sample_sequences(n_samples=N_SEQ, seed=42).numpy()
Y = one_hot_for_pca(generated_seqs)
dist_matrix = pairwise_distances(Y, metric="hamming")

plt.figure(figsize=(12, 10))
cluster_grid = sns.clustermap(
    dist_matrix,
    cmap="viridis",
    figsize=(12, 10),
    method="average",  # Use average linkage for clustering
    xticklabels=False, # Hide x-axis labels
    yticklabels=False, # Hide y-axis labels
    cbar_kws={"label": "Hamming Distance"}
)
plt.title(f"Hamming Distance on {N_SEQ} Model Samples")
plt.savefig("out/HD_MODEL")

# One-, Two-, & Three-Site Connected Correlation

In [None]:
generated_seqs = model.sample_sequences(n_samples=N_SEQ, seed=42).numpy().astype(np.int16)

In [None]:
def comprehensive_correlation_analysis(data_sequences: np.ndarray,
                                     model_sequences: np.ndarray,
                                     data_weights: np.ndarray,
                                     data_M_eff: float,
                                     model_weights: np.ndarray,
                                     model_M_eff: float,
                                     q: int = 21,
                                     max_triplets: Optional[int] = 500,
                                     figsize: Tuple[int, int] = (15, 5)) -> Dict:
    """
    Perform comprehensive correlation analysis and create comparison plots.
    
    Args:
        data_sequences: Real MSA sequences of shape (M_data, L)
        model_sequences: Model-generated sequences of shape (M_model, L)
        q: Number of amino acid types
        max_triplets: Maximum triplets for 3-site correlations
        figsize: Figure size for the combined plot
        
    Returns:
        Dictionary containing all computed correlations and statistics
    """
    print("Computing single-site frequencies...")
    data_f1 = compute_empirical_f1(data_sequences, data_weights, data_M_eff, q)
    model_f1 = compute_empirical_f1(model_sequences, model_weights, model_M_eff, q)
    
    print("Computing two-site correlations...")
    data_c2 = compute_two_site_correlations(data_sequences, data_weights, data_M_eff, q)
    model_c2 = compute_two_site_correlations(model_sequences, model_weights, model_M_eff, q)
    
    print("Computing three-site correlations...")
    data_c3, triplet_indices = compute_three_site_correlations(data_sequences, data_weights, data_M_eff, q, max_triplets)
    model_c3, _ = compute_three_site_correlations(model_sequences, model_weights, model_M_eff, q, max_triplets)
    
    
    # Create comparison plots
    fig, axes = plt.subplots(1, 3, figsize=figsize)
    
    # Single-site frequencies
    data_f1_flat = data_f1.flatten()
    model_f1_flat = model_f1.flatten()
    nonzero_mask = data_f1_flat > 1e-6  # Only non-zero frequencies
    
    corr_f1, _ = pearsonr(data_f1_flat[nonzero_mask], model_f1_flat[nonzero_mask])
    slope_f1 = np.polyfit(data_f1_flat[nonzero_mask], model_f1_flat[nonzero_mask], 1)[0]
    
    axes[0].scatter(data_f1_flat[nonzero_mask], model_f1_flat[nonzero_mask], 
                   alpha=0.6, s=2, color='blue')
    axes[0].plot([0, 1], [0, 1], 'k--', alpha=0.8, linewidth=1)
    axes[0].set_xlabel('Data')
    axes[0].set_ylabel('Sample')
    axes[0].set_title('$F_i$')
    axes[0].text(0.05, 0.95, f'arDCA, Pearson: {corr_f1:.1f}\nSlope: {slope_f1:.2f}', 
                transform=axes[0].transAxes, verticalalignment='top', fontsize=10,
                bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
    axes[0].grid(True, alpha=0.3)
    
    # Two-site correlations
    data_c2_flat = extract_nonzero_correlations(data_c2)
    model_c2_flat = extract_nonzero_correlations(model_c2)
    min_len = min(len(data_c2_flat), len(model_c2_flat))
    
    corr_c2, _ = pearsonr(data_c2_flat[:min_len], model_c2_flat[:min_len])
    slope_c2 = np.polyfit(data_c2_flat[:min_len], model_c2_flat[:min_len], 1)[0]
    
    axes[1].scatter(data_c2_flat[:min_len], model_c2_flat[:min_len], 
                   alpha=0.4, s=1, color='blue')
    min_val = min(np.min(data_c2_flat[:min_len]), np.min(model_c2_flat[:min_len]))
    max_val = max(np.max(data_c2_flat[:min_len]), np.max(model_c2_flat[:min_len]))
    axes[1].plot([min_val, max_val], [min_val, max_val], 'k--', alpha=0.8, linewidth=1)
    axes[1].set_xlabel('Data')
    axes[1].set_ylabel('Sample')
    axes[1].set_title('$C_{ij}$')
    axes[1].text(0.05, 0.95, f'arDCA, Pearson: {corr_c2:.1f}\nSlope: {slope_c2:.2f}', 
                transform=axes[1].transAxes, verticalalignment='top', fontsize=10,
                bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
    axes[1].grid(True, alpha=0.3)
    
    # Three-site correlations
    data_c3_flat = extract_nonzero_correlations(data_c3)
    model_c3_flat = extract_nonzero_correlations(model_c3)
    min_len = min(len(data_c3_flat), len(model_c3_flat))
    
    corr_c3, _ = pearsonr(data_c3_flat[:min_len], model_c3_flat[:min_len])
    slope_c3 = np.polyfit(data_c3_flat[:min_len], model_c3_flat[:min_len], 1)[0]
    
    axes[2].scatter(data_c3_flat[:min_len], model_c3_flat[:min_len], 
                   alpha=0.4, s=1, color='blue')
    min_val = min(np.min(data_c3_flat[:min_len]), np.min(model_c3_flat[:min_len]))
    max_val = max(np.max(data_c3_flat[:min_len]), np.max(model_c3_flat[:min_len]))
    axes[2].plot([min_val, max_val], [min_val, max_val], 'k--', alpha=0.8, linewidth=1)
    axes[2].set_xlabel('Data')
    axes[2].set_ylabel('Sample')
    axes[2].set_title('$C_{ijk}$')
    axes[2].text(0.05, 0.95, f'arDCA, Pearson: {corr_c3:.1f}\nSlope: {slope_c3:.2f}', 
                transform=axes[2].transAxes, verticalalignment='top', fontsize=10,
                bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
    axes[2].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # Return results dictionary
    results = {
        'single_site_freqs': {
            'data': data_f1,
            'model': model_f1,
            'pearson': corr_f1,
            'slope': slope_f1
        },
        'two_site_corrs': {
            'data': data_c2,
            'model': model_c2,
            'pearson': corr_c2,
            'slope': slope_c2
        },
        'three_site_corrs': {
            'data': data_c3,
            'model': model_c3,
            'pearson': corr_c3,
            'slope': slope_c3,
            'triplet_indices': triplet_indices
        },
        'figure': fig
    }
    
    return results

In [None]:
W, M_eff = compute_weights_blockwise(X_idx=alignment, theta=0.8, gap_idx=0)

model_weights = np.ones(len(generated_seqs))
model_M_eff = float(len(generated_seqs))

In [None]:
# Run comprehensive analysis
results = comprehensive_correlation_analysis(
    random_msa_samples, 
    generated_seqs,
    W,
    M_eff,
    model_weights,
    model_M_eff,
    q=q,
    max_triplets=200
)

# Print summary statistics
print(f"Single-site correlation: {results['single_site_freqs']['pearson']:.3f}")
print(f"Two-site correlation: {results['two_site_corrs']['pearson']:.3f}")
print(f"Three-site correlation: {results['three_site_corrs']['pearson']:.3f}")

results