Skip to content

Commit

Permalink
Get rid of SupportsBatchOperation template parameter from linear alge…
Browse files Browse the repository at this point in the history
…bra ops, to reduce code size and compilation time.

Change: 131328798
  • Loading branch information
tensorflower-gardener committed Aug 26, 2016
1 parent 7749e64 commit 6cf3f2a
Show file tree
Hide file tree
Showing 12 changed files with 104 additions and 128 deletions.
15 changes: 8 additions & 7 deletions tensorflow/core/kernels/cholesky_grad.cc
Expand Up @@ -22,10 +22,10 @@ limitations under the License.

namespace tensorflow {

template <typename Scalar, bool SupportsBatchOperation>
class CholeskyGrad : public LinearAlgebraOp<Scalar, SupportsBatchOperation> {
template <typename Scalar>
class CholeskyGrad : public LinearAlgebraOp<Scalar> {
public:
typedef LinearAlgebraOp<Scalar, SupportsBatchOperation> Base;
typedef LinearAlgebraOp<Scalar> Base;

explicit CholeskyGrad(OpKernelConstruction* context) : Base(context) {}

Expand Down Expand Up @@ -156,8 +156,9 @@ class CholeskyGrad : public LinearAlgebraOp<Scalar, SupportsBatchOperation> {
}
};

REGISTER_LINALG_OP("CholeskyGrad", (CholeskyGrad<float, false>), float);
REGISTER_LINALG_OP("CholeskyGrad", (CholeskyGrad<double, false>), double);
REGISTER_LINALG_OP("BatchCholeskyGrad", (CholeskyGrad<float, true>), float);
REGISTER_LINALG_OP("BatchCholeskyGrad", (CholeskyGrad<double, true>), double);
REGISTER_LINALG_OP("CholeskyGrad", (CholeskyGrad<float>), float);
REGISTER_LINALG_OP("CholeskyGrad", (CholeskyGrad<double>), double);
REGISTER_LINALG_OP("BatchCholeskyGrad", (CholeskyGrad<float>), float);
REGISTER_LINALG_OP("BatchCholeskyGrad", (CholeskyGrad<double>), double);

} // namespace tensorflow
15 changes: 8 additions & 7 deletions tensorflow/core/kernels/cholesky_op.cc
Expand Up @@ -29,10 +29,10 @@ limitations under the License.

namespace tensorflow {

template <class Scalar, bool SupportsBatchOperation>
class CholeskyOp : public LinearAlgebraOp<Scalar, SupportsBatchOperation> {
template <class Scalar>
class CholeskyOp : public LinearAlgebraOp<Scalar> {
public:
typedef LinearAlgebraOp<Scalar, SupportsBatchOperation> Base;
typedef LinearAlgebraOp<Scalar> Base;

explicit CholeskyOp(OpKernelConstruction* context) : Base(context) {}

Expand Down Expand Up @@ -65,8 +65,9 @@ class CholeskyOp : public LinearAlgebraOp<Scalar, SupportsBatchOperation> {
}
};

REGISTER_LINALG_OP("Cholesky", (CholeskyOp<float, false>), float);
REGISTER_LINALG_OP("Cholesky", (CholeskyOp<double, false>), double);
REGISTER_LINALG_OP("BatchCholesky", (CholeskyOp<float, true>), float);
REGISTER_LINALG_OP("BatchCholesky", (CholeskyOp<double, true>), double);
REGISTER_LINALG_OP("Cholesky", (CholeskyOp<float>), float);
REGISTER_LINALG_OP("Cholesky", (CholeskyOp<double>), double);
REGISTER_LINALG_OP("BatchCholesky", (CholeskyOp<float>), float);
REGISTER_LINALG_OP("BatchCholesky", (CholeskyOp<double>), double);

} // namespace tensorflow
16 changes: 7 additions & 9 deletions tensorflow/core/kernels/determinant_op.cc
Expand Up @@ -27,10 +27,10 @@ limitations under the License.

namespace tensorflow {

template <class Scalar, bool SupportsBatchOperation>
class DeterminantOp : public LinearAlgebraOp<Scalar, SupportsBatchOperation> {
template <class Scalar>
class DeterminantOp : public LinearAlgebraOp<Scalar> {
public:
typedef LinearAlgebraOp<Scalar, SupportsBatchOperation> Base;
typedef LinearAlgebraOp<Scalar> Base;

explicit DeterminantOp(OpKernelConstruction* context) : Base(context) {}

Expand Down Expand Up @@ -60,11 +60,9 @@ class DeterminantOp : public LinearAlgebraOp<Scalar, SupportsBatchOperation> {
}
};

REGISTER_LINALG_OP("MatrixDeterminant", (DeterminantOp<float, false>), float);
REGISTER_LINALG_OP("MatrixDeterminant", (DeterminantOp<double, false>), double);
REGISTER_LINALG_OP("BatchMatrixDeterminant", (DeterminantOp<float, true>),
float);
REGISTER_LINALG_OP("BatchMatrixDeterminant", (DeterminantOp<double, true>),
double);
REGISTER_LINALG_OP("MatrixDeterminant", (DeterminantOp<float>), float);
REGISTER_LINALG_OP("MatrixDeterminant", (DeterminantOp<double>), double);
REGISTER_LINALG_OP("BatchMatrixDeterminant", (DeterminantOp<float>), float);
REGISTER_LINALG_OP("BatchMatrixDeterminant", (DeterminantOp<double>), double);

} // namespace tensorflow
66 changes: 27 additions & 39 deletions tensorflow/core/kernels/linalg_ops_common.cc
Expand Up @@ -27,8 +27,8 @@ limitations under the License.
namespace tensorflow {

// static
template <typename Scalar, bool SupportsBatchOperation>
void LinearAlgebraOp<Scalar, SupportsBatchOperation>::ValidateSingleMatrix(
template <typename Scalar>
void LinearAlgebraOp<Scalar>::ValidateSingleMatrix(
OpKernelContext* context, const TensorShapes& input_matrix_shapes) {
OP_REQUIRES(context, input_matrix_shapes.size() == 1,
errors::InvalidArgument("Expected a single input matrix, got %d.",
Expand All @@ -38,10 +38,9 @@ void LinearAlgebraOp<Scalar, SupportsBatchOperation>::ValidateSingleMatrix(
}

// static
template <typename Scalar, bool SupportsBatchOperation>
void LinearAlgebraOp<Scalar, SupportsBatchOperation>::
ValidateSingleSquareMatrix(OpKernelContext* context,
const TensorShapes& input_matrix_shapes) {
template <typename Scalar>
void LinearAlgebraOp<Scalar>::ValidateSingleSquareMatrix(
OpKernelContext* context, const TensorShapes& input_matrix_shapes) {
OP_REQUIRES(context, input_matrix_shapes.size() == 1,
errors::InvalidArgument("Expected a single input matrix, got %d.",
input_matrix_shapes.size()));
Expand All @@ -50,8 +49,8 @@ void LinearAlgebraOp<Scalar, SupportsBatchOperation>::
}

// static
template <typename Scalar, bool SupportsBatchOperation>
void LinearAlgebraOp<Scalar, SupportsBatchOperation>::ValidateSolver(
template <typename Scalar>
void LinearAlgebraOp<Scalar>::ValidateSolver(
OpKernelContext* context, const TensorShapes& input_matrix_shapes) {
OP_REQUIRES(context, input_matrix_shapes.size() == 2,
errors::InvalidArgument("Expected two input matrices, got %d.",
Expand All @@ -67,8 +66,8 @@ void LinearAlgebraOp<Scalar, SupportsBatchOperation>::ValidateSolver(
}

// static
template <typename Scalar, bool SupportsBatchOperation>
void LinearAlgebraOp<Scalar, SupportsBatchOperation>::ValidateSquareSolver(
template <typename Scalar>
void LinearAlgebraOp<Scalar>::ValidateSquareSolver(
OpKernelContext* context, const TensorShapes& input_matrix_shapes) {
OP_REQUIRES(context, input_matrix_shapes.size() == 2,
errors::InvalidArgument("Expected two input matrices, got %d.",
Expand All @@ -84,9 +83,8 @@ void LinearAlgebraOp<Scalar, SupportsBatchOperation>::ValidateSquareSolver(
errors::InvalidArgument("Input matrix and rhs are incompatible."));
}

template <typename Scalar, bool SupportsBatchOperation>
void LinearAlgebraOp<Scalar, SupportsBatchOperation>::Compute(
OpKernelContext* context) {
template <typename Scalar>
void LinearAlgebraOp<Scalar>::Compute(OpKernelContext* context) {
TensorInputs inputs;
TensorShapes input_matrix_shapes;
TensorShape batch_shape;
Expand All @@ -110,27 +108,20 @@ void LinearAlgebraOp<Scalar, SupportsBatchOperation>::Compute(
batch_shape.num_elements(), GetCostPerUnit(input_matrix_shapes), shard);
}

template <typename Scalar, bool SupportsBatchOperation>
void LinearAlgebraOp<Scalar, SupportsBatchOperation>::AnalyzeInputs(
OpKernelContext* context, TensorInputs* inputs,
TensorShapes* input_matrix_shapes, TensorShape* batch_shape) {
template <typename Scalar>
void LinearAlgebraOp<Scalar>::AnalyzeInputs(OpKernelContext* context,
TensorInputs* inputs,
TensorShapes* input_matrix_shapes,
TensorShape* batch_shape) {
int input_rank = -1;
for (int i = 0; i < NumMatrixInputs(context); ++i) {
const Tensor& in = context->input(i);
if (i == 0) {
input_rank = in.dims();
if (SupportsBatchOperation) {
OP_REQUIRES(
context, input_rank >= 2,
errors::InvalidArgument("Input tensor ", i,
" must have rank >= 2, got", input_rank));
} else {
OP_REQUIRES(
context, input_rank == 2,
errors::InvalidArgument("Input tensor ", i,
" must have rank == 2, got", input_rank));
}

OP_REQUIRES(
context, input_rank >= 2,
errors::InvalidArgument("Input tensor ", i,
" must have rank >= 2, got", input_rank));
// If the tensor rank is greater than 2, we consider the inner-most
// dimensions as matrices, and loop over all the other outer ("batch")
// dimensions to compute the results.
Expand Down Expand Up @@ -163,8 +154,8 @@ void LinearAlgebraOp<Scalar, SupportsBatchOperation>::AnalyzeInputs(
ValidateInputMatrixShapes(context, *input_matrix_shapes);
}

template <typename Scalar, bool SupportsBatchOperation>
void LinearAlgebraOp<Scalar, SupportsBatchOperation>::PrepareOutputs(
template <typename Scalar>
void LinearAlgebraOp<Scalar>::PrepareOutputs(
OpKernelContext* context, const TensorShapes& input_matrix_shapes,
const TensorShape& batch_shape, TensorOutputs* outputs,
TensorShapes* output_matrix_shapes) {
Expand Down Expand Up @@ -205,8 +196,8 @@ void LinearAlgebraOp<Scalar, SupportsBatchOperation>::PrepareOutputs(
}
}

template <typename Scalar, bool SupportsBatchOperation>
void LinearAlgebraOp<Scalar, SupportsBatchOperation>::ComputeTensorSlice(
template <typename Scalar>
void LinearAlgebraOp<Scalar>::ComputeTensorSlice(
OpKernelContext* context, int64 matrix_index, const TensorInputs& inputs,
const TensorShapes& input_matrix_shapes, const TensorOutputs& outputs,
const TensorShapes& output_matrix_shapes) {
Expand Down Expand Up @@ -238,11 +229,8 @@ void LinearAlgebraOp<Scalar, SupportsBatchOperation>::ComputeTensorSlice(
ComputeMatrix(context, matrix_inputs, &matrix_outputs);
}

// Explicitly instantiate LinearAlgebraOp for the scalar types we expect to
// use.
template class LinearAlgebraOp<float, false>;
template class LinearAlgebraOp<float, true>;
template class LinearAlgebraOp<double, false>;
template class LinearAlgebraOp<double, true>;
// Explicitly instantiate LinearAlgebraOp for the scalar types we expect to use.
template class LinearAlgebraOp<float>;
template class LinearAlgebraOp<double>;

} // namespace tensorflow
8 changes: 3 additions & 5 deletions tensorflow/core/kernels/linalg_ops_common.h
Expand Up @@ -39,7 +39,7 @@ limitations under the License.
namespace tensorflow {

// Base class for linear algebra operators.
template <typename Scalar, bool SupportsBatchOperationT>
template <typename Scalar>
class LinearAlgebraOp : public OpKernel {
public:
explicit LinearAlgebraOp(OpKernelConstruction* context) : OpKernel(context) {}
Expand Down Expand Up @@ -164,10 +164,8 @@ class LinearAlgebraOp : public OpKernel {

// Declare that LinearAlgebraOp is explicitly instantiated in
// linalg_ops_common.cc for float and double.
extern template class LinearAlgebraOp<float, false>;
extern template class LinearAlgebraOp<float, true>;
extern template class LinearAlgebraOp<double, false>;
extern template class LinearAlgebraOp<double, true>;
extern template class LinearAlgebraOp<float>;
extern template class LinearAlgebraOp<double>;

} // namespace tensorflow

Expand Down
15 changes: 7 additions & 8 deletions tensorflow/core/kernels/matrix_inverse_op.cc
Expand Up @@ -28,10 +28,10 @@ limitations under the License.

namespace tensorflow {

template <class Scalar, bool SupportsBatchOperation>
class MatrixInverseOp : public LinearAlgebraOp<Scalar, SupportsBatchOperation> {
template <class Scalar>
class MatrixInverseOp : public LinearAlgebraOp<Scalar> {
public:
typedef LinearAlgebraOp<Scalar, SupportsBatchOperation> Base;
typedef LinearAlgebraOp<Scalar> Base;

explicit MatrixInverseOp(OpKernelConstruction* context) : Base(context) {
OP_REQUIRES_OK(context, context->GetAttr("adjoint", &adjoint_));
Expand Down Expand Up @@ -77,10 +77,9 @@ class MatrixInverseOp : public LinearAlgebraOp<Scalar, SupportsBatchOperation> {
TF_DISALLOW_COPY_AND_ASSIGN(MatrixInverseOp);
};

REGISTER_LINALG_OP("MatrixInverse", (MatrixInverseOp<float, false>), float);
REGISTER_LINALG_OP("MatrixInverse", (MatrixInverseOp<double, false>), double);
REGISTER_LINALG_OP("BatchMatrixInverse", (MatrixInverseOp<float, true>), float);
REGISTER_LINALG_OP("BatchMatrixInverse", (MatrixInverseOp<double, true>),
double);
REGISTER_LINALG_OP("MatrixInverse", (MatrixInverseOp<float>), float);
REGISTER_LINALG_OP("MatrixInverse", (MatrixInverseOp<double>), double);
REGISTER_LINALG_OP("BatchMatrixInverse", (MatrixInverseOp<float>), float);
REGISTER_LINALG_OP("BatchMatrixInverse", (MatrixInverseOp<double>), double);

} // namespace tensorflow
15 changes: 7 additions & 8 deletions tensorflow/core/kernels/matrix_solve_ls_op.cc
Expand Up @@ -28,10 +28,10 @@ limitations under the License.

namespace tensorflow {

template <class Scalar, bool SupportsBatchOperation>
class MatrixSolveLsOp : public LinearAlgebraOp<Scalar, SupportsBatchOperation> {
template <class Scalar>
class MatrixSolveLsOp : public LinearAlgebraOp<Scalar> {
public:
typedef LinearAlgebraOp<Scalar, SupportsBatchOperation> Base;
typedef LinearAlgebraOp<Scalar> Base;

explicit MatrixSolveLsOp(OpKernelConstruction* context) : Base(context) {
OP_REQUIRES_OK(context, context->GetAttr("fast", &fast_));
Expand Down Expand Up @@ -155,10 +155,9 @@ class MatrixSolveLsOp : public LinearAlgebraOp<Scalar, SupportsBatchOperation> {
bool fast_;
};

REGISTER_LINALG_OP("MatrixSolveLs", (MatrixSolveLsOp<float, false>), float);
REGISTER_LINALG_OP("MatrixSolveLs", (MatrixSolveLsOp<double, false>), double);
REGISTER_LINALG_OP("BatchMatrixSolveLs", (MatrixSolveLsOp<float, true>), float);
REGISTER_LINALG_OP("BatchMatrixSolveLs", (MatrixSolveLsOp<double, true>),
double);
REGISTER_LINALG_OP("MatrixSolveLs", (MatrixSolveLsOp<float>), float);
REGISTER_LINALG_OP("MatrixSolveLs", (MatrixSolveLsOp<double>), double);
REGISTER_LINALG_OP("BatchMatrixSolveLs", (MatrixSolveLsOp<float>), float);
REGISTER_LINALG_OP("BatchMatrixSolveLs", (MatrixSolveLsOp<double>), double);

} // namespace tensorflow
14 changes: 7 additions & 7 deletions tensorflow/core/kernels/matrix_solve_op.cc
Expand Up @@ -28,10 +28,10 @@ limitations under the License.

namespace tensorflow {

template <class Scalar, bool SupportsBatchOperation>
class MatrixSolveOp : public LinearAlgebraOp<Scalar, SupportsBatchOperation> {
template <class Scalar>
class MatrixSolveOp : public LinearAlgebraOp<Scalar> {
public:
typedef LinearAlgebraOp<Scalar, SupportsBatchOperation> Base;
typedef LinearAlgebraOp<Scalar> Base;

explicit MatrixSolveOp(OpKernelConstruction* context) : Base(context) {
OP_REQUIRES_OK(context, context->GetAttr("adjoint", &adjoint_));
Expand Down Expand Up @@ -105,9 +105,9 @@ class MatrixSolveOp : public LinearAlgebraOp<Scalar, SupportsBatchOperation> {
TF_DISALLOW_COPY_AND_ASSIGN(MatrixSolveOp);
};

REGISTER_LINALG_OP("MatrixSolve", (MatrixSolveOp<float, false>), float);
REGISTER_LINALG_OP("MatrixSolve", (MatrixSolveOp<double, false>), double);
REGISTER_LINALG_OP("BatchMatrixSolve", (MatrixSolveOp<float, true>), float);
REGISTER_LINALG_OP("BatchMatrixSolve", (MatrixSolveOp<double, true>), double);
REGISTER_LINALG_OP("MatrixSolve", (MatrixSolveOp<float>), float);
REGISTER_LINALG_OP("MatrixSolve", (MatrixSolveOp<double>), double);
REGISTER_LINALG_OP("BatchMatrixSolve", (MatrixSolveOp<float>), float);
REGISTER_LINALG_OP("BatchMatrixSolve", (MatrixSolveOp<double>), double);

} // namespace tensorflow
19 changes: 9 additions & 10 deletions tensorflow/core/kernels/matrix_triangular_solve_op.cc
Expand Up @@ -27,11 +27,10 @@ limitations under the License.

namespace tensorflow {

template <class Scalar, bool SupportsBatchOperation>
class MatrixTriangularSolveOp
: public LinearAlgebraOp<Scalar, SupportsBatchOperation> {
template <class Scalar>
class MatrixTriangularSolveOp : public LinearAlgebraOp<Scalar> {
public:
typedef LinearAlgebraOp<Scalar, SupportsBatchOperation> Base;
typedef LinearAlgebraOp<Scalar> Base;

explicit MatrixTriangularSolveOp(OpKernelConstruction* context)
: Base(context), lower_(true), adjoint_(false) {
Expand Down Expand Up @@ -104,13 +103,13 @@ class MatrixTriangularSolveOp
TF_DISALLOW_COPY_AND_ASSIGN(MatrixTriangularSolveOp);
};

REGISTER_LINALG_OP("MatrixTriangularSolve",
(MatrixTriangularSolveOp<float, false>), float);
REGISTER_LINALG_OP("MatrixTriangularSolve",
(MatrixTriangularSolveOp<double, false>), double);
REGISTER_LINALG_OP("MatrixTriangularSolve", (MatrixTriangularSolveOp<float>),
float);
REGISTER_LINALG_OP("MatrixTriangularSolve", (MatrixTriangularSolveOp<double>),
double);
REGISTER_LINALG_OP("BatchMatrixTriangularSolve",
(MatrixTriangularSolveOp<float, true>), float);
(MatrixTriangularSolveOp<float>), float);
REGISTER_LINALG_OP("BatchMatrixTriangularSolve",
(MatrixTriangularSolveOp<double, true>), double);
(MatrixTriangularSolveOp<double>), double);

} // namespace tensorflow
17 changes: 7 additions & 10 deletions tensorflow/core/kernels/self_adjoint_eig_op.cc
Expand Up @@ -28,11 +28,10 @@ limitations under the License.

namespace tensorflow {

template <class Scalar, bool SupportsBatchOperation>
class SelfAdjointEigOp
: public LinearAlgebraOp<Scalar, SupportsBatchOperation> {
template <class Scalar>
class SelfAdjointEigOp : public LinearAlgebraOp<Scalar> {
public:
typedef LinearAlgebraOp<Scalar, SupportsBatchOperation> Base;
typedef LinearAlgebraOp<Scalar> Base;

explicit SelfAdjointEigOp(OpKernelConstruction* context) : Base(context) {}

Expand Down Expand Up @@ -69,10 +68,8 @@ class SelfAdjointEigOp
}
};

REGISTER_LINALG_OP("SelfAdjointEig", (SelfAdjointEigOp<float, false>), float);
REGISTER_LINALG_OP("SelfAdjointEig", (SelfAdjointEigOp<double, false>), double);
REGISTER_LINALG_OP("BatchSelfAdjointEig", (SelfAdjointEigOp<float, true>),
float);
REGISTER_LINALG_OP("BatchSelfAdjointEig", (SelfAdjointEigOp<double, true>),
double);
REGISTER_LINALG_OP("SelfAdjointEig", (SelfAdjointEigOp<float>), float);
REGISTER_LINALG_OP("SelfAdjointEig", (SelfAdjointEigOp<double>), double);
REGISTER_LINALG_OP("BatchSelfAdjointEig", (SelfAdjointEigOp<float>), float);
REGISTER_LINALG_OP("BatchSelfAdjointEig", (SelfAdjointEigOp<double>), double);
} // namespace tensorflow

0 comments on commit 6cf3f2a

Please sign in to comment.