|
6 | 6 | from typing import Callable, Optional
|
7 | 7 |
|
8 | 8 | import jax
|
9 |
| -from jax import device_put, lax, random |
| 9 | +from jax import lax, random |
10 | 10 | import jax.numpy as jnp
|
11 | 11 |
|
12 | 12 | from numpyro import handlers
|
@@ -228,7 +228,6 @@ def body_fn(wrapped_carry, x, prefix=None):
|
228 | 228 | # return early if length = unroll_steps
|
229 | 229 | if length == unroll_steps:
|
230 | 230 | return wrapped_carry, (PytreeTrace({}), y0s)
|
231 |
| - wrapped_carry = jax.tree.map(device_put, wrapped_carry) |
232 | 231 | wrapped_carry, (pytree_trace, ys) = lax.scan(
|
233 | 232 | body_fn, wrapped_carry, xs_, length - unroll_steps, reverse
|
234 | 233 | )
|
@@ -331,7 +330,7 @@ def body_fn(wrapped_carry, x):
|
331 | 330 |
|
332 | 331 | return (i + 1, rng_key, carry), (PytreeTrace(trace), y)
|
333 | 332 |
|
334 |
| - wrapped_carry = jax.tree.map(device_put, (0, rng_key, init)) |
| 333 | + wrapped_carry = (jnp.asarray(0), rng_key, init) |
335 | 334 | last_carry, (pytree_trace, ys) = lax.scan(
|
336 | 335 | body_fn, wrapped_carry, xs, length=length, reverse=reverse
|
337 | 336 | )
|
|
0 commit comments