diff --git a/tensorflow_probability/python/bijectors/gev_cdf.py b/tensorflow_probability/python/bijectors/gev_cdf.py index b40bed524c..4f9939a126 100644 --- a/tensorflow_probability/python/bijectors/gev_cdf.py +++ b/tensorflow_probability/python/bijectors/gev_cdf.py @@ -120,14 +120,18 @@ def _is_increasing(cls): def _forward(self, x): loc = tf.convert_to_tensor(self.loc) scale = tf.convert_to_tensor(self.scale) - concentration = tf.convert_to_tensor(self.concentration) + conc = tf.convert_to_tensor(self.concentration) with tf.control_dependencies( self._maybe_assert_valid_x( - x, loc=loc, scale=scale, concentration=concentration)): + x, loc=loc, scale=scale, concentration=conc)): z = (x - loc) / scale + + equal_zero = tf.equal(conc, 0.) + # deal with case that gradient is N/A when conc = 0 + safe_conc = tf.where(equal_zero, 1., conc) t = tf.where( - tf.equal(concentration, 0.), tf.math.exp(-z), - tf.math.exp(-tf.math.log1p(z * concentration) / concentration)) + equal_zero, tf.math.exp(-z), + tf.math.exp(-tf.math.log1p(z * safe_conc) / safe_conc)) return tf.exp(-t) def _inverse(self, y): @@ -135,24 +139,34 @@ def _inverse(self, y): t = -tf.math.log(y) conc = tf.convert_to_tensor(self.concentration) + + equal_zero = tf.equal(conc, 0.) + # deal with case that gradient is N/A when conc = 0 + safe_conc = tf.where(equal_zero, 1., conc) + z = tf.where( - tf.equal(conc, 0.), -tf.math.log(t), - tf.math.expm1(-tf.math.log(t) * conc) / conc) + equal_zero, -tf.math.log(t), + tf.math.expm1(-tf.math.log(t) * safe_conc) / safe_conc) return self.loc + self.scale * z def _forward_log_det_jacobian(self, x): loc = tf.convert_to_tensor(self.loc) scale = tf.convert_to_tensor(self.scale) - concentration = tf.convert_to_tensor(self.concentration) + conc = tf.convert_to_tensor(self.concentration) with tf.control_dependencies( self._maybe_assert_valid_x( - x, loc=loc, scale=scale, concentration=concentration)): + x, loc=loc, scale=scale, concentration=conc)): z = (x - loc) / scale + + equal_zero = tf.equal(conc, 0.) + # deal with case that gradient is N/A when conc = 0 + safe_conc = tf.where(equal_zero, 1., conc) + log_t = tf.where( - tf.equal(concentration, 0.), -z, - -tf.math.log1p(z * concentration) / concentration) - return (tf.math.multiply_no_nan(concentration + 1., log_t) - + equal_zero, -z, + -tf.math.log1p(z * safe_conc) / safe_conc) + return (tf.math.multiply_no_nan(conc + 1., log_t) - tf.math.exp(log_t) - tf.math.log(scale)) def _inverse_log_det_jacobian(self, y): diff --git a/tensorflow_probability/python/distributions/BUILD b/tensorflow_probability/python/distributions/BUILD index 313ccce56d..ff16d4be0e 100644 --- a/tensorflow_probability/python/distributions/BUILD +++ b/tensorflow_probability/python/distributions/BUILD @@ -72,6 +72,7 @@ multi_substrate_py_library( ":generalized_pareto", ":geometric", ":gumbel", + ":gev", ":half_cauchy", ":half_normal", ":half_student_t", @@ -740,6 +741,24 @@ multi_substrate_py_library( ], ) +multi_substrate_py_library( + name = "gev", + srcs = ["gev.py"], + deps = [ + ":kullback_leibler", + ":transformed_distribution", + ":uniform", + # numpy dep, + # tensorflow dep, + "//tensorflow_probability/python/bijectors:gev_cdf", + "//tensorflow_probability/python/bijectors:identity", + "//tensorflow_probability/python/bijectors:invert", + "//tensorflow_probability/python/internal:distribution_util", + "//tensorflow_probability/python/internal:dtype_util", + "//tensorflow_probability/python/internal:tensor_util", + ], +) + multi_substrate_py_library( name = "half_cauchy", srcs = ["half_cauchy.py"], @@ -2508,6 +2527,19 @@ multi_substrate_py_test( ], ) +multi_substrate_py_test( + name = "gev_test", + srcs = ["gev_test.py"], + jax_size = "medium", + deps = [ + # numpy dep, + # scipy dep, + # tensorflow dep, + "//tensorflow_probability", + "//tensorflow_probability/python/internal:test_util", + ], +) + multi_substrate_py_test( name = "half_cauchy_test", srcs = ["half_cauchy_test.py"], diff --git a/tensorflow_probability/python/distributions/__init__.py b/tensorflow_probability/python/distributions/__init__.py index f9df267a8e..cbf67b85b9 100644 --- a/tensorflow_probability/python/distributions/__init__.py +++ b/tensorflow_probability/python/distributions/__init__.py @@ -53,6 +53,7 @@ from tensorflow_probability.python.distributions.generalized_pareto import GeneralizedPareto from tensorflow_probability.python.distributions.geometric import Geometric from tensorflow_probability.python.distributions.gumbel import Gumbel +from tensorflow_probability.python.distributions.gev import GeneralizedExtremeValue from tensorflow_probability.python.distributions.half_cauchy import HalfCauchy from tensorflow_probability.python.distributions.half_normal import HalfNormal from tensorflow_probability.python.distributions.half_student_t import HalfStudentT @@ -184,6 +185,7 @@ 'GaussianProcessRegressionModel', 'VariationalGaussianProcess', 'Gumbel', + 'GeneralizedExtremeValue', 'HalfCauchy', 'HalfNormal', 'HalfStudentT', diff --git a/tensorflow_probability/python/distributions/gev.py b/tensorflow_probability/python/distributions/gev.py new file mode 100644 index 0000000000..9ce832b122 --- /dev/null +++ b/tensorflow_probability/python/distributions/gev.py @@ -0,0 +1,272 @@ +# Copyright 2020 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""The GeneralizedExtremeValue distribution class.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# Dependency imports +import numpy as np +import tensorflow.compat.v2 as tf + +from tensorflow_probability.python.bijectors import gev_cdf as gev_cdf_bijector +from tensorflow_probability.python.bijectors import invert as invert_bijector +from tensorflow_probability.python.bijectors import softplus as softplus_bijector +from tensorflow_probability.python.distributions import transformed_distribution +from tensorflow_probability.python.distributions import uniform +from tensorflow_probability.python.internal import distribution_util +from tensorflow_probability.python.internal import dtype_util +from tensorflow_probability.python.internal import parameter_properties +from tensorflow_probability.python.internal import tensor_util +from tensorflow_probability.python.internal import prefer_static as ps + +class GeneralizedExtremeValue(transformed_distribution.TransformedDistribution): + """The scalar GeneralizedExtremeValue distribution + with location `loc`, `scale` and `concentration` parameters. + + #### Mathematical details + + The probability density function (pdf) of this distribution is, + + ```none + pdf(x; loc, scale, conc) = t(x; loc, scale, conc) ** (1 + conc) * exp( + -t(x; loc, scale, conc) ) / scale + where t(x) = + * (1 + conc * (x - loc) / scale) ) ** (-1 / conc) when conc != 0; + * exp(-(x - loc) / scale) when conc = 0. + ``` + + where `concentration = conc`. + + The cumulative density function of this distribution is, + + ```cdf(x; mu, sigma) = exp(-t(x))``` + + The generalized extreme value distribution is a member of the + [location-scale family](https://en.wikipedia.org/wiki/Location-scale_family), + i.e., it can be constructed as, + + ```none + X ~ GeneralizedExtremeValue(loc=0, scale=1, concentration=conc) + Y = loc + scale * X + ``` + + #### Examples + + Examples of initialization of one or a batch of distributions. + + ```python + tfd = tfp.distributions + + # Define a single scalar generalized extreme values distribution. + dist = tfd.GeneralizedExtremeValue(loc=0., scale=3., concentration=0.9) + + # Evaluate the cdf at 1, returning a scalar. + dist.cdf(1.) + + # Define a batch of two scalar valued generalized extreme values. + # The first has loc 1 and scale 11, the second 2 and 22. + dist = tfd.GeneralizedExtremeValue(loc=[1, 2.], scale=[11, 22.]) + + # Evaluate the pdf of the first distribution on 0, and the second on 1.5, + # returning a length two tensor. + dist.prob([0, 1.5]) + + # Get 3 samples, returning a 3 x 2 tensor. + dist.sample([3]) + ``` + + Arguments are broadcast when possible. + + ```python + # Define a batch of two scalar valued Logistics. + # Both have location 1, but different scales. + dist = tfd.GeneralizedExtremeValue(loc=1., scale=1, concentration=[0, 0.9]) + + # Evaluate the pdf of both distributions on the same point, 3.0, + # returning a length 2 tensor. + dist.prob(3.0) + ``` + + """ + + def __init__(self, + loc, + scale, + concentration, + validate_args=False, + allow_nan_stats=True, + name='GeneralizedExtremeValue'): + """Construct generalized extreme value distributions with location `loc`, + scale `scale`, and concentration `concentration`. + + The parameters `loc`, `scale`, and `concentration` must be shaped in a way + that supports broadcasting (e.g. `loc + scale` + `concentration` is valid). + + Args: + loc: Floating point tensor, the location parameter of the distribution(s). + scale: Floating point tensor, the scales of the distribution(s). + scale must contain only positive values. + concentration: Floating point tensor, the concentration of + the distribution(s). + validate_args: Python `bool`, default `False`. When `True` distribution + parameters are checked for validity despite possibly degrading runtime + performance. When `False` invalid inputs may silently render incorrect + outputs. + Default value: `False`. + allow_nan_stats: Python `bool`, default `True`. When `True`, + statistics (e.g., mean, mode, variance) use the value `NaN` to + indicate the result is undefined. When `False`, an exception is raised + if one or more of the statistic's batch members are undefined. + Default value: `True`. + name: Python `str` name prefixed to Ops created by this class. + Default value: `'GeneralizedExtremeValue'`. + + Raises: + TypeError: if loc and scale are different dtypes. + """ + parameters = dict(locals()) + with tf.name_scope(name) as name: + dtype = dtype_util.common_dtype([loc, scale], dtype_hint=tf.float32) + loc = tensor_util.convert_nonref_to_tensor( + loc, name='loc', dtype=dtype) + scale = tensor_util.convert_nonref_to_tensor( + scale, name='scale', dtype=dtype) + concentration = tensor_util.convert_nonref_to_tensor( + concentration, name='concentration', dtype=dtype) + dtype_util.assert_same_float_dtype([loc, scale, concentration]) + # Positive scale is asserted by the incorporated GEV bijector. + self._gev_bijector = gev_cdf_bijector.GeneralizedExtremeValueCDF( + loc=loc, scale=scale, concentration=concentration, + validate_args=validate_args) + + # Because the uniform sampler generates samples in `[0, 1)` this would + # cause samples to lie in `(inf, -inf]` instead of `(inf, -inf)`. To fix + # this, we use `np.finfo(dtype_util.as_numpy_dtype(self.dtype).tiny` + # because it is the smallest, positive, 'normal' number. + batch_shape = distribution_util.get_broadcast_shape(loc, scale, + concentration) + super(GeneralizedExtremeValue, self).__init__( + # TODO(b/137665504): Use batch-adding meta-distribution to set the + # batch shape instead of tf.ones. + distribution=uniform.Uniform( + low=np.finfo(dtype_util.as_numpy_dtype(dtype)).tiny, + high=tf.ones(batch_shape, dtype=dtype), + allow_nan_stats=allow_nan_stats), + # The GEV bijector encodes the CDF function as the forward, + # and hence needs to be inverted. + bijector=invert_bijector.Invert( + self._gev_bijector, validate_args=validate_args), + parameters=parameters, + name=name) + + @classmethod + def _parameter_properties(cls, dtype, num_classes=None): + # pylint: disable=g-long-lambda + return dict( + loc=parameter_properties.ParameterProperties(), + scale=parameter_properties.ParameterProperties( + default_constraining_bijector_fn=( + lambda: softplus_bijector.Softplus(low=dtype_util.eps(dtype)))), + concentration=parameter_properties.ParameterProperties()) + # pylint: enable=g-long-lambda + + @property + def loc(self): + """Distribution parameter for the location.""" + return self._gev_bijector.loc + + @property + def scale(self): + """Distribution parameter for scale.""" + return self._gev_bijector.scale + + @property + def concentration(self): + """Distribution parameter for shape.""" + return self._gev_bijector.concentration + + def _entropy(self): + scale = tf.broadcast_to(self.scale, + ps.broadcast_shape(ps.shape(self.scale), + ps.shape(self.loc))) + euler_gamma = tf.constant(np.euler_gamma, self.dtype) + return 1. + tf.math.log(scale) + euler_gamma * (1. + self.concentration) + + def _log_prob(self, x): + with tf.control_dependencies(self._gev_bijector._maybe_assert_valid_x(x)): + scale = tf.convert_to_tensor(self.scale) + z = (x - self.loc) / scale + + conc = tf.convert_to_tensor(self.concentration) + equal_zero = tf.equal(conc, 0.) + safe_conc = tf.where(equal_zero, 1., conc) + log_t = tf.where(equal_zero, -z, + -tf.math.log1p(z * safe_conc) / safe_conc) + + return (conc + 1) * log_t - tf.exp(log_t) - tf.math.log(scale) + + def _mean(self): + conc = tf.convert_to_tensor(self.concentration) + equal_zero = tf.equal(conc, 0.) + less_than_one = tf.less(conc, 1.) + safe_conc = tf.where(equal_zero, 1., conc) + + mean_zero = tf.fill(tf.shape(conc), tf.constant(np.euler_gamma, self.dtype)) + mean_fin = tf.math.expm1(tf.math.lgamma(1. - safe_conc)) / safe_conc + mean_inf = tf.fill(tf.shape(conc), tf.constant(np.inf, self.dtype)) + + mean_z = tf.where(equal_zero, + mean_zero, + tf.where(less_than_one, + mean_fin, + mean_inf)) + + return self.loc + self.scale * mean_z + + def _stddev(self): + conc = tf.convert_to_tensor(self.concentration) + equal_zero = tf.equal(conc, 0.) + less_than_half = tf.less(conc, 0.5) + + g1_square = tf.exp(tf.math.lgamma(1. - conc)) ** 2 + g2 = tf.exp(tf.math.lgamma(1. - 2. * conc)) + safe_conc = tf.where(equal_zero, 1., conc) + + std_z = tf.where(equal_zero, + tf.fill(tf.shape(conc), + tf.constant(np.pi / np.sqrt(6), self.dtype)), + tf.where(less_than_half, + tf.math.sqrt(g2 - g1_square) / tf.abs(safe_conc), + tf.fill(tf.shape(conc), + tf.constant(np.inf, self.dtype))) + ) + + return self.scale * tf.ones_like(self.loc) * std_z + + def _mode(self): + conc = tf.convert_to_tensor(self.concentration) + equal_zero = tf.equal(conc, 0.) + safe_conc = tf.where(equal_zero, 1., conc) + + mode_z = tf.where(equal_zero, + tf.zeros_like(conc), + tf.math.expm1(-conc * tf.math.log1p(conc)) / safe_conc) + + return self.loc + self.scale * mode_z + + def _parameter_control_dependencies(self, is_init): + return self._gev_bijector._parameter_control_dependencies(is_init) # pylint: disable=protected-access diff --git a/tensorflow_probability/python/distributions/gev_test.py b/tensorflow_probability/python/distributions/gev_test.py new file mode 100644 index 0000000000..927b2e6b27 --- /dev/null +++ b/tensorflow_probability/python/distributions/gev_test.py @@ -0,0 +1,470 @@ +# Copyright 2020 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Tests for GEV.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# Dependency imports +import numpy as np +from scipy import stats + +import tensorflow.compat.v1 as tf1 +import tensorflow.compat.v2 as tf +import tensorflow_probability as tfp + +from tensorflow_probability.python.internal import test_util + +tfd = tfp.distributions + + +class _GEVTest(object): + + def make_tensor(self, x): + x = tf.cast(x, self._dtype) + return tf1.placeholder_with_default( + x, shape=x.shape if self._use_static_shape else None) + + def testGEVShape(self): + loc = np.array([3.0] * 5, dtype=self._dtype) + scale = np.array([3.0] * 5, dtype=self._dtype) + conc = np.array([3.0] * 5, dtype=self._dtype) + gev = tfd.GeneralizedExtremeValue(loc=loc, scale=scale, + concentration=conc, + validate_args=True) + + self.assertEqual((5,), self.evaluate(gev.batch_shape_tensor())) + self.assertEqual(tf.TensorShape([5]), gev.batch_shape) + self.assertAllEqual([], self.evaluate(gev.event_shape_tensor())) + self.assertEqual(tf.TensorShape([]), gev.event_shape) + + def testInvalidScale(self): + scale = [-.01, 0., 2.] + with self.assertRaisesOpError('Argument `scale` must be positive.'): + gev = tfd.GeneralizedExtremeValue(loc=0., scale=scale, concentration=1., + validate_args=True) + self.evaluate(gev.mean()) + + scale = tf.Variable([.01]) + self.evaluate(scale.initializer) + gev = tfd.GeneralizedExtremeValue(loc=0., scale=scale, concentration=1., + validate_args=True) + self.assertIs(scale, gev.scale) + self.evaluate(gev.mean()) + with tf.control_dependencies([scale.assign([-.01])]): + with self.assertRaisesOpError('Argument `scale` must be positive.'): + self.evaluate(gev.mean()) + + def testGEVLogPdf(self): + batch_size = 6 + loc = np.array([0.] * batch_size, dtype=self._dtype) + scale = np.array([3.] * batch_size, dtype=self._dtype) + conc = np.array([2.] * batch_size, dtype=self._dtype) + gev_dist = stats.genextreme(-conc, loc=loc, scale=scale) + x = np.array([2., 3., 4., 5., 6., 7.], dtype=self._dtype) + gev = tfd.GeneralizedExtremeValue( + loc=self.make_tensor(loc), + scale=self.make_tensor(scale), + concentration=self.make_tensor(conc), + validate_args=True) + log_pdf = gev.log_prob(self.make_tensor(x)) + self.assertAllClose( + gev_dist.logpdf(x), + self.evaluate(log_pdf)) + + pdf = gev.prob(x) + self.assertAllClose( + gev_dist.pdf(x), self.evaluate(pdf)) + + def testGEVLogPdfMultidimensional(self): + batch_size = 6 + loc = np.array([[-2.0, -4.0, -5.0]] * batch_size, dtype=self._dtype) + scale = np.array([1.0], dtype=self._dtype) + conc = np.array([[0.0, 1.0, 2.0]] * batch_size, dtype=self._dtype) + gev_dist = stats.genextreme(-conc, loc=loc, scale=scale) + x = np.array([[2., 3., 4., 5., 6., 7.]], dtype=self._dtype).T + + gev = tfd.GeneralizedExtremeValue( + loc=self.make_tensor(loc), + scale=self.make_tensor(scale), + concentration=self.make_tensor(conc), + validate_args=True) + log_pdf = gev.log_prob(self.make_tensor(x)) + self.assertAllClose( + self.evaluate(log_pdf), gev_dist.logpdf(x)) + + pdf = gev.prob(self.make_tensor(x)) + self.assertAllClose( + self.evaluate(pdf), gev_dist.pdf(x)) + + def testGEVCDF(self): + batch_size = 6 + loc = np.array([0.] * batch_size, dtype=self._dtype) + scale = np.array([3.] * batch_size, dtype=self._dtype) + conc = np.array([2.] * batch_size, dtype=self._dtype) + gev_dist = stats.genextreme(-conc, loc=loc, scale=scale) + x = np.array([2., 3., 4., 5., 6., 7.], dtype=self._dtype) + + gev = tfd.GeneralizedExtremeValue( + loc=self.make_tensor(loc), + scale=self.make_tensor(scale), + concentration=self.make_tensor(conc), + validate_args=True) + + log_cdf = gev.log_cdf(self.make_tensor(x)) + self.assertAllClose( + self.evaluate(log_cdf), gev_dist.logcdf(x)) + + cdf = gev.cdf(self.make_tensor(x)) + self.assertAllClose( + self.evaluate(cdf), gev_dist.cdf(x)) + + def testGEVCdfMultidimensional(self): + batch_size = 6 + loc = np.array([[-2.0, -4.0, -5.0]] * batch_size, dtype=self._dtype) + scale = np.array([1.0], dtype=self._dtype) + conc = np.array([[0.0, 1.0, 2.0]] * batch_size, dtype=self._dtype) + gev_dist = stats.genextreme(-conc, loc=loc, scale=scale) + x = np.array([[2., 3., 4., 5., 6., 7.]], dtype=self._dtype).T + + gev = tfd.GeneralizedExtremeValue( + loc=self.make_tensor(loc), + scale=self.make_tensor(scale), + concentration=self.make_tensor(conc), + validate_args=True) + + log_cdf = gev.log_cdf(self.make_tensor(x)) + self.assertAllClose( + self.evaluate(log_cdf), + gev_dist.logcdf(x)) + + cdf = gev.cdf(self.make_tensor(x)) + self.assertAllClose( + self.evaluate(cdf), + gev_dist.cdf(x)) + + def testGEVMean(self): + loc = np.array([2.0], dtype=self._dtype) + scale = np.array([1.5], dtype=self._dtype) + conc = np.array([-0.9, 0.0], dtype=self._dtype) + gev_dist = stats.genextreme(-conc, loc=loc, scale=scale) + + gev = tfd.GeneralizedExtremeValue( + loc=self.make_tensor(loc), + scale=self.make_tensor(scale), + concentration=self.make_tensor(conc), + validate_args=True) + self.assertAllClose(self.evaluate(gev.mean()), + gev_dist.mean()) + + conc_with_inf_mean = np.array([2.], dtype=self._dtype) + gev_with_inf_mean = tfd.GeneralizedExtremeValue( + loc=self.make_tensor(loc), + scale=self.make_tensor(scale), + concentration=self.make_tensor(conc_with_inf_mean), + validate_args=True) + self.assertAllClose(self.evaluate(gev_with_inf_mean.mean()), + [np.inf]) + + def testGEVVariance(self): + loc = np.array([2.0], dtype=self._dtype) + scale = np.array([1.5], dtype=self._dtype) + conc = np.array([-0.9, 0.0], dtype=self._dtype) + gev_dist = stats.genextreme(-conc, loc=loc, scale=scale) + + gev = tfd.GeneralizedExtremeValue( + loc=self.make_tensor(loc), + scale=self.make_tensor(scale), + concentration=self.make_tensor(conc), + validate_args=True) + + self.assertAllClose(self.evaluate(gev.variance()), + gev_dist.var()) + + conc_with_inf_var = np.array([1.5], dtype=self._dtype) + gev_with_inf_var = tfd.GeneralizedExtremeValue( + loc=self.make_tensor(loc), + scale=self.make_tensor(scale), + concentration=self.make_tensor(conc_with_inf_var), + validate_args=True) + self.assertAllClose(self.evaluate(gev_with_inf_var.variance()), + [np.inf]) + + def testGEVStd(self): + loc = np.array([2.0], dtype=self._dtype) + scale = np.array([1.5], dtype=self._dtype) + conc = np.array([-0.9, 0.0], dtype=self._dtype) + gev_dist = stats.genextreme(-conc, loc=loc, scale=scale) + + gev = tfd.GeneralizedExtremeValue( + loc=self.make_tensor(loc), + scale=self.make_tensor(scale), + concentration=self.make_tensor(conc), + validate_args=True) + + self.assertAllClose(self.evaluate(gev.stddev()), + gev_dist.std()) + + conc_with_inf_std = np.array([1.5], dtype=self._dtype) + gev_with_inf_std = tfd.GeneralizedExtremeValue( + loc=self.make_tensor(loc), + scale=self.make_tensor(scale), + concentration=self.make_tensor(conc_with_inf_std), + validate_args=True) + self.assertAllClose(self.evaluate(gev_with_inf_std.stddev()), + [np.inf]) + + def testGEVMode(self): + loc = np.array([2.0], dtype=self._dtype) + scale = np.array([1.5], dtype=self._dtype) + conc = np.array([-0.9, 0.0, 1.5], dtype=self._dtype) + + gev = tfd.GeneralizedExtremeValue( + loc=self.make_tensor(loc), + scale=self.make_tensor(scale), + concentration=self.make_tensor(conc), + validate_args=True) + + np_mode_z = np.where(conc == 0., 0., ((conc+1)**(-conc) - 1.) / conc) + np_mode = loc + np_mode_z * scale + self.assertAllClose(self.evaluate(gev.mode()), np_mode) + + def testGEVSample(self): + loc = self._dtype(4.0) + scale = self._dtype(1.0) + conc = self._dtype(0.2) + n = int(1e6) + gev_dist = stats.genextreme(-conc, loc=loc, scale=scale) + + gev = tfd.GeneralizedExtremeValue( + loc=self.make_tensor(loc), + scale=self.make_tensor(scale), + concentration=self.make_tensor(conc), + validate_args=True) + + samples = gev.sample(n, seed=test_util.test_seed()) + sample_values = self.evaluate(samples) + self.assertEqual((n,), sample_values.shape) + self.assertAllClose( + gev_dist.mean(), + sample_values.mean(), rtol=.01) + self.assertAllClose( + gev_dist.var(), + sample_values.var(), rtol=.01) + + def testGEVSampleMultidimensionalMean(self): + loc = np.array([2.0, 4.0, 5.0], dtype=self._dtype) + scale = np.array([1.0, 0.8, 0.5], dtype=self._dtype) + conc = np.array([0.2], dtype=self._dtype) + gev_dist = stats.genextreme(-conc, loc=loc, scale=scale) + n = int(1e6) + + gev = tfd.GeneralizedExtremeValue( + loc=self.make_tensor(loc), + scale=self.make_tensor(scale), + concentration=self.make_tensor(conc), + validate_args=True) + + samples = gev.sample(n, seed=test_util.test_seed()) + sample_values = self.evaluate(samples) + self.assertAllClose( + gev_dist.mean(), + sample_values.mean(axis=0), + rtol=.03, + atol=0) + + def testGEVSampleMultidimensionalVar(self): + loc = np.array([2.0, 4.0, 5.0], dtype=self._dtype) + scale = np.array([1.0, 0.8, 0.5], dtype=self._dtype) + conc = np.array([0.2], dtype=self._dtype) + gev_dist = stats.genextreme(-conc, loc=loc, scale=scale) + n = int(1e6) + + gev = tfd.GeneralizedExtremeValue( + loc=self.make_tensor(loc), + scale=self.make_tensor(scale), + concentration=self.make_tensor(conc), + validate_args=True) + + + samples = gev.sample(n, seed=test_util.test_seed()) + sample_values = self.evaluate(samples) + self.assertAllClose( + gev_dist.var(), + sample_values.var(axis=0), + rtol=.03, + atol=0) + + @test_util.numpy_disable_gradient_test + def testFiniteGradientAtDifficultPoints(self): + def make_fn(dtype, attr): + x = np.array([1.]).astype(dtype) + return lambda m, s, p: getattr( # pylint: disable=g-long-lambda + tfd.GeneralizedExtremeValue(loc=m, scale=s, + concentration=p, validate_args=True), + attr)(x) + + loc = np.array([1.0], dtype=self._dtype) + scale = np.array([1.5], dtype=self._dtype) + conc = np.array([-0.7, 0.0, 0.5, 1.], dtype=self._dtype) + + for attr in ['log_prob', 'prob', 'cdf', 'log_cdf']: + value, grads = self.evaluate(tfp.math.value_and_gradient( + make_fn(self._dtype, attr), + [self.make_tensor(loc), # loc + self.make_tensor(scale), # scale + self.make_tensor(conc)])) # conc + self.assertAllFinite(value) + self.assertAllFinite(grads[0]) # d/d loc + self.assertAllFinite(grads[1]) # d/d scale + self.assertAllFinite(grads[2]) # d/d conc + + + def testBroadcastingParams(self): + + def _check(gev_dist): + self.assertEqual(gev_dist.mean().shape, (3,)) + self.assertEqual(gev_dist.variance().shape, (3,)) + self.assertEqual(gev_dist.entropy().shape, (3,)) + self.assertEqual(gev_dist.log_prob(6.).shape, (3,)) + self.assertEqual(gev_dist.prob(6.).shape, (3,)) + self.assertEqual(gev_dist.sample( + 37, seed=test_util.test_seed()).shape, (37, 3,)) + + _check( + tfd.GeneralizedExtremeValue(loc=[ + 2., + 3., + 4., + ], scale=2., concentration=1., validate_args=True)) + _check( + tfd.GeneralizedExtremeValue(loc=3., scale=[ + 2., + 3., + 4., + ], concentration=1., validate_args=True)) + _check( + tfd.GeneralizedExtremeValue(loc=3., scale=3., concentration=[ + 2., + 3., + 4., + ], validate_args=True)) + + def testBroadcastingPdfArgs(self): + + def _assert_shape(gev_dist, arg, shape): + self.assertEqual(gev_dist.log_prob(arg).shape, shape) + self.assertEqual(gev_dist.prob(arg).shape, shape) + + def _check(gev_dist): + _assert_shape(gev_dist, 5., (3,)) + xs = np.array([5., 6., 7.], dtype=np.float32) + _assert_shape(gev_dist, xs, (3,)) + xs = np.array([xs]) + _assert_shape(gev_dist, xs, (1, 3)) + xs = xs.T + _assert_shape(gev_dist, xs, (3, 3)) + + _check( + tfd.GeneralizedExtremeValue(loc=[ + -2., + -3., + -4., + ], scale=2., concentration=1., validate_args=True)) + _check( + tfd.GeneralizedExtremeValue(loc=-6., scale=[ + 2., + 3., + 4., + ], concentration=1., validate_args=True)) + _check( + tfd.GeneralizedExtremeValue(loc=-7., scale=3., concentration=[ + 2., + 3., + 4., + ], validate_args=True)) + + def _check2d(gev_dist): + _assert_shape(gev_dist, 5., (1, 3)) + xs = np.array([5., 6., 7.], dtype=np.float32) + _assert_shape(gev_dist, xs, (1, 3)) + xs = np.array([xs]) + _assert_shape(gev_dist, xs, (1, 3)) + xs = xs.T + _assert_shape(gev_dist, xs, (3, 3)) + + _check2d( + tfd.GeneralizedExtremeValue(loc=[[ + -2., + -3., + -4., + ]], scale=2., concentration=1., validate_args=True)) + _check2d( + tfd.GeneralizedExtremeValue(loc=-7., scale=[[ + 2., + 3., + 4., + ]], concentration=1., validate_args=True)) + _check2d( + tfd.GeneralizedExtremeValue(loc=-7., scale=3., concentration=[[ + 2., + 3., + 4., + ]], validate_args=True)) + + def _check2d_rows(gev_dist): + _assert_shape(gev_dist, 5., (3, 1)) + xs = np.array([5., 6., 7.], dtype=np.float32) # (3,) + _assert_shape(gev_dist, xs, (3, 3)) + xs = np.array([xs]) # (1,3) + _assert_shape(gev_dist, xs, (3, 3)) + xs = xs.T # (3,1) + _assert_shape(gev_dist, xs, (3, 1)) + + _check2d_rows( + tfd.GeneralizedExtremeValue( + loc=[[-2.], [-3.], [-4.]], scale=2., concentration=1., + validate_args=True)) + _check2d_rows( + tfd.GeneralizedExtremeValue( + loc=-7., scale=[[2.], [3.], [4.]], concentration=1., + validate_args=True)) + _check2d_rows( + tfd.GeneralizedExtremeValue( + loc=-7., scale=3., concentration=[[2.], [3.], [4.]], + validate_args=True)) + + +@test_util.test_all_tf_execution_regimes +class GEVTestStaticShape(test_util.TestCase, _GEVTest): + _dtype = np.float32 + _use_static_shape = True + + +@test_util.test_all_tf_execution_regimes +class GEVTestFloat64StaticShape(test_util.TestCase, _GEVTest): + _dtype = np.float64 + _use_static_shape = True + + +@test_util.test_all_tf_execution_regimes +class GEVTestDynamicShape(test_util.TestCase, _GEVTest): + _dtype = np.float32 + _use_static_shape = False + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow_probability/python/distributions/hypothesis_testlib.py b/tensorflow_probability/python/distributions/hypothesis_testlib.py index 69d4cdb0dc..28f70dffa7 100644 --- a/tensorflow_probability/python/distributions/hypothesis_testlib.py +++ b/tensorflow_probability/python/distributions/hypothesis_testlib.py @@ -72,6 +72,7 @@ 'GeneralizedPareto', 'Geometric', 'Gumbel', + 'GeneralizedExtremeValue', 'HalfCauchy', 'HalfNormal', 'HalfStudentT',