diff --git a/tensorflow/core/kernels/cuda_solvers.h b/tensorflow/core/kernels/cuda_solvers.h index eb720b191fd484..af27eb6c47d592 100644 --- a/tensorflow/core/kernels/cuda_solvers.h +++ b/tensorflow/core/kernels/cuda_solvers.h @@ -410,16 +410,6 @@ class DeviceLapackInfo : public ScratchSpace { namespace functor { -// Helper functor to compute the product of diagonal elements in all matrices -// in a flattened batch. -template -struct DeterminantFromPivotedLUFunctor { - void operator()(const Device& device, - typename TTypes::ConstTensor lu_factor, - const int* pivots, typename TTypes::Tensor output, - int* info); -}; - // Helper functor to set a batch of matrices to the identity. // TODO(rmlarsen): Use this kernel to replace the horribly inefficient tf.eye // op. diff --git a/tensorflow/core/kernels/cuda_solvers_gpu.cu.cc b/tensorflow/core/kernels/cuda_solvers_gpu.cu.cc index 4171f9d68e4355..84330c041acda8 100644 --- a/tensorflow/core/kernels/cuda_solvers_gpu.cu.cc +++ b/tensorflow/core/kernels/cuda_solvers_gpu.cu.cc @@ -29,159 +29,11 @@ namespace functor { typedef Eigen::GpuDevice GPUDevice; -namespace { - -// Hacks around missing support for complex arithmetic in nvcc. -template -__device__ inline Scalar Multiply(Scalar x, Scalar y) { - return x * y; -} - -template <> -__device__ inline cuComplex Multiply(cuComplex x, cuComplex y) { - return cuCmulf(x, y); -} - -template <> -__device__ inline cuDoubleComplex Multiply(cuDoubleComplex x, - cuDoubleComplex y) { - return cuCmul(x, y); -} - -template -__device__ inline Scalar Negate(Scalar x) { - return -x; -} - -template <> -__device__ inline cuComplex Negate(cuComplex x) { - return make_cuComplex(-cuCrealf(x), -cuCimagf(x)); -} - -template <> -__device__ inline cuDoubleComplex Negate(cuDoubleComplex x) { - return make_cuDoubleComplex(-cuCreal(x), -cuCimag(x)); -} - -template -__device__ inline bool IsFinite(Scalar x) { - return Eigen::numext::isfinite(x); -} - -template <> -__device__ inline bool IsFinite(cuComplex x) { - return Eigen::numext::isfinite(cuCrealf(x)) && - Eigen::numext::isfinite(cuCimagf(x)); -} - -template <> -__device__ inline bool IsFinite(cuDoubleComplex x) { - return Eigen::numext::isfinite(cuCreal(x)) && - Eigen::numext::isfinite(cuCimag(x)); -} - -template -struct Const { - template - __device__ static inline Scalar make_const(const RealScalar x) { - return Scalar(x); - } -}; - -template <> -struct Const { - template - __device__ static inline cuComplex make_const(const RealScalar x) { - return make_cuComplex(x, 0.0f); - } -}; - -template <> -struct Const { - template - __device__ static inline cuDoubleComplex make_const(const RealScalar x) { - return make_cuDoubleComplex(x, 0.0f); - } -}; - -} // namespace - -template -__global__ void DeterminantFromPivotedLUKernel(int nthreads, int n, - const Scalar* lu_factor, - const int* all_pivots, - Scalar* dst, int* info) { - const int matrix_size = n * n; - const int stride = n + 1; - // We only parallelize over batches here. Performance is not critical, - // since this cheap O(n) kernel always follows an O(n^3) LU factorization. - // The main purpose is to avoid having to copy the LU decomposition to - // host memory. - CUDA_1D_KERNEL_LOOP(o_idx, nthreads) { - // Compute the order of the permutation from the number of transpositions - // encoded in the pivot array, see: - // http://icl.cs.utk.edu/lapack-forum/viewtopic.php?f=2&t=340 - const int* pivots = all_pivots + o_idx * n; - int order = 0; - for (int i = 0; i < n - 1; ++i) { - // Notice: Internally, the cuBlas code uses Fortran convention (1-based) - // indexing so we expect pivots[i] == i + 1 for rows that were not moved. - order += pivots[i] != (i + 1); - } - - // Compute the product of the diagonal elements of U from the partially - // pivoted LU factorization. - // TODO(rmlarsen): This naive implementation (matching that in Eigen used - // for the CPU kernel) is pathetically unstable. Should we implement - // log-determinant instead (a different set of ops altogether) or something - // like the method used in the old LINPACK code: - // http://www.netlib.org/linpack/dgedi.f ? - int i_idx = matrix_size * o_idx; - Scalar prod = lu_factor[i_idx]; - for (int i = 1; i < n; ++i) { - i_idx += stride; - prod = Multiply(prod, lu_factor[i_idx]); - } - // Finally set the determinant to (-1)^order * prod(diag(U)). - dst[o_idx] = order % 2 ? Negate(prod) : prod; - - // We write a magic value into the info array if the result was infinite. - if (!IsFinite(prod)) { - info[o_idx] = kint32min; - } - } -} - -template -struct DeterminantFromPivotedLUFunctor { - void operator()(const GPUDevice& device, - typename TTypes::ConstTensor lu_factor, - const int* pivots, typename TTypes::Tensor output, - int* info) { - using CudaType = typename CUDAComplexT::type; - const int64 num_matrices = output.size(); - const int64 n = lu_factor.dimension(2); - const CudaType* lu_factor_ptr = - reinterpret_cast(lu_factor.data()); - CudaType* output_ptr = reinterpret_cast(output.data()); - CudaLaunchConfig config = GetCudaLaunchConfig(num_matrices, device); - DeterminantFromPivotedLUKernel<<< - config.block_count, config.thread_per_block, 0, device.stream()>>>( - config.virtual_thread_count, n, lu_factor_ptr, pivots, output_ptr, - info); - } -}; - -template struct DeterminantFromPivotedLUFunctor; -template struct DeterminantFromPivotedLUFunctor; -template struct DeterminantFromPivotedLUFunctor; -template struct DeterminantFromPivotedLUFunctor; - template __global__ void EyeKernel(Cuda3DLaunchConfig config, int batch_size, int m, int n, Scalar* matrix_batch_ptr) { const int matrix_size = m * n; - const Scalar one = Const::make_const(1.0); + const Scalar one = Scalar(1); CUDA_AXIS_KERNEL_LOOP(batch, config.virtual_thread_count, x) { if (batch >= batch_size) { break; @@ -205,16 +57,14 @@ template struct EyeFunctor { void operator()(const GPUDevice& device, typename TTypes::Tensor matrix_batch) { - using CudaType = typename CUDAComplexT::type; const int batch_size = matrix_batch.dimension(0); const int m = matrix_batch.dimension(1); const int n = matrix_batch.dimension(2); - CudaType* matrix_batch_ptr = - reinterpret_cast(matrix_batch.data()); Cuda3DLaunchConfig config = GetCuda3DLaunchConfig(batch_size, m, n, device, EyeKernel, 0, 0); EyeKernel<<>>(config, batch_size, m, n, matrix_batch_ptr); + device.stream()>>>(config, batch_size, m, n, + matrix_batch.data()); } }; diff --git a/tensorflow/core/kernels/determinant_op.cc b/tensorflow/core/kernels/determinant_op.cc index 876dbff0301d30..b06f42384ebb54 100644 --- a/tensorflow/core/kernels/determinant_op.cc +++ b/tensorflow/core/kernels/determinant_op.cc @@ -14,10 +14,13 @@ limitations under the License. ==============================================================================*/ // See docs in ../ops/linalg_ops.cc. + #include #if GOOGLE_CUDA #define EIGEN_USE_GPU +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/kernels/determinant_op.h" #endif #include "third_party/eigen3/Eigen/LU" @@ -31,23 +34,24 @@ limitations under the License. #include "tensorflow/core/platform/types.h" #if GOOGLE_CUDA -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/kernels/cuda_solvers.h" #include "tensorflow/core/kernels/fill_functor.h" #endif namespace tensorflow { -// A helper function to compute the sign and absolute value of the -// log of the determinant of inputs via a partially pivoted LU +// A helper function to compute the sign and absolute value of the log of the +// determinant of inputs via a partially pivoted LU // factorization. // -// Returns the sign in 'sign' and the log determinant in 'logdet' +// Returns the log of the absolute value of the determinant, and its sign in +// 'sign'. template -static void SLogDet( +static typename Eigen::NumTraits::Real SLogDet( const Eigen::Matrix& inputs, - Scalar* sign, Scalar* log_abs_det) { - *log_abs_det = 0; + Scalar* sign) { + using RealScalar = typename Eigen::NumTraits::Real; + RealScalar log_abs_det = 0; *sign = 1; // An empty matrix' determinant is defined to be 1. // (https://en.wikipedia.org/wiki/Determinant) @@ -58,27 +62,25 @@ static void SLogDet( Eigen::Matrix LU = lu.matrixLU(); *sign = lu.permutationP().determinant(); auto diag = LU.diagonal().array().eval(); - auto abs_diag = diag.cwiseAbs().template cast().eval(); - *log_abs_det += abs_diag.log().sum(); + auto abs_diag = diag.cwiseAbs().eval(); + log_abs_det += abs_diag.log().sum(); *sign *= (diag / abs_diag).prod(); } - if (!Eigen::numext::isfinite(*log_abs_det)) { + if (!Eigen::numext::isfinite(log_abs_det)) { *sign = 0; - *log_abs_det = std::log(0.0); + log_abs_det = + log_abs_det > 0 ? -std::log(RealScalar(0)) : std::log(RealScalar(0)); } + return log_abs_det; } template class LogDeterminantOp : public LinearAlgebraOp { public: - typedef LinearAlgebraOp Base; + INHERIT_LINALG_TYPEDEFS(Scalar); explicit LogDeterminantOp(OpKernelConstruction* context) : Base(context) {} - using TensorShapes = typename Base::TensorShapes; - using MatrixMaps = typename Base::MatrixMaps; - using ConstMatrixMaps = typename Base::ConstMatrixMaps; - TensorShapes GetOutputMatrixShapes( const TensorShapes& input_matrix_shapes) const final { return TensorShapes({TensorShape({}), TensorShape({})}); @@ -87,9 +89,9 @@ class LogDeterminantOp : public LinearAlgebraOp { void ComputeMatrix(OpKernelContext* context, const ConstMatrixMaps& inputs, MatrixMaps* outputs) final { Scalar sign; - Scalar log_abs_det; - SLogDet(Eigen::Matrix(inputs[0]), - &sign, &log_abs_det); + const RealScalar log_abs_det = SLogDet( + Eigen::Matrix(inputs[0]), + &sign); outputs->at(0)(0, 0) = sign; outputs->at(1)(0, 0) = log_abs_det; @@ -99,14 +101,10 @@ class LogDeterminantOp : public LinearAlgebraOp { template class DeterminantOp : public LinearAlgebraOp { public: - typedef LinearAlgebraOp Base; + INHERIT_LINALG_TYPEDEFS(Scalar); explicit DeterminantOp(OpKernelConstruction* context) : Base(context) {} - using TensorShapes = typename Base::TensorShapes; - using MatrixMaps = typename Base::MatrixMaps; - using ConstMatrixMaps = typename Base::ConstMatrixMaps; - TensorShapes GetOutputMatrixShapes( const TensorShapes& input_matrix_shape) const final { return TensorShapes({TensorShape({})}); @@ -115,15 +113,10 @@ class DeterminantOp : public LinearAlgebraOp { void ComputeMatrix(OpKernelContext* context, const ConstMatrixMaps& inputs, MatrixMaps* outputs) final { Scalar sign; - Scalar log_abs_det; - SLogDet(Eigen::Matrix(inputs[0]), - &sign, &log_abs_det); - Scalar determinant = sign * std::exp(log_abs_det); - // TODO(rmlarsen): Don't fail on infinite determinants, since that could - // be a valid result and the user should check for it instead. - OP_REQUIRES(context, Eigen::numext::isfinite(determinant), - errors::InvalidArgument("The determinant is not finite.")); - outputs->at(0)(0, 0) = determinant; + const RealScalar log_abs_det = SLogDet( + Eigen::Matrix(inputs[0]), + &sign); + outputs->at(0)(0, 0) = sign * std::exp(log_abs_det); } }; @@ -171,7 +164,7 @@ class DeterminantOpGpu : public AsyncOpKernel { return; } - // TODO(rmlarsen): Convert to std::make_unique when available. + // TODO(rmlarsen): Convert to absl::make_unique when available. std::unique_ptr solver(new CudaSolver(context)); // Reuse the input buffer or make a copy for the factorization step, @@ -255,18 +248,160 @@ class DeterminantOpGpu : public AsyncOpKernel { for (int i = 0; i < host_infos[0].size(); ++i) { // It is OK for a matrix to be singular (signaled by info > 0), // corresponding to determinant of zero, but we do want to catch - // invalid arguments to GetrfBatched. + // invalid arguments to Getrf{Batched}. OP_REQUIRES_ASYNC( - context, - host_infos[0].data()[i] >= 0 || - host_infos[0].data()[i] == kint32min, + context, host_infos[0](i) >= 0, errors::InvalidArgument("Invalid input argument no. ", host_infos[0].data()[i], " for batch index ", i, "."), done); + } + } + done(); + }; + CudaSolver::CheckLapackInfoAndDeleteSolverAsync(std::move(solver), dev_info, + std::move(info_checker)); + } +}; + +template +class LogDeterminantOpGpu : public AsyncOpKernel { + public: + explicit LogDeterminantOpGpu(OpKernelConstruction* context) + : AsyncOpKernel(context) {} + + void ComputeAsync(OpKernelContext* context, DoneCallback done) final { + const Tensor& input = context->input(0); + const int ndims = input.dims(); + const int64 n = input.dim_size(ndims - 1); + // Validate inputs. + OP_REQUIRES_ASYNC( + context, ndims >= 2, + errors::InvalidArgument("Input must have rank >= 2, got ", ndims), + done); + OP_REQUIRES_ASYNC( + context, input.dim_size(ndims - 2) == n, + errors::InvalidArgument("Input matrices must be square, got", + input.dim_size(ndims - 2), " != ", n), + done); + + // Allocate output. + TensorShape out_shape; + for (int dim = 0; dim < ndims - 2; ++dim) { + out_shape.AddDim(input.dim_size(dim)); + } + out_shape.AppendShape(TensorShape({})); + Tensor* sign; + OP_REQUIRES_OK_ASYNC(context, context->allocate_output(0, out_shape, &sign), + done); + Tensor* log_abs_det; + OP_REQUIRES_OK_ASYNC( + context, context->allocate_output(1, out_shape, &log_abs_det), done); + + // By definition, the determinant of an empty matrix is equal to one. + const GPUDevice& d = context->eigen_device(); + if (input.NumElements() == 0) { + functor::SetOneFunctor one_func; + one_func(d, sign->template flat()); + functor::SetZeroFunctor zero_func; + zero_func(d, log_abs_det->template flat()); + done(); + return; + } + + // TODO(rmlarsen): Convert to absl::make_unique when available. + std::unique_ptr solver(new CudaSolver(context)); + + // Reuse the input buffer or make a copy for the factorization step, + // depending on whether this ops owns it exclusively. + Tensor input_copy; + OP_REQUIRES_OK_ASYNC( + context, + solver->forward_input_or_allocate_scoped_tensor( + {0}, DataTypeToEnum::value, input.shape(), &input_copy), + done); + if (!input.SharesBufferWith(input_copy)) { + d.memcpy(input_copy.flat().data(), input.flat().data(), + input.NumElements() * sizeof(Scalar)); + } + auto input_copy_reshaped = input_copy.template flat_inner_dims(); + const int64 batch_size = input_copy_reshaped.dimension(0); + + // Allocate pivots on the device. + Tensor pivots; + OP_REQUIRES_OK_ASYNC( + context, + solver->allocate_scoped_tensor(DataTypeToEnum::value, + TensorShape{batch_size, n}, &pivots), + done); + auto pivots_mat = pivots.template matrix(); + + // Prepare pointer arrays for cuBlas' batch interface. + // TODO(rmlarsen): Find a way to encode pointer arrays in pinned host memory + // without the ugly casting. + auto input_copy_ptrs = solver->GetScratchSpace( + sizeof(Scalar*) * batch_size, "input_copy_ptrs", + /* on_host */ true); + + // Compute the partially pivoted LU factorization(s) of the matrix/matrices. + std::vector dev_info; + if (n / batch_size <= 128) { + // For small matrices or large batch sizes, we use the batched interface + // from cuBlas. + const Scalar** input_copy_ptrs_base = + reinterpret_cast(input_copy_ptrs.mutable_data()); + for (int batch = 0; batch < batch_size; ++batch) { + input_copy_ptrs_base[batch] = &input_copy_reshaped(batch, 0, 0); + } + dev_info.push_back( + solver->GetDeviceLapackInfo(batch_size, "getrfBatched")); + OP_REQUIRES_OK_ASYNC( + context, + solver->GetrfBatched(n, input_copy_ptrs_base, n, pivots_mat.data(), + &dev_info.back(), batch_size), + done); + } else { + // For large matrices or small batch sizes we use the non-batched + // interface from cuSolver, which is much faster for large matrices. + dev_info.push_back(solver->GetDeviceLapackInfo(batch_size, "getrf")); + for (int batch = 0; batch < batch_size; ++batch) { + OP_REQUIRES_OK_ASYNC( + context, + solver->Getrf(n, n, &input_copy_reshaped(batch, 0, 0), n, + &pivots_mat(batch, 0), &dev_info.back()(batch)), + done); + } + } + + auto input_copy_reshaped_const = + const_cast(&input_copy) + ->template flat_inner_dims(); + auto sign_reshaped = sign->flat(); + auto log_abs_det_reshaped = log_abs_det->flat(); + // Compute the determinant for each batch as (-1)^s * prod(diag(U)), + // where s is the order of the permutation encoded in pivots and U is the + // upper triangular factor of the LU factorization, which is written to + // input_copy by the Getrf{Batched} kernel. + functor::LogDeterminantFromPivotedLUFunctor functor; + functor(d, input_copy_reshaped_const, pivots_mat.data(), sign_reshaped, + log_abs_det_reshaped); + + // Register callback to check info after kernels finish. + auto info_checker = [context, done]( + const Status& status, + const std::vector& host_infos) { + if (!status.ok() && errors::IsInvalidArgument(status) && + !host_infos.empty()) { + for (int i = 0; i < host_infos[0].size(); ++i) { + // It is OK for a matrix to be singular (signaled by info > 0), + // corresponding to determinant of zero, but we do want to catch + // invalid arguments to Getrf{Batched}. OP_REQUIRES_ASYNC( - context, host_infos[0].data()[i] != kint32min, - errors::InvalidArgument("The determinant is not finite."), done); + context, host_infos[0](i) >= 0, + errors::InvalidArgument("Invalid input argument no. ", + host_infos[0].data()[i], + " for batch index ", i, "."), + done); } } done(); @@ -282,6 +417,15 @@ REGISTER_LINALG_OP_GPU("MatrixDeterminant", (DeterminantOpGpu), complex64); REGISTER_LINALG_OP_GPU("MatrixDeterminant", (DeterminantOpGpu), complex128); + +REGISTER_LINALG_OP_GPU("LogMatrixDeterminant", (LogDeterminantOpGpu), + float); +REGISTER_LINALG_OP_GPU("LogMatrixDeterminant", (LogDeterminantOpGpu), + double); +REGISTER_LINALG_OP_GPU("LogMatrixDeterminant", (LogDeterminantOpGpu), + complex64); +REGISTER_LINALG_OP_GPU("LogMatrixDeterminant", + (LogDeterminantOpGpu), complex128); #endif // GOOGLE_CUDA REGISTER_LINALG_OP("MatrixDeterminant", (DeterminantOp), float); diff --git a/tensorflow/core/kernels/determinant_op.h b/tensorflow/core/kernels/determinant_op.h new file mode 100644 index 00000000000000..e931e328e4bbb2 --- /dev/null +++ b/tensorflow/core/kernels/determinant_op.h @@ -0,0 +1,47 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DETERMINANT_OP_H_ +#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DETERMINANT_OP_H_ + +#include "tensorflow/core/framework/tensor_types.h" + +namespace tensorflow { +namespace functor { + +// Helper functor to compute Determinant from a partially pivoted LU +// factorization. +template +struct DeterminantFromPivotedLUFunctor { + void operator()(const Device& device, + typename TTypes::ConstTensor lu_factor, + const int* pivots, typename TTypes::Tensor output, + int* info); +}; + +// Helper functor to compute sign and log of the absolute value of the +// determinant from a partially pivoted LU factorization. +template +struct LogDeterminantFromPivotedLUFunctor { + void operator()(const Device& device, + typename TTypes::ConstTensor lu_factor, + const int* pivots, typename TTypes::Tensor sign, + typename TTypes::Tensor log_abs_det); +}; + +} // namespace functor +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DETERMINANT_OP_H_ diff --git a/tensorflow/core/kernels/determinant_op_gpu.cu.cc b/tensorflow/core/kernels/determinant_op_gpu.cu.cc new file mode 100644 index 00000000000000..c866204c97e6ac --- /dev/null +++ b/tensorflow/core/kernels/determinant_op_gpu.cu.cc @@ -0,0 +1,168 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include "tensorflow/core/kernels/determinant_op.h" + +#include +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/kernels/cuda_solvers.h" +#include "tensorflow/core/util/cuda_kernel_helper.h" + +namespace tensorflow { +namespace functor { + +typedef Eigen::GpuDevice GPUDevice; +namespace { +__device__ int PermutationOrder(int n, const int* pivots) { + // Compute the order of the permutation from the number of transpositions + // encoded in the pivot array, see: + // http://icl.cs.utk.edu/lapack-forum/viewtopic.php?f=2&t=340 + int order = 0; + for (int i = 0; i < n - 1; ++i) { + // Notice: Internally, the cuBlas code uses Fortran convention (1-based) + // indexing so we expect pivots[i] == i + 1 for rows that were not moved. + order += pivots[i] != (i + 1); + } + return order; +} + +#if defined(__CUDACC__) +// Hack around missing support for complex in NVCC. +template +__device__ inline std::complex complex_multiply(const std::complex& a, + const std::complex& b) { + const T a_real = Eigen::numext::real(a); + const T a_imag = Eigen::numext::imag(a); + const T b_real = Eigen::numext::real(b); + const T b_imag = Eigen::numext::imag(b); + return std::complex(a_real * b_real - a_imag * b_imag, + a_real * b_imag + a_imag * b_real); +} +__device__ inline complex64 operator*(const complex64& a, const complex64& b) { + return complex_multiply(a, b); +} +__device__ inline complex64 operator*(const complex64& a, const float& b) { + return complex64(Eigen::numext::real(a) * b, Eigen::numext::imag(a) * b); +} +__device__ inline complex64 operator/(const complex64& a, const float& b) { + const float inv_b = 1.0f / b; + return a * inv_b; +} +__device__ inline complex128 operator*(const complex128& a, + const complex128& b) { + return complex_multiply(a, b); +} +__device__ inline complex128 operator*(const complex128& a, const double& b) { + return complex128(Eigen::numext::real(a) * b, Eigen::numext::imag(a) * b); +} +__device__ inline complex128 operator/(const complex128& a, const double& b) { + const double inv_b = 1.0 / b; + return a * inv_b; +} +#endif +} // namespace + +// This kernel computes either determinant or log_abs_determinant, depending +// on the value of the template parameter. If compute_log_abs_det is false, +// the sign argument is ignored. +template +__global__ void DeterminantFromPivotedLUKernel(int nthreads, int n, + const Scalar* lu_factor, + const int* all_pivots, + Scalar* sign, + Scalar* log_abs_det) { + typedef typename Eigen::NumTraits::Real RealScalar; + const int matrix_size = n * n; + const int stride = n + 1; + // We only parallelize over batches here. Performance is not critical, + // since this cheap O(n) kernel always follows an O(n^3) LU factorization. + // The main purpose is to avoid having to copy the LU decomposition to + // host memory. + CUDA_1D_KERNEL_LOOP(o_idx, nthreads) { + // Initialize sign to (-1)^order. + const int order = PermutationOrder(n, all_pivots + o_idx * n); + Scalar prod_sign = order % 2 ? Scalar(-1) : Scalar(1); + RealScalar sum_log_abs_det = RealScalar(0); + int i_idx = matrix_size * o_idx; + for (int i = 0; i < n; ++i, i_idx += stride) { + const RealScalar abs_i = Eigen::numext::abs(lu_factor[i_idx]); + sum_log_abs_det += Eigen::numext::log(abs_i); + prod_sign = prod_sign * (lu_factor[i_idx] / abs_i); + } + if (!Eigen::numext::isfinite(sum_log_abs_det)) { + prod_sign = Scalar(0); + sum_log_abs_det = sum_log_abs_det > 0 ? -Eigen::numext::log(RealScalar(0)) + : Eigen::numext::log(RealScalar(0)); + } + if (compute_log_abs_det) { + sign[o_idx] = prod_sign; + log_abs_det[o_idx] = Scalar(sum_log_abs_det); + } else { + log_abs_det[o_idx] = prod_sign * Eigen::numext::exp(sum_log_abs_det); + } + } +} + +template +struct DeterminantFromPivotedLUFunctor { + void operator()(const GPUDevice& device, + typename TTypes::ConstTensor lu_factor, + const int* pivots, typename TTypes::Tensor output, + int* info) { + const int64 num_matrices = output.size(); + const int64 n = lu_factor.dimension(2); + CudaLaunchConfig config = GetCudaLaunchConfig(num_matrices, device); + DeterminantFromPivotedLUKernel + <<>>( + config.virtual_thread_count, n, lu_factor.data(), pivots, nullptr, + output.data()); + } +}; + +template struct DeterminantFromPivotedLUFunctor; +template struct DeterminantFromPivotedLUFunctor; +template struct DeterminantFromPivotedLUFunctor; +template struct DeterminantFromPivotedLUFunctor; + +template +struct LogDeterminantFromPivotedLUFunctor { + void operator()(const GPUDevice& device, + typename TTypes::ConstTensor lu_factor, + const int* pivots, typename TTypes::Tensor sign, + typename TTypes::Tensor log_abs_det) { + const int64 num_matrices = sign.size(); + const int64 n = lu_factor.dimension(2); + CudaLaunchConfig config = GetCudaLaunchConfig(num_matrices, device); + DeterminantFromPivotedLUKernel + <<>>( + config.virtual_thread_count, n, lu_factor.data(), pivots, + sign.data(), log_abs_det.data()); + } +}; + +template struct LogDeterminantFromPivotedLUFunctor; +template struct LogDeterminantFromPivotedLUFunctor; +template struct LogDeterminantFromPivotedLUFunctor; +template struct LogDeterminantFromPivotedLUFunctor; + +} // namespace functor +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/linalg_ops_common.h b/tensorflow/core/kernels/linalg_ops_common.h index 1d31786728f5c4..f7c3f1950b9af3 100644 --- a/tensorflow/core/kernels/linalg_ops_common.h +++ b/tensorflow/core/kernels/linalg_ops_common.h @@ -172,13 +172,14 @@ extern template class LinearAlgebraOp; } // namespace tensorflow -#define INHERIT_LINALG_TYPEDEFS(Scalar) \ - typedef LinearAlgebraOp Base; \ - using Matrix = typename Base::Matrix; \ - using MatrixMap = typename Base::MatrixMap; \ - using MatrixMaps = typename Base::MatrixMaps; \ - using ConstMatrixMap = typename Base::ConstMatrixMap; \ - using ConstMatrixMaps = typename Base::ConstMatrixMaps; \ +#define INHERIT_LINALG_TYPEDEFS(Scalar) \ + typedef LinearAlgebraOp Base; \ + using RealScalar = typename Eigen::NumTraits::Real; \ + using Matrix = typename Base::Matrix; \ + using MatrixMap = typename Base::MatrixMap; \ + using MatrixMaps = typename Base::MatrixMaps; \ + using ConstMatrixMap = typename Base::ConstMatrixMap; \ + using ConstMatrixMaps = typename Base::ConstMatrixMaps; \ using TensorShapes = typename Base::TensorShapes; #define REGISTER_LINALG_OP_CPU(OpName, OpClass, Scalar) \ diff --git a/tensorflow/core/kernels/matrix_inverse_op.cc b/tensorflow/core/kernels/matrix_inverse_op.cc index 64edfe470d0581..cae84f52d7a941 100644 --- a/tensorflow/core/kernels/matrix_inverse_op.cc +++ b/tensorflow/core/kernels/matrix_inverse_op.cc @@ -69,7 +69,6 @@ class MatrixInverseOp : public LinearAlgebraOp { // a result of basic user mistakes, such as providing integer valued // matrices that are exactly singular, or due to underflow if this // code is run with denormals being flushed to zero. - using RealScalar = typename Base::RealScalar; const RealScalar min_abs_pivot = lu_decomposition.matrixLU().diagonal().cwiseAbs().minCoeff(); OP_REQUIRES(context, min_abs_pivot > RealScalar(0), diff --git a/tensorflow/core/kernels/matrix_solve_op.cc b/tensorflow/core/kernels/matrix_solve_op.cc index 2e4098dfabae73..169f3dae76d2fb 100644 --- a/tensorflow/core/kernels/matrix_solve_op.cc +++ b/tensorflow/core/kernels/matrix_solve_op.cc @@ -44,18 +44,12 @@ static const char kErrMsg[] = "Input matrix is not invertible."; template class MatrixSolveOp : public LinearAlgebraOp { public: - typedef LinearAlgebraOp Base; + INHERIT_LINALG_TYPEDEFS(Scalar); explicit MatrixSolveOp(OpKernelConstruction* context) : Base(context) { OP_REQUIRES_OK(context, context->GetAttr("adjoint", &adjoint_)); } - using TensorShapes = typename Base::TensorShapes; - using Matrix = typename Base::Matrix; - using MatrixMaps = typename Base::MatrixMaps; - using ConstMatrixMap = typename Base::ConstMatrixMap; - using ConstMatrixMaps = typename Base::ConstMatrixMaps; - void ValidateInputMatrixShapes( OpKernelContext* context, const TensorShapes& input_matrix_shapes) const final { @@ -102,7 +96,6 @@ class MatrixSolveOp : public LinearAlgebraOp { // a result of basic user mistakes such providing integer valued // matrices that are exactly singular, or due to underflow if this // code is run with denormals being flushed to zero. - using RealScalar = typename Base::RealScalar; const RealScalar min_abs_pivot = lu_decomposition.matrixLU().diagonal().cwiseAbs().minCoeff(); OP_REQUIRES(context, min_abs_pivot > RealScalar(0), diff --git a/tensorflow/core/kernels/matrix_triangular_solve_op.cc b/tensorflow/core/kernels/matrix_triangular_solve_op.cc index 953f37fa029862..6f7e6a7496840f 100644 --- a/tensorflow/core/kernels/matrix_triangular_solve_op.cc +++ b/tensorflow/core/kernels/matrix_triangular_solve_op.cc @@ -47,7 +47,7 @@ perftools::gputools::DeviceMemory AsDeviceMemory( template class MatrixTriangularSolveOp : public LinearAlgebraOp { public: - typedef LinearAlgebraOp Base; + INHERIT_LINALG_TYPEDEFS(Scalar); explicit MatrixTriangularSolveOp(OpKernelConstruction* context) : Base(context), lower_(true), adjoint_(false) { @@ -55,13 +55,6 @@ class MatrixTriangularSolveOp : public LinearAlgebraOp { OP_REQUIRES_OK(context, context->GetAttr("adjoint", &adjoint_)); } - using TensorShapes = typename Base::TensorShapes; - using Matrix = typename Base::Matrix; - using MatrixMap = typename Base::MatrixMap; - using MatrixMaps = typename Base::MatrixMaps; - using ConstMatrixMap = typename Base::ConstMatrixMap; - using ConstMatrixMaps = typename Base::ConstMatrixMaps; - void ValidateInputMatrixShapes( OpKernelContext* context, const TensorShapes& input_matrix_shapes) const final { @@ -97,7 +90,6 @@ class MatrixTriangularSolveOp : public LinearAlgebraOp { // an empty set of equation as the empty matrix. return; } - using RealScalar = typename Base::RealScalar; const RealScalar min_abs_pivot = matrix.diagonal().cwiseAbs().minCoeff(); OP_REQUIRES(context, min_abs_pivot > RealScalar(0), errors::InvalidArgument("Input matrix is not invertible.")); diff --git a/tensorflow/python/kernel_tests/determinant_op_test.py b/tensorflow/python/kernel_tests/determinant_op_test.py index 7368fbc4a1a0d7..222038b22ef3c7 100644 --- a/tensorflow/python/kernel_tests/determinant_op_test.py +++ b/tensorflow/python/kernel_tests/determinant_op_test.py @@ -126,11 +126,10 @@ def testBasicComplex128(self): self._compareDeterminant( np.random.rand(3, 4, 5, 2, 2).astype(np.complex128)) - def testOverflow(self): + def testInfiniteDeterminant(self): max_double = np.finfo("d").max huge_matrix = np.array([[max_double, 0.0], [0.0, max_double]]) - with self.assertRaisesOpError("not finite"): - self._compareDeterminant(huge_matrix) + self._compareDeterminant(huge_matrix) def testNonSquareMatrix(self): # When the determinant of a non-square matrix is attempted we should return