Skip to content
Permalink
Browse files

Added gradient op for QR decomposition

PiperOrigin-RevId: 172895297
  • Loading branch information...
tensorflower-gardener committed Oct 20, 2017
1 parent 5c24b8b commit f86588ce8fb38ab3a6afc21eb08d2a2097b56adc
Showing with 98 additions and 10 deletions.
  1. +62 −4 tensorflow/python/kernel_tests/qr_op_test.py
  2. +36 −6 tensorflow/python/ops/linalg_grad.py
@@ -22,6 +22,7 @@

from tensorflow.python.framework import constant_op
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
@@ -140,11 +141,11 @@ def Test(self):
x_reshape = np.reshape(x_np, (-1, x_np.shape[-2], x_np.shape[-1]))
for i in range(new_first_dim):
if full_matrices_:
np_q_reshape[i,:,:], _ = \
np.linalg.qr(x_reshape[i,:,:], mode="complete")
np_q_reshape[i, :, :], _ = np.linalg.qr(
x_reshape[i, :, :], mode="complete")
else:
np_q_reshape[i,:,:], _ = \
np.linalg.qr(x_reshape[i,:,:], mode="reduced")
np_q_reshape[i, :, :], _ = np.linalg.qr(
x_reshape[i, :, :], mode="reduced")
np_q = np.reshape(np_q_reshape, q_dims)
CompareOrthogonal(self, np_q, q_tf_val, min(shape_[-2:]))
CheckApproximation(self, x_np, q_tf_val, r_tf_val)
@@ -153,6 +154,46 @@ def Test(self):
return Test


class QrGradOpTest(test.TestCase):
pass


def _GetQrGradOpTest(dtype_, shape_, full_matrices_):

def Test(self):
np.random.seed(42)
a = np.random.uniform(low=-1.0, high=1.0, size=shape_).astype(dtype_)
if dtype_ in [np.complex64, np.complex128]:
a += 1j * np.random.uniform(
low=-1.0, high=1.0, size=shape_).astype(dtype_)
# Optimal stepsize for central difference is O(epsilon^{1/3}).
epsilon = np.finfo(dtype_).eps
delta = 0.1 * epsilon**(1.0 / 3.0)
if dtype_ in [np.float32, np.complex64]:
tol = 3e-2
else:
tol = 1e-6
with self.test_session(use_gpu=True):
tf_a = constant_op.constant(a)
tf_b = linalg_ops.qr(tf_a, full_matrices=full_matrices_)
for b in tf_b:
x_init = np.random.uniform(
low=-1.0, high=1.0, size=shape_).astype(dtype_)
if dtype_ in [np.complex64, np.complex128]:
x_init += 1j * np.random.uniform(
low=-1.0, high=1.0, size=shape_).astype(dtype_)
theoretical, numerical = gradient_checker.compute_gradient(
tf_a,
tf_a.get_shape().as_list(),
b,
b.get_shape().as_list(),
x_init_value=x_init,
delta=delta)
self.assertAllClose(theoretical, numerical, atol=tol, rtol=tol)

return Test


if __name__ == "__main__":
for dtype in np.float32, np.float64, np.complex64, np.complex128:
for rows in 1, 2, 5, 10, 32, 100:
@@ -168,4 +209,21 @@ def Test(self):
_AddTest(QrOpTest, "Qr", name,
_GetQrOpTest(dtype, shape, full_matrices,
use_static_shape))

# TODO(pfau): Get working with complex types.
# TODO(pfau): Get working with full_matrices when rows != cols
# TODO(pfau): Get working when rows < cols
# TODO(pfau): Get working with shapeholders (dynamic shapes)
for full_matrices in False, True:
for dtype in np.float32, np.float64:
for rows in 1, 2, 5, 10:
for cols in 1, 2, 5, 10:
if rows == cols or (not full_matrices and rows > cols):
for batch_dims in [(), (3,)] + [(3, 2)] * (max(rows, cols) < 10):
shape = batch_dims + (rows, cols)
name = "%s_%s_full_%s" % (dtype.__name__,
"_".join(map(str, shape)),
full_matrices)
_AddTest(QrGradOpTest, "QrGrad", name,
_GetQrGradOpTest(dtype, shape, full_matrices))
test.main()
@@ -81,6 +81,36 @@ def _CholeskyGrad(op, grad):
return grad_a * 0.5


@ops.RegisterGradient("Qr")
def _QrGrad(op, dq, dr):
"""Gradient for Qr."""
q, r = op.outputs
if q.dtype.is_complex:
raise NotImplementedError("QrGrad not implemented for dtype: %s" % q.dtype)
if (r.shape.ndims is None or r.shape.as_list()[-2] is None or
r.shape.as_list()[-1] is None):
raise NotImplementedError("QrGrad not implemented with dynamic shapes.")
if r.shape[-2].value != r.shape[-1].value:
raise NotImplementedError("QrGrad not implemented when ncols > nrows "
"or full_matrices is true and ncols != nrows.")

qdq = math_ops.matmul(q, dq, adjoint_a=True)
qdq_ = qdq - _linalg.adjoint(qdq)
rdr = math_ops.matmul(r, dr, adjoint_b=True)
rdr_ = rdr - _linalg.adjoint(rdr)
tril = array_ops.matrix_band_part(qdq_ + rdr_, -1, 0)

def _TriangularSolve(x, r):
"""Equiv to matmul(x, adjoint(matrix_inverse(r))) if r is upper-tri."""
return _linalg.adjoint(
linalg_ops.matrix_triangular_solve(
r, _linalg.adjoint(x), lower=False, adjoint=False))

grad_a = math_ops.matmul(q, dr + _TriangularSolve(tril, r))
grad_b = _TriangularSolve(dq - math_ops.matmul(q, qdq), r)
return grad_a + grad_b


@ops.RegisterGradient("MatrixSolve")
def _MatrixSolveGrad(op, grad):
"""Gradient for MatrixSolve."""
@@ -105,7 +135,7 @@ def _MatrixSolveLsGrad(op, grad):
# b) Implement a symmetric rank-k update op instead of computing
# x*z + transpose(x*z). This pattern occurs other places in TensorFlow.

def _overdetermined(op, grad):
def _Overdetermined(op, grad):
"""Gradients for the overdetermined case of MatrixSolveLs.
This is the backprop for the solution to the normal equations of the first
@@ -130,7 +160,7 @@ def _overdetermined(op, grad):
grad_b = math_ops.matmul(a, z)
return (grad_a, grad_b, None)

def _underdetermined(op, grad):
def _Underdetermined(op, grad):
"""Gradients for the underdetermined case of MatrixSolveLs.
This is the backprop for the solution to the normal equations of the second
@@ -162,16 +192,16 @@ def _underdetermined(op, grad):
matrix_shape = op.inputs[0].get_shape()[-2:]
if matrix_shape.is_fully_defined():
if matrix_shape[-2] >= matrix_shape[-1]:
return _overdetermined(op, grad)
return _Overdetermined(op, grad)
else:
return _underdetermined(op, grad)
return _Underdetermined(op, grad)
else:
# We have to defer determining the shape to runtime and use
# conditional execution of the appropriate graph.
matrix_shape = array_ops.shape(op.inputs[0])[-2:]
return control_flow_ops.cond(matrix_shape[-2] >= matrix_shape[-1],
lambda: _overdetermined(op, grad),
lambda: _underdetermined(op, grad))
lambda: _Overdetermined(op, grad),
lambda: _Underdetermined(op, grad))


@ops.RegisterGradient("MatrixTriangularSolve")

0 comments on commit f86588c

Please sign in to comment.
You can’t perform that action at this time.