# Cell cycle dataset

This notebook is part of the paper titled, "Single-Cell Trajectory Inference for Detecting Transient Events in Biological Processes" by Hutton and Meyer. The data is from the 2025 Bubis et al. paper titled, "[Challenging the Astral mass analyzer to quantify up to 5,300 proteins per single cell at unseen accuracy to uncover cellular heterogeneity](https://doi.org/10.1038/s41592-024-02559-1)".

In [None]:
import numpy as np
import pandas as pd
from wavelet_pseudotime.windowing import GaussianWindow
from importlib import reload
import wavelet_pseudotime.synthetic
import wavelet_pseudotime.process
import wavelet_pseudotime.wavelets
from wavelet_pseudotime.wavelets import mag_median
from matplotlib import pyplot as plt
import scanpy as sc
import anndata as ad
import pypsupertime

import os
from datetime import datetime
date_str = datetime.now().strftime("%Y_%m_%d")
r_dir = f"{date_str}_cellcycle"
if not os.path.exists(r_dir):
    os.mkdir(r_dir)

In [None]:
thresh = 4
window_params = {"n_windows": 30, "sigma": 0.03, "max_distance": 0.11}

wr, scoresr, psdr, adata, psuper = wavelet_pseudotime.process.pipeline_astral_cellcycle(wavelet_pseudotime.load_data.load_astral,
                                                                                        window_params=window_params,
                                                                                        scoring_threshold=thresh,
                                                                                        coverage_threshold=0.0,
                                                                                        save_name=f"{r_dir}/astral2.h5ad",
                                                                                        exclude_pt_ends=(0.1,0.9),
                                                                                        repeat=True)

In [None]:
num_bins = 20

# Compute bin edges and assign bins
adata.obs['psupertime_bin'], bin_edges = pd.qcut(adata.obs['psupertime'], q=num_bins, labels=False, retbins=True)

# compute bin midpoints for correct x-axis scaling
bin_midpoints = (bin_edges[:-1] + bin_edges[1:]) / 2
phase_proportions = adata.obs.groupby(['psupertime_bin', 'phase']).size().unstack(fill_value=0)
phase_proportions = phase_proportions.div(phase_proportions.sum(axis=1), axis=0)

# for use in large figure at end of notebook
fig9_x = bin_midpoints
fig9_y = phase_proportions.T.values
fig9_labels=phase_proportions.columns

# Replot with original psupertime values on x-axis
plt.figure(figsize=(8, 5))
plt.stackplot(bin_midpoints, phase_proportions.T.values, labels=phase_proportions.columns, alpha=0.8)

# Formatting
plt.xlabel("Psupertime")  # Change x-axis label
plt.ylabel("Proportion of Cells")
plt.title("Stacked Cell Cycle Phase Proportions Across Psupertime")
plt.legend(title="Phase", bbox_to_anchor=(1.05, 1), loc='upper left')
plt.grid(False)


In [None]:
phase_proportions

In [None]:
phase_proportions.shape[0]

In [None]:
last_g1_idx = np.where(phase_proportions["G1"] <= 0.5)[0][0]  # 5
last_s_idx = np.where(phase_proportions["S"][last_g1_idx+1:] <= 0.5)[0][0] + last_g1_idx  # 11
last_g1_idx /= phase_proportions.shape[0]  # get fraction along pt  (this is bad variable naming, I know)
last_s_idx /= phase_proportions.shape[0]  # get fraction along pt
# pt_min = np.min(adata.obs["psupertime"])
# pt_max = np.max(adata.obs["psupertime"])
pt_min = 0
pt_max = 29
last_g1 = (pt_max - pt_min) * last_g1_idx + pt_min
last_s = (pt_max-pt_min) * last_s_idx + pt_min

In [None]:
adata.obs["phase"].value_counts()

In [None]:
thresh = 7
g_above_thresh = [k for k, v in scoresr.items() if v > thresh]
print(len(g_above_thresh))

In [None]:
pt_above_thresh = []
for g in g_above_thresh:
    pt_above_thresh.append(psdr[g])

In [None]:
all_scores = list(scoresr.values())
plt.hist(all_scores, bins=500);
plt.xlabel("Score")
plt.ylabel("Frequency")
plt.title("Score distribution for genes in cell cycle dataset")

In [None]:
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
import numpy as np
from collections import defaultdict as dd
kmeans = KMeans(n_clusters=5, random_state=0)
labels = kmeans.fit_predict(pt_above_thresh)

In [None]:
c = dd(list)
for idx_g, g in enumerate(g_above_thresh):
    c[labels[idx_g]].append(g)

In [None]:
for idx in range(5):
    fig, ax = plt.subplots()
    for g in c[idx]:
        ax.plot(psdr[g])
    ax.axvline(last_g1, linestyle="--")
    ax.axvline(last_s, linestyle="--")
    ax.set_title(f"Cluster {idx}")
    ax.set_xticks([5,15,25])
    ax.set_xticklabels(["G1", "S", "G2M"])
    fig.savefig(f"{r_dir}/cluster_{idx}.png")

In [None]:
# write out genes in each cluster
for idx in range(5):
    f = open(f"{r_dir}/genes_group_{idx}.txt", "w")
    for g in c[idx]:
        f.write(f"{g}\n")
    f.close()

In [None]:
## For writing out all genes

# f = open(f"{r_dir}/astral_genes.txt", "w")
# for g in g_above_thresh:
#     print(g)
#     f.write(f"{g}\n")
# f.close()

In [None]:
# Compute G1 and S to be in PT instead of by index

last_g1_idx = np.where(phase_proportions["G1"] <= 0.5)[0][0]  # 5
last_s_idx = np.where(phase_proportions["S"][last_g1_idx+1:] <= 0.5)[0][0] + last_g1_idx  # 11
last_g1_idx /= phase_proportions.shape[0]  # get fraction along pt
last_s_idx /= phase_proportions.shape[0]  # get fraction along pt
pt_min = np.min(adata.obs["psupertime"])
pt_max = np.max(adata.obs["psupertime"])
# pt_min = 0
# pt_max = 29
last_g1_pt = (pt_max - pt_min) * last_g1_idx + pt_min
last_s_pt = (pt_max-pt_min) * last_s_idx + pt_min

In [None]:
g = "AK6"
x = np.linspace(np.min(adata.obs["psupertime"]), np.max(adata.obs["psupertime"]), len(psdr[g]))
plt.plot(x, psdr[g])
plt.plot(adata.obs["psupertime"], adata[:, "AK6"].X[:, 0], ".")
plt.axvline(last_g1_pt, linestyle=":", label="Transition to S")
plt.axvline(last_s_pt, linestyle="--", label="Transition to G2/M")
plt.legend()
plt.title(f"{g} expression along pseudotime")
plt.xlabel("Pseudotime")
plt.ylabel("Gene expression")
plt.savefig(f"{r_dir}/{g}_expression.png")

In [None]:
g = "ATL2"
x = np.linspace(np.min(adata.obs["psupertime"]), np.max(adata.obs["psupertime"]), len(psdr[g]))
plt.plot(x, psdr[g])
plt.plot(adata.obs["psupertime"], adata[:, g].X[:, 0], ".")
plt.axvline((pt_max - pt_min)*5/20 + pt_min, linestyle=":", label="Transition to S")
plt.axvline((pt_max-pt_min)*11/20 + pt_min, linestyle="--", label="Transition to G2/M")
plt.legend()
plt.title(f"{g} expression along pseudotime")
plt.xlabel("Pseudotime")
plt.ylabel("Gene expression")
plt.savefig(f"{r_dir}/{g}_expression.png")

# Known cell cycle markers

In [None]:
f = open('regev_lab_cell_cycle_genes.txt', "r")
cell_cycle_genes = [x.strip() for x in f]
f.close()
s_genes = cell_cycle_genes[:43]
g2m_genes = cell_cycle_genes[43:]
cell_cycle_genes = [x for x in cell_cycle_genes if x in adata.var_names]

In [None]:
# which S genes are in our data?
s_in_data = set([s.lower() for s in s_genes]).intersection(set([v.lower() for v in adata.var_names]))
s_in_data = [s.upper() for s in s_in_data]
s_in_data.sort()

# which g2 genes are in our data?
g2_in_data = set([s.lower() for s in g2m_genes]).intersection(set([v.lower() for v in adata.var_names]))
g2_in_data = [s.upper() for s in g2_in_data]
g2_in_data.sort()

In [None]:
print(f"Number of S-related genes in the dataset: {len(s_in_data)}")

In [None]:
print(f"Number of G2-related genes in the dataset: {len(g2_in_data)}")

In [None]:
fig, axs = plt.subplots(8,3)
fig.set_figheight(12)
fig.set_figwidth(8)
for idx, g in enumerate(s_in_data):
    i, j = np.unravel_index(idx, axs.shape)
    ax = axs[i,j]
    x = np.linspace(np.min(adata.obs["psupertime"]), np.max(adata.obs["psupertime"]), len(psdr[g]))
    ax.plot(x, psdr[g], label="Windowed signal")
    ax.plot(adata.obs["psupertime"], adata[:, g].X[:, 0], ".", label="Cell data")
    ax.axvline((pt_max - pt_min)*5/20 + pt_min, linestyle=":", label="Transition to S")
    ax.axvline((pt_max-pt_min)*11/20 + pt_min, linestyle="--", label="Transition to G2/M")
    # ax.legend()
    ax.set_title(f"{g}")
    # ax.set_xlabel("Pseudotime")
    # ax.set_ylabel("Gene expression")
    if idx == 11:
        ax.legend(bbox_to_anchor=(1,1))
fig.tight_layout(rect=[0, 0.03, 1, 0.95])
fig.suptitle("S genes along pseudotime")

plt.savefig(f"{r_dir}/s_gene_expression.png")

In [None]:
num_plots = len(g2_in_data)
num_columns = 3
num_rows = num_plots // num_columns
if num_plots % num_columns != 0:
    num_rows += 1


fig, axs = plt.subplots(num_rows, num_columns)
fig.set_figheight(4/3*num_rows)
fig.set_figwidth(8)
for idx, g in enumerate(g2_in_data):
    i, j = np.unravel_index(idx, axs.shape)
    ax = axs[i,j]
    x = np.linspace(np.min(adata.obs["psupertime"]), np.max(adata.obs["psupertime"]), len(psdr[g]))
    ax.plot(x, psdr[g], label="Windowed signal")
    ax.plot(adata.obs["psupertime"], adata[:, g].X[:, 0], ".", label="Cell data")
    ax.axvline((pt_max - pt_min)*5/20 + pt_min, linestyle=":", label="Transition to S")
    ax.axvline((pt_max-pt_min)*11/20 + pt_min, linestyle="--", label="Transition to G2/M")
    # ax.legend()
    ax.set_title(f"{g}")
    # ax.set_xlabel("Pseudotime")
    # ax.set_ylabel("Gene expression")
    if idx == 11:
        ax.legend(bbox_to_anchor=(1,1))
fig.tight_layout(rect=[0, 0.03, 1, 0.95])
fig.suptitle("G2M genes along pseudotime")

plt.savefig(f"{r_dir}/g2m_gene_expression.png")

In [None]:
from gprofiler import GProfiler

In [None]:
from gprofiler import GProfiler
for idx in range(4):
    gp = GProfiler(return_dataframe=True)
    results = gp.profile(organism="hsapiens", query=c[idx], sources=['GO:BP', 'GO:MF', 'GO:CC'])
    results.sort_values(by="p_value", ascending=True)
    sig_res = results[results["p_value"] < 0.05]
    break
    sig_res[["source", "native", "name", "p_value", "intersection_size"]].to_csv(f"{r_dir}/astral_enrichment_cluster{idx}.csv")

In [None]:
def plot_df(df: pd.DataFrame, title: str = None, save=None, ax=None, pmin=None, pmax=None, sources=None, annotated_names: list[str] = None
) -> None:
    """
    Plots each row of the DataFrame as a circle grouped by the 'source' column.
    The horizontal axis displays -log10(p_value) and the vertical positions
    are arranged based on the source group with added jitter.
    
    A legend is added for both the source groups and the circle size scale (intersection_size).
    
    Parameters:
        df (pd.DataFrame): A DataFrame containing the columns:
            - 'source': categorical column with 3 categories.
            - 'p_value': continuous values.
            - 'intersection_size': integers (will be used to scale circle sizes).
            - 'name': a descriptor for the row (unused in the plot).
    """
    # Compute the horizontal position: -log10(p_value)
    # (Make sure there are no p_value values equal to 0)
    df = df.copy()  # Avoid modifying the original DataFrame
    fontsize=16
    if ax is None:
        fig, ax = plt.subplots()
        
    if (df["p_value"] <= 0).any():
        raise ValueError("All p_value entries must be positive so that -log10 can be computed.")

    df["neg_log10"] = -np.log10(df["p_value"])

    # Create a mapping for each unique source to a base y-position.
    if sources is None:
        unique_sources = sorted(df["source"].unique())
    else:
        unique_sources = sorted(np.unique(sources))
    source_to_index = {source: idx for idx, source in enumerate(unique_sources, start=1)}

    # Map sources to base y positions.
    df["base_y"] = df["source"].map(source_to_index)

    # Add vertical jitter to separate the circles
    np.random.seed(0)  # For reproducibility
    jitter = np.random.uniform(-0.2, 0.2, size=len(df))
    df["y_pos"] = df["base_y"] + jitter

    # Create the plot

    # Plot each group with its own color and label.
    for source in unique_sources:
        subset = df[df["source"] == source]
        ax.scatter(
            subset["neg_log10"],
            subset["y_pos"],
            s=subset["intersection_size"] * 10,  # Scale circle sizes; adjust factor as needed.
            alpha=0.7,
            label=source,  # This will be used in the legend for sources.
            edgecolors="w"
        )

    ax.set_xlabel("-log10(p_value)", fontsize=fontsize)
    ax.set_yticks(list(source_to_index.values()), list(source_to_index.keys()), fontsize=fontsize)
    ax.set_ylim([0,4])
    # plt.ylabel("Source Group")
    if title is None:
        ax.set_title("Function Enrichment Analysis", fontsize=fontsize)
    else:
        ax.set_title(title, fontsize=fontsize)

    # First, add the legend for the source groups.
    # source_legend = plt.legend(title="Source", loc="upper right")
    # plt.gca().add_artist(source_legend)

    # Now, create a legend for the circle sizes corresponding to 'intersection_size'.
    # Use three representative sizes: min, median, and max.
    size_min = df["intersection_size"].min()
    # size_median = int(df["intersection_size"].median())
    
    size_max = df["intersection_size"].max()
    size_median = int((size_min + size_max)/2) #int(df["intersection_size"].median())
    size_scale = 10  # This is the factor applied to intersection_size for the marker size

    sizes = [size_min, size_median, size_max]
    markers = [
        ax.scatter([], [], s=size * size_scale, color="gray", alpha=0.7, edgecolors="w")
        for size in sizes
    ]
    labels = [f"{size}" for size in sizes]

    if annotated_names:
    # You could adjust the base offsets for arrow text.
        offset_x = 0.5
        offset_y = 0.5
        for i, row in df.iterrows():
            if row["name"] in annotated_names:
                x_point = row["neg_log10"]
                y_point = row["y_pos"]
                x_text = x_point + offset_x
                y_text = y_point + offset_y
                ax.annotate(
                    row["name"],
                    xy=(x_point, y_point),
                    xytext=(x_text, y_text),
                    arrowprops=dict(facecolor="black", arrowstyle="->"),
                    fontsize=10,
                    bbox=dict(boxstyle="round,pad=0.3", fc="yellow", alpha=0.5)
                )

    ax.legend(markers, labels, title="Intersection Size", bbox_to_anchor=(1.05, 1), loc="upper left", borderaxespad=0)
    ax.grid()
    if pmin is not None and pmax is not None:
        ax.set_xlim([pmin, pmax])
    plt.tight_layout()
    if save is not None:
        plt.savefig(save)


In [None]:
results = []
sources=['GO:BP', 'GO:MF', 'GO:CC']
for idx in range(5):
    gp = GProfiler(return_dataframe=True)
    results.append(gp.profile(organism="hsapiens", query=c[idx], sources=sources))
    results[-1].sort_values(by="p_value", ascending=True)
    sig_res = results[-1][results[-1]["p_value"] < 0.05]
    # break
    if sig_res.shape[0] == 0:
        continue
    sig_res[["source", "native", "name", "p_value", "intersection_size"]].to_csv(f"{r_dir}/astral_enrichment_cluster{idx}.csv")

In [None]:
fig, axs = plt.subplots(3)
fig.set_figheight(8)
fig_idx = 0
fontsize=16
p_min = np.inf
p_max = -np.inf
annot_lists = []
annot_lists.append(["fatty acid catabolic process", "phagocytic vesicle membrane"])
annot_lists.append(["DNA replication", "nuclear chromosome"])
annot_lists.append(["N-acylsphingosine amidohydrolase activity", "tertiary granule lumen"])
for r2 in results:
    r = r2[r2["p_value"] < 0.05]
    if r.shape[0] == 0:
        continue
    mmin = np.min(r["p_value"])
    mmax = np.max(r["p_value"])
    p_min = np.min([mmin, p_min])
    p_max = np.max([mmax, p_max])
p_min *= 0.8
p_max *= 1.3
for idx in range(5):
    # gp = GProfiler(return_dataframe=True)
    # results = gp.profile(organism="hsapiens", query=c[idx], sources=['GO:BP', 'GO:MF', 'GO:CC'])
    # results.sort_values(by="p_value", ascending=True)
    sig_res = results[idx][results[idx]["p_value"] < 0.05]
    # break
    if sig_res.shape[0] == 0:
        continue
    # sig_res[["source", "native", "name", "p_value", "intersection_size"]].to_csv(f"paper_figures/astral_enrichment_cluster{idx}.csv")
    plot_df(sig_res, title=f"Enrichment for Cluster {idx}", ax=axs[fig_idx], pmin = -np.log10(p_max), pmax=-np.log10(p_min), sources=sources, annotated_names=annot_lists[idx])
    # break
    fig_idx += 1
# print(results)
plt.savefig(f"{r_dir}/fig9_cellcycle_enrichment.png")
plt.savefig(f"{r_dir}/fig9_cellcycle_enrichment.svg")

# Figure assembly

In [None]:
import matplotlib.gridspec as gridspec
from string import ascii_uppercase

In [None]:
adata

In [None]:
sc.tl.umap(adata)
sc.pl.umap(adata)

In [None]:
from collections import defaultdict as dd
adata_pscs = sc.read_h5ad("pscs_cellcycle.h5ad")
cluster_genes = {}
pt_signals = dd(list)
var_list = list(adata_pscs.var_names)
for idx in range(4):
    cluster_genes[idx] = list(adata_pscs.uns["te_cluster"].loc[(adata_pscs.uns["te_cluster"] == idx).values].index)
    print(len(cluster_genes[idx]))
    for g in cluster_genes[idx]:
        g_idx = var_list.index(g)
        pt_signals[idx].append(adata_pscs.uns["pseudotime_signals"][:, g_idx])
sc.tl.umap(adata_pscs)
sc.pl.umap(adata_pscs)

In [None]:
fig = plt.figure(figsize=(10, 12))
gs = gridspec.GridSpec(4, 2, figure=fig)
fontsize=16

ax = fig.add_subplot(gs[1,1])
psupertime_figure = psuper.plot_identified_gene_coefficients(adata, n_top=10, ax=ax)
ax.set_title("Genes for psupertime", fontsize=fontsize)
ax.text(0.05, 0.15, ascii_uppercase[3], transform=ax.transAxes,
            fontsize=16, fontweight='bold', va='top', ha='left')



ax = fig.add_subplot(gs[2, :])

# Replot with original psupertime values on x-axis
ax.stackplot(bin_midpoints, phase_proportions.T.values, labels=phase_proportions.columns, alpha=0.8)

# Formatting
ax.set_xlabel("Pseudotime", fontsize=fontsize)  # Change x-axis label
ax.set_ylabel("Proportion of Cells", fontsize=fontsize)
ax.set_title("Cell Cycle Phase Proportions Across Pseudotime", fontsize=fontsize)
ax.legend(title="Phase", loc='upper right')
ax.grid(False)
ax.set_xlim(bin_midpoints[0], bin_midpoints[-1])
# Remove all margins from both axes.
ax.margins(x=0, y=0)
# Adjust subplot parameters to use all the figure area.
plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
ax.text(0.025, 0.95, ascii_uppercase[4], transform=ax.transAxes,
            fontsize=16, fontweight='bold', va='top', ha='left')

for idx in [1,3]:
    if idx == 1:
        plot_idx = 0
    else:
        plot_idx = 1
    ax = fig.add_subplot(gs[3, plot_idx])
    for ii in range(len(pt_signals[idx])):
        ax.plot(pt_signals[idx][ii])
    ax.axvline(last_g1, linestyle="--")
    ax.axvline(last_s, linestyle="--")
    ax.set_title(f"Cluster {idx}", fontsize=fontsize)
    g1_tick = last_g1/2
    s_tick = (last_s + last_g1)/2
    g2_tick = (len(psdr[g]) + last_s)/2
    ax.set_xticks([g1_tick, s_tick, g2_tick])
    ax.set_xticklabels(["G1", "S", "G2M"], fontsize=fontsize-2)
    ax.set_xlabel("Pseudotime", fontsize=fontsize)
    ax.text(0.05, 0.95, ascii_uppercase[plot_idx+5], transform=ax.transAxes,
            fontsize=16, fontweight='bold', va='top', ha='left')


ax_sc = fig.add_subplot(gs[0,0])
sc.pl.umap(adata, color=["leiden"], ax=ax_sc, show=False, s=800)
ax_sc.set_title("Leiden clusters", fontsize=fontsize)
ax_sc.text(0.05, 0.95, ascii_uppercase[0], transform=ax_sc.transAxes,
            fontsize=16, fontweight='bold', va='top', ha='left')

ax_sc = fig.add_subplot(gs[1,0])
sc.pl.umap(adata, color=["phase"], ax=ax_sc, show=False, s=800)
ax_sc.set_title("Cell cycle phase", fontsize=fontsize)
ax_sc.text(0.05, 0.95, ascii_uppercase[2], transform=ax_sc.transAxes,
            fontsize=16, fontweight='bold', va='top', ha='left')


ax_sc = fig.add_subplot(gs[0,1])
sc.pl.umap(adata, color=["n_genes"], ax=ax_sc, show=False, s=800)
ax_sc.text(0.05, 0.95, ascii_uppercase[1], transform=ax_sc.transAxes,
            fontsize=16, fontweight='bold', va='top', ha='left')

fig.tight_layout()
plt.savefig(f"{r_dir}/fig10_pscs_cellcycle_cluster_pseudotimecourses.png")
plt.savefig(f"{r_dir}/fig10_pscs_cellcycle_cluster_pseudotimecourses.svg")

In [None]:
fig, axs = plt.subplots(1,2)
fig.set_figheight(4)
fig.set_figwidth(8)
mmin = np.min(adata.obs["psupertime"])
mmax = np.max(adata.obs["psupertime"])
pt = np.linspace(mmin, mmax, len(psdr["PCNA"]))
last_g1_pt = (mmax-mmin)*last_g1/30 + mmin
last_s_pt = (mmax-mmin)*last_s/30 + mmin
for idx, g in enumerate(["PCNA", "UNG"]):
    ax = axs[np.unravel_index(idx, axs.shape)]
    ax.plot(pt, psdr[g])
    ax.plot(adata.obs["psupertime"], adata[:, g].X[:,0], ".")
    ax.axvline(last_g1_pt, linestyle="--")
    ax.axvline(last_s_pt, linestyle="--")
    ax.set_title(f"{g}")
    # ax.set_xlabel("Pseudotime")
    g1_tick = (mmin+last_g1_pt)/2
    s_tick = (last_s_pt + last_g1_pt)/2
    g2_tick = (mmax + last_s_pt)/2
    # ax.set_xticks([5,12,25])
    ax.set_xticks([g1_tick, s_tick, g2_tick])
    ax.set_xticklabels(["G1", "S", "G2M"])
    if idx == 0:
        ax.set_ylabel("Protein Quant.")
    # ax.set_title(f"{g} - {adata[:, g].var['coverage'][g]}")
fig.tight_layout()
fig.savefig(f"{r_dir}/pcna_ung_pseudotimecourses.png")
fig.savefig(f"{r_dir}/pcna_ung_pseudotimecourses.svg")

### Cell cycle genes

Since we are using supervised pseudotime to determine cell cycle phase, it is possible that the clusters we are identifying are simply groups of those same genes. This section clusters the pseudotimecourses of those genes and shows that they do not exhibit the same behavior as those identified by scTransient.

In [None]:
from sklearn.cluster import KMeans

In [None]:
cell_cycle_genes_in_data = set(cell_cycle_genes).intersection(set(adata.var_names))
cell_cycle_genes_in_data = list(cell_cycle_genes_in_data)

In [None]:
psdr_mat = np.zeros((len(cell_cycle_genes_in_data), len(psdr["UNG"])))
regev_gene_to_idx = {}
idx_to_regev_gene = {}
for idx, g in enumerate(cell_cycle_genes_in_data):
    psdr_mat[idx, :] = psdr[g]
    regev_gene_to_idx[g] = idx
    idx_to_regev_gene[idx] = g

In [None]:
k=5
km = KMeans(n_clusters=k, random_state=42, n_init=10)
clusters = km.fit_predict(psdr_mat)


In [None]:
# count = 0
genes_per_cluster = dd(list)
pt_range = np.linspace(mmin, mmax, len(psdr["UNG"]))
fig, axs = plt.subplots(3,2)
for idx in range(5):
    ax = axs[np.unravel_index(idx, axs.shape)]
    for gidx, c in enumerate(clusters):
        if c == idx:
            ax.plot(pt_range, psdr_mat[gidx, :])
            genes_per_cluster[idx].append(cell_cycle_genes_in_data[gidx])
    ax.set_title(f"Cluster {idx}")
    ax.axvline(last_g1_pt, linestyle="--")
    ax.axvline(last_s_pt, linestyle="--")
    g1_tick = (mmin+last_g1_pt)/2
    s_tick = (last_s_pt + last_g1_pt)/2
    g2_tick = (mmax + last_s_pt)/2
    # ax.set_xticks([5,12,25])
    ax.set_xticks([g1_tick, s_tick, g2_tick])
    ax.set_xticklabels(["G1", "S", "G2M"])
    # count += 1
fig.tight_layout()
fig.delaxes(axs[2,1])
plt.savefig(f"{r_dir}/clustered_cell_cycle_genes.png")
plt.savefig(f"{r_dir}/clustered_cell_cycle_genes.svg")

In [None]:
for idx in range(5):
    f = open(f"{r_dir}/cell_cycle_cluster_{idx}.txt", "w")
    for g in genes_per_cluster[idx]:
        f.write(f"{g}\n")
    f.close()

# genes_per_cluster