Skip to content

Commit

Permalink
transfer_states_to_host convenience function (#1707)
Browse files Browse the repository at this point in the history
  • Loading branch information
amifalk committed Dec 24, 2023
1 parent a6693bb commit 09551d9
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion numpyro/infer/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import numpy as np

from jax import jit, lax, local_device_count, pmap, random, vmap
from jax import jit, lax, local_device_count, pmap, random, vmap, device_get
import jax.numpy as jnp
from jax.tree_util import tree_flatten, tree_map

Expand Down Expand Up @@ -721,6 +721,13 @@ def print_summary(self, prob=0.9, exclude_deterministic=True):
"Number of divergences: {}".format(jnp.sum(extra_fields["diverging"]))
)

def transfer_states_to_host(self):
"""
Reduce the memory footprint of collected samples by transfering them to the host device.
"""
self._states = device_get(self._states)
self._states_flat = device_get(self._states_flat)

def __getstate__(self):
state = self.__dict__.copy()
state["_cache"] = {}
Expand Down

0 comments on commit 09551d9

Please sign in to comment.