Skip to content

Commit

Permalink
NeptuneLogger improvements (#951)
Browse files Browse the repository at this point in the history
* added output handlers, model checkpoint handler, added mnist example, added initial tests file

* added exp link to examples, added tests

* added neptune do docs

* fixed test

* fixed imports

* added neptune-client to test dependencies

* fixed missing package message

* dropped model checkpoing handler

* updated experiment link

* dropped __futures__ print_function

* added NeptuneSaver and tests

* autopep8 fix

* updated token to anonymous user neptuner

* updated experiment link

* updated token to 'ANONYMOUS'

* updated examples, fixed tests

* autopep8 fix

* fixed serializable test

* fixed serializable model test

* local

* autopep8 fix

* added self.experiment back

Co-authored-by: AutoPEP8 <>
  • Loading branch information
jakubczakon committed Apr 21, 2020
1 parent 29058e2 commit ca9d08e
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 102 deletions.
66 changes: 38 additions & 28 deletions ignite/contrib/handlers/neptune_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,12 @@ def global_step_transform(*args, **kwargs):
event_name=Events.EPOCH_COMPLETED)
Args:
tag (str): common title for all produced plots. For example, 'training'
tag (str): common title for all produced plots. For example, "training"
metric_names (list of str, optional): list of metric names to plot or a string "all" to plot all available
metrics.
output_transform (callable, optional): output transform function to prepare `engine.state.output` as a number.
For example, `output_transform = lambda output: output`
This function can also return a dictionary, e.g `{'loss': loss1, 'another_loss': loss2}` to label the plot
This function can also return a dictionary, e.g `{"loss": loss1, "another_loss": loss2}` to label the plot
with corresponding keys.
another_engine (Engine): Deprecated (see :attr:`global_step_transform`). Another engine to use to provide the
value of event. Typically, user can provide
Expand Down Expand Up @@ -121,7 +121,7 @@ def __init__(self, tag, metric_names=None, output_transform=None, another_engine
def __call__(self, engine, logger, event_name):

if not isinstance(logger, NeptuneLogger):
raise RuntimeError("Handler 'OutputHandler' works only with NeptuneLogger")
raise RuntimeError("Handler OutputHandler works only with NeptuneLogger")

metrics = self._setup_output_metrics(engine)

Expand All @@ -135,10 +135,10 @@ def __call__(self, engine, logger, event_name):

for key, value in metrics.items():
if isinstance(value, numbers.Number) or isinstance(value, torch.Tensor) and value.ndimension() == 0:
logger.experiment.log_metric("{}/{}".format(self.tag, key), x=global_step, y=value)
logger.log_metric("{}/{}".format(self.tag, key), x=global_step, y=value)
elif isinstance(value, torch.Tensor) and value.ndimension() == 1:
for i, v in enumerate(value):
logger.experiment.log_metric("{}/{}/{}".format(self.tag, key, i), x=global_step, y=v.item())
logger.log_metric("{}/{}/{}".format(self.tag, key, i), x=global_step, y=v.item())
else:
warnings.warn("NeptuneLogger output_handler can not log " "metrics value type {}".format(type(value)))

Expand Down Expand Up @@ -170,15 +170,15 @@ class OptimizerParamsHandler(BaseOptimizerParamsHandler):
Args:
optimizer (torch.optim.Optimizer): torch optimizer which parameters to log
param_name (str): parameter name
tag (str, optional): common title for all produced plots. For example, 'generator'
tag (str, optional): common title for all produced plots. For example, generator
"""

def __init__(self, optimizer, param_name="lr", tag=None):
super(OptimizerParamsHandler, self).__init__(optimizer, param_name, tag)

def __call__(self, engine, logger, event_name):
if not isinstance(logger, NeptuneLogger):
raise RuntimeError("Handler 'OptimizerParamsHandler' works only with NeptuneLogger")
raise RuntimeError("Handler OptimizerParamsHandler works only with NeptuneLogger")

global_step = engine.state.get_event_attrib_value(event_name)
tag_prefix = "{}/".format(self.tag) if self.tag else ""
Expand All @@ -188,7 +188,7 @@ def __call__(self, engine, logger, event_name):
}

for k, v in params.items():
logger.experiment.log_metric(k, x=global_step, y=v)
logger.log_metric(k, x=global_step, y=v)


class WeightsScalarHandler(BaseWeightsScalarHandler):
Expand Down Expand Up @@ -220,7 +220,7 @@ class WeightsScalarHandler(BaseWeightsScalarHandler):
Args:
model (torch.nn.Module): model to log weights
reduction (callable): function to reduce parameters into scalar
tag (str, optional): common title for all produced plots. For example, 'generator'
tag (str, optional): common title for all produced plots. For example, generator
"""

Expand All @@ -230,7 +230,7 @@ def __init__(self, model, reduction=torch.norm, tag=None):
def __call__(self, engine, logger, event_name):

if not isinstance(logger, NeptuneLogger):
raise RuntimeError("Handler 'WeightsScalarHandler' works only with NeptuneLogger")
raise RuntimeError("Handler WeightsScalarHandler works only with NeptuneLogger")

global_step = engine.state.get_event_attrib_value(event_name)
tag_prefix = "{}/".format(self.tag) if self.tag else ""
Expand All @@ -239,7 +239,7 @@ def __call__(self, engine, logger, event_name):
continue

name = name.replace(".", "/")
logger.experiment.log_metric(
logger.log_metric(
"{}weights_{}/{}".format(tag_prefix, self.reduction.__name__, name),
x=global_step,
y=self.reduction(p.data),
Expand Down Expand Up @@ -275,7 +275,7 @@ class GradsScalarHandler(BaseWeightsScalarHandler):
Args:
model (torch.nn.Module): model to log weights
reduction (callable): function to reduce parameters into scalar
tag (str, optional): common title for all produced plots. For example, 'generator'
tag (str, optional): common title for all produced plots. For example, generator
"""

Expand All @@ -284,7 +284,7 @@ def __init__(self, model, reduction=torch.norm, tag=None):

def __call__(self, engine, logger, event_name):
if not isinstance(logger, NeptuneLogger):
raise RuntimeError("Handler 'GradsScalarHandler' works only with NeptuneLogger")
raise RuntimeError("Handler GradsScalarHandler works only with NeptuneLogger")

global_step = engine.state.get_event_attrib_value(event_name)
tag_prefix = "{}/".format(self.tag) if self.tag else ""
Expand All @@ -293,7 +293,7 @@ def __call__(self, engine, logger, event_name):
continue

name = name.replace(".", "/")
logger.experiment.log_metric(
logger.log_metric(
"{}grads_{}/{}".format(tag_prefix, self.reduction.__name__, name),
x=global_step,
y=self.reduction(p.grad),
Expand Down Expand Up @@ -354,7 +354,7 @@ class NeptuneLogger(BaseLogger):
# Attach the logger to the trainer to log training loss at each iteration
npt_logger.attach(trainer,
log_handler=OutputHandler(tag="training", output_transform=lambda loss: {'loss': loss}),
log_handler=OutputHandler(tag="training", output_transform=lambda loss: {"loss": loss}),
event_name=Events.ITERATION_COMPLETED)
# Attach the logger to the evaluator on the training dataset and log NLL, Accuracy metrics after each epoch
Expand Down Expand Up @@ -397,11 +397,11 @@ class NeptuneLogger(BaseLogger):
from ignite.handlers import Checkpoint
def score_function(engine):
return engine.state.metrics['accuracy']
return engine.state.metrics["accuracy"]
to_save = {'model': model}
to_save = {"model": model}
handler = Checkpoint(to_save, NeptuneSaver(npt_logger), n_saved=2,
filename_prefix='best', score_function=score_function,
filename_prefix="best", score_function=score_function,
score_name="validation_accuracy",
global_step_transform=global_step_from_engine(trainer))
validation_evaluator.add_event_handler(Events.COMPLETED, handler)
Expand All @@ -425,11 +425,20 @@ def score_function(engine):
# Attach the logger to the trainer to log training loss at each iteration
npt_logger.attach(trainer,
log_handler=OutputHandler(tag="training",
output_transform=lambda loss: {'loss': loss}),
output_transform=lambda loss: {"loss": loss}),
event_name=Events.ITERATION_COMPLETED)
"""

def __getattr__(self, attr):

import neptune

def wrapper(*args, **kwargs):
return getattr(neptune, attr)(*args, **kwargs)

return wrapper

def __init__(self, *args, **kwargs):
try:
import neptune
Expand All @@ -444,16 +453,17 @@ def __init__(self, *args, **kwargs):
neptune.init(project_qualified_name="dry-run/project", backend=neptune.OfflineBackend())
else:
self.mode = "online"
neptune.init(api_token=kwargs["api_token"], project_qualified_name=kwargs["project_name"])
neptune.init(api_token=kwargs.get("api_token"), project_qualified_name=kwargs.get("project_name"))

kwargs["name"] = kwargs.pop("experiment_name", None)
self._experiment_kwargs = {
k: v for k, v in kwargs.items() if k not in ["api_token", "project_name", "offline_mode"]
}

self.experiment = neptune.create_experiment(**self._experiment_kwargs)

def close(self):
self.experiment.stop()
self.stop()


class NeptuneSaver:
Expand Down Expand Up @@ -486,20 +496,20 @@ class NeptuneSaver:
from ignite.handlers import Checkpoint
def score_function(engine):
return engine.state.metrics['accuracy']
return engine.state.metrics["accuracy"]
to_save = {'model': model}
to_save = {"model": model}
# pass neptune logger to NeptuneServer
handler = Checkpoint(to_save, NeptuneSaver(npt_logger), n_saved=2,
filename_prefix='best', score_function=score_function,
filename_prefix="best", score_function=score_function,
score_name="validation_accuracy",
global_step_transform=global_step_from_engine(trainer))
evaluator.add_event_handler(Events.COMPLETED, handler)
# We need to close the logger when we are done
# We need to close the logger when we are done
npt_logger.close()
For example, you can access model checkpoints and download them from here:
Expand All @@ -508,13 +518,13 @@ def score_function(engine):
"""

def __init__(self, neptune_logger: NeptuneLogger):
self._experiment = neptune_logger.experiment
self._logger = neptune_logger

def __call__(self, checkpoint: Mapping, filename: str) -> None:

with tempfile.NamedTemporaryFile() as tmp:
torch.save(checkpoint, tmp.name)
self._experiment.log_artifact(tmp.name, filename)
self._logger.log_artifact(tmp.name, filename)

def remove(self, filename: str) -> None:
self._experiment.delete_artifacts(filename)
self._logger.delete_artifacts(filename)
Loading

0 comments on commit ca9d08e

Please sign in to comment.