diff --git a/mshadow/dot_engine-inl.h b/mshadow/dot_engine-inl.h index 2266f0b506be..062181b910c8 100644 --- a/mshadow/dot_engine-inl.h +++ b/mshadow/dot_engine-inl.h @@ -23,7 +23,7 @@ struct DotEngine { const Tensor &rhs, DType scale); }; -// handles the dot +// handles the dot, use CblasColMajor template struct BLASEngine { inline static bool GetT(bool t) { @@ -38,6 +38,13 @@ struct BLASEngine { DType beta, DType *C, int ldc) { LOG(FATAL) << "Not implmented!"; } + inline static void batched_gemm(Stream *stream, + bool transa, bool transb, + int m, int n, int k, DType alpha, + const DType *A, int lda, const DType *B, int ldb, + DType beta, DType *C, int ldc, int batch_count) { + LOG(FATAL) << "Not implmented!"; + } inline static void gemv(Stream *stream, bool trans, int m, int n, DType alpha, const DType *A, int lda, @@ -45,12 +52,25 @@ struct BLASEngine { DType beta, DType *Y, int incY) { LOG(FATAL) << "Not implmented!"; } + inline static void batched_gemv(Stream *stream, + bool trans, int m, int n, + DType alpha, const DType *A, int lda, + const DType *X, int incX, + DType beta, DType *Y, int incY, int batch_count) { + LOG(FATAL) << "Not implmented!"; + } inline static void ger(Stream *stream, int m, int n, DType alpha, const DType *X, int incX, const DType *Y, int incY, DType *A, int lda) { LOG(FATAL) << "Not implmented!"; } + inline static void batched_ger(Stream *stream, + int m, int n, DType alpha, + const DType *X, int incX, + const DType *Y, int incY, DType *A, int lda, int batch_count) { + LOG(FATAL) << "Not implmented!"; + } inline static void dot(Stream *stream, int n, const DType* X, int incX, @@ -92,6 +112,17 @@ struct BLASEngine { LOG(FATAL) << "Not implmented!"; } } + inline static void batched_gemm(Stream *stream, + bool transa, bool transb, + int m, int n, int k, DType alpha, + const float *A, int lda, const float *B, int ldb, + float beta, float *C, int ldc, int batch_count) { + for (int i = 0; i < batch_count; ++i) { + gemm(stream, transa, transb, m, n, k, alpha, + A + i * m * k, lda, B + i * k * n, ldb, + beta, C + i * m * n, ldc); + } + } inline static void gemv(Stream *stream, bool trans, int m, int n, float alpha, const float *A, int lda, @@ -99,12 +130,25 @@ struct BLASEngine { float beta, float *Y, int incY) { LOG(FATAL) << "Not implmented!"; } + inline static void batched_gemv(Stream *stream, + bool trans, int m, int n, + float alpha, const float *A, int lda, + const float *X, int incX, + float beta, float *Y, int incY, int batch_count) { + LOG(FATAL) << "Not implmented!"; + } inline static void ger(Stream *stream, int m, int n, float alpha, const float *X, int incX, const float *Y, int incY, float *A, int lda) { LOG(FATAL) << "Not implmented!"; } + inline static void batched_ger(Stream *stream, + int m, int n, float alpha, + const float *X, int incX, + const float *Y, int incY, float *A, int lda, int batch_count) { + LOG(FATAL) << "Not implmented!"; + } inline static void dot(Stream *stream, int n, const float* X, int incX, @@ -145,6 +189,17 @@ struct BLASEngine { LOG(FATAL) << "Not implmented!"; } } + inline static void batched_gemm(Stream *stream, + bool transa, bool transb, + int m, int n, int k, DType alpha, + const double *A, int lda, const double *B, int ldb, + double beta, double *C, int ldc, int batch_count) { + for (int i = 0; i < batch_count; ++i) { + gemm(stream, transa, transb, m, n, k, alpha, + A + i * m * k, lda, B + i * k * n, ldb, + beta, C + i * m * n, ldc); + } + } inline static void gemv(Stream *stream, bool trans, int m, int n, double alpha, const double *A, int lda, @@ -152,12 +207,25 @@ struct BLASEngine { double beta, double *Y, int incY) { LOG(FATAL) << "Not implmented!"; } + inline static void batched_gemv(Stream *stream, + bool trans, int m, int n, + double alpha, const double *A, int lda, + const double *X, int incX, + double beta, double *Y, int incY, int batch_count) { + LOG(FATAL) << "Not implmented!"; + } inline static void ger(Stream *stream, int m, int n, double alpha, const double *X, int incX, const double *Y, int incY, double *A, int lda) { LOG(FATAL) << "Not implmented!"; } + inline static void batched_ger(Stream *stream, + int m, int n, double alpha, + const double *X, int incX, + const double *Y, int incY, double *A, int lda, int batch_count) { + LOG(FATAL) << "Not implmented!"; + } inline static void dot(Stream *stream, int n, const double* X, int incX, @@ -183,6 +251,17 @@ struct BLASEngine { cblas_sgemm(CblasColMajor, GetT(transa), GetT(transb), m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); } + inline static void batched_gemm(Stream *stream, + bool transa, bool transb, + int m, int n, int k, float alpha, + const float *A, int lda, const float *B, int ldb, + float beta, float *C, int ldc, int batch_count) { + for (int i = 0; i < batch_count; ++i) { + gemm(stream, transa, transb, m, n, k, alpha, + A + i * m * k, lda, B + i * k * n, ldb, + beta, C + i * m * n, ldc); + } + } inline static void gemv(Stream *stream, bool trans, int m, int n, float alpha, const float *A, int lda, @@ -191,12 +270,32 @@ struct BLASEngine { cblas_sgemv(CblasColMajor, GetT(trans), m, n, alpha, A, lda, X, incX, beta, Y, incY); } + inline static void batched_gemv(Stream *stream, + bool trans, int m, int n, + float alpha, const float *A, int lda, + const float *X, int incX, + float beta, float *Y, int incY, int batch_count) { + for (int i = 0; i < batch_count; ++i) { + gemv(stream, trans, m, n, alpha, A + i * m * n, lda, + X + i * (trans ? m : n) * incX, incX, + beta, Y + i * (trans ? n : m) * incY, incY); + } + } inline static void ger(Stream *stream, int m, int n, float alpha, const float *X, int incX, const float *Y, int incY, float *A, int lda) { cblas_sger(CblasColMajor, m, n, alpha, X, incX, Y, incY, A, lda); } + inline static void batched_ger(Stream *stream, + int m, int n, float alpha, + const float *X, int incX, + const float *Y, int incY, float *A, int lda, int batch_count) { + for (int i = 0; i < batch_count; ++i) { + ger(stream, m, n, alpha, X + i * m * incX, incX, Y + i * n * incY, incY, + A + i * lda * n, lda); + } + } inline static void dot(Stream *stream, int n, const float* X, int incX, @@ -221,6 +320,17 @@ struct BLASEngine { cblas_dgemm(CblasColMajor, GetT(transa), GetT(transb), m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); } + inline static void batched_gemm(Stream *stream, + bool transa, bool transb, + int m, int n, int k, double alpha, + const double *A, int lda, const double *B, int ldb, + double beta, double *C, int ldc, int batch_count) { + for (int i = 0; i < batch_count; ++i) { + gemm(stream, transa, transb, m, n, k, alpha, + A + i * m * k, lda, B + i * k * n, ldb, + beta, C + i * m * n, ldc); + } + } inline static void gemv(Stream *stream, bool trans, int m, int n, double alpha, const double *A, int lda, @@ -229,12 +339,32 @@ struct BLASEngine { cblas_dgemv(CblasColMajor, GetT(trans), m, n, alpha, A, lda, X, incX, beta, Y, incY); } + inline static void batched_gemv(Stream *stream, + bool trans, int m, int n, + double alpha, const double *A, int lda, + const double *X, int incX, + double beta, double *Y, int incY, int batch_count) { + for (int i = 0; i < batch_count; ++i) { + gemv(stream, trans, m, n, alpha, A + i * m * n, lda, + X + i * (trans ? m : n) * incX, incX, + beta, Y + i * (trans ? n : m) * incY, incY); + } + } inline static void ger(Stream *stream, int m, int n, double alpha, const double *X, int incX, const double *Y, int incY, double *A, int lda) { cblas_dger(CblasColMajor, m, n, alpha, X, incX, Y, incY, A, lda); } + inline static void batched_ger(Stream *stream, + int m, int n, double alpha, + const double *X, int incX, + const double *Y, int incY, double *A, int lda, int batch_count) { + for (int i = 0; i < batch_count; ++i) { + ger(stream, m, n, alpha, X + i * m * incX, incX, Y + i * n * incY, incY, + A + i * lda * n, lda); + } + } inline static void dot(Stream *stream, int n, const double* X, int incX, @@ -290,6 +420,17 @@ struct BLASEngine { LOG(FATAL) << "Require CUDA version >= 7.5!"; #endif // defined(CUDA_VERSION) && CUDA_VERSION >= 7050 } + inline static void batched_gemm(Stream *stream, + bool transa, bool transb, + int m, int n, int k, half::half_t alpha, + const half::half_t *A, int lda, const half::half_t *B, int ldb, + half::half_t beta, half::half_t *C, int ldc, int batch_count) { + for (int i = 0; i < batch_count; ++i) { + gemm(stream, transa, transb, m, n, k, alpha, + A + i * m * k, lda, B + i * k * n, ldb, + beta, C + i * m * n, ldc); + } + } inline static void gemv(Stream *stream, bool trans, int m, int n, half::half_t alpha, const half::half_t *A, int lda, @@ -297,12 +438,25 @@ struct BLASEngine { half::half_t *Y, int incY) { LOG(FATAL) << "Not implmented!"; } + inline static void batched_gemv(Stream *stream, + bool trans, int m, int n, + half::half_t alpha, const half::half_t *A, int lda, + const half::half_t *X, int incX, + half::half_t beta, half::half_t *Y, int incY, int batch_count) { + LOG(FATAL) << "Not implmented!"; + } inline static void ger(Stream *stream, int m, int n, half::half_t alpha, const half::half_t *X, int incX, const half::half_t *Y, int incY, half::half_t *A, int lda) { LOG(FATAL) << "Not implmented!"; } + inline static void batched_ger(Stream *stream, + int m, int n, half::half_t alpha, + const half::half_t *X, int incX, const half::half_t *Y, int incY, + half::half_t *A, int lda, int batch_count) { + LOG(FATAL) << "Not implmented!"; + } inline static void dot(Stream *stream, int n, const half::half_t* X, int incX, @@ -320,7 +474,7 @@ struct BLASEngine { inline static void SetStream(Stream *stream) { cublasStatus_t err = cublasSetStream(Stream::GetBlasHandle(stream), Stream::GetStream(stream)); - CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas set stream fail"; + CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: set stream fail"; } inline static void gemm(Stream *stream, bool transa, bool transb, @@ -331,7 +485,18 @@ struct BLASEngine { cublasStatus_t err = cublasSgemm(Stream::GetBlasHandle(stream), GetT(transa), GetT(transb), m, n, k, &alpha, A, lda, B, ldb, &beta, C, ldc); - CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas Sgemm fail"; + CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: Sgemm fail"; + } + inline static void batched_gemm(Stream *stream, + bool transa, bool transb, + int m, int n, int k, float alpha, + const float *A, int lda, const float *B, int ldb, + float beta, float *C, int ldc, int batch_count) { + for (int i = 0; i < batch_count; ++i) { + gemm(stream, transa, transb, m, n, k, alpha, + A + i * m * k, lda, B + i * k * n, ldb, + beta, C + i * m * n, ldc); + } } inline static void gemv(Stream *stream, bool trans, int m, int n, float alpha, @@ -342,6 +507,17 @@ struct BLASEngine { GetT(trans), m, n, &alpha, A, lda, X, incX, &beta, Y, incY); CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: Sgemv fail"; } + inline static void batched_gemv(Stream *stream, + bool trans, int m, int n, + float alpha, const float *A, int lda, + const float *X, int incX, + float beta, float *Y, int incY, int batch_count) { + for (int i = 0; i < batch_count; ++i) { + gemv(stream, trans, m, n, alpha, A + i * m * n, lda, + X + i * (trans ? m : n) * incX, incX, + beta, Y + i * (trans ? n : m) * incY, incY); + } + } inline static void ger(Stream *stream, int m, int n, float alpha, const float *X, int incX, @@ -350,6 +526,15 @@ struct BLASEngine { m, n, &alpha, X, incX, Y, incY, A, lda); CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: Sger fail"; } + inline static void batched_ger(Stream *stream, + int m, int n, float alpha, + const float *X, int incX, + const float *Y, int incY, float *A, int lda, int batch_count) { + for (int i = 0; i < batch_count; ++i) { + ger(stream, m, n, alpha, X + i * m * incX, incX, Y + i * n * incY, incY, + A + i * lda * n, lda); + } + } inline static void dot(Stream *stream, int n, const float* X, int incX, @@ -373,7 +558,7 @@ struct BLASEngine { inline static void SetStream(Stream *stream) { cublasStatus_t err = cublasSetStream(Stream::GetBlasHandle(stream), Stream::GetStream(stream)); - CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas set stream fail"; + CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: set stream fail"; } inline static void gemm(Stream *stream, bool transa, bool transb, @@ -386,6 +571,17 @@ struct BLASEngine { A, lda, B, ldb, &beta, C, ldc); CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: Dgemm fail"; } + inline static void batched_gemm(Stream *stream, + bool transa, bool transb, + int m, int n, int k, double alpha, + const double *A, int lda, const double *B, int ldb, + double beta, double *C, int ldc, int batch_count) { + for (int i = 0; i < batch_count; ++i) { + gemm(stream, transa, transb, m, n, k, alpha, + A + i * m * k, lda, B + i * k * n, ldb, + beta, C + i * m * n, ldc); + } + } inline static void gemv(Stream *stream, bool trans, int m, int n, double alpha, const double *A, int lda, @@ -395,6 +591,17 @@ struct BLASEngine { GetT(trans), m, n, &alpha, A, lda, X, incX, &beta, Y, incY); CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: Dgemv fail"; } + inline static void batched_gemv(Stream *stream, + bool trans, int m, int n, + double alpha, const double *A, int lda, + const double *X, int incX, + double beta, double *Y, int incY, int batch_count) { + for (int i = 0; i < batch_count; ++i) { + gemv(stream, trans, m, n, alpha, A + i * m * n, lda, + X + i * (trans ? m : n) * incX, incX, + beta, Y + i * (trans ? n : m) * incY, incY); + } + } inline static void ger(Stream *stream, int m, int n, double alpha, const double *X, int incX, @@ -403,6 +610,15 @@ struct BLASEngine { m, n, &alpha, X, incX, Y, incY, A, lda); CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: Dger fail"; } + inline static void batched_ger(Stream *stream, + int m, int n, double alpha, + const double *X, int incX, + const double *Y, int incY, double *A, int lda, int batch_count) { + for (int i = 0; i < batch_count; ++i) { + ger(stream, m, n, alpha, X + i * m * incX, incX, Y + i * n * incY, incY, + A + i * lda * n, lda); + } + } inline static void dot(Stream *stream, int n, const double* X, int incX, @@ -422,6 +638,9 @@ struct BLASEngine { inline static Shape<2> GetShape(const Shape<2> &shape, bool transpose) { return transpose ? Shape2(shape[1], shape[0]) : shape; } +inline static Shape<3> GetBatchedShape(const Shape<3> &shape, bool transpose) { + return transpose ? Shape3(shape[0], shape[2], shape[1]) : shape; +} // dst = dot(lhs[.T], rhs[.T]) template @@ -498,7 +717,7 @@ struct DotEngine { // set kernel stream // if there is no stream, crush BLASEngine::SetStream(dst.stream_); - CHECK_EQ(dst.size(0), lhs.size(0) && dst.size(1) == rhs.size(0)) + CHECK(dst.size(0) == lhs.size(0) && dst.size(1) == rhs.size(0)) << "dot-ger: matrix shape mismatch" << "dst: " << dst.shape_ << "\n" << "lhs: " << lhs.shape_ << "\n" @@ -513,6 +732,73 @@ struct DotEngine { } } }; +// dst = batched_dot(lhs[.T], rhs[.T]) +template +struct DotEngine { + inline static void Eval(Tensor *p_dst, + const Tensor &lhs, + const Tensor &rhs, + DType scale) { + Tensor &dst = *p_dst; + // set kernel stream + // if there is no stream, crush + BLASEngine::SetStream(dst.stream_); + Shape<3> sleft = GetBatchedShape(lhs.shape_, transpose_left); + Shape<3> sright = GetBatchedShape(rhs.shape_, transpose_right); + CHECK(dst.size(0) == sleft[0] && dst.size(0) == sright[0]) + << "batch_dot-gemm: batchsize must be equal." + << "dst: " << dst.shape_ << "\n" + << "lhs: " << sleft << "\n" + << "rhs: " << sright << "\n"; + CHECK(dst.size(1) == sleft[1] && dst.size(2) == sright[2] && sleft[2] == sright[1]) + << "batch_dot-gemm: matrix shape mismatch" + << "dst: " << dst.shape_ << "\n" + << "lhs: " << sleft << "\n" + << "rhs: " << sright << "\n"; + // use column major argument to compatible with most BLAS + if (sleft[1] == 1) { + // For (batch, 1, K) gemm (batch, K, N), we can use (batch, N, K) gemv (batch, K) + BLASEngine::batched_gemv + (dst.stream_, + transpose_right, + rhs.size(2), rhs.size(1), scale * SV::AlphaBLAS(), + rhs.dptr_, rhs.stride_, + lhs.dptr_, 1, SV::BetaBLAS(), + dst.dptr_, 1, dst.size(0)); + } else if (sleft[2] == 1 && (SV::BetaBLAS() == 0.0f || SV::BetaBLAS() == 1.0f)) { + // For (batch, M, 1) gemm (batch, 1, N) + Beta = 0, we can use (batch, M) ger (batch, N) + if (SV::BetaBLAS() == 0.0f) { + dst = DType(0); + } + BLASEngine::batched_ger + (dst.stream_, sright[2], sleft[1], scale * SV::AlphaBLAS(), + rhs.dptr_, 1, lhs.dptr_, 1, dst.dptr_, dst.stride_, dst.size(0)); + } else if (sright[2] == 1) { + // For (batch, M, K) gemm (batch, K, 1), we can use (batch, M, K) gemv (batch, K) + BLASEngine::batched_gemv + (dst.stream_, + !transpose_left, + lhs.size(2), lhs.size(1), scale * SV::AlphaBLAS(), + lhs.dptr_, lhs.stride_, + rhs.dptr_, 1, SV::BetaBLAS(), + dst.dptr_, 1, dst.size(0)); + } else { + // For general case, use gemm + BLASEngine::batched_gemm + (dst.stream_, + transpose_right, transpose_left, + transpose_right ? rhs.size(1) : rhs.size(2), + transpose_left ? lhs.size(2) : lhs.size(1), + transpose_right ? rhs.size(2) : rhs.size(1), + DType(scale * SV::AlphaBLAS()), + rhs.dptr_, rhs.stride_, + lhs.dptr_, lhs.stride_, + DType(SV::BetaBLAS()), + dst.dptr_, dst.stride_, dst.size(0)); + } + } +}; } // namespace expr } // namespace mshadow #endif // MSHADOW_DOT_ENGINE_INL_H_ diff --git a/mshadow/expression.h b/mshadow/expression.h index f02e260f6405..fdafdad0b552 100644 --- a/mshadow/expression.h +++ b/mshadow/expression.h @@ -241,6 +241,13 @@ inline DotExp dot(const TransposeExp &lhs, const TransposeExp &rhs) { return DotExp(lhs.exp, rhs.exp, DType(1.0f)); } +/*! \brief batch_dot operator def */ +template +inline DotExp +batch_dot(const RValueExp &lhs, const RValueExp &rhs) { + return DotExp( + lhs.self(), rhs.self(), DType(1.0f)); +} //--------------- // BinaryMapExp // --------------