-
Notifications
You must be signed in to change notification settings - Fork 342
/
_training_mixin.py
88 lines (79 loc) · 3.23 KB
/
_training_mixin.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
from typing import List, Optional, Union
from scvi.dataloaders import DataSplitter
from scvi.model._utils import get_max_epochs_heuristic
from scvi.train import TrainingPlan, TrainRunner
from scvi.utils._docstrings import devices_dsp
class UnsupervisedTrainingMixin:
"""General purpose unsupervised train method."""
_data_splitter_cls = DataSplitter
_training_plan_cls = TrainingPlan
_train_runner_cls = TrainRunner
@devices_dsp.dedent
def train(
self,
max_epochs: Optional[int] = None,
use_gpu: Optional[Union[str, int, bool]] = None,
accelerator: str = "auto",
devices: Union[int, List[int], str] = "auto",
train_size: float = 0.9,
validation_size: Optional[float] = None,
shuffle_set_split: bool = True,
batch_size: int = 128,
early_stopping: bool = False,
plan_kwargs: Optional[dict] = None,
**trainer_kwargs,
):
"""Train the model.
Parameters
----------
max_epochs
Number of passes through the dataset. If `None`, defaults to
`np.min([round((20000 / n_cells) * 400), 400])`
%(param_use_gpu)s
%(param_accelerator)s
%(param_devices)s
train_size
Size of training set in the range [0.0, 1.0].
validation_size
Size of the test set. If `None`, defaults to 1 - `train_size`. If
`train_size + validation_size < 1`, the remaining cells belong to a test set.
shuffle_set_split
Whether to shuffle indices before splitting. If `False`, the val, train, and test set are split in the
sequential order of the data according to `validation_size` and `train_size` percentages.
batch_size
Minibatch size to use during training.
early_stopping
Perform early stopping. Additional arguments can be passed in `**kwargs`.
See :class:`~scvi.train.Trainer` for further options.
plan_kwargs
Keyword args for :class:`~scvi.train.TrainingPlan`. Keyword arguments passed to
`train()` will overwrite values present in `plan_kwargs`, when appropriate.
**trainer_kwargs
Other keyword args for :class:`~scvi.train.Trainer`.
"""
if max_epochs is None:
max_epochs = get_max_epochs_heuristic(self.adata.n_obs)
plan_kwargs = plan_kwargs if isinstance(plan_kwargs, dict) else {}
data_splitter = self._data_splitter_cls(
self.adata_manager,
train_size=train_size,
validation_size=validation_size,
batch_size=batch_size,
shuffle_set_split=shuffle_set_split,
)
training_plan = self._training_plan_cls(self.module, **plan_kwargs)
es = "early_stopping"
trainer_kwargs[es] = (
early_stopping if es not in trainer_kwargs.keys() else trainer_kwargs[es]
)
runner = self._train_runner_cls(
self,
training_plan=training_plan,
data_splitter=data_splitter,
max_epochs=max_epochs,
use_gpu=use_gpu,
accelerator=accelerator,
devices=devices,
**trainer_kwargs,
)
return runner()