Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add gradient for QR decomposition of wide matrices (nrows < ncols) #39321

Merged
merged 2 commits into from
Jun 24, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 1 addition & 2 deletions tensorflow/python/eager/pywrap_gradient_exclusions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ auto OpGradientInfoInit(const T &a) {

absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedInputIndices(
const tensorflow::string &op_name) {
static std::array<OpIndexInfo, 349> a = {{
static std::array<OpIndexInfo, 348> a = {{
{"Acosh"},
{"AllToAll", 1, {0}},
{"ApproximateEqual"},
Expand Down Expand Up @@ -222,7 +222,6 @@ absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedInputIndices(
{"PlaceholderWithDefault"},
{"PopulationCount"},
{"PreventGradient"},
{"Qr"},
{"QuantizeAndDequantize"},
{"QuantizeAndDequantizeV2"},
{"QuantizeAndDequantizeV3"},
Expand Down
5 changes: 2 additions & 3 deletions tensorflow/python/kernel_tests/qr_op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,14 +278,13 @@ def benchmarkQROp(self):
use_static_shape))

# TODO(pfau): Get working with complex types.
# TODO(pfau): Get working with full_matrices when rows != cols
# TODO(pfau): Get working when rows < cols
# TODO(pfau): Get working with full_matrices when rows > cols
# TODO(pfau): Get working with shapeholders (dynamic shapes)
for full_matrices in False, True:
for dtype in np.float32, np.float64:
for rows in 1, 2, 5, 10:
for cols in 1, 2, 5, 10:
if rows == cols or (not full_matrices and rows > cols):
if rows <= cols or (not full_matrices and rows > cols):
for batch_dims in [(), (3,)] + [(3, 2)] * (max(rows, cols) < 10):
shape = batch_dims + (rows, cols)
name = "%s_%s_full_%s" % (dtype.__name__,
Expand Down
47 changes: 36 additions & 11 deletions tensorflow/python/ops/linalg_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,25 +493,50 @@ def _QrGrad(op, dq, dr):
if (r.shape.ndims is None or r.shape.as_list()[-2] is None or
r.shape.as_list()[-1] is None):
raise NotImplementedError("QrGrad not implemented with dynamic shapes.")
if r.shape.dims[-2].value != r.shape.dims[-1].value:
if (r.shape.dims[-2].value > r.shape.dims[-1].value and
q.shape.dims[-2].value == q.shape.dims[-1].value):
raise NotImplementedError("QrGrad not implemented when ncols > nrows "
"or full_matrices is true and ncols != nrows.")

qdq = math_ops.matmul(q, dq, adjoint_a=True)
qdq_ = qdq - _linalg.adjoint(qdq)
rdr = math_ops.matmul(r, dr, adjoint_b=True)
rdr_ = rdr - _linalg.adjoint(rdr)
tril = array_ops.matrix_band_part(qdq_ + rdr_, -1, 0)
"and full_matrices is true.")

def _TriangularSolve(x, r):
"""Equiv to matmul(x, adjoint(matrix_inverse(r))) if r is upper-tri."""
return _linalg.adjoint(
linalg_ops.matrix_triangular_solve(
r, _linalg.adjoint(x), lower=False, adjoint=False))

grad_a = math_ops.matmul(q, dr + _TriangularSolve(tril, r))
grad_b = _TriangularSolve(dq - math_ops.matmul(q, qdq), r)
return grad_a + grad_b
def _QrGradSquareAndDeepMatrices(q, r, dq, dr):
"""Gradient for matrix orders num_rows >= num_cols
and full_matrices is false.
"""
qdq = math_ops.matmul(q, dq, adjoint_a=True)
qdq_ = qdq - _linalg.adjoint(qdq)
rdr = math_ops.matmul(r, dr, adjoint_b=True)
rdr_ = rdr - _linalg.adjoint(rdr)
tril = array_ops.matrix_band_part(qdq_ + rdr_, -1, 0)

grad_a = math_ops.matmul(q, dr + _TriangularSolve(tril, r))
grad_b = _TriangularSolve(dq - math_ops.matmul(q, qdq), r)
return grad_a + grad_b

num_rows, num_cols = q.shape.dims[-2].value, r.shape.dims[-1]

if num_rows >= num_cols:
return _QrGradSquareAndDeepMatrices(q, r, dq, dr)

# Partition a = [x, y], r = [u, v] and reduce to the square case
a = op.inputs[0]
y = a[..., :, num_rows:]
u = r[..., :, :num_rows]
dv = dr[..., :, num_rows:]
du = dr[..., :, :num_rows]
dy = math_ops.matmul(q, dv)
dx = _QrGradSquareAndDeepMatrices(q,
u,
dq + math_ops.matmul(y,
dv,
adjoint_b=True),
du)
return array_ops.concat([dx, dy], axis=-1)


@ops.RegisterGradient("MatrixSolve")
Expand Down