-
Notifications
You must be signed in to change notification settings - Fork 342
/
_linear_scvi.py
164 lines (141 loc) · 5.16 KB
/
_linear_scvi.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import logging
from typing import Literal, Optional
import pandas as pd
from anndata import AnnData
from scvi import REGISTRY_KEYS
from scvi.data import AnnDataManager
from scvi.data.fields import CategoricalObsField, LayerField
from scvi.model._utils import _init_library_size
from scvi.model.base import UnsupervisedTrainingMixin
from scvi.module import LDVAE
from scvi.utils import setup_anndata_dsp
from .base import BaseModelClass, RNASeqMixin, VAEMixin
logger = logging.getLogger(__name__)
class LinearSCVI(RNASeqMixin, VAEMixin, UnsupervisedTrainingMixin, BaseModelClass):
"""Linearly-decoded VAE :cite:p:`Svensson20`.
Parameters
----------
adata
AnnData object that has been registered via :meth:`~scvi.model.LinearSCVI.setup_anndata`.
n_hidden
Number of nodes per hidden layer.
n_latent
Dimensionality of the latent space.
n_layers
Number of hidden layers used for encoder NN.
dropout_rate
Dropout rate for neural networks.
dispersion
One of the following:
* ``'gene'`` - dispersion parameter of NB is constant per gene across cells
* ``'gene-batch'`` - dispersion can differ between different batches
* ``'gene-label'`` - dispersion can differ between different labels
* ``'gene-cell'`` - dispersion can differ for every gene in every cell
gene_likelihood
One of:
* ``'nb'`` - Negative binomial distribution
* ``'zinb'`` - Zero-inflated negative binomial distribution
* ``'poisson'`` - Poisson distribution
latent_distribution
One of:
* ``'normal'`` - Normal distribution
* ``'ln'`` - Logistic normal distribution (Normal(0, I) transformed by softmax)
**model_kwargs
Keyword args for :class:`~scvi.module.LDVAE`
Examples
--------
>>> adata = anndata.read_h5ad(path_to_anndata)
>>> scvi.model.LinearSCVI.setup_anndata(adata, batch_key="batch")
>>> vae = scvi.model.LinearSCVI(adata)
>>> vae.train()
>>> adata.var["loadings"] = vae.get_loadings()
Notes
-----
See further usage examples in the following tutorials:
1. :doc:`/tutorials/notebooks/linear_decoder`
"""
_module_cls = LDVAE
def __init__(
self,
adata: AnnData,
n_hidden: int = 128,
n_latent: int = 10,
n_layers: int = 1,
dropout_rate: float = 0.1,
dispersion: Literal["gene", "gene-batch", "gene-label", "gene-cell"] = "gene",
gene_likelihood: Literal["zinb", "nb", "poisson"] = "nb",
latent_distribution: Literal["normal", "ln"] = "normal",
**model_kwargs,
):
super().__init__(adata)
n_batch = self.summary_stats.n_batch
library_log_means, library_log_vars = _init_library_size(
self.adata_manager, n_batch
)
self.module = self._module_cls(
n_input=self.summary_stats.n_vars,
n_batch=n_batch,
n_hidden=n_hidden,
n_latent=n_latent,
n_layers_encoder=n_layers,
dropout_rate=dropout_rate,
dispersion=dispersion,
gene_likelihood=gene_likelihood,
latent_distribution=latent_distribution,
library_log_means=library_log_means,
library_log_vars=library_log_vars,
**model_kwargs,
)
self._model_summary_string = (
"LinearSCVI Model with the following params: \nn_hidden: {}, n_latent: {}, n_layers: {}, dropout_rate: "
"{}, dispersion: {}, gene_likelihood: {}, latent_distribution: {}"
).format(
n_hidden,
n_latent,
n_layers,
dropout_rate,
dispersion,
gene_likelihood,
latent_distribution,
)
self.n_latent = n_latent
self.init_params_ = self._get_init_params(locals())
def get_loadings(self) -> pd.DataFrame:
"""Extract per-gene weights in the linear decoder.
Shape is genes by `n_latent`.
"""
cols = [f"Z_{i}" for i in range(self.n_latent)]
var_names = self.adata.var_names
loadings = pd.DataFrame(
self.module.get_loadings(), index=var_names, columns=cols
)
return loadings
@classmethod
@setup_anndata_dsp.dedent
def setup_anndata(
cls,
adata: AnnData,
batch_key: Optional[str] = None,
labels_key: Optional[str] = None,
layer: Optional[str] = None,
**kwargs,
):
"""%(summary)s.
Parameters
----------
%(param_adata)s
%(param_batch_key)s
%(param_labels_key)s
%(param_layer)s
"""
setup_method_args = cls._get_setup_method_args(**locals())
anndata_fields = [
LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True),
CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key),
CategoricalObsField(REGISTRY_KEYS.LABELS_KEY, labels_key),
]
adata_manager = AnnDataManager(
fields=anndata_fields, setup_method_args=setup_method_args
)
adata_manager.register_fields(adata, **kwargs)
cls.register_manager(adata_manager)