Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jax.jit を使って高速化するときのメモ #5

Closed
sotetsuk opened this issue Aug 27, 2022 · 4 comments
Closed

jax.jit を使って高速化するときのメモ #5

sotetsuk opened this issue Aug 27, 2022 · 4 comments

Comments

@sotetsuk
Copy link
Owner

sotetsuk commented Aug 27, 2022

jax のnumpyとの基本的な違いとして、配列操作がimmutable(コピーしかできない)

jax.jit を使うためには動的な計算が制限される。

  • 引数を使ったif
  • 引数を長さに使ったfor
  • 引数をindexにつかった配列操作

などが動かないはず。
大体としてこれらを使うとjitでコンパイルできる(なぜかはよくわからない)

TIPS

  • 基本的にはまず、jitなしの実装とテストを用意する
  • 少しずつ関数呼び出しの深い方からテストとjitが通るようにjitに書き換えていく
  • もとの実装の時点で深いネストの実装を避ける
  • 深いネストは細かく純関数に分ける
    • このとき、たとえばネスト内で複数の変数が更新されていたら(e.g., x,y,z)、これらすべてを引数にとって、同様に返すような変数にすると楽

if

基本構文。<true_fn>,<false_fn> の型は同じである必要がある。

y = jax.lax.cond(
   x > 0, # cond
   lambda: x ** 2,  # if true 
   lambda: x, # if false
)

and/or&,| を使う

x = jax.lax.cond(
    (x % 2 == 0) & (x % 3 == 0),
    lambda: x**2,
    lambda: x,
)

switch

3つ以上の条件分岐などに使える

for

n がstaticなら大丈夫

@jax.jit
def f():
    s = 0
    for i in range(8):
        s += 1
    return s
@sotetsuk
Copy link
Owner Author

@sotetsuk
Copy link
Owner Author

関数の途中でのリターンやbreakはできる?どうする?

@OkanoShinri
Copy link
Collaborator

OkanoShinri commented Sep 29, 2022

forはコンパイル時に全て展開されてかなり時間を食うので、固定長であっても使わない方が良い。
jax.lax.fori_loopjax.lax.mapに変える。

  • before
# 遅い
for xy in range(BOARD_SIZE * BOARD_SIZE):
  _, _, is_illegal = step(state, xy)
  legal_action = legal_action.at[xy].set(is_illegal)
  • after
# 速い
jax.lax.map(lambda xy: step(state, xy), jnp.arange(BOARD_SIZE * BOARD_SIZE))[2]

どうしてもforですべての要素を調べないといけない場合が生じる場合、アルゴリズムを見直した方が良い。
(stateに新たに変数を追加するくらいで遅くなったりはしない)
例:囲碁で石を取った時に、隣接する連に呼吸点を追加する部分

  • before
for _xy in range(BOARD_SIZE * BOARD_SIZE):
  for _around in [[-1, 0], [1, 0], [0, 1], [0, -1]]:
    # 全ての点の周囲四方を調べ、取られた石と一致したら呼吸点に追加
  • after
# 隣接しているかどうかをstateに記録(0:なし 1:呼吸点 2:石)
jax.lax.map(
  lambda l: jnp.where((l > 0) & surrounded_stones, 1, l),  # 呼吸点に追加
  _state.liberty,
)

whileは条件が合えば使える。
jaxの制限を意識しすぎて非効率的なアルゴリズムを組むくらいならjax.lax.while_loopを使った方が良い。

@sotetsuk
Copy link
Owner Author

sotetsuk commented Oct 5, 2022

  • 途中でのリターンはcondで分けてreturn以降を別関数へ分ける
  • fori_loop, map, scanを使い分ける
  • while_loop
  • どうすてもifがネストするときには条件に名前を付けてネストを浅くする

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants