# Macrocyte & Macrophage dataset

This notebook is part of the paper titled, "Single-Cell Trajectory Inference for Detecting Transient Events in Biological Processes" by Hutton and Meyer. The data is from the 2021 Specht et al. paper titled, "[Single-cell proteomic and transcriptomic analysis of macrophage heterogeneity using SCoPE2](https://doi.org/10.1186/s13059-021-02267-5)".

In [None]:
import numpy as np
import scanpy as sc
import pandas as pd
import wavelet_pseudotime as wp
import wavelet_pseudotime.load_data
from matplotlib import pyplot as plt
from IPython.display import clear_output
from importlib import reload
from importlib import reload
import numpy as np
from anndata import AnnData
from typing import Union, List, Tuple
from wavelet_pseudotime.windowing import Window, GaussianWindow, RectWindow, ConfinedGaussianWindow
from collections import defaultdict as dd
from wavelet_pseudotime.wavelets import WaveletTransform, get_max_wavelets, get_max_scored_wavelets
from typing import Literal
import scanpy as sc
from wavelet_pseudotime.binning import quantile_binning_anndata
from wavelet_pseudotime.process import window_trajectory
from scipy.ndimage import gaussian_filter1d
import seaborn as sns
import os
import matplotlib.gridspec as gridspec

from datetime import datetime
date_str = datetime.now().strftime("%Y_%m_%d")
r_dir = f"{date_str}_slavov"  # results directory
if not os.path.exists(r_dir):
    os.mkdir(r_dir)

In [None]:
from wavelet_pseudotime.load_data import load_slavov, load_slavov_xomic

In [None]:
p, t = wavelet_pseudotime.load_data.load_slavov_xomic()

In [None]:
p

In [None]:
waves_p, scores_p, ps_p, adata_p = wavelet_pseudotime.process.pipeline_slavov(p, save_name=f"{r_dir}/slavov_p.hd5a", repeat=True, exclude_pt_ends=(0, 0.975))
waves_t, scores_t, ps_t, adata_t = wavelet_pseudotime.process.pipeline_slavov(t, save_name=f"{r_dir}/slavov_t.hd5a", repeat=True, exclude_pt_ends=(0,1))

In [None]:
# waves_p, scores_p, ps_p, adata_p = wavelet_pseudotime.process.pipeline_slavov(p, save_name="slavov_p.hd5a", repeat=True, exclude_pt_ends=(0, 1))

In [None]:
panel_6a = adata_t.obs["dpt_pseudotime"].copy()
panel_6b = adata_p.obs["dpt_pseudotime"].copy()


plt.hist(adata_t.obs["dpt_pseudotime"], bins=100)
plt.xlabel("Pseudotime")
plt.ylabel("Frequency")
plt.title("Distribution of Cells in Pseudotime (scRNA)")
# plt.savefig(f"{r_dir}/pt_dist_scrna.png")

plt.figure()
plt.hist(adata_p.obs["dpt_pseudotime"], bins=100)
plt.xlabel("Pseudotime")
plt.ylabel("Frequency")
plt.title("Distribution of Cells in Pseudotime (proteins)")
# plt.savefig(f"{r_dir}/pt_dist_prot.png")



In [None]:
adata_p.shape

In [None]:
prot = wavelet_pseudotime.load_data.load_slavov()

In [None]:
prot.shape

In [None]:
adata_p.obs["celltype"] = prot.obs["celltype"]

In [None]:
# make sure they match
for idx in adata_p.obs.index:
    if adata_p.obs.loc[idx, "celltype"] != adata_p.obs.loc[idx, "celltype"]:
        print(idx)

In [None]:
g = "LYZ"
m0_cells = (adata_p.obs["celltype"] == "sc_m0").values
u_cells = ~m0_cells
p07 = (adata_p.obs["dpt_pseudotime"] < 1)
plt.plot(adata_p.obs.loc[m0_cells & p07, "dpt_pseudotime"], adata_p[m0_cells & p07, g].X[:, 0], ".", label="m0")
plt.plot(adata_p.obs.loc[u_cells & p07, "dpt_pseudotime"], adata_p[u_cells & p07, g].X[:, 0], ".", label="u")
plt.legend()
plt.xlabel("Pseudotime")
plt.ylabel("Protein Quant.")
plt.title(f"Expression of {g} along pseudotime")
# plt.savefig(f"{r_dir}/example_dist.png")

In [None]:
import pandas as pd
import matplotlib.pyplot as plt

# Define number of bins for pseudotime
num_bins = 20

# Bin psupertime into intervals
adata_p.obs['pt_bin'] = pd.qcut(adata_p.obs['dpt_pseudotime'], q=num_bins, labels=False)

# Compute proportion of cells in each phase per bin
phase_proportions = adata_p.obs.groupby(['pt_bin', 'celltype']).size().unstack(fill_value=0)
phase_proportions = phase_proportions.div(phase_proportions.sum(axis=1), axis=0)
panel_6c = phase_proportions.copy()
# Create stacked area plot
plt.figure(figsize=(8, 5))
plt.stackplot(phase_proportions.index, phase_proportions.T.values, labels=phase_proportions.columns, alpha=0.8)

# Formatting
plt.xlabel("Pseudotime Bin")
plt.ylabel("Proportion of Cells")
plt.xticks(ticks=[0, (num_bins-1)//2, num_bins-1], labels=[0,0.5, 1])
plt.title("Cell Type Proportions Across Pseudotime")
plt.legend(title="Phase", bbox_to_anchor=(1.05, 1), loc='upper left')
plt.grid(False)

# Show plot
# plt.savefig(f"{r_dir}/cell_prop.png", bbox_inches="tight")
plt.show()


In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

# Bin the xlocation values into a specified number of bins
num_bins = 20
adata_p.obs['pt_bin'] = pd.cut(adata_p.obs['dpt_pseudotime'], bins=num_bins)

# Create a crosstab to get counts for each bin and class_label, then normalize each bin to get fractions
ct = pd.crosstab(adata_p.obs['pt_bin'], adata_p.obs['celltype'])
fractions = ct.div(ct.sum(axis=1), axis=0)

# Get the midpoints of each bin for the x-axis.
bin_midpoints = np.array([(interval.left + interval.right) / 2 for interval in fractions.index])

# Prepare data for the stackplot.
categories = fractions.columns.tolist()
stack_data = [fractions[cat].values for cat in categories]

# Create the stackplot
fig, ax = plt.subplots()#figsize=(5, 5))
fig6c_x = bin_midpoints.copy()
fig6c_y = stack_data.copy()
fig6c_labels = categories.copy()
ax.stackplot(bin_midpoints, stack_data, labels=categories)

# Customize the plot: labels, title, legend.
ax.set_xlabel('xlocation (bin midpoints)')
ax.set_ylabel('Fraction')
ax.set_title('Stacked Fraction Distribution of Class Labels by Binned xlocation')
# ax.legend(loc='upper right')

# Remove extra whitespace/margins around the plot.
# Set x limits exactly to the first and last bin midpoints.
ax.set_xlim(bin_midpoints[0], bin_midpoints[-1])
# Remove all margins from both axes.
ax.margins(x=0, y=0)
# Adjust subplot parameters to use all the figure area.
plt.subplots_adjust(left=0, right=1, top=1, bottom=0)

plt.savefig(f"{r_dir}/cell_prop.png", bbox_inches="tight")
plt.show()


In [None]:
idx = np.where(phase_proportions["sc_m0"] > 0.5)[-1][0]
fifty_trans = idx / phase_proportions.shape[0]

In [None]:
fifty_trans

In [None]:
thresh = 1.2
g_above_thresh = [k for k, v in scores_p.items() if v > thresh]
print(len(g_above_thresh))

In [None]:
f = open(f"{r_dir}/genes_above_threshold.txt", "w")
for g in g_above_thresh:
    f.write(f"{g}\n")
f.close()

In [None]:
# Clustered pseudotimecourses

pt_above_thresh = []
for g in g_above_thresh:
    pt_above_thresh.append(ps_p[g])
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
import numpy as np
from collections import defaultdict as dd
kmeans = KMeans(n_clusters=4, random_state=0)
labels = kmeans.fit_predict(pt_above_thresh)
c = dd(list)
for idx_g, g in enumerate(g_above_thresh):
    c[labels[idx_g]].append(g)
fig, axs = plt.subplots(2,2)

for idx in range(len(c)):
    ax = axs[np.unravel_index(idx, axs.shape)]
    for g in c[idx]:
        x = np.linspace(np.min(bin_midpoints), np.max(bin_midpoints), len(ps_p[g]))
        ax.plot(x, ps_p[g])
    lin_loc = (np.max(bin_midpoints) - np.min(bin_midpoints)) * fifty_trans + np.min(bin_midpoints)
    ax.axvline(lin_loc, linestyle="--", label="50% transition")
    ax.set_title(f"Cluster {idx}")
    if idx == 0:
        ax.legend()
    if idx > 1:
        ax.set_xlabel("Pseudotime")
    if idx % 2 == 0:
        ax.set_ylabel("Gene expression")
fig.tight_layout()

In [None]:
# Mean timecourses

x = np.zeros(len(ps_p[c[0][0]]))
count = 0
kk = list(c.keys())
kk.sort()
# for idx in c.keys():
for idx in kk:
    count = 0
    x = np.zeros(len(ps_p[c[0][0]]))
    # print(len(c[idx]))  # number of genes in each cluster
    for g in c[idx]:
        x += ps_p[g]
        count += 1
    x /= count
    minx = np.min(bin_midpoints)
    maxx = np.max(bin_midpoints)
    num = len(x)
    plt.plot(np.linspace(minx, maxx, num), x, label=f"Cluster {idx}")
lin_pos = (maxx - minx) * fifty_trans + minx
plt.axvline(lin_pos, linestyle="--", label="50% transition")
plt.legend()
plt.title("Mean pseudotimecourses")
plt.xlabel("Pseudotime")
plt.ylabel("Protein quant.")
plt.savefig(f"{r_dir}/fig7_cluster_mean_timecourses.png")
plt.savefig(f"{r_dir}/fig7_cluster_mean_timecourses.svg")

In [None]:
fig = plt.figure(figsize=(8, 6))
gs = gridspec.GridSpec(3, 2, figure=fig)
ax = fig.add_subplot(gs[0, :])
c = dd(list)
for idx_g, g in enumerate(g_above_thresh):
    c[labels[idx_g]].append(g)
x = np.zeros(len(ps_p[c[0][0]]))
count = 0
kk = list(c.keys())
kk.sort()
# for idx in c.keys():
for idx in kk:
    count = 0
    x = np.zeros(len(ps_p[c[0][0]]))
    for g in c[idx]:
        x += ps_p[g]
        count += 1
    x /= count
    minx = np.min(bin_midpoints)
    maxx = np.max(bin_midpoints)
    num = len(x)
    ax.plot(np.linspace(minx, maxx, num), x, label=f"Cluster {idx}")
lin_pos = (maxx - minx) * fifty_trans + minx
# plt.axvline(fifty_trans * len(ps_p[g]), linestyle="--", label="50% transition")
ax.axvline(lin_pos, linestyle="--", label="50% transition")
plt.legend(bbox_to_anchor=(1,1))
ax.set_title("Mean pseudotimecourses", fontsize=16)
ax.set_xlabel("Pseudotime", fontsize=16)
ax.set_ylabel("Protein quant.", fontsize=16)

pt_above_thresh = []
for g in g_above_thresh:
    pt_above_thresh.append(ps_p[g])
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
import numpy as np
from collections import defaultdict as dd
kmeans = KMeans(n_clusters=4, random_state=0)
labels = kmeans.fit_predict(pt_above_thresh)

gs_list = [gs[1,0], gs[1,1], gs[2,0], gs[2,1]]

for idx in range(len(c)):
    gs0 = gs_list[idx]
    ax = fig.add_subplot(gs0)
    for g in c[idx]:
        x = np.linspace(np.min(bin_midpoints), np.max(bin_midpoints), len(ps_p[g]))
        ax.plot(x, ps_p[g])
    lin_loc = (np.max(bin_midpoints) - np.min(bin_midpoints)) * fifty_trans + np.min(bin_midpoints)
    # ax.axvline(fifty_trans * len(ps_p[g]), linestyle="--", label="50% transition")
    ax.axvline(lin_loc, linestyle="--", label="50% transition")
    # ax.axvline(last_g1, linestyle="--")
    # ax.axvline(last_s, linestyle="--")
    ax.set_title(f"Cluster {idx}", fontsize=16)
    if idx == 0:
        ax.legend()
    if idx > 1:
        ax.set_xlabel("Pseudotime", fontsize=16)
    if idx % 2 == 0:
        ax.set_ylabel("Protein quant.", fontsize=16)
fig.tight_layout()

plt.savefig(f"{r_dir}/fig7_cluster_mean_timecourses.png")
plt.savefig(f"{r_dir}/fig7_cluster_mean_timecourses.svg")

In [None]:
wt = wavelet_pseudotime.wavelets.WaveletTransform(scales=[1,2,3,4], wavelet="mexh")
coef, _ = wt.apply(x)
plt.imshow(coef)

# Figure assembly

In [None]:
from string import ascii_uppercase

In [None]:
fig = plt.figure(figsize=(10, 12))
gs = gridspec.GridSpec(5, 2, figure=fig)
fontsize=16

ax = fig.add_subplot(gs[0,0])
ax.hist(panel_6a, bins=100);
ax.set_ylabel("Frequency", fontsize=fontsize)
ax.set_title("scRNA pseudotime", fontsize=fontsize)
ax.text(0.02, 0.95, ascii_uppercase[0], transform=ax.transAxes,
            fontsize=16, fontweight='bold', va='top', ha='left')

ax = fig.add_subplot(gs[0,1])
ax.hist(panel_6b, bins=100);
ax.set_title("scP pseudotime", fontsize=fontsize)
ax.set_xlim([0,1])
ax.text(0.02, 0.95, ascii_uppercase[1], transform=ax.transAxes,
            fontsize=16, fontweight='bold', va='top', ha='left')

ax = fig.add_subplot(gs[1,0])
ax.stackplot(fig6c_x, fig6c_y, labels=fig6c_labels)
ax.set_xlabel('Pseudotime', fontsize=fontsize)
ax.set_ylabel('Fraction', fontsize=fontsize)
ax.set_title('scP cell type fraction', fontsize=fontsize)
ax.text(0.02, 0.95, ascii_uppercase[2], transform=ax.transAxes,
            fontsize=16, fontweight='bold', va='top', ha='left')

# Remove extra whitespace/margins around the plot.
# Set x limits exactly to the first and last bin midpoints.
ax.set_xlim(bin_midpoints[0], bin_midpoints[-1])
# Remove all margins from both axes.
ax.margins(x=0, y=0)
# Adjust subplot parameters to use all the figure area.
plt.subplots_adjust(left=0, right=1, top=1, bottom=0)

# g = "LMNA"
g = "CAPG"
# ax = axs[1,1]
ax = fig.add_subplot(gs[1,1])
m0_cells = (adata_p.obs["celltype"] == "sc_m0").values
u_cells = ~m0_cells
p07 = (adata_p.obs["dpt_pseudotime"] < 0.7)
ax.plot(adata_p.obs.loc[m0_cells & p07, "dpt_pseudotime"], adata_p[m0_cells & p07, g].X[:, 0], ".", label="m0")
ax.plot(adata_p.obs.loc[u_cells & p07, "dpt_pseudotime"], adata_p[u_cells & p07, g].X[:, 0], ".", label="u")
ax.legend()
ax.set_xlabel("Pseudotime", fontsize=fontsize)
ax.set_ylabel("Protein Quant.", fontsize=fontsize)
ax.set_title(f"Expression of {g}", fontsize=fontsize)
ax.plot(np.linspace(0,np.max(adata_p.obs["dpt_pseudotime"]), len(ps_p[g])), ps_p[g])
ax.text(0.02, 0.15, ascii_uppercase[3], transform=ax.transAxes,
            fontsize=16, fontweight='bold', va='top', ha='left')

ax = fig.add_subplot(gs[2,:])
c = dd(list)
for idx_g, g in enumerate(g_above_thresh):
    c[labels[idx_g]].append(g)
x = np.zeros(len(ps_p[c[0][0]]))
count = 0
kk = list(c.keys())
kk.sort()
for idx in kk:
    count = 0
    x = np.zeros(len(ps_p[c[0][0]]))
    print(len(c[idx]))
    for g in c[idx]:
        x += ps_p[g]
        count += 1
    x /= count
    minx = np.min(bin_midpoints)
    maxx = np.max(bin_midpoints)
    num = len(x)
    ax.plot(np.linspace(minx, maxx, num), x, label=f"Cluster {idx}")
lin_pos = (maxx - minx) * fifty_trans + minx
ax.axvline(lin_pos, linestyle="--", label="50% transition")
plt.legend(bbox_to_anchor=(1,1))
ax.set_title("Mean pseudotimecourses", fontsize=16)
ax.set_xlabel("Pseudotime", fontsize=16)
ax.set_ylabel("Protein quant.", fontsize=16)
ax.text(0.02, 0.95, ascii_uppercase[4], transform=ax.transAxes,
            fontsize=16, fontweight='bold', va='top', ha='left')


# pseudotimecourse
pt_above_thresh = []
for g in g_above_thresh:
    pt_above_thresh.append(ps_p[g])
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
import numpy as np
from collections import defaultdict as dd
kmeans = KMeans(n_clusters=4, random_state=0)
labels = kmeans.fit_predict(pt_above_thresh)

gs_list = [gs[3,0], gs[3,1], gs[4,0], gs[4,1]]

for idx in range(len(c)):
    gs0 = gs_list[idx]
    ax = fig.add_subplot(gs0)
    for g in c[idx]:
        x = np.linspace(np.min(bin_midpoints), np.max(bin_midpoints), len(ps_p[g]))
        ax.plot(x, ps_p[g])
    lin_loc = (np.max(bin_midpoints) - np.min(bin_midpoints)) * fifty_trans + np.min(bin_midpoints)
    ax.axvline(lin_loc, linestyle="--", label="50% transition")
    ax.set_title(f"Cluster {idx}", fontsize=16)
    if idx == 0:
        ax.legend()
    if idx > 1:
        ax.set_xlabel("Pseudotime", fontsize=16)
    if idx % 2 == 0:
        ax.set_ylabel("Protein quant.", fontsize=16)
    ax.text(0.02, 0.95, ascii_uppercase[5+idx], transform=ax.transAxes,
            fontsize=16, fontweight='bold', va='top', ha='left')

fig.tight_layout()
plt.savefig(f"{r_dir}/fig6_slavov.png")
plt.savefig(f"{r_dir}/fig6_slavov.svg")

In [None]:
sc.tl.rank_genes_groups(adata_p, groupby="celltype")
top100 = adata_p.uns["rank_genes_groups"]["names"][:100]
adata_p.obs["celltype"].cat.categories

In [None]:
top100_m0 = []
top100_u = []
for c in top100:
    top100_m0.append(c[0])
    top100_u.append(c[1])

In [None]:
score_list = []
for k, v in scores_p.items():
    score_list.append({"protein": k, "score": v})
sorted_scores = sorted(score_list, key=lambda x: x["score"])
top100_proteins = [k["protein"] for k in sorted_scores[-100:]]

In [None]:
m0_overlap = set(top100_m0).intersection(set(top100_proteins))
u_overlap = set(top100_u).intersection(set(top100_proteins))

In [None]:
len(m0_overlap)

In [None]:
len(u_overlap)

In [None]:
import matplotlib.pyplot as plt
from matplotlib.patches import Circle


def draw_venn_diagram(group1_count: int,
                      group2_count: int,
                      group3_count: int,
                      overlap12_count: int,
                      overlap23_count: int) -> None:
    # create a new figure and axis
    fig, ax = plt.subplots(figsize=(7, 6))

    # Define circle properties.
    radius = 1.0
    circle1_center = (0, 0)  # Group 1
    circle2_center = (1.4, 0)  # Group 2 overlaps with both Group 1 and Group 3
    circle3_center = (2.8, 0)  # Group 3
    group1_label = "Monocyte DE"
    group2_label = "TE"
    group3_label = "Macrophage DE"
    fontsize=16
    # Create circles using Circle patch.
    circle1 = Circle(circle1_center, radius, color='red', alpha=0.4, label=group1_label)
    circle2 = Circle(circle2_center, radius, color='green', alpha=0.4, label=group2_label)
    circle3 = Circle(circle3_center, radius, color='blue', alpha=0.4, label=group3_label)

    # Add circles to the plot.
    for circle in (circle1, circle2, circle3):
        ax.add_patch(circle)

    # Annotate each circle with the group total count.
    # For Group 1, place the number in the left (exclusive) region.
    ax.text(circle1_center[0] - 0.5, circle1_center[1] + 0.1, f"{group1_count}",
            fontsize=fontsize, color='darkred', weight='bold')

    # For Group 2, place the main count in the upper region.
    ax.text(circle2_center[0] - 0.2, circle2_center[1] + 0.4, f"{group2_count}",
            fontsize=fontsize, color='darkgreen', weight='bold')

    # For Group 3, place the number in the right (exclusive) region.
    ax.text(circle3_center[0] + 0.4, circle3_center[1] + 0.1, f"{group3_count}",
            fontsize=fontsize, color='darkblue', weight='bold')

    # Annotate overlaps:
    # Overlap between Group 1 and Group 2:
    # Compute an approximate midpoint between the centers of circle1 and circle2.
    overlap_12_pos = ((circle1_center[0] + circle2_center[0]) / 2, (circle1_center[1] + circle2_center[1]) / 2 - 0.2)
    ax.text(overlap_12_pos[0] - 0.1, overlap_12_pos[1], f"{overlap12_count}",
            fontsize=fontsize, color='black', weight='bold')

    # Overlap between Group 2 and Group 3:
    overlap_23_pos = ((circle2_center[0] + circle3_center[0]) / 2, (circle2_center[1] + circle3_center[1]) / 2 - 0.2)
    ax.text(overlap_23_pos[0] - 0.1, overlap_23_pos[1], f"{overlap23_count}",
            fontsize=fontsize, color='black', weight='bold')

    # (Optional) Add group name labels below the circles.
    
    ax.text(circle1_center[0] - 0.5, circle1_center[1] - 1.2, group1_label, fontsize=fontsize, color='darkred')
    ax.text(circle2_center[0] - 0.1, circle2_center[1] - 1.2, group2_label, fontsize=fontsize, color='darkgreen')
    ax.text(circle3_center[0] - 0.5, circle3_center[1] - 1.2, group3_label, fontsize=fontsize, color='darkblue')

    # Adjust plot limits so all circles are clearly visible.
    ax.set_xlim(-1, 4)
    ax.set_ylim(-1, 1)

    # Use an equal aspect ratio for proper circle representation.
    ax.set_aspect('equal')
    ax.axis('off')

    # Add a title and show the plot.
    # plt.title('Venn Diagram with Group Totals and Overlaps')
    fig.subplots_adjust(left=0, right=1, top=1, bottom=0)
    plt.title("Overlap between DE genes and TE genes", fontsize=20)
    fig.tight_layout()
    # plt.show()


In [None]:
draw_venn_diagram(len(top100_u)-len(u_overlap), len(top100_proteins)-len(u_overlap)-len(m0_overlap), len(top100_m0)-len(m0_overlap), len(u_overlap), len(m0_overlap))
plt.savefig(f"{r_dir}/venn_slavov.png", bbox_inches="tight")
plt.savefig(f"{r_dir}/venn_slavov.svg", bbox_inches="tight")