/
lambertw_f.py
303 lines (270 loc) · 12.6 KB
/
lambertw_f.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
# Copyright 2020 The TensorFlow Probability Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""The Lambert W x F distribution class."""
# Dependency imports
import numpy as np
import tensorflow.compat.v2 as tf
from tensorflow_probability.python.bijectors import identity as identity_bijector
from tensorflow_probability.python.bijectors import lambertw_transform
from tensorflow_probability.python.bijectors import softplus as softplus_bijector
from tensorflow_probability.python.distributions import normal
from tensorflow_probability.python.distributions import transformed_distribution
from tensorflow_probability.python.internal import assert_util
from tensorflow_probability.python.internal import distribution_util
from tensorflow_probability.python.internal import dtype_util
from tensorflow_probability.python.internal import parameter_properties
from tensorflow_probability.python.internal import tensor_util
__all__ = [
"LambertWDistribution",
"LambertWNormal",
]
class LambertWDistribution(transformed_distribution.TransformedDistribution):
"""Implements a general heavy-tail Lambert W x F distribution.
Lambert W x F random variables are a transformed version of a random variables
with distribution F that have heavier tails. In particular, they are defined
as a (non-linear) transformation of random variables X with distribution F.
It therefore is straightforward to implement Lambert W x F distributions as a
particular TransformedDistribution, where the input can be specified by user
as any TensorFlow Distribution class.
### Mathematical Details
Let X be a random variable following distribution F with mean mu
and standard deviation sigma, define as U = (X-mu)/sigma its zero-mean,
unit-variance version. Then
Y = (U * exp (delta/2 * U^2)) * sigma + mu
is a location-scale heavy-tailed Lambert W x F with parameters mu,
sigma and delta, where delta can take any non-negative real value. In
particular, for delta = 0, the Lambert W x F distribution reduces to the
F distribution. That is F distributions are a subset of Lambert W x
F distributions.
See `tfp.bijectors.LambertWTail` for details on the transformation.
### References:
[1]: Goerg, G.M., 2011. Lambert W random variables - a new family of
generalized skewed distributions with applications to risk estimation.
The Annals of Applied Statistics, 5(3), pp.2197-2230.
[2]: Goerg, G.M., 2015. The Lambert way to Gaussianize heavy-tailed data with
the inverse of Tukey's h transformation as a special case. The Scientific
World Journal.
"""
def __init__(self,
distribution,
shift,
scale,
tailweight=None,
validate_args=False,
allow_nan_stats=True,
name="LambertWDistribution"):
"""Initializes the class.
Args:
distribution: `tf.Distribution`-like instance. Distribution F that is
transformed to produce this Lambert W x F distribution.
shift: shift that should be applied before & after tail transformation.
For a location-scale family `distribution` (e.g., `Normal` or
`StudentT`) this usually is set as the mean / location parameter. For a
scale family `distribution` (e.g., `Gamma` or `Fisher`) this must be
set to 0 to guarantee a proper transformation on the positive
real-line.
scale: scaling factor that should be applied before & after the tail
trarnsformation. Usually the standard deviation or scaling parameter
of the `distribution`.
tailweight: Tail parameter `delta` of the resulting Lambert W x F
distribution(s).
validate_args: Python `bool`, default `False`. When `True` distribution
parameters are checked for validity despite possibly degrading runtime
performance. When `False` invalid inputs may silently render incorrect
outputs.
allow_nan_stats: Python `bool`, default `True`. When `True`,
statistics (e.g., mean, mode, variance) use the value '`NaN`' to
indicate the result is undefined. When `False`, an exception is raised
if one or more of the statistic's batch members are undefined.
name: A name for the operation (optional).
"""
parameters = dict(locals())
with tf.name_scope(name) as name:
dtype = dtype_util.common_dtype([tailweight, shift, scale], tf.float32)
tailweight = 0. if tailweight is None else tailweight
self._tailweight = tensor_util.convert_nonref_to_tensor(
tailweight, name="tailweight", dtype=dtype)
self._shift = tensor_util.convert_nonref_to_tensor(
shift, name="shift", dtype=dtype)
self._scale = tensor_util.convert_nonref_to_tensor(
scale, name="scale", dtype=dtype)
dtype_util.assert_same_float_dtype((self.tailweight, self.shift,
self.scale))
self._allow_nan_stats = allow_nan_stats
super(LambertWDistribution, self).__init__(
distribution=distribution,
bijector=lambertw_transform.LambertWTail(
shift=shift, scale=scale,
tailweight=tailweight,
validate_args=validate_args),
parameters=parameters,
validate_args=validate_args,
name=name)
@classmethod
def _parameter_properties(cls, dtype, num_classes=None):
# pylint: disable=g-long-lambda
return dict(
distribution=parameter_properties.BatchedComponentProperties(),
shift=parameter_properties.ParameterProperties(),
scale=parameter_properties.ParameterProperties(
default_constraining_bijector_fn=(
lambda: softplus_bijector.Softplus(low=dtype_util.eps(dtype)))),
tailweight=parameter_properties.ParameterProperties(
default_constraining_bijector_fn=(
lambda: softplus_bijector.Softplus(low=dtype_util.eps(dtype)))))
# pylint: enable=g-long-lambda
@property
def allow_nan_stats(self):
return self._allow_nan_stats
@property
def shift(self):
"""Distribution parameter for the shift before & after transformation."""
return self._shift
@property
def scale(self):
"""Distribution parameter for the scaling before & after transformation."""
return self._scale
@property
def tailweight(self):
"""Distribution parameter for the tail parameter delta."""
return self._tailweight
experimental_is_sharded = False
class LambertWNormal(LambertWDistribution):
"""Implements a location-scale heavy-tail Lambert W x Normal distribution."""
def __init__(self,
loc,
scale,
tailweight=None,
validate_args=False,
allow_nan_stats=True,
name="LambertWNormal"):
"""Initializes the class.
See `tfp.distributions.LambertWDistribution` for details.
Args:
loc: location parameter `loc` of the Normal distribution(s). This
coincides with the location parameter of the resulting LambertWNormal.
scale: scale parameter `scale` of the Normal distribution(s).
tailweight: Tail parameter `delta` of the distribution(s). If `None`, it
defaults to 0, which implies LambertWNormal == Normal.
validate_args: Python `bool`, default `False`. When `True` distribution
parameters are checked for validity despite possibly degrading runtime
performance. When `False` invalid inputs may silently render incorrect
outputs.
allow_nan_stats: Python `bool`, default `True`. When `True`,
statistics (e.g., mean, mode, variance) use the value '`NaN`' to
indicate the result is undefined. When `False`, an exception is raised
if one or more of the statistic's batch members are undefined.
name: A name for the operation (optional).
"""
parameters = dict(locals())
with tf.name_scope(name) as name:
dtype = dtype_util.common_dtype([tailweight, loc, scale], tf.float32)
super(LambertWNormal, self).__init__(
distribution=normal.Normal(loc=loc, scale=scale),
shift=loc,
scale=scale,
tailweight=tailweight,
validate_args=validate_args,
allow_nan_stats=allow_nan_stats,
name=name)
self._parameters = parameters
self._loc = tensor_util.convert_nonref_to_tensor(
loc, name="loc", dtype=dtype)
dtype_util.assert_same_float_dtype((self.tailweight, self.loc,
self.scale))
@property
def loc(self):
"""Location parameter of the Lambert W x Normal distribution."""
return self._loc
@classmethod
def _parameter_properties(cls, dtype, num_classes=None):
# pylint: disable=g-long-lambda
return dict(
loc=parameter_properties.ParameterProperties(),
scale=parameter_properties.ParameterProperties(
default_constraining_bijector_fn=(
lambda: softplus_bijector.Softplus(low=dtype_util.eps(dtype)))),
tailweight=parameter_properties.ParameterProperties(
default_constraining_bijector_fn=(
lambda: softplus_bijector.Softplus(low=dtype_util.eps(dtype)))))
# pylint: enable=g-long-lambda
@distribution_util.AppendDocstring(
"""The mean of Lambert W x Normal equals `loc` if `tailweight > 1`,
otherwise it is `NaN`. If `self.allow_nan_stats=True`, then an exception
will be raised rather than returning `NaN`.""")
def _mean(self):
tailweight = tf.convert_to_tensor(self.tailweight)
loc = tf.convert_to_tensor(self.loc)
mean = loc * tf.ones(self.batch_shape, dtype=self.dtype)
if self.allow_nan_stats:
return tf.where(
tailweight < 1.,
mean,
dtype_util.as_numpy_dtype(self.dtype)(np.nan))
else:
return distribution_util.with_dependencies([
assert_util.assert_less(
tf.ones([], dtype=self.dtype),
tailweight,
message="mean not defined for components of tailweight >= 1"),
], mean)
@distribution_util.AppendDocstring("""
The variance for Lambert W x Normal is finite if `tailweight < 0.5`. For
`0.5 <= tailweight < 1` it is infinite, and for `tailweight > 1` it is
undefined (since mean does not exist either).
""")
def _variance(self):
tailweight = tf.convert_to_tensor(self.tailweight)
scale = tf.convert_to_tensor(self.scale)
# For tail < 0.5, the variance is finite. See Eq (18) in
# https://www.hindawi.com/journals/tswj/2015/909231/
var = (tf.cast(tf.pow(1. - 2. * tailweight, - 3. / 2.), dtype=self.dtype) *
tf.math.square(scale))
# We need to put the tf.where inside the outer tf.where to ensure we never
# hit a NaN in the gradient.
result_where_defined = tf.where(
tailweight < 0.5,
var,
tf.convert_to_tensor(np.inf, dtype=self.dtype))
if self.allow_nan_stats:
ans = tf.where(
tailweight < 1.0,
result_where_defined,
tf.convert_to_tensor(np.nan, self.dtype))
else:
ans = distribution_util.with_dependencies([
assert_util.assert_greater_equal(
tf.ones([], dtype=self.dtype),
tailweight,
message="variance not defined for components of tailweight >= 1"),
], result_where_defined)
return tf.broadcast_to(ans, self._batch_shape_tensor())
def _mode(self):
# Mode always exists (for any tail parameter) and equals the location / mean
# independent of the tail parameter.
loc = tf.convert_to_tensor(self.loc)
return tf.broadcast_to(loc, self.batch_shape)
def _parameter_control_dependencies(self, is_init):
if not self.validate_args:
return []
assertions = []
if is_init != tensor_util.is_ref(self._tailweight):
assertions.append(assert_util.assert_greater_equal(
self._tailweight, tf.zeros([], dtype=self.dtype),
message="Argument `tailweight` must be non-negative."))
return assertions
def _default_event_space_bijector(self):
# TODO(b/145620027) Finalize choice of bijector.
return identity_bijector.Identity(validate_args=self.validate_args)