This repository has been archived by the owner on May 28, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 2
/
optuna_transformers.py
130 lines (118 loc) · 5.41 KB
/
optuna_transformers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
# Copyright (c) 2021 Timothy Wolff-Piggott
# This software is distributed under the terms of the MIT license
# which is available at https://opensource.org/licenses/MIT
import logging
import os
from hpoflow.optuna_mlflow import OptunaMLflow
from optuna._imports import try_import
with try_import() as _imports:
import mlflow
import transformers
# do the check eagerly and not in the constructor
# because OMLflowCallback inherits from transformers.TrainerCallback
_imports.check()
_logger = logging.getLogger(__name__)
class OMLflowCallback(transformers.TrainerCallback):
"""
Class based on ``transformers.TrainerCallback``; integrates with OptunaMLflow
to send the logs to ``MLflow`` and ``Optuna`` during model training.
"""
def __init__(
self,
trial: OptunaMLflow,
log_training_args: bool = True,
log_model_config: bool = True,
):
"""
Check integration package dependencies and initialize class.
Args:
trial: OptunaMLflow object
log_training_args: Whether to log all Transformers TrainingArguments as MLflow params
log_model_config: Whether to log the Transformers model config as MLflow params
"""
self._initialized = False
self._log_artifacts = False
self._ml_flow = mlflow
self._trial = trial
self._log_training_args = log_training_args
self._log_model_config = log_model_config
def setup(self, args, state, model):
"""
Setup the optional MLflow integration.
Environment:
HF_MLFLOW_LOG_ARTIFACTS (:obj:``str``, ``optional``):
Whether to use MLflow .log_artifact() facility to log artifacts.
This only makes sense if logging to a remote server, e.g. s3 or GCS.
If set to ``True`` or ``1``, will copy whatever is in TrainerArgument's output_dir
to the local or remote artifact storage. Using it without a remote storage will
just copy the files to your artifact location.
"""
log_artifacts = os.getenv("HF_MLFLOW_LOG_ARTIFACTS", "FALSE").upper()
if log_artifacts in {"TRUE", "1"}:
self._log_artifacts = True
if state.is_world_process_zero:
combined_dict = dict()
if self._log_training_args:
_logger.info("Logging training arguments.")
combined_dict.update(args.to_dict())
if self._log_model_config and hasattr(model, "config") and model.config is not None:
_logger.info("Logging model config.")
model_config = model.config.to_dict()
combined_dict = {**model_config, **combined_dict}
# remove params that are too long for MLflow
for name, value in list(combined_dict.items()):
# internally, all values are converted to str in MLflow
if len(str(value)) > mlflow.utils.validation.MAX_PARAM_VAL_LENGTH:
_logger.warning(
f"Trainer is attempting to log a value of "
f"'{value}' for key '{name}' as a parameter. "
f"MLflow's log_param() only accepts values no longer than "
f"250 characters so we dropped this attribute."
)
del combined_dict[name]
# MLflow cannot log more than 100 values in one go, so we have to split it
combined_dict_items = list(combined_dict.items())
for i in range(
0, len(combined_dict_items), mlflow.utils.validation.MAX_PARAMS_TAGS_PER_BATCH
):
self._trial.log_params(
dict(
combined_dict_items[
i : i + mlflow.utils.validation.MAX_PARAMS_TAGS_PER_BATCH
]
)
)
self._initialized = True
def on_train_begin(self, args, state, control, model=None, **kwargs):
"""
Call setup if not yet initialized.
"""
if not self._initialized:
self.setup(args, state, model)
def on_log(self, args, state, control, logs, model=None, **kwargs):
"""
Log all metrics from Transformers logs as MLflow metrics at the appropriate step.
"""
if not self._initialized:
self.setup(args, state, model)
if state.is_world_process_zero:
metrics_to_log = dict()
for k, v in logs.items():
if isinstance(v, (int, float)):
metrics_to_log[k] = v
else:
_logger.warning(
f"Trainer is attempting to log a value of "
f"'{v}' of type {type(v)} for key '{k}' as a metric. "
f"MLflow's log_metric() only accepts float and "
f"int types so we dropped this attribute."
)
self._trial.log_metrics(metrics_to_log, step=state.global_step)
def on_train_end(self, args, state, control, **kwargs):
"""
Log the training output as MLflow artifacts if logging artifacts is enabled.
"""
if self._initialized and state.is_world_process_zero:
if self._log_artifacts:
_logger.info("Logging artifacts. This may take time.")
self._ml_flow.log_artifacts(args.output_dir)