In [None]:
import numpy as np
import matplotlib.pyplot as plt
from pywt import wavedec, waverec
from pyemma.coordinates import tica
import hdbscan
from scipy.stats import pearsonr
import seaborn as sns
import os
import warnings
warnings.filterwarnings("ignore")

In [None]:
# ============================================================================
# CONFIGURATION
# ============================================================================

# Input files
DATA_FILE = 'chig_sin_rmsd_rg_data.npz'


# Output directory
OUTPUT_DIR = './cluster/'

# Wavelet parameters
WAVELET_TYPE = 'coif7'
WAVELET_LEVEL = 4

# Feature selection
ORIGINAL_FEATURES = [1, 2, 3, 4, 5, 6, 7, 8, 10, 14]  # For original data tICA
WT_FEATURES = [2, 3, 4, 7, 9, 13]  # For wavelet-transformed data tICA

# tICA parameters
TICA_LAG_WT = 250  # Lag time for WT data
TICA_LAG_ORIG = 450  # Lag time for original data (different optimal lag)
TICA_DIM = 2

# HDBSCAN parameters
MIN_CLUSTER_SIZE = 10000
MIN_SAMPLES = 1000
CLUSTER_METHOD = 'leaf'
N_JOBS = 6


In [None]:
# ============================================================================
# STEP 1: DATA LOADING
# ============================================================================

# Load data from single npz file
data_file = np.load(DATA_FILE)
trajectory_raw = data_file['trajectory']
rmsd = data_file['rmsd']
rg = data_file['rg']
data_file.close()

print(f"✓ Loaded from: {DATA_FILE}")
print(f"  - Trajectory: {trajectory_raw.shape}")
print(f"  - RMSD: {rmsd.shape}")
print(f"  - Rg: {rg.shape}")

# Reshape and remove time column
data = trajectory_raw.reshape((trajectory_raw.shape[0], -1))[:, 1:]
print(f"Trajectory: {data.shape[0]} frames, {data.shape[1]} features")

# Extract original feature subset for comparison
original_data = trajectory_raw[:, ORIGINAL_FEATURES]
print(f"Original features selected: {len(ORIGINAL_FEATURES)} columns")


In [None]:
# ============================================================================
# STEP 2: WAVELET TRANSFORM DENOISING
# ============================================================================

print("\n" + "="*80)
print("[STEP 2] Wavelet Transform Denoising")
print("="*80)
print(f"Wavelet: {WAVELET_TYPE}, Level: {WAVELET_LEVEL}")

wt_data = np.zeros_like(data)
low_freq_energies = []
high_freq_energies = []
energy_ratios = []

# Apply wavelet denoising
for i in range(data.shape[1]):
    # Decompose
    coeffs = wavedec(data[:, i], WAVELET_TYPE, level=WAVELET_LEVEL)
    
    # Calculate energy before zeroing
    low_freq_component = coeffs[0]
    low_freq_energy = np.sum(low_freq_component ** 2)
    high_freq_energy = np.sum([np.sum(detail_coeff ** 2) for detail_coeff in coeffs[1:]])
    energy_ratio = low_freq_energy / (low_freq_energy + high_freq_energy)
    
    low_freq_energies.append(low_freq_energy)
    high_freq_energies.append(high_freq_energy)
    energy_ratios.append(energy_ratio)
    
    # Zero out high-frequency details
    for j in range(1, 5):
        coeffs[-j] = np.zeros_like(coeffs[-j])
    
    # Reconstruct
    wt_data[:, i] = waverec(coeffs, WAVELET_TYPE)[:data.shape[0]]

# Convert to arrays
low_freq_energies = np.array(low_freq_energies)
high_freq_energies = np.array(high_freq_energies)
energy_ratios = np.array(energy_ratios)

# Save energy analysis
np.savetxt('low_freq_energies.txt', low_freq_energies)
np.savetxt('high_freq_energies.txt', high_freq_energies)
np.savetxt('energy_ratios.txt', energy_ratios)
print(f"\nEnergy analysis saved")
print(f"Mean energy ratio: {energy_ratios.mean():.4f}")

# Save denoised data
np.savetxt('./chig_wt4_data', wt_data)
print(f"Denoised data saved to: ./chig_wt4_data")

# Plot: Original vs Denoised (first 2000 frames, feature index 2)
plt.figure(figsize=[20, 7])
plt.plot(data[:2000, 2], c='r', label='Original', linewidth=1, alpha=0.7)
plt.plot(wt_data[:2000, 2], c='k', label='Denoised', linewidth=1.5)
plt.xlabel('Frame', fontsize=14)
plt.ylabel('Feature Value', fontsize=14)
plt.title('Wavelet Denoising Example (Feature 2)', fontsize=16)
plt.legend(fontsize=12)
plt.grid(alpha=0.3)
plt.tight_layout()
plt.savefig('wavelet_denoising_example.png', dpi=300)
plt.show()

In [None]:
# ============================================================================
# STEP 3: tICA LAG TIME OPTIMIZATION (WAVELET-TRANSFORMED DATA)
# ============================================================================

print("\n" + "="*80)
print("[STEP 3] tICA Lag Time Optimization (Wavelet-Transformed Data)")
print("="*80)

# Select features for WT-tICA
wt_data_selected = wt_data[:, WT_FEATURES]
print(f"Selected features: {WT_FEATURES}")
print(f"Feature matrix: {wt_data_selected.shape}")

# Test different lag times
lags_wt = np.arange(10, 1000, 50)
eigenvalues_evolution_wt = []

print(f"Testing {len(lags_wt)} lag times...")
for lag in lags_wt:
    TICA_model = tica(lag=lag, dim=2)
    tica_dr_wt = TICA_model.fit_transform(wt_data_selected)
    eigenvalues = TICA_model.eigenvalues
    eigenvalues_evolution_wt.append(eigenvalues)

eigenvalues_evolution_wt = np.array(eigenvalues_evolution_wt)

# Save lag optimization results
wt_lag_data = np.column_stack((lags_wt, eigenvalues_evolution_wt))
np.savetxt('lag_and_eigenvalues_evolution-wt.txt', wt_lag_data)
print(f"Saved lag optimization: lag_and_eigenvalues_evolution-wt.txt")

# Plot eigenvalue evolution
plt.figure(figsize=(10, 6))
for i in range(eigenvalues_evolution_wt.shape[1]):
    plt.plot(lags_wt, eigenvalues_evolution_wt[:, i], 
            marker='o', label=f'Eigenvalue {i+1}', linewidth=2)
plt.xlabel('Lag Time', fontsize=14)
plt.ylabel('Eigenvalues', fontsize=14)
plt.title('Evolution of Eigenvalues vs Lag Time (WT Data)', fontsize=16)
plt.legend(fontsize=12)
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('eigenvalue_evolution_wt.png', dpi=300)
plt.show()

# Plot tICA projections at different lag times (5x5 grid)
plt.figure(figsize=[20, 20])
for n, lag in enumerate(np.arange(50, 600, 50)):
    plt.subplot(5, 5, n+1)
    TICA = tica(lag=lag, dim=2)
    tica_dr_wt_temp = TICA.fit_transform(wt_data_selected[::5])
    plt.scatter(tica_dr_wt_temp[::5, 0], tica_dr_wt_temp[::5, 1], 
               s=0.1, alpha=0.1, c='blue')
    plt.title(f'Lag={lag}', fontsize=12)
    plt.axis('off')
plt.suptitle('tICA Projections at Different Lag Times (WT Data)', fontsize=18, y=0.995)
plt.tight_layout()
plt.savefig('tica_projections_different_lags_wt.png', dpi=300)
plt.show()

In [None]:
# ============================================================================
# STEP 4: PERFORM tICA WITH OPTIMAL LAG (WAVELET-TRANSFORMED DATA)
# ============================================================================

print("\n" + "="*80)
print("[STEP 4] Performing tICA with Optimal Lag (WT Data)")
print("="*80)
print(f"Optimal lag time: {TICA_LAG_WT}")

# Perform tICA
TICA = tica(lag=TICA_LAG_WT, dim=2)
tica_dr_wt = TICA.fit_transform(wt_data_selected)

print(f"tICA projection shape: {tica_dr_wt.shape}")
print(f"Range of tIC1: [{tica_dr_wt[:, 0].min():.4f}, {tica_dr_wt[:, 0].max():.4f}]")
print(f"Range of tIC2: [{tica_dr_wt[:, 1].min():.4f}, {tica_dr_wt[:, 1].max():.4f}]")

# Plot final tICA projection
plt.figure(figsize=[10, 8])
plt.scatter(tica_dr_wt[:, 0], tica_dr_wt[:, 1], s=0.1, alpha=0.01, c='blue')
plt.xlabel('tIC1', fontsize=14)
plt.ylabel('tIC2', fontsize=14)
plt.title(f'tICA Projection (WT Data, Lag={TICA_LAG_WT})', fontsize=16)
plt.grid(alpha=0.3)
plt.tight_layout()
plt.savefig('tica_projection_wt_final.png', dpi=300)
plt.show()

# Save tICA trajectory
save_dir = './tica-sin-test'
os.makedirs(save_dir, exist_ok=True)
with open('./tica-sin-test/encoded-wt4-200.traj', 'w') as f:
    for h, i in enumerate(tica_dr_wt):
        f.write('%d \t %.4f \t %.4f\n' % ((h+1)*500, i[0], i[1]))
print(f"Saved WT tICA trajectory")

In [None]:
# ============================================================================
# STEP 5: FEATURE IMPORTANCE ANALYSIS (WAVELET-TRANSFORMED DATA)
# ============================================================================

print("\n" + "="*80)
print("[STEP 5] Feature Importance Analysis (WT Data)")
print("="*80)

# Get feature importance from tICA eigenvectors
tica_model_wt = tica(lag=200, dim=2)
tica_model_wt.fit(wt_data_selected)
eigenvectors_wt = tica_model_wt.eigenvectors
feature_importance_tica_wt = np.abs(eigenvectors_wt[:, 0])

print("Feature importance based on tICA eigenvector weights (WT):")
for i, importance in enumerate(feature_importance_tica_wt):
    print(f"  Feature {WT_FEATURES[i]}: {importance:.4f}")

# Plot feature importance
plt.figure(figsize=(8, 6))
plt.bar(range(len(feature_importance_tica_wt)), feature_importance_tica_wt, 
       color='steelblue', alpha=0.7, edgecolor='black')
plt.xlabel('Feature Index', fontsize=14)
plt.ylabel('Importance', fontsize=14)
plt.title('Feature Importance for tIC1 (WT Data)', fontsize=16)
plt.ylim(0, 1.5)
plt.xticks(range(len(feature_importance_tica_wt)), WT_FEATURES)
plt.grid(alpha=0.3, axis='y')
plt.tight_layout()
plt.savefig('feature_importance_wt.png', dpi=300)
plt.show()

In [None]:
# ============================================================================
# STEP 6: tICA ON ORIGINAL DATA (FOR COMPARISON)
# ============================================================================

print("\n" + "="*80)
print("[STEP 6] tICA on Original Data (No Denoising)")
print("="*80)

# Lag time optimization for original data
lags_orig = np.arange(10, 600, 100)
eigenvalues_evolution_orig = []

print(f"Testing {len(lags_orig)} lag times for original data...")
for lag in lags_orig:
    TICA_model = tica(lag=lag, dim=2)
    tica_dr_orin = TICA_model.fit_transform(original_data)
    eigenvalues = TICA_model.eigenvalues
    eigenvalues_evolution_orig.append(eigenvalues)

eigenvalues_evolution_orig = np.array(eigenvalues_evolution_orig)

# Save
orin_lag_data = np.column_stack((lags_orig, eigenvalues_evolution_orig))
np.savetxt('lag_and_eigenvalues_evolution-orin.txt', orin_lag_data)

# Plot eigenvalue evolution
plt.figure(figsize=(10, 6))
for i in range(eigenvalues_evolution_orig.shape[1]):
    plt.plot(lags_orig, eigenvalues_evolution_orig[:, i], 
            marker='o', label=f'Eigenvalue {i+1}', linewidth=2)
plt.xlabel('Lag Time', fontsize=14)
plt.ylabel('Eigenvalues', fontsize=14)
plt.title('Evolution of Eigenvalues vs Lag Time (Original Data)', fontsize=16)
plt.legend(fontsize=12)
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('eigenvalue_evolution_original.png', dpi=300)
plt.show()

# Plot tICA projections at different lag times (original data)
plt.figure(figsize=[20, 20])
for n, lag in enumerate(np.arange(50, 600, 50)):
    plt.subplot(5, 5, n+1)
    TICA = tica(lag=lag, dim=2)
    tica_dr_orin_temp = TICA.fit_transform(original_data[::5])
    plt.scatter(tica_dr_orin_temp[::5, 0], tica_dr_orin_temp[::5, 1],
               s=0.1, alpha=0.1, c='red')
    plt.title(f'Lag={lag}', fontsize=12)
    plt.axis('off')
plt.suptitle('tICA Projections at Different Lag Times (Original Data)', fontsize=18, y=0.995)
plt.tight_layout()
plt.savefig('tica_projections_different_lags_original.png', dpi=300)
plt.show()

# Perform tICA with optimal lag for original data
print(f"\nPerforming tICA with lag={TICA_LAG_ORIG}")
TICA = tica(lag=TICA_LAG_ORIG, dim=2)
tica_dr_orin = TICA.fit_transform(original_data)

print(f"Original tICA projection shape: {tica_dr_orin.shape}")
print(f"Range of tIC1: [{tica_dr_orin[:, 0].min():.4f}, {tica_dr_orin[:, 0].max():.4f}]")
print(f"Range of tIC2: [{tica_dr_orin[:, 1].min():.4f}, {tica_dr_orin[:, 1].max():.4f}]")

# Plot
plt.figure(figsize=[10, 8])
plt.scatter(tica_dr_orin[:, 0], tica_dr_orin[:, 1], s=0.1, alpha=0.01, c='red')
plt.xlabel('tIC1', fontsize=14)
plt.ylabel('tIC2', fontsize=14)
plt.title(f'tICA Projection (Original Data, Lag={TICA_LAG_ORIG})', fontsize=16)
plt.grid(alpha=0.3)
plt.tight_layout()
plt.savefig('tica_projection_original_final.png', dpi=300)
plt.show()

# Save trajectory

save_dir = './tica-sin-wt4-lag'
os.makedirs(save_dir, exist_ok=True)

with open('./tica-sin-wt4-lag/encoded-orin-250.traj', 'w') as f:
    for h, i in enumerate(tica_dr_orin):
        f.write('%d \t %.4f \t %.4f\n' % ((h+1)*500, i[0], i[1]))


In [None]:
# ============================================================================
# STEP 7: FEATURE IMPORTANCE ANALYSIS (ORIGINAL DATA)
# ============================================================================

print("\n" + "="*80)
print("[STEP 7] Feature Importance Analysis (Original Data)")
print("="*80)

tica_model_orig = tica(lag=200, dim=2)
tica_model_orig.fit(original_data)
eigenvectors_orig = tica_model_orig.eigenvectors
feature_importance_tica_orig = np.abs(eigenvectors_orig[:, 0])

print("Feature importance based on tICA eigenvector weights (Original):")
for i, importance in enumerate(feature_importance_tica_orig):
    print(f"  Feature {i}: {importance:.4f}")

plt.figure(figsize=(8, 6))
plt.bar(range(len(feature_importance_tica_orig)), feature_importance_tica_orig,
       color='coral', alpha=0.7, edgecolor='black')
plt.xlabel('Feature Index', fontsize=14)
plt.ylabel('Importance', fontsize=14)
plt.title('Feature Importance for tIC1 (Original Data)', fontsize=16)
plt.ylim(0, 1.5)
plt.grid(alpha=0.3, axis='y')
plt.tight_layout()
plt.savefig('feature_importance_original.png', dpi=300)
plt.show()

In [None]:
# ============================================================================
# STEP 8: CORRELATION ANALYSIS BETWEEN WT-tICA AND ORIGINAL-tICA
# ============================================================================

print("\n" + "="*80)
print("[STEP 8] Correlation Analysis: WT-tICA vs Original-tICA")
print("="*80)

# Calculate Pearson correlations
pearson_corrs = np.zeros(2)
p_values = np.zeros(2)

for i in range(2):
    pearson_corrs[i], p_values[i] = pearsonr(tica_dr_wt[:, i], tica_dr_orin[:, i])
    print(f"Pearson correlation tIC{i+1}: {pearson_corrs[i]:.4f}, p-value: {p_values[i]:.4e}")

# Plot correlation heatmaps
plt.figure(figsize=(6, 4))
corr_matrix = np.array([[1, pearson_corrs[0]], [pearson_corrs[0], 1]])
sns.heatmap(corr_matrix, annot=True, cmap='coolwarm', center=0, vmin=-1, vmax=1,
           xticklabels=['tIC1 (Orig)', 'tIC1 (Wavelet)'],
           yticklabels=['tIC1 (Orig)', 'tIC1 (Wavelet)'])
plt.title('Correlation Heatmap for tIC1', fontsize=14)
plt.tight_layout()
plt.savefig('correlation_heatmap_tIC1.png', dpi=300)
plt.show()

plt.figure(figsize=(6, 4))
corr_matrix = np.array([[1, pearson_corrs[1]], [pearson_corrs[1], 1]])
sns.heatmap(corr_matrix, annot=True, cmap='coolwarm', center=0, vmin=-1, vmax=1,
           xticklabels=['tIC2 (Orig)', 'tIC2 (Wavelet)'],
           yticklabels=['tIC2 (Orig)', 'tIC2 (Wavelet)'])
plt.title('Correlation Heatmap for tIC2', fontsize=14)
plt.tight_layout()
plt.savefig('correlation_heatmap_tIC2.png', dpi=300)
plt.show()

In [None]:
# ============================================================================
# STEP 9: HDBSCAN CLUSTERING ON WT-tICA
# ============================================================================

print("\n" + "="*80)
print("[STEP 9] HDBSCAN Clustering (WT-tICA)")
print("="*80)
print(f"min_cluster_size: {MIN_CLUSTER_SIZE}")
print(f"min_samples: {MIN_SAMPLES}")

# Perform HDBSCAN
est_wt = hdbscan.HDBSCAN(
    min_cluster_size=MIN_CLUSTER_SIZE,
    min_samples=MIN_SAMPLES,
    core_dist_n_jobs=N_JOBS,
    cluster_selection_method=CLUSTER_METHOD,
    gen_min_span_tree=True
)
est_wt.fit(tica_dr_wt)
labels_wt = est_wt.labels_

n_clusters_wt = len(set(labels_wt)) - (1 if -1 in labels_wt else 0)
n_noise_wt = np.sum(labels_wt == -1)

print(f"\nClustering results:")
print(f"  Clusters: {n_clusters_wt}")
print(f"  Noise points: {n_noise_wt} ({100*n_noise_wt/len(labels_wt):.2f}%)")

# Plot 1 - Include all points (including noise)
plt.figure(figsize=(10, 8))
plt.scatter(tica_dr_wt[:, 0], tica_dr_wt[:, 1],
           c=labels_wt, alpha=1, s=0.01, cmap='tab20')

plt.colorbar(label='Cluster ID')
plt.xlabel('tIC1', fontsize=14)
plt.ylabel('tIC2', fontsize=14)
plt.title('HDBSCAN Clustering - All Points (WT-tICA)', fontsize=16)
plt.grid(alpha=0.3)
plt.tight_layout()
plt.savefig('hdbscan_wt_tica_all_points.png', dpi=300)
plt.show()

# Plot 2 - Show only non-noise points
plt.figure(figsize=(10, 8))
mask_wt = (labels_wt != -1)
plt.scatter(tica_dr_wt[:, 0][mask_wt], tica_dr_wt[:, 1][mask_wt],
           c=labels_wt[mask_wt], alpha=0.1, s=0.1, cmap='tab20')
plt.colorbar(label='Cluster ID')
plt.xlabel('tIC1', fontsize=14)
plt.ylabel('tIC2', fontsize=14)
plt.title('HDBSCAN Clustering - Clusters Only (WT-tICA)', fontsize=16)
plt.grid(alpha=0.3)
plt.tight_layout()
plt.savefig('hdbscan_wt_tica_clusters_only.png', dpi=300)
plt.show()

# Save clustering data (excluding noise, downsampled)
os.makedirs(OUTPUT_DIR, exist_ok=True)
data_wt_cluster = np.column_stack((
    tica_dr_wt[:, 0][mask_wt][::10],
    tica_dr_wt[:, 1][mask_wt][::10],
    labels_wt[mask_wt][::10]
))
np.savetxt(os.path.join(OUTPUT_DIR, 'chig-wt_cluster-10.txt'), 
          data_wt_cluster, delimiter='\t', comments='')

# Save full trajectory with labels
with open('wt250_label_plot.txt', 'w') as f:
    f.write('Frame\tx1\tx2\tlabel\n')
    for h, i in enumerate(tica_dr_wt):
        f.write('%d\t%.4f\t%.4f\t%d\n' % (1*(h+1), i[0], i[1], labels_wt[h]))

print(f"\nSaved clustering results and plots")



In [None]:
# ============================================================================
# STEP 10: VALIDATION ON RMSD-Rg SPACE (WT-tICA CLUSTERING)
# ============================================================================

print("\n" + "="*80)
print("[STEP 11] Validation on RMSD-Rg Space (WT-tICA Clustering)")
print("="*80)

# Plot RMSD-Rg colored by WT-tICA clusters
plt.figure(figsize=(10, 8))
plt.scatter(rmsd[mask_wt], rg[mask_wt],
           c=labels_wt[mask_wt], cmap='tab20', s=0.1, alpha=1)
plt.colorbar(label='Cluster ID')
plt.xlabel('RMSD (Å)', fontsize=14)
plt.ylabel('Rg (Å)', fontsize=14)
plt.title('HDBSCAN Clusters in RMSD-Rg Space (WT-tICA)', fontsize=16)
plt.grid(alpha=0.3)
plt.tight_layout()
plt.savefig('hdbscan_wt_rmsd_rg_space.png', dpi=300)
plt.show()

# Save RMSD-Rg clustering
data_rmsd_rg = np.column_stack((rmsd[mask_wt], rg[mask_wt], labels_wt[mask_wt]))
np.savetxt(os.path.join(OUTPUT_DIR, 'chig-wt_rmsd.txt'),
          data_rmsd_rg, delimiter='\t', comments='')

# Save full RMSD-Rg trajectory with labels
with open('rmsd_label_data-200.txt', 'w') as f:
    f.write('Frame\trmsd\trg\tlabel\n')
    for h in range(rmsd.shape[0]):
        f.write('%d\t%.4f\t%.4f\t%d\n' % (1*(h+1), rmsd[h], rg[h], labels_wt[h]))

In [None]:
# ============================================================================
# STEP 11: HDBSCAN CLUSTERING ON ORIGINAL-tICA
# ============================================================================

print("\n" + "="*80)
print("[STEP 12] HDBSCAN Clustering (Original-tICA)")
print("="*80)

# Perform HDBSCAN on original tICA
est_orig = hdbscan.HDBSCAN(
    min_cluster_size=MIN_CLUSTER_SIZE,
    min_samples=MIN_SAMPLES,
    core_dist_n_jobs=N_JOBS,
    gen_min_span_tree=True
)
est_orig.fit(tica_dr_orin)
labels_orig = est_orig.labels_

n_clusters_orig = len(set(labels_orig)) - (1 if -1 in labels_orig else 0)
n_noise_orig = np.sum(labels_orig == -1)

print(f"Clustering results:")
print(f"  Clusters: {n_clusters_orig}")
print(f"  Noise points: {n_noise_orig} ({100*n_noise_orig/len(labels_orig):.2f}%)")

# Plot: Clustering in original tICA space
plt.figure(figsize=(10, 8))
mask_orig = (labels_orig != -1)
plt.scatter(tica_dr_orin[:, 0][mask_orig], tica_dr_orin[:, 1][mask_orig],
           c=labels_orig[mask_orig], alpha=1, s=0.01, cmap='tab20')
plt.colorbar(label='Cluster ID')
plt.xlabel('tIC1', fontsize=14)
plt.ylabel('tIC2', fontsize=14)
plt.title('HDBSCAN Clustering (Original-tICA Space)', fontsize=16)
plt.grid(alpha=0.3)
plt.tight_layout()
plt.savefig('hdbscan_original_tica_space.png', dpi=300)
plt.show()

# Save clustering data
data_orig_cluster = np.column_stack((
    tica_dr_orin[:, 0][mask_orig],
    tica_dr_orin[:, 1][mask_orig],
    labels_orig[mask_orig]
))
np.savetxt(os.path.join(OUTPUT_DIR, 'chig-orin_cluster.txt'),
          data_orig_cluster, delimiter='\t', comments='')



In [1]:
# ============================================================================
# STEP 12: VALIDATION ON RMSD-Rg SPACE (ORIGINAL-tICA CLUSTERING)
# ============================================================================

print("\n" + "="*80)
print("[STEP 13] Validation on RMSD-Rg Space (Original-tICA Clustering)")
print("="*80)

# Plot RMSD-Rg colored by original-tICA clusters
plt.figure(figsize=(10, 8))
plt.scatter(rmsd[mask_orig], rg[mask_orig],
           c=labels_orig[mask_orig], cmap='tab20', s=0.01, alpha=1)
plt.colorbar(label='Cluster ID')
plt.xlabel('RMSD (Å)', fontsize=14)
plt.ylabel('Rg (Å)', fontsize=14)
plt.title('HDBSCAN Clusters in RMSD-Rg Space (Original-tICA)', fontsize=16)
plt.grid(alpha=0.3)
plt.tight_layout()
plt.savefig('hdbscan_original_rmsd_rg_space.png', dpi=300)
plt.show()

# Save RMSD-Rg clustering
data_orig_rmsd = np.column_stack((rmsd[mask_orig], rg[mask_orig], labels_orig[mask_orig]))
np.savetxt(os.path.join(OUTPUT_DIR, 'chig-orin_rmsd.txt'),
          data_orig_rmsd, delimiter='\t', comments='')