forked from ray-project/ray
-
Notifications
You must be signed in to change notification settings - Fork 1
/
utils.py
740 lines (575 loc) · 23.8 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
import copy
import importlib
import inspect
import logging
import math
import os
import random
import string
import time
import traceback
from enum import Enum
from functools import wraps
from typing import (
Any,
Callable,
Dict,
Iterable,
List,
Tuple,
TypeVar,
Union,
Optional,
)
import threading
import fastapi.encoders
import numpy as np
import pydantic
import pydantic.json
import requests
import ray
import ray.util.serialization_addons
from ray.actor import ActorHandle
from ray.exceptions import RayTaskError
from ray.serve._private.constants import (
HTTP_PROXY_TIMEOUT,
SERVE_LOGGER_NAME,
)
from ray.types import ObjectRef
from ray.util.serialization import StandaloneSerializationContext
from ray._raylet import MessagePackSerializer
from ray._private.utils import import_attr
from ray._private.usage.usage_lib import TagKey, record_extra_usage_tag
from ray._private.resource_spec import HEAD_NODE_RESOURCE_NAME
import __main__
try:
import pandas as pd
except ImportError:
pd = None
ACTOR_FAILURE_RETRY_TIMEOUT_S = 60
MESSAGE_PACK_OFFSET = 9
# Use a global singleton enum to emulate default options. We cannot use None
# for those option because None is a valid new value.
class DEFAULT(Enum):
VALUE = 1
class DeploymentOptionUpdateType(str, Enum):
# Nothing needs to be done other than setting the target state.
LightWeight = "LightWeight"
# Each DeploymentReplica instance (tracked in DeploymentState) uses certain options
# from the deployment config. These values need to be updated in DeploymentReplica.
NeedsReconfigure = "NeedsReconfigure"
# Options that are sent to the replica actor. If changed, reconfigure() on the actor
# needs to be called to update these values.
NeedsActorReconfigure = "NeedsActorReconfigure"
# If changed, restart all replicas.
HeavyWeight = "HeavyWeight"
# Type alias: objects that can be DEFAULT.VALUE have type Default[T]
T = TypeVar("T")
Default = Union[DEFAULT, T]
logger = logging.getLogger(SERVE_LOGGER_NAME)
class _ServeCustomEncoders:
"""Group of custom encoders for common types that's not handled by FastAPI."""
@staticmethod
def encode_np_array(obj):
assert isinstance(obj, np.ndarray)
if obj.dtype.kind == "f": # floats
obj = obj.astype(float)
if obj.dtype.kind in {"i", "u"}: # signed and unsigned integers.
obj = obj.astype(int)
return obj.tolist()
@staticmethod
def encode_np_scaler(obj):
assert isinstance(obj, np.generic)
return obj.item()
@staticmethod
def encode_exception(obj):
assert isinstance(obj, Exception)
return str(obj)
@staticmethod
def encode_pandas_dataframe(obj):
assert isinstance(obj, pd.DataFrame)
return obj.to_dict(orient="records")
serve_encoders = {
np.ndarray: _ServeCustomEncoders.encode_np_array,
np.generic: _ServeCustomEncoders.encode_np_scaler,
Exception: _ServeCustomEncoders.encode_exception,
}
if pd is not None:
serve_encoders[pd.DataFrame] = _ServeCustomEncoders.encode_pandas_dataframe
def install_serve_encoders_to_fastapi():
"""Inject Serve's encoders so FastAPI's jsonable_encoder can pick it up."""
# https://stackoverflow.com/questions/62311401/override-default-encoders-for-jsonable-encoder-in-fastapi # noqa
pydantic.json.ENCODERS_BY_TYPE.update(serve_encoders)
# FastAPI cache these encoders at import time, so we also needs to refresh it.
fastapi.encoders.encoders_by_class_tuples = (
fastapi.encoders.generate_encoders_by_class_tuples(
pydantic.json.ENCODERS_BY_TYPE
)
)
@ray.remote(num_cpus=0)
def block_until_http_ready(
http_endpoint,
backoff_time_s=1,
check_ready=None,
timeout=HTTP_PROXY_TIMEOUT,
):
http_is_ready = False
start_time = time.time()
while not http_is_ready:
try:
resp = requests.get(http_endpoint)
assert resp.status_code == 200
if check_ready is None:
http_is_ready = True
else:
http_is_ready = check_ready(resp)
except Exception:
pass
if 0 < timeout < time.time() - start_time:
raise TimeoutError("HTTP proxy not ready after {} seconds.".format(timeout))
time.sleep(backoff_time_s)
def get_random_letters(length=6):
return "".join(random.choices(string.ascii_letters, k=length))
def format_actor_name(actor_name, controller_name=None, *modifiers):
if controller_name is None:
name = actor_name
else:
name = "{}:{}".format(controller_name, actor_name)
for modifier in modifiers:
name += "-{}".format(modifier)
return name
def compute_iterable_delta(old: Iterable, new: Iterable) -> Tuple[set, set, set]:
"""Given two iterables, return the entries that's (added, removed, updated).
Usage:
>>> from ray.serve._private.utils import compute_iterable_delta
>>> old = {"a", "b"}
>>> new = {"a", "d"}
>>> compute_iterable_delta(old, new)
({'d'}, {'b'}, {'a'})
"""
old_keys, new_keys = set(old), set(new)
added_keys = new_keys - old_keys
removed_keys = old_keys - new_keys
updated_keys = old_keys.intersection(new_keys)
return added_keys, removed_keys, updated_keys
def compute_dict_delta(old_dict, new_dict) -> Tuple[dict, dict, dict]:
"""Given two dicts, return the entries that's (added, removed, updated).
Usage:
>>> from ray.serve._private.utils import compute_dict_delta
>>> old = {"a": 1, "b": 2}
>>> new = {"a": 3, "d": 4}
>>> compute_dict_delta(old, new)
({'d': 4}, {'b': 2}, {'a': 3})
"""
added_keys, removed_keys, updated_keys = compute_iterable_delta(
old_dict.keys(), new_dict.keys()
)
return (
{k: new_dict[k] for k in added_keys},
{k: old_dict[k] for k in removed_keys},
{k: new_dict[k] for k in updated_keys},
)
def ensure_serialization_context():
"""Ensure the serialization addons on registered, even when Ray has not
been started."""
ctx = StandaloneSerializationContext()
ray.util.serialization_addons.apply(ctx)
def wrap_to_ray_error(function_name: str, exception: Exception) -> RayTaskError:
"""Utility method to wrap exceptions in user code."""
try:
# Raise and catch so we can access traceback.format_exc()
raise exception
except Exception as e:
traceback_str = ray._private.utils.format_error_message(traceback.format_exc())
return ray.exceptions.RayTaskError(function_name, traceback_str, e)
def msgpack_serialize(obj):
ctx = ray._private.worker.global_worker.get_serialization_context()
buffer = ctx.serialize(obj)
serialized = buffer.to_bytes()
return serialized
def msgpack_deserialize(data):
# todo: Ray does not provide a msgpack deserialization api.
try:
obj = MessagePackSerializer.loads(data[MESSAGE_PACK_OFFSET:], None)
except Exception:
raise
return obj
def merge_dict(dict1, dict2):
if dict1 is None and dict2 is None:
return None
if dict1 is None:
dict1 = dict()
if dict2 is None:
dict2 = dict()
result = dict()
for key in dict1.keys() | dict2.keys():
result[key] = sum([e.get(key, 0) for e in (dict1, dict2)])
return result
def get_deployment_import_path(
deployment, replace_main=False, enforce_importable=False
):
"""
Gets the import path for deployment's func_or_class.
deployment: A deployment object whose import path should be returned
replace_main: If this is True, the function will try to replace __main__
with __main__'s file name if the deployment's module is __main__
"""
body = deployment._func_or_class
if isinstance(body, str):
# deployment's func_or_class is already an import path
return body
elif hasattr(body, "__ray_actor_class__"):
# If ActorClass, get the class or function inside
body = body.__ray_actor_class__
import_path = f"{body.__module__}.{body.__qualname__}"
if enforce_importable and "<locals>" in body.__qualname__:
raise RuntimeError(
"Deployment definitions must be importable to build the Serve app, "
f"but deployment '{deployment.name}' is inline defined or returned "
"from another function. Please restructure your code so that "
f"'{import_path}' can be imported (i.e., put it in a module)."
)
if replace_main:
# Replaces __main__ with its file name. E.g. suppose the import path
# is __main__.classname and classname is defined in filename.py.
# Its import path becomes filename.classname.
if import_path.split(".")[0] == "__main__" and hasattr(__main__, "__file__"):
file_name = os.path.basename(__main__.__file__)
extensionless_file_name = file_name.split(".")[0]
attribute_name = import_path.split(".")[-1]
import_path = f"{extensionless_file_name}.{attribute_name}"
return import_path
def parse_import_path(import_path: str):
"""
Takes in an import_path of form:
[subdirectory 1].[subdir 2]...[subdir n].[file name].[attribute name]
Parses this path and returns the module name (everything before the last
dot) and attribute name (everything after the last dot), such that the
attribute can be imported using "from module_name import attr_name".
"""
nodes = import_path.split(".")
if len(nodes) < 2:
raise ValueError(
f"Got {import_path} as import path. The import path "
f"should at least specify the file name and "
f"attribute name connected by a dot."
)
return ".".join(nodes[:-1]), nodes[-1]
def override_runtime_envs_except_env_vars(parent_env: Dict, child_env: Dict) -> Dict:
"""Creates a runtime_env dict by merging a parent and child environment.
This method is not destructive. It leaves the parent and child envs
the same.
The merge is a shallow update where the child environment inherits the
parent environment's settings. If the child environment specifies any
env settings, those settings take precdence over the parent.
- Note: env_vars are a special case. The child's env_vars are combined
with the parent.
Args:
parent_env: The environment to inherit settings from.
child_env: The environment with override settings.
Returns: A new dictionary containing the merged runtime_env settings.
Raises:
TypeError: If a dictionary is not passed in for parent_env or child_env.
"""
if not isinstance(parent_env, Dict):
raise TypeError(
f'Got unexpected type "{type(parent_env)}" for parent_env. '
"parent_env must be a dictionary."
)
if not isinstance(child_env, Dict):
raise TypeError(
f'Got unexpected type "{type(child_env)}" for child_env. '
"child_env must be a dictionary."
)
defaults = copy.deepcopy(parent_env)
overrides = copy.deepcopy(child_env)
default_env_vars = defaults.get("env_vars", {})
override_env_vars = overrides.get("env_vars", {})
defaults.update(overrides)
default_env_vars.update(override_env_vars)
defaults["env_vars"] = default_env_vars
return defaults
class JavaActorHandleProxy:
"""Wraps actor handle and translate snake_case to camelCase."""
def __init__(self, handle: ActorHandle):
self.handle = handle
self._available_attrs = set(dir(self.handle))
def __getattr__(self, key: str):
if key in self._available_attrs:
camel_case_key = key
else:
components = key.split("_")
camel_case_key = components[0] + "".join(x.title() for x in components[1:])
return getattr(self.handle, camel_case_key)
def require_packages(packages: List[str]):
"""Decorator making sure function run in specified environments
Examples:
>>> from ray.serve._private.utils import require_packages
>>> @require_packages(["numpy", "package_a"]) # doctest: +SKIP
... def func(): # doctest: +SKIP
... import numpy as np # doctest: +SKIP
... ... # doctest: +SKIP
>>> func() # doctest: +SKIP
ImportError: func requires ["numpy", "package_a"] but
["package_a"] are not available, please pip install them.
"""
def decorator(func):
def check_import_once():
if not hasattr(func, "_require_packages_checked"):
missing_packages = []
for package in packages:
try:
importlib.import_module(package)
except ModuleNotFoundError:
missing_packages.append(package)
if len(missing_packages) > 0:
raise ImportError(
f"{func} requires packages {packages} to run but "
f"{missing_packages} are missing. Please "
"`pip install` them or add them to "
"`runtime_env`."
)
setattr(func, "_require_packages_checked", True)
if inspect.iscoroutinefunction(func):
@wraps(func)
async def wrapped(*args, **kwargs):
check_import_once()
return await func(*args, **kwargs)
elif inspect.isroutine(func):
@wraps(func)
def wrapped(*args, **kwargs):
check_import_once()
return func(*args, **kwargs)
else:
raise ValueError("Decorator expect callable functions.")
return wrapped
return decorator
def in_interactive_shell():
# Taken from:
# https://stackoverflow.com/questions/15411967/how-can-i-check-if-code-is-executed-in-the-ipython-notebook
import __main__ as main
return not hasattr(main, "__file__")
def guarded_deprecation_warning(*args, **kwargs):
"""Wrapper for deprecation warnings, guarded by a flag."""
if os.environ.get("SERVE_WARN_V1_DEPRECATIONS", "0") == "1":
from ray._private.utils import deprecated
return deprecated(*args, **kwargs)
else:
def noop_decorator(func):
return func
return noop_decorator
def snake_to_camel_case(snake_str: str) -> str:
"""Convert a snake case string to camel case."""
words = snake_str.strip("_").split("_")
return words[0] + "".join(word[:1].upper() + word[1:] for word in words[1:])
def dict_keys_snake_to_camel_case(snake_dict: dict) -> dict:
"""Converts dictionary's keys from snake case to camel case.
Does not modify original dictionary.
"""
camel_dict = dict()
for key, val in snake_dict.items():
if isinstance(key, str):
camel_dict[snake_to_camel_case(key)] = val
else:
camel_dict[key] = val
return camel_dict
def check_obj_ref_ready_nowait(obj_ref: ObjectRef) -> bool:
"""Check if ray object reference is ready without waiting for it."""
finished, _ = ray.wait([obj_ref], timeout=0)
return len(finished) == 1
serve_telemetry_tag_map = {
"SERVE_API_VERSION": TagKey.SERVE_API_VERSION,
"SERVE_NUM_DEPLOYMENTS": TagKey.SERVE_NUM_DEPLOYMENTS,
"GCS_STORAGE": TagKey.GCS_STORAGE,
"SERVE_NUM_GPU_DEPLOYMENTS": TagKey.SERVE_NUM_GPU_DEPLOYMENTS,
"SERVE_FASTAPI_USED": TagKey.SERVE_FASTAPI_USED,
"SERVE_DAG_DRIVER_USED": TagKey.SERVE_DAG_DRIVER_USED,
"SERVE_HTTP_ADAPTER_USED": TagKey.SERVE_HTTP_ADAPTER_USED,
"SERVE_GRPC_INGRESS_USED": TagKey.SERVE_GRPC_INGRESS_USED,
"SERVE_REST_API_VERSION": TagKey.SERVE_REST_API_VERSION,
"SERVE_NUM_APPS": TagKey.SERVE_NUM_APPS,
"SERVE_NUM_REPLICAS_LIGHTWEIGHT_UPDATED": (
TagKey.SERVE_NUM_REPLICAS_LIGHTWEIGHT_UPDATED
),
"SERVE_USER_CONFIG_LIGHTWEIGHT_UPDATED": (
TagKey.SERVE_USER_CONFIG_LIGHTWEIGHT_UPDATED
),
"SERVE_AUTOSCALING_CONFIG_LIGHTWEIGHT_UPDATED": (
TagKey.SERVE_AUTOSCALING_CONFIG_LIGHTWEIGHT_UPDATED
),
}
def record_serve_tag(key: str, value: str):
"""Record telemetry.
TagKey objects cannot be pickled, so deployments can't directly record
telemetry using record_extra_usage_tag. They can instead call this function
which records telemetry for them.
"""
if key not in serve_telemetry_tag_map:
raise ValueError(
f'The TagKey "{key}" does not exist. Expected a key from: '
f"{list(serve_telemetry_tag_map.keys())}."
)
record_extra_usage_tag(serve_telemetry_tag_map[key], value)
def extract_self_if_method_call(args: List[Any], func: Callable) -> Optional[object]:
"""Check if this is a method rather than a function.
Does this by checking to see if `func` is the attribute of the first
(`self`) argument under `func.__name__`. Unfortunately, this is the most
robust solution to this I was able to find. It would also be preferable
to do this check when the decorator runs, rather than when the method is.
Returns the `self` object if it's a method call, else None.
Arguments:
args: arguments to the function/method call.
func: the unbound function that was called.
"""
if len(args) > 0:
method = getattr(args[0], func.__name__, False)
if method:
wrapped = getattr(method, "__wrapped__", False)
if wrapped and wrapped == func:
return args[0]
return None
class _MetricTask:
def __init__(self, task_func, interval_s, callback_func):
"""
Args:
task_func: a callable that MetricsPusher will try to call in each loop.
interval_s: the interval of each task_func is supposed to be called.
callback_func: callback function is called when task_func is done, and
the result of task_func is passed to callback_func as the first
argument, and the timestamp of the call is passed as the second
argument.
"""
self.task_func: Callable = task_func
self.interval_s: float = interval_s
self.callback_func: Callable[[Any, float]] = callback_func
self.last_ref: Optional[ray.ObjectRef] = None
self.last_call_succeeded_time: Optional[float] = time.time()
class MetricsPusher:
"""
Metrics pusher is a background thread that run the registered tasks in a loop.
"""
def __init__(
self,
):
self.tasks: List[_MetricTask] = []
self.pusher_thread: Union[threading.Thread, None] = None
self.stop_event = threading.Event()
def register_task(self, task_func, interval_s, process_func=None):
self.tasks.append(_MetricTask(task_func, interval_s, process_func))
def start(self):
"""Start a background thread to run the registered tasks in a loop.
We use this background so it will be not blocked by user's code and ensure
consistently metrics delivery. Python GIL will ensure that this thread gets
fair timeshare to execute and run.
"""
def send_forever():
while True:
if self.stop_event.is_set():
return
start = time.time()
for task in self.tasks:
try:
if start - task.last_call_succeeded_time >= task.interval_s:
if task.last_ref:
ready_refs, _ = ray.wait([task.last_ref], timeout=0)
if len(ready_refs) == 0:
continue
data = task.task_func()
task.last_call_succeeded_time = time.time()
if task.callback_func and ray.is_initialized():
task.last_ref = task.callback_func(
data, send_timestamp=time.time()
)
except Exception as e:
logger.warning(
f"MetricsPusher thread failed to run metric task: {e}"
)
# For all tasks, check when the task should be executed
# next. Sleep until the next closest time.
least_interval_s = math.inf
for task in self.tasks:
time_until_next_push = task.interval_s - (
time.time() - task.last_call_succeeded_time
)
least_interval_s = min(least_interval_s, time_until_next_push)
time.sleep(max(least_interval_s, 0))
if len(self.tasks) == 0:
raise ValueError("MetricsPusher has zero tasks registered.")
self.pusher_thread = threading.Thread(target=send_forever)
# Making this a daemon thread so it doesn't leak upon shutdown, and it
# doesn't need to block the replica's shutdown.
self.pusher_thread.setDaemon(True)
self.pusher_thread.start()
def __del__(self):
self.shutdown()
def shutdown(self):
"""Shutdown metrics pusher gracefully.
This method will ensure idempotency of shutdown call.
"""
if not self.stop_event.is_set():
self.stop_event.set()
if self.pusher_thread:
self.pusher_thread.join()
def call_function_from_import_path(import_path: str) -> Any:
"""Call the function given import path.
Args:
import_path: The import path of the function to call.
Raises:
ValueError: If the import path is invalid.
TypeError: If the import path is not callable.
RuntimeError: if the function raise exeception during execution.
Returns:
The result of the function call.
"""
try:
callback_func = import_attr(import_path)
except Exception as e:
raise ValueError(f"The import path {import_path} cannot be imported: {e}")
if not callable(callback_func):
raise TypeError(f"The import path {import_path} is not callable.")
try:
return callback_func()
except Exception as e:
raise RuntimeError(f"The function {import_path} raised an exception: {e}")
def get_head_node_id() -> str:
"""Get the head node id.
Iterate through all nodes in the ray cluster and return the node id of the first
alive node with head node resource.
"""
head_node_id = None
for node in ray.nodes():
if HEAD_NODE_RESOURCE_NAME in node["Resources"] and node["Alive"]:
head_node_id = node["NodeID"]
break
assert head_node_id is not None, "Cannot find alive head node."
return head_node_id
def calculate_remaining_timeout(
*,
timeout_s: Optional[float],
start_time_s: float,
curr_time_s: float,
) -> Optional[float]:
"""Get the timeout remaining given an overall timeout, start time, and curr time.
If the timeout passed in was `None` or negative, will always return that timeout
directly.
If the timeout is >= 0, the returned remaining timeout always be >= 0.
"""
if timeout_s is None or timeout_s < 0:
return timeout_s
time_since_start_s = curr_time_s - start_time_s
return max(0, timeout_s - time_since_start_s)
def get_all_live_placement_group_names() -> List[str]:
"""Fetch and parse the Ray placement group table for live placement group names.
Placement groups are filtered based on their `scheduling_state`; any placement
group not in the "REMOVED" state is considered live.
"""
placement_group_table = ray.util.placement_group_table()
live_pg_names = []
for entry in placement_group_table.values():
pg_name = entry.get("name", "")
if (
pg_name
and entry.get("stats", {}).get("scheduling_state", "UNKNOWN") != "REMOVED"
):
live_pg_names.append(pg_name)
return live_pg_names