-
Notifications
You must be signed in to change notification settings - Fork 25
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
jax.jit
を使って高速化するときのメモ
#5
Comments
関数の途中でのリターンやbreakはできる?どうする? |
forはコンパイル時に全て展開されてかなり時間を食うので、固定長であっても使わない方が良い。
# 遅い
for xy in range(BOARD_SIZE * BOARD_SIZE):
_, _, is_illegal = step(state, xy)
legal_action = legal_action.at[xy].set(is_illegal)
# 速い
jax.lax.map(lambda xy: step(state, xy), jnp.arange(BOARD_SIZE * BOARD_SIZE))[2] どうしても
for _xy in range(BOARD_SIZE * BOARD_SIZE):
for _around in [[-1, 0], [1, 0], [0, 1], [0, -1]]:
# 全ての点の周囲四方を調べ、取られた石と一致したら呼吸点に追加
# 隣接しているかどうかをstateに記録(0:なし 1:呼吸点 2:石)
jax.lax.map(
lambda l: jnp.where((l > 0) & surrounded_stones, 1, l), # 呼吸点に追加
_state.liberty,
) whileは条件が合えば使える。 |
|
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
jax
のnumpyとの基本的な違いとして、配列操作がimmutable(コピーしかできない)jax.jit
を使うためには動的な計算が制限される。などが動かないはず。
大体としてこれらを使うとjitでコンパイルできる(なぜかはよくわからない)
jax.lax.max
jax.lax.min
jax.lax.fori_loop
jax.lax.where
TIPS
if
基本構文。
<true_fn>
と,<false_fn>
の型は同じである必要がある。and/or
は&,|
を使うswitch
3つ以上の条件分岐などに使える
for
n
がstaticなら大丈夫The text was updated successfully, but these errors were encountered: