Skip to content

Commit

Permalink
Added MLflow callback (#770)
Browse files Browse the repository at this point in the history
  • Loading branch information
cacharle committed Jun 17, 2021
1 parent 812f54d commit 852383e
Show file tree
Hide file tree
Showing 6 changed files with 379 additions and 1 deletion.
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Added `load_best` attribute to `Checkpoint` callback to automatically load state of the best result at the end of training
- Added a `get_all_learnable_params` method to retrieve the named parameters of all PyTorch modules defined on the net, including of criteria if applicable
- Added `MlflowLogger` callback for logging to Mlflow (#769)

### Changed

Expand Down
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
'sklearn': ('http://scikit-learn.org/stable/', None),
'numpy': ('http://docs.scipy.org/doc/numpy/', None),
'python': ('https://docs.python.org/3', None),
'mlflow': ('https://mlflow.org/docs/latest/', None),
}

# Add any paths that contain templates here, relative to this directory.
Expand Down
1 change: 1 addition & 0 deletions skorch/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
'Initializer',
'LRScheduler',
'LoadInitState',
'MlflowLogger',
'NeptuneLogger',
'ParamMapper',
'PassthroughScoring',
Expand Down
184 changes: 183 additions & 1 deletion skorch/callbacks/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import sys
import time
import tempfile
from contextlib import suppress
from numbers import Number
from itertools import cycle
Expand All @@ -16,7 +17,7 @@
from skorch.callbacks import Callback

__all__ = ['EpochTimer', 'NeptuneLogger', 'WandbLogger', 'PrintLog', 'ProgressBar',
'TensorBoard', 'SacredLogger']
'TensorBoard', 'SacredLogger', 'MlflowLogger']


def filter_log_keys(keys, keys_ignored=None):
Expand Down Expand Up @@ -875,3 +876,184 @@ def on_epoch_end(self, net, **kwargs):

for key in filter_log_keys(epoch_logs.keys(), self.keys_ignored_):
self.experiment.log_scalar(key + self.epoch_suffix_, epoch_logs[key], epoch)


class MlflowLogger(Callback):
"""Logs results from history and artifact to Mlflow
"MLflow is an open source platform for managing
the end-to-end machine learning lifecycle" (:doc:`mlflow:index`)
Use this callback to automatically log your metrics
and create/log artifacts to mlflow.
The best way to log additional information is to log directly to the
experiment object or subclass the ``on_*`` methods.
To use this logger, you first have to install Mlflow:
.. code-block::
$ pip install mlflow
Examples
--------
Mlflow :doc:`fluent API <mlflow:python_api/mlflow>`:
>>> import mlflow
>>> net = NeuralNetClassifier(net, callbacks=[MLflowLogger()])
>>> with mlflow.start_run():
... net.fit(X, y)
Custom :py:class:`run <mlflow.entities.Run>` and
:py:class:`client <mlflow.tracking.MlflowClient>`:
>>> from mlflow.tracking import MlflowClient
>>> client = MlflowClient()
>>> experiment = client.get_experiment_by_name('Default')
>>> run = client.create_run(experiment.experiment_id)
>>> net = NeuralNetClassifier(..., callbacks=[MlflowLogger(run, client)])
>>> net.fit(X, y)
Parameters
----------
run : mlflow.entities.Run (default=None)
Instantiated :py:class:`mlflow.entities.Run` class.
By default (if set to ``None``),
:py:func:`mlflow.active_run` is used to get the current run.
client : mlflow.tracking.MlflowClient (default=None)
Instantiated :py:class:`mlflow.tracking.MlflowClient` class.
By default (if set to ``None``),
``MlflowClient()`` is used, which by default has:
- the tracking URI set by :py:func:`mlflow.set_tracking_uri`
- the registry URI set by :py:func:`mlflow.set_registry_uri`
create_artifact : bool (default=True)
Whether to create artifacts for the network's
params, optimizer, criterion and history.
See :ref:`save_load`
terminate_after_train : bool (default=True)
Whether to terminate the ``Run`` object once training finishes.
log_on_batch_end : bool (default=False)
Whether to log loss and other metrics on batch level.
log_on_epoch_end : bool (default=True)
Whether to log loss and other metrics on epoch level.
batch_suffix : str (default=None)
A string that will be appended to all logged keys. By default (if set to
``None``) ``'_batch'`` is used if batch and epoch logging are both enabled
and no suffix is used otherwise.
epoch_suffix : str (default=None)
A string that will be appended to all logged keys. By default (if set to
``None``) ``'_epoch'`` is used if batch and epoch logging are both enabled
and no suffix is used otherwise.
keys_ignored : str or list of str (default=None)
Key or list of keys that should not be logged to Mlflow. Note that in
addition to the keys provided by the user, keys such as those starting
with ``'event_'`` or ending on ``'_best'`` are ignored by default.
"""
def __init__(
self,
run=None,
client=None,
create_artifact=True,
terminate_after_train=True,
log_on_batch_end=False,
log_on_epoch_end=True,
batch_suffix=None,
epoch_suffix=None,
keys_ignored=None,
):
self.run = run
self.client = client
self.create_artifact = create_artifact
self.terminate_after_train = terminate_after_train
self.log_on_batch_end = log_on_batch_end
self.log_on_epoch_end = log_on_epoch_end
self.batch_suffix = batch_suffix
self.epoch_suffix = epoch_suffix
self.keys_ignored = keys_ignored

def initialize(self):
self.run_ = self.run
if self.run_ is None:
import mlflow
self.run_ = mlflow.active_run()
self.client_ = self.client
if self.client_ is None:
from mlflow.tracking import MlflowClient
self.client_ = MlflowClient()
keys_ignored = self.keys_ignored
if isinstance(keys_ignored, str):
keys_ignored = [keys_ignored]
self.keys_ignored_ = set(keys_ignored or [])
self.keys_ignored_.add('batches')
self.batch_suffix_ = self._init_suffix(self.batch_suffix, '_batch')
self.epoch_suffix_ = self._init_suffix(self.epoch_suffix, '_epoch')
return self

def _init_suffix(self, suffix, default):
if suffix is not None:
return suffix
return default if self.log_on_batch_end and self.log_on_epoch_end else ''

def on_train_begin(self, net, **kwargs):
self._batch_count = 0

def on_batch_end(self, net, training, **kwargs):
if not self.log_on_batch_end:
return
self._batch_count += 1
batch_logs = net.history[-1]['batches'][-1]
self._iteration_log(batch_logs, self.batch_suffix_, self._batch_count)

def on_epoch_end(self, net, **kwargs):
if not self.log_on_epoch_end:
return
epoch_logs = net.history[-1]
self._iteration_log(epoch_logs, self.epoch_suffix_, len(net.history))

def _iteration_log(self, logs, suffix, step):
for key in filter_log_keys(logs.keys(), self.keys_ignored_):
self.client_.log_metric(
self.run_.info.run_id,
key + suffix,
logs[key],
step=step,
)

def on_train_end(self, net, **kwargs):
try:
self._log_artifacts(net)
finally:
if self.terminate_after_train:
self.client_.set_terminated(self.run_.info.run_id)

def _log_artifacts(self, net):
if not self.create_artifact:
return
with tempfile.TemporaryDirectory(prefix='skorch_mlflow_logger_') as dirpath:
dirpath = Path(dirpath)
params_filepath = dirpath / 'params.pth'
optimizer_filepath = dirpath / 'optimizer.pth'
criterion_filepath = dirpath / 'criterion.pth'
history_filepath = dirpath / 'history.json'
net.save_params(
f_params=params_filepath,
f_optimizer=optimizer_filepath,
f_criterion=criterion_filepath,
f_history=history_filepath,
)
self.client_.log_artifact(self.run_.info.run_id, params_filepath)
self.client_.log_artifact(self.run_.info.run_id, optimizer_filepath)
self.client_.log_artifact(self.run_.info.run_id, criterion_filepath)
self.client_.log_artifact(self.run_.info.run_id, history_filepath)

0 comments on commit 852383e

Please sign in to comment.