In [1]:
import cupy as cp
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import src.mvae.mt.mvae.utils as utils
import torch
import yaml

from functools import partial
from geomstats.geometry.hypersphere import Hypersphere
from geomstats.geometry.euclidean import Euclidean
from geomstats.geometry.hyperbolic import Hyperbolic
from geomstats.geometry.product_manifold import ProductManifold
from scipy import stats
from scipy.io import mmread
from src.lightning.gene import GeneModule
from src.mvae.mt.data import GeneDataset
from src.mvae.mt.mvae.components import *
from src.mvae.mt.mvae.distributions import *
from src.mvae.mt.mvae.models.gene_vae import GeneVAE
from src.mvae.mt.mvae.ops.hyperbolics import lorentz_to_poincare
from src.mvae.mt.mvae.ops.spherical import spherical_to_projected
from src.rkmeans.rkmeans import RiemannianKMeans, merge_clusters
from torch.utils.data import DataLoader
from tqdm.auto import tqdm

In [None]:
# Load dataset
config_path = "/home/romainlhardy/code/hyperbolic-cancer/configs/cellxgene/cellxgene_e20h20s20.yaml"
with open(config_path, "r") as f:
    config = yaml.safe_load(f)

dataset = GeneDataset(**config["data"]["options"])
print(dataset.n_gene_r)
print(dataset.n_gene_p)
print(dataset.n_batch)
print(len(dataset))

dataloader = DataLoader(dataset, batch_size=2048, num_workers=16, shuffle=True)

x_r, x_p, batch_idx = dataset[np.random.choice(len(dataset))]
print(x_r, x_p, batch_idx)
print(x_r.max())

In [None]:
# Load MC-VAE model
checkpoint_path = "/home/romainlhardy/code/hyperbolic-cancer/models/mvae/cellxgene_mvae_e20h20s20.ckpt"

device = "cuda"
config["lightning"]["model"]["options"]["n_gene_r"] = dataset.n_gene_r
config["lightning"]["model"]["options"]["n_gene_p"] = dataset.n_gene_p
config["lightning"]["model"]["options"]["n_batch"] = dataset.n_batch
module = GeneModule(config).to(device)

if checkpoint_path is not None:
    module.load_state_dict(torch.load(checkpoint_path)["state_dict"])

model = module.model
model.eval()

x_r, x_p, batch_idx = next(iter(dataloader))
outputs = model(x_r.to(device), x_p.to(device), batch_idx.to(device))

In [None]:
# Compute latent embeddings
def get_latents(reparametrized, num_components=1):
    assert len(reparametrized) > 0

    latents = [[] for _ in range(num_components)]
    for r in reparametrized:
        for i, rr in enumerate(r):
            latents[i].append(rr.q_z.loc.detach().cpu().numpy())

    for i in range(num_components):
        latents[i] = np.concatenate(latents[i], axis=0)
        
    return latents

dataloader = DataLoader(dataset, batch_size=2048, num_workers=16, shuffle=False)

reparametrized = []
for batch in tqdm(dataloader):
    x_r, x_p, batch_idx = batch
    with torch.no_grad():
        outputs = model(x_r.to(device), x_p.to(device), batch_idx.to(device))
    reparametrized.append(outputs["reparametrized"])

num_components = len(model.components)
latents = get_latents(reparametrized, num_components)

In [None]:
# Create a product manifold
components = []
for component in model.components:
    if isinstance(component, EuclideanComponent):
        components.append(Euclidean(dim=component.dim))
    elif isinstance(component, HyperbolicComponent):
        components.append(Hyperbolic(dim=component.dim - 1))
    elif isinstance(component, SphericalComponent):
        components.append(Hypersphere(dim=component.dim - 1))

manifold = ProductManifold(components)
print(manifold.dim)

In [None]:
# Run fine-grained k-means clustering
X = np.concatenate(latents, axis=-1)
assert X.shape[1] == manifold.shape[0]

# Subsample the data if X is large
X_sub = X.copy()
max_size = 50000
if X.shape[0] > max_size:
    X_sub = X[np.random.choice(X.shape[0], max_size, replace=False)]

rkmeans = RiemannianKMeans(
    space=manifold,
    n_clusters=1000,
    batch_size=len(X_sub),        
    random_state=42,
    init="random",    
    tol=1e-2,
    max_iter=1000,
    verbose=1,
)
rkmeans.fit(X_sub)

labels = rkmeans.labels_
cluster_centers = rkmeans.cluster_centers_
print(labels.shape, cluster_centers.shape)

In [None]:
# Predict on the whole dataset
labels = rkmeans.predict(X)
print(labels.shape)

In [None]:
# Merge clusters
final_assignments, n_final_clusters = merge_clusters(
    manifold, 
    labels,
    cluster_centers,
    merge_threshold=1.5,
    verbose=1,
)
print(np.unique(final_assignments, return_counts=True))

In [None]:
# Merge small clusters
unique_labels, counts = np.unique(final_assignments, return_counts=True)
print(f"Cluster counts before merging small clusters: {counts}")
print(f"Number of clusters before merging: {len(unique_labels)}")

min_cluster_size = 1000

# Identify the largest cluster
if len(counts) > 0:
    largest_cluster_index = np.argmax(counts)
    largest_cluster_label = unique_labels[largest_cluster_index]
    print(f"Largest cluster label: {largest_cluster_label} with size {counts[largest_cluster_index]}")

    # Identify small clusters
    small_cluster_indices = np.where(counts < min_cluster_size)[0]
    small_cluster_labels = unique_labels[small_cluster_indices]
    print(f"Labels of small clusters (< {min_cluster_size}): {small_cluster_labels}")

    # Reassign elements from small clusters to the largest cluster
    # Avoid reassigning the largest cluster to itself if it happens to be small (edge case)
    for label in small_cluster_labels:
        if label != largest_cluster_label:
            final_assignments[final_assignments == label] = largest_cluster_label
            print(f"Reassigned cluster {label} to {largest_cluster_label}")
else:
    print("No clusters found in final_assignments.")

# Verify the result
unique_labels_after, counts_after = np.unique(final_assignments, return_counts=True)
print(f"Final cluster counts after merging: {counts_after}")
print(f"Number of final clusters after merging: {len(unique_labels_after)}")
print(np.unique(final_assignments, return_counts=True))

In [None]:
# Remap cluster labels to be contiguous from 0
unique_labels_final = np.unique(final_assignments)
print(f"Unique labels before remapping: {unique_labels_final}")

# Create a mapping from the current labels to new labels (0, 1, 2, ...)
label_mapping = {old_label: new_label for new_label, old_label in enumerate(unique_labels_final)}

# Apply the mapping
final_assignments = np.array([label_mapping[label] for label in final_assignments])

# Verify the remapping
unique_labels_remapped, counts_remapped = np.unique(final_assignments, return_counts=True)
print(f"Unique labels after remapping: {unique_labels_remapped}")
print(f"Counts after remapping: {counts_remapped}")
print(f"Number of clusters after remapping: {len(unique_labels_remapped)}")

In [22]:
np.save("/home/romainlhardy/code/hyperbolic-cancer/data/cellxgene/climb_cluster_assignments.npy", final_assignments)