From 8fa3efee1affad2b5ca3e5d8e8c0b7e100654d60 Mon Sep 17 00:00:00 2001 From: Janne Lappalainen <34949352+lappalainenj@users.noreply.github.com> Date: Wed, 26 Jan 2022 15:34:39 +0100 Subject: [PATCH] Plot functions for tensorboard data (#593) * added vscode code-workspace to gitignore * Plot functions for tensorboard data #586 * fixed minor style issues * adding test for plot_summary --- .gitignore | 1 + sbi/inference/base.py | 8 ++ sbi/utils/__init__.py | 1 + sbi/utils/tensorboard_output.py | 212 ++++++++++++++++++++++++++++++++ tests/plot_test.py | 52 ++++++++ 5 files changed, 274 insertions(+) create mode 100644 sbi/utils/tensorboard_output.py create mode 100644 tests/plot_test.py diff --git a/.gitignore b/.gitignore index 2143160fb..dafb4a21b 100644 --- a/.gitignore +++ b/.gitignore @@ -89,6 +89,7 @@ target/ # IDEs (VSCode, etc.) .env +*.code-workspace # Class diagram *.pyns diff --git a/sbi/inference/base.py b/sbi/inference/base.py index 1b4666939..1c18d856a 100644 --- a/sbi/inference/base.py +++ b/sbi/inference/base.py @@ -441,6 +441,14 @@ def _summarize( Statistics are extracted from the arguments and from entries in self._summary created during training. + + Scalar tags: + - median_observation_distances + - epochs_trained + - best_validation_log_prob + - validation_log_probs_across_rounds + - train_log_probs_across_rounds + - epoch_durations_sec_across_rounds """ # NB. This is a subset of the logging as done in `GH:conormdurkan/lfi`. A big diff --git a/sbi/utils/__init__.py b/sbi/utils/__init__.py index d66c11fb9..a3f4dc7a5 100644 --- a/sbi/utils/__init__.py +++ b/sbi/utils/__init__.py @@ -68,3 +68,4 @@ ) from sbi.utils.user_input_checks_utils import MultipleIndependent from sbi.utils.potentialutils import transformed_potential, pyro_potential_wrapper +from sbi.utils.tensorboard_output import plot_summary, list_all_logs diff --git a/sbi/utils/tensorboard_output.py b/sbi/utils/tensorboard_output.py new file mode 100644 index 000000000..c5d818db3 --- /dev/null +++ b/sbi/utils/tensorboard_output.py @@ -0,0 +1,212 @@ +# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed +# under the Affero General Public License v3, see . +"""Utils for processing tensorboard event data.""" +import inspect +import logging +from copy import deepcopy +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple, Union + +import matplotlib.pyplot as plt +import numpy as np +from matplotlib.axes import Axes +from matplotlib.figure import Figure +from tensorboard.backend.event_processing.event_accumulator import ( + DEFAULT_SIZE_GUIDANCE, + EventAccumulator, +) + +from sbi.utils.plot import _get_default_opts +from sbi.utils.io import get_log_root +import sbi.inference.base + +# creating an alias for annotating, because sbi.inference.base.NeuralInference creates +# a circular import error +_NeuralInference = Any + + +def plot_summary( + inference: Union[_NeuralInference, Path], + tags: List[str] = ["validation_log_probs_across_rounds"], + disable_tensorboard_prompt: bool = False, + tensorboard_scalar_limit: int = 10_000, + figsize: List[int] = [20, 6], + fontsize: float = 12, + fig: Optional[Figure] = None, + axes: Optional[Axes] = None, + xlabel: str = "epochs", + ylabel: List[str] = [], + plot_kwargs: Dict[str, Any] = {}, +) -> Tuple[Figure, Axes]: + """Plots data logged by the tensorboard summary writer of an inference object. + + Args: + inference: inference object that holds a ._summary_writer.log_dir attribute. + Optionally the log_dir itself. + tags: list of summery writer tags to visualize. + disable_tensorboard_prompt: flag to disable the logging of how to run + tensorboard and valid tags. Default is False. + tensorboard_scalar_limit: overriding DEFAULT_SIZE_GUIDANCE. + figsize: determines the figure size. Defaults is [6, 6]. + fontsize: determines the fontsize of axes ticks and labels. Default is 12. + fig: optional existing figure instance. + axes: optional existing axes instance. + xlabel: x-axis label describing 'steps' attribute of tensorboards ScalarEvent. + ylabel: list of alternative ylabels for items in tags. Optional. + plot_kwargs: will be passed to ax.plot. + + Returns a tuple of Figure and Axes objects. + """ + logger = logging.getLogger(__name__) + + size_guidance = deepcopy(DEFAULT_SIZE_GUIDANCE) + size_guidance.update(scalars=tensorboard_scalar_limit) + + if isinstance(inference, sbi.inference.NeuralInference): + log_dir = inference._summary_writer.log_dir + elif isinstance(inference, Path): + log_dir = inference + else: + raise ValueError(f"inference {inference}") + + all_event_data = _get_event_data_from_log_dir(log_dir, size_guidance) + scalars = all_event_data["scalars"] + + if not disable_tensorboard_prompt: + logger.warning( + ( + "For an interactive, detailed view of the summary, launch tensorboard " + f" with 'tensorboard --logdir={log_dir}' from a" + " terminal on your machine, visit http://127.0.0.1:6006 afterwards." + " Requires port forwarding if tensorboard runs on a remote machine, as" + " e.g. https://stackoverflow.com/a/42445070/7770835 explains.\n" + ) + ) + logger.warning(f"Valid tags are: {sorted(list(scalars.keys()))}.") + + _check_tags(scalars, tags) + + if len(scalars[tags[0]]["step"]) == tensorboard_scalar_limit: + logger.warning( + ( + "Event data as large as the chosen limit for tensorboard scalars." + "Tensorboard might be subsampling your data, as " + "https://stackoverflow.com/a/65564389/7770835 explains." + " Consider increasing tensorboard_scalar_limit to see all data.\n" + ) + ) + + plot_options = _get_default_opts() + + plot_options.update(figsize=figsize, fontsize=fontsize) + if fig is None or axes is None: + fig, axes = plt.subplots( + 1, len(tags), figsize=plot_options["figsize"], **plot_options["subplots"] + ) + axes = np.atleast_1d(axes) + + ylabel = ylabel or tags + + for i, ax in enumerate(axes): + ax.plot(scalars[tags[i]]["step"], scalars[tags[i]]["value"], **plot_kwargs) + + ax.set_ylabel(ylabel[i], fontsize=fontsize) + ax.set_xlabel(xlabel, fontsize=fontsize) + ax.xaxis.set_tick_params(labelsize=fontsize) + ax.yaxis.set_tick_params(labelsize=fontsize) + + plt.subplots_adjust(wspace=0.3) + + return fig, axes + + +def list_all_logs(inference: _NeuralInference) -> List: + """Returns a list of all log dirs for an inference class.""" + method = inference.__class__.__name__ + log_dir = Path(get_log_root()) / method + return sorted(log_dir.iterdir()) + + +def _get_event_data_from_log_dir( + log_dir: Union[str, Path], size_guidance=DEFAULT_SIZE_GUIDANCE +) -> Dict[str, Dict[str, Dict[str, List[Any]]]]: + """All event data stored by tensorboards summary writer as nested dictionary. + + The event data is stripped off from their native tensorboard event types and + represented in a tabular way, i.e. Dict[str, List]. + + The hierarchy of the dictionary is: + 1. tag type: event types that can be logged with tensorboard like 'scalars', + 'images', 'histograms', etc. + 2. tag: tag for the event type that the user of the SummaryWriter specifies. + 3. tag type attribute: attribute of the event. + + Args: + log_dir: log dir of a tensorboard summary writer. + size_guidance: to avoid causing out of memory erros by loading too much data at + once into memory. Defaults to tensorboards default size_guidance. + + Returns a nested, exhaustive dictionary of all event data unter log_dir. + + Based on: https://stackoverflow.com/a/45899735/7770835 + """ + + event_acc = _get_event_accumulator(log_dir, size_guidance) + + all_event_data = {} + # tensorboard logs different event types, like scalars, images, histograms etc. + for tag_type, list_of_tags in event_acc.Tags().items(): + all_event_data[tag_type] = {} + + if list_of_tags: + + for tag in list_of_tags: + all_event_data[tag_type][tag] = {} + + # to retrieve the data from the EventAccumulator as in + # event_acc.Scalars('epochs_trained') + _getter_fn = getattr(event_acc, tag_type.capitalize()) + data = _getter_fn(tag) + + # ScalarEvent has three attributes, wall_time, step, and value + # a generic way to get data from all other EventType as for ScalarEvent, + # we inspect their argument signature. These events are namedtuples that + # can be found here: + # https://github.com/tensorflow/tensorboard/blob/b84f3738032277894c6f3fd3e011f032a89d002c/tensorboard/backend/event_processing/event_accumulator.py#L37 + _type = type(data[0]) + for attribute in inspect.getfullargspec(_type).args: + if not attribute.startswith("_"): + if attribute not in all_event_data[tag_type][tag]: + all_event_data[tag_type][tag][attribute] = [] + for datapoint in data: + all_event_data[tag_type][tag][attribute].append( + getattr(datapoint, attribute) + ) + return all_event_data + + +def _get_event_accumulator( + log_dir: Union[str, Path], size_guidance: Dict = DEFAULT_SIZE_GUIDANCE +) -> EventAccumulator: + """Returns the tensorboard EventAccumulator instance for a log dir.""" + event_acc = EventAccumulator(str(log_dir), size_guidance=size_guidance) + event_acc.Reload() + return event_acc + + +def _check_tags(adict: Dict, tags: List[str]) -> None: + """Checks if tags are present in a dict.""" + for tag in tags: + if tag not in adict: + raise KeyError( + f"'{tag}' is not a valid tag of the tensorboard SummaryWriter. " + f"Valid tags are: {list(adict.keys())}." + ) + + +def _remove_all_logs(path: Path) -> None: + """Removes all logs in path/sbi-logs.""" + if (path / "sbi-logs").exists(): + import shutil + + shutil.rmtree(path / "sbi-logs") diff --git a/tests/plot_test.py b/tests/plot_test.py new file mode 100644 index 000000000..cb20144d8 --- /dev/null +++ b/tests/plot_test.py @@ -0,0 +1,52 @@ +# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed +# under the Affero General Public License v3, see . + +import pytest +import torch +from torch.utils.tensorboard import SummaryWriter + +from sbi.inference import ( + SNLE, + SNPE, + SNRE, + prepare_for_sbi, + simulate_for_sbi, +) +from sbi import utils +from matplotlib.figure import Figure +from matplotlib.axes import Axes + + +def test_plot_summary(tmp_path): + num_dim = 1 + prior = utils.BoxUniform(low=-2 * torch.ones(num_dim), high=2 * torch.ones(num_dim)) + + summary_writer = SummaryWriter(tmp_path) + + def linear_gaussian(theta): + return theta + 1.0 + torch.randn_like(theta) * 0.1 + + simulator, prior = prepare_for_sbi(linear_gaussian, prior) + + # SNPE + inference = SNPE(prior=prior, summary_writer=summary_writer) + theta, x = simulate_for_sbi(simulator, proposal=prior, num_simulations=5) + _ = inference.append_simulations(theta, x).train(max_num_epochs=1) + fig, axes = utils.plot_summary(inference) + assert isinstance(fig, Figure) and isinstance(axes[0], Axes) + + # SNLE + inference = SNLE(prior=prior, summary_writer=summary_writer) + theta, x = simulate_for_sbi(simulator, proposal=prior, num_simulations=5) + _ = inference.append_simulations(theta, x).train(max_num_epochs=1) + fig, axes = utils.plot_summary(inference) + assert isinstance(fig, Figure) and isinstance(axes[0], Axes) + + # SNRE + inference = SNRE(prior=prior, summary_writer=summary_writer) + theta, x = simulate_for_sbi(simulator, proposal=prior, num_simulations=6) + _ = inference.append_simulations(theta, x).train( + num_atoms=2, max_num_epochs=5, validation_fraction=0.5 + ) + fig, axes = utils.plot_summary(inference) + assert isinstance(fig, Figure) and isinstance(axes[0], Axes)