From 8fa3efee1affad2b5ca3e5d8e8c0b7e100654d60 Mon Sep 17 00:00:00 2001
From: Janne Lappalainen <34949352+lappalainenj@users.noreply.github.com>
Date: Wed, 26 Jan 2022 15:34:39 +0100
Subject: [PATCH] Plot functions for tensorboard data (#593)
* added vscode code-workspace to gitignore
* Plot functions for tensorboard data #586
* fixed minor style issues
* adding test for plot_summary
---
.gitignore | 1 +
sbi/inference/base.py | 8 ++
sbi/utils/__init__.py | 1 +
sbi/utils/tensorboard_output.py | 212 ++++++++++++++++++++++++++++++++
tests/plot_test.py | 52 ++++++++
5 files changed, 274 insertions(+)
create mode 100644 sbi/utils/tensorboard_output.py
create mode 100644 tests/plot_test.py
diff --git a/.gitignore b/.gitignore
index 2143160fb..dafb4a21b 100644
--- a/.gitignore
+++ b/.gitignore
@@ -89,6 +89,7 @@ target/
# IDEs (VSCode, etc.)
.env
+*.code-workspace
# Class diagram
*.pyns
diff --git a/sbi/inference/base.py b/sbi/inference/base.py
index 1b4666939..1c18d856a 100644
--- a/sbi/inference/base.py
+++ b/sbi/inference/base.py
@@ -441,6 +441,14 @@ def _summarize(
Statistics are extracted from the arguments and from entries in self._summary
created during training.
+
+ Scalar tags:
+ - median_observation_distances
+ - epochs_trained
+ - best_validation_log_prob
+ - validation_log_probs_across_rounds
+ - train_log_probs_across_rounds
+ - epoch_durations_sec_across_rounds
"""
# NB. This is a subset of the logging as done in `GH:conormdurkan/lfi`. A big
diff --git a/sbi/utils/__init__.py b/sbi/utils/__init__.py
index d66c11fb9..a3f4dc7a5 100644
--- a/sbi/utils/__init__.py
+++ b/sbi/utils/__init__.py
@@ -68,3 +68,4 @@
)
from sbi.utils.user_input_checks_utils import MultipleIndependent
from sbi.utils.potentialutils import transformed_potential, pyro_potential_wrapper
+from sbi.utils.tensorboard_output import plot_summary, list_all_logs
diff --git a/sbi/utils/tensorboard_output.py b/sbi/utils/tensorboard_output.py
new file mode 100644
index 000000000..c5d818db3
--- /dev/null
+++ b/sbi/utils/tensorboard_output.py
@@ -0,0 +1,212 @@
+# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
+# under the Affero General Public License v3, see .
+"""Utils for processing tensorboard event data."""
+import inspect
+import logging
+from copy import deepcopy
+from pathlib import Path
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import matplotlib.pyplot as plt
+import numpy as np
+from matplotlib.axes import Axes
+from matplotlib.figure import Figure
+from tensorboard.backend.event_processing.event_accumulator import (
+ DEFAULT_SIZE_GUIDANCE,
+ EventAccumulator,
+)
+
+from sbi.utils.plot import _get_default_opts
+from sbi.utils.io import get_log_root
+import sbi.inference.base
+
+# creating an alias for annotating, because sbi.inference.base.NeuralInference creates
+# a circular import error
+_NeuralInference = Any
+
+
+def plot_summary(
+ inference: Union[_NeuralInference, Path],
+ tags: List[str] = ["validation_log_probs_across_rounds"],
+ disable_tensorboard_prompt: bool = False,
+ tensorboard_scalar_limit: int = 10_000,
+ figsize: List[int] = [20, 6],
+ fontsize: float = 12,
+ fig: Optional[Figure] = None,
+ axes: Optional[Axes] = None,
+ xlabel: str = "epochs",
+ ylabel: List[str] = [],
+ plot_kwargs: Dict[str, Any] = {},
+) -> Tuple[Figure, Axes]:
+ """Plots data logged by the tensorboard summary writer of an inference object.
+
+ Args:
+ inference: inference object that holds a ._summary_writer.log_dir attribute.
+ Optionally the log_dir itself.
+ tags: list of summery writer tags to visualize.
+ disable_tensorboard_prompt: flag to disable the logging of how to run
+ tensorboard and valid tags. Default is False.
+ tensorboard_scalar_limit: overriding DEFAULT_SIZE_GUIDANCE.
+ figsize: determines the figure size. Defaults is [6, 6].
+ fontsize: determines the fontsize of axes ticks and labels. Default is 12.
+ fig: optional existing figure instance.
+ axes: optional existing axes instance.
+ xlabel: x-axis label describing 'steps' attribute of tensorboards ScalarEvent.
+ ylabel: list of alternative ylabels for items in tags. Optional.
+ plot_kwargs: will be passed to ax.plot.
+
+ Returns a tuple of Figure and Axes objects.
+ """
+ logger = logging.getLogger(__name__)
+
+ size_guidance = deepcopy(DEFAULT_SIZE_GUIDANCE)
+ size_guidance.update(scalars=tensorboard_scalar_limit)
+
+ if isinstance(inference, sbi.inference.NeuralInference):
+ log_dir = inference._summary_writer.log_dir
+ elif isinstance(inference, Path):
+ log_dir = inference
+ else:
+ raise ValueError(f"inference {inference}")
+
+ all_event_data = _get_event_data_from_log_dir(log_dir, size_guidance)
+ scalars = all_event_data["scalars"]
+
+ if not disable_tensorboard_prompt:
+ logger.warning(
+ (
+ "For an interactive, detailed view of the summary, launch tensorboard "
+ f" with 'tensorboard --logdir={log_dir}' from a"
+ " terminal on your machine, visit http://127.0.0.1:6006 afterwards."
+ " Requires port forwarding if tensorboard runs on a remote machine, as"
+ " e.g. https://stackoverflow.com/a/42445070/7770835 explains.\n"
+ )
+ )
+ logger.warning(f"Valid tags are: {sorted(list(scalars.keys()))}.")
+
+ _check_tags(scalars, tags)
+
+ if len(scalars[tags[0]]["step"]) == tensorboard_scalar_limit:
+ logger.warning(
+ (
+ "Event data as large as the chosen limit for tensorboard scalars."
+ "Tensorboard might be subsampling your data, as "
+ "https://stackoverflow.com/a/65564389/7770835 explains."
+ " Consider increasing tensorboard_scalar_limit to see all data.\n"
+ )
+ )
+
+ plot_options = _get_default_opts()
+
+ plot_options.update(figsize=figsize, fontsize=fontsize)
+ if fig is None or axes is None:
+ fig, axes = plt.subplots(
+ 1, len(tags), figsize=plot_options["figsize"], **plot_options["subplots"]
+ )
+ axes = np.atleast_1d(axes)
+
+ ylabel = ylabel or tags
+
+ for i, ax in enumerate(axes):
+ ax.plot(scalars[tags[i]]["step"], scalars[tags[i]]["value"], **plot_kwargs)
+
+ ax.set_ylabel(ylabel[i], fontsize=fontsize)
+ ax.set_xlabel(xlabel, fontsize=fontsize)
+ ax.xaxis.set_tick_params(labelsize=fontsize)
+ ax.yaxis.set_tick_params(labelsize=fontsize)
+
+ plt.subplots_adjust(wspace=0.3)
+
+ return fig, axes
+
+
+def list_all_logs(inference: _NeuralInference) -> List:
+ """Returns a list of all log dirs for an inference class."""
+ method = inference.__class__.__name__
+ log_dir = Path(get_log_root()) / method
+ return sorted(log_dir.iterdir())
+
+
+def _get_event_data_from_log_dir(
+ log_dir: Union[str, Path], size_guidance=DEFAULT_SIZE_GUIDANCE
+) -> Dict[str, Dict[str, Dict[str, List[Any]]]]:
+ """All event data stored by tensorboards summary writer as nested dictionary.
+
+ The event data is stripped off from their native tensorboard event types and
+ represented in a tabular way, i.e. Dict[str, List].
+
+ The hierarchy of the dictionary is:
+ 1. tag type: event types that can be logged with tensorboard like 'scalars',
+ 'images', 'histograms', etc.
+ 2. tag: tag for the event type that the user of the SummaryWriter specifies.
+ 3. tag type attribute: attribute of the event.
+
+ Args:
+ log_dir: log dir of a tensorboard summary writer.
+ size_guidance: to avoid causing out of memory erros by loading too much data at
+ once into memory. Defaults to tensorboards default size_guidance.
+
+ Returns a nested, exhaustive dictionary of all event data unter log_dir.
+
+ Based on: https://stackoverflow.com/a/45899735/7770835
+ """
+
+ event_acc = _get_event_accumulator(log_dir, size_guidance)
+
+ all_event_data = {}
+ # tensorboard logs different event types, like scalars, images, histograms etc.
+ for tag_type, list_of_tags in event_acc.Tags().items():
+ all_event_data[tag_type] = {}
+
+ if list_of_tags:
+
+ for tag in list_of_tags:
+ all_event_data[tag_type][tag] = {}
+
+ # to retrieve the data from the EventAccumulator as in
+ # event_acc.Scalars('epochs_trained')
+ _getter_fn = getattr(event_acc, tag_type.capitalize())
+ data = _getter_fn(tag)
+
+ # ScalarEvent has three attributes, wall_time, step, and value
+ # a generic way to get data from all other EventType as for ScalarEvent,
+ # we inspect their argument signature. These events are namedtuples that
+ # can be found here:
+ # https://github.com/tensorflow/tensorboard/blob/b84f3738032277894c6f3fd3e011f032a89d002c/tensorboard/backend/event_processing/event_accumulator.py#L37
+ _type = type(data[0])
+ for attribute in inspect.getfullargspec(_type).args:
+ if not attribute.startswith("_"):
+ if attribute not in all_event_data[tag_type][tag]:
+ all_event_data[tag_type][tag][attribute] = []
+ for datapoint in data:
+ all_event_data[tag_type][tag][attribute].append(
+ getattr(datapoint, attribute)
+ )
+ return all_event_data
+
+
+def _get_event_accumulator(
+ log_dir: Union[str, Path], size_guidance: Dict = DEFAULT_SIZE_GUIDANCE
+) -> EventAccumulator:
+ """Returns the tensorboard EventAccumulator instance for a log dir."""
+ event_acc = EventAccumulator(str(log_dir), size_guidance=size_guidance)
+ event_acc.Reload()
+ return event_acc
+
+
+def _check_tags(adict: Dict, tags: List[str]) -> None:
+ """Checks if tags are present in a dict."""
+ for tag in tags:
+ if tag not in adict:
+ raise KeyError(
+ f"'{tag}' is not a valid tag of the tensorboard SummaryWriter. "
+ f"Valid tags are: {list(adict.keys())}."
+ )
+
+
+def _remove_all_logs(path: Path) -> None:
+ """Removes all logs in path/sbi-logs."""
+ if (path / "sbi-logs").exists():
+ import shutil
+
+ shutil.rmtree(path / "sbi-logs")
diff --git a/tests/plot_test.py b/tests/plot_test.py
new file mode 100644
index 000000000..cb20144d8
--- /dev/null
+++ b/tests/plot_test.py
@@ -0,0 +1,52 @@
+# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
+# under the Affero General Public License v3, see .
+
+import pytest
+import torch
+from torch.utils.tensorboard import SummaryWriter
+
+from sbi.inference import (
+ SNLE,
+ SNPE,
+ SNRE,
+ prepare_for_sbi,
+ simulate_for_sbi,
+)
+from sbi import utils
+from matplotlib.figure import Figure
+from matplotlib.axes import Axes
+
+
+def test_plot_summary(tmp_path):
+ num_dim = 1
+ prior = utils.BoxUniform(low=-2 * torch.ones(num_dim), high=2 * torch.ones(num_dim))
+
+ summary_writer = SummaryWriter(tmp_path)
+
+ def linear_gaussian(theta):
+ return theta + 1.0 + torch.randn_like(theta) * 0.1
+
+ simulator, prior = prepare_for_sbi(linear_gaussian, prior)
+
+ # SNPE
+ inference = SNPE(prior=prior, summary_writer=summary_writer)
+ theta, x = simulate_for_sbi(simulator, proposal=prior, num_simulations=5)
+ _ = inference.append_simulations(theta, x).train(max_num_epochs=1)
+ fig, axes = utils.plot_summary(inference)
+ assert isinstance(fig, Figure) and isinstance(axes[0], Axes)
+
+ # SNLE
+ inference = SNLE(prior=prior, summary_writer=summary_writer)
+ theta, x = simulate_for_sbi(simulator, proposal=prior, num_simulations=5)
+ _ = inference.append_simulations(theta, x).train(max_num_epochs=1)
+ fig, axes = utils.plot_summary(inference)
+ assert isinstance(fig, Figure) and isinstance(axes[0], Axes)
+
+ # SNRE
+ inference = SNRE(prior=prior, summary_writer=summary_writer)
+ theta, x = simulate_for_sbi(simulator, proposal=prior, num_simulations=6)
+ _ = inference.append_simulations(theta, x).train(
+ num_atoms=2, max_num_epochs=5, validation_fraction=0.5
+ )
+ fig, axes = utils.plot_summary(inference)
+ assert isinstance(fig, Figure) and isinstance(axes[0], Axes)