forked from pykeen/pykeen
-
Notifications
You must be signed in to change notification settings - Fork 0
/
wandb.py
90 lines (71 loc) 路 2.73 KB
/
wandb.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
# -*- coding: utf-8 -*-
"""An adapter for Weights and Biases."""
import os
from typing import TYPE_CHECKING, Any, Mapping, Optional
from .base import ResultTracker
from ..utils import flatten_dictionary
if TYPE_CHECKING:
import wandb.wandb_run
__all__ = [
"WANDBResultTracker",
]
class WANDBResultTracker(ResultTracker):
"""A tracker for Weights and Biases.
Note that you have to perform wandb login beforehand.
"""
#: The WANDB run
run: "wandb.wandb_run.Run"
def __init__(
self,
project: str,
offline: bool = False,
**kwargs,
):
"""Initialize result tracking via WANDB.
:param project:
project name your WANDB login has access to.
:param offline:
whether to run in offline mode, i.e, without syncing with the wandb server.
:param kwargs:
additional keyword arguments passed to :func:`wandb.init`.
:raises ValueError:
If the project name is given as None
"""
import wandb as _wandb
self.wandb = _wandb
if project is None:
raise ValueError("Weights & Biases requires a project name.")
self.project = project
if "allow_val_change" not in kwargs:
kwargs["allow_val_change"] = None
if offline:
os.environ[self.wandb.env.MODE] = "dryrun" # type: ignore
self.kwargs = kwargs
self.run = None
# docstr-coverage: inherited
def start_run(self, run_name: Optional[str] = None) -> None: # noqa: D102
self.run = self.wandb.init(project=self.project, name=run_name, **self.kwargs) # type: ignore
# docstr-coverage: inherited
def end_run(self, success: bool = True) -> None: # noqa: D102
self.run.finish(exit_code=0 if success else -1)
self.run = None
# docstr-coverage: inherited
def log_metrics(
self,
metrics: Mapping[str, float],
step: Optional[int] = None,
prefix: Optional[str] = None,
) -> None: # noqa: D102
if self.run is None:
raise AssertionError("start_run must be called before logging any metrics")
metrics = flatten_dictionary(dictionary=metrics, prefix=prefix)
self.run.log(metrics, step=step)
# docstr-coverage: inherited
def log_params(self, params: Mapping[str, Any], prefix: Optional[str] = None) -> None: # noqa: D102
if self.run is None:
raise AssertionError("start_run must be called before logging any metrics")
params = flatten_dictionary(dictionary=params, prefix=prefix)
if self.kwargs["allow_val_change"]:
self.run.config.update(params, allow_val_change=True)
else:
self.run.config.update(params)