-
Notifications
You must be signed in to change notification settings - Fork 1.1k
/
discrete_cosine_transform.py
88 lines (69 loc) · 3.19 KB
/
discrete_cosine_transform.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
# Copyright 2018 The TensorFlow Probability Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Discrete Cosine Transform bijector."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow.compat.v2 as tf
from tensorflow_probability.python.bijectors import bijector
from tensorflow_probability.python.internal import dtype_util
__all__ = [
'DiscreteCosineTransform',
]
class DiscreteCosineTransform(bijector.Bijector):
"""Compute `Y = g(X) = DCT(X)`, where DCT type is indicated by the `type` arg.
The [discrete cosine transform](
https://en.wikipedia.org/wiki/Discrete_cosine_transform) efficiently applies
a unitary DCT operator. This can be useful for mixing and decorrelating across
the innermost event dimension.
The inverse `X = g^{-1}(Y) = IDCT(Y)`, where IDCT is DCT-III for type==2.
This bijector can be interleaved with affine bijectors to build a cascade of
structured efficient linear layers as in [1].
Note that the operator applied is orthonormal (i.e. `norm='ortho'`).
#### References
[1]: Moczulski M, Denil M, Appleyard J, de Freitas N. ACDC: A structured
efficient linear layer. In _International Conference on Learning
Representations_, 2016. https://arxiv.org/abs/1511.05946
"""
def __init__(self, dct_type=2, validate_args=False, name='dct'):
"""Instantiates the `DiscreteCosineTransform` bijector.
Args:
dct_type: Python `int`, the DCT type performed by the forward
transformation. Currently, only 2 and 3 are supported.
validate_args: Python `bool` indicating whether arguments should be
checked for correctness.
name: Python `str` name given to ops managed by this object.
"""
parameters = dict(locals())
with tf.name_scope(name) as name:
# TODO(b/115910664): Support other DCT types.
if dct_type not in (2, 3):
raise NotImplementedError('`type` must be one of 2 or 3')
self._dct_type = dct_type
super(DiscreteCosineTransform, self).__init__(
forward_min_event_ndims=1,
inverse_min_event_ndims=1,
is_constant_jacobian=True,
validate_args=validate_args,
parameters=parameters,
name=name)
def _forward(self, x):
return tf.signal.dct(x, type=self._dct_type, norm='ortho')
def _inverse(self, y):
return tf.signal.idct(y, type=self._dct_type, norm='ortho')
def _inverse_log_det_jacobian(self, y):
return tf.constant(0., dtype=dtype_util.base_dtype(y.dtype))
def _forward_log_det_jacobian(self, x):
return tf.constant(0., dtype=dtype_util.base_dtype(x.dtype))