Skip to content
Fetching contributors…
Cannot retrieve contributors at this time
108 lines (91 sloc) 3.67 KB
 # Copyright 2018 The TensorFlow Probability Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """TransformDiagonal bijector.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import tensorflow.compat.v2 as tf from tensorflow_probability.python.bijectors import bijector __all__ = [ "TransformDiagonal", ] class TransformDiagonal(bijector.Bijector): """Applies a Bijector to the diagonal of a matrix. #### Example ```python b = tfb.TransformDiagonal(diag_bijector=tfb.Exp()) b.forward([[1., 0.], [0., 1.]]) # ==> [[2.718, 0.], [0., 2.718]] ``` """ def __init__(self, diag_bijector, validate_args=False, name="transform_diagonal"): """Instantiates the `TransformDiagonal` bijector. Args: diag_bijector: `Bijector` instance used to transform the diagonal. validate_args: Python `bool` indicating whether arguments should be checked for correctness. name: Python `str` name given to ops managed by this object. """ self._diag_bijector = diag_bijector super(TransformDiagonal, self).__init__( forward_min_event_ndims=2, inverse_min_event_ndims=2, validate_args=validate_args, dtype=diag_bijector.dtype, name=name) @property def diag_bijector(self): return self._diag_bijector def _forward(self, x): diag = self.diag_bijector.forward(tf.linalg.diag_part(x)) return tf.linalg.set_diag(x, diag) def _inverse(self, y): diag = self.diag_bijector.inverse(tf.linalg.diag_part(y)) return tf.linalg.set_diag(y, diag) def _forward_log_det_jacobian(self, x): # We formulate the Jacobian with respect to the flattened matrices # `vec(x)` and `vec(y)`. Suppose for notational convenience that # the first `n` entries of `vec(x)` are the diagonal of `x`, and # the remaining `n**2-n` entries are the off-diagonals in # arbitrary order. Then the Jacobian is a block-diagonal matrix, # with the Jacobian of the diagonal bijector in the first block, # and the identity Jacobian for the remaining entries (since this # bijector acts as the identity on non-diagonal entries): # # J_vec(x) (vec(y)) = # ------------------------------- # | J_diag(x) (diag(y)) 0 | n entries # | | # | 0 I | n**2-n entries # ------------------------------- # n n**2-n # # Since the log-det of the second (identity) block is zero, the # overall log-det-jacobian is just the log-det of first block, # from the diagonal bijector. # # Note that for elementwise operations (exp, softplus, etc) the # first block of the Jacobian will itself be a diagonal matrix, # but our implementation does not require this to be true. return self.diag_bijector.forward_log_det_jacobian( tf.linalg.diag_part(x), event_ndims=1) def _inverse_log_det_jacobian(self, y): return self.diag_bijector.inverse_log_det_jacobian( tf.linalg.diag_part(y), event_ndims=1)
You can’t perform that action at this time.