-
Notifications
You must be signed in to change notification settings - Fork 5.4k
/
checkpoints.py
336 lines (278 loc) · 12.1 KB
/
checkpoints.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
import logging
import json
import os
from packaging import version
import re
from typing import Any, Dict, Union
import ray
from ray.rllib.utils.serialization import NOT_SERIALIZABLE, serialize_type
from ray.train import Checkpoint
from ray.util import log_once
from ray.util.annotations import PublicAPI
logger = logging.getLogger(__name__)
# The current checkpoint version used by RLlib for Algorithm and Policy checkpoints.
# History:
# 0.1: Ray 2.0.0
# A single `checkpoint-[iter num]` file for Algorithm checkpoints
# within the checkpoint directory. Policy checkpoints not supported across all
# DL frameworks.
# 1.0: Ray >=2.1.0
# An algorithm_state.pkl file for the state of the Algorithm (excluding
# individual policy states).
# One sub-dir inside the "policies" sub-dir for each policy with a
# dedicated policy_state.pkl in it for the policy state.
# 1.1: Same as 1.0, but has a new "format" field in the rllib_checkpoint.json file
# indicating, whether the checkpoint is `cloudpickle` (default) or `msgpack`.
# 1.2: Introduces the checkpoint for the new Learner API if the Learner api is enabled.
CHECKPOINT_VERSION = version.Version("1.1")
CHECKPOINT_VERSION_LEARNER = version.Version("1.2")
@PublicAPI(stability="alpha")
def get_checkpoint_info(checkpoint: Union[str, Checkpoint]) -> Dict[str, Any]:
"""Returns a dict with information about a Algorithm/Policy checkpoint.
If the given checkpoint is a >=v1.0 checkpoint directory, try reading all
information from the contained `rllib_checkpoint.json` file.
Args:
checkpoint: The checkpoint directory (str) or an AIR Checkpoint object.
Returns:
A dict containing the keys:
"type": One of "Policy" or "Algorithm".
"checkpoint_version": A version tuple, e.g. v1.0, indicating the checkpoint
version. This will help RLlib to remain backward compatible wrt. future
Ray and checkpoint versions.
"checkpoint_dir": The directory with all the checkpoint files in it. This might
be the same as the incoming `checkpoint` arg.
"state_file": The main file with the Algorithm/Policy's state information in it.
This is usually a pickle-encoded file.
"policy_ids": An optional set of PolicyIDs in case we are dealing with an
Algorithm checkpoint. None if `checkpoint` is a Policy checkpoint.
"""
# Default checkpoint info.
info = {
"type": "Algorithm",
"format": "cloudpickle",
"checkpoint_version": CHECKPOINT_VERSION,
"checkpoint_dir": None,
"state_file": None,
"policy_ids": None,
}
# `checkpoint` is a Checkpoint instance: Translate to directory and continue.
if isinstance(checkpoint, Checkpoint):
checkpoint: str = checkpoint.to_directory()
# Checkpoint is dir.
if os.path.isdir(checkpoint):
info.update({"checkpoint_dir": checkpoint})
# Figure out whether this is an older checkpoint format
# (with a `checkpoint-\d+` file in it).
for file in os.listdir(checkpoint):
path_file = os.path.join(checkpoint, file)
if os.path.isfile(path_file):
if re.match("checkpoint-\\d+", file):
info.update(
{
"checkpoint_version": version.Version("0.1"),
"state_file": path_file,
}
)
return info
# No old checkpoint file found.
# If rllib_checkpoint.json file present, read available information from it
# and then continue with the checkpoint analysis (possibly overriding further
# information).
if os.path.isfile(os.path.join(checkpoint, "rllib_checkpoint.json")):
with open(os.path.join(checkpoint, "rllib_checkpoint.json")) as f:
rllib_checkpoint_info = json.load(fp=f)
if "checkpoint_version" in rllib_checkpoint_info:
rllib_checkpoint_info["checkpoint_version"] = version.Version(
rllib_checkpoint_info["checkpoint_version"]
)
info.update(rllib_checkpoint_info)
else:
# No rllib_checkpoint.json file present: Warn and continue trying to figure
# out checkpoint info ourselves.
if log_once("no_rllib_checkpoint_json_file"):
logger.warning(
"No `rllib_checkpoint.json` file found in checkpoint directory "
f"{checkpoint}! Trying to extract checkpoint info from other files "
f"found in that dir."
)
# Policy checkpoint file found.
for extension in ["pkl", "msgpck"]:
if os.path.isfile(os.path.join(checkpoint, "policy_state." + extension)):
info.update(
{
"type": "Policy",
"format": "cloudpickle" if extension == "pkl" else "msgpack",
"checkpoint_version": CHECKPOINT_VERSION,
"state_file": os.path.join(
checkpoint, f"policy_state.{extension}"
),
}
)
return info
# Valid Algorithm checkpoint >v0 file found?
format = None
for extension in ["pkl", "msgpck"]:
state_file = os.path.join(checkpoint, f"algorithm_state.{extension}")
if os.path.isfile(state_file):
format = "cloudpickle" if extension == "pkl" else "msgpack"
break
if format is None:
raise ValueError(
"Given checkpoint does not seem to be valid! No file with the name "
"`algorithm_state.[pkl|msgpck]` (or `checkpoint-[0-9]+`) found."
)
info.update(
{
"format": format,
"state_file": state_file,
}
)
# Collect all policy IDs in the sub-dir "policies/".
policies_dir = os.path.join(checkpoint, "policies")
if os.path.isdir(policies_dir):
policy_ids = set()
for policy_id in os.listdir(policies_dir):
policy_ids.add(policy_id)
info.update({"policy_ids": policy_ids})
# Checkpoint is a file: Use as-is (interpreting it as old Algorithm checkpoint
# version).
elif os.path.isfile(checkpoint):
info.update(
{
"checkpoint_version": version.Version("0.1"),
"checkpoint_dir": os.path.dirname(checkpoint),
"state_file": checkpoint,
}
)
else:
raise ValueError(
f"Given checkpoint ({checkpoint}) not found! Must be a "
"checkpoint directory (or a file for older checkpoint versions)."
)
return info
@PublicAPI(stability="beta")
def convert_to_msgpack_checkpoint(
checkpoint: Union[str, Checkpoint],
msgpack_checkpoint_dir: str,
) -> str:
"""Converts an Algorithm checkpoint (pickle based) to a msgpack based one.
Msgpack has the advantage of being python version independent.
Args:
checkpoint: The directory, in which to find the Algorithm checkpoint (pickle
based).
msgpack_checkpoint_dir: The directory, in which to create the new msgpack
based checkpoint.
Returns:
The directory in which the msgpack checkpoint has been created. Note that
this is the same as `msgpack_checkpoint_dir`.
"""
from ray.rllib.algorithms import Algorithm
from ray.rllib.utils.policy import validate_policy_id
# Try to import msgpack and msgpack_numpy.
msgpack = try_import_msgpack(error=True)
# Restore the Algorithm using the python version dependent checkpoint.
algo = Algorithm.from_checkpoint(checkpoint)
state = algo.__getstate__()
# Convert all code in state into serializable data.
# Serialize the algorithm class.
state["algorithm_class"] = serialize_type(state["algorithm_class"])
# Serialize the algorithm's config object.
state["config"] = state["config"].serialize()
# Extract policy states from worker state (Policies get their own
# checkpoint sub-dirs).
policy_states = {}
if "worker" in state and "policy_states" in state["worker"]:
policy_states = state["worker"].pop("policy_states", {})
# Policy mapping fn.
state["worker"]["policy_mapping_fn"] = NOT_SERIALIZABLE
# Is Policy to train function.
state["worker"]["is_policy_to_train"] = NOT_SERIALIZABLE
# Add RLlib checkpoint version (as string).
if state["config"]["_enable_new_api_stack"]:
state["checkpoint_version"] = str(CHECKPOINT_VERSION_LEARNER)
else:
state["checkpoint_version"] = str(CHECKPOINT_VERSION)
# Write state (w/o policies) to disk.
state_file = os.path.join(msgpack_checkpoint_dir, "algorithm_state.msgpck")
with open(state_file, "wb") as f:
msgpack.dump(state, f)
# Write rllib_checkpoint.json.
with open(os.path.join(msgpack_checkpoint_dir, "rllib_checkpoint.json"), "w") as f:
json.dump(
{
"type": "Algorithm",
"checkpoint_version": state["checkpoint_version"],
"format": "msgpack",
"state_file": state_file,
"policy_ids": list(policy_states.keys()),
"ray_version": ray.__version__,
"ray_commit": ray.__commit__,
},
f,
)
# Write individual policies to disk, each in their own sub-directory.
for pid, policy_state in policy_states.items():
# From here on, disallow policyIDs that would not work as directory names.
validate_policy_id(pid, error=True)
policy_dir = os.path.join(msgpack_checkpoint_dir, "policies", pid)
os.makedirs(policy_dir, exist_ok=True)
policy = algo.get_policy(pid)
policy.export_checkpoint(
policy_dir,
policy_state=policy_state,
checkpoint_format="msgpack",
)
# Release all resources used by the Algorithm.
algo.stop()
return msgpack_checkpoint_dir
@PublicAPI(stability="beta")
def convert_to_msgpack_policy_checkpoint(
policy_checkpoint: Union[str, Checkpoint],
msgpack_checkpoint_dir: str,
) -> str:
"""Converts a Policy checkpoint (pickle based) to a msgpack based one.
Msgpack has the advantage of being python version independent.
Args:
policy_checkpoint: The directory, in which to find the Policy checkpoint (pickle
based).
msgpack_checkpoint_dir: The directory, in which to create the new msgpack
based checkpoint.
Returns:
The directory in which the msgpack checkpoint has been created. Note that
this is the same as `msgpack_checkpoint_dir`.
"""
from ray.rllib.policy.policy import Policy
policy = Policy.from_checkpoint(policy_checkpoint)
os.makedirs(msgpack_checkpoint_dir, exist_ok=True)
policy.export_checkpoint(
msgpack_checkpoint_dir,
policy_state=policy.get_state(),
checkpoint_format="msgpack",
)
# Release all resources used by the Policy.
del policy
return msgpack_checkpoint_dir
@PublicAPI
def try_import_msgpack(error: bool = False):
"""Tries importing msgpack and msgpack_numpy and returns the patched msgpack module.
Returns None if error is False and msgpack or msgpack_numpy is not installed.
Raises an error, if error is True and the modules could not be imported.
Args:
error: Whether to raise an error if msgpack/msgpack_numpy cannot be imported.
Returns:
The `msgpack` module.
Raises:
ImportError: If error=True and msgpack/msgpack_numpy is not installed.
"""
try:
import msgpack
import msgpack_numpy
# Make msgpack_numpy look like msgpack.
msgpack_numpy.patch()
return msgpack
except Exception:
if error:
raise ImportError(
"Could not import or setup msgpack and msgpack_numpy! "
"Try running `pip install msgpack msgpack_numpy` first."
)