In [None]:
import numpy as np
import polars as pl
import seaborn as sns
import colorcet as cc
import matplotlib.pyplot as plt
from umap import UMAP
from pathlib import Path
from sklearn.decomposition import PCA
from scipy.optimize import linear_sum_assignment
from sklearn.cluster import AgglomerativeClustering

# PCA

In [None]:
embedding_cols = [
    "gte_embedding_clean",
    "snowflake_embedding_clean",
]

embed_df = (
    pl.scan_parquet("../emails_with_tokens_and_embeddings.parquet")
    .select(["from", "subject", *embedding_cols])
    .collect()
)

# this represents ~90% of the variance for the embeddings
pcs = {
    c: PCA(n_components=250).fit_transform(embed_df[c].to_numpy())
    for c in embedding_cols
}

In [None]:
save_path = "../emails_with_tokens_and_embeddings_with_clusters.parquet"

if not Path(save_path).exists():
    # takes a long time to run: ~40 minutes
    agg_clustering = AgglomerativeClustering(
        n_clusters=10, metric="cosine", linkage="complete"
    )

    for k, v in pcs.items():
        cluster_labels = agg_clustering.fit_predict(v)
        embed_df = embed_df.with_columns(
            pl.Series(f"cluster_labels_{k}", cluster_labels)
        )

    embed_df.write_parquet(save_path)
else:
    embed_df = pl.read_parquet(save_path)
    cluster_labels = {k: embed_df[f"cluster_labels_{k}"] for k in embedding_cols}

In [None]:
def match_clusters(cluster_labels_1, cluster_labels_2):
    cluster_counts = (
        pl.DataFrame(
            {
                "cluster_1": cluster_labels_1,
                "cluster_2": cluster_labels_2,
            }
        )
        .group_by(["cluster_1", "cluster_2"])
        .agg(pl.len())
    )

    cluster_counts_df = (
        cluster_counts.pivot(on="cluster_1", index="cluster_2", values="len")
        .fill_null(0)
        .to_pandas()
        .set_index("cluster_2")
    )
    cluster_counts_df = cluster_counts_df[
        sorted(cluster_counts_df.columns)
    ].sort_index()
    cluster_counts = cluster_counts_df.to_numpy()

    row_ind, col_ind = linear_sum_assignment(-cluster_counts)
    return row_ind, col_ind, cluster_counts_df

In [None]:
row_ind, col_ind, _df = match_clusters(
    cluster_labels["gte_embedding_clean"], cluster_labels["snowflake_embedding_clean"]
)
cluster_2_to_1 = {k: v for k, v in zip(row_ind, col_ind)}
cluster_labels[embedding_cols[1]] = np.array(
    [cluster_2_to_1[c] for c in cluster_labels[embedding_cols[1]]]
)

In [None]:
title_map = {
    "gte_embedding_clean": "gte-base-en-v1.5",
    "snowflake_embedding_clean": "snowflake-arctic-embed-m-v2.0",
}
fig, axes = plt.subplots(1, 2, figsize=(6, 3))

for ax, (k, v) in zip(axes.flat, cluster_labels.items()):
    _pcs = pcs[k]
    ax.scatter(
        _pcs[:, 0],
        _pcs[:, 1],
        c=v,
        cmap=cc.cm.glasbey,
        vmax=255,
        s=1.5,
        lw=0,
        rasterized=True,
    )
    ax.set(xticks=[], yticks=[], title=title_map[k])
    sns.despine(left=True, bottom=True)

    # axis embellishment
    ax.plot([0, 0.2], [0, 0], transform=ax.transAxes, color="k", lw=3)
    ax.plot([0, 0], [0, 0.2], transform=ax.transAxes, color="k", lw=3)
    ax.set_xlabel("PC1", x=0, horizontalalignment="left")
    ax.set_ylabel("PC2", y=0.08, verticalalignment="bottom")

fig.savefig("img/pca_clusters.png", dpi=300, bbox_inches="tight")

# UMAP

In [None]:
# ~ 6 minutes
umap_mdl = UMAP(
    n_components=2,
    n_neighbors=20,
    min_dist=0.5,
    metric="cosine",
    random_state=0,
    n_jobs=1,
)
umap_embedding = {k: umap_mdl.fit_transform(v) for k, v in pcs.items()}

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(6, 3))

for ax, (k, v) in zip(axes.flat, cluster_labels.items()):
    _umaps = umap_embedding[k]
    ax.scatter(
        _umaps[:, 0],
        _umaps[:, 1],
        c=v,
        cmap=cc.cm.glasbey,
        vmax=255,
        s=1.5,
        lw=0,
        rasterized=True,
    )
    ax.set(xticks=[], yticks=[], title=title_map[k])
    sns.despine(left=True, bottom=True)

    # axis embellishment
    ax.plot([0, 0.2], [0, 0], transform=ax.transAxes, color="k", lw=3)
    ax.plot([0, 0], [0, 0.2], transform=ax.transAxes, color="k", lw=3)
    ax.set_xlabel("UMAP1", x=0, horizontalalignment="left")
    ax.set_ylabel("UMAP2", y=0.13, verticalalignment="bottom")

fig.savefig("img/umap_clusters.png", dpi=300, bbox_inches="tight")