Join GitHub today
GitHub is home to over 40 million developers working together to host and review code, manage projects, and build software together.Sign up
Add GPU support for float16 batched matmul #18436
Note that //tensorflow/python/kernel_tests:batch_matmul_op_test previously passed only because it does not specify force_gpu=True and falls back to the CPU.
yzhwang left a comment
Thanks a lot for the change Ben! I think this will also be useful to enable autotune for batch_matmul.
Please make sure that the float16 CUDA >= 9.1 batched_matmul on the GPU will really use cublasGemmBatchedEx() (Ideally we would want to know the CUDA/cudnn version at python level, so that we can write flexible code according to different versions. But I don't think I know how to do that or if TensorFlow supports that). Other than that, only one other comment (see below).