-
Notifications
You must be signed in to change notification settings - Fork 41
/
basic_trainer.py
352 lines (287 loc) · 14.5 KB
/
basic_trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
import os
from typing import List, Tuple, Dict
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader, DistributedSampler, SequentialSampler
class BasicTrainer:
r'''
A generic trainer class for EEG analysis providing interfaces for all trainers to implement contexts common in training deep learning models. After inheriting this class, :obj:`on_training_step`, :obj:`on_validation_step`, and :obj:`on_test_step` must be implemented.
The class provides the following hook functions for inserting additional implementations in the training, validation and testing lifecycle:
- :obj:`before_training_epoch`: executed before each epoch of training starts
- :obj:`before_training_step`: executed before each batch of training starts
- :obj:`on_training_step`: the training process for each batch
- :obj:`after_training_step`: execute after the training of each batch
- :obj:`after_training_epoch`: executed after each epoch of training
- :obj:`before_validation_epoch`: executed before each round of validation starts
- :obj:`before_validation_step`: executed before the validation of each batch
- :obj:`on_validation_step`: validation process for each batch
- :obj:`after_validation_step`: executed after the validation of each batch
- :obj:`after_validation_epoch`: executed after each round of validation
- :obj:`before_test_epoch`: executed before each round of test starts
- :obj:`before_test_step`: executed before the test of each batch
- :obj:`on_test_step`: test process for each batch
- :obj:`after_test_step`: executed after the test of each batch
- :obj:`after_test_epoch`: executed after each round of test
If you want to use multiple GPUs for parallel computing, you need to specify the GPU indices you want to use in the python file:
.. code-block:: python
trainer = BasicTrainer(model, device_ids=[1, 2, 7])
trainer.fit(train_loader, val_loader)
trainer.test(test_loader)
Then, you can use the :obj:`torch.distributed.launch` or :obj:`torchrun` to run your python file.
.. code-block:: shell
python -m torch.distributed.launch \
--nproc_per_node=3 \
--nnodes=1 \
--node_rank=0 \
--master_addr="localhost" \
--master_port=2345 \
your_python_file.py
Here, :obj:`nproc_per_node` is the number of GPUs you specify.
Args:
model (Dict): A dictionary that stores neural networks for import, export and device conversion of neural networks.
device_ids (list): Use cpu if the list is empty. If the list contains indices of multiple GPUs, it needs to be launched with :obj:`torch.distributed.launch` or :obj:`torchrun`. (defualt: :obj:`[]`)
ddp_sync_bn (bool): Whether to replace batch normalization in network structure with cross-GPU synchronized batch normalization. Only valid when the length of :obj:`device_ids` is greater than one. (defualt: :obj:`True`)
ddp_replace_sampler (bool): Whether to replace sampler in dataloader with :obj:`DistributedSampler`. Only valid when the length of :obj:`device_ids` is greater than one. (defualt: :obj:`True`)
ddp_val (bool): Whether to use multi-GPU acceleration for the validation set. For experiments where data input order is sensitive, :obj:`ddp_val` should be set to :obj:`False`. Only valid when the length of :obj:`device_ids` is greater than one. (defualt: :obj:`True`)
ddp_test (bool): Whether to use multi-GPU acceleration for the test set. For experiments where data input order is sensitive, :obj:`ddp_test` should be set to :obj:`False`. Only valid when the length of :obj:`device_ids` is greater than one. (defualt: :obj:`True`)
.. automethod:: fit
.. automethod:: test
'''
def __init__(self,
modules: Dict,
device_ids: List[int] = [],
ddp_sync_bn: bool = True,
ddp_replace_sampler: bool = True,
ddp_val: bool = True,
ddp_test: bool = True):
# given params
self.modules = modules
self.device_ids = device_ids
self.ddp_sync_bn = ddp_sync_bn
self.ddp_replace_sampler = ddp_replace_sampler
self.ddp_val = ddp_val
self.ddp_test = ddp_test
# built-in params
self.rank = -1
self.local_rank = -1
self.world_size = 1
self.device = None
# cpu trainer
if len(device_ids) == 0:
for k, m in self.modules.items():
self.modules[k] = m.to(torch.device('cpu'))
self.device = torch.device('cpu')
# gpu trainer
if len(device_ids) == 1:
assert torch.cuda.is_available(
), 'GPU is not available, please set device to cpu!'
for k, m in self.modules.items():
self.modules[k] = m.to(torch.device('cuda', device_ids[0]))
self.device = torch.device('cuda', device_ids[0])
# ddp trainer
if len(device_ids) > 1:
assert torch.cuda.is_available(
), 'GPU is not available, please set device to cpu!'
assert torch.cuda.device_count() > max(
device_ids
), f'{torch.cuda.device_count()} GPUs are available, but try to access f{max(device_ids)}-th GPU!'
rank = int(os.getenv('RANK', -1))
local_rank = int(os.getenv('LOCAL_RANK', -1))
world_size = int(os.getenv('WORLD_SIZE', 1))
assert rank != -1, 'TorchEEG support to use multiple GPU with distributed data parallel, and you need to run the python script with torch.distributed.launch or torchrun.'
self.rank = rank
self.local_rank = local_rank
self.world_size = world_size
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(
[str(d) for d in device_ids])
assert dist.is_available(
), 'Distributed data parallel is not supported!'
if not dist.is_initialized():
# nccl is not available in windows
backend = "nccl" if dist.is_nccl_available() else "gloo"
if self.rank == 0:
self.log(f"Using backend: {backend}.")
dist.init_process_group(backend=backend,
rank=self.rank,
world_size=self.world_size)
for k, m in self.modules.items():
self.modules[k] = m.to(torch.device('cuda', local_rank))
for k, m in self.modules.items():
ddp_m = DistributedDataParallel(m,
device_ids=[self.local_rank],
find_unused_parameters=True)
if self.ddp_sync_bn:
ddp_m = nn.SyncBatchNorm.convert_sync_batchnorm(ddp_m)
self.modules[k] = ddp_m
self.device = torch.device(self.local_rank)
@property
def is_distributed(self):
return self.rank != -1
@property
def is_main(self):
# if not distributed mode -1, then it is the main process
# if distributed mode > -1, then 0 is the main process
return self.rank in {-1, 0}
def log(self, *args, **kwargs):
# can be overwritten
if self.is_main:
print(*args, **kwargs)
def before_training_epoch(self, epoch_id: int, num_epochs: int):
# can be overwritten
...
def before_training_step(self, batch_id: int, num_batches: int):
# can be overwritten
...
def on_training_step(self, train_batch: Tuple, batch_id: int,
num_batches: int):
raise NotImplementedError
def after_training_step(self, batch_id: int, num_batches: int):
# can be overwritten
...
def after_training_epoch(self, epoch_id: int, num_epochs: int):
# can be overwritten
...
def before_validation_epoch(self, epoch_id: int, num_epochs: int):
# can be overwritten
...
def before_validation_step(self, batch_id: int, num_batches: int):
# can be overwritten
...
def on_validation_step(self, val_batch: Tuple, batch_id: int,
num_batches: int):
raise NotImplementedError
def after_validation_step(self, batch_id: int, num_batches: int):
# can be overwritten
...
def after_validation_epoch(self, epoch_id: int, num_epochs: int):
# can be overwritten
...
def on_reveive_dataloader(self, dataloader, mode='train'):
if mode == 'test' and not self.ddp_test:
return dataloader
if mode == 'val' and not self.ddp_val:
return dataloader
if not self.is_distributed:
return dataloader
shuffle = True
if isinstance(dataloader.sampler, SequentialSampler):
shuffle = False
sampler = DistributedSampler(dataloader.dataset, shuffle=shuffle)
dataloader = DataLoader(dataloader.dataset,
dataloader.batch_size,
sampler=sampler,
num_workers=dataloader.num_workers,
pin_memory=dataloader.pin_memory,
drop_last=dataloader.drop_last,
collate_fn=dataloader.collate_fn)
dataloader.need_to_set_epoch = True
return dataloader
def fit(self,
train_loader: DataLoader,
val_loader: DataLoader,
num_epochs: int = 1,
**kwargs):
r'''
Args:
train_loader (DataLoader): Iterable DataLoader for traversing the training data batch (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc).
val_loader (DataLoader): Iterable DataLoader for traversing the validation data batch (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc).
num_epochs (int): training epochs. (defualt: :obj:`1`)
'''
train_loader = self.on_reveive_dataloader(train_loader, mode='train')
val_loader = self.on_reveive_dataloader(val_loader, mode='val')
for t in range(num_epochs):
if hasattr(train_loader, 'need_to_set_epoch'):
train_loader.sampler.set_epoch(t)
if hasattr(val_loader, 'need_to_set_epoch'):
val_loader.sampler.set_epoch(t)
num_batches = len(train_loader)
# set model to train mode
for k, m in self.modules.items():
self.modules[k].train()
# hook
self.before_training_epoch(t + 1, num_epochs, **kwargs)
for batch_id, train_batch in enumerate(train_loader):
# hook
self.before_training_step(batch_id, num_batches, **kwargs)
# hook
self.on_training_step(train_batch, batch_id, num_batches,
**kwargs)
# hook
self.after_training_step(batch_id, num_batches, **kwargs)
# hook
self.after_training_epoch(t + 1, num_epochs, **kwargs)
# set model to val mode
for k, m in self.modules.items():
self.modules[k].eval()
num_batches = len(val_loader)
# hook
self.before_validation_epoch(t + 1, num_epochs, **kwargs)
with torch.no_grad():
for batch_id, val_batch in enumerate(val_loader):
# hook
self.before_validation_step(batch_id, num_batches, **kwargs)
# hook
self.on_validation_step(val_batch, batch_id, num_batches,
**kwargs)
# hook
self.after_validation_step(batch_id, num_batches, **kwargs)
# hook
self.after_validation_epoch(t + 1, num_epochs, **kwargs)
return self
def test(self, test_loader: DataLoader, **kwargs):
r'''
Args:
test_loader (DataLoader): Iterable DataLoader for traversing the test data batch (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc).
'''
test_loader = self.on_reveive_dataloader(test_loader, mode='test')
for k, m in self.modules.items():
self.modules[k].eval()
num_batches = len(test_loader)
self.before_test_epoch(**kwargs)
with torch.no_grad():
for batch_id, test_batch in enumerate(test_loader):
# hook
self.before_test_step(batch_id, num_batches, **kwargs)
# hook
self.on_test_step(test_batch, batch_id, num_batches, **kwargs)
# hook
self.after_test_step(batch_id, num_batches, **kwargs)
self.after_test_epoch(**kwargs)
def before_test_epoch(self, **kwargs):
# can be overwritten
...
def before_test_step(self, batch_id: int, num_batches: int, **kwargs):
# can be overwritten
...
def on_test_step(self, test_batch: Tuple, batch_id: int, num_batches: int, **kwargs):
raise NotImplementedError
def after_test_step(self, batch_id: int, num_batches: int, **kwargs):
# can be overwritten
...
def after_test_epoch(self, **kwargs):
# can be overwritten
...
def load_state_dict(self, load_path, strict=False):
if self.is_distributed:
map_location = {'cpu': f'cuda:{self.local_rank}'}
else:
map_location = {'cpu': self.device}
state_dict = torch.load(load_path, map_location=map_location)
for k, m in self.modules.items():
self.modules[k].load_state_dict(state_dict[k], strict=strict)
def save_state_dict(self, save_path):
if not os.path.exists(os.path.dirname(save_path)):
os.makedirs(os.path.dirname(save_path))
if self.is_main:
state_dict = {}
for k, m in self.modules.items():
if isinstance(m, DistributedDataParallel):
m = m.module
state_dict[k] = m.state_dict()
torch.save(state_dict, save_path)
# Use a barrier() to make sure that process a loads the model after process b
# saves it (In case the user calls load_state_dict right after save_state_dict).
if self.is_distributed:
dist.barrier()