diff --git a/tensorflow_probability/python/edward2/random_variable.py b/tensorflow_probability/python/edward2/random_variable.py index 4991d2251b..f2b385e1de 100644 --- a/tensorflow_probability/python/edward2/random_variable.py +++ b/tensorflow_probability/python/edward2/random_variable.py @@ -23,6 +23,7 @@ from tensorflow.python.client import session as tf_session from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util __all__ = [ "RandomVariable", @@ -106,12 +107,11 @@ def __init__(self, NotImplementedError: `distribution` does not have a `sample` method. """ self._distribution = distribution - - self._sample_shape = tf.TensorShape(sample_shape) + self._sample_shape = sample_shape if value is not None: t_value = tf.convert_to_tensor(value, self.distribution.dtype) value_shape = t_value.shape - expected_shape = self._sample_shape.concatenate( + expected_shape = self.sample_shape.concatenate( self.distribution.batch_shape).concatenate( self.distribution.event_shape) if not value_shape.is_compatible_with(expected_shape): @@ -122,7 +122,7 @@ def __init__(self, self._value = t_value else: try: - self._value = self.distribution.sample(self._sample_shape) + self._value = self.distribution.sample(self.sample_shape_tensor()) except NotImplementedError: raise NotImplementedError( "sample is not implemented for {0}. You must either pass in the " @@ -141,8 +141,24 @@ def dtype(self): @property def sample_shape(self): - """Sample shape of random variable.""" - return self._sample_shape + """Sample shape of random variable as a `TensorShape`.""" + if isinstance(self._sample_shape, tf.Tensor): + return tf.TensorShape(tensor_util.constant_value(self._sample_shape)) + return tf.TensorShape(self._sample_shape) + + def sample_shape_tensor(self, name="sample_shape_tensor"): + """Sample shape of random variable as a 1-D `Tensor`. + + Args: + name: name to give to the op + + Returns: + batch_shape: `Tensor`. + """ + with tf.name_scope(name): + if isinstance(self._sample_shape, tf.Tensor): + return self._sample_shape + return tf.convert_to_tensor(self.sample_shape.as_list(), dtype=tf.int32) @property def shape(self): diff --git a/tensorflow_probability/python/edward2/random_variable_test.py b/tensorflow_probability/python/edward2/random_variable_test.py index 5e92b9ffc4..52d9d55a3c 100644 --- a/tensorflow_probability/python/edward2/random_variable_test.py +++ b/tensorflow_probability/python/edward2/random_variable_test.py @@ -377,6 +377,19 @@ def testShapeRandomVariable(self): self._testShape( ed.RandomVariable(tfd.Bernoulli(probs=0.5), sample_shape=[2, 1]), [2, 1], [], []) + self._testShape( + ed.RandomVariable(tfd.Bernoulli(probs=0.5), + sample_shape=tf.constant([2])), + [2], [], []) + self._testShape( + ed.RandomVariable(tfd.Bernoulli(probs=0.5), + sample_shape=tf.constant([2, 4])), + [2, 4], [], []) + + @tfe.run_test_in_graph_and_eager_modes() + def testRandomTensorSample(self): + num_samples = tf.cast(tfd.Poisson(rate=5.).sample(), tf.int32) + _ = tfd.Normal(loc=0.0, scale=1.0).sample(num_samples) if __name__ == "__main__":