Skip to content

Commit

Permalink
Add fp16 matrix multiplication (GEMM) support to StreamExecutor, gate…
Browse files Browse the repository at this point in the history
…d on

compilation with CUDA 7.5; fp16 convolutions via cuDNN will come soon.
This does not update any TensorFlow ops, but it is a dependency of doing
that.

Note: fp16 axpy and dot do not exist in CUDA 7.5 and have thus not been added.
CUDA 8.0 supports both (through the axpyEx and dotEx interfaces).
Change: 122069402
  • Loading branch information
A. Unique TensorFlower authored and tensorflower-gardener committed May 11, 2016
1 parent 939ede0 commit 5230554
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 0 deletions.
17 changes: 17 additions & 0 deletions tensorflow/stream_executor/blas.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ limitations under the License.

#include "tensorflow/stream_executor/lib/array_slice.h"
#include "tensorflow/stream_executor/platform/port.h"
#include "third_party/eigen3/Eigen/Core"

namespace perftools {
namespace gputools {
Expand Down Expand Up @@ -846,6 +847,17 @@ class BlasSupport {
// op(X) is one of op(X) = X, or op(X) = X', or op(X) = conj(X'); alpha and
// beta are scalars; a, b, and c are matrices; op(a) is an m-by-k matrix;
// op(b) is a k-by-n matrix; c is an m-by-n matrix.
//
// Note: The half interface uses float precision internally; the version
// that uses half precision internally is not yet supported. There is no
// batched version of the half-precision interface.
virtual bool DoBlasGemm(Stream *stream, blas::Transpose transa,
blas::Transpose transb, uint64 m, uint64 n, uint64 k,
float alpha,
const DeviceMemory<Eigen::half> &a, int lda,
const DeviceMemory<Eigen::half> &b, int ldb,
float beta,
DeviceMemory<Eigen::half> *c, int ldc) = 0;
virtual bool DoBlasGemm(Stream *stream, blas::Transpose transa,
blas::Transpose transb, uint64 m, uint64 n, uint64 k,
float alpha, const DeviceMemory<float> &a, int lda,
Expand Down Expand Up @@ -1597,6 +1609,11 @@ class BlasSupport {
blas::Transpose trans, blas::Diagonal diag, uint64 n, \
const DeviceMemory<std::complex<double>> &a, int lda, \
DeviceMemory<std::complex<double>> *x, int incx) override; \
bool DoBlasGemm(Stream *stream, blas::Transpose transa, \
blas::Transpose transb, uint64 m, uint64 n, uint64 k, \
float alpha, const DeviceMemory<Eigen::half> &a, int lda, \
const DeviceMemory<Eigen::half> &b, int ldb, float beta, \
DeviceMemory<Eigen::half> *c, int ldc) override; \
bool DoBlasGemm(Stream *stream, blas::Transpose transa, \
blas::Transpose transb, uint64 m, uint64 n, uint64 k, \
float alpha, const DeviceMemory<float> &a, int lda, \
Expand Down
55 changes: 55 additions & 0 deletions tensorflow/stream_executor/cuda/cuda_blas.cc
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,10 @@ PERFTOOLS_GPUTOOLS_CUBLAS_WRAP(cublasCgemmBatched)
PERFTOOLS_GPUTOOLS_CUBLAS_WRAP(cublasZgemmBatched)
CUBLAS_BLAS_ROUTINE_EACH(PERFTOOLS_GPUTOOLS_CUBLAS_V2_WRAP)

#if CUDA_VERSION >= 7050
PERFTOOLS_GPUTOOLS_CUBLAS_WRAP(cublasSgemmEx)
#endif

} // namespace dynload

static string ToString(cublasStatus_t status) {
Expand Down Expand Up @@ -1624,6 +1628,57 @@ bool CUDABlas::DoBlasTrsv(Stream *stream, blas::UpperLower uplo,
lda, CUDAComplex(CUDAMemoryMutable(x)), incx);
}

bool CUDABlas::DoBlasGemm(
Stream *stream, blas::Transpose transa,
blas::Transpose transb, uint64 m, uint64 n, uint64 k,
float alpha, const DeviceMemory<Eigen::half> &a, int lda,
const DeviceMemory<Eigen::half> &b, int ldb, float beta,
DeviceMemory<Eigen::half> *c, int ldc) {
#if CUDA_VERSION >= 7050
VLOG(1) << port::Printf(
"doing cuBLAS SGEMM: at=%d bt=%d m=%llu n=%llu "
"k=%llu alpha=%f a=%p lda=%d b=%p ldb=%d beta=%f "
"c=%p ldc=%d",
static_cast<int>(transa), static_cast<int>(transb), m, n, k, alpha,
a.opaque(), lda, b.opaque(), ldb, beta, c->opaque(), ldc);
if (transa == blas::Transpose::kNoTranspose) {
if (lda < static_cast<int64>(m)) {
LOG(WARNING) << "GEMM lda was smaller than m (no transpose case); "
"precondition violation";
}
} else {
if (lda < static_cast<int64>(k)) {
LOG(WARNING) << "GEMM lda (" << lda << ") was smaller than k (" << k
<< ") (transpose case); precondition violation";
}
}
if (transb == blas::Transpose::kNoTranspose) {
if (ldb < static_cast<int64>(k)) {
LOG(WARNING) << "GEMM ldb (" << ldb << ") was smaller than k (" << k
<< ") (no transpose case); precondition violation";
}
} else {
if (ldb < static_cast<int64>(n)) {
LOG(WARNING) << "GEMM ldb was smaller than n (transpose case); "
"precondition violation";
}
}
// TODO(sesse): Consider supporting the Hgemm interface, which uses half
// calculations internally (faster on newer devices, such as Pascal and TX1,
// but less precise).
return DoBlasInternal(
dynload::cublasSgemmEx, stream, true /* = pointer_mode_host */,
CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, &alpha,
CUDAMemory(a), CUBLAS_DATA_HALF, lda,
CUDAMemory(b), CUBLAS_DATA_HALF, ldb,
&beta,
CUDAMemoryMutable(c), CUBLAS_DATA_HALF, ldc);
#else
LOG(FATAL) << "fp16 sgemm is not implemented in this cuBLAS version "
<< "(need at least CUDA 7.5)";
#endif
}

bool CUDABlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
blas::Transpose transb, uint64 m, uint64 n, uint64 k,
float alpha, const DeviceMemory<float> &a, int lda,
Expand Down
18 changes: 18 additions & 0 deletions tensorflow/stream_executor/stream.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2572,6 +2572,24 @@ Stream &Stream::ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans,
lda, x, incx);
}

Stream &Stream::ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
uint64 m, uint64 n, uint64 k, float alpha,
const DeviceMemory<Eigen::half> &a, int lda,
const DeviceMemory<Eigen::half> &b, int ldb,
float beta,
DeviceMemory<Eigen::half> *c, int ldc) {
VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
PARAM(beta), PARAM(c), PARAM(ldc));

ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, float,
const DeviceMemory<Eigen::half> &, int,
const DeviceMemory<Eigen::half> &, int,
float, DeviceMemory<Eigen::half> *, int> impl;
return impl(this, &blas::BlasSupport::DoBlasGemm, transa, transb, m, n, k,
alpha, a, lda, b, ldb, beta, c, ldc);
}

Stream &Stream::ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
uint64 m, uint64 n, uint64 k, float alpha,
const DeviceMemory<float> &a, int lda,
Expand Down
5 changes: 5 additions & 0 deletions tensorflow/stream_executor/stream.h
Original file line number Diff line number Diff line change
Expand Up @@ -922,6 +922,11 @@ class Stream {
DeviceMemory<std::complex<double>> *x, int incx);

// See BlasSupport::DoBlasGemm.
Stream &ThenBlasGemm(blas::Transpose transa, blas::Transpose transb, uint64 m,
uint64 n, uint64 k, float alpha,
const DeviceMemory<Eigen::half> &a, int lda,
const DeviceMemory<Eigen::half> &b, int ldb, float beta,
DeviceMemory<Eigen::half> *c, int ldc);
Stream &ThenBlasGemm(blas::Transpose transa, blas::Transpose transb, uint64 m,
uint64 n, uint64 k, float alpha,
const DeviceMemory<float> &a, int lda,
Expand Down

0 comments on commit 5230554

Please sign in to comment.