/
datablob_metadata.py
190 lines (167 loc) · 6.64 KB
/
datablob_metadata.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
"""Tools for storing metadata about :py:class:`~scalarstop.datablob.DataBlob` s."""
import json
import os
from typing import Any, Dict, Union
import scalarstop.pickle
from scalarstop._constants import _METADATA_JSON_FILENAME, _METADATA_PICKLE_FILENAME
from scalarstop.dataclasses import asdict
from scalarstop.exceptions import DataBlobNotFound
from scalarstop.hyperparams import HyperparamsType
class DataBlobMetadata:
"""
Represents the metadata from a :py:class:`~scalarstop.datablob.DataBlob`
that is saved to or loaded from the filesystem.
When we save metadata to the filesystem, we save the same information
in two files--one in JSON format and the other in Python Pickle format.
The JSON file is human-readable and can be parsed by non-Python
programs. The Pickle file is kept to ensure that the
:py:class:`~scalarstop.datablob.DataBlob` 's hyperparams can be accurately
deserialized--even if the hyperparams are not JSON-serializable.
"""
@staticmethod
def _pickle_filename(path: str) -> str:
"""
The exact filename of the
:py:class:`~scalarstop.datablob.DataBlob`'s Pickle metadata file."""
return os.path.join(path, _METADATA_PICKLE_FILENAME)
@staticmethod
def _json_filename(path: str) -> str:
"""
The exact filename of the
:py:class:`~scalarstop.datablob.DataBlob` 's JSON metadata file.
"""
return os.path.join(path, _METADATA_JSON_FILENAME)
@classmethod
def load(
cls,
path: str,
) -> "DataBlobMetadata":
"""
Loads metadata from a :py:class:`~scalarstop.datablob.DataBlob` 's directory on the
filesystem.
"""
try:
with open(cls._pickle_filename(path), "rb") as fh:
metadata = scalarstop.pickle.load(fh)
except FileNotFoundError as exc:
raise DataBlobNotFound(path) from exc
return cls(
name=metadata["name"],
group_name=metadata["group_name"],
save_load_version=metadata.get("save_load_version", 1),
num_shards=metadata.get("num_shards", 1),
hyperparams=metadata["hyperparams"],
)
@classmethod
def from_datablob(
cls,
datablob: "scalarstop.datablob.DataBlob",
*,
save_load_version: int,
num_shards: int,
):
"""
Creates a :py:class:`DataBlobMetadata` object in memory
from a :py:class:`~scalarstop.datablob.DataBlob` instance.
Args:
datablob: The :py:class:`~scalarstop.datablob.DataBlob` for which this
:py:class:`DataBlobMetadata` object is being created for.
save_load_version: The protocol version used to save or load
this :py:class:`~scalarstop.datablob.DataBlob` to/from the filesystem.
num_shards: The number of shards to divide the
:py:class:`~scalarstop.datablob.DataBlob` into when
saving to the filesystem.
"""
return cls(
name=datablob.name,
group_name=datablob.group_name,
hyperparams=datablob.hyperparams,
save_load_version=save_load_version,
num_shards=num_shards,
)
def __init__(
self,
*,
name: str,
group_name: str,
hyperparams: HyperparamsType,
save_load_version: int,
num_shards: int,
):
"""
Creates a :py:class:`DataBlobMetadata` object in memory.
Args:
name: The :py:class:`~scalarstop.datablob.DataBlob` name.
group_name: The :py:class:`~scalarstop.datablob.DataBlob` group name.
hyperparams: The ``Hyperparams`` object for the
:py:class:`~scalarstop.datablob.DataBlob`.
This has to be an instance of
:py:class:`~scalarstop.hyperparams.HyperparamsType` and
**not** a Python dictionary.
save_load_version: The protocol version used to save or load
this :py:class:`~scalarstop.datablob.DataBlob` to/from the filesystem.
num_shards: The number of shards to divide the
:py:class:`~scalarstop.datablob.DataBlob` into when
saving to the filesystem.
"""
if not isinstance(hyperparams, HyperparamsType):
raise TypeError(
"ScalarStop's DataBlobMetadata requires a HyperparamsType "
"instance for the `hyperparams` parameter. You provided "
f"the object {hyperparams} of type {type(hyperparams)}."
)
self.name = name
self.group_name = group_name
self.save_load_version = save_load_version
self.num_shards = num_shards
self.hyperparams = hyperparams
def to_dict(self, *, hyperparams_as_dict: bool = False) -> Dict[str, Any]:
"""
Return the metadata as a Python dictionary.
Args:
hyperparams_as_dict: Set to ``True`` to convert a
:py:class:`~scalarstop.hyperparams.HyperparamsType`
object to a Python dictionary.
"""
if hyperparams_as_dict:
hyperparams: Union[HyperparamsType, Dict[str, Any]] = asdict(
self.hyperparams
)
else:
hyperparams = self.hyperparams
return dict(
name=self.name,
group_name=self.group_name,
save_load_version=self.save_load_version,
num_shards=self.num_shards,
hyperparams=hyperparams,
)
def save(self, path: str):
"""
Save the metadata to a given
:py:class:`~scalarstop.datablob.DataBlob` directory on
the filesystem.
Args:
path: The :py:class:`~scalarstop.datablob.DataBlob` directory
on the filesystem to save the metadata to.
"""
with open(self._json_filename(path), "w", encoding="utf-8") as fh:
json.dump(
obj=self.to_dict(hyperparams_as_dict=True),
fp=fh,
sort_keys=True,
indent=4,
)
with open(self._pickle_filename(path), "wb") as fh: # type: ignore
scalarstop.pickle.dump(
obj=self.to_dict(hyperparams_as_dict=False),
file=fh,
)
def __eq__(self, other) -> bool:
if isinstance(other, DataBlobMetadata):
self_dict = self.to_dict(hyperparams_as_dict=False)
other_dict = other.to_dict(hyperparams_as_dict=False)
return self_dict == other_dict
return False
def __ne__(self, other) -> bool:
return not self.__eq__(other)