Skip to content

Commit

Permalink
Avoid recompilation of rolled loops in threefry2x32. (jax-ml#3069)
Browse files Browse the repository at this point in the history
  • Loading branch information
tomhennigan committed May 12, 2020
1 parent 11760ca commit abdf504
Showing 1 changed file with 20 additions and 17 deletions.
37 changes: 20 additions & 17 deletions jax/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,25 @@ def _threefry2x32_abstract_eval(*args):
aval = abstract_arrays.UnshapedArray(np.dtype(np.uint32))
return (aval,) * 2

rotate_left = _make_rotate_left(onp.uint32)

def apply_round(v, rot):
v = v[:]
v[0] = v[0] + v[1]
v[1] = rotate_left(v[1], rot)
v[1] = v[0] ^ v[1]
return v

def rotate_list(xs):
return xs[1:] + xs[:1]

def rolled_loop_step(i, state):
x, ks, rotations = state
for r in rotations[0]:
x = apply_round(x, r)
new_x = [x[0] + ks[0], x[1] + ks[1] + asarray(i + 1, dtype=onp.uint32)]
return new_x, rotate_list(ks), rotate_list(rotations)

def _threefry2x32_lowering(key1, key2, x1, x2, use_rolled_loops=True):
"""Apply the Threefry 2x32 hash.
Expand All @@ -119,15 +138,6 @@ def _threefry2x32_lowering(key1, key2, x1, x2, use_rolled_loops=True):
An array of dtype uint32 with the same shape as `count`.
"""
x = [x1, x2]
rotate_left = _make_rotate_left(onp.uint32)

def apply_round(v, rot):
v = v[:]
v[0] = v[0] + v[1]
v[1] = rotate_left(v[1], rot)
v[1] = v[0] ^ v[1]
return v


rotations = [onp.array([13, 15, 26, 6], dtype=onp.uint32),
onp.array([17, 29, 16, 24], dtype=onp.uint32)]
Expand All @@ -137,14 +147,7 @@ def apply_round(v, rot):
x[1] = x[1] + ks[1]

if use_rolled_loops:
def rotate_list(xs): return xs[1:] + xs[:1]
def step(i, state):
x, ks, rotations = state
for r in rotations[0]:
x = apply_round(x, r)
new_x = [x[0] + ks[0], x[1] + ks[1] + asarray(i + 1, dtype=onp.uint32)]
return new_x, rotate_list(ks), rotate_list(rotations)
x, _, _ = lax.fori_loop(0, 5, step, (x, rotate_list(ks), rotations))
x, _, _ = lax.fori_loop(0, 5, rolled_loop_step, (x, rotate_list(ks), rotations))

else:
for r in rotations[0]:
Expand Down

0 comments on commit abdf504

Please sign in to comment.