Skip to content

Conversation

@blacksde
Copy link
Contributor

Resolve issue #874

@googlebot googlebot added the cla: yes Declares that the user has signed CLA label Oct 26, 2020
@blacksde
Copy link
Contributor Author

@srvasude @brianwa84 Could you help review this pr? Thx~

@brianwa84 brianwa84 requested a review from SiegeLordEx October 26, 2020 14:52
@brianwa84
Copy link
Contributor

Pavel, could you take a look?

Copy link
Member

@SiegeLordEx SiegeLordEx left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, thanks! I glanced over the math, it seemed okay. I mainly have concerns about the gradients in a few spots.

that supports broadcasting (e.g. `loc + scale` + `concentration` is valid).

Args:
loc: Floating point tensor, the means of the distribution(s).
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not the actual mean. Maybe just call it "the location parameter of the distribution(s)"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thx for pointing this out, I will also fix all other places of misusing mean and loc.

parameters=parameters,
name=name)

@staticmethod
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_param_shapes and _param_event_ndims methods, we have a new mechanism for describing this. In this case, you can do:

  @classmethod
  def _parameter_properties(cls, dtype, num_classes=None):
    # pylint: disable=g-long-lambda
    return dict(
        loc=parameter_properties.ParameterProperties(),
        scale=parameter_properties.ParameterProperties(
            default_constraining_bijector_fn=(
                lambda: softplus_bijector.Softplus(low=dtype_util.eps(dtype)))))
        concentration=parameter_properties.ParameterProperties(),
    # pylint: enable=g-long-lambda

This new style lets us automatically constrain the parameters to the distribution.


def _entropy(self):
# Use broadcasting rules to calculate the full broadcast sigma.
scale = self.scale * tf.ones_like(self.loc)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do tf.broadcast_to(self.scale, ps.broadcast_shape(ps.shape(self.scale), ps.shape(self.loc)) instead, it's a little more self-descriptive (you can drop the comment) and doesn't waste flops.

ps should be defined as from tensorflow_probability.python.internal import prefer_static as ps up top.


return self.loc + self.scale * mode_z

def _default_event_space_bijector(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Drop this entirely, the superclass has a good default implementation for this distribution (it'll use the composition of sigmoid + GevCDF). Identity is particularly bad, since this distribution often isn't supported on the entire real line.

g1_square = tf.exp(tf.math.lgamma(1. - conc)) ** 2
g2 = tf.exp(tf.math.lgamma(1. - 2.*conc))

std_z = tf.where(equal_zero,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment about NaN gradients.

less_than_half = tf.less(conc, 0.5)

g1_square = tf.exp(tf.math.lgamma(1. - conc)) ** 2
g2 = tf.exp(tf.math.lgamma(1. - 2.*conc))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Spaces around binary ops, please.

tfd = tfp.distributions


class _GEVTest(object):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since I had all those concerns about NaN gradients, would you mind adding a test for those? See

@test_util.numpy_disable_gradient_test
def testFiniteGradientAtDifficultPoints(self):
def make_fn(dtype, attr):
x = np.array([-100., -20., -5., 5., 20., 100.]).astype(dtype)
return lambda m, s, p: getattr( # pylint: disable=g-long-lambda
tfd.GeneralizedNormal(loc=m, scale=s, power=p, validate_args=True),
attr)(x)
# TODO(b/157524947): add 'log_cdf', currently fails at -100, -20, in fp32.
for attr in ['log_prob', 'prob', 'cdf']:
value, grads = self.evaluate(tfp.math.value_and_gradient(
make_fn(self.dtype, attr),
[tf.constant(0, self.dtype), # mu
tf.constant(1, self.dtype), # scale
tf.constant(2.1, self.dtype)])) # power
self.assertAllFinite(value)
self.assertAllFinite(grads[0]) # d/d mu
self.assertAllFinite(grads[1]) # d/d scale
self.assertAllFinite(grads[2]) # d/d power
for an example. I especially care about the branch points of concentration=1, 0.5, 0 but other spots may be good to check too.

@blacksde
Copy link
Contributor Author

blacksde commented Nov 9, 2020

Looks good, thanks! I glanced over the math, it seemed okay. I mainly have concerns about the gradients in a few spots.

Hi @SiegeLordEx , Thanks a lot for all your valuable comments. I have already fixed all of them(although I didn't reply one by one inline :) ). There are two places to point out:

  1. for the broadcasting test case you suggested, I referred to the one in half_student_t_test as its test case class is the same as what I used previously. I don't want to change the basic structure too much so I was hoping this one would be OK.

  2. for the gradient NaN issue, I also updated some logic in gev_cdf as the cdf of gev relies on this bijector.

Let me know if there is anything else needed to be updated.

@blacksde blacksde requested a review from SiegeLordEx November 9, 2020 02:14
@SiegeLordEx
Copy link
Member

Looks good, thanks. I'll send this along to our internal review and that'll get this merged.

@copybara-service copybara-service bot merged commit 8e70d83 into tensorflow:master Nov 25, 2020
jburnim pushed a commit to jburnim/probability that referenced this pull request Dec 9, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla: yes Declares that the user has signed CLA

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants