-
Notifications
You must be signed in to change notification settings - Fork 342
/
_trainrunner.py
148 lines (134 loc) · 5.52 KB
/
_trainrunner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import logging
import warnings
from typing import List, Optional, Union
import lightning.pytorch as pl
import numpy as np
import pandas as pd
from scvi import settings
from scvi.dataloaders import DataSplitter, SemiSupervisedDataSplitter
from scvi.model._utils import parse_device_args
from scvi.model.base import BaseModelClass
from scvi.train import Trainer
logger = logging.getLogger(__name__)
class TrainRunner:
"""TrainRunner calls Trainer.fit() and handles pre and post training procedures.
Parameters
----------
model
model to train
training_plan
initialized TrainingPlan
data_splitter
initialized :class:`~scvi.dataloaders.SemiSupervisedDataSplitter` or
:class:`~scvi.dataloaders.DataSplitter`
max_epochs
max_epochs to train for
use_gpu
Use default GPU if available (if `True`), or index of GPU to use (if int), or name
of GPU (if str, e.g., `'cuda:0'`), or use CPU (if False). Passing in `use_gpu != None`
will override `accelerator` and `devices` arguments. This argument is deprecated in
v1.0 and will be removed in v1.1. Please use `accelerator` and `devices` instead.
accelerator
Supports passing different accelerator types ("cpu", "gpu", "tpu", "ipu", "hpu",
"mps, "auto") as well as custom accelerator instances.
devices
The devices to use. Can be set to a positive number (int or str), a sequence of
device indices (list or str), the value -1 to indicate all available devices should
be used, or "auto" for automatic selection based on the chosen accelerator.
trainer_kwargs
Extra kwargs for :class:`~scvi.train.Trainer`
Examples
--------
>>> # Following code should be within a subclass of BaseModelClass
>>> data_splitter = DataSplitter(self.adata)
>>> training_plan = TrainingPlan(self.module, len(data_splitter.train_idx))
>>> runner = TrainRunner(
>>> self,
>>> training_plan=trianing_plan,
>>> data_splitter=data_splitter,
>>> max_epochs=max_epochs)
>>> runner()
"""
_trainer_cls = Trainer
def __init__(
self,
model: BaseModelClass,
training_plan: pl.LightningModule,
data_splitter: Union[SemiSupervisedDataSplitter, DataSplitter],
max_epochs: int,
use_gpu: Optional[Union[str, int, bool]] = None,
accelerator: str = "auto",
devices: Union[int, List[int], str] = "auto",
**trainer_kwargs,
):
self.training_plan = training_plan
self.data_splitter = data_splitter
self.model = model
accelerator, lightning_devices, device = parse_device_args(
use_gpu=use_gpu,
accelerator=accelerator,
devices=devices,
return_device="torch",
)
self.accelerator = accelerator
self.lightning_devices = lightning_devices
self.device = device
self.trainer = self._trainer_cls(
max_epochs=max_epochs,
accelerator=accelerator,
devices=lightning_devices,
**trainer_kwargs,
)
def __call__(self):
"""Run training."""
if hasattr(self.data_splitter, "n_train"):
self.training_plan.n_obs_training = self.data_splitter.n_train
if hasattr(self.data_splitter, "n_val"):
self.training_plan.n_obs_validation = self.data_splitter.n_val
self.trainer.fit(self.training_plan, self.data_splitter)
self._update_history()
# data splitter only gets these attrs after fit
self.model.train_indices = self.data_splitter.train_idx
self.model.test_indices = self.data_splitter.test_idx
self.model.validation_indices = self.data_splitter.val_idx
self.model.module.eval()
self.model.is_trained_ = True
self.model.to_device(self.device)
self.model.trainer = self.trainer
def _update_history(self):
# model is being further trained
# this was set to true during first training session
if self.model.is_trained_ is True:
# if not using the default logger (e.g., tensorboard)
if not isinstance(self.model.history_, dict):
warnings.warn(
"Training history cannot be updated. Logger can be accessed from "
"`model.trainer.logger`",
UserWarning,
stacklevel=settings.warnings_stacklevel,
)
return
else:
new_history = self.trainer.logger.history
for key, val in self.model.history_.items():
# e.g., no validation loss due to training params
if key not in new_history:
continue
prev_len = len(val)
new_len = len(new_history[key])
index = np.arange(prev_len, prev_len + new_len)
new_history[key].index = index
self.model.history_[key] = pd.concat(
[
val,
new_history[key],
]
)
self.model.history_[key].index.name = val.index.name
else:
# set history_ attribute if it exists
# other pytorch lightning loggers might not have history attr
try:
self.model.history_ = self.trainer.logger.history
except AttributeError:
self.history_ = None