-
Notifications
You must be signed in to change notification settings - Fork 81
/
jax.py
140 lines (104 loc) · 3.98 KB
/
jax.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE
from __future__ import annotations
import enum
import threading
import weakref
import awkward as ak
from awkward import highlevel
from awkward._nplikes.numpy import Numpy
from awkward._typing import TypeVar
numpy = Numpy.instance()
def assert_never(arg) -> None:
raise AssertionError(f"this should never be run: {arg}")
class _RegistrationState(enum.Enum):
INIT = enum.auto()
SUCCESS = enum.auto()
FAILED = enum.auto()
_registration_lock = threading.RLock()
_registration_state = _RegistrationState.INIT
def register_and_check():
"""
Register Awkward Array node types with JAX's tree mechanism.
"""
try:
import jax # noqa: TID251, F401
except ModuleNotFoundError:
raise ModuleNotFoundError(
"""install the 'jax' package with:
python3 -m pip install jax jaxlib
or
conda install -c conda-forge jax jaxlib
"""
) from None
_register()
HighLevelType = TypeVar(
"HighLevelType", bound="type[highlevel.Array | highlevel.Record]"
)
_known_highlevel_classes = weakref.WeakSet([highlevel.Array, highlevel.Record])
def register_behavior_class(cls: HighLevelType):
"""
Args:
cls: behavior class to register with JAX
Register the behavior class with JAX, if JAX integration is enabled. Otherwise,
queue the type for subsequent registration when/if JAX is registered.
"""
# Acquire lock so that we know registration has completed
with _registration_lock:
if _registration_state == _RegistrationState.SUCCESS:
# Safe to invoke JAX code here
import awkward._connect.jax as jax_connect
jax_connect.register_pytree_class(cls)
else:
_known_highlevel_classes.add(cls)
def _register():
"""
Register Awkward Array node types with JAX's tree mechanism.
"""
global _registration_state
# Require that no threads are trying to register before checking this flag
with _registration_lock:
if _registration_state != _RegistrationState.INIT:
return
try:
import awkward._connect.jax as jax_connect
for cls in [
ak.contents.BitMaskedArray,
ak.contents.ByteMaskedArray,
ak.contents.EmptyArray,
ak.contents.IndexedArray,
ak.contents.IndexedOptionArray,
ak.contents.NumpyArray,
ak.contents.ListArray,
ak.contents.ListOffsetArray,
ak.contents.RecordArray,
ak.contents.UnionArray,
ak.contents.UnmaskedArray,
ak.record.Record,
]:
jax_connect.register_pytree_class(cls)
for cls in _known_highlevel_classes:
jax_connect.register_pytree_class(cls)
except Exception:
_registration_state = _RegistrationState.FAILED
raise
else:
_registration_state = _RegistrationState.SUCCESS
def assert_registered():
"""Ensure that JAX integration is registered. Raise a RuntimeError if not."""
with _registration_lock:
if _registration_state == _RegistrationState.INIT:
raise RuntimeError("JAX features require `ak.jax.register_and_check()`")
elif _registration_state == _RegistrationState.FAILED:
raise RuntimeError(
"JAX features require `ak.jax.register_and_check()`, "
"but the last call to `ak.jax.register_and_check()` did not succeed. "
"Please look for a traceback to identify the error."
)
elif _registration_state == _RegistrationState.SUCCESS:
return
assert_never(_registration_state)
def import_jax():
"""Ensure that JAX integration is registered, and return the JAX module. Raise a RuntimeError if not."""
assert_registered()
import jax # noqa: TID251
return jax