In [215]:
import math
import os
import torch
import torch.distributions.constraints as constraints
import pyro
from pyro.optim import Adam
from pyro.infer import SVI, Trace_ELBO
import pyro.distributions as dist
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap

import numpy as np
import pandas as pd
import scanpy as sc
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import torch
from torch.utils.data import DataLoader, TensorDataset
import lightning as L
import torch.nn.functional as F
import optuna
import tensorboard
from lightning.pytorch.loggers import TensorBoardLogger

import warnings
warnings.filterwarnings("ignore")
from importlib import reload

import torchclustermetrics 
reload(torchclustermetrics)
from torchclustermetrics import silhouette

# this ensures that I can update the class without losing my variables in my notebook
import xenium_cluster
reload(xenium_cluster)
from xenium_cluster import XeniumCluster
from utils.metrics import *

from sklearn.decomposition import PCA

In [None]:
# Path to your .gz file
file_path = 'data/hBreast/transcripts.csv.gz'

# Read the gzipped CSV file into a DataFrame
df_transcripts = pd.read_csv(file_path, compression='gzip')
df_transcripts["error_prob"] = 10 ** (-df_transcripts["qv"]/10)
df_transcripts.head(), df_transcripts.shape

In [None]:
# drop cells without ids
df_transcripts = df_transcripts[df_transcripts["cell_id"] != -1]

In [None]:
cells = df_transcripts.groupby(['cell_id', 'feature_name']).size().reset_index(name='count')
cells_pivot = cells.pivot_table(index='cell_id', 
                                columns='feature_name', 
                                values='count', 
                                fill_value=0)
cells_pivot.shape

In [218]:
location_means = df_transcripts.groupby('cell_id').agg({
    'x_location': 'mean',
    'y_location': 'mean',
    'z_location': 'mean'
}).reset_index()

cells_pivot = location_means.join(cells_pivot, on='cell_id')

In [219]:
# log normalization
cells_pivot.iloc[:, 4:] = np.log1p(cells_pivot.iloc[:, 4:])

In [220]:
pca = PCA(n_components=5)
pca_data = pca.fit_transform(cells_pivot.iloc[:, 4:])

In [271]:
NUM_CLUSTERS = 6
BATCH_SIZE = 512

# Clear the param store in case we're in a REPL
pyro.clear_param_store()

In [272]:
# Load the data (5 PCs for each spot)
data = torch.tensor(pca_data).float()

In [273]:
def spatial_loss(model, guide, data, original_positions, batch_size=256, weight=100.0, sigma=1.0, *args, **kwargs):

    elbo_loss_fn = Trace_ELBO(num_particles=10).differentiable_loss
    elbo_loss = elbo_loss_fn(model, guide, data, *args, **kwargs)
    
    def smoothness_loss(cluster_probs, original_positions, sigma):
        pairwise_distances = torch.cdist(original_positions, original_positions, p=2)
        adjacency_matrix = torch.exp(-pairwise_distances**2 / (2 * sigma**2))
        # cluster_probs = F.softmax(cluster_probs, dim=1)
        cluster_probs = F.gumbel_softmax(cluster_probs, tau=0.25, dim=1)
        diffs = cluster_probs.unsqueeze(1) - cluster_probs.unsqueeze(0)
        smoothness_loss_value = torch.sum(adjacency_matrix * torch.sum(diffs**2, dim=-1))
        return smoothness_loss_value
    
    with pyro.plate("data", len(original_positions), subsample_size=batch_size) as ind:
        cluster_probs = (pyro.param("cluster_concentration_params_q")[ind])
        positions = original_positions[ind]
    
    spatial_loss_value = smoothness_loss(cluster_probs, positions, sigma)
    
    total_loss = elbo_loss + weight * spatial_loss_value
    
    # print(f"ELBO: {elbo_loss.item()}, SPATIAL: {weight * spatial_loss_value.item()}, CUSTOM: {total_loss.item()}")

    return total_loss

In [274]:
def model(data):
    # Define priors for the cluster assignment probabilities and Gaussian parameters
    with pyro.plate("data", len(data), subsample_size=BATCH_SIZE) as ind:
        batch_data = data[ind]
        cluster_probs = pyro.sample("cluster_probs", dist.Dirichlet(torch.ones(BATCH_SIZE, NUM_CLUSTERS)))
        
        # Define the means and variances of the Gaussian components
        cluster_means = pyro.sample("cluster_means", dist.Normal(0., 1.).expand([NUM_CLUSTERS, batch_data.size(1)]).to_event(2))
        cluster_scales = pyro.sample("cluster_scales", dist.LogNormal(0., 1.).expand([NUM_CLUSTERS, batch_data.size(1)]).to_event(2))
        
        # Likelihood of data given cluster assignments
        pyro.sample("obs", dist.MixtureOfDiagNormals(cluster_means, cluster_scales, cluster_probs).to_event(1), obs=batch_data)

def guide(data):
    # Initialize cluster assignment probabilities for the entire dataset
    MIN_CONCENTRATION = 0.1
    cluster_concentration_params_q = pyro.param("cluster_concentration_params_q", torch.ones(data.size(0), NUM_CLUSTERS), constraint=dist.constraints.positive) + MIN_CONCENTRATION
    # Global variational parameters for means and scales
    cluster_means_q_mean = pyro.param("cluster_means_q", torch.randn(NUM_CLUSTERS, data.size(1)))
    cluster_scales_q_mean = pyro.param("cluster_scales_q", torch.ones(NUM_CLUSTERS, data.size(1)), constraint=dist.constraints.positive)
    
    with pyro.plate("data", len(data), subsample_size=BATCH_SIZE) as ind:

        batch_cluster_concentration_params_q = cluster_concentration_params_q[ind]

        # pyro.sample("cluster_assignments", dist.Categorical(batch_cluster_probs_q))
        pyro.sample("cluster_probs", dist.Dirichlet(batch_cluster_concentration_params_q))
        pyro.sample("cluster_means", dist.Normal(cluster_means_q_mean, 0.1).to_event(2))
        pyro.sample("cluster_scales", dist.LogNormal(cluster_scales_q_mean, 0.1).to_event(2))


In [275]:
from pyro.optim import PyroOptim, PyroLRScheduler
from torch.optim import Adam, lr_scheduler

starting_lr = 0.01
ending_lr = 0.00001
N_STEPS = 100000

# Setup the optimizer
adam_params = {"lr": 0.01, "betas": (0.90, 0.999)}
optimizer = PyroOptim(Adam, adam_params)
scheduler = PyroLRScheduler(lr_scheduler.StepLR, {'optimizer': Adam, 'optim_args': {'lr': starting_lr}, 'step_size': 1, 'gamma': (ending_lr / starting_lr) ** (1 / N_STEPS)})

In [276]:
original_positions = torch.tensor(cells_pivot[["x_location", "y_location"]].to_numpy())
original_positions = (original_positions - original_positions.mean(dim=0)) / original_positions.std(dim=0)
# Setup the inference algorithm
svi = SVI(model, guide, scheduler, loss=lambda model, guide, data, original_positions: spatial_loss(model, guide, data, original_positions))

# Setup the inference algorithm
# svi = SVI(model, guide, scheduler, loss=Trace_ELBO(num_particles=10))

In [None]:
# Do gradient steps
for step in range(N_STEPS):
    loss = svi.step(data, original_positions)
    # loss = svi.step(data)
    svi.optim.step()
    if step % 100 == 0:
        print(f"Step {step} : loss = {round(loss/1e6, 4)}")

In [None]:
# Grab the learned variational parameters
cluster_concentration_params_q = pyro.param("cluster_concentration_params_q")
cluster_probs_q = pyro.sample("cluster_probs", dist.Dirichlet(cluster_concentration_params_q))
cluster_concentration_params_q = cluster_concentration_params_q.detach().numpy()

cluster_assignments_q = cluster_probs_q.argmax(dim=1)
cluster_means_q_mean = pyro.param("cluster_means_q").detach().numpy()
cluster_scales_q_mean = pyro.param("cluster_scales_q").detach().numpy()

# Output the learned cluster probabilities for each data point
print(cluster_assignments_q, cluster_probs_q)

In [None]:
plt.hist(cluster_assignments_q)

In [None]:
np.set_printoptions(suppress=True) 
np.round(cluster_means_q_mean, 4), np.round(cluster_scales_q_mean, 4)

In [None]:
cells_pivot.head()

In [None]:
_ = plt.scatter(cells_pivot["x_location"], cells_pivot["y_location"], s=1, c=cluster_assignments_q)