Skip to content

Commit

Permalink
Merge pull request #941 from YosefLab/renaming
Browse files Browse the repository at this point in the history
renaming data loader
  • Loading branch information
galenxing committed Feb 19, 2021
2 parents bcae907 + ef96280 commit f2c42c6
Show file tree
Hide file tree
Showing 10 changed files with 76 additions and 40 deletions.
2 changes: 1 addition & 1 deletion scvi/external/gimvi/_model.py
Expand Up @@ -225,7 +225,7 @@ def train(
def _make_scvi_dls(self, adatas: List[AnnData] = None, batch_size=128):
if adatas is None:
adatas = self.adatas
post_list = [self._make_scvi_dl(ad) for ad in adatas]
post_list = [self._make_data_loader(ad) for ad in adatas]
for i, dl in enumerate(post_list):
dl.mode = i

Expand Down
2 changes: 1 addition & 1 deletion scvi/external/solo/_model.py
Expand Up @@ -215,7 +215,7 @@ def predict(self, soft: bool = True):
"""
adata = self._validate_anndata(None)

scdl = self._make_scvi_dl(
scdl = self._make_data_loader(
adata=adata,
)

Expand Down
4 changes: 3 additions & 1 deletion scvi/model/_autozi.py
Expand Up @@ -176,7 +176,9 @@ def get_marginal_ll(
if indices is None:
indices = np.arange(adata.n_obs)

scdl = self._make_scvi_dl(adata=adata, indices=indices, batch_size=batch_size)
scdl = self._make_data_loader(
adata=adata, indices=indices, batch_size=batch_size
)

log_lkl = 0
to_sum = torch.zeros((n_mc_samples,))
Expand Down
8 changes: 6 additions & 2 deletions scvi/model/_peakvi.py
Expand Up @@ -235,7 +235,9 @@ def get_library_size_factors(
batch_size: int = 128,
):
adata = self._validate_anndata(adata)
scdl = self._make_scvi_dl(adata=adata, indices=indices, batch_size=batch_size)
scdl = self._make_data_loader(
adata=adata, indices=indices, batch_size=batch_size
)

library_sizes = []
for tensors in scdl:
Expand Down Expand Up @@ -300,7 +302,9 @@ def get_accessibility_estimates(
"""
adata = self._validate_anndata(adata)
post = self._make_scvi_dl(adata=adata, indices=indices, batch_size=batch_size)
post = self._make_data_loader(
adata=adata, indices=indices, batch_size=batch_size
)
transform_batch = _get_batch_code_from_category(adata, transform_batch)

if threshold is not None and (threshold < 0 or threshold > 1):
Expand Down
20 changes: 10 additions & 10 deletions scvi/model/_scanvi.py
Expand Up @@ -356,7 +356,7 @@ def predict(
if indices is None:
indices = np.arange(adata.n_obs)

scdl = self._make_scvi_dl(
scdl = self._make_data_loader(
adata=adata,
indices=indices,
batch_size=batch_size,
Expand Down Expand Up @@ -415,7 +415,7 @@ def _train_test_val_split(
n_samples_per_label
Number of subsamples for each label class to sample per epoch
**kwargs
Keyword args for `_make_scvi_dl()`
Keyword args for `_make_data_loader()`
"""
train_size = float(train_size)
if train_size > 1.0 or train_size <= 0.0:
Expand Down Expand Up @@ -486,37 +486,37 @@ def get_train_val_split(n_samples, test_size, train_size):
indices_test = indices_test.astype(int)

if len(self._labeled_indices) != 0:
dataloader_class = SemiSupervisedDataLoader
data_loader_class = SemiSupervisedDataLoader
dl_kwargs = {
"unlabeled_category": unlabeled_category,
"n_samples_per_label": n_samples_per_label,
}
else:
dataloader_class = AnnDataLoader
data_loader_class = AnnDataLoader
dl_kwargs = {}
dl_kwargs.update(kwargs)

scanvi_train_dl = self._make_scvi_dl(
scanvi_train_dl = self._make_data_loader(
adata,
indices=indices_train,
shuffle=True,
scvi_dl_class=dataloader_class,
data_loader_class=data_loader_class,
drop_last=3,
**dl_kwargs,
)
scanvi_val_dl = self._make_scvi_dl(
scanvi_val_dl = self._make_data_loader(
adata,
indices=indices_val,
shuffle=True,
scvi_dl_class=dataloader_class,
data_loader_class=data_loader_class,
drop_last=3,
**dl_kwargs,
)
scanvi_test_dl = self._make_scvi_dl(
scanvi_test_dl = self._make_data_loader(
adata,
indices=indices_test,
shuffle=True,
scvi_dl_class=dataloader_class,
data_loader_class=data_loader_class,
drop_last=3,
**dl_kwargs,
)
Expand Down
24 changes: 18 additions & 6 deletions scvi/model/_totalvi.py
Expand Up @@ -272,7 +272,9 @@ def get_latent_library_size(
raise RuntimeError("Please train the model first.")

adata = self._validate_anndata(adata)
post = self._make_scvi_dl(adata=adata, indices=indices, batch_size=batch_size)
post = self._make_data_loader(
adata=adata, indices=indices, batch_size=batch_size
)
libraries = []
for tensors in post:
inference_inputs = self.module._get_inference_input(tensors)
Expand Down Expand Up @@ -361,7 +363,9 @@ def get_normalized_expression(
Otherwise, shape is ``(cells, genes)``. Return type is ``pd.DataFrame`` unless ``return_numpy`` is True.
"""
adata = self._validate_anndata(adata)
post = self._make_scvi_dl(adata=adata, indices=indices, batch_size=batch_size)
post = self._make_data_loader(
adata=adata, indices=indices, batch_size=batch_size
)

if gene_list is None:
gene_mask = slice(None)
Expand Down Expand Up @@ -524,7 +528,9 @@ def get_protein_foreground_probability(
Otherwise, shape is `(cells, genes)`. In this case, return type is :class:`~pandas.DataFrame` unless `return_numpy` is True.
"""
adata = self._validate_anndata(adata)
post = self._make_scvi_dl(adata=adata, indices=indices, batch_size=batch_size)
post = self._make_data_loader(
adata=adata, indices=indices, batch_size=batch_size
)

if protein_list is None:
protein_mask = slice(None)
Expand Down Expand Up @@ -754,7 +760,9 @@ def posterior_predictive_sample(
all_proteins = self.scvi_setup_dict_["protein_names"]
protein_mask = [True if p in protein_list else False for p in all_proteins]

scdl = self._make_scvi_dl(adata=adata, indices=indices, batch_size=batch_size)
scdl = self._make_data_loader(
adata=adata, indices=indices, batch_size=batch_size
)

scdl_list = []
for tensors in scdl:
Expand Down Expand Up @@ -802,7 +810,9 @@ def _get_denoised_samples(
int of which batch to condition on for all cells
"""
adata = self._validate_anndata(adata)
scdl = self._make_scvi_dl(adata=adata, indices=indices, batch_size=batch_size)
scdl = self._make_data_loader(
adata=adata, indices=indices, batch_size=batch_size
)

scdl_list = []
for tensors in scdl:
Expand Down Expand Up @@ -999,7 +1009,9 @@ def _data_loader_cls(self):
@torch.no_grad()
def get_protein_background_mean(self, adata, indices, batch_size):
adata = self._validate_anndata(adata)
scdl = self._make_scvi_dl(adata=adata, indices=indices, batch_size=batch_size)
scdl = self._make_data_loader(
adata=adata, indices=indices, batch_size=batch_size
)
background_mean = []
for tensors in scdl:
_, inference_outputs, _ = self.module.forward(tensors)
Expand Down
18 changes: 9 additions & 9 deletions scvi/model/base/_base_model.py
Expand Up @@ -47,13 +47,13 @@ def __init__(self, adata: Optional[AnnData] = None, use_gpu: Optional[bool] = No
self.validation_indices_ = None
self.history_ = None

def _make_scvi_dl(
def _make_data_loader(
self,
adata: AnnData,
indices: Optional[Sequence[int]] = None,
batch_size: Optional[int] = None,
shuffle: bool = False,
scvi_dl_class=None,
data_loader_class=None,
**data_loader_kwargs,
):
"""
Expand All @@ -77,13 +77,13 @@ def _make_scvi_dl(
batch_size = settings.batch_size
if indices is None:
indices = np.arange(adata.n_obs)
if scvi_dl_class is None:
scvi_dl_class = self._data_loader_cls
if data_loader_class is None:
data_loader_class = self._data_loader_cls

if "num_workers" not in data_loader_kwargs:
data_loader_kwargs.update({"num_workers": settings.dl_num_workers})

dl = scvi_dl_class(
dl = data_loader_class(
adata,
shuffle=shuffle,
indices=indices,
Expand Down Expand Up @@ -113,7 +113,7 @@ def _train_test_val_split(
validation_size
float, or None (default is None)
**kwargs
Keyword args for `_make_scvi_dl()`
Keyword args for `_make_data_loader()`
"""
train_size = float(train_size)
if train_size > 1.0 or train_size <= 0.0:
Expand All @@ -140,13 +140,13 @@ def _train_test_val_split(

# do not remove drop_last=3, skips over small minibatches
return (
self._make_scvi_dl(
self._make_data_loader(
adata, indices=indices_train, shuffle=True, drop_last=3, **kwargs
),
self._make_scvi_dl(
self._make_data_loader(
adata, indices=indices_validation, shuffle=True, drop_last=3, **kwargs
),
self._make_scvi_dl(
self._make_data_loader(
adata, indices=indices_test, shuffle=True, drop_last=3, **kwargs
),
)
Expand Down
20 changes: 15 additions & 5 deletions scvi/model/base/_rnamixin.py
Expand Up @@ -81,7 +81,9 @@ def get_normalized_expression(
Otherwise, shape is `(cells, genes)`. In this case, return type is :class:`~pandas.DataFrame` unless `return_numpy` is True.
"""
adata = self._validate_anndata(adata)
scdl = self._make_scvi_dl(adata=adata, indices=indices, batch_size=batch_size)
scdl = self._make_data_loader(
adata=adata, indices=indices, batch_size=batch_size
)

transform_batch = _get_batch_code_from_category(adata, transform_batch)

Expand Down Expand Up @@ -252,7 +254,9 @@ def posterior_predictive_sample(
raise ValueError("Invalid gene_likelihood.")

adata = self._validate_anndata(adata)
scdl = self._make_scvi_dl(adata=adata, indices=indices, batch_size=batch_size)
scdl = self._make_data_loader(
adata=adata, indices=indices, batch_size=batch_size
)

if indices is None:
indices = np.arange(adata.n_obs)
Expand Down Expand Up @@ -308,7 +312,9 @@ def _get_denoised_samples(
denoised_samples
"""
adata = self._validate_anndata(adata)
scdl = self._make_scvi_dl(adata=adata, indices=indices, batch_size=batch_size)
scdl = self._make_data_loader(
adata=adata, indices=indices, batch_size=batch_size
)

data_loader_list = []
for tensors in scdl:
Expand Down Expand Up @@ -453,7 +459,9 @@ def get_likelihood_parameters(
Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`.
"""
adata = self._validate_anndata(adata)
scdl = self._make_scvi_dl(adata=adata, indices=indices, batch_size=batch_size)
scdl = self._make_data_loader(
adata=adata, indices=indices, batch_size=batch_size
)

dropout_list = []
mean_list = []
Expand Down Expand Up @@ -522,7 +530,9 @@ def get_latent_library_size(
if self.is_trained_ is False:
raise RuntimeError("Please train the model first.")
adata = self._validate_anndata(adata)
scdl = self._make_scvi_dl(adata=adata, indices=indices, batch_size=batch_size)
scdl = self._make_data_loader(
adata=adata, indices=indices, batch_size=batch_size
)
libraries = []
for tensors in scdl:
inference_inputs = self.module._get_inference_input(tensors)
Expand Down
16 changes: 12 additions & 4 deletions scvi/model/base/_vaemixin.py
Expand Up @@ -36,7 +36,9 @@ def get_elbo(
Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`.
"""
adata = self._validate_anndata(adata)
scdl = self._make_scvi_dl(adata=adata, indices=indices, batch_size=batch_size)
scdl = self._make_data_loader(
adata=adata, indices=indices, batch_size=batch_size
)
elbo = compute_elbo(self.module, scdl)
return -elbo

Expand Down Expand Up @@ -69,7 +71,9 @@ def get_marginal_ll(
adata = self._validate_anndata(adata)
if indices is None:
indices = np.arange(adata.n_obs)
scdl = self._make_scvi_dl(adata=adata, indices=indices, batch_size=batch_size)
scdl = self._make_data_loader(
adata=adata, indices=indices, batch_size=batch_size
)
if hasattr(self.module, "marginal_ll"):
log_lkl = 0
for tensors in scdl:
Expand Down Expand Up @@ -106,7 +110,9 @@ def get_reconstruction_error(
Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`.
"""
adata = self._validate_anndata(adata)
scdl = self._make_scvi_dl(adata=adata, indices=indices, batch_size=batch_size)
scdl = self._make_data_loader(
adata=adata, indices=indices, batch_size=batch_size
)
reconstruction_error = compute_reconstruction_error(self.module, scdl)
return reconstruction_error

Expand Down Expand Up @@ -148,7 +154,9 @@ def get_latent_representation(
raise RuntimeError("Please train the model first.")

adata = self._validate_anndata(adata)
scdl = self._make_scvi_dl(adata=adata, indices=indices, batch_size=batch_size)
scdl = self._make_data_loader(
adata=adata, indices=indices, batch_size=batch_size
)
latent = []
for tensors in scdl:
inference_inputs = self.module._get_inference_input(tensors)
Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_scarches.py
Expand Up @@ -10,7 +10,7 @@


def single_pass_for_online_update(model):
dl = model._make_scvi_dl(model.adata, indices=range(0, 10))
dl = model._make_data_loader(model.adata, indices=range(0, 10))
for i_batch, tensors in enumerate(dl):
_, _, scvi_loss = model.module(tensors)
scvi_loss.loss.backward()
Expand Down

0 comments on commit f2c42c6

Please sign in to comment.