Skip to content

Commit

Permalink
Update the Neptune integration (#906)
Browse files Browse the repository at this point in the history
Neptune integration was updated to work with a recent version of the Neptune
client.

Additional logging was added: model summary, configuration (learning rate,
optimizer, etc), event_lr (when available).

Tests were altered to match the current API of the client library and improved
when possible.

Co-authored-by: Sabine <sabine.nyholm@neptune.ai>
  • Loading branch information
twolodzko and normandy7 committed Oct 19, 2022
1 parent cfe568b commit 647702b
Show file tree
Hide file tree
Showing 4 changed files with 205 additions and 97 deletions.
2 changes: 2 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed

- `NeptuneLogger` was updated to work with recent versions of Neptune client (v0.14.3 or higher); it now logs some additional data, including the model summary, configuration, and learning rate (when available) (#906)

### Fixed

## [0.12.0] - 2022-10-07
Expand Down
3 changes: 1 addition & 2 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@ future>=0.17.1
gpytorch>=1.5
jupyter
matplotlib>=2.0.2
mlflow
neptune-client>=0.4.103
neptune-client>=0.14.3
numpydoc
openpyxl
pandas
Expand Down
179 changes: 119 additions & 60 deletions skorch/callbacks/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,147 +65,206 @@ def on_epoch_end(self, net, **kwargs):


class NeptuneLogger(Callback):
"""Logs results from history to Neptune
"""Logs model metadata and training metrics to Neptune.
Neptune is a lightweight experiment tracking tool.
Neptune is a lightweight experiment-tracking tool.
You can read more about it here: https://neptune.ai
Use this callback to automatically log all interesting values from
your net's history to Neptune.
The best way to log additional information is to log directly to the
experiment object or subclass the ``on_*`` methods.
run object.
To monitor resource consumption install psutil
To monitor resource consumption, install psutil:
>>> python -m pip install psutil
$ python -m pip install psutil
You can view example experiment logs here:
https://ui.neptune.ai/o/shared/org/skorch-integration/e/SKOR-13/charts
https://app.neptune.ai/shared/skorch-integration/e/SKOR-23/
Examples
--------
>>> # Install neptune
>>> python -m pip install neptune-client
>>> # Create a neptune experiment object
>>> import neptune
...
... # We are using api token for an anonymous user.
... # For your projects use the token associated with your neptune.ai account
>>> neptune.init(api_token='ANONYMOUS',
... project_qualified_name='shared/skorch-integration')
...
... experiment = neptune.create_experiment(
... name='skorch-basic-example',
... params={'max_epochs': 20,
... 'lr': 0.01},
... upload_source_files=['skorch_example.py'])
$ # Install Neptune
$ python -m pip install neptune-client
>>> # Create a neptune_logger callback
>>> neptune_logger = NeptuneLogger(experiment, close_after_train=False)
>>> # Pass a logger to net callbacks argument
>>> # Create a Neptune run
>>> import neptune.new as neptune
>>> from neptune.new.types import File
...
... # This example uses the API token for anonymous users.
... # For your own projects, use the token associated with your neptune.ai account.
>>> run = neptune.init_run(
... api_token=neptune.ANONYMOUS_API_TOKEN,
... project='shared/skorch-integration',
... name='skorch-basic-example',
... source_files=['skorch_example.py'],
... )
>>> # Create a NeptuneLogger callback
>>> neptune_logger = NeptuneLogger(run, close_after_train=False)
>>> # Pass the logger to the net callbacks argument
>>> net = NeuralNetClassifier(
... ClassifierModule,
... max_epochs=20,
... lr=0.01,
... callbacks=[neptune_logger])
... callbacks=[neptune_logger, Checkpoint(dirname="./checkpoints")])
>>> net.fit(X, y)
>>> # Save the checkpoints to Neptune
>>> neptune_logger.run["checkpoints].upload_files("./checkpoints")
>>> # Log additional metrics after training has finished
>>> from sklearn.metrics import roc_auc_score
... y_pred = net.predict_proba(X)
... auc = roc_auc_score(y, y_pred[:, 1])
... y_proba = net.predict_proba(X)
... auc = roc_auc_score(y, y_proba[:, 1])
...
... neptune_logger.experiment.log_metric('roc_auc_score', auc)
... neptune_logger.run["roc_auc_score"].log(auc)
>>> # log charts like ROC curve
... from scikitplot.metrics import plot_roc
... import matplotlib.pyplot as plt
>>> # Log charts, such as an ROC curve
>>> from sklearn.metrics import RocCurveDisplay
...
... fig, ax = plt.subplots(figsize=(16, 12))
... plot_roc(y, y_pred, ax=ax)
... neptune_logger.experiment.log_image('roc_curve', fig)
>>> roc_plot = RocCurveDisplay.from_estimator(net, X, y)
>>> neptune_logger.run["roc_curve"].upload(File.as_html(roc_plot.figure_))
>>> # log net object after training
>>> # Log the net object after training
... net.save_params(f_params='basic_model.pkl')
... neptune_logger.experiment.log_artifact('basic_model.pkl')
... neptune_logger.run["basic_model"].upload(File('basic_model.pkl'))
>>> # close experiment
... neptune_logger.experiment.stop()
>>> # Close the run
... neptune_logger.run.stop()
Parameters
----------
experiment : neptune.experiments.Experiment
Instantiated ``Experiment`` class.
run : neptune.new.Run
Instantiated ``Run`` class.
log_on_batch_end : bool (default=False)
Whether to log loss and other metrics on batch level.
close_after_train : bool (default=True)
Whether to close the ``Experiment`` object once training
Whether to close the ``Run`` object once training
finishes. Set this parameter to False if you want to continue
logging to the same Experiment or if you use it as a context
logging to the same run or if you use it as a context
manager.
keys_ignored : str or list of str (default=None)
Key or list of keys that should not be logged to Neptune. Note that in
addition to the keys provided by the user, keys such as those starting
with ``'event_'`` or ending on ``'_best'`` are ignored by default.
base_namespace: str
Namespace (folder) under which all metadata logged by the ``NeptuneLogger``
will be stored. Defaults to "training".
Attributes
----------
first_batch_ : bool
Helper attribute that is set to True at initialization and changes
to False on first batch end. Can be used when we want to log things
exactly once.
.. _Neptune: https://www.neptune.ai
"""

def __init__(
self,
experiment,
run,
*,
log_on_batch_end=False,
close_after_train=True,
keys_ignored=None,
base_namespace='training',
):
self.experiment = experiment
self.run = run
self.log_on_batch_end = log_on_batch_end
self.close_after_train = close_after_train
self.keys_ignored = keys_ignored
self.base_namespace = base_namespace

def initialize(self):
self.first_batch_ = True
@property
def _metric_logger(self):
return self.run[self._base_namespace]

@staticmethod
def _get_obj_name(obj):
return type(obj).__name__

def initialize(self):
keys_ignored = self.keys_ignored
if isinstance(keys_ignored, str):
keys_ignored = [keys_ignored]
self.keys_ignored_ = set(keys_ignored or [])
self.keys_ignored_.add('batches')

if self.base_namespace.endswith("/"):
self._base_namespace = self.base_namespace[:-1]
else:
self._base_namespace = self.base_namespace

return self

def on_train_begin(self, net, X, y, **kwargs):
# TODO: we might want to improve logging of the multi-module net objects, see:
# https://github.com/skorch-dev/skorch/pull/906#discussion_r993514643

self._metric_logger['model/model_type'] = self._get_obj_name(net.module_)
self._metric_logger['model/summary'] = self._model_summary_file(net.module_)

self._metric_logger['config/optimizer'] = self._get_obj_name(net.optimizer_)
self._metric_logger['config/criterion'] = self._get_obj_name(net.criterion_)
self._metric_logger['config/lr'] = net.lr
self._metric_logger['config/epochs'] = net.max_epochs
self._metric_logger['config/batch_size'] = net.batch_size
self._metric_logger['config/device'] = net.device

def on_batch_end(self, net, **kwargs):
if self.log_on_batch_end:
batch_logs = net.history[-1]['batches'][-1]

for key in filter_log_keys(batch_logs.keys(), self.keys_ignored_):
self.experiment.log_metric(key, batch_logs[key])

self.first_batch_ = False
self._log_metric(key, batch_logs, batch=True)

def on_epoch_end(self, net, **kwargs):
"""Automatically log values from the last history step."""
history = net.history
epoch_logs = history[-1]
epoch = epoch_logs['epoch']
epoch_logs = net.history[-1]

for key in filter_log_keys(epoch_logs.keys(), self.keys_ignored_):
self.experiment.log_metric(key, x=epoch, y=epoch_logs[key])
self._log_metric(key, epoch_logs, batch=False)

def on_train_end(self, net, **kwargs):
try:
self._metric_logger['train/epoch/event_lr'].log(net.history[:, 'event_lr'])
except KeyError:
pass
if self.close_after_train:
self.experiment.stop()
self.run.stop()

def _log_metric(self, name, logs, batch):
kind, _, key = name.partition('_')

if not key:
key = 'epoch_duration' if kind == 'dur' else kind
self._metric_logger[key].log(logs[name])
else:
if kind == 'valid':
kind = 'validation'

if batch:
granularity = 'batch'
else:
granularity = 'epoch'

# for example: train / epoch / loss
self._metric_logger[kind][granularity][key].log(logs[name])

@staticmethod
def _model_summary_file(model):
try:
# neptune-client=0.9.0+ package structure
from neptune.new.types import File
except ImportError:
# neptune-client>=1.0.0 package structure
from neptune.types import File

return File.from_content(str(model), extension='txt')


class WandbLogger(Callback):
Expand Down

0 comments on commit 647702b

Please sign in to comment.