In [None]:
import sctour as sct
import scanpy as sc
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

In [None]:
adata = sc.read('Pancancer_T/h5ad/T11.h5ad')

In [None]:
adata.X = adata.layers['counts']

In [None]:
sc.pp.calculate_qc_metrics(adata, percent_top=None, log1p=False, inplace=True)

In [None]:
sc.pp.highly_variable_genes(adata, n_top_genes=1000, subset=True)

In [None]:
tnode = sct.train.Trainer(adata, loss_mode='nb', alpha_recon_lec=0.5, alpha_recon_lode=0.5)
tnode.train()

In [None]:
adata.obs['ptime'] = tnode.get_time()

In [None]:
#zs represents the latent z from variational inference, and pred_zs represents the latent z from ODE solver
#mix_zs represents the weighted combination of the two, which is used for downstream analysis
mix_zs, zs, pred_zs = tnode.get_latentsp(alpha_z=0.5, alpha_predz=0.5)
adata.obsm['X_TNODE'] = mix_zs

In [None]:
adata.obsm['X_VF'] = tnode.get_vector_field(adata.obs['ptime'].values, adata.obsm['X_TNODE'])

In [None]:
fig, axs = plt.subplots(ncols=2, nrows=2, figsize=(10, 10))
sc.pl.umap(adata, color='T11_type', ax=axs[0, 0],  palette={"Grm1": "#63B4C1", "Grm2": "#5072AC"},legend_loc='on data', show=False, frameon=False)
sc.pl.umap(adata, color='ptime', ax=axs[1, 0], show=False, frameon=False)
sct.vf.plot_vector_field(adata, zs_key='X_TNODE', vf_key='X_VF', use_rep_neigh='X_TNODE', color='T11_type', 
                          palette={"Grm1": "#63B4C1", "Grm2": "#5072AC"},
                         show=False, ax=axs[1, 1], legend_loc='none', frameon=False, size=100, alpha=0.7)
plt.show()

In [None]:
sct.vf.plot_vector_field(adata, zs_key='X_TNODE', vf_key='X_VF', use_rep_neigh='X_TNODE', color='T11_type', 
                         # show=False, 
                         # legend_loc='on_data',
                         frameon=False,
                         size=250, alpha = 0.8,
                         stream_density = 1.8,
                         smooth = 1,
                         stream_linewidth = 1.5,
                         stream_arrowsize = 1.5,
                         save = True
                        )