/
iterated_sigmoid_centered.py
143 lines (116 loc) · 5.2 KB
/
iterated_sigmoid_centered.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
# Copyright 2018 The TensorFlow Probability Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""IteratedSigmoidCentered bijector."""
import tensorflow.compat.v2 as tf
from tensorflow_probability.python.bijectors import bijector
from tensorflow_probability.python.internal import assert_util
from tensorflow_probability.python.internal import dtype_util
from tensorflow_probability.python.internal import prefer_static as ps
__all__ = [
'IteratedSigmoidCentered',
]
class IteratedSigmoidCentered(bijector.AutoCompositeTensorBijector):
"""Bijector which applies a Stick Breaking procedure.
Given a vector `x`, transform it in to a vector `y` such that
`y[i] > 0, sum_i y[i] = 1.`. In other words, takes a vector in
`R^{k-1}` (unconstrained space) and maps it to a vector in the
unit simplex in `R^{k}`.
This transformation is centered in that it maps the zero vector
`[0., 0., ... 0.]` to the center of the simplex `[1/k, ... 1/k]`.
This bijector arises from the stick-breaking procedure for constructing
a Dirichlet distribution / Dirichlet process as defined in [Stan, 2018][1].
Example Use:
```python
bijector.IteratedSigmoidCentered().forward([0., 0., 0.])
# Result: [0.25, 0.25, 0.25, 0.25]
# Extra result: 0.25
bijector.IteratedSigmoidCentered().inverse([0.25, 0.25, 0.25, 0.25])
# Result: [0., 0., 0.]
# Extra coordinate removed.
```
At first blush it may seem like the [Invariance of domain](
https://en.wikipedia.org/wiki/Invariance_of_domain) theorem implies this
implementation is not a bijection. However, the appended dimension
makes the (forward) image non-open and the theorem does not directly apply.
#### References
[1]: Stan Development Team. 2018. Stan Modeling Language Users Guide and
Reference Manual, Version 2.18.0. http://mc-stan.org
"""
def __init__(self,
validate_args=False,
name='iterated_sigmoid'):
parameters = dict(locals())
with tf.name_scope(name) as name:
super(IteratedSigmoidCentered, self).__init__(
forward_min_event_ndims=1,
validate_args=validate_args,
parameters=parameters,
name=name)
@classmethod
def _parameter_properties(cls, dtype):
return dict()
def _forward_event_shape(self, input_shape):
if not input_shape[-1:].is_fully_defined():
return input_shape
return input_shape[:-1].concatenate(input_shape[-1] + 1)
def _forward_event_shape_tensor(self, input_shape):
return tf.concat([input_shape[:-1], [input_shape[-1] + 1]], axis=0)
def _inverse_event_shape(self, output_shape):
if not output_shape[-1:].is_fully_defined():
return output_shape
if output_shape[-1] < 1:
raise ValueError('output_shape[-1] = %d < 1' % output_shape[-1])
return output_shape[:-1].concatenate(output_shape[-1] - 1)
def _inverse_event_shape_tensor(self, output_shape):
if self.validate_args:
# It is not possible for a negative shape so we need only check < 1.
dependencies = [assert_util.assert_greater(
output_shape[-1], 0, message='Need last dimension greater than 0.')]
else:
dependencies = []
with tf.control_dependencies(dependencies):
return tf.concat([output_shape[:-1], [output_shape[-1] - 1]], axis=0)
def _forward(self, x):
# As specified in the Stan reference manual, the procedure is as follows:
# N = x.shape[-1] + 1
# z_k = sigmoid(x + log(1 / (N - k)))
# y_1 = z_1
# y_k = (1 - sum_{i=1 to k-1} y_i) * z_k
# y_N = 1 - sum_{i=1 to N-1} y_i
# TODO(b/128857065): The numerics can possibly be improved here with a
# log-space computation.
offset = -tf.math.log(
tf.cast(
tf.range(ps.shape(x)[-1], 0, delta=-1),
dtype=dtype_util.base_dtype(x.dtype)))
z = tf.math.sigmoid(x + offset)
y = z * tf.math.cumprod(1 - z, axis=-1, exclusive=True)
return tf.concat([y, 1. - tf.reduce_sum(y, axis=-1, keepdims=True)],
axis=-1)
def _inverse(self, y):
# As specified in the Stan reference manual, the procedure is as follows:
# N = y.shape[-1]
# z_k = y_k / (1 - sum_{i=1 to k-1} y_i)
# x_k = logit(z_k) - log(1 / (N - k))
offset = tf.math.log(
tf.cast(
tf.range(ps.shape(y)[-1] - 1, 0, delta=-1),
dtype=dtype_util.base_dtype(y.dtype)))
z = y / (1. - tf.math.cumsum(y, axis=-1, exclusive=True))
return tf.math.log(z[..., :-1]) - tf.math.log1p(-z[..., :-1]) + offset
def _inverse_log_det_jacobian(self, y):
z = y / (1. - tf.math.cumsum(y, axis=-1, exclusive=True))
return tf.reduce_sum(
(-tf.math.log(y[..., :-1]) - tf.math.log1p(-z[..., :-1])), axis=-1)