forked from pykeen/pykeen
-
Notifications
You must be signed in to change notification settings - Fork 0
/
tensorboard.py
78 lines (63 loc) 路 2.72 KB
/
tensorboard.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
# -*- coding: utf-8 -*-
"""An adapter for TensorBoard."""
import datetime
import pathlib
from typing import Any, Dict, Mapping, Optional, TYPE_CHECKING, Type
from .base import ResultTracker
from ..constants import PYKEEN_LOGS
from ..utils import flatten_dictionary
if TYPE_CHECKING:
import torch.utils.tensorboard
__all__ = [
'TensorBoardResultTracker',
]
class TensorBoardResultTracker(ResultTracker):
"""A tracker for TensorBoard."""
#: The class that's used to instantiate a summarywriter
SummaryWriter: Type['torch.utils.tensorboard.SummaryWriter']
def __init__(
self,
experiment_path: Optional[str] = None,
experiment_name: Optional[str] = None,
tags: Optional[Dict[str, Any]] = None,
):
"""
Initialize result tracking via Tensorboard.
:param experiment_path:
The experiment path. A custom path at which the tensorboard logs will be saved.
:param experiment_name:
The name of the experiment, will be used as a sub directory name for the logging. If no default is given,
the current time is used. If set, experiment_path is set, this argument has no effect.
:param tags:
The additional run details which are presented as tags to be logged
"""
from torch.utils.tensorboard import SummaryWriter as _SummaryWriter
self.summary_writer_cls = _SummaryWriter
self.tags = tags
if experiment_path is None:
if experiment_name is None:
experiment_name = datetime.datetime.now().isoformat()
path = PYKEEN_LOGS.joinpath("tensorboard", experiment_name)
elif isinstance(experiment_path, str):
path = pathlib.Path(experiment_path)
self.path = path
def start_run(self, run_name: Optional[str] = None) -> None: # noqa: D102
self.writer = self.SummaryWriter(log_dir=self.path, comment=run_name)
def log_metrics(
self,
metrics: Mapping[str, float],
step: Optional[int] = None,
prefix: Optional[str] = None,
) -> None: # noqa: D102
metrics = flatten_dictionary(dictionary=metrics, prefix=prefix)
for key, value in metrics.items():
self.writer.add_scalar(tag=key, scalar_value=value, global_step=step)
self.writer.flush()
def log_params(self, params: Mapping[str, Any], prefix: Optional[str] = None) -> None: # noqa: D102
params = flatten_dictionary(dictionary=params, prefix=prefix)
for key, value in params.items():
self.writer.add_text(tag=str(key), text_string=str(value))
self.writer.flush()
def end_run(self) -> None: # noqa: D102
self.writer.flush()
self.writer.close()