Skip to content

JointDistributionCoroutineAutoBatched seed behaviour with batches #2011

@chrism0dwk

Description

@chrism0dwk

Hi all,

I'm trying to compute a posterior predictive distribution over samples from a posterior distribution (Colab here). TFP 0.25 with JAX backend.

My (mre and therefore contrived) model specification is

@tfd.JointDistributionCoroutineAutoBatched
def model_autobatched():
    theta = yield tfd.Normal(loc=0., scale=1., name="theta")
    yield tfd.Normal(loc=theta, scale=0.1, name="y")

i.e. a Normally-distributed observation model with Normally-distributed mean. To compute the posterior predictive distribution, I wish to sample the y component conditional on a vector of theta samples.

theta_samples = np.arange(5.)
model_autobatched.sample(theta=theta_samples, seed=jax.random.key(0))

giving

StructTuple(
  theta=Array([0., 1., 2., 3., 4.], dtype=float32),
  y=Array([0.06215769, 1.0621576 , 2.0621576 , 3.0621576 , 4.0621576 ],      dtype=float32)
)

Oh dear, we notice that y - theta = constant. This seems to suggest that a single PRNG key is being used for each draw of y given the sample from theta.

Moreover, this approach fails entirely if sample_distributions is called.

model_autobatched.sample_distributions(theta=theta_samples, seed=jax.random.key(0))
ValueError: Attempt to convert a value (<object object at 0x7a53561590d0>) with an unsupported type (<class 'object'>) to a Tensor.

As a workaround, we could use the older JointDistributionCoroutine with Root annotation which works as desired (see Colab)

[edit] actually, JDCoroutine/Root only works because the whole theta vector is passed to y's constructor, not vectorisation over the whole model.

Do we have a bug or a feature, I wonder?

Regards,

Chris

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions