In [None]:
%matplotlib inline

In [None]:
import scanpy as sc
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import os
import muon

                         
import mplscience

mplscience.available_styles()
mplscience.set_style(reset_current=True)
plt.rcParams['legend.scatterpoints'] = 1 
sc.set_figure_params(scanpy=True, dpi=500, dpi_save=500)

In [None]:
output_dir = "/lustre/groups/ml01/workspace/moscot_paper/pancreas_revision/plots/diffmap"

In [None]:
not_endocrine_celltypes = ["Mat. Acinar", "Imm. Acinar", "Prlf. Ductal", "Ductal", "Ngn3 high cycling"]

In [None]:
mudata = muon.read("/lustre/groups/ml01/workspace/moscot_paper/pancreas_revision/mudata_with_annotation_all.h5mu")
mudata = mudata[~mudata.obs["cell_type"].isin(not_endocrine_celltypes)]

In [None]:
color_dict = { 'Prlf. Ductal': '#f7b6d2', 'Ductal': '#d62728', 'Imm. Acinar': '#ffeb3b', 'Mat. Acinar': '#98df8a',
    'Alpha': '#1f77b4', 'Beta': '#ff7f0e', 'Delta': '#279e68', 'Eps. progenitors': '#aa40fc', 'Epsilon': '#8c564b',
       'Fev+': '#e377c2', 'Fev+ Alpha': '#b5bd61', 'Fev+ Beta': '#42f5ec', 'Fev+ Delta': '#aec7e8', 'Fev+ Delta,0': '#17becf', 'Fev+ Delta,1':'#98df8a',
      'Ngn3 high cycling': '#adf542',
        'Ngn3 high,0': '#ff9896', 'Ngn3 high': '#ff9896', 'Ngn3 high,1':'#f0b98d', 'Ngn3 low': '#c5b0d5'
}

In [None]:
mudata.obs["cell_type"] = mudata.obs["cell_type"].astype("category")
mudata.uns["cell_type_colors"] = [color_dict[ct] for ct in mudata.obs["cell_type"].cat.categories]

mudata.obs["cell_type_refined"] = mudata.obs["cell_type_refined"].astype("category")
mudata.uns["cell_type_refined_colors"] = [color_dict[ct] for ct in mudata.obs["cell_type_refined"].cat.categories]

In [None]:
from anndata import AnnData
bdata = AnnData(mudata.obsm["X_MultiVI"])

In [None]:
bdata.obs=mudata.obs

In [None]:
import scanpy.external as sce
sce.tl.phate(bdata, k=30, n_components=3)

In [None]:
sce.pl.phate(bdata, color="cell_type", components='all')

In [None]:
from moscot.problems.time import TemporalProblem
tp0 = TemporalProblem.load("/lustre/groups/ml01/workspace/moscot_paper/pancreas_revision/plots/OT_encodrine_analysis/TemporalProblem.pkl")

In [None]:
tp0.adata.uns["cell_type_colors"] = [color_dict[ct] for ct in tp0.adata.obs["cell_type"].cat.categories]
tp0.adata.uns["cell_type_refined_colors"] = [color_dict[ct] for ct in tp0.adata.obs["cell_type_refined"].cat.categories]

In [None]:
tp0.pull(14.5, 16.5, data="cell_type", subset="Delta", key_added="Delta_pull")

In [None]:
tp0.pull(14.5, 16.5, data="cell_type", subset="Beta", key_added="Beta_pull")

In [None]:
tp0.pull(14.5, 16.5, data="cell_type", subset="Fev+", key_added="Fev_pull")

In [None]:
tp0.pull(14.5, 16.5, data="cell_type", subset="Fev+ Alpha", key_added="Fev_Alpha_pull")

In [None]:
tp0.pull(14.5, 16.5, data="cell_type", subset="Fev+ Beta", key_added="Fev_Beta_pull")

In [None]:
tp0.pull(14.5, 16.5, data="cell_type", subset="Alpha", key_added="Alpha_pull")

In [None]:
tp0.pull(14.5, 16.5, data="cell_type", subset="Epsilon", key_added="Epsilon_pull")

In [None]:
tp0.pull(14.5, 16.5, data="cell_type", subset="Fev+ Delta", key_added="Fev_delta_pull")

In [None]:
tp0.pull(14.5, 16.5, data="cell_type_refined", subset="Fev+ Delta,0", key_added="Fev_delta0_pull")

In [None]:
tp0.pull(14.5, 16.5, data="cell_type_refined", subset="Fev+ Delta,1", key_added="Fev_delta1_pull")

In [None]:
tp0.pull(14.5, 16.5, data="cell_type_refined", subset="Eps. progenitors", key_added="Eps_prog_pull")

In [None]:
tp0.push(14.5, 16.5, data="cell_type", subset="Fev+ Delta", key_added="Fev_delta_push")

In [None]:
tp0.push(14.5, 16.5, data="cell_type_refined", subset="Fev+ Delta,0", key_added="Fev_delta0_push")

In [None]:
tp0.push(14.5, 16.5, data="cell_type_refined", subset="Fev+ Delta,1", key_added="Fev_delta1_push")

In [None]:
tp0.push(14.5, 16.5, data="cell_type_refined", subset="Eps. progenitors", key_added="Eps_prog_push")

In [None]:
tp0.push(14.5, 16.5, data="cell_type", subset="Epsilon", key_added="Epsilon_push")

In [None]:
tp0.push(14.5, 16.5, data="cell_type", subset="Fev+", key_added="Fev_push")

In [None]:
tp0.push(14.5, 16.5, data="cell_type", subset="Fev+ Alpha", key_added="Fev_Alpha_push")

In [None]:
tp0.push(14.5, 16.5, data="cell_type", subset="Fev+ Beta", key_added="Fev_Beta_push")

In [None]:
tp0._adata = tp0.adata[~tp0.adata.obs["cell_type"].isin(("Ngn3 high cycling",))]
tp0.adata.obsm["diff_map"] = bdata.obsm["X_phate"][:, [2,0]]
tp0.adata.obsm["diff_map"][:,1] = -tp0.adata.obsm["diff_map"][:,1]

In [None]:
with plt.rc_context({"figure.figsize": (5, 3), "figure.dpi": (300)}):
    axes = sc.pl.embedding(tp0.adata, basis="diff_map", color="cell_type_refined", show=False)
    axes.set_xlabel("Diffusion component 2")
    axes.set_ylabel("Diffusion component 4")


In [None]:
axes.figure.savefig(os.path.join(output_dir, 'diffusion_refinement.png'))

In [None]:
label_params = axes.get_legend_handles_labels() 
figl, axl = plt.subplots()
axl.axis(False)
axl.legend(*label_params, fontsize=5, loc="center", markerscale=2, bbox_to_anchor=(0.5, 0.5), prop={"size":10})

In [None]:
axl.figure.savefig(os.path.join(output_dir, 'diffusion_refinement_legend_only.png'))


In [None]:
import moscot as mt
mt.plotting.pull(tp0, key="Delta_pull", basis="diff_map", time_points=[14.5, 15.5, 16.5], figsize=(15,3), dot_scale_factor=5.0, save=os.path.join(output_dir, "Delta_pull"))

In [None]:
mt.plotting.pull(tp0, key="Alpha_pull", basis="diff_map", time_points=[14.5, 15.5, 16.5], figsize=(15,3), dot_scale_factor=5.0, save=os.path.join(output_dir, "Alpha_pull"))

In [None]:
mt.plotting.pull(tp0, key="Beta_pull", basis="diff_map", time_points=[14.5, 15.5, 16.5], figsize=(15,3), dot_scale_factor=5.0, save=os.path.join(output_dir, "Beta_pull"))

In [None]:
mt.plotting.pull(tp0, key="Epsilon_pull", basis="diff_map", time_points=[14.5, 15.5, 16.5], figsize=(15,3), dot_scale_factor=5.0, save=os.path.join(output_dir, "Epsilon_pull"))

In [None]:
mt.plotting.pull(tp0, key="Fev_delta_pull", basis="diff_map", time_points=[14.5, 15.5, 16.5], figsize=(15,3), dot_scale_factor=5.0, save=os.path.join(output_dir, "Fev_delta_pull"))

In [None]:
mt.plotting.pull(tp0, key="Fev_delta1_pull", basis="diff_map", time_points=[14.5, 15.5, 16.5], figsize=(15,3), dot_scale_factor=5.0, save=os.path.join(output_dir, "Fev_delta1_pull"))

In [None]:
mt.plotting.pull(tp0, key="Fev_delta0_pull", basis="diff_map", time_points=[14.5, 15.5, 16.5], figsize=(15,3), dot_scale_factor=5.0, save=os.path.join(output_dir, "Fev_delta0_pull"))

In [None]:
mt.plotting.pull(tp0, key="Eps_prog_pull", basis="diff_map", time_points=[14.5, 15.5, 16.5], figsize=(15,3), dot_scale_factor=5.0, save=os.path.join(output_dir, "Eps_prog_pull"))

In [None]:
mt.plotting.pull(tp0, key="Fev_pull", basis="diff_map", time_points=[14.5, 15.5, 16.5], figsize=(15,3), dot_scale_factor=5.0, save=os.path.join(output_dir, "Fev_pull"))

In [None]:
mt.plotting.pull(tp0, key="Fev_Alpha_pull", basis="diff_map", time_points=[14.5, 15.5, 16.5], figsize=(15,3), dot_scale_factor=5.0, save=os.path.join(output_dir, "Fev_Alpha_pull"))

In [None]:
mt.plotting.pull(tp0, key="Fev_Beta_pull", basis="diff_map", time_points=[14.5, 15.5, 16.5], figsize=(15,3), dot_scale_factor=5.0, save=os.path.join(output_dir, "Fev_Beta_pull"))

## Push

In [None]:
mt.plotting.push(tp0, key="Epsilon_push", basis="diff_map", time_points=[14.5, 15.5, 16.5], figsize=(15,3), dot_scale_factor=5.0, save=os.path.join(output_dir, "Epsilon_push"))

In [None]:
mt.plotting.push(tp0, key="Fev_delta_push", basis="diff_map", time_points=[14.5, 15.5, 16.5], figsize=(15,3), dot_scale_factor=5.0, save=os.path.join(output_dir, "Fev_delta_push"))

In [None]:
mt.plotting.push(tp0, key="Fev_delta0_push", basis="diff_map", time_points=[14.5, 15.5, 16.5], figsize=(15,3), dot_scale_factor=5.0, save=os.path.join(output_dir, "Fev_delta0_push"))

In [None]:
mt.plotting.push(tp0, key="Fev_delta1_push", basis="diff_map", time_points=[14.5, 15.5, 16.5], figsize=(15,3), dot_scale_factor=5.0, save=os.path.join(output_dir, "Fev_delta1_push"))

In [None]:
mt.plotting.push(tp0, key="Fev_delta_push", basis="diff_map", time_points=[14.5, 15.5, 16.5], figsize=(15,3), dot_scale_factor=5.0, save=os.path.join(output_dir, "Fev_delta_push"))

In [None]:
mt.plotting.push(tp0, key="Fev_push", basis="diff_map", time_points=[14.5, 15.5, 16.5], figsize=(15,3), dot_scale_factor=5.0, save=os.path.join(output_dir, "Fev_push"))

In [None]:
mt.plotting.push(tp0, key="Fev_Alpha_push", basis="diff_map", time_points=[14.5, 15.5, 16.5], figsize=(15,3), dot_scale_factor=5.0, save=os.path.join(output_dir, "Fev_Alpha_push"))

In [None]:
mt.plotting.push(tp0, key="Fev_Beta_push", basis="diff_map", time_points=[14.5, 15.5, 16.5], figsize=(15,3), dot_scale_factor=5.0, save=os.path.join(output_dir, "Fev_Beta_push"))

In [None]:
mt.plotting.push(tp0, key="Eps_prog_push", basis="diff_map", time_points=[14.5, 15.5, 16.5], figsize=(15,3), dot_scale_factor=5.0, save=os.path.join(output_dir, "Eps_prog_push"))