In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import seaborn as sns
from pywt import wavedec, waverec
from pyemma.coordinates import tica
from scipy.stats import pearsonr
import hdbscan
import os
import warnings
warnings.filterwarnings("ignore")


In [None]:
# ============================================================
# Visualization Settings
# ============================================================
mpl.rcParams['axes.linewidth']    = 2
mpl.rcParams['xtick.major.size']  = 4
mpl.rcParams['xtick.major.width'] = 2
mpl.rcParams['ytick.major.size']  = 4
mpl.rcParams['ytick.major.width'] = 2
mpl.rcParams['xtick.direction']   = 'in'
mpl.rcParams['ytick.direction']   = 'in'
mpl.rcParams['font.size']         = 14
mpl.rcParams['savefig.dpi']       = 300

In [None]:
# ============================================================
# Step 1: Load Data & Create Output Directories
# ============================================================
os.makedirs('wt3-150',  exist_ok=True)
os.makedirs('tica-test', exist_ok=True)

loaded = np.load('tri_sin_phi_data.npz')
tri_sin_phi = loaded['tri_sin_phi']

data_orig = tri_sin_phi[:, 1:7]   # original data (6 columns)
phi1      = tri_sin_phi[:, 7]
phi2      = tri_sin_phi[:, 8]
phi3      = tri_sin_phi[:, 9]

print(f"data_orig shape:    {data_orig.shape}")
print(f"tri_sin_phi shape:  {tri_sin_phi.shape}")

In [None]:
# ============================================================
# Step 2: Discrete Wavelet Transform (DWT) Denoising
#   wavelet : coif7
#   level   : 3
#   action  : zero out the 3 lowest-level detail coefficients
# ============================================================
data_wt = np.zeros_like(data_orig)
low_freq_energies  = []
high_freq_energies = []
energy_ratios      = []

for i in range(data_orig.shape[1]):
    coeffs = wavedec(data_orig[:, i], 'coif7', level=3)

    # --- energy calculation (before zeroing) ---
    low_freq_energy  = np.sum(coeffs[0] ** 2)
    high_freq_energy = np.sum([np.sum(c ** 2) for c in coeffs[1:]])
    energy_ratio     = low_freq_energy / (low_freq_energy + high_freq_energy)

    low_freq_energies.append(low_freq_energy)
    high_freq_energies.append(high_freq_energy)
    energy_ratios.append(energy_ratio)

    # --- zero out detail coefficients & reconstruct ---
    for j in range(1, 4):
        coeffs[-j] = np.zeros_like(coeffs[-j])
    data_wt[:, i] = waverec(coeffs, 'coif7')[:data_orig.shape[0]]

# --- save energy info ---
low_freq_energies  = np.array(low_freq_energies)
high_freq_energies = np.array(high_freq_energies)
energy_ratios      = np.array(energy_ratios)
np.savetxt('low_freq_energies.txt',  low_freq_energies)
np.savetxt('high_freq_energies.txt', high_freq_energies)
np.savetxt('energy_ratios.txt',      energy_ratios)
np.savetxt('tri_wt3_data.txt',       data_wt)

# --- save original vs wavelet comparison (first 2000 frames, column 1) ---
np.savetxt('tri-wt_sin2.txt',
           np.column_stack((data_orig[:2000, 1], data_wt[:2000, 1])),
           delimiter='\t')

# --- quick visual check: original vs wavelet (first 1000 frames, column 1) ---
plt.figure(figsize=(20, 7))
plt.plot(data_orig[:1000, 1], c='r', label='Original')
plt.plot(data_wt[:1000, 1],   c='k', label='Wavelet')
plt.legend()
plt.xlabel('Frame')
plt.ylabel('Value')
plt.title('Original vs Wavelet-Denoised (column 1, first 1000 frames)')
plt.savefig('wt_check_original_vs_wavelet.png', dpi=300, bbox_inches='tight')
plt.show()


In [None]:
# ============================================================
# Step 3: Eigenvalue Evolution vs Lag Time
#   scans lag = 10, 60, 110, ..., 460
#   fits tICA (dim=2) on both original and wavelet data
# ============================================================
lags = np.arange(10, 500, 50)

eigenvalues_orig = []
eigenvalues_wt   = []

for lag in lags:
    model_orig = tica(lag=lag, dim=2)
    model_orig.fit(data_orig)
    eigenvalues_orig.append(model_orig.eigenvalues)

    model_wt = tica(lag=lag, dim=2)
    model_wt.fit(data_wt)
    eigenvalues_wt.append(model_wt.eigenvalues)

eigenvalues_orig = np.array(eigenvalues_orig)
eigenvalues_wt   = np.array(eigenvalues_wt)

# --- save ---
np.savetxt('wt3-150/lag_and_eigenvalues_evolution-orin3.txt',
           np.column_stack((lags, eigenvalues_orig)))
np.savetxt('wt3-150/lag_and_eigenvalues_evolution-wt3.txt',
           np.column_stack((lags, eigenvalues_wt)))

# --- plot: original ---
plt.figure(figsize=(10, 6))
for i in range(eigenvalues_orig.shape[1]):
    plt.plot(lags, eigenvalues_orig[:, i], marker='o', label=f'Eigenvalue {i+1}')
plt.xlabel('Lag Time')
plt.ylabel('Eigenvalues')
plt.title('Eigenvalue Evolution vs Lag Time (Original)')
plt.legend()
plt.grid(True)
plt.savefig('eigenvalues_evolution_original.png', dpi=300, bbox_inches='tight')
plt.show()

# --- plot: wavelet ---
plt.figure(figsize=(10, 6))
for i in range(eigenvalues_wt.shape[1]):
    plt.plot(lags, eigenvalues_wt[:, i], marker='o', label=f'Eigenvalue {i+1}')
plt.xlabel('Lag Time')
plt.ylabel('Eigenvalues')
plt.title('Eigenvalue Evolution vs Lag Time (Wavelet)')
plt.legend()
plt.grid(True)
plt.savefig('eigenvalues_evolution_wavelet.png', dpi=300, bbox_inches='tight')
plt.show()

In [None]:
# ============================================================
# Step 4: tICA Scatter Plots Across Different Lags (5×5 grid)
# ============================================================
# --- wavelet ---
plt.figure(figsize=(20, 20))
for n, lag in enumerate(lags):
    plt.subplot(5, 5, n + 1)
    model = tica(lag=lag, dim=2)
    dr = model.fit_transform(data_wt)
    plt.scatter(dr[:, 0], dr[:, 1], s=0.1, alpha=0.1)
    plt.title(f'lag={lag}')
    plt.axis('off')
plt.suptitle('tICA Scatter — Wavelet', fontsize=20)
plt.tight_layout()
plt.savefig('tica_grid_wavelet.png', dpi=300, bbox_inches='tight')
plt.show()

# --- original ---
plt.figure(figsize=(20, 20))
for n, lag in enumerate(lags):
    plt.subplot(5, 5, n + 1)
    model = tica(lag=lag, dim=2)
    dr = model.fit_transform(data_orig)
    plt.scatter(dr[:, 0], dr[:, 1], s=0.1, alpha=0.1)
    plt.title(f'lag={lag}')
    plt.axis('off')
plt.suptitle('tICA Scatter — Original', fontsize=20)
plt.tight_layout()
plt.savefig('tica_grid_original.png', dpi=300, bbox_inches='tight')
plt.show()


In [None]:
# ============================================================
# Step 5: Fix Lag = 150, fit tICA on both datasets
# ============================================================
TICA_orig = tica(lag=150, dim=2)
tica_dr_orig = TICA_orig.fit_transform(data_orig)

TICA_wt = tica(lag=150, dim=2)
tica_data_wt = TICA_wt.fit_transform(data_wt)

In [None]:
# ============================================================
# Step 6: Feature Importance (tICA eigenvector weights)
# ============================================================
print("\n--- Feature Importance (Original, tIC2) ---")
feat_imp_orig = np.abs(TICA_orig.eigenvectors[:, 1])
for i, val in enumerate(feat_imp_orig):
    print(f"  Feature {i}: {val:.4f}")

print("\n--- Feature Importance (Wavelet, tIC2) ---")
feat_imp_wt = np.abs(TICA_wt.eigenvectors[:, 1])
for i, val in enumerate(feat_imp_wt):
    print(f"  Feature {i}: {val:.4f}")

print("\n--- Eigenvectors (Original) ---")
print(TICA_orig.eigenvectors[:, :2])
print("\n--- Eigenvectors (Wavelet) ---")
print(TICA_wt.eigenvectors[:, :2])

In [None]:
# ============================================================
# Step 7: Pearson Correlation — Original vs Wavelet tICA
# ============================================================
pearson_corrs = np.zeros(2)
p_values      = np.zeros(2)

for i in range(2):
    pearson_corrs[i], p_values[i] = pearsonr(tica_data_wt[:, i], tica_dr_orig[:, i])
    print(f"Pearson corr tIC{i+1}: {pearson_corrs[i]:.4f}  (p = {p_values[i]:.4e})")

# --- heatmaps ---
for idx, label in enumerate(['tIC1', 'tIC2']):
    plt.figure(figsize=(6, 4))
    mat = np.array([[1, pearson_corrs[idx]], [pearson_corrs[idx], 1]])
    sns.heatmap(mat, annot=True, cmap='coolwarm',
                xticklabels=[f'{label} (Orig)', f'{label} (Wavelet)'],
                yticklabels=[f'{label} (Orig)', f'{label} (Wavelet)'])
    plt.title(f'Correlation Heatmap — {label}')
    plt.savefig(f'corr_heatmap_{label}.png', dpi=300, bbox_inches='tight')
    plt.show()


In [None]:
# ============================================================
# Step 8: HDBSCAN Clustering
# ============================================================

# --- 8a: Wavelet data ---
print("\n" + "="*80)
print("HDBSCAN Clustering - Wavelet Data")
print("="*80)

cluster_wt = hdbscan.HDBSCAN(min_cluster_size=1500, min_samples=300,
                          core_dist_n_jobs=6, cluster_selection_method='eom',
                          gen_min_span_tree=True)
cluster_wt.fit(tica_data_wt)
labels_wt = cluster_wt.labels_
mask_wt   = labels_wt != -1


n_clusters_wt = len(set(labels_wt)) - (1 if -1 in labels_wt else 0)
n_noise_wt = np.sum(labels_wt == -1)
print(f"Total points:          {len(labels_wt)}")
print(f"Noise points (-1):     {n_noise_wt} ({n_noise_wt/len(labels_wt)*100:.2f}%)")
print(f"Clustered points:      {np.sum(mask_wt)} ({np.sum(mask_wt)/len(labels_wt)*100:.2f}%)")
print(f"Number of clusters:    {n_clusters_wt}")
for k in set(labels_wt):
    if k != -1:
        count = np.sum(labels_wt == k)
        print(f"  Cluster {k}: {count} points ({count/len(labels_wt)*100:.2f}%)")

# --- plot: wavelet - tICA space with noise in gray ---
plt.figure(figsize=(8, 6))

plt.scatter(tica_data_wt[~mask_wt, 0], tica_data_wt[~mask_wt, 1],
            c='lightgray', s=1, alpha=0.3, label='Noise')

if mask_wt.sum() > 0:
    plt.scatter(tica_data_wt[mask_wt, 0], tica_data_wt[mask_wt, 1],
                c=labels_wt[mask_wt], s=1, alpha=0.8, cmap='tab20')
    plt.colorbar(label='Cluster Label')
plt.xlabel('dCV1', fontsize=16)
plt.ylabel('dCV2', fontsize=16)
plt.title(f'Wavelet HDBSCAN (n_clusters={n_clusters_wt})', fontsize=16, fontweight='bold')
plt.savefig('hdbscan_wt_tica_space.png', dpi=300, bbox_inches='tight')
plt.show()

# scatter in phi1-phi2 space
plt.figure(figsize=(8, 6))

plt.scatter(phi1[~mask_wt], phi2[~mask_wt],
            c='lightgray', s=1, alpha=0.3, label='Noise')

if mask_wt.sum() > 0:
    plt.scatter(phi1[mask_wt], phi2[mask_wt],
                c=labels_wt[mask_wt], alpha=0.8, s=1, cmap='tab20')
    plt.colorbar(label='Cluster Label')
plt.xlabel('φ1 (degrees)', fontsize=16)
plt.ylabel('φ2 (degrees)', fontsize=16)
plt.xlim(-180, 180)
plt.ylim(-180, 180)
plt.savefig('hdbscan_wt_phi1_phi2.png', dpi=300, bbox_inches='tight')
plt.show()

# --- save wavelet cluster results ---
np.savetxt('wt3-150/tri-wt_cluster.txt',
           np.column_stack((tica_data_wt[mask_wt, 0], tica_data_wt[mask_wt, 1], labels_wt[mask_wt])),
           delimiter='\t')
np.savetxt('wt3-150/tri-wt_cluster-phi23.txt',
           np.column_stack((phi2[mask_wt], phi3[mask_wt], labels_wt[mask_wt])),
           delimiter='\t')

# --- 8b: Original data ---
print("\n" + "="*80)
print("HDBSCAN Clustering - Original Data")
print("="*80)

cluster_orig = hdbscan.HDBSCAN(min_cluster_size=1000, min_samples=200,
                            core_dist_n_jobs=6, cluster_selection_method='eom',
                            gen_min_span_tree=True)
cluster_orig.fit(tica_dr_orig)
labels_orig = cluster_orig.labels_
mask_orig   = labels_orig != -1

n_clusters_orig = len(set(labels_orig)) - (1 if -1 in labels_orig else 0)
n_noise_orig = np.sum(labels_orig == -1)
print(f"Total points:          {len(labels_orig)}")
print(f"Noise points (-1):     {n_noise_orig} ({n_noise_orig/len(labels_orig)*100:.2f}%)")
print(f"Clustered points:      {np.sum(mask_orig)} ({np.sum(mask_orig)/len(labels_orig)*100:.2f}%)")
print(f"Number of clusters:    {n_clusters_orig}")
for k in set(labels_orig):
    if k != -1:
        count = np.sum(labels_orig == k)
        print(f"  Cluster {k}: {count} points ({count/len(labels_orig)*100:.2f}%)")

# --- plot: original - tICA space with noise in gray ---
plt.figure(figsize=(8, 6))

plt.scatter(tica_dr_orig[~mask_orig, 0], tica_dr_orig[~mask_orig, 1],
            c='lightgray', s=1, alpha=0.3, label='Noise')

if mask_orig.sum() > 0:
    plt.scatter(tica_dr_orig[mask_orig, 0], tica_dr_orig[mask_orig, 1],
                c=labels_orig[mask_orig], s=1, alpha=0.8, cmap='tab20')
    plt.colorbar(label='Cluster Label')
plt.xlabel('dCV1', fontsize=16)
plt.ylabel('dCV2', fontsize=16)
plt.title(f'Original HDBSCAN (n_clusters={n_clusters_orig})', fontsize=16, fontweight='bold')
plt.savefig('hdbscan_orig_tica_space.png', dpi=300, bbox_inches='tight')
plt.show()

# scatter in phi1-phi2 space
plt.figure(figsize=(8, 6))

plt.scatter(phi1[~mask_orig], phi2[~mask_orig],
            c='lightgray', s=1, alpha=0.3, label='Noise')

if mask_orig.sum() > 0:
    plt.scatter(phi1[mask_orig], phi2[mask_orig],
                c=labels_orig[mask_orig], alpha=0.8, s=1, cmap='tab20')
    plt.colorbar(label='Cluster Label')
plt.xlabel('φ1 (degrees)', fontsize=16)
plt.ylabel('φ2 (degrees)', fontsize=16)
plt.xlim(-180, 180)
plt.ylim(-180, 180)
plt.savefig('hdbscan_orig_phi1_phi2.png', dpi=300, bbox_inches='tight')
plt.show()

# --- save original cluster results ---
np.savetxt('wt3-150/tri-orin_cluster.txt',
           np.column_stack((tica_dr_orig[mask_orig, 0], tica_dr_orig[mask_orig, 1], labels_orig[mask_orig])),
           delimiter='\t')
np.savetxt('wt3-150/tri-orin_cluster-phi12.txt',
           np.column_stack((phi1[mask_orig], phi2[mask_orig], labels_orig[mask_orig])),
           delimiter='\t')

print("="*80)

In [None]:
# ============================================================
# Step 9: Export .traj Files & Print Ranges
# ============================================================
# --- wavelet ---
with open('tica-test/encoded-wt4-150.traj', 'w') as f:
    for h, row in enumerate(tica_data_wt):
        f.write('%d \t %.4f \t %.4f\n' % ((h + 1) * 1000, row[0], row[1]))
print(f"Wavelet  tIC1 range: [{tica_data_wt[:, 0].min():.4f}, {tica_data_wt[:, 0].max():.4f}]")
print(f"Wavelet  tIC2 range: [{tica_data_wt[:, 1].min():.4f}, {tica_data_wt[:, 1].max():.4f}]")

# --- original ---
with open('tica-test/encoded-orin3-150.traj', 'w') as f:
    for h, row in enumerate(tica_dr_orig):
        f.write('%d \t %.4f \t %.4f\n' % ((h + 1) * 1000, row[0], row[1]))
print(f"Original tIC1 range: [{tica_dr_orig[:, 0].min():.4f}, {tica_dr_orig[:, 0].max():.4f}]")
print(f"Original tIC2 range: [{tica_dr_orig[:, 1].min():.4f}, {tica_dr_orig[:, 1].max():.4f}]")