In [None]:
from __future__ import annotations

import scanpy as sc
import umap

import umapjax

In [None]:
url = "https://datasets.cellxgene.cziscience.com/b522bd5f-c34d-47a1-954a-f5379c98945a.h5ad"
adata = sc.read("/tmp/test.h5ad", backup_url=url, backed="r")

In [None]:
adata.shape

## Umap from author pipeline

In [None]:
sc.pl.embedding(adata, color=["batch_id", "cell_type"], size=2, basis="X_umap", ncols=1)

## Setup

To compare time fairly, we precompute the nearest neighbors graph and use PCA initialization. By default umap/umapjax will use a spectral initialization that occurs on CPU.

In [None]:
from umap.umap_ import nearest_neighbors

data_knn = nearest_neighbors(
    adata.obsm["X_pca"], n_neighbors=15, metric="euclidean", metric_kwds=None, angular=False, random_state=42
)

In [None]:
# Use first two principal components as the initializer.
# So that spectral init runtime is not considered.
init = adata.obsm["X_pca"][:, :2]

## Umap-learn

In [None]:
ref_model = umap.UMAP(n_neighbors=15, precomputed_knn=data_knn, init=init)
adata.obsm["X_umap_ref"] = ref_model.fit_transform(adata.obsm["X_pca"])

In [None]:
sc.pl.embedding(adata, color=["batch_id", "cell_type"], size=2, basis="X_umap_ref", ncols=1)

## Umapjax

In [None]:
jax_model = umapjax.UmapJax(
    n_neighbors=15,
    precomputed_knn=data_knn,
    init=init,
    layout_backend="mx",
)
adata.obsm["X_umap_jax"] = jax_model.fit_transform(adata.obsm["X_pca"])

In [None]:
sc.pl.embedding(adata, color=["batch_id", "cell_type"], size=2, basis="X_umap_jax", ncols=1)