-
Notifications
You must be signed in to change notification settings - Fork 5.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Train] [Tune] Refactor MLflow #20802
Changes from 16 commits
9437251
2720e9c
9740782
b73503e
0331ef6
4efb081
4cdcb31
ae64f04
41e9b0d
0557561
567924c
b004dee
d02237a
06de0d9
4755286
5950734
1faaea4
0683718
de719eb
e1c19d3
7ab7814
ac564e9
15e9c21
3cba50b
939f3b2
4379615
34c64b7
61bfe2e
192420e
8880fb1
4caad5b
d031913
6d35f86
c827cb4
66c0b3c
4f9eb20
d2aa72c
a5ffbd0
fb7325f
95fdc64
b045256
dd894ce
c2f8dab
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,6 +12,7 @@ | |
from ray.train.callbacks import TrainingCallback | ||
from ray.train.constants import (RESULT_FILE_JSON, TRAINING_ITERATION, | ||
TIME_TOTAL_S, TIMESTAMP, PID) | ||
from ray.util.ml_utils.mlflow import MLflowLoggerUtil | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
@@ -174,6 +175,100 @@ def _validate_worker_to_log(self, worker_to_log) -> int: | |
return worker_to_log | ||
|
||
|
||
class MLflowLoggerCallback(TrainingSingleWorkerLoggingCallback): | ||
"""MLflow Logger to automatically log Train results and config to MLflow. | ||
|
||
MLflow (https://mlflow.org) Tracking is an open source library for | ||
recording and querying experiments. This Ray Train callback | ||
sends information (config parameters, training results & metrics, | ||
and artifacts) to MLflow for automatic experiment tracking. | ||
|
||
Args: | ||
tracking_uri (Optional[str]): The tracking URI for where to manage | ||
experiments and runs. This can either be a local file path or a | ||
remote server. This arg gets passed directly to mlflow | ||
initialization. | ||
registry_uri (Optional[str]): The registry URI that gets passed | ||
directly to mlflow initialization. | ||
experiment_id (Optional[str]): The experiment id of an already | ||
existing experiment. If not | ||
passed in, experiment_name will be used. | ||
experiment_name (Optional[str]): The experiment name to use for this | ||
Train run. | ||
If the experiment with the name already exists with MLflow, | ||
it will be used. If not, a new experiment will be created with | ||
this name. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you document the behavior when these are There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you elaborate more on what you'd like to see here? There is information in the description for the arguments on the behavior if None is passed in. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, so I'm not clear from this doc what the behavior is if I also can't tell from this doc at all what happens if There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Got it, thanks for the explanation! Added it to the docstring. |
||
tags (Optional[Dict]): An optional dictionary of string keys and | ||
values to set as tags on the run | ||
save_artifact (bool): If set to True, automatically save the entire | ||
contents of the Train local_dir as an artifact to the | ||
corresponding run in MlFlow. | ||
logdir (Optional[str]): Path to directory where the results file | ||
should be. If None, will be set by the Trainer. If no tracking | ||
uri or registry uri are passed in, the logdir will be used for | ||
both. | ||
worker_to_log (int): Worker index to log. By default, will log the | ||
worker with index 0. | ||
""" | ||
|
||
def __init__(self, | ||
tracking_uri: Optional[str] = None, | ||
registry_uri: Optional[str] = None, | ||
experiment_id: Optional[str] = None, | ||
experiment_name: Optional[str] = None, | ||
tags: Optional[Dict] = None, | ||
save_artifact: bool = False, | ||
logdir: Optional[str] = None, | ||
worker_to_log: int = 0): | ||
super().__init__(logdir=logdir, worker_to_log=worker_to_log) | ||
|
||
self.tracking_uri = tracking_uri | ||
self.registry_uri = registry_uri | ||
self.experiment_id = experiment_id | ||
self.experiment_name = experiment_name | ||
self.tags = tags | ||
|
||
self.save_artifact = save_artifact | ||
self.mlflow_util = MLflowLoggerUtil() | ||
|
||
def start_training(self, logdir: str, config: Dict, **info): | ||
super().start_training(logdir=logdir, config=config, info=info) | ||
|
||
tracking_uri = self.tracking_uri if self.tracking_uri is not None \ | ||
else \ | ||
str(self.logdir) | ||
registry_uri = self.registry_uri if self.registry_uri is not None \ | ||
else \ | ||
str(self.logdir) | ||
amogkam marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
success = self.mlflow_util.setup_mlflow( | ||
tracking_uri=tracking_uri, | ||
registry_uri=registry_uri, | ||
experiment_id=self.experiment_id, | ||
experiment_name=self.experiment_name, | ||
create_experiment_if_not_exists=True) | ||
|
||
if not success: | ||
raise ValueError("No experiment_name or experiment_id passed in, " | ||
"Please " | ||
"set one of these to use the " | ||
"MLflowLoggerCallback.") | ||
amogkam marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
self.mlflow_util.start_run(tags=self.tags, set_active=True) | ||
self.mlflow_util.log_params(params_to_log=config) | ||
|
||
def handle_result(self, results: List[Dict], **info): | ||
result = results[self._workers_to_log] | ||
|
||
self.mlflow_util.log_metrics( | ||
metrics_to_log=result, step=result[TRAINING_ITERATION]) | ||
|
||
def finish_training(self, error: bool = False, **info): | ||
if self.save_artifact: | ||
self.mlflow_util.save_artifacts(dir=str(self.logdir)) | ||
self.mlflow_util.end_run(status="FAILED" if error else "FINISHED") | ||
amogkam marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
class TBXLoggerCallback(TrainingSingleWorkerLoggingCallback): | ||
"""Logs Train results in TensorboardX format. | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are we removing
RAY_CI_SGD_AFFECTED
here?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think it was ever needed in the first place.