Skip to content

Commit

Permalink
Merge commit for internal changes
Browse files Browse the repository at this point in the history
  • Loading branch information
Vijay Vasudevan committed Jan 27, 2016
2 parents f13d006 + 42b1a67 commit d4422ff
Show file tree
Hide file tree
Showing 24 changed files with 1,269 additions and 89 deletions.
2 changes: 1 addition & 1 deletion tensorflow/core/framework/tensor_shape.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ class TensorShape {

struct TensorShapeDim {
explicit TensorShapeDim(int64 s) : size(s) {}
int size;
int64 size;
};

class TensorShapeIter {
Expand Down
12 changes: 12 additions & 0 deletions tensorflow/core/framework/tensor_shape_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,5 +86,17 @@ TEST(TensorShapeTest, SetDimForEmptyTensor) {
EXPECT_EQ(1400, s.num_elements());
}

TEST(TensorShapeTest, AppendShape64BitIndices) {
TensorShape s({10, 2147483648});

EXPECT_EQ(10, s.dim_size(0));
EXPECT_EQ(2147483648, s.dim_size(1));

TensorShape s2;
s2.AppendShape(s);
EXPECT_EQ(10, s2.dim_size(0));
EXPECT_EQ(2147483648, s2.dim_size(1));
}

} // namespace
} // namespace tensorflow
183 changes: 183 additions & 0 deletions tensorflow/core/kernels/matrix_solve_ls_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
/* Copyright 2015 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

// See docs in ../ops/linalg_ops.cc.
#include <cmath>

#include "third_party/eigen3/Eigen/Cholesky"
#include "third_party/eigen3/Eigen/Core"
#include "third_party/eigen3/Eigen/QR"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/kernels/binary_linalg_ops_common.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/port.h"
#include "tensorflow/core/public/tensor_shape.h"

namespace tensorflow {

template <class Scalar, bool SupportsBatchOperationT>
class MatrixSolveLsOp
: public BinaryLinearAlgebraOp<Scalar, SupportsBatchOperationT> {
public:
explicit MatrixSolveLsOp(OpKernelConstruction* context)
: BinaryLinearAlgebraOp<Scalar, SupportsBatchOperationT>(context) {
OP_REQUIRES_OK(context, context->GetAttr("fast", &fast_));
}

~MatrixSolveLsOp() override {}

TensorShape GetOutputMatrixShape(
const TensorShape& input_matrix_shape,
const TensorShape& rhs_matrix_shape) override {
CHECK_EQ(input_matrix_shape.dims(), rhs_matrix_shape.dims());
TensorShape output_matrix_shape = rhs_matrix_shape;
output_matrix_shape.set_dim(
output_matrix_shape.dims() - 2,
input_matrix_shape.dim_size(output_matrix_shape.dims() - 1));
return output_matrix_shape;
}

int64 GetCostPerUnit(const TensorShape& input_matrix_shape,
const TensorShape& rhs_matrix_shape) override {
const int64 rows = input_matrix_shape.dim_size(0);
const int64 rhss = rhs_matrix_shape.dim_size(1);
if (rows > (1LL << 20)) {
// A big number to cap the cost in case overflow.
return kint32max;
} else {
return 2 * rows * rows * (rows + rhss);
}
}

using typename BinaryLinearAlgebraOp<Scalar, SupportsBatchOperationT>::Matrix;
using typename BinaryLinearAlgebraOp<Scalar,
SupportsBatchOperationT>::MatrixMap;
using typename BinaryLinearAlgebraOp<Scalar,
SupportsBatchOperationT>::ConstMatrixMap;

void ComputeMatrix(OpKernelContext* context, const ConstMatrixMap& matrix,
const ConstMatrixMap& rhs, MatrixMap* output) override {
const int64 rows = matrix.rows();
const int64 cols = matrix.cols();
OP_REQUIRES(
context, rows == rhs.rows(),
errors::InvalidArgument("Input matrix and rhs are incompatible."));
const auto& l2_regularizer_in = context->input(2);
OP_REQUIRES(
context, TensorShapeUtils::IsScalar(l2_regularizer_in.shape()),
errors::InvalidArgument("l2_regularizer must be scalar, got shape ",
l2_regularizer_in.shape().DebugString()));
const double l2_regularizer = l2_regularizer_in.scalar<double>()();

OP_REQUIRES(context, l2_regularizer >= 0,
errors::InvalidArgument("l2_regularizer must be >= 0."));
if (rows == 0 || cols == 0) {
// The result is the empty matrix.
return;
}
if (fast_) {
// The fast branch assumes that matrix is not rank deficient and
// not too ill-conditioned. Specifically, the reciprobal condition number
// should be greater than the square root of the machine precision, i.e.
// 1 / cond(matrix) > sqrt(std::numeric_limits<Scalar>::epsilon()).
// This branch solves over- or underdetermined least-squares problems
// via the normal equations and Cholesky decomposition.
if (matrix.rows() >= matrix.cols()) {
// Overdetermined case (rows >= cols): Solves the ordinary (possibly
// regularized) least-squares problem
// min || A * X - RHS ||_F^2 + l2_regularizer ||X||_F^2
// by solving the normal equations
// (A^T * A + l2_regularizer * I) X = A^T RHS
// using Cholesky decomposition.
Matrix gramian(cols, cols);
gramian.template triangularView<Eigen::Lower>() =
matrix.transpose() * matrix;
if (l2_regularizer > 0) {
gramian +=
(Scalar(l2_regularizer) * Matrix::Ones(cols, 1)).asDiagonal();
}
const Eigen::LLT<Matrix, Eigen::Lower> llt(gramian);
OP_REQUIRES(
context, llt.info() == Eigen::Success,
errors::InvalidArgument("Input matrix was rank deficient or "
"ill-conditioned. Try setting fast=False "
"or provide a larger l2_regularizer > 0."));
*output = llt.solve(matrix.transpose() * rhs);
} else {
// Underdetermined case (rows < cols): Solves the minimum-norm problem
// min ||X||_F^2 s.t. A*X = RHS
// by solving the normal equations of the second kind
// (A * A^T + l2_regularizer * I) Z = RHS, X = A^T * Z
// using Cholesky decomposition.
Matrix gramian(rows, rows);
gramian.template triangularView<Eigen::Lower>() =
matrix * matrix.transpose();
if (l2_regularizer > 0) {
gramian +=
(Scalar(l2_regularizer) * Matrix::Ones(rows, 1)).asDiagonal();
}
const Eigen::LLT<Matrix, Eigen::Lower> llt(gramian);
OP_REQUIRES(
context, llt.info() == Eigen::Success,
errors::InvalidArgument("Input matrix was rank deficient or "
"ill-conditioned. Try setting fast=False "
"or provide an l2_regularizer > 0."));
*output = matrix.transpose() * llt.solve(rhs);
}
} else {
// Use a rank revealing factorization (QR with column pivoting).
//
// NOTICE: Currently, Eigen's implementation of column pivoted Householder
// QR has a few deficiencies:
// 1. It does not implement the post-processing step to compute a
// complete orthogonal factorization. This means that it does not
// return a minimum-norm solution for underdetermined and
// rank-deficient matrices. We could use the Eigen SVD instead, but
// the currently available JacobiSVD is so slow that is it is
// essentially useless (~100x slower than QR).
// 2. The implementation is not blocked, so for matrics that do not fit
// in cache, it is significantly slower than the equivalent blocked
// LAPACK routine xGEQP3 (e.g. Eigen is ~3x slower for 4k x 4k
// matrices). See http://www.netlib.org/lapack/lawnspdf/lawn114.pdf
// 3. The implementation uses the numerically unstable norm downdating
// formula from the original 1965 Businger & Golub paper. This can
// lead to incorrect rank determination for graded matrices. I
// (rmlarsen@) have a patch to bring this up to date by implementing
// the robust formula from
// http://www.netlib.org/lapack/lawnspdf/lawn176.pdf
//
// TODO(rmlarsen): a) Contribute 1. and 2. to Eigen.
// b) Evaluate new divide-and-conquer SVD in Eigen when
// it becomes available & robust.
*output = matrix.colPivHouseholderQr().solve(rhs);
}
}

private:
bool fast_;
};

REGISTER_BINARY_LINALG_OP("MatrixSolveLs", (MatrixSolveLsOp<float, false>),
float);
REGISTER_BINARY_LINALG_OP("MatrixSolveLs", (MatrixSolveLsOp<double, false>),
double);
REGISTER_BINARY_LINALG_OP("BatchMatrixSolveLs", (MatrixSolveLsOp<float, true>),
float);
REGISTER_BINARY_LINALG_OP("BatchMatrixSolveLs", (MatrixSolveLsOp<double, true>),
double);

} // namespace tensorflow
105 changes: 93 additions & 12 deletions tensorflow/core/ops/linalg_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ Calculates the determinant of a square matrix.
input: A tensor of shape `[M, M]`.
output: A scalar, equal to the determinant of the input.
T: The type of values in the input and output.
)doc");

REGISTER_OP("BatchMatrixDeterminant")
Expand All @@ -42,7 +41,6 @@ for all input submatrices `[..., :, :]`.
input: Shape is `[..., M, M]`.
output: Shape is `[...]`.
T: The type of values in the input and output.
)doc");

REGISTER_OP("MatrixInverse")
Expand All @@ -61,7 +59,6 @@ garbage result.
input: Shape is `[M, M]`.
output: Shape is `[M, M]` containing the matrix inverse of the input.
T: The type of values in the input and output.
)doc");

REGISTER_OP("BatchMatrixInverse")
Expand All @@ -84,7 +81,6 @@ garbage result.
input: Shape is `[..., M, M]`.
output: Shape is `[..., M, M]`.
T: The type of values in the input and output.
)doc");

REGISTER_OP("Cholesky")
Expand All @@ -103,7 +99,6 @@ input.
input: Shape is `[M, M]`.
output: Shape is `[M, M]`.
T: The type of values in the input and output.
)doc");

REGISTER_OP("BatchCholesky")
Expand All @@ -120,7 +115,6 @@ containing the Cholesky decompositions for all input submatrices `[..., :, :]`.
input: Shape is `[..., M, M]`.
output: Shape is `[..., M, M]`.
T: The type of values in the input and output.
)doc");

REGISTER_OP("SelfAdjointEig")
Expand All @@ -138,7 +132,6 @@ subsequent rows are eigenvectors.
input: Shape is `[M, M]`.
output: Shape is `[M+1, M]`.
T: The type of values in the input and output.
)doc");

REGISTER_OP("BatchSelfAdjointEig")
Expand All @@ -157,7 +150,6 @@ eigenvalues, and subsequent [...,1:, :] containing the eigenvectors.
input: Shape is `[..., M, M]`.
output: Shape is `[..., M+1, M]`.
T: The type of values in the input and output.
)doc");

REGISTER_OP("MatrixSolve")
Expand All @@ -172,7 +164,6 @@ matrix: Shape is `[M, M]`.
rhs: Shape is `[M, K]`.
output: Shape is `[M, K]` containing the tensor that solves
matrix * output = rhs.
T: The type of values in the input and output.
)doc");

REGISTER_OP("BatchMatrixSolve")
Expand All @@ -191,7 +182,6 @@ matrix satisfies matrix[..., :, :] * output[..., :, :] = rhs[..., :, :].
matrix: Shape is `[..., M, M]`.
rhs: Shape is `[..., M, K]`.
output: Shape is `[..., M, K]`.
T: The type of values in the input and output.
)doc");

REGISTER_OP("MatrixTriangularSolve")
Expand All @@ -218,7 +208,6 @@ matrix: Shape is `[M, M]`.
rhs: Shape is `[M, K]`.
output: Shape is `[M, K]`.
lower: Boolean indicating whether matrix is lower or upper triangular.
T: The type of values in the input and output.
)doc");

REGISTER_OP("BatchMatrixTriangularSolve")
Expand Down Expand Up @@ -247,7 +236,99 @@ matrix: Shape is `[..., M, M]`.
rhs: Shape is `[..., M, K]`.
output: Shape is `[..., M, K]`.
lower: Boolean indicating whether matrix is lower or upper triangular.
T: The type of values in the input and output.
)doc");

REGISTER_OP("MatrixSolveLs")
.Input("matrix: T")
.Input("rhs: T")
.Input("l2_regularizer: double")
.Output("output: T")
.Attr("T: {float, double}")
.Attr("fast: bool = True")
.Doc(R"doc(
Solves a linear least-squares problem.
Below we will use the following notation
`matrix`=\\(A \in \Re^{m \times n}\\),
`rhs`=\\(B \in \Re^{m \times k}\\),
`output`=\\(X \in \Re^{n \times k}\\),
`l2_regularizer`=\\(\lambda\\).
If `fast` is `True`, then the solution is computed by solving the normal
equations using Cholesky decomposition. Specifically, if \\(m \ge n\\) then
\\(X = (A^T A + \lambda I)^{-1} A^T B\\), which solves the least-squares
problem \\(X = \mathrm{argmin}_{Z \in \Re^{n \times k}} ||A Z - B||_F^2 +
\lambda ||Z||_F^2\\). If \\(m \lt n\\) then `output` is computed as
\\(X = A^T (A A^T + \lambda I)^{-1} B\\),
which (for \\(\lambda = 0\\)) is the minimum-norm solution to the
under-determined linear system, i.e.
\\(X = \mathrm{argmin}_{Z \in \Re^{n \times k}} ||Z||_F^2 \\),
subject to \\(A Z = B\\).
Notice that the fast path is only numerically stable when \\(A\\) is
numerically full rank and has a condition number
\\(\mathrm{cond}(A) \lt \frac{1}{\sqrt{\epsilon_{mach}}}\\)
or \\(\lambda\\) is sufficiently large.
If `fast` is `False` then the solution is computed using the rank revealing QR
decomposition with column pivoting. This will always compute a least-squares
solution that minimizes the residual norm \\(||A X - B||_F^2 \\), even when
\\( A \\) is rank deficient or ill-conditioned. Notice: The current version
does not compute a minimum norm solution. If `fast` is `False` then
`l2_regularizer` is ignored.
matrix: Shape is `[M, N]`.
rhs: Shape is `[M, K]`.
output: Shape is `[N, K]` containing the tensor that solves
`matrix * output = rhs` in the least-squares sense.
)doc");

REGISTER_OP("BatchMatrixSolveLs")
.Input("matrix: T")
.Input("rhs: T")
.Input("l2_regularizer: double")
.Output("output: T")
.Attr("T: {float, double}")
.Attr("fast: bool = True")
.Doc(R"doc(
Solves multiple linear least-squares problems.
`matrix` is a tensor of shape `[..., M, N]` whose inner-most 2 dimensions
form square matrices. Rhs is a tensor of shape `[..., M, K]`. The output
is a tensor shape `[..., N, K]` where each output matrix solves each of
the equations matrix[..., :, :] * output[..., :, :] = rhs[..., :, :] in the
least squares sense.
Below we will use the following notation for each pair of
matrix and right-hand sides in the batch:
`matrix`=\\(A \in \Re^{m \times n}\\),
`rhs`=\\(B \in \Re^{m \times k}\\),
`output`=\\(X \in \Re^{n \times k}\\),
`l2_regularizer`=\\(\lambda\\).
If `fast` is `True`, then the solution is computed by solving the normal
equations using Cholesky decomposition. Specifically, if \\(m \ge n\\) then
\\(X = (A^T A + \lambda I)^{-1} A^T B\\), which solves the least-squares
problem \\(X = \mathrm{argmin}_{Z \in \Re^{n \times k}} ||A Z - B||_F^2 +
\lambda ||Z||_F^2\\). If \\(m \lt n\\) then `output` is computed as
\\(X = A^T (A A^T + \lambda I)^{-1} B\\), which (for \\(\lambda = 0\\)) is the
minimum-norm solution to the under-determined linear system, i.e.
\\(X = \mathrm{argmin}_{Z \in \Re^{n \times k}} ||Z||_F^2 \\), subject to
\\(A Z = B\\). Notice that the fast path is only numerically stable when
\\(A\\) is numerically full rank and has a condition number
\\(\mathrm{cond}(A) \lt \frac{1}{\sqrt{\epsilon_{mach}}}\\) or\\(\lambda\\) is
sufficiently large.
If `fast` is `False` then the solution is computed using the rank revealing QR
decomposition with column pivoting. This will always compute a least-squares
solution that minimizes the residual norm \\(||A X - B||_F^2\\), even when
\\(A\\) is rank deficient or ill-conditioned. Notice: The current version does
not compute a minimum norm solution. If `fast` is `False` then `l2_regularizer`
is ignored.
matrix: Shape is `[..., M, N]`.
rhs: Shape is `[..., M, K]`.
output: Shape is `[..., N, K]`.
)doc");

} // namespace tensorflow

0 comments on commit d4422ff

Please sign in to comment.