-
Notifications
You must be signed in to change notification settings - Fork 1.1k
/
poisson.py
472 lines (391 loc) · 17.4 KB
/
poisson.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
# Copyright 2018 The TensorFlow Probability Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""The Poisson distribution class."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow.compat.v2 as tf
from tensorflow_probability.python.distributions import distribution
from tensorflow_probability.python.internal import assert_util
from tensorflow_probability.python.internal import batched_rejection_sampler as brs
from tensorflow_probability.python.internal import distribution_util
from tensorflow_probability.python.internal import dtype_util
from tensorflow_probability.python.internal import implementation_selection
from tensorflow_probability.python.internal import reparameterization
from tensorflow_probability.python.internal import samplers
from tensorflow_probability.python.internal import tensor_util
__all__ = [
'Poisson',
]
def _random_poisson_cpu(
shape,
rates=None,
log_rates=None,
output_dtype=tf.float32,
seed=None,
name=None):
"""Sample using *fast* `tf.random.stateless_poisson`."""
with tf.name_scope(name or 'poisson_cpu'):
if rates is None:
rates = tf.math.exp(log_rates)
shape = tf.concat([shape, tf.shape(rates)], axis=0)
return tf.random.stateless_poisson(
shape=shape, seed=seed, lam=rates, dtype=output_dtype)
def _random_poisson_noncpu(
shape,
rates=None,
log_rates=None,
output_dtype=tf.float32,
seed=None,
name=None):
"""Sample using XLA-friendly python-based rejection sampler."""
with tf.name_scope(name or 'poisson_noncpu'):
if log_rates is None:
log_rates = tf.math.log(rates)
shape = tf.concat([shape, tf.shape(log_rates)], axis=0)
good_params_mask = ~tf.math.is_nan(log_rates)
internal_dtype = tf.float64
seed_lo, seed_hi = samplers.split_seed(seed)
# First, we sample the values for which rate >= 10.
# When replacing NaN or < 10 values, use 100 for log rate, since that leads
# to a high-likelihood of the rejection sampler accepting on the first pass.
high_params_mask = good_params_mask & (log_rates >= np.log(10.))
cast_log_rates = tf.cast(log_rates, internal_dtype)
safe_log_rates = tf.where(high_params_mask, cast_log_rates, 100.)
high_rate_samples = _random_poisson_high_rate(
shape,
log_rate=safe_log_rates,
internal_dtype=internal_dtype,
seed=seed_hi)
high_rate_samples = tf.cast(high_rate_samples, output_dtype)
# Next, we sample the values for which rate < 10. When replacing NaN or high
# values, use a small number so that the sum-of-exponentials sampler
# terminates on the first pass with high likelihood.
low_params_mask = good_params_mask & ~high_params_mask
safe_rate = tf.where(low_params_mask, tf.math.exp(cast_log_rates), 1e-5)
low_rate_samples = _random_poisson_low_rate(
shape, rate=safe_rate, internal_dtype=internal_dtype, seed=seed_lo)
low_rate_samples = tf.cast(low_rate_samples, output_dtype)
samples = tf.where(
good_params_mask,
tf.where(high_params_mask, high_rate_samples, low_rate_samples), np.nan)
return samples
# tf.function required to access Grappler's implementation_selector.
@tf.function(autograph=False)
def _random_poisson(
shape,
rates=None,
log_rates=None,
output_dtype=tf.float32,
seed=None,
name=None):
"""Sample a poisson, CPU specialized to stateless_poisson.
Args:
shape: Shape of the full sample output. Trailing dims should match the
broadcast shape of `counts` with `probs|logits`.
rates: Batch of rates for Poisson distribution.
log_rates: Batch of log rates for Poisson distribution.
output_dtype: DType of samples.
seed: int or Tensor seed.
name: Optional name for related ops.
Returns:
samples: Samples from poisson distributions.
runtime_used_for_sampling: One of `implementation_selection._RUNTIME_*`.
"""
with tf.name_scope(name or 'random_poisson'):
seed = samplers.sanitize_seed(seed)
shape = tf.convert_to_tensor(shape, dtype_hint=tf.int32, name='shape')
params = dict(shape=shape, rates=rates, log_rates=log_rates,
output_dtype=output_dtype, seed=seed, name=name)
sampler_impl = implementation_selection.implementation_selecting(
fn_name='poisson',
default_fn=_random_poisson_noncpu,
cpu_fn=_random_poisson_cpu)
return sampler_impl(**params)
class Poisson(distribution.Distribution):
"""Poisson distribution.
The Poisson distribution is parameterized by an event `rate` parameter.
#### Mathematical Details
The probability mass function (pmf) is,
```none
pmf(k; lambda, k >= 0) = (lambda^k / k!) / Z
Z = exp(lambda).
```
where `rate = lambda` and `Z` is the normalizing constant.
"""
def __init__(self,
rate=None,
log_rate=None,
interpolate_nondiscrete=True,
validate_args=False,
allow_nan_stats=True,
name='Poisson'):
"""Initialize a batch of Poisson distributions.
Args:
rate: Floating point tensor, the rate parameter. `rate` must be positive.
Must specify exactly one of `rate` and `log_rate`.
log_rate: Floating point tensor, the log of the rate parameter.
Must specify exactly one of `rate` and `log_rate`.
interpolate_nondiscrete: Python `bool`. When `False`,
`log_prob` returns `-inf` (and `prob` returns `0`) for non-integer
inputs. When `True`, `log_prob` evaluates the continuous function
`k * log_rate - lgamma(k+1) - rate`, which matches the Poisson pmf
at integer arguments `k` (note that this function is not itself
a normalized probability log-density).
Default value: `True`.
validate_args: Python `bool`. When `True` distribution
parameters are checked for validity despite possibly degrading runtime
performance. When `False` invalid inputs may silently render incorrect
outputs.
Default value: `False`.
allow_nan_stats: Python `bool`. When `True`, statistics
(e.g., mean, mode, variance) use the value "`NaN`" to indicate the
result is undefined. When `False`, an exception is raised if one or
more of the statistic's batch members are undefined.
Default value: `True`.
name: Python `str` name prefixed to Ops created by this class.
Raises:
ValueError: if none or both of `rate`, `log_rate` are specified.
TypeError: if `rate` is not a float-type.
TypeError: if `log_rate` is not a float-type.
"""
parameters = dict(locals())
if (rate is None) == (log_rate is None):
raise ValueError('Must specify exactly one of `rate` and `log_rate`.')
with tf.name_scope(name) as name:
dtype = dtype_util.common_dtype([rate, log_rate], dtype_hint=tf.float32)
if not dtype_util.is_floating(dtype):
raise TypeError('[log_]rate.dtype ({}) is a not a float-type.'.format(
dtype_util.name(dtype)))
self._rate = tensor_util.convert_nonref_to_tensor(
rate, name='rate', dtype=dtype)
self._log_rate = tensor_util.convert_nonref_to_tensor(
log_rate, name='log_rate', dtype=dtype)
self._interpolate_nondiscrete = interpolate_nondiscrete
super(Poisson, self).__init__(
dtype=dtype,
reparameterization_type=reparameterization.NOT_REPARAMETERIZED,
validate_args=validate_args,
allow_nan_stats=allow_nan_stats,
parameters=parameters,
name=name)
@classmethod
def _params_event_ndims(cls):
return dict(rate=0, log_rate=0)
@property
def rate(self):
"""Rate parameter."""
return self._rate
@property
def log_rate(self):
"""Log rate parameter."""
return self._log_rate
@property
def interpolate_nondiscrete(self):
"""Interpolate (log) probs on non-integer inputs."""
return self._interpolate_nondiscrete
def _batch_shape_tensor(self):
x = self._rate if self._log_rate is None else self._log_rate
return tf.shape(x)
def _batch_shape(self):
x = self._rate if self._log_rate is None else self._log_rate
return x.shape
def _event_shape_tensor(self):
return tf.constant([], dtype=tf.int32)
def _event_shape(self):
return tf.TensorShape([])
def _log_prob(self, x):
log_rate = self._log_rate_parameter_no_checks()
log_probs = (self._log_unnormalized_prob(x, log_rate) -
self._log_normalization(log_rate))
if not self.interpolate_nondiscrete:
# Ensure the gradient wrt `rate` is zero at non-integer points.
log_probs = tf.where(
tf.math.is_inf(log_probs),
dtype_util.as_numpy_dtype(log_probs.dtype)(-np.inf),
log_probs)
return log_probs
def _log_cdf(self, x):
return tf.math.log(self.cdf(x))
def _cdf(self, x):
# CDF is the probability that the Poisson variable is less or equal to x.
# For fractional x, the CDF is equal to the CDF at n = floor(x).
# For negative x, the CDF is zero, but tf.igammac gives NaNs, so we impute
# the values and handle this case explicitly.
safe_x = tf.maximum(x if self.interpolate_nondiscrete else tf.floor(x), 0.)
cdf = tf.math.igammac(1. + safe_x, self._rate_parameter_no_checks())
return tf.where(x < 0., tf.zeros_like(cdf), cdf)
def _log_normalization(self, log_rate):
return tf.exp(log_rate)
def _log_unnormalized_prob(self, x, log_rate):
# The log-probability at negative points is always -inf.
# Catch such x's and set the output value accordingly.
safe_x = tf.maximum(x if self.interpolate_nondiscrete else tf.floor(x), 0.)
y = tf.math.multiply_no_nan(log_rate, safe_x) - tf.math.lgamma(1. + safe_x)
return tf.where(
tf.equal(x, safe_x), y, dtype_util.as_numpy_dtype(y.dtype)(-np.inf))
def _mean(self):
return self._rate_parameter_no_checks()
def _variance(self):
return self._rate_parameter_no_checks()
@distribution_util.AppendDocstring(
"""Note: when `rate` is an integer, there are actually two modes: `rate`
and `rate - 1`. In this case we return the larger, i.e., `rate`.""")
def _mode(self):
return tf.floor(self._rate_parameter_no_checks())
def _sample_n(self, n, seed=None):
seed = samplers.sanitize_seed(seed)
return _random_poisson(
shape=tf.convert_to_tensor([n]),
rates=(None if self._rate is None else
tf.convert_to_tensor(self._rate)),
log_rates=(None if self._log_rate is None else
tf.convert_to_tensor(self._log_rate)),
output_dtype=self.dtype,
seed=seed)[0]
def rate_parameter(self, name=None):
"""Rate vec computed from non-`None` input arg (`rate` or `log_rate`)."""
with self._name_and_control_scope(name or 'rate_parameter'):
return self._rate_parameter_no_checks()
def _rate_parameter_no_checks(self):
if self._rate is None:
return tf.exp(self._log_rate)
return tf.identity(self._rate)
def log_rate_parameter(self, name=None):
"""Log-rate vec computed from non-`None` input arg (`rate`, `log_rate`)."""
with self._name_and_control_scope(name or 'log_rate_parameter'):
return self._log_rate_parameter_no_checks()
def _log_rate_parameter_no_checks(self):
if self._log_rate is None:
return tf.math.log(self._rate)
return tf.identity(self._log_rate)
def _default_event_space_bijector(self):
return
def _parameter_control_dependencies(self, is_init):
if not self.validate_args:
return []
assertions = []
if self._rate is not None:
if is_init != tensor_util.is_ref(self._rate):
assertions.append(assert_util.assert_non_negative(
self._rate,
message='Argument `rate` must be non-negative.'))
return assertions
def _sample_control_dependencies(self, x):
assertions = []
if not self.validate_args:
return assertions
assertions.extend(distribution_util.assert_nonnegative_integer_form(x))
return assertions
def _random_poisson_high_rate(sample_shape,
log_rate,
internal_dtype=tf.float64,
seed=None):
"""Samples from the Poisson distribution using transformed rejection sampling.
Given a CDF F(x), and G(x), a dominating distribution chosen such that it is
close to the inverse CDF F^-1(x), compute the following steps:
1) Generate U and V, two independent random variates. Set U = U - 0.5 (this
step isn't strictly necessary, but is done to make some calculations symmetric
and convenient. Henceforth, G is defined on [-0.5, 0.5]).
2) If V <= alpha * F'(G(U)) * G'(U), return floor(G(U)), else return to
step 1. alpha is the acceptance probability of the rejection algorithm.
The dominating distribution in this case:
G(u) = (2 * a / (2 - |u|) + b) * u + c
For more details on transformed rejection, see [1].
Args:
sample_shape: The output sample shape. Must broadcast with `log_rate`.
log_rate: Floating point tensor, log rate.
internal_dtype: dtype to use for internal computations.
seed: (optional) The random seed.
Returns:
Samples from the poisson distribution using transformed rejection.
#### References
[1]: W. Hormann, G. Derflinger, The Transformed Rejection Method For
Generating Random Variables, An Alternative To The Ratio Of Uniforms Method
(1994), Manuskript, Institut f. Statistik, Wirtschaftsuniversitat
"""
rate = tf.math.exp(log_rate)
b = 0.931 + 2.53 * tf.math.exp(0.5 * log_rate)
a = -0.059 + 0.02483 * b
inverse_alpha = 1.1239 + 1.1328 / (b - 3.4)
def generate_and_test_samples(seed):
"""Generate and test samples."""
u_seed, v_seed = samplers.split_seed(seed)
u = samplers.uniform(sample_shape, dtype=internal_dtype, seed=u_seed)
u = u - 0.5
u_shifted = 0.5 - tf.math.abs(u)
v = samplers.uniform(sample_shape, dtype=internal_dtype, seed=v_seed)
k = tf.math.floor(((2. * a) / u_shifted + b) * u + rate + 0.43)
good_sample_mask = (u_shifted >= 0.07) & (v <= 0.9277 - 3.6224 / (b - 2.))
s = tf.math.log(v * inverse_alpha / (a / tf.math.square(u_shifted) + b))
t = -rate + k * log_rate - tf.math.lgamma(k + 1)
good_sample_mask = good_sample_mask | (s <= t)
# Make sure the sample is within bounds.
good_sample_mask = good_sample_mask & (k >= 0) & ((u_shifted >= 0.013) |
(v <= u_shifted))
return k, good_sample_mask
samples = brs.batched_las_vegas_algorithm(
generate_and_test_samples, seed=seed)[0]
return samples
def _random_poisson_low_rate(sample_shape,
rate,
internal_dtype=tf.float64,
seed=None):
"""Samples from the Poisson distribution using Knuth's algorithm.
We use an algorithm attributed to Knuth: Seminumerical Algorithms. Art of
Computer Programming, Volume 2. This algorithm runs in O(rate) time, and
requires O(rate) uniform variates. This algorithm is performant for rate ~<10.
Given a Poisson process, the time between events is exponentially distributed.
If we have a Poisson process with rate lambda, then, the time between events
is distributed as Exp(lambda). If X ~ Uniform(0, 1), then Y ~ Exp(lambda)
where Y = -log(X) / lambda. Thus, to simulate a Poisson draw, we can sample
X_i ~ Exp(lambda), and we will haver N ~ Poisson(lambda), where N is the
smallest number such that sum_i^N X_i > 1.
Args:
sample_shape: The output sample shape. Must broadcast with `rate`.
rate: Floating point tensor, rate.
internal_dtype: (optional) dtype to use for internal computations.
seed: (optional) The random seed.
Returns:
Samples from the poisson distribution.
"""
exp_neg_rate = tf.math.exp(-rate)
def loop_body(should_continue, samples, prod, num_iters, seed):
u_seed, next_seed = samplers.split_seed(seed)
prod = prod * samplers.uniform(
sample_shape, dtype=internal_dtype, seed=u_seed)
accept = should_continue & (prod <= exp_neg_rate)
samples = tf.where(accept, num_iters, samples)
return [
should_continue & (~accept), samples, prod, num_iters + 1, next_seed
]
_, samples, _, _, _ = tf.while_loop(
cond=lambda should_continue, *ignore: tf.reduce_any(should_continue),
body=loop_body,
loop_vars=[
tf.ones(sample_shape, dtype=tf.bool), # should_continue
tf.zeros(sample_shape, dtype=tf.int32), # samples
tf.ones(sample_shape, dtype=internal_dtype), # prod
tf.zeros([], dtype=tf.int32), # num_iters
seed, # seed
],
# Using a Chernoff-like bound, we can show that for lambda < 10,
# Pr[X >= lambda + n] <= exp(-n^2 / 2(lambda + n)) < exp(-90). Hence,
# there is miniscule probability that, even after a union bound over
# batch size, a poisson sample with rate < 10 would attain a value > 200.
maximum_iterations=200,
)
return samples