/
round_adapters.py
288 lines (223 loc) · 8.63 KB
/
round_adapters.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
# Copyright 2020 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Distribution adapters for (soft) round functions."""
import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow_compression.python.distributions import deep_factorized
from tensorflow_compression.python.distributions import helpers
from tensorflow_compression.python.distributions import uniform_noise
from tensorflow_compression.python.ops import round_ops
__all__ = [
"MonotonicAdapter",
"RoundAdapter",
"NoisyRoundedNormal",
"NoisyRoundedDeepFactorized",
"SoftRoundAdapter",
"NoisySoftRoundedNormal",
"NoisySoftRoundedDeepFactorized",
]
class MonotonicAdapter(tfp.distributions.Distribution):
"""Adapt a continuous distribution via an ascending monotonic function.
This is described in Appendix E. in the paper
> "Universally Quantized Neural Compression"<br />
> Eirikur Agustsson & Lucas Theis<br />
> https://arxiv.org/abs/2006.09952
"""
invertible = True # Set to false if the transform is not invertible.
def __init__(self, base, name="MonotonicAdapter"):
"""Initializer.
Args:
base: A `tfp.distributions.Distribution` object representing a
continuous-valued random variable.
name: String. A name for this distribution.
"""
parameters = dict(locals())
self._base = base
super().__init__(
dtype=base.dtype,
reparameterization_type=base.reparameterization_type,
validate_args=base.validate_args,
allow_nan_stats=base.allow_nan_stats,
parameters=parameters,
name=name,
)
@property
def base(self):
"""The base distribution."""
return self._base
def transform(self, x):
"""The forward transform."""
raise NotImplementedError()
def inverse_transform(self, y):
"""The backward transform."""
# Let f(x) = self.transform(x)
# Then g(y) = self.inverse_transform(y) is defined as
# g(y) := inf_x { x : f(x) >= y }
# which is just the inverse of `f` if it is invertible.
raise NotImplementedError()
def _batch_shape_tensor(self):
return self.base.batch_shape_tensor()
def _batch_shape(self):
return self.base.batch_shape
def _event_shape_tensor(self):
return self.base.event_shape_tensor()
def _event_shape(self):
return self.base.event_shape
def _sample_n(self, n, seed=None):
with tf.name_scope("round"):
n = tf.convert_to_tensor(n, name="n")
samples = self.base.sample(n, seed=seed)
return self.transform(samples)
def _prob(self, *args, **kwargs):
raise NotImplementedError
def _log_prob(self, *args, **kwargs):
raise NotImplementedError
# pylint: disable=protected-access
def _cdf(self, y):
# Let f be the forward transform and g the inverse.
# Then we have:
# P( f(x) <= y )
# P( g(f(x)) <= g(y) )
# = P( x <= g(y) )
return self.base._cdf(self.inverse_transform(y))
def _log_cdf(self, y):
return self.base._log_cdf(self.inverse_transform(y))
def _survival_function(self, y):
return self.base._survival_function(self.inverse_transform(y))
def _log_survival_function(self, y):
return self.base._log_survival_function(self.inverse_transform(y))
def _quantile(self, value):
if not self.invertible:
raise NotImplementedError()
# We have:
# P( x <= z ) = value
# if and only if
# P( f(x) <= f(z) ) = value
return self.transform(self.base._quantile(value))
def _mode(self):
# Same logic as for _quantile.
if not self.invertible:
raise NotImplementedError()
return self.transform(self.base._mode())
def _quantization_offset(self):
# Same logic as for _quantile.
if not self.invertible:
raise NotImplementedError()
return self.transform(helpers.quantization_offset(self.base))
def _lower_tail(self, tail_mass):
# Same logic as for _quantile.
if not self.invertible:
raise NotImplementedError()
return self.transform(helpers.lower_tail(self.base, tail_mass))
def _upper_tail(self, tail_mass):
# Same logic as for _quantile.
if not self.invertible:
raise NotImplementedError()
return self.transform(helpers.upper_tail(self.base, tail_mass))
# pylint: enable=protected-access
@classmethod
def _parameter_properties(cls, dtype=tf.float32, num_classes=None):
raise NotImplementedError(
f"`{cls.__name__}` does not implement `_parameter_properties`.")
class RoundAdapter(MonotonicAdapter):
"""Continuous density function + round."""
invertible = False
def transform(self, x):
return tf.round(x)
def inverse_transform(self, y):
# Let f(x) = round(x)
# Then g(y) = inverse_transform(y) is defined as
# g(y) := inf_x { x : f(x) >= y }
# For f = round, we have
# round(x) >= y
# <=> round(x) >= ceil(y)
# so g(y) = inf_x { x: round(x) >= ceil(y) }
# = ceil(y)-0.5
# Alternative derivation:
# P( round(x) <= y )
# = P( round(x) <= floor(y) )
# = P( x <= floor(y)+0.5 )
# = P( x <= ceil(y)-0.5 )
# = P( x <= inverse_transform(y) )
return tf.math.ceil(y) - 0.5
def _quantization_offset(self):
return tf.convert_to_tensor(0.0, dtype=self.dtype)
def _lower_tail(self, tail_mass):
return tf.math.floor(helpers.lower_tail(self.base, tail_mass))
def _upper_tail(self, tail_mass):
return tf.math.ceil(helpers.upper_tail(self.base, tail_mass))
class NoisyRoundAdapter(uniform_noise.UniformNoiseAdapter):
"""Uniform noise + round."""
def __init__(self, base, name="NoisyRoundAdapter"):
"""Initializer.
Args:
base: A `tfp.distributions.Distribution` object representing a
continuous-valued random variable.
name: String. A name for this distribution.
"""
super().__init__(RoundAdapter(base), name=name)
class NoisyRoundedDeepFactorized(NoisyRoundAdapter):
"""Rounded `DeepFactorized` + uniform noise."""
def __init__(self, name="NoisyRoundedDeepFactorized", **kwargs):
prior = deep_factorized.DeepFactorized(**kwargs)
super().__init__(base=prior, name=name)
class NoisyRoundedNormal(NoisyRoundAdapter):
"""Rounded normal distribution + uniform noise."""
def __init__(self, name="NoisyRoundedNormal", **kwargs):
super().__init__(base=tfp.distributions.Normal(**kwargs), name=name)
class SoftRoundAdapter(MonotonicAdapter):
"""Differentiable approximation to round."""
def __init__(self, base, alpha, name="SoftRoundAdapter"):
"""Initializer.
Args:
base: A `tfp.distributions.Distribution` object representing a
continuous-valued random variable.
alpha: Float or tf.Tensor. Controls smoothness of the approximation.
name: String. A name for this distribution.
"""
super().__init__(base=base, name=name)
self._alpha = alpha
def transform(self, x):
return round_ops.soft_round(x, self._alpha)
def inverse_transform(self, y):
return round_ops.soft_round_inverse(y, self._alpha)
class NoisySoftRoundAdapter(uniform_noise.UniformNoiseAdapter):
"""Uniform noise + differentiable approximation to round."""
def __init__(self, base, alpha, name="NoisySoftRoundAdapter"):
"""Initializer.
Args:
base: A `tfp.distributions.Distribution` object representing a
continuous-valued random variable.
alpha: Float or tf.Tensor. Controls smoothness of soft round.
name: String. A name for this distribution.
"""
super().__init__(SoftRoundAdapter(base, alpha), name=name)
class NoisySoftRoundedNormal(NoisySoftRoundAdapter):
"""Soft rounded normal distribution + uniform noise."""
def __init__(self, alpha=5.0, name="NoisySoftRoundedNormal", **kwargs):
super().__init__(
base=tfp.distributions.Normal(**kwargs),
alpha=alpha,
name=name)
class NoisySoftRoundedDeepFactorized(NoisySoftRoundAdapter):
"""Soft rounded `DeepFactorized` + uniform noise."""
def __init__(self,
alpha=5.0,
name="NoisySoftRoundedDeepFactorized",
**kwargs):
super().__init__(
base=deep_factorized.DeepFactorized(**kwargs),
alpha=alpha,
name=name)