-
Notifications
You must be signed in to change notification settings - Fork 342
/
_destvi.py
401 lines (366 loc) · 13.9 KB
/
_destvi.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
import logging
from collections import OrderedDict
from typing import Dict, List, Optional, Sequence, Union
import numpy as np
import pandas as pd
import torch
from anndata import AnnData
from scvi import REGISTRY_KEYS
from scvi.data import AnnDataManager
from scvi.data.fields import LayerField, NumericalObsField
from scvi.model import CondSCVI
from scvi.model.base import BaseModelClass, UnsupervisedTrainingMixin
from scvi.module import MRDeconv
from scvi.utils import setup_anndata_dsp
from scvi.utils._docstrings import devices_dsp
logger = logging.getLogger(__name__)
class DestVI(UnsupervisedTrainingMixin, BaseModelClass):
"""Multi-resolution deconvolution of Spatial Transcriptomics data (DestVI) :cite:p:`Lopez21`. Most users will use the alternate constructor (see example).
Parameters
----------
st_adata
spatial transcriptomics AnnData object that has been registered via :meth:`~scvi.model.DestVI.setup_anndata`.
cell_type_mapping
mapping between numerals and cell type labels
decoder_state_dict
state_dict from the decoder of the CondSCVI model
px_decoder_state_dict
state_dict from the px_decoder of the CondSCVI model
px_r
parameters for the px_r tensor in the CondSCVI model
n_hidden
Number of nodes per hidden layer.
n_latent
Dimensionality of the latent space.
n_layers
Number of hidden layers used for encoder and decoder NNs.
**module_kwargs
Keyword args for :class:`~scvi.modules.MRDeconv`
Examples
--------
>>> sc_adata = anndata.read_h5ad(path_to_scRNA_anndata)
>>> scvi.model.CondSCVI.setup_anndata(sc_adata)
>>> sc_model = scvi.model.CondSCVI(sc_adata)
>>> st_adata = anndata.read_h5ad(path_to_ST_anndata)
>>> DestVI.setup_anndata(st_adata)
>>> spatial_model = DestVI.from_rna_model(st_adata, sc_model)
>>> spatial_model.train(max_epochs=2000)
>>> st_adata.obsm["proportions"] = spatial_model.get_proportions(st_adata)
>>> gamma = spatial_model.get_gamma(st_adata)
Notes
-----
See further usage examples in the following tutorials:
1. :doc:`/tutorials/notebooks/spatial/DestVI_tutorial`
"""
_module_cls = MRDeconv
def __init__(
self,
st_adata: AnnData,
cell_type_mapping: np.ndarray,
decoder_state_dict: OrderedDict,
px_decoder_state_dict: OrderedDict,
px_r: np.ndarray,
n_hidden: int,
n_latent: int,
n_layers: int,
dropout_decoder: float,
l1_reg: float,
**module_kwargs,
):
super().__init__(st_adata)
self.module = self._module_cls(
n_spots=st_adata.n_obs,
n_labels=cell_type_mapping.shape[0],
decoder_state_dict=decoder_state_dict,
px_decoder_state_dict=px_decoder_state_dict,
px_r=px_r,
n_genes=st_adata.n_vars,
n_latent=n_latent,
n_layers=n_layers,
n_hidden=n_hidden,
dropout_decoder=dropout_decoder,
l1_reg=l1_reg,
**module_kwargs,
)
self.cell_type_mapping = cell_type_mapping
self._model_summary_string = "DestVI Model"
self.init_params_ = self._get_init_params(locals())
@classmethod
def from_rna_model(
cls,
st_adata: AnnData,
sc_model: CondSCVI,
vamp_prior_p: int = 15,
l1_reg: float = 0.0,
**module_kwargs,
):
"""Alternate constructor for exploiting a pre-trained model on a RNA-seq dataset.
Parameters
----------
st_adata
registered anndata object
sc_model
trained CondSCVI model
vamp_prior_p
number of mixture parameter for VampPrior calculations
l1_reg
Scalar parameter indicating the strength of L1 regularization on cell type proportions.
A value of 50 leads to sparser results.
**model_kwargs
Keyword args for :class:`~scvi.model.DestVI`
"""
decoder_state_dict = sc_model.module.decoder.state_dict()
px_decoder_state_dict = sc_model.module.px_decoder.state_dict()
px_r = sc_model.module.px_r.detach().cpu().numpy()
mapping = sc_model.adata_manager.get_state_registry(
REGISTRY_KEYS.LABELS_KEY
).categorical_mapping
dropout_decoder = sc_model.module.dropout_rate
if vamp_prior_p is None:
mean_vprior = None
var_vprior = None
else:
mean_vprior, var_vprior, mp_vprior = sc_model.get_vamp_prior(
sc_model.adata, p=vamp_prior_p
)
return cls(
st_adata,
mapping,
decoder_state_dict,
px_decoder_state_dict,
px_r,
sc_model.module.n_hidden,
sc_model.module.n_latent,
sc_model.module.n_layers,
mean_vprior=mean_vprior,
var_vprior=var_vprior,
mp_vprior=mp_vprior,
dropout_decoder=dropout_decoder,
l1_reg=l1_reg,
**module_kwargs,
)
def get_proportions(
self,
keep_noise: bool = False,
indices: Optional[Sequence[int]] = None,
batch_size: Optional[int] = None,
) -> pd.DataFrame:
"""Returns the estimated cell type proportion for the spatial data.
Shape is n_cells x n_labels OR n_cells x (n_labels + 1) if keep_noise.
Parameters
----------
keep_noise
whether to account for the noise term as a standalone cell type in the proportion estimate.
indices
Indices of cells in adata to use. Only used if amortization. If `None`, all cells are used.
batch_size
Minibatch size for data loading into model. Only used if amortization. Defaults to `scvi.settings.batch_size`.
"""
self._check_if_trained()
column_names = self.cell_type_mapping
index_names = self.adata.obs.index
if keep_noise:
column_names = np.append(column_names, "noise_term")
if self.module.amortization in ["both", "proportion"]:
stdl = self._make_data_loader(
adata=self.adata, indices=indices, batch_size=batch_size
)
prop_ = []
for tensors in stdl:
generative_inputs = self.module._get_generative_input(tensors, None)
prop_local = self.module.get_proportions(
x=generative_inputs["x"], keep_noise=keep_noise
)
prop_ += [prop_local.cpu()]
data = torch.cat(prop_).numpy()
if indices:
index_names = index_names[indices]
else:
if indices is not None:
logger.info(
"No amortization for proportions, ignoring indices and returning results for the full data"
)
data = self.module.get_proportions(keep_noise=keep_noise)
return pd.DataFrame(
data=data,
columns=column_names,
index=index_names,
)
def get_gamma(
self,
indices: Optional[Sequence[int]] = None,
batch_size: Optional[int] = None,
return_numpy: bool = False,
) -> Union[np.ndarray, Dict[str, pd.DataFrame]]:
"""Returns the estimated cell-type specific latent space for the spatial data.
Parameters
----------
indices
Indices of cells in adata to use. Only used if amortization. If `None`, all cells are used.
batch_size
Minibatch size for data loading into model. Only used if amortization. Defaults to `scvi.settings.batch_size`.
return_numpy
if activated, will return a numpy array of shape is n_spots x n_latent x n_labels.
"""
self._check_if_trained()
column_names = np.arange(self.module.n_latent)
index_names = self.adata.obs.index
if self.module.amortization in ["both", "latent"]:
stdl = self._make_data_loader(
adata=self.adata, indices=indices, batch_size=batch_size
)
gamma_ = []
for tensors in stdl:
generative_inputs = self.module._get_generative_input(tensors, None)
gamma_local = self.module.get_gamma(x=generative_inputs["x"])
gamma_ += [gamma_local.cpu()]
data = torch.cat(gamma_, dim=-1).numpy()
if indices is not None:
index_names = index_names[indices]
else:
if indices is not None:
logger.info(
"No amortization for latent values, ignoring adata and returning results for the full data"
)
data = self.module.get_gamma()
data = np.transpose(data, (2, 0, 1))
if return_numpy:
return data
else:
res = {}
for i, ct in enumerate(self.cell_type_mapping):
res[ct] = pd.DataFrame(
data=data[:, :, i], columns=column_names, index=index_names
)
return res
def get_scale_for_ct(
self,
label: str,
indices: Optional[Sequence[int]] = None,
batch_size: Optional[int] = None,
) -> pd.DataFrame:
r"""Return the scaled parameter of the NB for every spot in queried cell types.
Parameters
----------
label
cell type of interest
indices
Indices of cells in self.adata to use. If `None`, all cells are used.
batch_size
Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`.
Returns
-------
Pandas dataframe of gene_expression
"""
self._check_if_trained()
if label not in self.cell_type_mapping:
raise ValueError("Unknown cell type")
y = np.where(label == self.cell_type_mapping)[0][0]
stdl = self._make_data_loader(
self.adata, indices=indices, batch_size=batch_size
)
scale = []
for tensors in stdl:
generative_inputs = self.module._get_generative_input(tensors, None)
x, ind_x = (
generative_inputs["x"],
generative_inputs["ind_x"],
)
px_scale = self.module.get_ct_specific_expression(x, ind_x, y)
scale += [px_scale.cpu()]
data = torch.cat(scale).numpy()
column_names = self.adata.var.index
index_names = self.adata.obs.index
if indices is not None:
index_names = index_names[indices]
return pd.DataFrame(data=data, columns=column_names, index=index_names)
@devices_dsp.dedent
def train(
self,
max_epochs: int = 2000,
lr: float = 0.003,
use_gpu: Optional[Union[str, int, bool]] = None,
accelerator: str = "auto",
devices: Union[int, List[int], str] = "auto",
train_size: float = 1.0,
validation_size: Optional[float] = None,
shuffle_set_split: bool = True,
batch_size: int = 128,
n_epochs_kl_warmup: int = 200,
plan_kwargs: Optional[dict] = None,
**kwargs,
):
"""Trains the model using MAP inference.
Parameters
----------
max_epochs
Number of epochs to train for
lr
Learning rate for optimization.
%(param_use_gpu)s
%(param_accelerator)s
%(param_devices)s
train_size
Size of training set in the range [0.0, 1.0].
validation_size
Size of the test set. If `None`, defaults to 1 - `train_size`. If
`train_size + validation_size < 1`, the remaining cells belong to a test set.
shuffle_set_split
Whether to shuffle indices before splitting. If `False`, the val, train, and test set are split in the
sequential order of the data according to `validation_size` and `train_size` percentages.
batch_size
Minibatch size to use during training.
n_epochs_kl_warmup
number of epochs needed to reach unit kl weight in the elbo
plan_kwargs
Keyword args for :class:`~scvi.train.TrainingPlan`. Keyword arguments passed to
`train()` will overwrite values present in `plan_kwargs`, when appropriate.
**kwargs
Other keyword args for :class:`~scvi.train.Trainer`.
"""
update_dict = {
"lr": lr,
"n_epochs_kl_warmup": n_epochs_kl_warmup,
}
if plan_kwargs is not None:
plan_kwargs.update(update_dict)
else:
plan_kwargs = update_dict
super().train(
max_epochs=max_epochs,
use_gpu=use_gpu,
accelerator=accelerator,
devices=devices,
train_size=train_size,
validation_size=validation_size,
shuffle_set_split=shuffle_set_split,
batch_size=batch_size,
plan_kwargs=plan_kwargs,
**kwargs,
)
@classmethod
@setup_anndata_dsp.dedent
def setup_anndata(
cls,
adata: AnnData,
layer: Optional[str] = None,
**kwargs,
):
"""%(summary)s.
Parameters
----------
%(param_adata)s
%(param_layer)s
"""
setup_method_args = cls._get_setup_method_args(**locals())
# add index for each cell (provided to pyro plate for correct minibatching)
adata.obs["_indices"] = np.arange(adata.n_obs)
anndata_fields = [
LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True),
NumericalObsField(REGISTRY_KEYS.INDICES_KEY, "_indices"),
]
adata_manager = AnnDataManager(
fields=anndata_fields, setup_method_args=setup_method_args
)
adata_manager.register_fields(adata, **kwargs)
cls.register_manager(adata_manager)