Skip to content

Commit

Permalink
Merge pull request #116 from Bartolo1024/feature-plots-description
Browse files Browse the repository at this point in the history
Feature plots description
  • Loading branch information
stared committed Jul 13, 2020
2 parents cd0862f + 8bd10f7 commit 2fb117d
Show file tree
Hide file tree
Showing 8 changed files with 580 additions and 43 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ Look at notebook files with full working [examples](https://github.com/stared/li
- [poutyne.ipynb](https://github.com/stared/livelossplot/blob/master/examples/poutyne.ipynb) - a Poutyne callback ([Poutyne](https://poutyne.org/) is a Keras-like framework for PyTorch)
- [torchbearer.ipynb](https://github.com/stared/livelossplot/blob/master/examples/torchbearer.ipynb) - an example using the built in functionality from torchbearer ([torchbearer](https://github.com/ecs-vlc/torchbearer) is a model fitting library for PyTorch)
- [neptune.py](https://github.com/stared/livelossplot/blob/master/examples/neptune.py) and [neptune.ipynb](https://github.com/stared/livelossplot/blob/master/examples/neptune.ipynb) - a [Neptune.AI](https://neptune.au/)
- [matplotlib.ipynb](https://github.com/stared/livelossplot/blob/master/examples/matplotlib.ipynb) - a Matplotlib output example
- [matplotlib.ipynb](https://github.com/stared/livelossplot/blob/master/examples/various_options.ipynb) - an extended API for metrics grouping and custom outputs

You [run examples in Colab](https://colab.research.google.com/github/stared/livelossplot).

Expand Down
214 changes: 214 additions & 0 deletions examples/matplotlib.ipynb

Large diffs are not rendered by default.

273 changes: 273 additions & 0 deletions examples/various_options.ipynb

Large diffs are not rendered by default.

18 changes: 12 additions & 6 deletions livelossplot/main_logger.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import re
from collections import OrderedDict
from typing import NamedTuple, Dict, List, Pattern, Tuple, Optional
from collections import OrderedDict, defaultdict
from typing import NamedTuple, Dict, List, Pattern, Tuple, Optional, Union

# Value of metrics - for value later, we want to support numpy arrays etc
LogItem = NamedTuple('LogItem', [('step', int), ('value', float)])
Expand All @@ -24,10 +24,11 @@ def __init__(
current_step: int = -1,
auto_generate_groups_if_not_available: bool = True,
auto_generate_metric_to_name: bool = True,
group_patterns: List[Tuple[str, str]] = [
(r'^(?!val(_|-))(.*)', 'training '),
(r'^(val(_|-))(.*)', 'validation '),
]
group_patterns: List[Tuple[Pattern, str]] = [
(r'^(?!val(_|-))(.*)', 'training'),
(r'^(val(_|-))(.*)', 'validation'),
],
step_names: Union[str, Dict[str, str]] = 'epoch'
):
"""
:param groups - dictionary with grouped metrics for example one group can contain
Expand All @@ -39,6 +40,7 @@ def __init__(
based on common shortcuts:
:param group_patterns - you can put there regular expressions to match a few metric names with group
and replace its name using second value:
:param step_names dictionary with a name of x axis for each metrics group or one name for all metrics:
"""
self.log_history = {}
self.groups = groups if groups is not None else {}
Expand All @@ -47,6 +49,10 @@ def __init__(
self.auto_generate_groups = all((not groups, auto_generate_groups_if_not_available))
self.auto_generate_metric_to_name = auto_generate_metric_to_name
self.group_patterns = tuple((re.compile(pattern), replace_with) for pattern, replace_with in group_patterns)
if isinstance(step_names, str):
self.step_names = defaultdict(lambda: step_names)
else:
self.step_names = defaultdict(lambda: 'epoch', step_names)

def update(self, logs: dict, current_step: Optional[int] = None) -> None:
"""Update logs - loop step can be controlled outside or inside main logger"""
Expand Down
77 changes: 54 additions & 23 deletions livelossplot/outputs/matplotlib_plot.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Tuple, List, Dict, Optional
from typing import Tuple, List, Dict, Optional, Callable

import warnings

Expand All @@ -17,57 +17,88 @@ def __init__(
max_cols: int = 2,
max_epoch: int = None,
skip_first: int = 2,
extra_plots=[],
figpath: Optional[str] = None
extra_plots: List[Callable[[MainLogger], None]] = [],
figpath: Optional[str] = None,
after_subplot: Optional[Callable[[plt.Axes, str, str], None]] = None,
before_plots: Optional[Callable[[plt.Figure, int], None]] = None,
after_plots: Optional[Callable[[plt.Figure], None]] = None,
):
"""
:param cell_size size of one chart:
:param max_cols maximal number of charts in one row:
:param max_epoch maximum epoch on x axis:
:param skip_first number of first steps to skip:
:param extra_plots extra charts functions:
:param figpath path to save figure:
:param after_subplot function which will be called after every subplot:
:param before_plots function which will be called before all subplots:
:param after_plots function which will be called after all subplots:
"""
self.cell_size = cell_size
self.max_cols = max_cols
self.max_epoch = max_epoch
self.skip_first = skip_first # think about it
self.extra_plots = extra_plots
self.max_epoch = max_epoch
self.figpath = figpath
self.file_idx = 0 # now only for saving files
self._after_subplot = after_subplot if after_subplot else self._default_after_subplot
self._before_plots = before_plots if before_plots else self._default_before_plots
self._after_plots = after_plots if after_plots else self._default_after_plots

def send(self, logger: MainLogger):
"""Draw figures with metrics and show"""
log_groups = logger.grouped_log_history()
figsize_x = self.max_cols * self.cell_size[0]
figsize_y = ((len(log_groups) + 1) // self.max_cols + 1) * self.cell_size[1]

max_rows = (len(log_groups) + len(self.extra_plots) + 1) // self.max_cols + 1
max_rows = (len(log_groups) + len(self.extra_plots)) // self.max_cols

clear_output(wait=True)
plt.figure(figsize=(figsize_x, figsize_y))
fig, axes = plt.subplots(max_rows, self.max_cols)
axes = axes.reshape(-1, self.max_cols)
self._before_plots(fig, len(log_groups))

for group_id, (group_name, group_logs) in enumerate(log_groups.items()):
plt.subplot(max_rows, self.max_cols, group_id + 1)
self._draw_metric_subplot(group_logs, group_name=group_name)
for group_idx, (group_name, group_logs) in enumerate(log_groups.items()):
ax = axes[group_idx // self.max_cols, group_idx % self.max_cols]
self._draw_metric_subplot(ax, group_logs, group_name=group_name, x_label=logger.step_names[group_name])

for i, extra_plot in enumerate(self.extra_plots):
plt.subplot(max_rows, self.max_cols, i + len(log_groups) + 1)
extra_plot(logger)
for idx, extra_plot in enumerate(self.extra_plots):
ax = axes[(len(log_groups) + idx) // self.max_cols, (len(log_groups) + idx) % self.max_cols]
extra_plot(ax, logger)

plt.tight_layout()
self._after_plots(fig)
if self.figpath is not None:
plt.savefig(self.figpath.format(i=self.file_idx))
fig.savefig(self.figpath.format(i=self.file_idx))
self.file_idx += 1

plt.show()

def _draw_metric_subplot(self, group_logs: Dict[str, List[LogItem]], group_name: str = ''):
def _default_after_subplot(self, ax: plt.Axes, group_name: str, x_label: str):
"""Add title xlabel and legend to single chart"""
ax.set_title(group_name)
ax.set_xlabel(x_label)
ax.legend(loc='center right')

def _default_before_plots(self, fig: plt.Figure, num_of_log_groups: int) -> None:
"""Set matplotlib window properties"""
clear_output(wait=True)
figsize_x = self.max_cols * self.cell_size[0]
figsize_y = ((num_of_log_groups + 1) // self.max_cols + 1) * self.cell_size[1]
fig.set_size_inches(figsize_x, figsize_y)

def _default_after_plots(self, fig: plt.Figure):
"""Set properties after charts creation"""
fig.tight_layout()

def _draw_metric_subplot(self, ax: plt.Axes, group_logs: Dict[str, List[LogItem]], group_name: str, x_label: str):
# there used to be skip first part, but I skip it first
if self.max_epoch is not None:
plt.xlim(0, self.max_epoch)
ax.set_xlim(0, self.max_epoch)

for name, logs in group_logs.items():
if len(logs) > 0:
xs = [log.step for log in logs]
ys = [log.value for log in logs]
plt.plot(xs, ys, label=name)
ax.plot(xs, ys, label=name)

plt.title(group_name)
plt.xlabel('epoch')
plt.legend(loc='center right')
self._after_subplot(ax, group_name, x_label)

def _not_inline_warning(self):
backend = matplotlib.get_backend()
Expand Down
8 changes: 4 additions & 4 deletions tests/test_extrema_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def test_extrema_print():
liveplot.update({'acc': 0.65, 'val_acc': 0.55, 'loss': 1.0, 'val_loss': 0.9})
liveplot.send()
assert len(plugin.extrema_cache['log-loss']) == 2
assert len(plugin.extrema_cache['log-loss']['training ']) == 3
assert plugin.extrema_cache['accuracy']['validation ']['min'] == 0.35
assert plugin.extrema_cache['accuracy']['validation ']['max'] == 0.55
assert plugin.extrema_cache['accuracy']['validation ']['current'] == 0.55
assert len(plugin.extrema_cache['log-loss']['training']) == 3
assert plugin.extrema_cache['accuracy']['validation']['min'] == 0.35
assert plugin.extrema_cache['accuracy']['validation']['max'] == 0.55
assert plugin.extrema_cache['accuracy']['validation']['current'] == 0.55
29 changes: 20 additions & 9 deletions tests/test_main_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ def test_main_logger_with_groups():
grouped_log_history = logger.grouped_log_history()
assert len(grouped_log_history) == 2
assert len(grouped_log_history['acccuracy']) == 2
assert len(grouped_log_history['acccuracy']['validation ']) == 3
assert len(grouped_log_history['log-loss']['training ']) == 3
assert len(grouped_log_history['acccuracy']['validation']) == 3
assert len(grouped_log_history['log-loss']['training']) == 3


def test_main_logger_with_default_groups():
Expand All @@ -44,10 +44,10 @@ def test_main_logger_with_default_groups():
grouped_log_history = logger.grouped_log_history()
assert len(grouped_log_history) == 2
assert len(grouped_log_history['Accuracy']) == 2
assert len(grouped_log_history['Accuracy']['validation ']) == 3
assert len(grouped_log_history['Accuracy']['validation']) == 3


def test_main_logger_metrc_to_name():
def test_main_logger_metric_to_name():
"""Test group patterns"""
logger = MainLogger()
logger.update({'acc': 0.5, 'val_acc': 0.4, 'loss': 1.2, 'val_loss': 1.1, 'lr': 0.01})
Expand All @@ -56,10 +56,10 @@ def test_main_logger_metrc_to_name():
metric_to_name = logger.metric_to_name
assert 'lr' not in metric_to_name
target_metric_to_name = {
'acc': 'training ',
'val_acc': 'validation ',
'loss': 'training ',
'val_loss': 'validation ',
'acc': 'training',
'val_acc': 'validation',
'loss': 'training',
'val_loss': 'validation',
}
for metric, metric_name in metric_to_name.items():
assert metric_name == target_metric_to_name.get(metric)
Expand All @@ -72,7 +72,18 @@ def test_main_logger_autogroups():
logger.update({'acc': 0.55, 'val_acc': 0.45, 'loss': 1.1, 'val_loss': 1.0, 'lr': 0.001})
logger.update({'acc': 0.65, 'val_acc': 0.55, 'loss': 1.0, 'val_loss': 0.9, 'lr': 0.0001})
grouped_log_history = logger.grouped_log_history()
target_groups = {'Accuracy': ('validation ', 'training '), 'Loss': ('validation ', 'training '), 'lr': ('lr', )}
target_groups = {'Accuracy': ('validation', 'training'), 'Loss': ('validation', 'training'), 'lr': ('lr', )}
for target_group, target_metrics in target_groups.items():
for m1, m2 in zip(sorted(grouped_log_history[target_group].keys()), sorted(target_metrics)):
assert m1 == m2


def test_main_logger_step_names():
step_names = 'iteration'
logger = MainLogger(step_names=step_names)
assert logger.step_names['Accuracy'] == 'iteration'
step_names = {'Accuracy': 'evaluation', 'Loss': 'batch'}
logger = MainLogger(step_names=step_names)
assert logger.step_names['Accuracy'] == 'evaluation'
assert logger.step_names['Loss'] == 'batch'
assert logger.step_names['Epoch Time'] == 'epoch'
2 changes: 1 addition & 1 deletion tests/test_plot_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def send(self, logger: MainLogger):
assert len(grouped_log_history) == 2
assert len(grouped_log_history['Accuracy']) == 2
print(grouped_log_history)
assert len(grouped_log_history['Accuracy']['validation ']) == 2
assert len(grouped_log_history['Accuracy']['validation']) == 2


def test_plot_losses():
Expand Down

0 comments on commit 2fb117d

Please sign in to comment.