-
Notifications
You must be signed in to change notification settings - Fork 341
/
_semi_dataloader.py
117 lines (101 loc) · 4.05 KB
/
_semi_dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
from typing import List, Optional, Union
import numpy as np
from scvi import REGISTRY_KEYS
from scvi.data import AnnDataManager
from scvi.data._utils import get_anndata_attribute
from ._ann_dataloader import AnnDataLoader
from ._concat_dataloader import ConcatDataLoader
class SemiSupervisedDataLoader(ConcatDataLoader):
"""DataLoader that supports semisupervised training.
Parameters
----------
adata_manager
:class:`~scvi.data.AnnDataManager` object that has been created via ``setup_anndata``.
n_samples_per_label
Number of subsamples for each label class to sample per epoch. By default, there
is no label subsampling.
indices
The indices of the observations in the adata to load
shuffle
Whether the data should be shuffled
batch_size
minibatch size to load each iteration
data_and_attributes
Dictionary with keys representing keys in data registry (`adata_manager.data_registry`)
and value equal to desired numpy loading type (later made into torch tensor).
If `None`, defaults to all registered data.
data_loader_kwargs
Keyword arguments for :class:`~torch.utils.data.DataLoader`
"""
def __init__(
self,
adata_manager: AnnDataManager,
n_samples_per_label: Optional[int] = None,
indices: Optional[List[int]] = None,
shuffle: bool = False,
batch_size: int = 128,
data_and_attributes: Optional[dict] = None,
drop_last: Union[bool, int] = False,
**data_loader_kwargs,
):
adata = adata_manager.adata
if indices is None:
indices = np.arange(adata.n_obs)
self.indices = np.asarray(indices)
if len(self.indices) == 0:
return None
self.n_samples_per_label = n_samples_per_label
labels_state_registry = adata_manager.get_state_registry(
REGISTRY_KEYS.LABELS_KEY
)
labels = get_anndata_attribute(
adata_manager.adata,
adata_manager.data_registry.labels.attr_name,
labels_state_registry.original_key,
).ravel()
# save a nested list of the indices per labeled category
self.labeled_locs = []
for label in np.unique(labels):
if label != labels_state_registry.unlabeled_category:
label_loc_idx = np.where(labels[indices] == label)[0]
label_loc = self.indices[label_loc_idx]
self.labeled_locs.append(label_loc)
labelled_idx = self.subsample_labels()
super().__init__(
adata_manager=adata_manager,
indices_list=[self.indices, labelled_idx],
shuffle=shuffle,
batch_size=batch_size,
data_and_attributes=data_and_attributes,
drop_last=drop_last,
**data_loader_kwargs,
)
def resample_labels(self):
"""Resamples the labeled data."""
labelled_idx = self.subsample_labels()
# self.dataloaders[0] iterates over full_indices
# self.dataloaders[1] iterates over the labelled_indices
# change the indices of the labelled set
self.dataloaders[1] = AnnDataLoader(
self.adata_manager,
indices=labelled_idx,
shuffle=self._shuffle,
batch_size=self._batch_size,
data_and_attributes=self.data_and_attributes,
drop_last=self._drop_last,
)
def subsample_labels(self):
"""Subsamples each label class by taking up to n_samples_per_label samples per class."""
if self.n_samples_per_label is None:
return np.concatenate(self.labeled_locs)
sample_idx = []
for loc in self.labeled_locs:
if len(loc) < self.n_samples_per_label:
sample_idx.append(loc)
else:
label_subset = np.random.choice(
loc, self.n_samples_per_label, replace=False
)
sample_idx.append(label_subset)
sample_idx = np.concatenate(sample_idx)
return sample_idx