-
-
Notifications
You must be signed in to change notification settings - Fork 1.3k
/
base.py
353 lines (290 loc) · 12.5 KB
/
base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
"""Abstract base class for the Keras neural network regressors.
The reason for this class between BaseClassifier and deep_learning classifiers is
because we can generalise tags and _predict
"""
__author__ = ["AurumnPegasus", "achieveordie"]
__all__ = ["BaseDeepRegressor"]
import os
from abc import ABC, abstractmethod
import numpy as np
from sktime.regression.base import BaseRegressor
from sktime.utils.validation._dependencies import _check_soft_dependencies
class BaseDeepRegressor(BaseRegressor, ABC):
"""Abstract base class for deep learning time series regression.
The base classifier provides a deep learning default method for
_predict, and provides a new abstract method for building a
model.
Parameters
----------
batch_size : int, default = 40
training batch size for the model
Attributes
----------
self.model_ - the fitted DL model
"""
_tags = {
"X_inner_mtype": "numpy3D",
"capability:multivariate": True,
"python_dependencies": "tensorflow",
}
@abstractmethod
def build_model(self, input_shape, **kwargs):
"""Construct a compiled, un-trained, keras model that is ready for training.
Parameters
----------
input_shape : tuple
The shape of the data fed into the input layer
Returns
-------
A compiled Keras Model
"""
...
def _predict(self, X, **kwargs):
"""Find regression estimate for all cases in X.
Parameters
----------
X : an np.ndarray of shape = (n_instances, n_dimensions, series_length)
The training input samples.
Returns
-------
predictions : 1d numpy array
array of predictions of each instance
"""
X = X.transpose((0, 2, 1))
y_pred = self.model_.predict(X, self.batch_size, **kwargs)
y_pred = np.squeeze(y_pred, axis=-1)
return y_pred
def __getstate__(self):
"""Get Dict config that will be used when a serialization method is called.
Returns
-------
copy : dict, the config to be serialized
"""
from tensorflow.keras.optimizers import Optimizer, serialize
copy = self.__dict__.copy()
# Either optimizer might not exist at all(-1),
# or it does and takes a value(including None)
optimizer_attr = copy.get("optimizer", -1)
if not isinstance(optimizer_attr, str):
if optimizer_attr is None:
# if it is None, then save it as 0, so it can be
# later correctly restored as None
copy["optimizer"] = 0
elif optimizer_attr == -1:
# if an `optimizer` parameter doesn't exist at all
# save it as -1
copy["optimizer"] = -1
elif isinstance(optimizer_attr, Optimizer):
copy["optimizer"] = serialize(optimizer_attr)
else:
raise ValueError(
f"`optimizer` of type {type(optimizer_attr)} cannot be "
"serialized, it should either be absent/None/str/"
"tf.keras.optimizers.Optimizer object"
)
else:
# if it was a string, don't touch since already serializable
pass
check_before_deletion = ["model_", "history", "optimizer_"]
for attribute in check_before_deletion:
if copy.get(attribute) is not None:
del copy[attribute]
return copy
def __setstate__(self, state):
"""Magic method called during deserialization.
Parameters
----------
state : dict, as returned from __getstate__(), used for correct deserialization
Returns
-------
-
"""
from tensorflow.keras.optimizers import deserialize
self.__dict__ = state
if hasattr(self, "model_"):
self.__dict__["model_"] = self.model_
if hasattr(self, "model_.optimizer"):
self.__dict__["optimizer_"] = self.model_.optimizer
# if optimizer_ exists, set optimizer as optimizer_
if self.__dict__.get("optimizer_") is not None:
self.__dict__["optimizer"] = self.__dict__["optimizer_"]
# else model may not have been built, but an optimizer might be passed
else:
# Having 0 as value implies "optimizer" attribute was None
# as per __getstate__()
if self.__dict__.get("optimizer") == 0:
self.__dict__["optimizer"] = None
elif self.__dict__.get("optimizer") == -1:
# `optimizer` doesn't exist as a parameter alone, so delete it.
del self.__dict__["optimizer"]
else:
if isinstance(self.optimizer, dict):
self.__dict__["optimizer"] = deserialize(self.optimizer)
else:
# must have been a string already, no need to set
pass
if hasattr(self, "history"):
self.__dict__["history"] = self.history
def save(self, path=None, legacy_save=False):
"""Save serialized self to bytes-like object or to (.zip) file.
Behaviour:
if ``path`` is None, returns an in-memory serialized self
if ``path`` is a file, stores the zip with that name at the location.
The contents of the zip file are:
_metadata - contains class of self, i.e., type(self).
_obj - serialized self. This class uses the default serialization (pickle).
keras/ - model, optimizer and state stored inside this directory.
history - serialized history object.
Parameters
----------
path : None or file location (str or Path)
if None, self is saved to an in-memory object
if file location, self is saved to that file location. For eg:
path="estimator" then a zip file ``estimator.zip`` will be made at cwd.
path="/home/stored/estimator" then a zip file ``estimator.zip`` will be
stored in ``/home/stored/``.
legacy_save : bool, default = False
whether to use the legacy saving method for the model. If
tensorflow >= 2.16.0 is installed, this is ignored.
The legacy saving method will be removed in sktime 0.30.0.
Returns
-------
if ``path`` is None - in-memory serialized self
if ``path`` is file location - ZipFile with reference to the file
"""
# TODO 0.30.0 - remove the legacy_save parameter in sktime 0.30.0
import pickle
import shutil
from pathlib import Path
from zipfile import ZipFile
if legacy_save:
from sktime.utils.warnings import warn
warn(
"WARNING: In the save method of classifiers and regressors,"
" saving logic has changed to be compatible with tensorflow 2.16. "
"The old saving logic is deprecated and will be removed in "
"sktime 0.30.0. "
"If tensorflow>=2.16.0 is installed, the new saving logic is always "
"used. If not, by default, the legacy saving logic is used until "
"sktime 0.28.last, and the new logic is used from sktime 0.29.0."
"For safe change in an environment with tensorflow<2.16.0, "
"set the legacy_save parameter explicitly to False to test the "
"new saving logic. If no issues are found, no changes to your code "
"are necessary. To keep using the legacy method, set the parameter "
"legacy_save to True. Note that the legacy_save parameter will be "
"removed entirely in sktime 0.30.0.",
FutureWarning,
obj=self,
stacklevel=2,
)
if _check_soft_dependencies("tensorflow>=2.16.0", severity="none"):
legacy_save = False
if path is None:
_check_soft_dependencies("h5py")
import h5py
in_memory_model = None
if self.model_ is not None:
self.model_.save("disk_less.h5")
with h5py.File("disk_less.h5", "r") as h5file:
in_memory_model = h5file.id.get_file_image()
in_memory_history = pickle.dumps(self.history.history)
return (
type(self),
(
pickle.dumps(self),
in_memory_model,
in_memory_history,
),
)
if not isinstance(path, (str, Path)):
raise TypeError(
"`path` is expected to either be a string or a Path object "
f"but found of type:{type(path)}."
)
path = Path(path) if isinstance(path, str) else path
path.mkdir()
if self.model_ is not None:
if not legacy_save:
keras_path = path / "keras" / "model.keras"
os.makedirs(keras_path.parent, exist_ok=True)
self.model_.save(keras_path)
else:
self.model_.save(path / "keras/")
with open(path / "history", "wb") as history_writer:
pickle.dump(self.history.history, history_writer)
pickle.dump(type(self), open(path / "_metadata", "wb"))
pickle.dump(self, open(path / "_obj", "wb"))
shutil.make_archive(base_name=path, format="zip", root_dir=path)
shutil.rmtree(path)
return ZipFile(path.with_name(f"{path.stem}.zip"))
@classmethod
def load_from_serial(cls, serial):
"""Load object from serialized memory container.
Parameters
----------
serial: 1st element of output of ``cls.save(None)``
This is a tuple of size 3.
The first element represents pickle-serialized instance.
The second element represents h5py-serialized ``keras`` model.
The third element represent pickle-serialized history of ``.fit()``.
Returns
-------
Deserialized self resulting in output ``serial``, of ``cls.save(None)``
"""
import pickle
from tensorflow.keras.models import load_model
if not isinstance(serial, tuple):
raise TypeError(
"`serial` is expected to be a tuple, "
f"instead found of type: {type(serial)}"
)
if len(serial) != 3:
raise ValueError(
"`serial` should have 3 elements. "
"All 3 elements represent in-memory serialization "
"of the estimator. "
f"Found a tuple of length: {len(serial)} instead."
)
serial, in_memory_model, in_memory_history = serial
if in_memory_model is None:
cls.model_ = None
else:
with open("diskless.h5", "wb") as store_:
store_.write(in_memory_model)
cls.model_ = load_model("diskless.h5")
cls.history = pickle.loads(in_memory_history)
return pickle.loads(serial)
@classmethod
def load_from_path(cls, serial):
"""Load object from file location.
Parameters
----------
serial : Name of the zip file.
Returns
-------
deserialized self resulting in output at ``path``, of ``cls.save(path)``
"""
import pickle
from shutil import rmtree
from zipfile import ZipFile
from tensorflow import keras
temp_unzip_loc = serial.parent / "temp_unzip/"
temp_unzip_loc.mkdir()
with ZipFile(serial, mode="r") as zip_file:
for file in zip_file.namelist():
if not file.startswith("keras/"):
continue
zip_file.extract(file, temp_unzip_loc)
keras_location_legacy = temp_unzip_loc / "keras"
keras_location = temp_unzip_loc / "keras" / "model.keras"
if keras_location.exists():
cls.model_ = keras.models.load_model(keras_location)
elif keras_location_legacy.exists():
cls.model_ = keras.models.load_model(keras_location_legacy)
else:
cls.model_ = None
rmtree(temp_unzip_loc)
cls.history = keras.callbacks.History()
with ZipFile(serial, mode="r") as file:
cls.history.set_params(pickle.loads(file.open("history").read()))
return pickle.loads(file.open("_obj").read())