# Evaluation of `moscot-spatio-temporal` mapping on the MOSTA dataset considering liver mapping


In [1]:
%load_ext autoreload 
%autoreload 2

In [3]:
import scanpy as sc
import matplotlib.pyplot as plt

In [4]:
import jax
jax.config.update("jax_enable_x64", True)

In [5]:
from moscot.problems.spatiotemporal import SpatioTemporalProblem

In [6]:
import mplscience
mplscience.set_style()
plt.rcParams["legend.scatterpoints"] = 1

In [7]:
import warnings
warnings.simplefilter("ignore", UserWarning)
warnings.simplefilter("ignore", FutureWarning)

## Dependencies

Requires running:

1. `0_Data_preparation/ZP_2023-04-20_spatiotemporal_fullembryo-preprocess.ipynb`: creates `mouse_embryo_{tp1}_{tp2}_renormalized.h5ad` for all consecutive time points.
2. `mosta_st_map_array_save.sh`: runs the SpatioTemporal problem over all consecutive time points and saves the ouput in `DATA_DIR + "grid/save/"` (uses `save_config.txt`, and `mosta_st_map_accuracies_save.py`)


## Set parameters

In [8]:
sys.path.insert(
    0, "../../../../../"
)  # this depends on the notebook depth and must be adapted per notebook

from paths import DATA_DIR, FIG_DIR

FIG_DIR = FIG_DIR / "space/spatiotemporal/"
DATA_DIR = DATA_DIR / "space/spatiotemporal/"

## Evaluate results

In [9]:
lst_load = [
    'stp_eps_0.0001_rank_500_gamma_100.0_alpha_0.4_tp_9_5_10_5_cost_sq_euclidean.pkl',
    'stp_eps_0.0001_rank_500_gamma_100.0_alpha_0.6_tp_10_5_11_5_cost_sq_euclidean.pkl',
    'stp_eps_0.0001_rank_500_gamma_100.0_alpha_0.2_tp_11_5_12_5_cost_sq_euclidean.pkl',
    'stp_eps_0.0001_rank_500_gamma_100.0_alpha_0.6_tp_12_5_13_5_cost_sq_euclidean.pkl',
    'stp_eps_0.0001_rank_500_gamma_10.0_alpha_0.4_tp_13_5_14_5_cost_sq_euclidean.pkl',
    'stp_eps_0.0001_rank_500_gamma_100.0_alpha_0.4_tp_14_5_15_5_cost_sq_euclidean.pkl',
    'stp_eps_0.0001_rank_500_gamma_10.0_alpha_0.4_tp_15_5_16_5_cost_sq_euclidean.pkl'
]

### Liver annotation pull analysis

In [10]:
## load res
adatas = []
for file in lst_load: 
    stp = SpatioTemporalProblem.load(DATA_DIR + f"grid/save/{file}")
    start, end = stp.adata.obs["time"].unique()
    print(f"pulling {start}-{end}")
    stp.pull(source=start, target=end, data="annotation", subset="Liver", key_added=f"Liver_pull", normalize=False)
    adatas.append(stp.adata.copy())
    

pulling 9.5-10.5
pulling 10.5-11.5
pulling 11.5-12.5
pulling 12.5-13.5
pulling 13.5-14.5
pulling 14.5-15.5
pulling 15.5-16.5


In [11]:
adata_pulls = []
for adata in adatas:
    adata_pulls.append(adata[adata.obs["time"] == adata.obs["time"].unique()[0]].copy())
    

In [12]:
adata_pull = sc.concat(adata_pulls)
adata_pull.var = adata.var.copy()
adata_pull = adata_pull[:, adata_pull.var["highly_variable"]].copy()
adata_pull

AnnData object with n_obs × n_vars = 397204 × 2000
    obs: 'annotation', 'timepoint', 'n_genes_by_counts', 'log1p_n_genes_by_counts', 'total_counts', 'log1p_total_counts', 'pct_counts_in_top_50_genes', 'pct_counts_in_top_100_genes', 'pct_counts_in_top_200_genes', 'pct_counts_in_top_500_genes', 'time', 'n_genes', 'total_counts_mt', 'pct_counts_mt', 'transition', 'proliferation', 'apoptosis', 'Liver_pull'
    var: 'n_cells_by_counts', 'mean_counts', 'log1p_mean_counts', 'pct_dropout_by_counts', 'total_counts', 'log1p_total_counts', 'n_cells', 'mt', 'highly_variable', 'highly_variable_rank', 'means', 'variances', 'variances_norm', 'highly_variable_nbatches'
    obsm: 'X_pca', 'X_pca_30', 'spatial'

In [13]:
stp = SpatioTemporalProblem(adata=adata_pull)
stp=stp.score_genes_for_marginals(gene_set_proliferation="mouse", gene_set_apoptosis="mouse")
stp = stp.prepare(
    time_key="time",
    spatial_key="spatial",
    joint_attr="X_pca_30",
    cost="sq_euclidean"
)

[34mINFO    [0m Normalizing spatial coordinates of `x`.                                                                   
[34mINFO    [0m Normalizing spatial coordinates of `y`.                                                                   
[34mINFO    [0m Normalizing spatial coordinates of `x`.                                                                   
[34mINFO    [0m Normalizing spatial coordinates of `y`.                                                                   
[34mINFO    [0m Normalizing spatial coordinates of `x`.                                                                   
[34mINFO    [0m Normalizing spatial coordinates of `y`.                                                                   
[34mINFO    [0m Normalizing spatial coordinates of `x`.                                                                   
[34mINFO    [0m Normalizing spatial coordinates of `y`.                                                                   


In [14]:
df_liver = stp.compute_feature_correlation(
        obs_key=f"Liver_pull",
    )

In [15]:
top_genes = df_liver[:10].index
df_liver[:10]

Unnamed: 0,Liver_pull_corr,Liver_pull_pval,Liver_pull_qval,Liver_pull_ci_low,Liver_pull_ci_high
Afp,0.59548,0.0,0.0,0.593469,0.597483
Alb,0.577517,0.0,0.0,0.575441,0.579586
Apoa2,0.534929,0.0,0.0,0.532705,0.537145
Car2,0.523767,0.0,0.0,0.521506,0.52602
Mt2,0.507893,0.0,0.0,0.505582,0.510197
Trf,0.506352,0.0,0.0,0.504036,0.508661
Gm45774,0.4971,0.0,0.0,0.494755,0.499438
Apoa1,0.49268,0.0,0.0,0.490321,0.495031
Cpox,0.484279,0.0,0.0,0.481894,0.486655
Rbp4,0.482769,0.0,0.0,0.48038,0.48515


### Transcription Factor `Hnf4a` push analysis

In [16]:
adatas_push = []
tf = "Hnf4a"
for file in lst_load: 
    stp = SpatioTemporalProblem.load(DATA_DIR + f"grid/save/{file}")
    start, end = stp.adata.obs["time"].unique()
    print(f"pushing {start}-{end}")
    stp.adata.obs[tf] = stp.adata[:, tf].X.A.copy()
    stp.push(source=start, target=end, data=tf, key_added=f"{tf}_push", normalize=False)
    del stp.adata.obs[tf]
    adatas_push.append(stp.adata.copy())
    

pushing 9.5-10.5
pushing 10.5-11.5
pushing 11.5-12.5
pushing 12.5-13.5
pushing 13.5-14.5
pushing 14.5-15.5
pushing 15.5-16.5


In [17]:
adatas_push_all = []
for adata in adatas_push:
    adatas_push_all.append(adata[adata.obs["time"] == adata.obs["time"].unique()[1]].copy())
    

In [18]:
adata_push = sc.concat(adatas_push_all)
adata_push.var = adata.var.copy()
adata_push = adata_push[:, adata_push.var["highly_variable"]].copy()
adata_push

AnnData object with n_obs × n_vars = 512814 × 2000
    obs: 'annotation', 'timepoint', 'n_genes_by_counts', 'log1p_n_genes_by_counts', 'total_counts', 'log1p_total_counts', 'pct_counts_in_top_50_genes', 'pct_counts_in_top_100_genes', 'pct_counts_in_top_200_genes', 'pct_counts_in_top_500_genes', 'time', 'n_genes', 'total_counts_mt', 'pct_counts_mt', 'transition', 'proliferation', 'apoptosis', 'Hnf4a_push'
    var: 'n_cells_by_counts', 'mean_counts', 'log1p_mean_counts', 'pct_dropout_by_counts', 'total_counts', 'log1p_total_counts', 'n_cells', 'mt', 'highly_variable', 'highly_variable_rank', 'means', 'variances', 'variances_norm', 'highly_variable_nbatches'
    obsm: 'X_pca', 'X_pca_30', 'spatial'

In [19]:
stp = SpatioTemporalProblem(adata=adata_push)
stp=stp.score_genes_for_marginals(gene_set_proliferation="mouse", gene_set_apoptosis="mouse")
stp = stp.prepare(
    time_key="time",
    spatial_key="spatial",
    joint_attr="X_pca_30",
    cost="sq_euclidean"
)

[34mINFO    [0m Normalizing spatial coordinates of `x`.                                                                   
[34mINFO    [0m Normalizing spatial coordinates of `y`.                                                                   
[34mINFO    [0m Normalizing spatial coordinates of `x`.                                                                   
[34mINFO    [0m Normalizing spatial coordinates of `y`.                                                                   
[34mINFO    [0m Normalizing spatial coordinates of `x`.                                                                   
[34mINFO    [0m Normalizing spatial coordinates of `y`.                                                                   
[34mINFO    [0m Normalizing spatial coordinates of `x`.                                                                   
[34mINFO    [0m Normalizing spatial coordinates of `y`.                                                                   


In [20]:
df = stp.compute_feature_correlation(obs_key=f"{tf}_push")

In [21]:
print(df.iloc[:5, 0])

Afp       0.415220
Alb       0.405604
Lgals2    0.388928
Lgals4    0.368240
Mt1       0.357200
Name: Hnf4a_push_corr, dtype: float64
