-
-
Notifications
You must be signed in to change notification settings - Fork 142
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature plots description #116
Changes from 6 commits
d2292eb
8c4bd7d
22a119c
4a676c0
7df997b
1d508d6
eba7dad
6202511
9d7a4c8
3b498e9
3339586
77fd9f4
441a4b7
88baaca
f26f911
fed38ff
8bd10f7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
import re | ||
from collections import OrderedDict | ||
from typing import NamedTuple, Dict, List, Pattern, Tuple, Optional | ||
from typing import NamedTuple, Dict, List, Pattern, Tuple, Optional, Union | ||
|
||
# Value of metrics - for value later, we want to support numpy arrays etc | ||
LogItem = NamedTuple('LogItem', [('step', int), ('value', float)]) | ||
|
@@ -24,10 +24,11 @@ def __init__( | |
current_step: int = -1, | ||
auto_generate_groups_if_not_available: bool = True, | ||
auto_generate_metric_to_name: bool = True, | ||
group_patterns: List[Tuple[str, str]] = [ | ||
group_patterns: List[Tuple[Pattern, str]] = [ | ||
(r'^(?!val(_|-))(.*)', 'training '), | ||
(r'^(val(_|-))(.*)', 'validation '), | ||
] | ||
], | ||
step_names: Union[str, Dict[str, str]] = 'epoch' | ||
): | ||
""" | ||
:param groups - dictionary with grouped metrics for example one group can contain | ||
|
@@ -39,6 +40,7 @@ def __init__( | |
based on common shortcuts: | ||
:param group_patterns - you can put there regular expressions to match a few metric names with group | ||
and replace its name using second value: | ||
:param step_names list names of x axis for each metrics group or one name for all metrics: | ||
""" | ||
self.log_history = {} | ||
self.groups = groups if groups is not None else {} | ||
|
@@ -47,6 +49,7 @@ def __init__( | |
self.auto_generate_groups = all((not groups, auto_generate_groups_if_not_available)) | ||
self.auto_generate_metric_to_name = auto_generate_metric_to_name | ||
self.group_patterns = tuple((re.compile(pattern), replace_with) for pattern, replace_with in group_patterns) | ||
self._step_names = step_names | ||
|
||
def update(self, logs: dict, current_step: Optional[int] = None) -> None: | ||
"""Update logs - loop step can be controlled outside or inside main logger""" | ||
|
@@ -158,3 +161,9 @@ def log_history(self, value: Dict[str, List[LogItem]]) -> None: | |
if len(value) > 0: | ||
raise RuntimeError('Cannot overwrite log history with non empty dictionary') | ||
self._log_history = value | ||
|
||
def step_name(self, group_name: str) -> str: | ||
if isinstance(self._step_names, str): | ||
return self._step_names | ||
else: | ||
return self._step_names.get(group_name, 'epoch') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is OK, but I think it would be more idiomatic to use DefaultDict (but it can stay as it is). |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
from typing import Tuple, List, Dict, Optional | ||
from typing import Tuple, List, Dict, Optional, Callable | ||
|
||
import warnings | ||
|
||
|
@@ -17,44 +17,75 @@ def __init__( | |
max_cols: int = 2, | ||
max_epoch: int = None, | ||
skip_first: int = 2, | ||
extra_plots=[], | ||
figpath: Optional[str] = None | ||
extra_plots: List[Callable[[MainLogger], None]] = [], | ||
figpath: Optional[str] = None, | ||
after_subplot: Optional[Callable[[str, str], None]] = None, | ||
before_plots: Optional[Callable[[int], None]] = None, | ||
after_plots: Optional[Callable[[], None]] = None, | ||
): | ||
""" | ||
:param cell_size size of one chart: | ||
:param max_cols maximal number of charts in one row: | ||
:param max_epoch maximum epoch on x axis: | ||
:param skip_first number of first steps to skip: | ||
:param extra_plots extra charts functions: | ||
:param figpath path to save figure: | ||
:param after_subplot function which will be called after every subplot: | ||
:param before_plots function which will be called before all subplots: | ||
:param after_plots function which will be called after all subplots: | ||
""" | ||
self.cell_size = cell_size | ||
self.max_cols = max_cols | ||
self.max_epoch = max_epoch | ||
self.skip_first = skip_first # think about it | ||
self.extra_plots = extra_plots | ||
self.max_epoch = max_epoch | ||
self.figpath = figpath | ||
self.file_idx = 0 # now only for saving files | ||
self._after_subplot = after_subplot if after_subplot else self._default_after_subplot | ||
self._before_plots = before_plots if before_plots else self._default_before_plots | ||
self._after_plots = after_plots if after_plots else self._default_after_plots | ||
Comment on lines
+44
to
+46
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice pattern |
||
|
||
def send(self, logger: MainLogger): | ||
"""Draw figures with metrics and show""" | ||
log_groups = logger.grouped_log_history() | ||
figsize_x = self.max_cols * self.cell_size[0] | ||
figsize_y = ((len(log_groups) + 1) // self.max_cols + 1) * self.cell_size[1] | ||
|
||
max_rows = (len(log_groups) + len(self.extra_plots) + 1) // self.max_cols + 1 | ||
|
||
clear_output(wait=True) | ||
plt.figure(figsize=(figsize_x, figsize_y)) | ||
self._before_plots(len(log_groups)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
||
for group_id, (group_name, group_logs) in enumerate(log_groups.items()): | ||
plt.subplot(max_rows, self.max_cols, group_id + 1) | ||
self._draw_metric_subplot(group_logs, group_name=group_name) | ||
for group_idx, (group_name, group_logs) in enumerate(log_groups.items()): | ||
plt.subplot(max_rows, self.max_cols, group_idx + 1) | ||
self._draw_metric_subplot(group_logs, group_name=group_name, x_label=logger.step_name(group_name)) | ||
|
||
for i, extra_plot in enumerate(self.extra_plots): | ||
plt.subplot(max_rows, self.max_cols, i + len(log_groups) + 1) | ||
for idx, extra_plot in enumerate(self.extra_plots): | ||
plt.subplot(max_rows, self.max_cols, idx + len(log_groups) + 1) | ||
extra_plot(logger) | ||
|
||
plt.tight_layout() | ||
self._after_plots() | ||
if self.figpath is not None: | ||
plt.savefig(self.figpath.format(i=self.file_idx)) | ||
self.file_idx += 1 | ||
|
||
plt.show() | ||
|
||
def _draw_metric_subplot(self, group_logs: Dict[str, List[LogItem]], group_name: str = ''): | ||
def _default_after_subplot(self, group_name: str, x_label: str): | ||
"""Add title xlabel and legend to single chart""" | ||
plt.title(group_name) | ||
plt.xlabel(x_label) | ||
plt.legend(loc='center right') | ||
|
||
def _default_before_plots(self, num_of_log_groups: int) -> None: | ||
"""Set matplotlib window properties""" | ||
figsize_x = self.max_cols * self.cell_size[0] | ||
figsize_y = ((num_of_log_groups + 1) // self.max_cols + 1) * self.cell_size[1] | ||
plt.figure(figsize=(figsize_x, figsize_y)) | ||
|
||
def _default_after_plots(self): | ||
"""Set properties after charts creation""" | ||
plt.tight_layout() | ||
|
||
def _draw_metric_subplot(self, group_logs: Dict[str, List[LogItem]], group_name: str, x_label: str): | ||
# there used to be skip first part, but I skip it first | ||
if self.max_epoch is not None: | ||
plt.xlim(0, self.max_epoch) | ||
|
@@ -65,9 +96,7 @@ def _draw_metric_subplot(self, group_logs: Dict[str, List[LogItem]], group_name: | |
ys = [log.value for log in logs] | ||
plt.plot(xs, ys, label=name) | ||
|
||
plt.title(group_name) | ||
plt.xlabel('epoch') | ||
plt.legend(loc='center right') | ||
self._after_subplot(group_name, x_label) | ||
|
||
def _not_inline_warning(self): | ||
backend = matplotlib.get_backend() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here we shouldn't have space after (e.g.
'training'
not'training '
).There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done