Skip to content

Commit f86588c

Browse files
Added gradient op for QR decomposition
PiperOrigin-RevId: 172895297
1 parent 5c24b8b commit f86588c

File tree

2 files changed

+98
-10
lines changed

2 files changed

+98
-10
lines changed

tensorflow/python/kernel_tests/qr_op_test.py

Lines changed: 62 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
from tensorflow.python.framework import constant_op
2424
from tensorflow.python.ops import array_ops
25+
from tensorflow.python.ops import gradient_checker
2526
from tensorflow.python.ops import linalg_ops
2627
from tensorflow.python.ops import math_ops
2728
from tensorflow.python.ops import random_ops
@@ -140,11 +141,11 @@ def Test(self):
140141
x_reshape = np.reshape(x_np, (-1, x_np.shape[-2], x_np.shape[-1]))
141142
for i in range(new_first_dim):
142143
if full_matrices_:
143-
np_q_reshape[i,:,:], _ = \
144-
np.linalg.qr(x_reshape[i,:,:], mode="complete")
144+
np_q_reshape[i, :, :], _ = np.linalg.qr(
145+
x_reshape[i, :, :], mode="complete")
145146
else:
146-
np_q_reshape[i,:,:], _ = \
147-
np.linalg.qr(x_reshape[i,:,:], mode="reduced")
147+
np_q_reshape[i, :, :], _ = np.linalg.qr(
148+
x_reshape[i, :, :], mode="reduced")
148149
np_q = np.reshape(np_q_reshape, q_dims)
149150
CompareOrthogonal(self, np_q, q_tf_val, min(shape_[-2:]))
150151
CheckApproximation(self, x_np, q_tf_val, r_tf_val)
@@ -153,6 +154,46 @@ def Test(self):
153154
return Test
154155

155156

157+
class QrGradOpTest(test.TestCase):
158+
pass
159+
160+
161+
def _GetQrGradOpTest(dtype_, shape_, full_matrices_):
162+
163+
def Test(self):
164+
np.random.seed(42)
165+
a = np.random.uniform(low=-1.0, high=1.0, size=shape_).astype(dtype_)
166+
if dtype_ in [np.complex64, np.complex128]:
167+
a += 1j * np.random.uniform(
168+
low=-1.0, high=1.0, size=shape_).astype(dtype_)
169+
# Optimal stepsize for central difference is O(epsilon^{1/3}).
170+
epsilon = np.finfo(dtype_).eps
171+
delta = 0.1 * epsilon**(1.0 / 3.0)
172+
if dtype_ in [np.float32, np.complex64]:
173+
tol = 3e-2
174+
else:
175+
tol = 1e-6
176+
with self.test_session(use_gpu=True):
177+
tf_a = constant_op.constant(a)
178+
tf_b = linalg_ops.qr(tf_a, full_matrices=full_matrices_)
179+
for b in tf_b:
180+
x_init = np.random.uniform(
181+
low=-1.0, high=1.0, size=shape_).astype(dtype_)
182+
if dtype_ in [np.complex64, np.complex128]:
183+
x_init += 1j * np.random.uniform(
184+
low=-1.0, high=1.0, size=shape_).astype(dtype_)
185+
theoretical, numerical = gradient_checker.compute_gradient(
186+
tf_a,
187+
tf_a.get_shape().as_list(),
188+
b,
189+
b.get_shape().as_list(),
190+
x_init_value=x_init,
191+
delta=delta)
192+
self.assertAllClose(theoretical, numerical, atol=tol, rtol=tol)
193+
194+
return Test
195+
196+
156197
if __name__ == "__main__":
157198
for dtype in np.float32, np.float64, np.complex64, np.complex128:
158199
for rows in 1, 2, 5, 10, 32, 100:
@@ -168,4 +209,21 @@ def Test(self):
168209
_AddTest(QrOpTest, "Qr", name,
169210
_GetQrOpTest(dtype, shape, full_matrices,
170211
use_static_shape))
212+
213+
# TODO(pfau): Get working with complex types.
214+
# TODO(pfau): Get working with full_matrices when rows != cols
215+
# TODO(pfau): Get working when rows < cols
216+
# TODO(pfau): Get working with shapeholders (dynamic shapes)
217+
for full_matrices in False, True:
218+
for dtype in np.float32, np.float64:
219+
for rows in 1, 2, 5, 10:
220+
for cols in 1, 2, 5, 10:
221+
if rows == cols or (not full_matrices and rows > cols):
222+
for batch_dims in [(), (3,)] + [(3, 2)] * (max(rows, cols) < 10):
223+
shape = batch_dims + (rows, cols)
224+
name = "%s_%s_full_%s" % (dtype.__name__,
225+
"_".join(map(str, shape)),
226+
full_matrices)
227+
_AddTest(QrGradOpTest, "QrGrad", name,
228+
_GetQrGradOpTest(dtype, shape, full_matrices))
171229
test.main()

tensorflow/python/ops/linalg_grad.py

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,36 @@ def _CholeskyGrad(op, grad):
8181
return grad_a * 0.5
8282

8383

84+
@ops.RegisterGradient("Qr")
85+
def _QrGrad(op, dq, dr):
86+
"""Gradient for Qr."""
87+
q, r = op.outputs
88+
if q.dtype.is_complex:
89+
raise NotImplementedError("QrGrad not implemented for dtype: %s" % q.dtype)
90+
if (r.shape.ndims is None or r.shape.as_list()[-2] is None or
91+
r.shape.as_list()[-1] is None):
92+
raise NotImplementedError("QrGrad not implemented with dynamic shapes.")
93+
if r.shape[-2].value != r.shape[-1].value:
94+
raise NotImplementedError("QrGrad not implemented when ncols > nrows "
95+
"or full_matrices is true and ncols != nrows.")
96+
97+
qdq = math_ops.matmul(q, dq, adjoint_a=True)
98+
qdq_ = qdq - _linalg.adjoint(qdq)
99+
rdr = math_ops.matmul(r, dr, adjoint_b=True)
100+
rdr_ = rdr - _linalg.adjoint(rdr)
101+
tril = array_ops.matrix_band_part(qdq_ + rdr_, -1, 0)
102+
103+
def _TriangularSolve(x, r):
104+
"""Equiv to matmul(x, adjoint(matrix_inverse(r))) if r is upper-tri."""
105+
return _linalg.adjoint(
106+
linalg_ops.matrix_triangular_solve(
107+
r, _linalg.adjoint(x), lower=False, adjoint=False))
108+
109+
grad_a = math_ops.matmul(q, dr + _TriangularSolve(tril, r))
110+
grad_b = _TriangularSolve(dq - math_ops.matmul(q, qdq), r)
111+
return grad_a + grad_b
112+
113+
84114
@ops.RegisterGradient("MatrixSolve")
85115
def _MatrixSolveGrad(op, grad):
86116
"""Gradient for MatrixSolve."""
@@ -105,7 +135,7 @@ def _MatrixSolveLsGrad(op, grad):
105135
# b) Implement a symmetric rank-k update op instead of computing
106136
# x*z + transpose(x*z). This pattern occurs other places in TensorFlow.
107137

108-
def _overdetermined(op, grad):
138+
def _Overdetermined(op, grad):
109139
"""Gradients for the overdetermined case of MatrixSolveLs.
110140
111141
This is the backprop for the solution to the normal equations of the first
@@ -130,7 +160,7 @@ def _overdetermined(op, grad):
130160
grad_b = math_ops.matmul(a, z)
131161
return (grad_a, grad_b, None)
132162

133-
def _underdetermined(op, grad):
163+
def _Underdetermined(op, grad):
134164
"""Gradients for the underdetermined case of MatrixSolveLs.
135165
136166
This is the backprop for the solution to the normal equations of the second
@@ -162,16 +192,16 @@ def _underdetermined(op, grad):
162192
matrix_shape = op.inputs[0].get_shape()[-2:]
163193
if matrix_shape.is_fully_defined():
164194
if matrix_shape[-2] >= matrix_shape[-1]:
165-
return _overdetermined(op, grad)
195+
return _Overdetermined(op, grad)
166196
else:
167-
return _underdetermined(op, grad)
197+
return _Underdetermined(op, grad)
168198
else:
169199
# We have to defer determining the shape to runtime and use
170200
# conditional execution of the appropriate graph.
171201
matrix_shape = array_ops.shape(op.inputs[0])[-2:]
172202
return control_flow_ops.cond(matrix_shape[-2] >= matrix_shape[-1],
173-
lambda: _overdetermined(op, grad),
174-
lambda: _underdetermined(op, grad))
203+
lambda: _Overdetermined(op, grad),
204+
lambda: _Underdetermined(op, grad))
175205

176206

177207
@ops.RegisterGradient("MatrixTriangularSolve")

0 commit comments

Comments
 (0)