Skip to content

Commit 0097c02

Browse files
committed
remove unnecessary device_put
1 parent 8a67269 commit 0097c02

File tree

7 files changed

+13
-14
lines changed

7 files changed

+13
-14
lines changed

numpyro/contrib/control_flow/cond.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from functools import partial
55
from typing import Any, Callable
66

7-
from jax import device_put, lax
7+
from jax import lax
88

99
from numpyro import handlers
1010
from numpyro.ops.pytree import PytreeTrace
@@ -69,7 +69,7 @@ def cond_wrapper(
6969

7070
wrapped_true_fun = wrap_fn(true_fun, substitute_stack)
7171
wrapped_false_fun = wrap_fn(false_fun, substitute_stack)
72-
wrapped_operand = device_put((rng_key, operand))
72+
wrapped_operand = (rng_key, operand)
7373
return lax.cond(pred, wrapped_true_fun, wrapped_false_fun, wrapped_operand)
7474

7575

numpyro/contrib/control_flow/scan.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from typing import Callable, Optional
77

88
import jax
9-
from jax import device_put, lax, random
9+
from jax import lax, random
1010
import jax.numpy as jnp
1111

1212
from numpyro import handlers
@@ -228,7 +228,6 @@ def body_fn(wrapped_carry, x, prefix=None):
228228
# return early if length = unroll_steps
229229
if length == unroll_steps:
230230
return wrapped_carry, (PytreeTrace({}), y0s)
231-
wrapped_carry = jax.tree.map(device_put, wrapped_carry)
232231
wrapped_carry, (pytree_trace, ys) = lax.scan(
233232
body_fn, wrapped_carry, xs_, length - unroll_steps, reverse
234233
)
@@ -331,7 +330,7 @@ def body_fn(wrapped_carry, x):
331330

332331
return (i + 1, rng_key, carry), (PytreeTrace(trace), y)
333332

334-
wrapped_carry = jax.tree.map(device_put, (0, rng_key, init))
333+
wrapped_carry = (jnp.asarray(0), rng_key, init)
335334
last_carry, (pytree_trace, ys) = lax.scan(
336335
body_fn, wrapped_carry, xs, length=length, reverse=reverse
337336
)

numpyro/infer/barker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs):
208208
wa_state,
209209
rng_key,
210210
)
211-
return jax.device_put(init_state)
211+
return init_state
212212

213213
def postprocess_fn(self, args, kwargs):
214214
if self._postprocess_fn is None:

numpyro/infer/hmc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import os
88
import warnings
99

10-
from jax import device_put, lax, random, vmap
10+
from jax import lax, random, vmap
1111
from jax.flatten_util import ravel_pytree
1212
import jax.numpy as jnp
1313

@@ -359,7 +359,7 @@ def init_kernel(
359359
wa_state,
360360
rng_key_hmc,
361361
)
362-
return device_put(hmc_state)
362+
return hmc_state
363363

364364
def _hmc_next(
365365
step_size,

numpyro/infer/hmc_gibbs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import numpy as np
99

10-
from jax import device_put, grad, jacfwd, random, value_and_grad
10+
from jax import grad, jacfwd, random, value_and_grad
1111
from jax.flatten_util import ravel_pytree
1212
import jax.numpy as jnp
1313
from jax.scipy.special import expit
@@ -148,7 +148,7 @@ def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs):
148148

149149
z = {**gibbs_sites, **hmc_state.z}
150150

151-
return device_put(HMCGibbsState(z, hmc_state, rng_key))
151+
return HMCGibbsState(z, hmc_state, rng_key)
152152

153153
def sample(self, state, model_args, model_kwargs):
154154
model_kwargs = {} if model_kwargs is None else model_kwargs

numpyro/infer/sa.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from collections import namedtuple
55

6-
from jax import device_put, lax, random, vmap
6+
from jax import lax, random, vmap
77
from jax.flatten_util import ravel_pytree
88
import jax.numpy as jnp
99
from jax.scipy.special import logsumexp
@@ -174,7 +174,7 @@ def init_kernel(
174174
adapt_state,
175175
rng_key_sa,
176176
)
177-
return device_put(sa_state)
177+
return sa_state
178178

179179
def sample_kernel(sa_state, model_args=(), model_kwargs=None):
180180
pe_fn = potential_fn

numpyro/util.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from tqdm.auto import tqdm as tqdm_auto
1919

2020
import jax
21-
from jax import device_put, jit, lax, vmap
21+
from jax import jit, lax, vmap
2222
from jax.core import Tracer
2323
from jax.experimental import io_callback
2424
import jax.numpy as jnp
@@ -386,7 +386,7 @@ def loop_fn(collection):
386386
diagnostics_fn = progbar_opts.pop("diagnostics_fn", None)
387387
progbar_desc = progbar_opts.pop("progbar_desc", lambda x: "")
388388

389-
vals = (init_val, collection, device_put(start_idx), device_put(thinning))
389+
vals = (init_val, collection, jnp.asarray(start_idx), jnp.asarray(thinning))
390390

391391
if upper == 0:
392392
# special case, only compiling

0 commit comments

Comments
 (0)