-
Notifications
You must be signed in to change notification settings - Fork 3.6k
/
resolver.py
170 lines (138 loc) · 6.01 KB
/
resolver.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import inspect
from typing import Any, Optional, Union
import torch
from torch import Tensor
from torch.optim import Optimizer
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch_geometric.nn.lr_scheduler import (
ConstantWithWarmupLR,
CosineWithWarmupLR,
CosineWithWarmupRestartsLR,
LinearWithWarmupLR,
PolynomialWithWarmupLR,
)
from torch_geometric.resolver import normalize_string, resolver
try:
from torch.optim.lr_scheduler import LRScheduler
except ImportError: # PyTorch < 2.0
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
# Activation Resolver #########################################################
def swish(x: Tensor) -> Tensor:
return x * x.sigmoid()
def activation_resolver(query: Union[Any, str] = 'relu', *args, **kwargs):
base_cls = torch.nn.Module
base_cls_repr = 'Act'
acts = [
act for act in vars(torch.nn.modules.activation).values()
if isinstance(act, type) and issubclass(act, base_cls)
]
acts += [
swish,
]
act_dict = {}
return resolver(acts, act_dict, query, base_cls, base_cls_repr, *args,
**kwargs)
# Normalization Resolver ######################################################
def normalization_resolver(query: Union[Any, str], *args, **kwargs):
import torch_geometric.nn.norm as norm
base_cls = torch.nn.Module
base_cls_repr = 'Norm'
norms = [
norm for norm in vars(norm).values()
if isinstance(norm, type) and issubclass(norm, base_cls)
]
norm_dict = {}
return resolver(norms, norm_dict, query, base_cls, base_cls_repr, *args,
**kwargs)
# Aggregation Resolver ########################################################
def aggregation_resolver(query: Union[Any, str], *args, **kwargs):
import torch_geometric.nn.aggr as aggr
if isinstance(query, (list, tuple)):
return aggr.MultiAggregation(query, *args, **kwargs)
base_cls = aggr.Aggregation
aggrs = [
aggr for aggr in vars(aggr).values()
if isinstance(aggr, type) and issubclass(aggr, base_cls)
]
aggr_dict = {
'add': aggr.SumAggregation,
}
return resolver(aggrs, aggr_dict, query, base_cls, None, *args, **kwargs)
# Optimizer Resolver ##########################################################
def optimizer_resolver(query: Union[Any, str], *args, **kwargs):
base_cls = Optimizer
optimizers = [
optimizer for optimizer in vars(torch.optim).values()
if isinstance(optimizer, type) and issubclass(optimizer, base_cls)
]
return resolver(optimizers, {}, query, base_cls, None, *args, **kwargs)
# Learning Rate Scheduler Resolver ############################################
def lr_scheduler_resolver(
query: Union[Any, str],
optimizer: Optimizer,
warmup_ratio_or_steps: Optional[Union[float, int]] = 0.1,
num_training_steps: Optional[int] = None,
**kwargs,
) -> Union[LRScheduler, ReduceLROnPlateau]:
r"""A resolver to obtain a learning rate scheduler implemented in either
PyG or PyTorch from its name or type.
Args:
query (Any or str): The query name of the learning rate scheduler.
optimizer (Optimizer): The optimizer to be scheduled.
warmup_ratio_or_steps (float or int, optional): The number of warmup
steps. If given as a `float`, it will act as a ratio that gets
multiplied with the number of training steps to obtain the number
of warmup steps. Only required for warmup-based LR schedulers.
(default: :obj:`0.1`)
num_training_steps (int, optional): The total number of training steps.
(default: :obj:`None`)
**kwargs (optional): Additional arguments of the LR scheduler.
"""
if not isinstance(query, str):
return query
if isinstance(warmup_ratio_or_steps, float):
if warmup_ratio_or_steps < 0 or warmup_ratio_or_steps > 1:
raise ValueError(f"`warmup_ratio_or_steps` needs to be between "
f"0.0 and 1.0 when given as a floating point "
f"number (got {warmup_ratio_or_steps}).")
if num_training_steps is not None:
warmup_steps = round(warmup_ratio_or_steps * num_training_steps)
elif isinstance(warmup_ratio_or_steps, int):
if warmup_ratio_or_steps < 0:
raise ValueError(f"`warmup_ratio_or_steps` needs to be positive "
f"when given as an integer "
f"(got {warmup_ratio_or_steps}).")
warmup_steps = warmup_ratio_or_steps
else:
raise ValueError(f"Found invalid type of `warmup_ratio_or_steps` "
f"(got {type(warmup_ratio_or_steps)})")
base_cls = LRScheduler
classes = [
scheduler for scheduler in vars(torch.optim.lr_scheduler).values()
if isinstance(scheduler, type) and issubclass(scheduler, base_cls)
] + [ReduceLROnPlateau]
customized_lr_schedulers = [
ConstantWithWarmupLR,
LinearWithWarmupLR,
CosineWithWarmupLR,
CosineWithWarmupRestartsLR,
PolynomialWithWarmupLR,
]
classes += customized_lr_schedulers
query_repr = normalize_string(query)
base_cls_repr = normalize_string('LR')
for cls in classes:
cls_repr = normalize_string(cls.__name__)
if query_repr in [cls_repr, cls_repr.replace(base_cls_repr, '')]:
if inspect.isclass(cls):
if cls in customized_lr_schedulers:
cls_keys = inspect.signature(cls).parameters.keys()
if 'num_warmup_steps' in cls_keys:
kwargs['num_warmup_steps'] = warmup_steps
if 'num_training_steps' in cls_keys:
kwargs['num_training_steps'] = num_training_steps
obj = cls(optimizer, **kwargs)
return obj
return cls
choices = {cls.__name__ for cls in classes}
raise ValueError(f"Could not resolve '{query}' among choices {choices}")