Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Outdated documentation #938

Closed
AdrienCorenflos opened this issue May 18, 2020 · 7 comments
Closed

Outdated documentation #938

AdrienCorenflos opened this issue May 18, 2020 · 7 comments

Comments

@AdrienCorenflos
Copy link
Contributor

Hi,

This doc page is outdated:
https://www.tensorflow.org/probability/api_docs/python/tfp/util/SeedStream
By the way, correcting the code in your example, this does not return 0.5 all the time in eager mode.

def broken_beta(shape, alpha, beta, seed):
  x = tf.random.gamma(shape, alpha, seed=seed)
  y = tf.random.gamma(shape, beta, seed=seed)
  return x / (x + y)
@csuter
Copy link
Member

csuter commented May 18, 2020

Hi, can you clarify what is outdated about the documentation? It sounds like that's separate from the eager seed behavior?

Eager seed weirdness is a known issue; it's documented here: https://www.tensorflow.org/api_docs/python/tf/random/set_seed

The appropriate workaround is to use the newer stateless sampler and/or Generator APIs. TFP is in the process of migrating all our samplers to use this stateless API, which will make the behavior in eager more like what one would expect.

@AdrienCorenflos
Copy link
Contributor Author

tf.random_gamma has not been part of the interface for a while. You should replace with the snippet I provided.

Also you may want to add an autograph decorator to the example, otherwise it's a bit misleading.

I am aware of the random seed thingy, it's actually hurting me crucially (the tf.random.set_seed operation is super slow and scales linearly with the number of operations in the graph) so I was looking at your stuff to try and hack my way through the issue by feeding the seed as a variable but it's not working (gen_random_ops needs a genuine int as a seed and won't take a tensor). That's how I saw this outdated doc.

@csuter
Copy link
Member

csuter commented May 18, 2020

Depending on what you're doing you might be able to get around the seed issues today by passing either a list of 2 ints or a Tensor of 2 ints as the argument to the seed parameter. This is what enables stateless sampling in the TFP API now. The migration is still in progress so not all samplers will work (hence "depending on what you're doing"). I think we have quite a lot of bases covered now, though, but it's not yet well documented. The new entrypoint for most TFP sampling is now here: https://github.com/tensorflow/probability/blob/master/tensorflow_probability/python/internal/samplers.py

@AdrienCorenflos
Copy link
Contributor Author

AdrienCorenflos commented May 19, 2020

To be honest, ideally I would like to be able to pass a random state, a la numpy in a sense. Is tfp gonna support this kind of interface?

It feels a bit artificial to be constructing something like this (if I understand what you are suggesting):

def random_fun(global_seed):
    local_seed = tf.Variable(0)
    @tf.function
    def random_fun_inner(global_seed):
        x = dist.sample(seed=(global_seed, local_seed.add_assign(1))
        y = dist.sample(seed=(global_seed, local_seed.add_assign(1))
        return x + y
    reset_seed = local_seed.assign(0)
    with tf.control_dependencies([reset_seed]):
        return random_fun_inner(global_seed)

Something like this would feel much more natural as a user interface:

@tf.function
def random_fun(generator):
    x = dist.sample(generator=generator)
    y = dist.sample(generator=generator)
    return x + y

And it would be the job of the generator to increment itself.

@AdrienCorenflos
Copy link
Contributor Author

FYI I have tested this on colab (with global and local being either first or second element) but it doesn't work:

import tensorflow as tf
import tensorflow_probability as tfp

dist = tfp.distributions.Normal(0., 1.)

def random_fun(global_seed):
    local_seed = tf.Variable(0)
    @tf.function
    def random_fun_inner(global_seed):
        seed = [local_seed.assign_add(1), global_seed]
        x = dist.sample(seed=seed)
        seed = [local_seed.assign_add(1), global_seed]
        y = dist.sample(seed=seed)
        return x + y
    reset_seed = local_seed.assign(0)
    with tf.control_dependencies([reset_seed]):
        return random_fun_inner(global_seed)

@csuter
Copy link
Member

csuter commented May 20, 2020

I'd recommend avoiding the Variables explicitly, though IIUC this is how the TF Generator API works under the hood. The pattern we're moving towards in TFP is the following, styled after JAX (see their excellent PRNG doc):

The API is still in our "internal" package, but we plan to make it part of the public API soon:

import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow_probability.python.internal import samplers

dist = tfp.distributions.Normal(0., 1.)

def random_fun(seed):
    @tf.function
    def random_fun_inner(seed):
        x_seed, y_seed = samplers.split_seed(seed, salt="random_fun_inner")
        x = dist.sample(seed=x_seed)
        y = dist.sample(seed=y_seed)
        return x + y

    _, new_seed = samplers.split_seed(seed, salt="random_fun")
    return random_fun_inner(new_seed)

init_seed = 0
print(random_fun(init_seed))
# => 0.1234
print(random_fun(init_seed))
# => 0.1234
_, new_seed = samplers.split_seed(init_seed)
print(random_fun(new_seed))
# => 0.6789

@srvasude
Copy link
Member

I'm going to close this since the documentation issues is fixed, but feel free to reopen if there are other issues.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants