-
Notifications
You must be signed in to change notification settings - Fork 342
/
_mde.py
78 lines (62 loc) · 2.42 KB
/
_mde.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from typing import Literal, Optional, Union
import numpy as np
import pandas as pd
import torch
from scipy.sparse import spmatrix
def mde(
data: Union[np.ndarray, pd.DataFrame, spmatrix, torch.Tensor],
device: Optional[Literal["cpu", "cuda"]] = None,
**kwargs,
) -> np.ndarray:
"""Util to run :func:`pymde.preserve_neighbors` for visualization of scvi-tools embeddings.
Parameters
----------
data
The data of shape (n_obs, k), where k is typically defined by one of the models
in scvi-tools that produces an embedding (e.g., :class:`~scvi.model.SCVI`.)
device
Whether to run on cpu or gpu ("cuda"). If None, tries to run on gpu if available.
kwargs
Keyword args to :func:`pymde.preserve_neighbors`
Returns
-------
The pymde embedding, defaults to two dimensions.
Notes
-----
This function is included in scvi-tools to provide an alternative to UMAP/TSNE that is GPU-
accelerated. The appropriateness of use of visualization of high-dimensional spaces in single-
cell omics remains an open research questions. See:
Chari, Tara, Joeyta Banerjee, and Lior Pachter. "The specious art of single-cell genomics." bioRxiv (2021).
If you use this function in your research please cite:
Agrawal, Akshay, Alnur Ali, and Stephen Boyd. "Minimum-distortion embedding." arXiv preprint arXiv:2103.02559 (2021).
Examples
--------
>>> adata = anndata.read_h5ad(path_to_anndata)
>>> scvi.model.SCVI.setup_anndata(adata, batch_key="batch")
>>> vae = scvi.model.SCVI(adata)
>>> vae.train()
>>> adata.obsm["X_scVI"] = vae.get_latent_representation()
>>> adata.obsm["X_mde"] = scvi.model.utils.mde(adata.obsm["X_scVI"])
"""
try:
import pymde
except ImportError as err:
raise ImportError(
"Please install pymde package via `pip install pymde`"
) from err
if isinstance(data, pd.DataFrame):
data = data.values
device = "cpu" if not torch.cuda.is_available() else "cuda"
_kwargs = {
"embedding_dim": 2,
"constraint": pymde.Standardized(),
"repulsive_fraction": 0.7,
"verbose": False,
"device": device,
"n_neighbors": 15,
}
_kwargs.update(kwargs)
emb = pymde.preserve_neighbors(data, **_kwargs).embed(verbose=_kwargs["verbose"])
if isinstance(emb, torch.Tensor):
emb = emb.cpu().numpy()
return emb