-
Notifications
You must be signed in to change notification settings - Fork 5.5k
/
__init__.py
145 lines (123 loc) · 3.16 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import contextlib
from functools import partial
from ray.rllib.utils.annotations import override, PublicAPI, DeveloperAPI
from ray.rllib.utils.deprecation import deprecation_warning
from ray.rllib.utils.filter import Filter
from ray.rllib.utils.filter_manager import FilterManager
from ray.rllib.utils.framework import (
try_import_jax,
try_import_tf,
try_import_tfp,
try_import_torch,
)
from ray.rllib.utils.numpy import (
sigmoid,
softmax,
relu,
one_hot,
fc,
lstm,
SMALL_NUMBER,
LARGE_INTEGER,
MIN_LOG_NN_OUTPUT,
MAX_LOG_NN_OUTPUT,
)
from ray.rllib.utils.pre_checks.env import check_env
from ray.rllib.utils.schedules import (
LinearSchedule,
PiecewiseSchedule,
PolynomialSchedule,
ExponentialSchedule,
ConstantSchedule,
)
from ray.rllib.utils.test_utils import (
check,
check_compute_single_action,
check_train_results,
framework_iterator,
)
from ray.tune.utils import merge_dicts, deep_update
@DeveloperAPI
def add_mixins(base, mixins, reversed=False):
"""Returns a new class with mixins applied in priority order."""
mixins = list(mixins or [])
while mixins:
if reversed:
class new_base(base, mixins.pop()):
pass
else:
class new_base(mixins.pop(), base):
pass
base = new_base
return base
@DeveloperAPI
def force_list(elements=None, to_tuple=False):
"""
Makes sure `elements` is returned as a list, whether `elements` is a single
item, already a list, or a tuple.
Args:
elements (Optional[any]): The inputs as single item, list, or tuple to
be converted into a list/tuple. If None, returns empty list/tuple.
to_tuple: Whether to use tuple (instead of list).
Returns:
Union[list,tuple]: All given elements in a list/tuple depending on
`to_tuple`'s value. If elements is None,
returns an empty list/tuple.
"""
ctor = list
if to_tuple is True:
ctor = tuple
return (
ctor()
if elements is None
else ctor(elements)
if type(elements) in [list, set, tuple]
else ctor([elements])
)
@DeveloperAPI
class NullContextManager(contextlib.AbstractContextManager):
"""No-op context manager"""
def __init__(self):
pass
def __enter__(self):
pass
def __exit__(self, *args):
pass
force_tuple = partial(force_list, to_tuple=True)
__all__ = [
"add_mixins",
"check",
"check_env",
"check_compute_single_action",
"check_train_results",
"deep_update",
"deprecation_warning",
"fc",
"force_list",
"force_tuple",
"framework_iterator",
"lstm",
"merge_dicts",
"one_hot",
"override",
"relu",
"sigmoid",
"softmax",
"try_import_jax",
"try_import_tf",
"try_import_tfp",
"try_import_torch",
"ConstantSchedule",
"DeveloperAPI",
"ExponentialSchedule",
"Filter",
"FilterManager",
"LARGE_INTEGER",
"LinearSchedule",
"MAX_LOG_NN_OUTPUT",
"MIN_LOG_NN_OUTPUT",
"PiecewiseSchedule",
"PolynomialSchedule",
"PublicAPI",
"SMALL_NUMBER",
]