Skip to content

Commit

Permalink
FunMC: Add AIS/SMC.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 492849049
  • Loading branch information
SiegeLordEx authored and tensorflower-gardener committed Dec 4, 2022
1 parent 6fc9a9e commit e816859
Show file tree
Hide file tree
Showing 5 changed files with 489 additions and 40 deletions.
6 changes: 5 additions & 1 deletion spinoffs/fun_mc/fun_mc/dynamic/backend_jax/tf_on_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,10 @@ def _get_static_value(value):

tf.newaxis = None

_impl_np()(jnp.cumsum)
_impl_np()(jnp.exp)
_impl_np()(jnp.einsum)
_impl_np()(jnp.floor)
_impl_np()(jnp.float32)
_impl_np()(jnp.float64)
_impl_np()(jnp.int32)
Expand All @@ -179,20 +181,22 @@ def _get_static_value(value):
_impl_np()(jnp.zeros_like)
_impl_np()(jnp.transpose)
_impl_np(name='fill')(jnp.full)
_impl_np(['nn'])(jax.nn.softmax)
_impl_np(['math'])(jnp.ceil)
_impl_np(['math'])(jnp.log)
_impl_np(['math'], name='mod')(jnp.mod)
_impl_np(['math'])(jnp.sqrt)
_impl_np(['math'], name='is_finite')(jnp.isfinite)
_impl_np(['math'], name='is_nan')(jnp.isnan)
_impl_np(['math'], name='pow')(jnp.power)
_impl_np(['math'], name='reduce_all')(jnp.all)
_impl_np(['math'], name='reduce_prod')(jnp.prod)
_impl_np(['math'], name='reduce_variance')(jnp.var)
_impl_np(name='abs')(jnp.abs)
_impl_np(name='Tensor')(jnp.ndarray)
_impl_np(name='concat')(jnp.concatenate)
_impl_np(name='constant')(jnp.array)
_impl_np(name='expand_dims')(jnp.expand_dims)
_impl_np(['math'], name='reduce_all')(jnp.all)
_impl_np(name='reduce_max')(jnp.max)
_impl_np(name='reduce_mean')(jnp.mean)
_impl_np(name='reduce_sum')(jnp.sum)
Expand Down
19 changes: 16 additions & 3 deletions spinoffs/fun_mc/fun_mc/dynamic/backend_jax/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@
from jax import lax
from jax import random
from jax import tree_util
from jax.example_libraries import stax
import jax.numpy as jnp

__all__ = [
'assert_same_shallow_tree',
'block_until_ready',
'diff',
'flatten_tree',
'get_shallow_tree',
'inverse_fn',
Expand All @@ -38,6 +38,7 @@
'random_integer',
'random_normal',
'random_uniform',
'repeat',
'split_seed',
'trace',
'value_and_grad',
Expand Down Expand Up @@ -143,7 +144,7 @@ def body(state):

def random_categorical(logits, num_samples, seed):
"""Returns a sample from a categorical distribution. `logits` must be 2D."""
probs = stax.softmax(logits)
probs = jax.nn.softmax(logits)
cum_sum = jnp.cumsum(probs, axis=-1)

eta = random.uniform(
Expand Down Expand Up @@ -211,6 +212,7 @@ def wrapper(i, state_untraced_traced):
state, untraced, traced = fn(state)
trace_arrays = map_tree(lambda a, e: a.at[i].set(e), trace_arrays, traced)
return (state, untraced, trace_arrays)

state, untraced, traced = lax.fori_loop(
jnp.asarray(0, num_steps.dtype),
num_steps,
Expand Down Expand Up @@ -250,7 +252,6 @@ def scale_by_two(x):
assert y_extra == 3
assert y_ldj == jnp.log(2)
```
"""
value, (extra, ldj) = fn(args)
return value, (extra, ldj), ldj
Expand Down Expand Up @@ -307,11 +308,13 @@ def block_until_ready(tensors):
Returns:
tensors: Tensors that are are guaranteed to be ready to materialize.
"""

def _block_until_ready(tensor):
if hasattr(tensor, 'block_until_ready'):
return tensor.block_until_ready()
else:
return tensor

return map_tree(_block_until_ready, tensors)


Expand All @@ -326,3 +329,13 @@ def named_call(f=None, name=None):
return functools.partial(named_call, name=name)

return jax.named_call(f, name=name)


def diff(x, prepend=None):
"""Like jnp.diff."""
return jnp.diff(x, prepend=prepend)


def repeat(x, repeats, total_repeat_length=None):
"""Like jnp.repeat."""
return jnp.repeat(x, repeats, total_repeat_length=total_repeat_length)
28 changes: 25 additions & 3 deletions spinoffs/fun_mc/fun_mc/dynamic/backend_tensorflow/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
__all__ = [
'assert_same_shallow_tree',
'block_until_ready',
'diff',
'flatten_tree',
'get_shallow_tree',
'inverse_fn',
Expand All @@ -38,6 +39,7 @@
'random_integer',
'random_normal',
'random_uniform',
'repeat',
'split_seed',
'trace',
'value_and_ldj',
Expand Down Expand Up @@ -177,7 +179,9 @@ def trace(state, fn, num_steps, unroll, parallel_iterations=10):
state, first_untraced, first_traced = fn(state)
arrays = tf.nest.map_structure(
lambda v: tf.TensorArray( # pylint: disable=g-long-lambda
v.dtype, size=num_steps, element_shape=v.shape).write(0, v),
v.dtype,
size=num_steps,
element_shape=v.shape).write(0, v),
first_traced)
start_idx = 1
else:
Expand All @@ -189,7 +193,10 @@ def trace(state, fn, num_steps, unroll, parallel_iterations=10):

arrays = tf.nest.map_structure(
lambda spec: tf.TensorArray( # pylint: disable=g-long-lambda
spec.dtype, size=num_steps, element_shape=spec.shape), traced_spec)
spec.dtype,
size=num_steps,
element_shape=spec.shape),
traced_spec)
first_untraced = tf.nest.map_structure(
lambda spec: tf.zeros(spec.shape, spec.dtype), untraced_spec)
start_idx = 0
Expand Down Expand Up @@ -266,7 +273,6 @@ def scale_by_two(x):
assert y_extra == 3
assert y_ldj == np.log(2)
```
"""
value, (extra, ldj) = fn(args)
return value, (extra, ldj), ldj
Expand Down Expand Up @@ -348,3 +354,19 @@ def wrapped(*args, **kwargs):
return f(*args, **kwargs)

return wrapped


def diff(x, prepend=None):
"""Like jnp.diff."""
if prepend is not None:
x = tf.concat([tf.convert_to_tensor(prepend, dtype=x.dtype)[tf.newaxis], x],
0)
return x[1:] - x[:-1]


def repeat(x, repeats, total_repeat_length=None):
"""Like jnp.repeat."""
res = tf.repeat(x, repeats)
if total_repeat_length is not None:
res.set_shape([total_repeat_length] + [None] * (len(res.shape) - 1))
return res
Loading

0 comments on commit e816859

Please sign in to comment.