In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import torch
from torch.utils.data import DataLoader
from deeptime.util.torch import MLP
from deeptime.decomposition.deep import TAE
from deeptime.util.data import TrajectoryDataset
import hdbscan
import os
import warnings
warnings.filterwarnings("ignore")

In [None]:
# ============================================================
# Visualization Settings
# ============================================================
mpl.rcParams['axes.linewidth']    = 2
mpl.rcParams['xtick.major.size']  = 4
mpl.rcParams['xtick.major.width'] = 2
mpl.rcParams['ytick.major.size']  = 4
mpl.rcParams['ytick.major.width'] = 2
mpl.rcParams['xtick.direction']   = 'in'
mpl.rcParams['ytick.direction']   = 'in'
mpl.rcParams['font.size']         = 14
mpl.rcParams['savefig.dpi']       = 300

In [None]:
# ============================================================
# Step 1: Load Data & Create Output Directories
# ============================================================
os.makedirs('wt3-150/wt3-tae/tae-new/wt3-60', exist_ok=True)
os.makedirs('wt3-150/wt3-tae/tae-new/orin-60', exist_ok=True)
os.makedirs('tae-noise', exist_ok=True)

loaded = np.load('tri_sin_phi_data.npz')
tri_sin_phi = loaded['tri_sin_phi']

# origin data
data_orig = tri_sin_phi[:, 1:7]

# Data after wavelet denoising (generated by WE-tICA-HDBSCAN)
data_wt = np.loadtxt('tri_wt3_data.txt')

# dihedral angle data
phi1 = tri_sin_phi[:, 7]
phi2 = tri_sin_phi[:, 8]
phi3 = tri_sin_phi[:, 9]

print(f"data_orig shape:    {data_orig.shape}")
print(f"data_wt shape:      {data_wt.shape}")
print(f"tri_sin_phi shape:  {tri_sin_phi.shape}")

In [None]:
# ============================================================
# Step 2a: Prepare DataLoader for TAE Training (Wavelet)
# ============================================================
print("\n" + "="*80)
print("Preparing Wavelet Data for TAE")
print("="*80)

lag_time = 60
dataset_wt = TrajectoryDataset(lag_time, data_wt.astype(np.float32))

n_val_wt = int(len(dataset_wt) * 0.4)
train_data_wt, val_data_wt = torch.utils.data.random_split(
    dataset_wt, [len(dataset_wt) - n_val_wt, n_val_wt]
)

loader_train_wt = DataLoader(train_data_wt, batch_size=32, shuffle=True)
loader_val_wt   = DataLoader(val_data_wt,   batch_size=32, shuffle=False)

print(f"Wavelet - Training samples:   {len(train_data_wt)}")
print(f"Wavelet - Validation samples: {len(val_data_wt)}")


In [None]:
# ============================================================
# Step 2b: Prepare DataLoader for TAE Training (Original)
# ============================================================
print("\n" + "="*80)
print("Preparing Original Data for TAE")
print("="*80)

dataset_orig = TrajectoryDataset(lag_time, data_orig.astype(np.float32))

n_val_orig = int(len(dataset_orig) * 0.4)
train_data_orig, val_data_orig = torch.utils.data.random_split(
    dataset_orig, [len(dataset_orig) - n_val_orig, n_val_orig]
)

loader_train_orig = DataLoader(train_data_orig, batch_size=32, shuffle=True)
loader_val_orig   = DataLoader(val_data_orig,   batch_size=32, shuffle=False)

print(f"Original - Training samples:   {len(train_data_orig)}")
print(f"Original - Validation samples: {len(val_data_orig)}")

In [None]:
# ============================================================
# Step 3: Define TAE Model Architecture
#   encoder: 6 -> 12 -> 12 -> 2 (with Tanh activation)
#   decoder: 2 -> 12 -> 12 -> 6 (reverse of encoder)
# ============================================================
units = [6, 12, 12, 2]

def create_tae_model():
    encoder = MLP(
        units,
        nonlinearity=torch.nn.Tanh,
        output_nonlinearity=torch.nn.Tanh,
        initial_batchnorm=False
    )
    decoder = MLP(
        units[::-1],  # reverse: [2, 12, 12, 6]
        nonlinearity=torch.nn.Tanh,
        initial_batchnorm=False
    )
    return TAE(encoder, decoder, optimizer='Adam', learning_rate=1e-4)

In [None]:
# ============================================================
# Step 4a: Train TAE Model (Wavelet)
# ============================================================
print("\n" + "="*80)
print("Training TAE Model - Wavelet Data")
print("="*80)

tae_wt = create_tae_model()
tae_wt.fit(loader_train_wt, validation_loader=loader_val_wt, n_epochs=100)
print("Wavelet TAE training completed.")

# --- plot training & validation loss ---
plt.figure(figsize=(10, 6))
plt.semilogy(*tae_wt.train_losses.T, label='Train Loss (Wavelet)')
plt.semilogy(*tae_wt.validation_losses.T, label='Validation Loss (Wavelet)')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
plt.title('TAE Training & Validation Loss (Wavelet)')
plt.savefig('tae_training_loss_wavelet.png', dpi=300, bbox_inches='tight')
plt.show()

# --- save training losses ---
np.save('wt3-150/wt3-tae/tae-new/wt3-60/train_losses.npy', np.array(tae_wt.train_losses))
np.save('wt3-150/wt3-tae/tae-new/wt3-60/val_losses.npy', np.array(tae_wt.validation_losses))

# --- save model ---
tae_model_wt = tae_wt.fetch_model()
torch.save(tae_model_wt, 'wt3-150/wt3-tae/tae-new/wt3-60/tae_model.pth')
print("Wavelet model saved to: wt3-150/wt3-tae/tae-new/wt3-60/tae_model.pth")

In [None]:
# ============================================================
# Step 4b: Train TAE Model (Original)
# ============================================================
print("\n" + "="*80)
print("Training TAE Model - Original Data")
print("="*80)

tae_orig = create_tae_model()
tae_orig.fit(loader_train_orig, validation_loader=loader_val_orig, n_epochs=100)
print("Original TAE training completed.")

# --- plot training & validation loss ---
plt.figure(figsize=(10, 6))
plt.semilogy(*tae_orig.train_losses.T, label='Train Loss (Original)')
plt.semilogy(*tae_orig.validation_losses.T, label='Validation Loss (Original)')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
plt.title('TAE Training & Validation Loss (Original)')
plt.savefig('tae_training_loss_original.png', dpi=300, bbox_inches='tight')
plt.show()

# --- save training losses ---
np.save('wt3-150/wt3-tae/tae-new/orin-60/train_losses.npy', np.array(tae_orig.train_losses))
np.save('wt3-150/wt3-tae/tae-new/orin-60/val_losses.npy', np.array(tae_orig.validation_losses))

# --- save model ---
tae_model_orig = tae_orig.fetch_model()
torch.save(tae_model_orig, 'wt3-150/wt3-tae/tae-new/orin-60/tae_model.pth')
print("Original model saved to: wt3-150/wt3-tae/tae-new/orin-60/tae_model.pth")


In [None]:
# ============================================================
# Step 5: Transform Data to TAE Latent Space
# ============================================================
# --- wavelet ---
dcv_wt = tae_model_wt.transform(data_wt)
print(f"\nWavelet TAE latent space shape: {dcv_wt.shape}")

plt.figure(figsize=(8, 6))
plt.scatter(dcv_wt[:, 0], dcv_wt[:, 1], s=0.1, alpha=0.1)
plt.xlabel('dCV1', fontsize=16)
plt.ylabel('dCV2', fontsize=16)
plt.title('TAE Latent Space (Wavelet)', fontsize=16, fontweight='bold')
plt.savefig('tae_latent_space_wavelet.png', dpi=300, bbox_inches='tight')
plt.show()

# --- original ---
dcv_orig = tae_model_orig.transform(data_orig)
print(f"Original TAE latent space shape: {dcv_orig.shape}")

plt.figure(figsize=(8, 6))
plt.scatter(dcv_orig[:, 0], dcv_orig[:, 1], s=0.1, alpha=0.1)
plt.xlabel('dCV1', fontsize=16)
plt.ylabel('dCV2', fontsize=16)
plt.title('TAE Latent Space (Original)', fontsize=16, fontweight='bold')
plt.savefig('tae_latent_space_original.png', dpi=300, bbox_inches='tight')
plt.show()


In [None]:
# ============================================================
# Step 6a: HDBSCAN Clustering on TAE Latent Space (Wavelet)
# ============================================================
print("\n" + "="*80)
print("HDBSCAN Clustering - Wavelet TAE")
print("="*80)

cluster_wt = hdbscan.HDBSCAN(
    min_cluster_size=600,
    min_samples=100,
    core_dist_n_jobs=6,
    cluster_selection_method='eom',
    gen_min_span_tree=True
)
cluster_wt.fit(dcv_wt)

labels_wt = cluster_wt.labels_
mask_wt = labels_wt != -1


n_clusters_wt = len(set(labels_wt)) - (1 if -1 in labels_wt else 0)
n_noise_wt = np.sum(labels_wt == -1)
print(f"Total points:          {len(labels_wt)}")
print(f"Noise points (-1):     {n_noise_wt} ({n_noise_wt/len(labels_wt)*100:.2f}%)")
print(f"Clustered points:      {np.sum(mask_wt)} ({np.sum(mask_wt)/len(labels_wt)*100:.2f}%)")
print(f"Number of clusters:    {n_clusters_wt}")
for k in set(labels_wt):
    if k != -1:
        count = np.sum(labels_wt == k)
        print(f"  Cluster {k}: {count} points ({count/len(labels_wt)*100:.2f}%)")

# --- plot: Wavelet TAE latent space with clustering ---
plt.figure(figsize=(8, 6))
# Noise point (gray)
plt.scatter(dcv_wt[~mask_wt, 0], dcv_wt[~mask_wt, 1],
            c='lightgray', s=1, alpha=0.3, label='Noise')
# Cluster points (color)
if mask_wt.sum() > 0:
    plt.scatter(dcv_wt[mask_wt, 0], dcv_wt[mask_wt, 1],
                c=labels_wt[mask_wt], s=1, alpha=0.8, cmap='tab20')
    plt.colorbar(label='Cluster Label')
plt.xlabel('dCV1', fontsize=16)
plt.ylabel('dCV2', fontsize=16)
plt.title(f'Wavelet TAE + HDBSCAN (n_clusters={n_clusters_wt})', fontsize=16, fontweight='bold')
plt.savefig('tae_hdbscan_clustering_wavelet.png', dpi=300, bbox_inches='tight')
plt.show()


In [None]:
# ============================================================
# Step 6b: HDBSCAN Clustering on TAE Latent Space (Original)
# ============================================================
print("\n" + "="*80)
print("HDBSCAN Clustering - Original TAE")
print("="*80)

cluster_orig = hdbscan.HDBSCAN(
    min_cluster_size=600,
    min_samples=100,
    core_dist_n_jobs=6,
    cluster_selection_method='eom',
    gen_min_span_tree=True
)
cluster_orig.fit(dcv_orig)

labels_orig = cluster_orig.labels_
mask_orig = labels_orig != -1


n_clusters_orig = len(set(labels_orig)) - (1 if -1 in labels_orig else 0)
n_noise_orig = np.sum(labels_orig == -1)
print(f"Total points:          {len(labels_orig)}")
print(f"Noise points (-1):     {n_noise_orig} ({n_noise_orig/len(labels_orig)*100:.2f}%)")
print(f"Clustered points:      {np.sum(mask_orig)} ({np.sum(mask_orig)/len(labels_orig)*100:.2f}%)")
print(f"Number of clusters:    {n_clusters_orig}")
for k in set(labels_orig):
    if k != -1:
        count = np.sum(labels_orig == k)
        print(f"  Cluster {k}: {count} points ({count/len(labels_orig)*100:.2f}%)")

# --- plot: Original TAE latent space with clustering ---
plt.figure(figsize=(8, 6))
# Noise point (gray)
plt.scatter(dcv_orig[~mask_orig, 0], dcv_orig[~mask_orig, 1],
            c='lightgray', s=1, alpha=0.3, label='Noise')
# Cluster points (color)
if mask_orig.sum() > 0:
    plt.scatter(dcv_orig[mask_orig, 0], dcv_orig[mask_orig, 1],
                c=labels_orig[mask_orig], s=1, alpha=0.8, cmap='tab20')
    plt.colorbar(label='Cluster Label')
plt.xlabel('dCV1', fontsize=16)
plt.ylabel('dCV2', fontsize=16)
plt.title(f'Original TAE + HDBSCAN (n_clusters={n_clusters_orig})', fontsize=16, fontweight='bold')
plt.savefig('tae_hdbscan_clustering_original.png', dpi=300, bbox_inches='tight')
plt.show()

In [None]:
# ============================================================
# Step 7: Project Clustering Results onto Dihedral Angle Space
# ============================================================
# --- 7a: Wavelet ---
print("\nPlotting Wavelet TAE clusters in phi space...")
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# phi1 vs phi2
ax = axes[0]
ax.scatter(phi1[~mask_wt], phi2[~mask_wt], c='lightgray', s=1, alpha=0.3)
if mask_wt.sum() > 0:
    ax.scatter(phi1[mask_wt], phi2[mask_wt], c=labels_wt[mask_wt], s=1, alpha=0.8, cmap='tab20')
ax.set_xlabel('φ1 (degrees)', fontsize=16)
ax.set_ylabel('φ2 (degrees)', fontsize=16)
ax.set_xlim(-180, 180)
ax.set_ylim(-180, 180)
ax.set_title('φ1 vs φ2', fontsize=14, fontweight='bold')

# phi1 vs phi3
ax = axes[1]
ax.scatter(phi1[~mask_wt], phi3[~mask_wt], c='lightgray', s=1, alpha=0.3)
if mask_wt.sum() > 0:
    ax.scatter(phi1[mask_wt], phi3[mask_wt], c=labels_wt[mask_wt], s=1, alpha=0.8, cmap='tab20')
ax.set_xlabel('φ1 (degrees)', fontsize=16)
ax.set_ylabel('φ3 (degrees)', fontsize=16)
ax.set_xlim(-180, 180)
ax.set_ylim(-180, 180)
ax.set_title('φ1 vs φ3', fontsize=14, fontweight='bold')

# phi2 vs phi3
ax = axes[2]
ax.scatter(phi2[~mask_wt], phi3[~mask_wt], c='lightgray', s=1, alpha=0.3)
if mask_wt.sum() > 0:
    sc = ax.scatter(phi2[mask_wt], phi3[mask_wt], c=labels_wt[mask_wt], s=1, alpha=0.8, cmap='tab20')
    plt.colorbar(sc, ax=ax, label='Cluster Label')
ax.set_xlabel('φ2 (degrees)', fontsize=16)
ax.set_ylabel('φ3 (degrees)', fontsize=16)
ax.set_xlim(-180, 180)
ax.set_ylim(-180, 180)
ax.set_title('φ2 vs φ3', fontsize=14, fontweight='bold')

plt.suptitle('Wavelet TAE Clustering in Dihedral Angle Space', fontsize=18, fontweight='bold')
plt.tight_layout()
plt.savefig('tae_hdbscan_phi_space_wavelet.png', dpi=300, bbox_inches='tight')
plt.show()

# --- 7b: Original ---
print("Plotting Original TAE clusters in phi space...")
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# phi1 vs phi2
ax = axes[0]
ax.scatter(phi1[~mask_orig], phi2[~mask_orig], c='lightgray', s=1, alpha=0.3)
if mask_orig.sum() > 0:
    ax.scatter(phi1[mask_orig], phi2[mask_orig], c=labels_orig[mask_orig], s=1, alpha=0.8, cmap='tab20')
ax.set_xlabel('φ1 (degrees)', fontsize=16)
ax.set_ylabel('φ2 (degrees)', fontsize=16)
ax.set_xlim(-180, 180)
ax.set_ylim(-180, 180)
ax.set_title('φ1 vs φ2', fontsize=14, fontweight='bold')

# phi1 vs phi3
ax = axes[1]
ax.scatter(phi1[~mask_orig], phi3[~mask_orig], c='lightgray', s=1, alpha=0.3)
if mask_orig.sum() > 0:
    ax.scatter(phi1[mask_orig], phi3[mask_orig], c=labels_orig[mask_orig], s=1, alpha=0.8, cmap='tab20')
ax.set_xlabel('φ1 (degrees)', fontsize=16)
ax.set_ylabel('φ3 (degrees)', fontsize=16)
ax.set_xlim(-180, 180)
ax.set_ylim(-180, 180)
ax.set_title('φ1 vs φ3', fontsize=14, fontweight='bold')

# phi2 vs phi3
ax = axes[2]
ax.scatter(phi2[~mask_orig], phi3[~mask_orig], c='lightgray', s=1, alpha=0.3)
if mask_orig.sum() > 0:
    sc = ax.scatter(phi2[mask_orig], phi3[mask_orig], c=labels_orig[mask_orig], s=1, alpha=0.8, cmap='tab20')
    plt.colorbar(sc, ax=ax, label='Cluster Label')
ax.set_xlabel('φ2 (degrees)', fontsize=16)
ax.set_ylabel('φ3 (degrees)', fontsize=16)
ax.set_xlim(-180, 180)
ax.set_ylim(-180, 180)
ax.set_title('φ2 vs φ3', fontsize=14, fontweight='bold')

plt.suptitle('Original TAE Clustering in Dihedral Angle Space', fontsize=18, fontweight='bold')
plt.tight_layout()
plt.savefig('tae_hdbscan_phi_space_original.png', dpi=300, bbox_inches='tight')
plt.show()


In [None]:
# ============================================================
# Step 8: Save Clustering Results
# ============================================================
print("\n" + "="*80)
print("Saving Clustering Results")
print("="*80)

# --- Wavelet TAE results ---
# TAE space
np.savetxt('tae-noise/wt3_noise_tae.txt',
           np.column_stack((dcv_wt[~mask_wt, 0], dcv_wt[~mask_wt, 1])),
           delimiter='\t', header='dCV1\tdCV2', comments='')

np.savetxt('tae-noise/wt3_clusters_tae.txt',
           np.column_stack((dcv_wt[mask_wt, 0], dcv_wt[mask_wt, 1], labels_wt[mask_wt])),
           delimiter='\t', header='dCV1\tdCV2\tCluster_Label', comments='')

np.savetxt('wt3-150/wt3-tae/tae-new/wt3-60/wt3_cluster-tae-60.txt',
           np.column_stack((dcv_wt[mask_wt, 0], dcv_wt[mask_wt, 1], labels_wt[mask_wt])),
           delimiter='\t', comments='')

# phi space
np.savetxt('tae-noise/wt3_noise_phi23.txt',
           np.column_stack((phi2[~mask_wt], phi3[~mask_wt])),
           delimiter='\t', header='phi2\tphi3', comments='')

np.savetxt('tae-noise/wt3_clusters_phi23.txt',
           np.column_stack((phi2[mask_wt], phi3[mask_wt], labels_wt[mask_wt])),
           delimiter='\t', header='phi2\tphi3\tCluster_Label', comments='')

np.savetxt('wt3-150/wt3-tae/tae-new/wt3-60/tri-wt3-cluster-phi13.txt',
           np.column_stack((phi1[mask_wt], phi3[mask_wt], labels_wt[mask_wt])),
           delimiter='\t', comments='')

print("Wavelet TAE clustering results saved.")

# --- Original TAE results ---
# TAE space
np.savetxt('tae-noise/orin_noise_tae.txt',
           np.column_stack((dcv_orig[~mask_orig, 0], dcv_orig[~mask_orig, 1])),
           delimiter='\t', header='dCV1\tdCV2', comments='')

np.savetxt('tae-noise/orin_clusters_tae.txt',
           np.column_stack((dcv_orig[mask_orig, 0], dcv_orig[mask_orig, 1], labels_orig[mask_orig])),
           delimiter='\t', header='dCV1\tdCV2\tCluster_Label', comments='')

np.savetxt('wt3-150/wt3-tae/tae-new/orin-60/orin_cluster-tae-60.txt',
           np.column_stack((dcv_orig[mask_orig, 0], dcv_orig[mask_orig, 1], labels_orig[mask_orig])),
           delimiter='\t', comments='')

# phi space
np.savetxt('tae-noise/orin_noise_phi23.txt',
           np.column_stack((phi2[~mask_orig], phi3[~mask_orig])),
           delimiter='\t', header='phi2\tphi3', comments='')

np.savetxt('tae-noise/orin_clusters_phi23.txt',
           np.column_stack((phi2[mask_orig], phi3[mask_orig], labels_orig[mask_orig])),
           delimiter='\t', header='phi2\tphi3\tCluster_Label', comments='')

np.savetxt('wt3-150/wt3-tae/tae-new/orin-60/orin_cluster-phi13.txt',
           np.column_stack((phi1[mask_orig], phi3[mask_orig], labels_orig[mask_orig])),
           delimiter='\t', comments='')

print("Original TAE clustering results saved.")

In [None]:
# ============================================================
# Step 9: Save Model Parameters as Text Files
# ============================================================
def save_model_parameters(model, save_path, name):
    """Extract and save all model parameters (weights & biases) as text files"""
    def save_params(module, module_name):
        for param_name, param in module.named_parameters():
            param_array = param.detach().cpu().numpy()
            param_type = 'weights' if 'weight' in param_name else 'biases'
            filename = f'{module_name}_{param_name.replace(".", "_")}_{param_type}.txt'
            np.savetxt(os.path.join(save_path, filename), param_array)
    
    print(f"\nExtracting {name} model parameters:")
    save_params(model.encoder, "encoder")
    save_params(model.decoder, "decoder")
    print(f"  Saved to: {save_path}")

save_model_parameters(tae_model_wt, 'wt3-150/wt3-tae/tae-new/wt3-60/', "Wavelet")
save_model_parameters(tae_model_orig, 'wt3-150/wt3-tae/tae-new/orin-60/', "Original")

print("\n" + "="*80)
print("All tasks completed successfully!")
print("="*80)