-
Notifications
You must be signed in to change notification settings - Fork 341
/
_scvi.py
295 lines (267 loc) · 10.7 KB
/
_scvi.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
import logging
from typing import List, Literal, Optional
import numpy as np
from anndata import AnnData
from scvi import REGISTRY_KEYS
from scvi._types import MinifiedDataType
from scvi.data import AnnDataManager
from scvi.data._constants import _ADATA_MINIFY_TYPE_UNS_KEY, ADATA_MINIFY_TYPE
from scvi.data._utils import _get_adata_minify_type
from scvi.data.fields import (
BaseAnnDataField,
CategoricalJointObsField,
CategoricalObsField,
LayerField,
NumericalJointObsField,
NumericalObsField,
ObsmField,
StringUnsField,
)
from scvi.model._utils import _init_library_size
from scvi.model.base import UnsupervisedTrainingMixin
from scvi.model.utils import get_minified_adata_scrna
from scvi.module import VAE
from scvi.utils import setup_anndata_dsp
from .base import ArchesMixin, BaseMinifiedModeModelClass, RNASeqMixin, VAEMixin
_SCVI_LATENT_QZM = "_scvi_latent_qzm"
_SCVI_LATENT_QZV = "_scvi_latent_qzv"
_SCVI_OBSERVED_LIB_SIZE = "_scvi_observed_lib_size"
logger = logging.getLogger(__name__)
class SCVI(
RNASeqMixin,
VAEMixin,
ArchesMixin,
UnsupervisedTrainingMixin,
BaseMinifiedModeModelClass,
):
"""single-cell Variational Inference :cite:p:`Lopez18`.
Parameters
----------
adata
AnnData object that has been registered via :meth:`~scvi.model.SCVI.setup_anndata`.
n_hidden
Number of nodes per hidden layer.
n_latent
Dimensionality of the latent space.
n_layers
Number of hidden layers used for encoder and decoder NNs.
dropout_rate
Dropout rate for neural networks.
dispersion
One of the following:
* ``'gene'`` - dispersion parameter of NB is constant per gene across cells
* ``'gene-batch'`` - dispersion can differ between different batches
* ``'gene-label'`` - dispersion can differ between different labels
* ``'gene-cell'`` - dispersion can differ for every gene in every cell
gene_likelihood
One of:
* ``'nb'`` - Negative binomial distribution
* ``'zinb'`` - Zero-inflated negative binomial distribution
* ``'poisson'`` - Poisson distribution
latent_distribution
One of:
* ``'normal'`` - Normal distribution
* ``'ln'`` - Logistic normal distribution (Normal(0, I) transformed by softmax)
**model_kwargs
Keyword args for :class:`~scvi.module.VAE`
Examples
--------
>>> adata = anndata.read_h5ad(path_to_anndata)
>>> scvi.model.SCVI.setup_anndata(adata, batch_key="batch")
>>> vae = scvi.model.SCVI(adata)
>>> vae.train()
>>> adata.obsm["X_scVI"] = vae.get_latent_representation()
>>> adata.obsm["X_normalized_scVI"] = vae.get_normalized_expression()
Notes
-----
See further usage examples in the following tutorials:
1. :doc:`/tutorials/notebooks/quick_start/api_overview`
2. :doc:`/tutorials/notebooks/scrna/harmonization`
3. :doc:`/tutorials/notebooks/scrna/scarches_scvi_tools`
4. :doc:`/tutorials/notebooks/scrna/scvi_in_R`
"""
_module_cls = VAE
def __init__(
self,
adata: AnnData,
n_hidden: int = 128,
n_latent: int = 10,
n_layers: int = 1,
dropout_rate: float = 0.1,
dispersion: Literal["gene", "gene-batch", "gene-label", "gene-cell"] = "gene",
gene_likelihood: Literal["zinb", "nb", "poisson"] = "zinb",
latent_distribution: Literal["normal", "ln"] = "normal",
**model_kwargs,
):
super().__init__(adata)
n_cats_per_cov = (
self.adata_manager.get_state_registry(
REGISTRY_KEYS.CAT_COVS_KEY
).n_cats_per_key
if REGISTRY_KEYS.CAT_COVS_KEY in self.adata_manager.data_registry
else None
)
n_batch = self.summary_stats.n_batch
use_size_factor_key = (
REGISTRY_KEYS.SIZE_FACTOR_KEY in self.adata_manager.data_registry
)
library_log_means, library_log_vars = None, None
if not use_size_factor_key and self.minified_data_type is None:
library_log_means, library_log_vars = _init_library_size(
self.adata_manager, n_batch
)
self.module = self._module_cls(
n_input=self.summary_stats.n_vars,
n_batch=n_batch,
n_labels=self.summary_stats.n_labels,
n_continuous_cov=self.summary_stats.get("n_extra_continuous_covs", 0),
n_cats_per_cov=n_cats_per_cov,
n_hidden=n_hidden,
n_latent=n_latent,
n_layers=n_layers,
dropout_rate=dropout_rate,
dispersion=dispersion,
gene_likelihood=gene_likelihood,
latent_distribution=latent_distribution,
use_size_factor_key=use_size_factor_key,
library_log_means=library_log_means,
library_log_vars=library_log_vars,
**model_kwargs,
)
self.module.minified_data_type = self.minified_data_type
self._model_summary_string = (
"SCVI Model with the following params: \nn_hidden: {}, n_latent: {}, n_layers: {}, dropout_rate: "
"{}, dispersion: {}, gene_likelihood: {}, latent_distribution: {}"
).format(
n_hidden,
n_latent,
n_layers,
dropout_rate,
dispersion,
gene_likelihood,
latent_distribution,
)
self.init_params_ = self._get_init_params(locals())
@classmethod
@setup_anndata_dsp.dedent
def setup_anndata(
cls,
adata: AnnData,
layer: Optional[str] = None,
batch_key: Optional[str] = None,
labels_key: Optional[str] = None,
size_factor_key: Optional[str] = None,
categorical_covariate_keys: Optional[List[str]] = None,
continuous_covariate_keys: Optional[List[str]] = None,
**kwargs,
):
"""%(summary)s.
Parameters
----------
%(param_adata)s
%(param_layer)s
%(param_batch_key)s
%(param_labels_key)s
%(param_size_factor_key)s
%(param_cat_cov_keys)s
%(param_cont_cov_keys)s
"""
setup_method_args = cls._get_setup_method_args(**locals())
anndata_fields = [
LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True),
CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key),
CategoricalObsField(REGISTRY_KEYS.LABELS_KEY, labels_key),
NumericalObsField(
REGISTRY_KEYS.SIZE_FACTOR_KEY, size_factor_key, required=False
),
CategoricalJointObsField(
REGISTRY_KEYS.CAT_COVS_KEY, categorical_covariate_keys
),
NumericalJointObsField(
REGISTRY_KEYS.CONT_COVS_KEY, continuous_covariate_keys
),
]
# register new fields if the adata is minified
adata_minify_type = _get_adata_minify_type(adata)
if adata_minify_type is not None:
anndata_fields += cls._get_fields_for_adata_minification(adata_minify_type)
adata_manager = AnnDataManager(
fields=anndata_fields, setup_method_args=setup_method_args
)
adata_manager.register_fields(adata, **kwargs)
cls.register_manager(adata_manager)
@staticmethod
def _get_fields_for_adata_minification(
minified_data_type: MinifiedDataType,
) -> List[BaseAnnDataField]:
"""Return the anndata fields required for adata minification of the given minified_data_type."""
if minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR:
fields = [
ObsmField(
REGISTRY_KEYS.LATENT_QZM_KEY,
_SCVI_LATENT_QZM,
),
ObsmField(
REGISTRY_KEYS.LATENT_QZV_KEY,
_SCVI_LATENT_QZV,
),
NumericalObsField(
REGISTRY_KEYS.OBSERVED_LIB_SIZE,
_SCVI_OBSERVED_LIB_SIZE,
),
]
else:
raise NotImplementedError(f"Unknown MinifiedDataType: {minified_data_type}")
fields.append(
StringUnsField(
REGISTRY_KEYS.MINIFY_TYPE_KEY,
_ADATA_MINIFY_TYPE_UNS_KEY,
),
)
return fields
def minify_adata(
self,
minified_data_type: MinifiedDataType = ADATA_MINIFY_TYPE.LATENT_POSTERIOR,
use_latent_qzm_key: str = "X_latent_qzm",
use_latent_qzv_key: str = "X_latent_qzv",
) -> None:
"""Minifies the model's adata.
Minifies the adata, and registers new anndata fields: latent qzm, latent qzv, adata uns
containing minified-adata type, and library size.
This also sets the appropriate property on the module to indicate that the adata is minified.
Parameters
----------
minified_data_type
How to minify the data. Currently only supports `latent_posterior_parameters`.
If minified_data_type == `latent_posterior_parameters`:
* the original count data is removed (`adata.X`, adata.raw, and any layers)
* the parameters of the latent representation of the original data is stored
* everything else is left untouched
use_latent_qzm_key
Key to use in `adata.obsm` where the latent qzm params are stored
use_latent_qzv_key
Key to use in `adata.obsm` where the latent qzv params are stored
Notes
-----
The modification is not done inplace -- instead the model is assigned a new (minified)
version of the adata.
"""
# TODO(adamgayoso): Add support for a scenario where we want to cache the latent posterior
# without removing the original counts.
if minified_data_type != ADATA_MINIFY_TYPE.LATENT_POSTERIOR:
raise NotImplementedError(f"Unknown MinifiedDataType: {minified_data_type}")
if self.module.use_observed_lib_size is False:
raise ValueError(
"Cannot minify the data if `use_observed_lib_size` is False"
)
minified_adata = get_minified_adata_scrna(self.adata, minified_data_type)
minified_adata.obsm[_SCVI_LATENT_QZM] = self.adata.obsm[use_latent_qzm_key]
minified_adata.obsm[_SCVI_LATENT_QZV] = self.adata.obsm[use_latent_qzv_key]
counts = self.adata_manager.get_from_registry(REGISTRY_KEYS.X_KEY)
minified_adata.obs[_SCVI_OBSERVED_LIB_SIZE] = np.squeeze(
np.asarray(counts.sum(axis=1))
)
self._update_adata_and_manager_post_minification(
minified_adata, minified_data_type
)
self.module.minified_data_type = minified_data_type