Skip to content

Commit

Permalink
add shared time uncertainty plot to summarize stages (#216)
Browse files Browse the repository at this point in the history
  • Loading branch information
qinqian committed Apr 27, 2023
1 parent 595e849 commit 588afce
Showing 1 changed file with 48 additions and 2 deletions.
50 changes: 48 additions & 2 deletions reproducibility/figures/summarize.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@

import hydra
import matplotlib.pyplot as plt
import numpy as np
import scvelo as scv
import seaborn as sns
from omegaconf import DictConfig

from pyrovelocity.config import print_config_tree
Expand Down Expand Up @@ -86,11 +88,13 @@ def plots(conf: DictConfig, logger: Logger) -> None:
volcano_plot = reports_data_model_conf.volcano_plot
rainbow_plot = reports_data_model_conf.rainbow_plot
vector_field_plot = reports_data_model_conf.vector_field_plot
shared_time_plot = reports_data_model_conf.shared_time_plot
output_filenames = [
dataframe_path,
volcano_plot,
rainbow_plot,
vector_field_plot,
shared_time_plot,
]
if all(os.path.isfile(f) for f in output_filenames):
logger.info(
Expand Down Expand Up @@ -133,6 +137,50 @@ def plots(conf: DictConfig, logger: Logger) -> None:
##################
# generate figures
##################
vector_field_basis = data_model_conf.vector_field_parameters.basis

# shared time plot
cell_time_mean = posterior_samples["cell_time"].mean(0).flatten()
cell_time_std = posterior_samples["cell_time"].std(0).flatten()
adata.obs["shared_time_uncertain"] = cell_time_std
adata.obs["shared_time_mean"] = cell_time_mean
fig, ax = plt.subplots(1, 2)
fig.set_size_inches(9.2, 3.5)
ax_cb = scv.pl.scatter(
adata,
c="shared_time_mean",
ax=ax[0],
show=False,
cmap="inferno",
fontsize=7,
colorbar=True,
)
ax_cb = scv.pl.scatter(
adata,
c="shared_time_uncertain",
ax=ax[1],
show=False,
cmap="inferno",
fontsize=7,
colorbar=True,
)
select = adata.obs["shared_time_uncertain"] > np.quantile(
adata.obs["shared_time_uncertain"], 0.9
)
sns.kdeplot(
adata.obsm[f"X_{vector_field_basis}"][:, 0][select],
adata.obsm[f"X_{vector_field_basis}"][:, 1][select],
ax=ax[1],
levels=3,
fill=False,
)
fig.savefig(
shared_time_plot,
facecolor=fig.get_facecolor(),
bbox_inches="tight",
edgecolor="none",
dpi=300,
)

# volcano plot

Expand Down Expand Up @@ -197,8 +245,6 @@ def plots(conf: DictConfig, logger: Logger) -> None:
logger.info(f"Generating figure: {vector_field_plot}")
fig, ax = plt.subplots()

vector_field_basis = data_model_conf.vector_field_parameters.basis

# embed_mean = plot_mean_vector_field(posterior_samples, adata, ax=ax)
scv.pl.velocity_embedding_grid(
adata,
Expand Down

0 comments on commit 588afce

Please sign in to comment.