-
Notifications
You must be signed in to change notification settings - Fork 21.4k
/
__init__.py
78 lines (66 loc) · 2.6 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from enum import Enum
from functools import partial
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
from . import (
default_hooks as default,
powerSGD_hook as powerSGD,
quantization_hooks as quantization,
)
def _ddp_comm_hook_wrapper(comm_hook, model, state):
model.register_comm_hook(state, comm_hook)
def _powerSGD_comm_hook_wrapper(
comm_hook, model, state, matrix_approximation_rank, random_seed=0
):
"""
To be consistent with the wrappers of other DDP comm hooks, the input state only needs to be a process group,
which will be wrapped up with other state info.
"""
powerSGD_state = powerSGD.PowerSGDState(
process_group=state,
matrix_approximation_rank=matrix_approximation_rank,
random_seed=random_seed,
)
model.register_comm_hook(powerSGD_state, comm_hook)
class DDPCommHookType(Enum):
"""
DDPCommHookType enumerates the hooks of ``torch.distributed.algorithms.ddp_comm_hooks``
as names and ``ddp_comm_hook_wrapper`` partials with hook specified. As an example,
you can register allreduce hook by
``DDPCommHookType.ALLREDUCE.value(model=model, state=process_group)``.
"""
ALLREDUCE = partial(_ddp_comm_hook_wrapper, comm_hook=default.allreduce_hook)
FP16_COMPRESS = partial(
_ddp_comm_hook_wrapper, comm_hook=default.fp16_compress_hook
)
QUANTIZE_PER_TENSOR = partial(
_ddp_comm_hook_wrapper, comm_hook=quantization.quantization_pertensor_hook
)
QUANTIZE_PER_CHANNEL = partial(
_ddp_comm_hook_wrapper, comm_hook=quantization.quantization_perchannel_hook
)
POWER_SGD = partial(
_powerSGD_comm_hook_wrapper,
comm_hook=powerSGD.powerSGD_hook,
matrix_approximation_rank=1,
)
# Rank-2 PowerSGD can give a higher accuracy than the default rank-1 version,
# but it runs slower and consumes more memory.
POWER_SGD_RANK2 = partial(
_powerSGD_comm_hook_wrapper,
comm_hook=powerSGD.powerSGD_hook,
matrix_approximation_rank=2,
)
def register_ddp_comm_hook(
comm_hook_type: DDPCommHookType, model: DistributedDataParallel, state=None
):
"""
Registers the hooks of ``torch.distributed.algorithms.ddp_comm_hooks``
to the DDP model. User can specify the type of hook as an enum
``DDPCommHookType`` type using ``comm_hook_type`` input. State input will
be passed to the model.
Uses Python comm hook implementations.
Example::
>>> register_ddp_comm_hook(DDPCommHookType.FP16_COMPRESS, model, state)
"""
comm_hook_type.value(model=model, state=state)