Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature plots description #116

Merged
merged 17 commits into from
Jul 13, 2020
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions livelossplot/main_logger.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import re
from collections import OrderedDict
from typing import NamedTuple, Dict, List, Pattern, Tuple, Optional
from typing import NamedTuple, Dict, List, Pattern, Tuple, Optional, Union

# Value of metrics - for value later, we want to support numpy arrays etc
LogItem = NamedTuple('LogItem', [('step', int), ('value', float)])
Expand All @@ -24,10 +24,11 @@ def __init__(
current_step: int = -1,
auto_generate_groups_if_not_available: bool = True,
auto_generate_metric_to_name: bool = True,
group_patterns: List[Tuple[str, str]] = [
group_patterns: List[Tuple[Pattern, str]] = [
(r'^(?!val(_|-))(.*)', 'training '),
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we shouldn't have space after (e.g. 'training' not 'training ').

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

(r'^(val(_|-))(.*)', 'validation '),
]
],
step_names: Union[str, Dict[str, str]] = 'epoch'
):
"""
:param groups - dictionary with grouped metrics for example one group can contain
Expand All @@ -39,6 +40,7 @@ def __init__(
based on common shortcuts:
:param group_patterns - you can put there regular expressions to match a few metric names with group
and replace its name using second value:
:param step_names list names of x axis for each metrics group or one name for all metrics:
"""
self.log_history = {}
self.groups = groups if groups is not None else {}
Expand All @@ -47,6 +49,7 @@ def __init__(
self.auto_generate_groups = all((not groups, auto_generate_groups_if_not_available))
self.auto_generate_metric_to_name = auto_generate_metric_to_name
self.group_patterns = tuple((re.compile(pattern), replace_with) for pattern, replace_with in group_patterns)
self._step_names = step_names

def update(self, logs: dict, current_step: Optional[int] = None) -> None:
"""Update logs - loop step can be controlled outside or inside main logger"""
Expand Down Expand Up @@ -158,3 +161,9 @@ def log_history(self, value: Dict[str, List[LogItem]]) -> None:
if len(value) > 0:
raise RuntimeError('Cannot overwrite log history with non empty dictionary')
self._log_history = value

def step_name(self, group_name: str) -> str:
if isinstance(self._step_names, str):
return self._step_names
else:
return self._step_names.get(group_name, 'epoch')
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is OK, but I think it would be more idiomatic to use DefaultDict (but it can stay as it is).

63 changes: 46 additions & 17 deletions livelossplot/outputs/matplotlib_plot.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Tuple, List, Dict, Optional
from typing import Tuple, List, Dict, Optional, Callable

import warnings

Expand All @@ -17,44 +17,75 @@ def __init__(
max_cols: int = 2,
max_epoch: int = None,
skip_first: int = 2,
extra_plots=[],
figpath: Optional[str] = None
extra_plots: List[Callable[[MainLogger], None]] = [],
figpath: Optional[str] = None,
after_subplot: Optional[Callable[[str, str], None]] = None,
before_plots: Optional[Callable[[int], None]] = None,
after_plots: Optional[Callable[[], None]] = None,
):
"""
:param cell_size size of one chart:
:param max_cols maximal number of charts in one row:
:param max_epoch maximum epoch on x axis:
:param skip_first number of first steps to skip:
:param extra_plots extra charts functions:
:param figpath path to save figure:
:param after_subplot function which will be called after every subplot:
:param before_plots function which will be called before all subplots:
:param after_plots function which will be called after all subplots:
"""
self.cell_size = cell_size
self.max_cols = max_cols
self.max_epoch = max_epoch
self.skip_first = skip_first # think about it
self.extra_plots = extra_plots
self.max_epoch = max_epoch
self.figpath = figpath
self.file_idx = 0 # now only for saving files
self._after_subplot = after_subplot if after_subplot else self._default_after_subplot
self._before_plots = before_plots if before_plots else self._default_before_plots
self._after_plots = after_plots if after_plots else self._default_after_plots
Comment on lines +44 to +46
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice pattern


def send(self, logger: MainLogger):
"""Draw figures with metrics and show"""
log_groups = logger.grouped_log_history()
figsize_x = self.max_cols * self.cell_size[0]
figsize_y = ((len(log_groups) + 1) // self.max_cols + 1) * self.cell_size[1]

max_rows = (len(log_groups) + len(self.extra_plots) + 1) // self.max_cols + 1

clear_output(wait=True)
plt.figure(figsize=(figsize_x, figsize_y))
self._before_plots(len(log_groups))
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

clear_output should be in before_plots.


for group_id, (group_name, group_logs) in enumerate(log_groups.items()):
plt.subplot(max_rows, self.max_cols, group_id + 1)
self._draw_metric_subplot(group_logs, group_name=group_name)
for group_idx, (group_name, group_logs) in enumerate(log_groups.items()):
plt.subplot(max_rows, self.max_cols, group_idx + 1)
self._draw_metric_subplot(group_logs, group_name=group_name, x_label=logger.step_name(group_name))

for i, extra_plot in enumerate(self.extra_plots):
plt.subplot(max_rows, self.max_cols, i + len(log_groups) + 1)
for idx, extra_plot in enumerate(self.extra_plots):
plt.subplot(max_rows, self.max_cols, idx + len(log_groups) + 1)
extra_plot(logger)

plt.tight_layout()
self._after_plots()
if self.figpath is not None:
plt.savefig(self.figpath.format(i=self.file_idx))
self.file_idx += 1

plt.show()

def _draw_metric_subplot(self, group_logs: Dict[str, List[LogItem]], group_name: str = ''):
def _default_after_subplot(self, group_name: str, x_label: str):
"""Add title xlabel and legend to single chart"""
plt.title(group_name)
plt.xlabel(x_label)
plt.legend(loc='center right')

def _default_before_plots(self, num_of_log_groups: int) -> None:
"""Set matplotlib window properties"""
figsize_x = self.max_cols * self.cell_size[0]
figsize_y = ((num_of_log_groups + 1) // self.max_cols + 1) * self.cell_size[1]
plt.figure(figsize=(figsize_x, figsize_y))

def _default_after_plots(self):
"""Set properties after charts creation"""
plt.tight_layout()

def _draw_metric_subplot(self, group_logs: Dict[str, List[LogItem]], group_name: str, x_label: str):
# there used to be skip first part, but I skip it first
if self.max_epoch is not None:
plt.xlim(0, self.max_epoch)
Expand All @@ -65,9 +96,7 @@ def _draw_metric_subplot(self, group_logs: Dict[str, List[LogItem]], group_name:
ys = [log.value for log in logs]
plt.plot(xs, ys, label=name)

plt.title(group_name)
plt.xlabel('epoch')
plt.legend(loc='center right')
self._after_subplot(group_name, x_label)

def _not_inline_warning(self):
backend = matplotlib.get_backend()
Expand Down
11 changes: 11 additions & 0 deletions tests/test_main_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,14 @@ def test_main_logger_autogroups():
for target_group, target_metrics in target_groups.items():
for m1, m2 in zip(sorted(grouped_log_history[target_group].keys()), sorted(target_metrics)):
assert m1 == m2


def test_main_logger_step_names():
step_names = 'iteration'
logger = MainLogger(step_names=step_names)
assert logger.step_name('Accuracy') == 'iteration'
step_names = {'Accuracy': 'evaluation', 'Loss': 'batch'}
logger = MainLogger(step_names=step_names)
assert logger.step_name('Accuracy') == 'evaluation'
assert logger.step_name('Loss') == 'batch'
assert logger.step_name('Epoch Time') == 'epoch'