# Benchmarking

In this notebook we will benchmark our method for identifying circular structures against the synthetic dataset of Lederer at al. 

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.express as px
import random
import seaborn as sns
import scanpy as sc
import scanpy.external as sce
import scipy.stats as ss

from pathlib import Path
from ripser import ripser
from scipy.spatial import distance
from tqdm import tqdm


RANDOM_STATE = 240209

In [None]:
pip install velocycle

In [None]:
import velocycle as vc
from velocycle import *

from lederer_utils import simulate_data
from lederer_utils import circular_corrcoef

import pyro

We need a function to align the estimated phase with the ground truth phase.

In [83]:
import matplotlib.pyplot as plt

def compute_final_circular_corrcoef_with_plotting(adata):
    """
    Function to recenter data by the coord value of the cells with the lowest 2% simulated phi values,
    reverse if the initial correlation is less than 0.5, and return the final circular correlation coefficient.
    Also plots the original, reparametrized, and reversed data.
    
    Parameters:
    adata : AnnData object
        The annotated data matrix containing the obs and var data.
    
    Returns:
    float
        The final circular correlation coefficient after alignment and possible reversal.
    """
    
    # Initial computations
    coords = 2 * np.pi * adata.obs['coords']
    simulated_phis = adata.obs['simulated_phis']

    # Plot the original data
    plt.figure(figsize=(12, 6))
    plt.scatter(adata.obs['coords'], simulated_phis, label='Original', alpha=0.6)

    # Find the indexes of cells in the lowest 2% of simulated phi values
    threshold_phi = np.percentile(simulated_phis, 2)
    low_phi_indexes = adata.obs['simulated_phis'] <= threshold_phi

    prior_corr = circular_corrcoef(2 * np.pi * adata.obs['coords'], simulated_phis)

    print("Prior circular correlation coefficient: ", prior_corr)

    # Get the average coord value of these cells
    recenter_value = adata.obs['coords'][low_phi_indexes].mean()

    # Reparametrize by subtracting this coord value
    adata.obs['coords'] -= recenter_value
    adata.obs['coords'] %= 1  # Ensure coords are still within [0, 1) after shift

    # Recompute the coordinates after reparametrization
    coords = 2 * np.pi * adata.obs['coords']

    # Plot the reparametrized data
    plt.scatter(adata.obs['coords'], simulated_phis, label='Reparametrized', alpha=0.6)

    # Compute the initial circular correlation coefficient
    initial_corr = circular_corrcoef(2 * np.pi * adata.obs['coords'], simulated_phis)
    
    print(f'Initial circular correlation coefficient: {initial_corr}')


    # If the initial correlation is less than 0.5, reverse the coordinates
    if initial_corr < 0.35:
        adata.obs['coords'] = 1 - adata.obs['coords']
        plt.scatter(adata.obs['coords'], simulated_phis, label='Reversed', alpha=0.6)


    # Compute the final circular correlation coefficient
    circular_corr = circular_corrcoef(2 * np.pi * adata.obs['coords'], simulated_phis)

    max_corr = np.max([prior_corr, initial_corr, circular_corr])
    
    # Plot settings
    plt.xlabel('Coords')
    plt.ylabel('Simulated Phis')
    plt.legend()
    plt.title(f'Final Circular Correlation Coefficient: {max_corr:.4f}')
    plt.show()
    
    print(f'Final circular correlation coefficient: {max_corr}')
    
    return circular_corr

We run the method on the Lederer et al. synthetic dataset over the same parameter set of number of genes and cells. For each, we compute:
- The circular correlation between the gorund truth and the simulated phase.
- The ring score (both diameter and ratio) of the data transformed onto the first 3 principal components.

We recover a similar pattern to Lederer at. al, where the circular correlation increases with both number of genes and number of cells. We observe that the ring score follows the same patterns. We argue that this demonstrates the viability of using the ring score as preprocessing step for the existence of circular structure in data, prior to applying our method or any other that assumes circular structure.

In [None]:
import anndata as ad

adata = simulate_data(Nc=100, Ng=300)

import anndata as ad

def create_spliced_unspliced_adata(adata):
    """
    Creates a new AnnData object with spliced and unspliced copies of each gene variable,
    along with concatenated spliced and unspliced layers as the count matrix.

    Parameters:
    adata (AnnData): Input AnnData object with 'spliced' and 'unspliced' layers.

    Returns:
    AnnData: New AnnData object with concatenated spliced and unspliced count matrix
             and duplicated metadata in var for each spliced and unspliced version.
    """
    # Ensure that spliced and unspliced layers are dense arrays (if they are sparse matrices)
    X_spliced = adata.layers["spliced"].toarray() if hasattr(adata.layers["spliced"], "toarray") else adata.layers["spliced"]
    X_unspliced = adata.layers["unspliced"].toarray() if hasattr(adata.layers["unspliced"], "toarray") else adata.layers["unspliced"]

    # Concatenate the spliced and unspliced matrices along the gene axis
    X_new = np.concatenate([X_spliced, X_unspliced], axis=1)

    # Duplicate `adata.var` metadata for spliced and unspliced versions of each gene
    var_spliced = adata.var.copy()
    var_spliced.index = var_spliced.index + "_spliced"

    var_unspliced = adata.var.copy()
    var_unspliced.index = var_unspliced.index + "_unspliced"

    # Combine the duplicated metadata
    var_new = pd.concat([var_spliced, var_unspliced])

    # Create the new AnnData object
    adata_new = ad.AnnData(X=X_new, obs=adata.obs.copy(), var=var_new)

    # Copy over uns metadata
    adata_new.uns = adata.uns.copy()

    return adata_new

In [None]:
# Define the ranges for Nc and Ng
#Ng_values = [100, 200, 350, 500, 750, 1000]
#Nc_values = [100, 200, 500, 1000]

Ng_values = [300]*20
Nc_values = [3000]


# Initialize arrays to hold the scores
diam_scores = np.zeros((len(Nc_values), len(Ng_values)))
ratio_scores = np.zeros((len(Nc_values), len(Ng_values)))
circ_correlation = np.zeros((len(Nc_values), len(Ng_values)))

# Loop over the different values of Nc and Ng
for i, Nc in enumerate(Nc_values):
    for j, Ng in enumerate(Ng_values):
        print(f'Nc: {Nc}, Ng: {Ng}')
        # Simulate the data
        adata = simulate_data(Nc=Nc, Ng=Ng)
        adata.var['sim_sin'] = adata.uns['simulated_nu'][:,:,1]
        adata.var['sim_cos'] = adata.uns['simulated_nu'][:,:,2]

        adata = create_spliced_unspliced_adata(adata)

        # Apply standard normalization
        sc.pp.filter_genes(adata, min_cells=3)
        sc.pp.normalize_per_cell(adata, counts_per_cell_after=1e4)

        # Log transform the data and scale
        sc.pp.log1p(adata)

        # Diffusion and filtering for highly variable genes
        # sce.pp.magic(adata, knn=10, t=2, random_state=RANDOM_STATE, n_jobs=4)
        # sc.pp.highly_variable_genes(adata, subset=True)

        # PCA
        sc.pp.pca(adata, n_comps=2, use_highly_variable=False, random_state=RANDOM_STATE)

        # Calculate the diameter and ratio scores
        diam_score = chnt.ring_score(adata, score_type='diameter', exponent=2, comp=np.arange(2))
        ratio_score = chnt.ring_score(adata, score_type='ratio', exponent=2, comp=np.arange(2))

        # Store the results
        diam_scores[i, j] = diam_score
        ratio_scores[i, j] = ratio_score
        
        # Calculate the circular correlation
        chnt.circular(adata, comp = [0,1])
        circ_correlation[i,j] = compute_final_circular_corrcoef_with_plotting(adata)

        
        
        

# Plot the results as heatmaps
fig, axes = plt.subplots(1, 3, figsize=(21, 6))

sns.heatmap(diam_scores, xticklabels=Ng_values, yticklabels=Nc_values, ax=axes[0], cmap='viridis')
axes[0].set_title('Diameter Score Heatmap')
axes[0].set_xlabel('Ng (Number of Genes)')
axes[0].set_ylabel('Nc (Number of Cells)')

sns.heatmap(ratio_scores, xticklabels=Ng_values, yticklabels=Nc_values, ax=axes[1], cmap='viridis')
axes[1].set_title('Ratio Score Heatmap')
axes[1].set_xlabel('Ng (Number of Genes)')
axes[1].set_ylabel('Nc (Number of Cells)')

sns.heatmap(circ_correlation, xticklabels=Ng_values, yticklabels=Nc_values, ax=axes[2], cmap='viridis')
axes[2].set_title('Circular Correlation Score Heatmap')
axes[2].set_xlabel('Ng (Number of Genes)')
axes[2].set_ylabel('Nc (Number of Cells)')

plt.tight_layout()
plt.show()

In [None]:
# make a boxplot of the circular correlation scores
plt.figure(figsize=(10, 6))
sns.boxplot(data=circ_correlation, orient='v')
plt.xticks(ticks=range(len(Nc_values)), labels=Nc_values)
plt.xlabel('Number of Cells (Nc)')
plt.ylabel('Circular Correlation Score')
plt.title('Circular Correlation Score Boxplot')
plt.show()


In [None]:
# Define the ranges for Nc and Ng
#Ng_values = [100, 200, 350, 500, 750, 1000]
#Nc_values = [100, 200, 500, 1000]

Ng_values = [100,200,350,500,700,1000]
Nc_values = [100,200,500,1000,3000]


# Initialize arrays to hold the scores
diam_scores = np.zeros((len(Nc_values), len(Ng_values)))
ratio_scores = np.zeros((len(Nc_values), len(Ng_values)))
circ_correlation = np.zeros((len(Nc_values), len(Ng_values)))
pca_circ_correlation = np.zeros((len(Nc_values), len(Ng_values)))

# Loop over the different values of Nc and Ng
for i, Nc in enumerate(Nc_values):
    for j, Ng in enumerate(Ng_values):
        print(f'Nc: {Nc}, Ng: {Ng}')
        # Simulate the data
        adata = simulate_data(Nc=Nc, Ng=Ng)
        adata.var['sim_sin'] = adata.uns['simulated_nu'][:,:,1]
        adata.var['sim_cos'] = adata.uns['simulated_nu'][:,:,2]

        adata = create_spliced_unspliced_adata(adata)

        # Apply standard normalization
        sc.pp.filter_genes(adata, min_cells=3)
        sc.pp.normalize_per_cell(adata, counts_per_cell_after=1e4)

        # Log transform the data and scale
        sc.pp.log1p(adata)

        # Diffusion and filtering for highly variable genes
        # sce.pp.magic(adata, knn=10, t=2, random_state=RANDOM_STATE, n_jobs=4)
        # sc.pp.highly_variable_genes(adata, subset=True)

        # PCA
        sc.pp.pca(adata, n_comps=2, use_highly_variable=False, random_state=RANDOM_STATE)

        # Calculate the diameter and ratio scores
        diam_score = chnt.ring_score(adata, score_type='diameter', exponent=2, comp=np.arange(2))
        ratio_score = chnt.ring_score(adata, score_type='ratio', exponent=2, comp=np.arange(2))

        # Store the results
        diam_scores[i, j] = diam_score
        ratio_scores[i, j] = ratio_score
        
        # Calculate the circular correlation
        chnt.circular(adata, comp = [0,1])
        circ_correlation[i,j] = compute_final_circular_corrcoef_with_plotting(adata)

        angles = np.arctan2(adata.obsm['X_pca'][:, 0], adata.obsm['X_pca'][:, 1])
        pca_adata = adata.copy()
        pca_adata.obs['coords'] = ( angles + np.pi) / (2 * np.pi)
        pca_circ_correlation[i,j] = compute_final_circular_corrcoef_with_plotting(pca_adata)

        

    
# Plot the results as heatmaps
fig, axes = plt.subplots(1, 4, figsize=(28, 6))

sns.heatmap(diam_scores, xticklabels=Ng_values, yticklabels=Nc_values, ax=axes[0], cmap='viridis', annot = True)
axes[0].set_title('Diameter Score Heatmap')
axes[0].set_xlabel('Ng (Number of Genes)')
axes[0].set_ylabel('Nc (Number of Cells)')

sns.heatmap(ratio_scores, xticklabels=Ng_values, yticklabels=Nc_values, ax=axes[1], cmap='viridis', annot = True)
axes[1].set_title('Ratio Score Heatmap')
axes[1].set_xlabel('Ng (Number of Genes)')
axes[1].set_ylabel('Nc (Number of Cells)')

sns.heatmap(circ_correlation, xticklabels=Ng_values, yticklabels=Nc_values, ax=axes[2], cmap='viridis', annot = True)
axes[2].set_title('Chunter Circular Correlation Score Heatmap')
axes[2].set_xlabel('Ng (Number of Genes)')
axes[2].set_ylabel('Nc (Number of Cells)')


sns.heatmap(pca_circ_correlation, xticklabels=Ng_values, yticklabels=Nc_values, ax=axes[3], cmap='viridis', annot = True)
axes[3].set_title('PCA Angle Circular Correlation Score Heatmap')
axes[3].set_xlabel('Ng (Number of Genes)')
axes[3].set_ylabel('Nc (Number of Cells)')

plt.tight_layout()
plt.show()

In [None]:
# Plot the results as heatmaps
fig, axes = plt.subplots(1, 3, figsize=(21, 6))

sns.heatmap(diam_scores, xticklabels=Ng_values, yticklabels=Nc_values, ax=axes[0], cmap='viridis', annot=True)
axes[0].set_title('Diameter Score Heatmap')
axes[0].set_xlabel('Ng (Number of Genes)')
axes[0].set_ylabel('Nc (Number of Cells)')

sns.heatmap(ratio_scores, xticklabels=Ng_values, yticklabels=Nc_values, ax=axes[1], cmap='viridis', annot=True)
axes[1].set_title('Ratio Score Heatmap')
axes[1].set_xlabel('Ng (Number of Genes)')
axes[1].set_ylabel('Nc (Number of Cells)')

sns.heatmap(circ_correlation, xticklabels=Ng_values, yticklabels=Nc_values, ax=axes[2], cmap='viridis', annot=True)
axes[2].set_title('Circular Correlation Score Heatmap')
axes[2].set_xlabel('Ng (Number of Genes)')
axes[2].set_ylabel('Nc (Number of Cells)')

plt.tight_layout()
plt.show()

In [None]:
adata.obsm['X_pca']

angles = np.arctan2(adata.obsm['X_pca'][:, 0], adata.obsm['X_pca'][:, 1])
pca_adata = adata.copy()
pca_adata.obs['coords'] = ( angles + np.pi) / (2 * np.pi)
compute_final_circular_corrcoef_with_plotting(pca_adata)

plt.figure(figsize=(10, 6))
plt.scatter(angles, adata.obs['simulated_phis'])
plt.xlabel('Angle')
plt.ylabel('Coords')
plt.show()



pca_adata = adata.copy()
pca_adata.obs['coords'] = ( angles + np.pi) / (2 * np.pi)
compute_final_circular_corrcoef_with_plotting(pca_adata)

In [None]:
# redo the above plot to print the value in each cell
fig, axes = plt.subplots(1, 3, figsize=(21, 6))

sns.heatmap(diam_scores, xticklabels=Ng_values, yticklabels=Nc_values, ax=axes[0], cmap='viridis', annot=True)


# Gene phase recovery

In [None]:
# Simulate the data
adata = simulate_data(Nc=3000, Ng=350)
adata.var['sim_sin'] = adata.uns['simulated_nu'][:,:,1]
adata.var['sim_cos'] = adata.uns['simulated_nu'][:,:,2]

adata = create_spliced_unspliced_adata(adata)

adata_velocycle = adata.copy()

# Apply standard normalization
sc.pp.filter_genes(adata, min_cells=3)
sc.pp.normalize_per_cell(adata, counts_per_cell_after=1e4)

# Log transform the data and scale
sc.pp.log1p(adata)

# Diffusion and filtering for highly variable genes
# sce.pp.magic(adata, knn=10, t=6, random_state=RANDOM_STATE, n_jobs=4)
# sc.pp.highly_variable_genes(adata, subset=True)

# PCA
sc.pp.pca(adata, n_comps=5, random_state=RANDOM_STATE)

# Calculate the diameter and ratio scores
diam_score = chnt.ring_score(adata, score_type='diameter', exponent=2, comp=np.arange(5))
ratio_score = chnt.ring_score(adata, score_type='ratio', exponent=2, comp=np.arange(5))

# Calculate the circular correlation
chnt.circular(adata, comp = [0,1])
compute_final_circular_corrcoef_with_plotting(adata)

In [None]:
sc.pl.pca_variance_ratio(adata, log=True)

In [None]:
chnt.plot_diagram(adata, comp = [0,1])

In [None]:
chnt.plot_2d(adata, comp = [0,1], mode = 'pca', c = 'coords')

In [21]:
adata_hvg = adata[:,:350]

In [None]:
# subset adata by highly variable genes
sc.pp.highly_variable_genes(adata)
adata_hvg = adata[:, adata.var['highly_variable']]
adata_hvg

In [None]:
chnt.leadlag(adata_hvg)

In [None]:
chnt.phase_plot(adata_hvg, topk=30)

In [None]:
chnt.plot_top_genes(adata, k = 20)

In [None]:
chnt.leadlag_plot(adata_hvg, k = 50)

In [None]:
fig, ax = plt.subplots(1,2, figsize = (10,5))

# first plot
plt.sca(ax[0])
plt.scatter(adata_hvg.var['sim_cos'], adata_hvg.var['sim_sin'])
plt.xlabel('Simulated cosine')
plt.ylabel('Simulated sine')


# second plot
plt.sca(ax[1])
plt.scatter(np.real(adata_hvg.varm['leadlag_pcs'][:,0]),np.imag(adata_hvg.varm['leadlag_pcs'][:,0]))
plt.xlabel('Simulated cosine')
plt.ylabel('Simulated sine')


In [None]:
plt.scatter(adata.var['sim_sin'], adata.var['sim_cos'])

In [None]:
# redo the previous two plots as subplots of a bigger plot
fig, axs = plt.subplots(1, 2, figsize = (10, 5))

# plot the first subplot
plt.sca(axs[0])

# calculate the correlation between the leadlag pcs and the simulated nu
corr = np.corrcoef(np.imag(adata_hvg.varm['leadlag_pcs'][:,0]), adata_hvg.var['sim_sin'])[0,1]

# add corr as a legend
plt.scatter(np.imag(adata_hvg.varm['leadlag_pcs'][:,0]), adata_hvg.var['sim_sin'])

# label x axis
plt.xlabel('Imaginary Leadlag PC 1')

# label y axis
plt.ylabel('Simulated Sine')

# plot line of best fit
plt.plot(np.imag(adata_hvg.varm['leadlag_pcs'][:,0]), np.polyval(np.polyfit(np.imag(adata_hvg.varm['leadlag_pcs'][:,0]), adata_hvg.var['sim_sin'], 1), np.imag(adata_hvg.varm['leadlag_pcs'][:,0])), color = 'red', label = f'corr = {corr:.2f}')
         
plt.legend()

# plot the second subplot
plt.sca(axs[1])

# calculate the correlation between the leadlag pcs and the simulated nu
corr = np.corrcoef(np.real(adata_hvg.varm['leadlag_pcs'][:,0]), adata_hvg.var['sim_cos'])[0,1]

# add corr as a legend
plt.scatter(np.real(adata_hvg.varm['leadlag_pcs'][:,0]), adata_hvg.var['sim_cos'])

# label x axis
plt.xlabel('Real Leadlag PC 1')

# label y axis
plt.ylabel('Simulated Cosine')

# plot line of best fit
plt.plot(np.real(adata_hvg.varm['leadlag_pcs'][:,0]), np.polyval(np.polyfit(np.real(adata_hvg.varm['leadlag_pcs'][:,0]), adata_hvg.var['sim_cos'], 1), np.real(adata_hvg.varm['leadlag_pcs'][:,0])), color = 'red', label = f'corr = {corr:.2f}')

plt.legend()

plt.show()


In [89]:
top_genes = chnt.get_top_genes(adata, k = 10)

In [None]:
chnt.plot_2d(adata, c = top_genes[:5], mode = 'll')

In [None]:
chnt.plot_2d(adata, c = top_genes[5:], mode = 'll')

# Velocycle benchmark

In [27]:
# generic & ml
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import pyro
import copy
import scipy
import pycircstat
import pickle
from scipy.stats import pearsonr
from scipy.linalg import lstsq
import statsmodels.api as sm

# scRNA-seq
import scanpy as sc
import anndata


In [None]:
adata = simulate_data(Nc=3000, Ng=350)

In [35]:
pyro.clear_param_store()

In [36]:
# VeloCycle operates on the raw counts and has a built-in count factor term
preprocessing.normalize_total(adata)

In [37]:
cycle_prior = cycle.Cycle.trivial_prior(gene_names=adata.var_names, harmonics=1)

In [38]:
# Keep only genes from biologically-relevant gene set that are present in the current dataset
cycle_prior, data_to_fit = preprocessing.filter_shared_genes(cycle_prior, adata, filter_type="intersection")

In [39]:
# Update the priors for gene harmonics
# to gene-specific means and stds
S = data_to_fit.layers['spliced'].toarray()
S_means = S.mean(axis=0) #sum over cells
nu0 = np.log(S_means)

# Mean prior
S_frac_means=np.vstack((nu0, 0*nu0, 0*nu0))
cycle_prior.set_means(S_frac_means)

# Standard deviation prior
nu0std = np.std(np.log(S+1), axis=0)/2
S_frac_stds=np.vstack((nu0std, 0.5*nu0std, 0.5*nu0std))
cycle_prior.set_stds(S_frac_stds)

In [None]:
# Obtain a PCA prior for individual cell phases
# The prior for cell cycle coordinates are taken from the first two components of the PCA
# This is the only step of the model that relies on log transformed counts - but still no smoothing is performed
data_to_fit.layers["S_sz_log"] = np.log(data_to_fit.layers["S_sz"]+1)
phase_prior = phases.Phases.from_pca_heuristic(data_to_fit, 
                                               genes_to_use=adata.var_names, 
                                               layer='S_sz_log',
                                               concentration=5.0, plot=True, small_count=1)

In [None]:
USE_GPU = True
if USE_GPU and torch.cuda.is_available():
    print("Will use GPU")
    device = torch.device("cuda:0")
else:
    print("Will use CPU")
    device = torch.device("cpu")

    # Create design matrix for dataset with a single sample/batch
batch_design_matrix = preprocessing.make_design_matrix(adata, ids="batch")

In [42]:
pyro.clear_param_store()

In [43]:
# Call a preprocessing function to metaparameters to provide to Pyro
metapar = preprocessing.preprocess_for_phase_estimation(anndata=data_to_fit, 
                                          cycle_obj=cycle_prior, 
                                          phase_obj=phase_prior, 
                                          design_mtx=batch_design_matrix,
                                          n_harmonics=1,
                                          device = device,
                                          with_delta_nu=False)



In [44]:
# Define a pyro object for phase inference
phase_fit = phase_inference_model.PhaseFitModel(metaparams=metapar)

In [None]:


# Perform training using a decaying learning rate
num_steps = 500
initial_lr = 0.03
final_lr = 0.005
gamma = final_lr / initial_lr
lrd = gamma ** (1 / num_steps)
adam = pyro.optim.ClippedAdam({'lr': initial_lr, 'lrd': lrd, 'betas': (0.80, 0.99)})

phase_fit.fit(optimizer=adam, num_steps=num_steps)



I have no idea how to compare the complexity given they have a number of parameters which determine the runtime.

In [60]:
cycle_pyro = phase_fit.cycle_pyro
phase_pyro = phase_fit.phase_pyro

In [None]:
phase_pyro.phis

In [None]:
list(np.array(phase_pyro.phis).flatten())

In [49]:
priors = phase_prior.phi_xy
prior_phis = np.angle(priors.iloc[0] + 1j*priors.iloc[1]) % (2*np.pi)

In [None]:
fig, ax = plt.subplots(1,3, figsize = (10,5))

plt.sca(ax[0])
plt.scatter(adata.obs['simulated_phis'], prior_phis)
plt.xlabel('Simulated Phis')
plt.ylabel('Prior Phis')

plt.sca(ax[1])
plt.scatter(adata.obs['simulated_phis'], list(phase_pyro.phis))
plt.xlabel('Simulated Phis')
plt.ylabel('VeloCycle Inferred Phis')

plt.sca(ax[2])
plt.scatter(prior_phis, list(phase_pyro.phis))
plt.xlabel('Prior Phis')
plt.ylabel('VeloCycle Inferred Phis')

In [None]:
def velocycle_estimation(adata):

    adata = adata.copy()

    # VeloCycle operates on the raw counts and has a built-in count factor term
    preprocessing.normalize_total(adata)

    cycle_prior = cycle.Cycle.trivial_prior(gene_names=adata.var_names, harmonics=1)
    # Keep only genes from biologically-relevant gene set that are present in the current dataset
    cycle_prior, data_to_fit = preprocessing.filter_shared_genes(cycle_prior, adata, filter_type="intersection")
    # Update the priors for gene harmonics
    # to gene-specific means and stds
    S = data_to_fit.layers['spliced'].toarray()
    S_means = S.mean(axis=0) #sum over cells
    nu0 = np.log(S_means)

    # Mean prior
    S_frac_means=np.vstack((nu0, 0*nu0, 0*nu0))
    cycle_prior.set_means(S_frac_means)

    # Standard deviation prior
    nu0std = np.std(np.log(S+1), axis=0)/2
    S_frac_stds=np.vstack((nu0std, 0.5*nu0std, 0.5*nu0std))
    cycle_prior.set_stds(S_frac_stds)
    # Obtain a PCA prior for individual cell phases
    # The prior for cell cycle coordinates are taken from the first two components of the PCA
    # This is the only step of the model that relies on log transformed counts - but still no smoothing is performed
    data_to_fit.layers["S_sz_log"] = np.log(data_to_fit.layers["S_sz"]+1)
    phase_prior = phases.Phases.from_pca_heuristic(data_to_fit, 
                                                genes_to_use=adata.var_names, 
                                                layer='S_sz_log',
                                                concentration=5.0, plot=True, small_count=1)
    USE_GPU = True
    if USE_GPU and torch.cuda.is_available():
        print("Will use GPU")
        device = torch.device("cuda:0")
    else:
        print("Will use CPU")
        device = torch.device("cpu")

        # Create design matrix for dataset with a single sample/batch
    batch_design_matrix = preprocessing.make_design_matrix(adata, ids="batch")
    pyro.clear_param_store()

    # Call a preprocessing function to metaparameters to provide to Pyro
    metapar = preprocessing.preprocess_for_phase_estimation(anndata=data_to_fit, 
                                            cycle_obj=cycle_prior, 
                                            phase_obj=phase_prior, 
                                            design_mtx=batch_design_matrix,
                                            n_harmonics=1,
                                            device = device,
                                            with_delta_nu=False)


    # Define a pyro object for phase inference
    phase_fit = phase_inference_model.PhaseFitModel(metaparams=metapar)


    # Perform training using a decaying learning rate
    num_steps = 200
    initial_lr = 0.03
    final_lr = 0.005
    gamma = final_lr / initial_lr
    lrd = gamma ** (1 / num_steps)
    adam = pyro.optim.ClippedAdam({'lr': initial_lr, 'lrd': lrd, 'betas': (0.80, 0.99)})

    phase_fit.fit(optimizer=adam, num_steps=num_steps)

    adata.obs['coords'] = (np.array(phase_fit.phase_pyro.phis) + np.pi)/(2*np.pi)

    return adata.copy()


adata = simulate_data(Nc=3000, Ng=350)

adata = velocycle_estimation(adata)

adata

In [None]:
# Define the ranges for Nc and Ng
#Ng_values = [100, 200, 350, 500, 750, 1000]
#Nc_values = [100, 200, 500, 1000, 3000]

Ng_values = [100, 200, 350, 500, 750, 1000]
Nc_values = [100, 200, 500, 1000, 3000]


# Initialize arrays to hold the scores
diam_scores = np.zeros((len(Nc_values), len(Ng_values)))
ratio_scores = np.zeros((len(Nc_values), len(Ng_values)))
circ_correlation = np.zeros((len(Nc_values), len(Ng_values)))
pca_circ_correlation = np.zeros((len(Nc_values), len(Ng_values)))
velocycle_circ_correlation = np.zeros((len(Nc_values), len(Ng_values)))

# Loop over the different values of Nc and Ng
for i, Nc in enumerate(Nc_values):
    for j, Ng in enumerate(Ng_values):
        print(f'Nc: {Nc}, Ng: {Ng}')
        # Simulate the data
        adata = simulate_data(Nc=Nc, Ng=Ng)

        velocycle_adata = velocycle_estimation(adata)
        velocycle_circ_correlation[i,j] = compute_final_circular_corrcoef_with_plotting(velocycle_adata)
        

        adata.var['sim_sin'] = adata.uns['simulated_nu'][:,:,1]
        adata.var['sim_cos'] = adata.uns['simulated_nu'][:,:,2]

        adata = create_spliced_unspliced_adata(adata)

        # Apply standard normalization
        sc.pp.filter_genes(adata, min_cells=3)
        sc.pp.normalize_per_cell(adata, counts_per_cell_after=1e4)

        # Log transform the data and scale
        sc.pp.log1p(adata)

        # Diffusion and filtering for highly variable genes
        # sce.pp.magic(adata, knn=10, t=2, random_state=RANDOM_STATE, n_jobs=4)
        # sc.pp.highly_variable_genes(adata, subset=True)

        # PCA
        sc.pp.pca(adata, n_comps=2, use_highly_variable=False, random_state=RANDOM_STATE)

        # Calculate the diameter and ratio scores
        diam_score = chnt.ring_score(adata, score_type='diameter', exponent=2, comp=np.arange(2))
        ratio_score = chnt.ring_score(adata, score_type='ratio', exponent=2, comp=np.arange(2))

        # Store the results
        diam_scores[i, j] = diam_score
        ratio_scores[i, j] = ratio_score
        
        # Calculate the circular correlation
        chnt.circular(adata, comp = [0,1])
        circ_correlation[i,j] = compute_final_circular_corrcoef_with_plotting(adata)

        angles = np.arctan2(adata.obsm['X_pca'][:, 0], adata.obsm['X_pca'][:, 1])
        pca_adata = adata.copy()
        pca_adata.obs['coords'] = ( angles + np.pi) / (2 * np.pi)
        pca_circ_correlation[i,j] = compute_final_circular_corrcoef_with_plotting(pca_adata)

# Plot the results as heatmaps
fig, axes = plt.subplots(1, 5, figsize=(35, 6))

sns.heatmap(diam_scores, xticklabels=Ng_values, yticklabels=Nc_values, ax=axes[0], cmap='viridis', annot = True)
axes[0].set_title('Diameter Score Heatmap')
axes[0].set_xlabel('Ng (Number of Genes)')
axes[0].set_ylabel('Nc (Number of Cells)')

sns.heatmap(ratio_scores, xticklabels=Ng_values, yticklabels=Nc_values, ax=axes[1], cmap='viridis', annot = True)
axes[1].set_title('Ratio Score Heatmap')
axes[1].set_xlabel('Ng (Number of Genes)')
axes[1].set_ylabel('Nc (Number of Cells)')

sns.heatmap(circ_correlation, xticklabels=Ng_values, yticklabels=Nc_values, ax=axes[2], cmap='viridis', annot = True)
axes[2].set_title('Chunter Circular Correlation Score Heatmap')
axes[2].set_xlabel('Ng (Number of Genes)')
axes[2].set_ylabel('Nc (Number of Cells)')

sns.heatmap(pca_circ_correlation, xticklabels=Ng_values, yticklabels=Nc_values, ax=axes[3], cmap='viridis', annot = True)
axes[3].set_title('PCA Angle Circular Correlation Score Heatmap')
axes[3].set_xlabel('Ng (Number of Genes)')
axes[3].set_ylabel('Nc (Number of Cells)')

sns.heatmap(velocycle_circ_correlation, xticklabels=Ng_values, yticklabels=Nc_values, ax=axes[4], cmap='viridis', annot = True)
axes[4].set_title('Velocycle Circular Correlation Score Heatmap')
axes[4].set_xlabel('Ng (Number of Genes)')
axes[4].set_ylabel('Nc (Number of Cells)')

plt.tight_layout()
plt.show()

In [None]:
# Define the ranges for Nc and Ng
#Ng_values = [100, 200, 350, 500, 750, 1000]
#Nc_values = [100, 200, 500, 1000, 3000]

Ng_values = [300]*20
Nc_values = [3000]


# Initialize arrays to hold the scores
diam_scores = np.zeros((len(Nc_values), len(Ng_values)))
ratio_scores = np.zeros((len(Nc_values), len(Ng_values)))
circ_correlation = np.zeros((len(Nc_values), len(Ng_values)))
pca_circ_correlation = np.zeros((len(Nc_values), len(Ng_values)))
velocycle_circ_correlation = np.zeros((len(Nc_values), len(Ng_values)))

# Loop over the different values of Nc and Ng
for i, Nc in enumerate(Nc_values):
    for j, Ng in enumerate(Ng_values):
        print(f'Nc: {Nc}, Ng: {Ng}')
        # Simulate the data
        adata = simulate_data(Nc=Nc, Ng=Ng)

        velocycle_adata = velocycle_estimation(adata)
        velocycle_circ_correlation[i,j] = compute_final_circular_corrcoef_with_plotting(velocycle_adata)
        

        adata.var['sim_sin'] = adata.uns['simulated_nu'][:,:,1]
        adata.var['sim_cos'] = adata.uns['simulated_nu'][:,:,2]

        adata = create_spliced_unspliced_adata(adata)

        # Apply standard normalization
        sc.pp.filter_genes(adata, min_cells=3)
        sc.pp.normalize_per_cell(adata, counts_per_cell_after=1e4)

        # Log transform the data and scale
        sc.pp.log1p(adata)

        # Diffusion and filtering for highly variable genes
        # sce.pp.magic(adata, knn=10, t=2, random_state=RANDOM_STATE, n_jobs=4)
        # sc.pp.highly_variable_genes(adata, subset=True)

        # PCA
        sc.pp.pca(adata, n_comps=2, use_highly_variable=False, random_state=RANDOM_STATE)

        # Calculate the diameter and ratio scores
        diam_score = chnt.ring_score(adata, score_type='diameter', exponent=2, comp=np.arange(2))
        ratio_score = chnt.ring_score(adata, score_type='ratio', exponent=2, comp=np.arange(2))

        # Store the results
        diam_scores[i, j] = diam_score
        ratio_scores[i, j] = ratio_score
        
        # Calculate the circular correlation
        chnt.circular(adata, comp = [0,1])
        circ_correlation[i,j] = compute_final_circular_corrcoef_with_plotting(adata)

        angles = np.arctan2(adata.obsm['X_pca'][:, 0], adata.obsm['X_pca'][:, 1])
        pca_adata = adata.copy()
        pca_adata.obs['coords'] = ( angles + np.pi) / (2 * np.pi)
        pca_circ_correlation[i,j] = compute_final_circular_corrcoef_with_plotting(pca_adata)

# Plot the results as heatmaps
fig, axes = plt.subplots(1, 5, figsize=(35, 6))

sns.heatmap(diam_scores, xticklabels=Ng_values, yticklabels=Nc_values, ax=axes[0], cmap='viridis', annot = True)
axes[0].set_title('Diameter Score Heatmap')
axes[0].set_xlabel('Ng (Number of Genes)')
axes[0].set_ylabel('Nc (Number of Cells)')

sns.heatmap(ratio_scores, xticklabels=Ng_values, yticklabels=Nc_values, ax=axes[1], cmap='viridis', annot = True)
axes[1].set_title('Ratio Score Heatmap')
axes[1].set_xlabel('Ng (Number of Genes)')
axes[1].set_ylabel('Nc (Number of Cells)')

sns.heatmap(circ_correlation, xticklabels=Ng_values, yticklabels=Nc_values, ax=axes[2], cmap='viridis', annot = True)
axes[2].set_title('Chunter Circular Correlation Score Heatmap')
axes[2].set_xlabel('Ng (Number of Genes)')
axes[2].set_ylabel('Nc (Number of Cells)')

sns.heatmap(pca_circ_correlation, xticklabels=Ng_values, yticklabels=Nc_values, ax=axes[3], cmap='viridis', annot = True)
axes[3].set_title('PCA Angle Circular Correlation Score Heatmap')
axes[3].set_xlabel('Ng (Number of Genes)')
axes[3].set_ylabel('Nc (Number of Cells)')

sns.heatmap(velocycle_circ_correlation, xticklabels=Ng_values, yticklabels=Nc_values, ax=axes[4], cmap='viridis', annot = True)
axes[4].set_title('Velocycle Circular Correlation Score Heatmap')
axes[4].set_xlabel('Ng (Number of Genes)')
axes[4].set_ylabel('Nc (Number of Cells)')

plt.tight_layout()
plt.show()

In [None]:
# make a box plot of the circular correlation scores of chunter, pca and velocycle
fig, axes = plt.subplots(1, 3, figsize=(15, 8))

# Define the y-ticks to be consistent across all subplots
y_ticks = [0.87, 0.88, 0.89, 0.90, 0.91, 0.92, 0.93, 0.94, 0.95, 0.96, 0.97]
# Create boxplots
sns.boxplot(data=circ_correlation, orient='v', ax=axes[0])
sns.boxplot(data=pca_circ_correlation, orient='v', ax=axes[1])
sns.boxplot(data=velocycle_circ_correlation, orient='v', ax=axes[2])

# Set consistent y-ticks for all subplots
for ax in axes:
    ax.set_yticks(y_ticks)
    ax.set_ylim(min(y_ticks), max(y_ticks))  # Ensure y-limits match the y-ticks

# Add titles and labels
axes[0].set_ylabel('Circular Correlation Score')
axes[0].set_title('Chunter Circular Correlation Score Boxplot')

axes[1].set_ylabel('Circular Correlation Score')
axes[1].set_title('PCA Circular Correlation Score Boxplot')

axes[2].set_ylabel('Circular Correlation Score')
axes[2].set_title('Velocycle Circular Correlation Score Boxplot')

# Adjust layout to avoid overlap
plt.tight_layout()

# Show the plot
plt.show()


In [114]:
chunter_coords = adata.obs['coords']

pca_coords = (pca_adata.obs['coords']+0.1) % 1

velocycle_coords = velocycle_adata.obs['coords']

simulated_coords = adata.obs['simulated_phis'] / (2*np.pi)




In [None]:
import matplotlib.pyplot as plt
import pandas as pd

def plot_pairwise_scatter(chunter_coords, pca_coords, velocycle_coords, simulated_coords):
    """
    Creates a 4x4 matrix of pairwise scatter plots between four sets of coordinates.
    
    Parameters:
    chunter_coords (pd.Series): Coordinates from chunter.
    pca_coords (pd.Series): Coordinates from PCA.
    velocycle_coords (pd.Series): Coordinates from velocycle.
    simulated_coords (pd.Series): Simulated coordinates.

    Returns:
    None: Displays the scatter plot matrix.
    """
    # Combine the coordinates into a DataFrame for easier iteration
    data = pd.DataFrame({
        'Chunter': chunter_coords,
        'PCA': pca_coords,
        'Velocycle': velocycle_coords,
        'Simulated': simulated_coords
    })
    
    labels = data.columns
    num_coords = len(labels)

    # Set up a 4x4 grid for pairwise scatter plots
    fig, axes = plt.subplots(num_coords, num_coords, figsize=(12, 12))
    plt.subplots_adjust(hspace=0.5, wspace=0.5)

    

    for i in range(num_coords):
        for j in range(num_coords):
            ax = axes[i, j]
            
            if i == j:
                # Diagonal plots: show the variable name
                ax.text(0.5, 0.5, labels[i], fontsize=12, ha='center', va='center')
                ax.set_xticks([])
                ax.set_yticks([])
            else:
                # Scatter plots for off-diagonal pairs
                ax.scatter(data.iloc[:, j], data.iloc[:, i], alpha=0.6, s=10)
                if j == 0:
                    ax.set_ylabel(labels[i])
                else:
                    ax.set_yticks([])
                if i == num_coords - 1:
                    ax.set_xlabel(labels[j])
                else:
                    ax.set_xticks([])



    plt.show()

plot_pairwise_scatter(chunter_coords, pca_coords, velocycle_coords, simulated_coords)


In [None]:
adata

In [None]:
# redo the previous two plots as subplots of a bigger plot
fig, axs = plt.subplots(1, 2, figsize = (10, 5))

# plot the first subplot
plt.sca(axs[0])

# calculate the correlation between the leadlag pcs and the simulated nu
corr = np.corrcoef(adata.varm['PCs'][:,0], adata.var['sim_sin'])[0,1]

# add corr as a legend
plt.scatter(adata.varm['PCs'][:,0], adata.varm['PCs'][:,1])

# label x axis
plt.xlabel('Imaginary Leadlag PC 1')

# label y axis
plt.ylabel('Simulated Sine')

plt.legend()

# plot the second subplot
plt.sca(axs[1])

# calculate the correlation between the leadlag pcs and the simulated nu
corr = np.corrcoef(adata.var['sim_cos'], adata.var['sim_sin'])[0,1]

# add corr as a legend
plt.scatter(adata.var['sim_cos'], adata.var['sim_sin'])

# label x axis
plt.xlabel('Real Leadlag PC 1')

# label y axis
plt.ylabel('Simulated Cosine')

plt.legend()

plt.show()

In [None]:
ll_adata = chnt.leadlag(adata[:,:300])

In [None]:
chnt.reverse(ll_adata)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import linregress

# initialize plot
plt.figure(figsize=(10, 6))

# make two subplots
fig, ax = plt.subplots(1, 2, figsize=(10, 5))

# Extract data for plots
imag_pc1 = np.imag(ll_adata.varm['leadlag_pcs'][:, 0])
real_pc1 = np.real(ll_adata.varm['leadlag_pcs'][:, 0])
sim_sin = ll_adata.var['sim_sin']
sim_cos = ll_adata.var['sim_cos']

# Calculate correlations and fit lines for plot 0
corr1, p_value1 = np.corrcoef(imag_pc1, sim_sin)[0, 1], None
slope1, intercept1, r_value1, p_value1, std_err1 = linregress(imag_pc1, sim_sin)
line1 = slope1 * imag_pc1 + intercept1

# Calculate correlations and fit lines for plot 1
corr2, p_value2 = np.corrcoef(real_pc1, sim_cos)[0, 1], None
slope2, intercept2, r_value2, p_value2, std_err2 = linregress(real_pc1, sim_cos)
line2 = slope2 * real_pc1 + intercept2

# Plot the leadlag PC 0 against sim_sin
ax[0].scatter(imag_pc1, sim_sin, alpha=0.6)
ax[0].plot(imag_pc1, line1, color='red', label=f"R={corr1:.2f}")
ax[0].set_xlabel('Imaginary Leadlag PC 1')
ax[0].set_ylabel('Simulated Sine')
ax[0].legend()

# Plot the leadlag PC 0 against sim_cos
ax[1].scatter(real_pc1, sim_cos, alpha=0.6)
ax[1].plot(real_pc1, line2, color='red', label=f"R={corr2:.2f}")
ax[1].set_xlabel('Real Leadlag PC 1')
ax[1].set_ylabel('Simulated Cosine')
ax[1].legend()

# Display the plot
plt.tight_layout()
plt.show()


In [None]:
# Extract the complex vectors
leadlag_pcs = ll_adata.varm['leadlag_pcs'][:, 0]
sim_vector = ll_adata.var['sim_cos'] + 1j * ll_adata.var['sim_sin']

# Calculate the phases (angles) of the complex vectors
leadlag_phases = np.angle(leadlag_pcs)
sim_phases = np.angle(sim_vector)

# Compute the circular correlation coefficient
circular_corr = circular_corrcoef(leadlag_phases, sim_phases)

# Create the scatter plot
plt.figure(figsize=(8, 6))
plt.scatter(sim_phases, leadlag_phases, alpha=0.6)

# Add labels and title
plt.xlabel('Simulated Phases')
plt.ylabel('Leadlag Phases')
plt.title(f'Phase Comparison: Simulated vs. Leadlag\nCircular Correlation: R={circular_corr:.2f}')
plt.axhline(0, color='gray', linestyle='--', linewidth=0.8)
plt.axvline(0, color='gray', linestyle='--', linewidth=0.8)

# Display the plot
plt.tight_layout()
plt.show()

In [None]:
# Extract the complex vectors
leadlag_pcs = ll_adata.varm['leadlag_pcs'][:, 0]
pcs = ll_adata.varm['PCs'][:, 0] + 1j * ll_adata.varm['PCs'][:, 1]
sim_vector = ll_adata.var['sim_cos'] + 1j * ll_adata.var['sim_sin']

# Calculate the phases (angles) of the complex vectors
leadlag_phases = np.angle(np.exp(1.1j) * pcs)
sim_phases = np.angle(sim_vector)

# Compute the circular correlation coefficient
circular_corr = circular_corrcoef(leadlag_phases, sim_phases)

# Create the scatter plot
plt.figure(figsize=(8, 6))
plt.scatter(sim_phases, leadlag_phases, alpha=0.6)

# Add labels and title
plt.xlabel('Simulated Phases')
plt.ylabel('PCA-based Phases')
plt.title(f'Phase Comparison: Simulated vs. PCA\nCircular Correlation: R={circular_corr:.2f}')
plt.axhline(0, color='gray', linestyle='--', linewidth=0.8)
plt.axvline(0, color='gray', linestyle='--', linewidth=0.8)

# Display the plot
plt.tight_layout()
plt.show()

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import linregress

# Extract the complex vectors
leadlag_pcs = ll_adata.varm['leadlag_pcs'][:, 0]
sim_vector = ll_adata.var['sim_cos'] + 1j * ll_adata.var['sim_sin']

# Calculate the amplitudes (magnitudes) of the complex vectors
leadlag_amplitudes = np.abs(leadlag_pcs)
sim_amplitudes = np.abs(sim_vector)

# Calculate the line of best fit and correlation coefficient
slope, intercept, r_value, p_value, std_err = linregress(sim_amplitudes, leadlag_amplitudes)
line_of_best_fit = slope * sim_amplitudes + intercept
correlation_coefficient = r_value

# Create the scatter plot
plt.figure(figsize=(8, 6))
plt.scatter(sim_amplitudes, leadlag_amplitudes, alpha=0.6, label='Data Points')

# Plot the line of best fit
plt.plot(sim_amplitudes, line_of_best_fit, color='red', label=f'Best Fit (R={correlation_coefficient:.2f})')

# Add labels and title
plt.xlabel('Simulated Amplitudes')
plt.ylabel('Leadlag Amplitudes')
plt.title('Amplitude Comparison with Best Fit and Correlation')
plt.axhline(0, color='gray', linestyle='--', linewidth=0.8)
plt.axvline(0, color='gray', linestyle='--', linewidth=0.8)

# Add legend
plt.legend()

# Display the plot
plt.tight_layout()
plt.show()


In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import linregress

# Extract the complex vectors
leadlag_pcs =ll_adata.varm['PCs'][:, 0] + 1j * ll_adata.varm['PCs'][:, 1]
sim_vector = ll_adata.var['sim_cos'] + 1j * ll_adata.var['sim_sin']

# Calculate the amplitudes (magnitudes) of the complex vectors
leadlag_amplitudes = np.abs(leadlag_pcs)
sim_amplitudes = np.abs(sim_vector)

# Calculate the line of best fit and correlation coefficient
slope, intercept, r_value, p_value, std_err = linregress(sim_amplitudes, leadlag_amplitudes)
line_of_best_fit = slope * sim_amplitudes + intercept
correlation_coefficient = r_value

# Create the scatter plot
plt.figure(figsize=(8, 6))
plt.scatter(sim_amplitudes, leadlag_amplitudes, alpha=0.6, label='Data Points')

# Plot the line of best fit
plt.plot(sim_amplitudes, line_of_best_fit, color='red', label=f'Best Fit (R={correlation_coefficient:.2f})')

# Add labels and title
plt.xlabel('Simulated Amplitudes')
plt.ylabel('PCA Amplitudes')
plt.title('Amplitude Comparison with Best Fit and Correlation')
plt.axhline(0, color='gray', linestyle='--', linewidth=0.8)
plt.axvline(0, color='gray', linestyle='--', linewidth=0.8)

# Add legend
plt.legend()

# Display the plot
plt.tight_layout()
plt.show()


In [None]:
cycle_pyro.genes

In [200]:
# show all methods in pyro_phase
phase_pyro.omegas


In [None]:
import numpy as np
from skimage.measure import marching_cubes
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d.art3d import Poly3DCollection

# Step 1: Define the function
def f(x, y, z):
    return x**2 + y**2 + z**2 - 1  # Example: A sphere of radius 1

# Step 2: Create a 3D grid
x = np.linspace(-2, 2, 100)  # Grid range for x
y = np.linspace(-2, 2, 100)  # Grid range for y
z = np.linspace(-2, 2, 100)  # Grid range for z
X, Y, Z = np.meshgrid(x, y, z)

# Compute the function values on the grid
values = f(X, Y, Z)

# Step 3: Use Marching Cubes to approximate the level set
# Level set value
level = 0
verts, faces, normals, values = marching_cubes(values, level, spacing=(x[1] - x[0], y[1] - y[0], z[1] - z[0]))

# Step 4: Visualize the result
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111, projection='3d')

# Create a 3D polygon collection
mesh = Poly3DCollection(verts[faces], alpha=0.7, edgecolor='k')
ax.add_collection3d(mesh)

# Set the limits of the plot
ax.set_xlim(-2, 2)
ax.set_ylim(-2, 2)
ax.set_zlim(-2, 2)

# Add labels
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')

plt.show()

In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Extract the complex vectors
leadlag_pcs = ll_adata.varm['leadlag_pcs'][:, 0]
sim_vector = ll_adata.var['sim_cos'] + 1j * ll_adata.var['sim_sin']

# Calculate the phases (angles) of the complex vectors
leadlag_phases = np.angle(leadlag_pcs)
sim_phases = np.angle(sim_vector)

# Create the scatter plot
plt.figure(figsize=(8, 6))
plt.scatter(sim_phases, phase_pyro.phis, alpha=0.6)

# Add labels and title
plt.xlabel('Simulated Phases (Sim Cos + i Sim Sin)')
plt.ylabel('Leadlag Phases (Leadlag PCs)')
plt.title('Phase Comparison: Simulated vs. Leadlag')
plt.axhline(0, color='gray', linestyle='--', linewidth=0.8)
plt.axvline(0, color='gray', linestyle='--', linewidth=0.8)

# Display the plot
plt.tight_layout()
plt.show()
