In [None]:
import os
os.environ['KERAS_BACKEND'] = 'jax'

import numpy as np
import keras
from umap import UMAP
import hdbscan
import matplotlib.pyplot as plt

# Import custom layers needed to load the encoder
import sys
sys.path.insert(0, '.')
from VAE import Sampling, ClipLayer

In [None]:
# ========== LOAD DATA ==========
data_path = './Data/multimer_frequencies_l5000_shuffled.npy'
data = np.load(data_path)
print(f'Loaded {len(data)} samples with {data.shape[1]} features')

# Transform to log-space (same as training)
X = np.log(data.astype(np.float32) + 1e-6)
print(f'Log-transformed: Min {X.min():.2f}, Max {X.max():.2f}, Mean {X.mean():.2f}')

In [None]:
# ========== LOAD ENCODER AND GENERATE EMBEDDINGS ==========
encoder = keras.models.load_model(
    'vae_encoder_best.keras',
    custom_objects = {'Sampling': Sampling, 'ClipLayer': ClipLayer}
)
print('Encoder loaded successfully')

# Get 256-dimensional embeddings (z_mean for deterministic embedding)
z_mean, z_log_var, z = encoder.predict(X, batch_size = 4096, verbose = 1)
print(f'Embeddings shape: {z_mean.shape}')

In [None]:
# ========== UMAP REDUCTION FOR CLUSTERING ==========
print('Running UMAP to 20 dimensions for clustering...')
reducer_20d = UMAP(
    n_components = 20,
    n_neighbors = 15,
    min_dist = 0.0,
    metric = 'euclidean',
    random_state = 42,
    verbose = True
)
embedding_20d = reducer_20d.fit_transform(z_mean)
print(f'UMAP 20D embedding shape: {embedding_20d.shape}')

# ========== HDBSCAN CLUSTERING ==========
print('Running HDBSCAN clustering...')
clusterer = hdbscan.HDBSCAN(
    min_cluster_size = 15,
    min_samples = None,  # defaults to min_cluster_size
    metric = 'euclidean',
    cluster_selection_method = 'eom'
)
labels = clusterer.fit_predict(embedding_20d)

n_clusters = len(set(labels)) - (1 if -1 in labels else 0)
n_noise = (labels == -1).sum()
print(f'Found {n_clusters} clusters')
print(f'Noise points: {n_noise} ({100 * n_noise / len(labels):.1f}%)')

In [None]:
# ========== UMAP TO 2D FOR VISUALIZATION ==========
print('Running UMAP to 2D for visualization...')
reducer_2d = UMAP(
    n_components = 2,
    n_neighbors = 15,
    min_dist = 0.1,
    metric = 'euclidean',
    random_state = 42,
    verbose = True
)
embedding_2d = reducer_2d.fit_transform(z_mean)
print(f'UMAP 2D embedding shape: {embedding_2d.shape}')

In [None]:
# ========== VISUALIZATION ==========
fig, ax = plt.subplots(figsize = (12, 10))

# Create colormap - use gray for noise (-1), colors for clusters
unique_labels = sorted(set(labels))
colors = plt.cm.tab20(np.linspace(0, 1, max(20, n_clusters)))

# Plot noise points first (in gray)
noise_mask = labels == -1
if noise_mask.any():
    ax.scatter(
        embedding_2d[noise_mask, 0],
        embedding_2d[noise_mask, 1],
        c = 'lightgray',
        s = 1,
        alpha = 0.3,
        label = f'Noise ({n_noise})'
    )

# Plot each cluster
for i, cluster_id in enumerate([l for l in unique_labels if l != -1]):
    mask = labels == cluster_id
    ax.scatter(
        embedding_2d[mask, 0],
        embedding_2d[mask, 1],
        c = [colors[i % len(colors)]],
        s = 2,
        alpha = 0.6,
        label = f'Cluster {cluster_id} ({mask.sum()})'
    )

ax.set_xlabel('UMAP 1')
ax.set_ylabel('UMAP 2')
ax.set_title(f'VAE Embeddings: {n_clusters} clusters, {n_noise} noise points ({len(labels)} total)')

# Only show legend if not too many clusters
if n_clusters <= 20:
    ax.legend(loc = 'best', markerscale = 3, fontsize = 8)

plt.tight_layout()
plt.savefig('vae_clusters.png', dpi = 300, bbox_inches = 'tight')
plt.show()
print('Saved visualization to vae_clusters.png')

In [None]:
# ========== SAVE RESULTS ==========
np.save('vae_embeddings_256d.npy', z_mean)
np.save('umap_embedding_20d.npy', embedding_20d)
np.save('umap_embedding_2d.npy', embedding_2d)
np.save('cluster_labels.npy', labels)

print('Saved:')
print(f'  vae_embeddings_256d.npy - {z_mean.shape}')
print(f'  umap_embedding_20d.npy - {embedding_20d.shape}')
print(f'  umap_embedding_2d.npy - {embedding_2d.shape}')
print(f'  cluster_labels.npy - {labels.shape}')