diff --git a/reproducibility/figures/summarize.py b/reproducibility/figures/summarize.py index 8c12312b6..2376c9bc2 100644 --- a/reproducibility/figures/summarize.py +++ b/reproducibility/figures/summarize.py @@ -7,7 +7,9 @@ import hydra import matplotlib.pyplot as plt +import numpy as np import scvelo as scv +import seaborn as sns from omegaconf import DictConfig from pyrovelocity.config import print_config_tree @@ -86,11 +88,13 @@ def plots(conf: DictConfig, logger: Logger) -> None: volcano_plot = reports_data_model_conf.volcano_plot rainbow_plot = reports_data_model_conf.rainbow_plot vector_field_plot = reports_data_model_conf.vector_field_plot + shared_time_plot = reports_data_model_conf.shared_time_plot output_filenames = [ dataframe_path, volcano_plot, rainbow_plot, vector_field_plot, + shared_time_plot, ] if all(os.path.isfile(f) for f in output_filenames): logger.info( @@ -133,6 +137,50 @@ def plots(conf: DictConfig, logger: Logger) -> None: ################## # generate figures ################## + vector_field_basis = data_model_conf.vector_field_parameters.basis + + # shared time plot + cell_time_mean = posterior_samples["cell_time"].mean(0).flatten() + cell_time_std = posterior_samples["cell_time"].std(0).flatten() + adata.obs["shared_time_uncertain"] = cell_time_std + adata.obs["shared_time_mean"] = cell_time_mean + fig, ax = plt.subplots(1, 2) + fig.set_size_inches(9.2, 3.5) + ax_cb = scv.pl.scatter( + adata, + c="shared_time_mean", + ax=ax[0], + show=False, + cmap="inferno", + fontsize=7, + colorbar=True, + ) + ax_cb = scv.pl.scatter( + adata, + c="shared_time_uncertain", + ax=ax[1], + show=False, + cmap="inferno", + fontsize=7, + colorbar=True, + ) + select = adata.obs["shared_time_uncertain"] > np.quantile( + adata.obs["shared_time_uncertain"], 0.9 + ) + sns.kdeplot( + adata.obsm[f"X_{vector_field_basis}"][:, 0][select], + adata.obsm[f"X_{vector_field_basis}"][:, 1][select], + ax=ax[1], + levels=3, + fill=False, + ) + fig.savefig( + shared_time_plot, + facecolor=fig.get_facecolor(), + bbox_inches="tight", + edgecolor="none", + dpi=300, + ) # volcano plot @@ -197,8 +245,6 @@ def plots(conf: DictConfig, logger: Logger) -> None: logger.info(f"Generating figure: {vector_field_plot}") fig, ax = plt.subplots() - vector_field_basis = data_model_conf.vector_field_parameters.basis - # embed_mean = plot_mean_vector_field(posterior_samples, adata, ax=ax) scv.pl.velocity_embedding_grid( adata,