In [25]:
# PCA -> GMM (run on first N PCs) + visualization (first 2 PCs)
import os
import numpy as np
import pandas as pd
from pathlib import Path
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.cluster import KMeans
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score
from sklearn.decomposition import PCA
from sklearn.metrics import f1_score, classification_report
from sklearn.mixture import GaussianMixture
from sklearn import metrics
# ----------------- CONFIG -----------------
DATA_DIR = Path("desc_data_no_pca")
X_path = os.path.join(DATA_DIR, "X_scaled_hvg.npy")
meta_csv = os.path.join(DATA_DIR, "metadata_subset.csv")
SAMPLE_NAMES_PATH = None  # optional alignment file

OUT_DIR = "pca_gmm_results"
os.makedirs(OUT_DIR, exist_ok=True)

random_state = 42
pca_n_components_for_clustering = 200  # <-- use more PCs for GMM
pca_n_components_for_plot = 2         # <-- for plotting

# ----------------- Load data -----------------
print("Loading X...")
X = np.load(X_path)
n_samples, input_dim = X.shape
print(f"X: {n_samples:,} cells × {input_dim:,} features")

# ----------------- Load metadata -----------------
if not os.path.exists(meta_csv):
    raise FileNotFoundError(f"Metadata CSV not found: {meta_csv}")
meta = pd.read_csv(meta_csv)
print(f"Loaded metadata CSV: {len(meta)} rows")

# Align metadata if SAMPLE_NAMES_PATH provided
if SAMPLE_NAMES_PATH and os.path.exists(SAMPLE_NAMES_PATH):
    if SAMPLE_NAMES_PATH.endswith(".npy"):
        sample_names = list(np.load(SAMPLE_NAMES_PATH))
    else:
        with open(SAMPLE_NAMES_PATH) as f:
            sample_names = [line.strip() for line in f if line.strip()]
    if len(sample_names) != n_samples:
        raise ValueError(f"SAMPLE_NAMES length ({len(sample_names)}) != X rows ({n_samples})")
    meta_indexed = meta.set_index('sample_name')
    meta = meta_indexed.loc[sample_names].reset_index()
    print("Metadata aligned using SAMPLE_NAMES_PATH.")

# ----------------- Determine k -----------------
if 'subclass_label' in meta.columns:
    le = LabelEncoder()
    true_labels = meta['subclass_label'].astype(str).values
    y_int = le.fit_transform(true_labels)
    k = len(le.classes_)
    print(f"Detected {k} subclasses (using k = {k} for KMeans).")
else:
    true_labels = None
    y_int = None
    k = 10
    print("No 'subclass_label' found; using default k = 10.")

# ----------------- PCA -----------------
print(f"Running PCA ({pca_n_components_for_clustering} PCs for clustering)...")
pca = PCA(n_components=pca_n_components_for_clustering, random_state=random_state)
Z_full = pca.fit_transform(X)
print(f"PCA done. Shape for clustering: {Z_full.shape}")
print(f"Explained variance ratio (first {pca_n_components_for_plot} PCs): {pca.explained_variance_ratio_[:pca_n_components_for_plot].sum():.4f}")

# Save PCA coords for all components
pca_out = os.path.join(OUT_DIR, "pca_coords_full.npy")
np.save(pca_out, Z_full)


gmm = GaussianMixture(
    n_components=k,
    covariance_type="diag",
    max_iter=300,
    init_params="kmeans",
)
pred_labels = gmm.fit_predict(Z_full)

ari = metrics.adjusted_rand_score(true_labels, pred_labels)
nmi = metrics.normalized_mutual_info_score(true_labels, pred_labels)
hom = metrics.homogeneity_score(true_labels, pred_labels)
comp = metrics.completeness_score(true_labels, pred_labels)
print(f" GMM ari: {ari:.4f}")
print(f" GMM nmi: {nmi:.4f}")
print(f" GMM final loss: {gmm.lower_bound_:.4f}")

labels_out = os.path.join(OUT_DIR, "gmm_labels_on_pca.npy")
np.save(labels_out, pred_labels)
print("Saved GMM labels.")

# ----------------- Macro/Micro F1 -----------------
if true_labels is not None:
    df = pd.DataFrame({"cluster": pred_labels, "true": true_labels})
    cluster_to_label = df.groupby("cluster")["true"].agg(lambda x: x.value_counts().index[0]).to_dict()
    predicted_labels = np.array([cluster_to_label[c] for c in pred_labels])
    le2 = LabelEncoder()
    y_true_int = le2.fit_transform(true_labels)
    y_pred_int = le2.transform(predicted_labels)
    macro_f1 = f1_score(y_true_int, y_pred_int, average="macro")
    micro_f1 = f1_score(y_true_int, y_pred_int, average="micro")
    print(f"Macro-F1: {macro_f1:.4f}, Micro-F1: {micro_f1:.4f}")
    print("\nClassification report:")
    print(classification_report(y_true_int, y_pred_int, digits=4))

# ----------------- Plot first 2 PCs -----------------
# Z_plot = Z_full[:, :pca_n_components_for_plot]
# fig, axes = plt.subplots(1, 2, figsize=(14, 6), constrained_layout=True)

# ax = axes[0]
# ax.scatter(Z_plot[:,0], Z_plot[:,1], c=pred_labels, s=6, cmap='tab20', linewidth=0, alpha=0.8)
# ax.set_title(f"PCA (GMM k={k})")
# ax.set_xlabel("PC1"); ax.set_ylabel("PC2")

# ax = axes[1]
# if true_labels is not None:
#     n_true = len(np.unique(y_int))
#     ax.scatter(Z_plot[:,0], Z_plot[:,1], c=y_int, s=6, cmap='tab20', linewidth=0, alpha=0.8)
#     ax.set_title(f"PCA (ground truth: {n_true} subclasses)")
# else:
#     ax.text(0.5,0.5,"No ground truth", ha='center', va='center')
#     ax.set_title("Ground truth missing")
# ax.set_xlabel("PC1"); ax.set_ylabel("PC2")

# plt_path = os.path.join(OUT_DIR, "pca_gmm_vs_truth.png")
# plt.savefig(plt_path, dpi=150)
# plt.show()
# print("Saved PCA figure.")


Loading X...
X: 100,000 cells × 2,335 features
Loaded metadata CSV: 100000 rows
Detected 42 subclasses (using k = 42 for KMeans).
Running PCA (200 PCs for clustering)...
PCA done. Shape for clustering: (100000, 200)
Explained variance ratio (first 2 PCs): 0.0298
 GMM ari: 0.3250
 GMM nmi: 0.6840
 GMM final loss: -300.5599
Saved GMM labels.
Macro-F1: 0.4402, Micro-F1: 0.8177

Classification report:
              precision    recall  f1-score   support

           0     0.9905    0.9874    0.9889       317
           1     0.9081    0.7780    0.8380      1410
           2     0.0000    0.0000    0.0000        23
           3     0.0000    0.0000    0.0000       140
           4     0.0000    0.0000    0.0000        20
           5     0.0000    0.0000    0.0000       492
           6     0.9930    0.9539    0.9730      1929
           7     0.9996    0.9761    0.9877      5013
           8     0.0000    0.0000    0.0000        67
           9     0.0000    0.0000    0.0000       381
    

  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


In [21]:
# Visualize clusters vs. true labels with t-SNE (also save figures)
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
from pathlib import Path

tsne = TSNE(n_components=2, perplexity=30, init='pca', learning_rate='auto')
X_tsne = tsne.fit_transform(Z_full)

true_cat = pd.Categorical(true_labels)
true_codes = true_cat.codes
n_true = len(true_cat.categories)
n_pred = pred_labels.max() + 1

fig, axs = plt.subplots(1, 2, figsize=(14, 6), dpi=120, sharex=True, sharey=True)
cmap = plt.cm.get_cmap('tab20', max(n_true, n_pred))

axs[0].scatter(X_tsne[:, 0], X_tsne[:, 1], c=true_codes, s=5, cmap=cmap, alpha=0.7)
axs[0].set_title('True labels (t-SNE)')
axs[0].set_xlabel('t-SNE 1')
axs[0].set_ylabel('t-SNE 2')

axs[1].scatter(X_tsne[:, 0], X_tsne[:, 1], c=pred_labels, s=5, cmap=cmap, alpha=0.7)
axs[1].set_title('GMM clusters (t-SNE)')
axs[1].set_xlabel('t-SNE 1')
axs[1].set_ylabel('t-SNE 2')

plt.tight_layout()
plots_dir = Path('tsne_plots')
plots_dir.mkdir(exist_ok=True)
fig_path = plots_dir / 'tsne_true_vs_gmm.png'
plt.savefig(fig_path, dpi=200)
print(f'Saved t-SNE comparison plot to {fig_path}')
plt.show()


KeyboardInterrupt: 

In [27]:
# GMM on data with additive Gaussian noise (mean=1, var=1) in PCA space
noise = np.random.normal(loc=1.0, scale=1.0, size=Z_full.shape)
Z_noisy = Z_full + noise

gmm_noisy = GaussianMixture(
    n_components=k,
    covariance_type='diag',
    random_state=random_state,
    max_iter=300,
    init_params='kmeans',
    reg_covar=1e-3,
)
pred_noisy = gmm_noisy.fit_predict(Z_noisy)

nmi_noisy = metrics.normalized_mutual_info_score(true_labels, pred_noisy)
ari_noisy = metrics.adjusted_rand_score(true_labels, pred_noisy)
hom_noisy = metrics.homogeneity_score(true_labels, pred_noisy)
comp_noisy = metrics.completeness_score(true_labels, pred_noisy)
sil_noisy = None
try:
    sil_noisy = metrics.silhouette_score(Z_noisy, pred_noisy, sample_size=min(2000, len(pred_noisy)), random_state=random_state)
except Exception as e:
    print(f'Silhouette not computed (noisy): {e}')

print('GMM on noisy PCA data metrics:')
print(f'  ARI: {ari_noisy:.4f}')
print(f'  NMI: {nmi_noisy:.4f}')
print(f'  Homogeneity: {hom_noisy:.4f}')
print(f'  Completeness: {comp_noisy:.4f}')
print(f" GMM noisy final loss: {gmm_noisy.lower_bound_:.4f}")
if sil_noisy is not None:
    print(f'  Silhouette: {sil_noisy:.4f}')
print('Cluster sizes (noisy):')
print(pd.Series(pred_noisy).value_counts().sort_index())

# t-SNE visualization for noisy data
# true_cat = pd.Categorical(true_labels)
# true_codes = true_cat.codes
# n_true = len(true_cat.categories)
# n_pred_noisy = pred_noisy.max() + 1
# tsne_noisy = TSNE(n_components=2, perplexity=30, random_state=random_state, init='pca', learning_rate='auto')
# X_tsne_noisy = tsne_noisy.fit_transform(Z_noisy)

# fig, axs = plt.subplots(1, 2, figsize=(14, 6), dpi=120, sharex=True, sharey=True)
# cmap_noisy = plt.cm.get_cmap('tab20', max(n_true, n_pred_noisy))
# axs[0].scatter(X_tsne_noisy[:, 0], X_tsne_noisy[:, 1], c=true_codes, s=5, cmap=cmap_noisy, alpha=0.7)
# axs[0].set_title('True labels (noisy t-SNE)')
# axs[0].set_xlabel('t-SNE 1')
# axs[0].set_ylabel('t-SNE 2')

# axs[1].scatter(X_tsne_noisy[:, 0], X_tsne_noisy[:, 1], c=pred_noisy, s=5, cmap=cmap_noisy, alpha=0.7)
# axs[1].set_title('GMM clusters (noisy t-SNE)')
# axs[1].set_xlabel('t-SNE 1')
# axs[1].set_ylabel('t-SNE 2')

# plt.tight_layout()
# plots_dir = Path('tsne_plots')
# plots_dir.mkdir(exist_ok=True)
# noisy_fig_path = plots_dir / 'tsne_noisy_true_vs_gmm.png'
# plt.savefig(noisy_fig_path, dpi=200)
# print(f'Saved noisy t-SNE comparison plot to {noisy_fig_path}')
# plt.show()


GMM on noisy PCA data metrics:
  ARI: 0.3167
  NMI: 0.6558
  Homogeneity: 0.7572
  Completeness: 0.5783
 GMM noisy final loss: -363.9682
  Silhouette: 0.0117
Cluster sizes (noisy):
0     1576
1     2754
2     5146
3     3147
4     6321
5     3714
6      988
7     3484
8     1146
9     2518
10    2212
11    1215
12    1375
13    1364
14    1923
15    5252
16    1895
17    2733
18    3334
19    3650
20    1634
21    1975
22    1820
23    1884
24    1606
25     831
26    1863
27    1104
28    1088
29    2672
30    2199
31    1485
32    2607
33    3295
34     459
35    1594
36     311
37    3079
38    4317
39    1992
40    3859
41    2579
Name: count, dtype: int64


In [28]:
# GMM with partially masked Gaussian noise (50% zeros)
noise = np.random.normal(loc=1.0, scale=1.0, size=Z_full.shape)
mask = np.random.rand(*noise.shape) < 0.5
noise[~mask] = 0.0
Z_partial = Z_full + noise

gmm_partial = GaussianMixture(
    n_components=k,
    covariance_type='diag',
    random_state=random_state,
    max_iter=300,
    init_params='kmeans',
    reg_covar=1e-3,
)
pred_partial = gmm_partial.fit_predict(Z_partial)

nmi_partial = metrics.normalized_mutual_info_score(true_labels, pred_partial)
ari_partial = metrics.adjusted_rand_score(true_labels, pred_partial)
hom_partial = metrics.homogeneity_score(true_labels, pred_partial)
comp_partial = metrics.completeness_score(true_labels, pred_partial)
sil_partial = None
try:
    sil_partial = metrics.silhouette_score(Z_partial, pred_partial, sample_size=min(2000, len(pred_partial)), random_state=random_state)
except Exception as e:
    print(f'Silhouette not computed (partial noise): {e}')

print('GMM on partially noisy PCA data metrics:')
print(f'  ARI: {ari_partial:.4f}')
print(f'  NMI: {nmi_partial:.4f}')
print(f'  Homogeneity: {hom_partial:.4f}')
print(f'  Completeness: {comp_partial:.4f}')
print(f" GMM partial noisy final loss: {gmm_partial.lower_bound_:.4f}")
if sil_partial is not None:
    print(f'  Silhouette: {sil_partial:.4f}')
print('Cluster sizes (partial noise):')
print(pd.Series(pred_partial).value_counts().sort_index())

# t-SNE visualization for partially noisy data
# true_cat = pd.Categorical(true_labels)
# true_codes = true_cat.codes
# n_true = len(true_cat.categories)
# n_pred_partial = pred_partial.max() + 1
# tsne_partial = TSNE(n_components=2, perplexity=30, random_state=random_state, init='pca', learning_rate='auto')
# X_tsne_partial = tsne_partial.fit_transform(Z_partial)

# fig, axs = plt.subplots(1, 2, figsize=(14, 6), dpi=120, sharex=True, sharey=True)
# cmap_partial = plt.cm.get_cmap('tab20', max(n_true, n_pred_partial))
# axs[0].scatter(X_tsne_partial[:, 0], X_tsne_partial[:, 1], c=true_codes, s=5, cmap=cmap_partial, alpha=0.7)
# axs[0].set_title('True labels (partial-noise t-SNE)')
# axs[0].set_xlabel('t-SNE 1')
# axs[0].set_ylabel('t-SNE 2')

# axs[1].scatter(X_tsne_partial[:, 0], X_tsne_partial[:, 1], c=pred_partial, s=5, cmap=cmap_partial, alpha=0.7)
# axs[1].set_title('GMM clusters (partial-noise t-SNE)')
# axs[1].set_xlabel('t-SNE 1')
# axs[1].set_ylabel('t-SNE 2')

# plt.tight_layout()
# plots_dir = Path('tsne_plots')
# plots_dir.mkdir(exist_ok=True)
# partial_fig_path = plots_dir / 'tsne_partial_noise_true_vs_gmm.png'
# plt.savefig(partial_fig_path, dpi=200)
# print(f'Saved partial-noise t-SNE comparison plot to {partial_fig_path}')
# plt.show()


GMM on partially noisy PCA data metrics:
  ARI: 0.3020
  NMI: 0.6549
  Homogeneity: 0.7564
  Completeness: 0.5774
 GMM partial noisy final loss: -351.6572
  Silhouette: 0.0098
Cluster sizes (partial noise):
0     3552
1     3336
2     5377
3     3977
4     2680
5     2522
6     1106
7     1162
8     4918
9      643
10    2180
11    2494
12    1896
13    3922
14     999
15     889
16    2835
17    1308
18    1225
19    3169
20    1152
21    1808
22    2243
23    1825
24    1802
25    5275
26    2425
27    2684
28    1403
29    1574
30    3372
31    1425
32    4007
33    1811
34    1459
35    1115
36    1682
37    2262
38    5910
39     801
40    2141
41    1634
Name: count, dtype: int64


In [29]:
## SFAA
import numpy as np
from sklearn.base import clone

def compute_pairwise_cluster_indicator(labels: np.ndarray) -> np.ndarray:
    """
    Build an n x n matrix M where M[i, j] = 1 if labels[i] == labels[j], else 0.
    This corresponds to Y Y^T in the paper.
    """
    labels = np.asarray(labels)
    n = labels.shape[0]
    M = np.zeros((n, n), dtype=np.int8)
    for k in np.unique(labels):
        idx = np.where(labels == k)[0]
        M[np.ix_(idx, idx)] = 1
    return M


def evaluate_spillover(
    X: np.ndarray,
    clusterer,
    base_labels: np.ndarray,
    base_indicator: np.ndarray,
    pivot_idx: int,
    delta: np.ndarray,
):
    """
    Apply a perturbation delta to X[pivot_idx], re-run clustering, and compute:
    - new labels
    - spillover score: number of pairwise label-relationship changes

    Spillover score is: || Y Y^T - Y' Y'^T ||_F^2 (we can use sum of squared diff).
    """
    X_pert = X.copy()
    X_pert[pivot_idx] = X[pivot_idx] + delta

    # Re-fit clusterer (black box)
    cl = clone(clusterer)
    pert_labels = cl.fit_predict(X_pert)

    pert_indicator = compute_pairwise_cluster_indicator(pert_labels)

    diff = base_indicator - pert_indicator
    # Frobenius norm squared of the difference:
    spill = np.sum(diff ** 2)

    return spill, pert_labels, X_pert


def pick_pivot_point(
    X: np.ndarray,
    labels: np.ndarray,
    k1: int,
    k2: int,
):
    """
    Among points in cluster k1, pick the one closest to the centroid of cluster k2.
    This is the 'most boundary-like' point along k1->k2 direction.
    """
    X = np.asarray(X)
    k1_idx = np.where(labels == k1)[0]
    k2_idx = np.where(labels == k2)[0]

    if len(k1_idx) == 0 or len(k2_idx) == 0:
        raise ValueError("One of the clusters is empty; cannot pick pivot.")

    c2 = X[k2_idx].mean(axis=0)

    # index (in k1_idx) of the point closest to c2
    local_best = np.argmin(np.linalg.norm(X[k1_idx] - c2, axis=1))
    pivot_idx = k1_idx[local_best]
    return pivot_idx


def random_delta(d: int, delta_max: np.ndarray) -> np.ndarray:
    """
    Sample a random perturbation in [-delta_max, delta_max] elementwise.
    delta_max can be scalar or vector of length d.
    """
    delta_max = np.asarray(delta_max)
    if delta_max.size == 1:
        delta_max = np.full(d, float(delta_max))
    return np.random.uniform(-delta_max, delta_max, size=d)


def attack_pair(
    X: np.ndarray,
    clusterer,
    k1: int,
    k2: int,
    delta_max,
    n_iters: int = 100,
    n_candidates_per_iter: int = 10,
    random_state: int | None = None,
):
    """
    Perform a one-point black-box attack targeting cluster k1 -> k2.

    Parameters
    ----------
    X : (n, d) array
        Data matrix.
    clusterer : sklearn-like estimator
        Any clustering algorithm with fit_predict(X).
    k1 : int
        Source cluster index (where pivot point starts).
    k2 : int
        Target cluster index.
    delta_max : float or (d,) array
        Max perturbation magnitude per feature (L_infinity ball).
    n_iters : int
        Number of optimization iterations.
    n_candidates_per_iter : int
        Random candidates per iteration.
    random_state : int or None
        Seed for reproducibility.

    Returns
    -------
    best_result : dict
        Contains:
            - 'pivot_idx'
            - 'best_delta'
            - 'best_spill'
            - 'base_labels'
            - 'adv_labels'
            - 'X_adv'
            - 'k1', 'k2'
    """
    rng = np.random.default_rng(random_state)

    # Fit base clustering once
    cl0 = clone(clusterer)
    base_labels = cl0.fit_predict(X)
    base_indicator = compute_pairwise_cluster_indicator(base_labels)

    # Pick pivot point in k1 closest to centroid of k2
    pivot_idx = pick_pivot_point(X, base_labels, k1, k2)
    d = X.shape[1]

    best_spill = -np.inf
    best_delta = np.zeros(d)
    best_labels = base_labels
    best_X_adv = X

    for it in range(n_iters):
        for _ in range(n_candidates_per_iter):
            delta = random_delta(d, delta_max)
            spill, pert_labels, X_pert = evaluate_spillover(
                X, clusterer, base_labels, base_indicator, pivot_idx, delta
            )

            if spill > best_spill:
                best_spill = spill
                best_delta = delta
                best_labels = pert_labels
                best_X_adv = X_pert

    return {
        "pivot_idx": pivot_idx,
        "best_delta": best_delta,
        "best_spill": best_spill,
        "base_labels": base_labels,
        "adv_labels": best_labels,
        "X_adv": best_X_adv,
        "k1": k1,
        "k2": k2,
    }


def attack_best_pair(
    X: np.ndarray,
    clusterer,
    delta_max,
    n_iters_per_pair: int = 50,
    n_candidates_per_iter: int = 10,
    random_state: int | None = None,
):
    """
    Search over all ordered cluster pairs (k1, k2) with k1 != k2 and
    return the best attack.

    This is the K-cluster extension: we don't assume only 2 clusters exist.
    """
    rng = np.random.default_rng(random_state)

    cl0 = clone(clusterer)
    base_labels = cl0.fit_predict(X)
    unique_clusters = np.unique(base_labels)

    best_global = None
    best_global_spill = -np.inf

    for k1 in unique_clusters:
        for k2 in unique_clusters:
            if k1 == k2:
                continue

            try:
                result = attack_pair(
                    X,
                    clusterer,
                    k1=int(k1),
                    k2=int(k2),
                    delta_max=delta_max,
                    n_iters=n_iters_per_pair,
                    n_candidates_per_iter=n_candidates_per_iter,
                    random_state=rng.integers(1e9),
                )
            except ValueError:
                # might happen if cluster is empty for some reason
                continue

            if result["best_spill"] > best_global_spill:
                best_global_spill = result["best_spill"]
                best_global = result

    return best_global


In [None]:
## run SFAA
