# Visualization of Heart mapping using `moscot.spatiotemporal` applied to the MOSTA dataset

Imports mapping results after running the grid search using `run_mosta_st_map.py`.


In [None]:
%load_ext autoreload 
%autoreload 2

In [None]:
import os
import sys

In [None]:
from datetime import datetime
import numpy as np
from copy import copy
import pickle as pkl
import glob

import scanpy as sc
import squidpy as sq
import pandas as pd

import seaborn as sns
import matplotlib.pyplot as plt

In [None]:
import mplscience
mplscience.set_style()
plt.rcParams["legend.scatterpoints"] = 1

## Dependencies

Requires running:

1. `0_Data_preparation/ZP_2023-04-20_spatiotemporal_fullembryo-preprocess.ipynb`: creates `mouse_embryo_all_stage_renormalized.h5ad`
2. `1_Cell_type_transition_analysis/1_mapping_across_timepoints`: creates `.csv` heart push forwards files


## Set parameters

In [None]:
sys.path.insert(
    0, "../../../../../"
)  # this depends on the notebook depth and must be adapted per notebook

from paths import DATA_DIR, FIG_DIR

FIG_DIR = FIG_DIR / "space/spatiotemporal"
DATA_DIR = DATA_DIR / "space/spatiotemporal"


## Load processed

In [None]:
adata_six = sc.read(DATA_DIR / "mouse_embryo_all_stage_renormalized.h5ad")

In [None]:
tps = adata_six.obs["time"].unique()
tps_couple = [[i, i+1] for i in tps[:-1]]

## Load mappings

In [None]:
tps_alpha = ["0.4", "0.99", "0.4", "0.8", "0.8", "0.6", "0.99"]

In [None]:
adata_six.obs["Heart_push"] =  0

In [None]:
for tp in range(7):
    file_push = DATA_DIR / f"output/mouse_embryo_eps_0.001_rank_500_gamma_10_alpha_{tps_alpha[tp]}_tp_{tp}_heart_push.pkl"
    epsilon = 0.001
    rank = 500
    gamma = 10
    alpha = tps_alpha[tp]
    start, end = tps_couple[tp]
    print(f"{tp} and {start}-{end}")
    with open(file_push, "rb") as handle:
        df = pkl.load(handle)
        adata_six.obs.loc[df.index , "Heart_push"] = df[0] / df[0].max()

In [None]:
heart_col = adata_six.uns["annotation_colors"][adata_six.obs["annotation"].cat.categories == "Heart"][0]

In [None]:
vmax = np.percentile(adata_six.obs["Heart_push"], 97)
sq.pl.spatial_scatter(
    adata_six,
    shape=None,
    color=["Heart_push"],
    cmap=cmap,
    size=1,
    frameon=False,
    figsize=(18, 3),
    dpi=300,
    legend_loc=None,
    vmax = vmax
)

plt.savefig(FIG_DIR / "Heart_push.png", bbox_inches="tight", transparent=True, dpi=300)
plt.show()