-
-
Notifications
You must be signed in to change notification settings - Fork 605
/
base_logger.py
294 lines (231 loc) · 11.5 KB
/
base_logger.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
"""Base logger and its helper handlers."""
import numbers
import warnings
from abc import ABCMeta, abstractmethod
from collections import OrderedDict
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
import torch
import torch.nn as nn
from torch.optim import Optimizer
from ignite.engine import Engine, Events, EventsList, State
from ignite.engine.events import CallableEventWithFilter, RemovableEventHandle
class BaseHandler(metaclass=ABCMeta):
"""Base handler for defining various useful handlers."""
@abstractmethod
def __call__(self, engine: Engine, logger: Any, event_name: Union[str, Events]) -> None:
pass
class BaseWeightsHandler(BaseHandler):
"""
Base handler for logging weights or their gradients.
"""
def __init__(
self,
model: nn.Module,
tag: Optional[str] = None,
whitelist: Optional[Union[List[str], Callable[[str, nn.Parameter], bool]]] = None,
):
if not isinstance(model, torch.nn.Module):
raise TypeError(f"Argument model should be of type torch.nn.Module, but given {type(model)}")
self.model = model
self.tag = tag
weights = {}
if whitelist is None:
weights = dict(model.named_parameters())
elif callable(whitelist):
for n, p in model.named_parameters():
if whitelist(n, p):
weights[n] = p
else:
for n, p in model.named_parameters():
for item in whitelist:
if n.startswith(item):
weights[n] = p
self.weights = weights.items()
class BaseOptimizerParamsHandler(BaseHandler):
"""
Base handler for logging optimizer parameters
"""
def __init__(self, optimizer: Optimizer, param_name: str = "lr", tag: Optional[str] = None):
if not (
isinstance(optimizer, Optimizer)
or (hasattr(optimizer, "param_groups") and isinstance(optimizer.param_groups, Sequence))
):
raise TypeError(
"Argument optimizer should be torch.optim.Optimizer or has attribute 'param_groups' as list/tuple, "
f"but given {type(optimizer)}"
)
self.optimizer = optimizer
self.param_name = param_name
self.tag = tag
class BaseOutputHandler(BaseHandler):
"""
Helper handler to log engine's output and/or metrics
"""
def __init__(
self,
tag: str,
metric_names: Optional[Union[str, List[str]]] = None,
output_transform: Optional[Callable] = None,
global_step_transform: Optional[Callable[[Engine, Union[str, Events]], int]] = None,
state_attributes: Optional[List[str]] = None,
):
if metric_names is not None:
if not (isinstance(metric_names, list) or (isinstance(metric_names, str) and metric_names == "all")):
raise TypeError(
f"metric_names should be either a list or equal 'all', got {type(metric_names)} instead."
)
if output_transform is not None and not callable(output_transform):
raise TypeError(f"output_transform should be a function, got {type(output_transform)} instead.")
if output_transform is None and metric_names is None and state_attributes is None:
raise ValueError("Either metric_names, output_transform or state_attributes should be defined")
if global_step_transform is not None and not callable(global_step_transform):
raise TypeError(f"global_step_transform should be a function, got {type(global_step_transform)} instead.")
if global_step_transform is None:
def global_step_transform(engine: Engine, event_name: Union[str, Events]) -> int:
return engine.state.get_event_attrib_value(event_name)
self.tag = tag
self.metric_names = metric_names
self.output_transform = output_transform
self.global_step_transform = global_step_transform
self.state_attributes = state_attributes
def _setup_output_metrics_state_attrs(
self, engine: Engine, log_text: Optional[bool] = False, key_tuple: Optional[bool] = True
) -> Dict[Any, Any]:
"""Helper method to setup metrics and state attributes to log"""
metrics_state_attrs = OrderedDict()
if self.metric_names is not None:
if isinstance(self.metric_names, str) and self.metric_names == "all":
metrics_state_attrs = OrderedDict(engine.state.metrics)
else:
for name in self.metric_names:
if name not in engine.state.metrics:
warnings.warn(
f"Provided metric name '{name}' is missing "
f"in engine's state metrics: {list(engine.state.metrics.keys())}"
)
continue
metrics_state_attrs[name] = engine.state.metrics[name]
if self.output_transform is not None:
output_dict = self.output_transform(engine.state.output)
if not isinstance(output_dict, dict):
output_dict = {"output": output_dict}
metrics_state_attrs.update(output_dict)
if self.state_attributes is not None:
metrics_state_attrs.update({name: getattr(engine.state, name, None) for name in self.state_attributes})
metrics_state_attrs_dict: Dict[Any, Union[str, float, numbers.Number]] = OrderedDict()
def key_tuple_tf(tag: str, name: str, *args: str) -> Tuple[str, ...]:
return (tag, name) + args
def key_str_tf(tag: str, name: str, *args: str) -> str:
return "/".join((tag, name) + args)
key_tf = key_tuple_tf if key_tuple else key_str_tf
for name, value in metrics_state_attrs.items():
if isinstance(value, numbers.Number):
metrics_state_attrs_dict[key_tf(self.tag, name)] = value
elif isinstance(value, torch.Tensor) and value.ndimension() == 0:
metrics_state_attrs_dict[key_tf(self.tag, name)] = value.item()
elif isinstance(value, torch.Tensor) and value.ndimension() == 1:
for i, v in enumerate(value):
metrics_state_attrs_dict[key_tf(self.tag, name, str(i))] = v.item()
else:
if isinstance(value, str) and log_text:
metrics_state_attrs_dict[key_tf(self.tag, name)] = value
else:
warnings.warn(f"Logger output_handler can not log metrics value type {type(value)}")
return metrics_state_attrs_dict
class BaseWeightsScalarHandler(BaseWeightsHandler):
"""
Helper handler to log model's weights or gradients as scalars.
"""
def __init__(
self,
model: nn.Module,
reduction: Callable[[torch.Tensor], Union[float, torch.Tensor]] = torch.norm,
tag: Optional[str] = None,
whitelist: Optional[Union[List[str], Callable[[str, nn.Parameter], bool]]] = None,
):
super(BaseWeightsScalarHandler, self).__init__(model, tag=tag, whitelist=whitelist)
if not callable(reduction):
raise TypeError(f"Argument reduction should be callable, but given {type(reduction)}")
def _is_0D_tensor(t: Any) -> bool:
return isinstance(t, torch.Tensor) and t.ndimension() == 0
# Test reduction function on a tensor
o = reduction(torch.ones(4, 2))
if not (isinstance(o, numbers.Number) or _is_0D_tensor(o)):
raise TypeError(f"Output of the reduction function should be a scalar, but got {type(o)}")
self.reduction = reduction
class BaseLogger(metaclass=ABCMeta):
"""
Base logger handler. See implementations: TensorboardLogger, VisdomLogger, PolyaxonLogger, MLflowLogger, ...
"""
def attach(
self,
engine: Engine,
log_handler: Callable,
event_name: Union[str, Events, CallableEventWithFilter, EventsList],
*args: Any,
**kwargs: Any,
) -> RemovableEventHandle:
"""Attach the logger to the engine and execute `log_handler` function at `event_name` events.
Args:
engine: engine object.
log_handler: a logging handler to execute
event_name: event to attach the logging handler to. Valid events are from
:class:`~ignite.engine.events.Events` or :class:`~ignite.engine.events.EventsList` or any `event_name`
added by :meth:`~ignite.engine.engine.Engine.register_events`.
args: args forwarded to the `log_handler` method
kwargs: kwargs forwarded to the `log_handler` method
Returns:
:class:`~ignite.engine.events.RemovableEventHandle`, which can be used to remove the handler.
"""
if isinstance(event_name, EventsList):
for name in event_name:
if name not in State.event_to_attr:
raise RuntimeError(f"Unknown event name '{name}'")
engine.add_event_handler(name, log_handler, self, name)
return RemovableEventHandle(event_name, log_handler, engine)
else:
if event_name not in State.event_to_attr:
raise RuntimeError(f"Unknown event name '{event_name}'")
return engine.add_event_handler(event_name, log_handler, self, event_name, *args, **kwargs)
def attach_output_handler(self, engine: Engine, event_name: Any, *args: Any, **kwargs: Any) -> RemovableEventHandle:
"""Shortcut method to attach `OutputHandler` to the logger.
Args:
engine: engine object.
event_name: event to attach the logging handler to. Valid events are from
:class:`~ignite.engine.events.Events` or any `event_name` added by
:meth:`~ignite.engine.engine.Engine.register_events`.
args: args to initialize `OutputHandler`
kwargs: kwargs to initialize `OutputHandler`
Returns:
:class:`~ignite.engine.events.RemovableEventHandle`, which can be used to remove the handler.
"""
return self.attach(engine, self._create_output_handler(*args, **kwargs), event_name=event_name)
def attach_opt_params_handler(
self, engine: Engine, event_name: Any, *args: Any, **kwargs: Any
) -> RemovableEventHandle:
"""Shortcut method to attach `OptimizerParamsHandler` to the logger.
Args:
engine: engine object.
event_name: event to attach the logging handler to. Valid events are from
:class:`~ignite.engine.events.Events` or any `event_name` added by
:meth:`~ignite.engine.engine.Engine.register_events`.
args: args to initialize `OptimizerParamsHandler`
kwargs: kwargs to initialize `OptimizerParamsHandler`
Returns:
:class:`~ignite.engine.events.RemovableEventHandle`, which can be used to remove the handler.
.. versionchanged:: 0.4.3
Added missing return statement.
"""
return self.attach(engine, self._create_opt_params_handler(*args, **kwargs), event_name=event_name)
@abstractmethod
def _create_output_handler(self, engine: Engine, *args: Any, **kwargs: Any) -> Callable:
pass
@abstractmethod
def _create_opt_params_handler(self, *args: Any, **kwargs: Any) -> Callable:
pass
def __enter__(self) -> "BaseLogger":
return self
def __exit__(self, type: Any, value: Any, traceback: Any) -> None:
self.close()
def close(self) -> None:
pass