# Hematopoiesis dataset

This notebook is part of the paper titled, "Single-Cell Trajectory Inference for Detecting Transient Events in Biological Processes" by Hutton and Meyer. The data is from the 2015 Paul et al. paper titled, "[ Transcriptional Heterogeneity and Lineage Commitment in Myeloid Progenitors](https://doi.org/10.1016/j.cell.2015.11.013)". The file is made available through [Scanpy's](https://scanpy.readthedocs.io/en/stable/) datasets.

In [None]:
import numpy as np
import pandas as pd
from wavelet_pseudotime.windowing import GaussianWindow
from importlib import reload
import wavelet_pseudotime.synthetic
import wavelet_pseudotime.process
import wavelet_pseudotime.wavelets
from wavelet_pseudotime.wavelets import mag_median
from matplotlib import pyplot as plt
import scanpy as sc
import anndata as ad
import os
from wavelet_pseudotime.wavelets import std_from_median
from mpl_toolkits.axes_grid1 import make_axes_locatable

from datetime import datetime
date_str = datetime.now().strftime("%Y_%m_%d")
r_dir = f"{date_str}_paul15"
if not os.path.exists(r_dir):
    os.mkdir(r_dir)
sc.settings.figdir = r_dir

In [None]:
paul15 = wavelet_pseudotime.load_data.load_paul15(process_data=True)

In [None]:
sc.tl.umap(paul15)
sc.tl.paga(paul15, groups="paul15_clusters")
plt.imshow(paul15.uns["paga"]["connectivities"].todense())
plt.colorbar()

In [None]:
sc.tl.paga(paul15, groups="paul15_clusters")
sc.pl.paga(paul15, threshold=0.8, save="_connectivity_thresh08.png")

In [None]:
node_labels = list(paul15.obs["paul15_clusters"].cat.categories)
print(node_labels)

In [None]:
# Find optimal path between initial (stem cell) cluster and erythrocyte cluster.

paul15_path, dist = wavelet_pseudotime.graph_traversal.find_min_cost_path(paul15.uns["paga"]["connectivities"], 7, 0)
node_labels = list(paul15.obs["paul15_clusters"].cat.categories)
wavelet_pseudotime.plotting.draw_path(paul15.uns["paga"]["pos"], paul15_path, paul15.uns["paga"]["connectivities"], node_labels=node_labels)
paul15_labels_path = [list(paul15.obs["paul15_clusters"].cat.categories)[p] for p in paul15_path]
plt.xticks([])
plt.yticks([])
print(paul15_labels_path)

In [None]:
paul15

In [None]:
paul15.obs["psupertime"] = paul15.obs["dpt_pseudotime"]

In [None]:
paul15

In [None]:
# reload(wavelet_pseudotime.process)

In [None]:
waves, scores, psd, adata = wavelet_pseudotime.process.pipeline_paul15v2(paul15, trajectory=paul15_labels_path, node_col="paul15_clusters",
                                                                        scoring_threshold=3,
                                                                        exclude_pt_ends=(0.1,0.9))

In [None]:
scores_dist = [v for v in scores.values()]

In [None]:
plt.hist(scores_dist, bins=100);
plt.xlabel("Scores")
plt.ylabel("Frequency")
plt.title("Distribution of gene wavelet scores")
# plt.savefig(f"{r_dir}/wavelet_scores_dist.png")

In [None]:
for g in waves.keys():
    plt.figure()
    mmin = np.min(paul15.obs["psupertime"])
    mmax = np.max(paul15.obs["psupertime"])
    x = np.linspace(mmin, mmax, len(psd[g]))
    idx = paul15.obs.index[paul15.obs["paul15_clusters"].isin(paul15_labels_path)]
    
    plt.plot(paul15[idx, g].obs["psupertime"], paul15[idx, g].X[:,0], ".")
    plt.plot(x, psd[g])
    plt.title(g)
    plt.savefig(f"{r_dir}/paul15_wavelet_detect_{g}.png")

In [None]:
f = open(f"{r_dir}/paul15_genes.txt", "w")
for g in waves.keys():
    f.write(f"{g}\n")
f.close()

In [None]:
kkeys = list(scores.keys())
s = []
for k in kkeys:
    s.append(scores[k])
# s = [v for v in scores.values()]

In [None]:
fig, ax = plt.subplots()

ax.hist(s, bins=100, edgecolor="black");
ax.set_xlabel("Gene score")
ax.set_ylabel("Frequency")
ax.set_title("Distribution of Gene Scores")
ax.axvline(4, color='red', label="Cutoff")
ax.legend()
# plt.savefig(f"{r_dir}/paul15_gene_score.png")

In [None]:
idx = np.where(np.array(s) > 4)[0]

In [None]:
signal_keys = []
for i in idx:
    print(kkeys[i])
    signal_keys.append(kkeys[i])

In [None]:
len(signal_keys)

In [None]:
wt = wavelet_pseudotime.wavelets.WaveletTransform(scales=[i for i in range(1,4)], wavelet="mexh")

In [None]:
for g in signal_keys:
    fig, (ax0, ax1) = plt.subplots(2,1)
    mmin = np.min(paul15.obs["psupertime"])
    mmax = np.max(paul15.obs["psupertime"])
    x = np.linspace(mmin, mmax, len(psd[g]))
    idx = paul15.obs.index[paul15.obs["paul15_clusters"].isin(paul15_labels_path)]
    
    ax0.plot(paul15[idx, g].obs["psupertime"], paul15[idx, g].X[:,0], ".")
    ax0.plot(x, psd[g])
    ax0.set_title(g)
    ax0.set_xlim([0,1])
    ax0.set_ylabel("Gene expression")
    coefs, _ = wt.apply(psd[g])
    coefs_std = std_from_median(coefs)
    
    im = ax1.imshow(np.sqrt(np.abs(coefs * coefs_std)))
    ax1.set_yticks([i for i in range(0,coefs.shape[0])])
    ax1.set_yticklabels([i for i in range(1,coefs.shape[0]+1)])
    ax1.set_xticks([])
    ax1.set_ylabel("Scale")
    ax1.set_title("Square root of score")
    fig.tight_layout()
    if not os.path.exists(f"{r_dir}/paul15_pseudotimecourses"):
        os.mkdir(f"{r_dir}/paul15_pseudotimecourses")
    fig.savefig(f"{r_dir}/paul15_pseudotimecourses/{g}.png", bbox_inches="tight")
    fig.savefig(f"{r_dir}/paul15_pseudotimecourses/{g}.svg", bbox_inches="tight")

In [None]:
paul15_labels_path

In [None]:
fig, ax = plt.subplots(figsize=(10, 7))

keep_idx = adata.obs["paul15_clusters"].isin(paul15_labels_path)
adata2 = adata[keep_idx,:].copy()
number_bins = 10
adata2.obs['pt_bin'] = pd.cut(adata2.obs['psupertime'], bins=number_bins)

# in_path = adata2.obs["paul15_clusters"]
ct = pd.crosstab(adata2.obs['pt_bin'], adata2.obs['paul15_clusters'])
fractions = ct.div(ct.sum(axis=1), axis=0)
bin_midpoints = np.array([(interval.left + interval.right) / 2 for interval in fractions.index])
categories = fractions.columns.tolist()
stack_data = [fractions[cat].values for cat in categories]

ax.stackplot(bin_midpoints, stack_data, labels=categories)
ax.legend()


# Multi-panel figure for paper

In [None]:
from string import ascii_uppercase
import matplotlib.gridspec as gridspec

In [None]:
from gprofiler import GProfiler
gp = GProfiler(return_dataframe=True)
results = gp.profile(organism="hsapiens", query=signal_keys, sources=['GO:BP', 'GO:MF', 'GO:CC'])
results.sort_values(by="p_value", ascending=True)
sig_results = results[results["p_value"] < 0.05]

In [None]:
sources=['GO:BP', 'GO:MF', 'GO:CC']
p_min = np.min(sig_results["p_value"])
p_max = np.max(sig_results["p_value"])
annot = ["cell redox homeostasis", "peroxiredoxin activity", "heme metabolic process"]
# p_min = np.min([mmin, p_min])
# p_max = np.max([mmax, p_max])
p_min *= 0.8
p_max *= 1.3

In [None]:
def plot_df(df: pd.DataFrame, title: str = None, save=None, ax=None, pmin=None, pmax=None, sources=None, annotated_names: list[str] = None
) -> None:
    """
    Plots each row of the DataFrame as a circle grouped by the 'source' column.
    The horizontal axis displays -log10(p_value) and the vertical positions
    are arranged based on the source group with added jitter.
    
    A legend is added for both the source groups and the circle size scale (intersection_size).
    
    Parameters:
        df (pd.DataFrame): A DataFrame containing the columns:
            - 'source': categorical column with 3 categories.
            - 'p_value': continuous values.
            - 'intersection_size': integers (will be used to scale circle sizes).
            - 'name': a descriptor for the row (unused in the plot).
    """
    # Compute the horizontal position: -log10(p_value)
    # (Make sure there are no p_value values equal to 0)
    df = df.copy()  # Avoid modifying the original DataFrame
    if ax is None:
        fig, ax = plt.subplots()
        
    if (df["p_value"] <= 0).any():
        raise ValueError("All p_value entries must be positive so that -log10 can be computed.")

    df["neg_log10"] = -np.log10(df["p_value"])

    # Create a mapping for each unique source to a base y-position.
    if sources is None:
        unique_sources = sorted(df["source"].unique())
    else:
        unique_sources = sorted(np.unique(sources))
    source_to_index = {source: idx for idx, source in enumerate(unique_sources, start=1)}

    # Map sources to base y positions.
    df["base_y"] = df["source"].map(source_to_index)

    # Add vertical jitter to separate the circles
    np.random.seed(0)  # For reproducibility
    jitter = np.random.uniform(-0.2, 0.2, size=len(df))
    df["y_pos"] = df["base_y"] + jitter

    # Create the plot
    # plt.figure(figsize=(4,4))

    # Plot each group with its own color and label.
    for source in unique_sources:
        subset = df[df["source"] == source]
        ax.scatter(
            subset["neg_log10"],
            subset["y_pos"],
            s=subset["intersection_size"] * 10,  # Scale circle sizes; adjust factor as needed.
            alpha=0.7,
            label=source,  # This will be used in the legend for sources.
            edgecolors="w",
        )

    ax.set_xlabel("-log10(p_value)", fontsize=16)
    ax.set_yticks(list(source_to_index.values()), list(source_to_index.keys()), fontsize=14)
    ax.set_ylim([0,4])
    # plt.ylabel("Source Group")
    if title is None:
        ax.set_title("Function Enrichment Analysis")
    else:
        ax.set_title(title)

    # First, add the legend for the source groups.
    # source_legend = plt.legend(title="Source", loc="upper right")
    # plt.gca().add_artist(source_legend)

    # Now, create a legend for the circle sizes corresponding to 'intersection_size'.
    # Use three representative sizes: min, median, and max.
    size_min = df["intersection_size"].min()
    # size_median = int(df["intersection_size"].median())
    
    size_max = df["intersection_size"].max()
    size_median = int((size_min + size_max)/2) #int(df["intersection_size"].median())
    size_scale = 10  # This is the factor applied to intersection_size for the marker size

    sizes = [size_min, size_median, size_max]
    markers = [
        ax.scatter([], [], s=size * size_scale, color="gray", alpha=0.7, edgecolors="w")
        for size in sizes
    ]
    labels = [f"{size}" for size in sizes]

    if annotated_names:
        offset_x = -1.0
        offset_y = -0.5
        last_sources = set()
        for i, row in df.iterrows():
            
            if row["name"] in annotated_names:
                if row["source"] in last_sources:
                    print("swap")
                    # offset_x *= -1
                    offset_y *= -1
                x_point = row["neg_log10"]
                y_point = row["y_pos"]
                x_text = x_point + offset_x
                y_text = y_point + offset_y
                print(row["source"])
                
                last_sources.add(row["source"])
                # last_source = row["source"]
                ax.annotate(
                    row["name"],
                    xy=(x_point, y_point),
                    xytext=(x_text, y_text),
                    arrowprops=dict(facecolor="black", arrowstyle="->"),
                    fontsize=10,
                    bbox=dict(boxstyle="round,pad=0.3", fc="yellow", alpha=0.5)
                )

    ax.legend(markers, labels, title="Intersection Size", bbox_to_anchor=(1.05, 1), loc="upper left", borderaxespad=0)
    ax.grid()
    if pmin is not None and pmax is not None:
        ax.set_xlim([pmin, pmax])
    plt.tight_layout()
    if save is not None:
        plt.savefig(save)


In [None]:
# Put together full figure for paper.

fig = plt.figure(figsize=(8,10))
gs = gridspec.GridSpec(nrows=4, ncols=2, height_ratios=[1,1,1,0.1], figure=fig)
fontsize=16

ax = fig.add_subplot(gs[0,1])
paul15_path, dist = wavelet_pseudotime.graph_traversal.find_min_cost_path(paul15.uns["paga"]["connectivities"], 7, 0)
node_labels = list(paul15.obs["paul15_clusters"].cat.categories)
wavelet_pseudotime.plotting.draw_path(paul15.uns["paga"]["pos"], paul15_path, paul15.uns["paga"]["connectivities"], node_labels=node_labels, ax=ax)
# wavelet_pseudotime.plotting.draw_path(paul15.uns["paga"]["pos"], paul15_path, paul15.uns["paga"]["connectivities"], node_labels=node_labels)
ax.set_xticks([])
ax.set_yticks([])
ax.set_title("Minimum cost path", fontsize=fontsize)
paul15_labels_path = [list(paul15.obs["paul15_clusters"].cat.categories)[p] for p in paul15_path]
ax.text(0.05, 0.95, ascii_uppercase[1], transform=ax.transAxes,
            fontsize=16, fontweight='bold', va='top', ha='left')

ax = fig.add_subplot(gs[1,0])
keep_idx = adata.obs["paul15_clusters"].isin(paul15_labels_path)
adata2 = adata[keep_idx,:].copy()
number_bins = 10
adata2.obs['pt_bin'] = pd.cut(adata2.obs['psupertime'], bins=number_bins)
ct = pd.crosstab(adata2.obs['pt_bin'], adata2.obs['paul15_clusters'])
fractions = ct.div(ct.sum(axis=1), axis=0)
bin_midpoints = np.array([(interval.left + interval.right) / 2 for interval in fractions.index])
categories = fractions.columns.tolist()
stack_data = [fractions[cat].values for cat in categories]
ax.stackplot(bin_midpoints, stack_data, labels=categories)
ax.legend()
ax.set_title("Proportion of cell labels", fontsize=fontsize)
ax.set_xlabel("Pseudotime", fontsize=fontsize)
ax.set_xlim(bin_midpoints[0], bin_midpoints[-1])
# Remove all margins from both axes.
ax.margins(x=0, y=0)
# Adjust subplot parameters to use all the figure area.
plt.subplots_adjust(left=0, right=1, top=1, bottom=0)

ax.text(0.05, 0.95, ascii_uppercase[2], transform=ax.transAxes,
            fontsize=16, fontweight='bold', va='top', ha='left')

ax = fig.add_subplot(gs[1,1])
ax.hist(s, bins=100);
ax.set_xlabel("TES", fontsize=fontsize)
ax.set_ylabel("Frequency", fontsize=fontsize)
ax.set_title("Distribution of TES", fontsize=fontsize)
ax.axvline(4, color='red', label="Cutoff")
ax.legend()
ax.text(0.05, 0.95, ascii_uppercase[3], transform=ax.transAxes,
            fontsize=16, fontweight='bold', va='top', ha='left')

ax0 = fig.add_subplot(gs[2,0])
g = "Nfe2"
mmin = np.min(paul15.obs["psupertime"])
mmax = np.max(paul15.obs["psupertime"])
x = np.linspace(mmin, mmax, len(psd[g]))
idx = paul15.obs.index[paul15.obs["paul15_clusters"].isin(paul15_labels_path)]

ax0.plot(paul15[idx, g].obs["psupertime"], paul15[idx, g].X[:,0], ".")
ax0.plot(x, psd[g])
ax0.set_title(g)
ax0.set_xlim([0,1])
ax0.set_ylabel("Gene expression", fontsize=fontsize)
ax0.text(0.05, 0.95, ascii_uppercase[4], transform=ax0.transAxes,
            fontsize=16, fontweight='bold', va='top', ha='left')
coefs, _ = wt.apply(psd[g])
coefs_std = std_from_median(coefs)

ax1 = fig.add_subplot(gs[3,0])
im = ax1.imshow(np.sqrt(np.abs(coefs * coefs_std)))
ax1.set_yticks([i for i in range(0,coefs.shape[0])])
ax1.set_yticklabels([i for i in range(1,coefs.shape[0]+1)])
ax1.set_xticks([])
ax1.set_ylabel("Scale", fontsize=fontsize-2)
ax1.set_title("Root of TES", fontsize=fontsize)
ax1.text(0.01, 1.8, ascii_uppercase[6], transform=ax1.transAxes,
            fontsize=16, fontweight='bold', va='top', ha='left')

ax = fig.add_subplot(gs[2,1])
p_min = np.min(sig_results["p_value"])
p_max = np.max(sig_results["p_value"])
annot = ["cell redox homeostasis", "peroxiredoxin activity", "heme metabolic process"]
p_min *= 0.8
p_max *= 1.3

plot_df(sig_results, title=f"Enrichment for Erythropoiesis", ax=ax, pmin = -np.log10(p_max), pmax=-np.log10(p_min), sources=sources, annotated_names=annot)
ax.text(0.05, 0.95, ascii_uppercase[5], transform=ax.transAxes,
            fontsize=16, fontweight='bold', va='top', ha='left')



ax = fig.add_subplot(gs[0,0])
ax.set_title("PAGA connectivity", fontsize=fontsize)
sc.tl.paga(paul15, groups="paul15_clusters")
ax.text(0.05, 0.95, ascii_uppercase[0], transform=ax.transAxes,
            fontsize=16, fontweight='bold', va='top', ha='left')
sc.pl.paga(paul15, threshold=0.8, ax=ax, show=False)
fig.tight_layout()

fig.savefig(f"{r_dir}/fig3_hemo.png")
fig.savefig(f"{r_dir}/fig3_hemo.svg")