# Gene Trajectory Python tutorial: Human myeloid #

GeneTrajectory is a method for inferring gene trajectories in scRNA-seq data, which facilitates understanding of gene dynamics underlying biological processes. The major workflow of GeneTrajectory comprises the following four main steps:

- Step 1. Build a cell-cell kNN graph in which each cell is connected to its k-nearest neighbors. Find the shortest path connecting each pair of cells in the graph and denote its length as the graph distance between cells.
- Step 2. Compute pairwise graph-based Wasserstein distance between gene distributions, which quantifies the minimum cost of transporting the distribution of a given gene into the distribution of another gene in the cell graph.
- Step 3. Generate a low-dimensional representation of genes (using Diffusion Map by default) based on the gene-gene Wasserstein distance matrix. Identify gene trajectories in a sequential manner.
- Step 4. Determine the order of genes along each gene trajectory.

![GT_workflow.png](https://github.com/richcmwang/gene-trajectory-experiments/blob/main/docs/notebooks/tutorial_images/GT_workflow.png?raw=1)

In [None]:
!pip install \
igraph>=0.10 \
matplotlib>=3.6 \
numpy>=1.25 \
pandas>=1.5 \
pot>=0.8.2 \
scanpy>=1.9.3 \
scikit-misc>=0.1.3 \
scikit-learn>=0.24 \
scipy>=1.8 \
seaborn>=0.13 \
tqdm>=4.64.1

In [None]:
!pip install ipywidgets>=8.0.0 --upgrade

In [None]:
!git clone https://github.com/richcmwang/gene-trajectory-experiments.git
%cd gene-trajectory-experiments

In [None]:
import scanpy as sc
from gene_trajectory.add_gene_bin_score import add_gene_bin_score
from gene_trajectory.coarse_grain import select_top_genes, coarse_grain_adata
from gene_trajectory.extract_gene_trajectory import get_gene_embedding
from gene_trajectory.get_graph_distance import get_graph_distance
from gene_trajectory.gene_distance_shared import cal_ot_mat
from gene_trajectory.run_dm import run_dm
from gene_trajectory.plot.gene_trajectory_plots import plot_gene_trajectory_umap
from gene_trajectory.util.download_file import download_file_if_missing

from gene_trajectory.widgets import ExtractGeneTrajectoryWidget

In [None]:
import warnings

warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=UserWarning)

In [None]:
from google.colab import output
output.enable_custom_widget_manager()

## Loading example data  ##
The standard preprocessing can be done by employing the scanpy Python package which includes:
library normalization; finding variable features; scaling; generating PCA embedding (and UMAP embedding for visualization).

We load a preprocessed Scanpy Anndata object where we will perform the gene trajectory inference.

First, we download the data from [figshare](https://figshare.com/articles/dataset/Processed_AnnData_objects_for_GeneTrajectory_inference_Gene_Trajectory_Inference_for_Single-cell_Data_by_Optimal_Transport_Metrics_/25539547), where a copy of the data needed for the tutorial is saved.


In [None]:
download_file_if_missing('tutorial_data/human_myeloid_scanpy_obj.h5ad',
    url='https://figshare.com/ndownloader/files/45448603',
    md5_hash='923f8f00819e9f6401445af8d97275eb',
    create_target_folder_if_missing=True)

Once the file has been downloaded and saved to `tutorial_data/human_myeloid_scanpy_obj.h5ad`, we load the preprocessed Scanpy Anndata object where we will perform the gene trajectory inference.

In [None]:
adata = sc.read_h5ad('tutorial_data/human_myeloid_scanpy_obj.h5ad')

Review dataset

In [None]:
print(f"cell barcode (ID) x  cell-level features: {adata.var.shape}")
adata.obs   # cell barcode (ID) x  cell-level features

The row represents cell ID, and the column represents:

* **`orig.ident`**: Original sample or batch ID (e.g., 0, 1, 2)
* **`nCount_RNA`**: Total number of RNA UMIs (counts) detected in the cell across genes for that cell
* **`nFeature_RNA`**: Number of genes detected (non-zero counts) for the cell
* **`observed`, `simulated`**: Likely from a modeling step (e.g., diffusion, entropy, or trajectory likelihoods)
* **`percent.mito`**: Percentage of reads from mitochondrial genes (common QC feature)
* **`RNA_snn_res.0.4`**: Clustering labels computed at resolution 0.4 (e.g., from a shared nearest neighbor graph)
* **`seurat_clusters`**: Final clustering assignments (often equivalent to a specific `RNA_snn_res`)
* **`celltype`**: Annotated or inferred cell type label
* **`TrajectoryX_genesY`**: Scores from gene trajectory analysis (e.g., from optimal transport or diffusion methods)


In [None]:
print(f"gene x gene level features: {adata.var.shape}")
adata.var # gene x gene level features

The gene level features are statistics over the cell population.

* **`mean`**: Mean expression of the gene across all cells (raw or normalized, depending on the pipeline)
* **`variance`**: Variance of expression across all cells
* **`variance.expected`**: Expected variance under a null model (often based on a mean–variance relationship)
* **`variance.standardized`**: Observed variance divided by expected variance — used to score variability
* **`vst.mean`**: Mean expression after **variance-stabilizing transformation (VST)**
* **`vst.variance`**: Variance after VST
* **`vst.variance.expected`**: Expected variance under the VST-based null model
* **`vst.variance.standardized`**: Standardized VST variance = observed / expected
* **`vst.variable`**: Boolean indicating whether the gene is flagged as **highly variable** by the VST method
* **`alra_features`**: Gene identifier or label used during ALRA imputation (if that method was applied)

adata.raw.X:
- matrix of size `cell ID x genes`
- The values are gene counts

adata.X:
- matrix of size `cell ID x genes`
- The values could be normalized and transformed

In [None]:
# cell ID x genes
print(f"Number of cell x Number of genes: {adata.raw.X.shape}")

Next, we add a `cell_type` annotation to the metadata based on the clustering labels and plot it in the UMAP representation.

Original clusters are labeled numerically.

In [None]:
adata.obs['cell_type'] = adata.obs['cluster'].replace({
  0: "CD14+ monocytes",
  1: "Intermediate monocytes",
  2: "CD16+ monocytes",
  3: "Myeloid type-2 dendritic cells"}
)
sc.pl.umap(adata, color=["cell_type"])


How is UMAP calculated in Scanpy?

**Step 1: PCA**

* Compress high-dimensional gene expression into fewer components while preserving structure.
* **Input**: `adata.X` (cells × genes)
* **Output**:
  `adata.obsm["X_pca"]` (cells × PCs)

**Step 2: Compute Neighborhood Graph**

* Build a graph connecting each cell to its nearest neighbors based on PCA space.
* **Input**: `adata.obsm["X_pca"]`
* **Output**:
  `adata.obsp["connectivities"]` → weighted cell-cell similarity graph
  `adata.obsp["distances"]` → raw distances

**Step 3: Compute UMAP Embedding**

* Learn a 2D or 3D layout that preserves local cell neighborhoods.
* **Input**: `adata.obsp["connectivities"]`
* **Output**:
  `adata.obsm["X_umap"]` (cells × 2)


## Gene-gene distance computation ##
We narrow down the gene list for gene-gene distance computation by focusing on the top 500 variable genes expressed by 1% - 50% of cells.



In [None]:
if 'counts' not in adata.layers:
    adata.layers['counts'] = adata.raw.X.copy()
genes = select_top_genes(adata, layer='counts', n_variable_genes=500)

Support for third party widgets will remain active for the duration of the session. To disable support:

In [None]:
from google.colab import output
output.disable_custom_widget_manager()

In [None]:
len(genes)

## Prepare the input for gene-gene Wasserstein distance computation ##

Next, we construct the cell-cell kNN graph and calculate cell-cell graph distances.

In [None]:
run_dm(adata)
cell_graph_dist = get_graph_distance(adata, k=10)

In [None]:
gene_expression_updated, graph_dist_updated = coarse_grain_adata(adata, graph_dist=cell_graph_dist, features=genes, n=500)

In [None]:
gene_dist_mat = cal_ot_mat(gene_expr=gene_expression_updated,
                           ot_cost=graph_dist_updated,
                           show_progress_bar=True)

## Gene trajectory inference and visualization ##

Next, we generate the gene embedding by employing Diffusion Map.

In [None]:
gene_embedding, _ = get_gene_embedding(gene_dist_mat, k = 5)

The extraction of gene trajectories is done sequentially. The initial node (terminus-1) is defined by the gene with the largest distance from the origin in the Diffusion Map embedding. A random-walk procedure is then employed on the gene graph to select the other genes that belong to the trajectory terminated at terminus-1. After retrieving genes for the first trajectory, we identify the terminus of the subsequent gene trajectory among the remaining genes and repeat the steps above. This is done iteratively until all detectable trajectories are extracted.

To refine the trajectories we use the widget `ExtractGeneTrajectoryWidget`, which allows to tune parameters. We also label  some genes that are important in the system (e.g. CLEC5A, CD1C, FCGR3A, and PKIB)


In [None]:
extract_gene_trajectory_widget = ExtractGeneTrajectoryWidget(gene_embedding, gene_dist_mat, genes,
                                                             label_genes=['CLEC5A', 'CD1C', 'FCGR3A', 'PKIB'])
extract_gene_trajectory_widget

We perform the following changes
- Since this tutorial is made on a small dataset, set `k` from the default of `10` to `5`.
- We slide the values of `t_list` to cover all genes, first extending the `t` for  `Trajectory-2` to 8
- Adjust the `t` for `Trajectory-1` to 4
- Adjust the `t` for `Trajectory-3`, until all genes are cover, i.e. to `7`

The interactive optimization of the parameters is equivalent to setting the parameters
- `k` = `5`
- `t_list` = `[4, 8, 7]`
the same set of parameters could have been applied directly as
```
gene_trajectory = extract_gene_trajectory(gene_embedding, gene_dist_mat, t_list = [4, 8, 7], gene_names=genes, k=5)
plot_gene_trajectory_3d(gene_trajectory, label_genes=['CLEC5A', 'CD1C', 'FCGR3A', 'PKIB'])
```

Next, we extract the gene trajectory from the widget:

In [None]:
gene_trajectory = extract_gene_trajectory_widget.gene_trajectory

In [None]:
gene_trajectory

## Visualize gene bin plots ##

To examine how each given gene trajectory is reflected over the cell graph, we can track how these genes are expressed across different regions in the cell embedding.
For generating gene bin plots, we use the smooth expression values that are stored in the `alra` layer, which were computed using [ALRA](https://github.com/KlugerLab/ALRA/blob/master/README.md) imputation.


In [None]:
print(gene_trajectory.columns.tolist())


In [None]:
add_gene_bin_score(adata, gene_trajectory=gene_trajectory, n_bins=5, trajectories=2, layer='alra')

In [None]:
plot_gene_trajectory_umap(adata, 'Trajectory1', other_panels='cell_type')


In [None]:
plot_gene_trajectory_umap(adata, 'Trajectory1', other_panels='cell_type')


In [None]:
plot_gene_trajectory_umap(adata, 'Trajectory2', other_panels='cell_type')


We plot Trajectory 3 in reverse order as we want the gene `CLEC5A` to be at the end of the trajectory rather than at the beginning

In [None]:
plot_gene_trajectory_umap(adata, 'Trajectory3', other_panels='cell_type', reverse=True)
