diff --git a/tfjs-backend-wasm/src/cc/BUILD b/tfjs-backend-wasm/src/cc/BUILD index 0126adfa16e..e3951df8944 100644 --- a/tfjs-backend-wasm/src/cc/BUILD +++ b/tfjs-backend-wasm/src/cc/BUILD @@ -94,6 +94,18 @@ tfjs_cc_library( ], ) +tfjs_cc_library( + name = "batch_mat_mul_impl", + srcs = ["batch_mat_mul_impl.cc"], + hdrs = ["batch_mat_mul_impl.h"], + deps = [ + ":backend", + ":prelu_impl", + ":transpose_impl", + ":util", + ], +) + tfjs_cc_library( name = "interpolate_bilinear_impl", srcs = ["interpolate_bilinear_impl.cc"], @@ -182,6 +194,7 @@ tfjs_cc_library( ":NonMaxSuppressionV5", ":NotEqual", ":PadV2", + ":Pow", ":Prelu", ":Relu", ":Relu6", @@ -192,6 +205,7 @@ tfjs_cc_library( ":Sub", ":Tile", ":Transpose", + ":_FusedMatMul", ], ) @@ -256,10 +270,21 @@ tfjs_cc_library( hdrs = ["kernels/BatchMatMul.h"], deps = [ ":backend", + ":batch_mat_mul_impl", ":util", ], ) +tfjs_cc_library( + name = "_FusedMatMul", + srcs = ["kernels/_FusedMatMul.cc"], + hdrs = ["kernels/_FusedMatMul.h"], + deps = [ + ":backend", + ":batch_mat_mul_impl", + ], +) + tfjs_unit_test( name = "BatchMatMul_test", srcs = ["kernels/BatchMatMul_test.cc"], @@ -268,6 +293,14 @@ tfjs_unit_test( ], ) +tfjs_unit_test( + name = "_FusedMatMul_test", + srcs = ["kernels/_FusedMatMul_test.cc"], + deps = [ + ":_FusedMatMul", + ], +) + tfjs_cc_library( name = "ClipByValue", srcs = ["kernels/ClipByValue.cc"], @@ -291,6 +324,7 @@ tfjs_cc_library( srcs = ["kernels/Conv2D.cc"], hdrs = ["kernels/Conv2D.h"], deps = [ + ":backend", ":conv2d_impl", ], ) @@ -318,6 +352,7 @@ tfjs_cc_library( srcs = ["kernels/DepthwiseConv2dNative.cc"], hdrs = ["kernels/DepthwiseConv2dNative.h"], deps = [ + ":backend", ":conv2d_impl", ], ) @@ -390,6 +425,7 @@ tfjs_cc_library( srcs = ["kernels/FusedDepthwiseConv2D.cc"], hdrs = ["kernels/FusedDepthwiseConv2D.h"], deps = [ + ":backend", ":conv2d_impl", ], ) @@ -593,6 +629,16 @@ tfjs_cc_library( ], ) +tfjs_cc_library( + name = "Pow", + srcs = ["kernels/Pow.cc"], + deps = [ + ":backend", + ":binary", + ":util", + ], +) + tfjs_cc_library( name = "Prelu", srcs = ["kernels/Prelu.cc"], diff --git a/tfjs-backend-wasm/src/cc/backend.h b/tfjs-backend-wasm/src/cc/backend.h index f730bd120ea..581066ec8ef 100644 --- a/tfjs-backend-wasm/src/cc/backend.h +++ b/tfjs-backend-wasm/src/cc/backend.h @@ -27,6 +27,9 @@ enum DType { complex64 = 4, }; +// Must match enum in kernels/types.ts. +enum FusableActivation { LINEAR = 0, RELU = 1, RELU6 = 2, PRELU = 3 }; + // Holds the memory offset and the size of a tensor. struct TensorInfo { // Pointer to the bytes where the data is allocated. diff --git a/tfjs-backend-wasm/src/cc/batch_mat_mul_impl.cc b/tfjs-backend-wasm/src/cc/batch_mat_mul_impl.cc new file mode 100644 index 00000000000..3cf5b15e979 --- /dev/null +++ b/tfjs-backend-wasm/src/cc/batch_mat_mul_impl.cc @@ -0,0 +1,297 @@ +/* Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ===========================================================================*/ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "src/cc/backend.h" +#include "src/cc/prelu_impl.h" +#include "src/cc/util.h" + +#include "src/cc/batch_mat_mul_impl.h" + +const size_t kBlockSize = 48; + +namespace { +// We use std::tuple as the cache key as it implements the compare operator +// needed for std::map. +typedef std::tuple OperatorCacheKey; + +// The operator cache maps the weights id to the xnn_operator_t instantiated for +// this set of weights. +std::map operator_cache; + +std::unordered_map> + b_operator_cache_key_map; + +std::unordered_map> + bias_operator_cache_key_map; + +void erase_from_cache(const size_t tensor_id, + std::unordered_map>& + operator_cache_key_map) { + auto operator_cache_keys_idx = operator_cache_key_map.find(tensor_id); + if (operator_cache_keys_idx != operator_cache_key_map.end()) { + std::vector& operator_cache_keys = + operator_cache_keys_idx->second; + for (auto& operator_cache_key : operator_cache_keys) { + auto operator_cache_key_idx = operator_cache.find(operator_cache_key); + if (operator_cache_key_idx != operator_cache.end()) { + auto& cached_op = operator_cache_key_idx->second; + xnn_delete_operator(cached_op); + tfjs::backend::xnn_operator_count--; + + operator_cache.erase(operator_cache_key); + } + } + operator_cache_key_map.erase(tensor_id); + } +} + +void delete_xnn_operators(const size_t tensor_id) { + erase_from_cache(tensor_id, b_operator_cache_key_map); + erase_from_cache(tensor_id, bias_operator_cache_key_map); +} + +void associate_tensor_with_key( + const size_t tensor_id, const OperatorCacheKey& cache_key, + std::unordered_map>& + operator_cache_key_map) { + auto cache_keys_idx = operator_cache_key_map.find(tensor_id); + if (cache_keys_idx == operator_cache_key_map.end()) { + std::vector cache_keys = {cache_key}; + operator_cache_key_map.emplace(tensor_id, std::move(cache_keys)); + tfjs::backend::register_disposal_callback(tensor_id, *delete_xnn_operators); + + } else { + auto& cache_keys = operator_cache_key_map.at(tensor_id); + cache_keys.emplace_back(cache_key); + } +} + +void xnn_matmul(const size_t a_id, const size_t* a_shape_ptr, + const size_t a_shape_len, const size_t b_id, + const size_t* b_shape_ptr, const size_t b_shape_len, + const size_t out_id, const size_t bias_id, + const float output_min, const float output_max, + const size_t clamp_method) { + auto& a_info = tfjs::backend::get_tensor_info(a_id); + auto& b_info = tfjs::backend::get_tensor_info(b_id); + auto& out_info = tfjs::backend::get_tensor_info_out(out_id); + + const float* a_buf = a_info.f32(); + const float* b_buf = b_info.f32(); + float* out_buf = out_info.f32_write(); + + const float* bias_buf = nullptr; + if (bias_id != 0) { + bias_buf = tfjs::backend::get_tensor_info_out(bias_id).f32(); + } + + xnn_operator_t fully_connected_op = nullptr; + + OperatorCacheKey cache_key = {b_id, bias_id, clamp_method}; + + // We assume b is the weights and cache the xnn operator on it. + auto operator_cache_idx = operator_cache.find(cache_key); + if (operator_cache_idx == operator_cache.end()) { + const size_t input_channels = b_shape_ptr[1]; + const size_t output_channels = b_shape_ptr[2]; + const size_t input_stride = input_channels; + const size_t output_stride = output_channels; + + // XNNPack expects b to already be transposed. TensorFlow.js doesn't do this + // automatically so we have to tell XNNPack to do the transposing. + const uint32_t flags = XNN_FLAG_TRANSPOSE_WEIGHTS; + xnn_status status = xnn_create_fully_connected_nc_f32( + input_channels, output_channels, input_stride, output_stride, b_buf, + bias_buf, output_min, output_max, flags, &fully_connected_op); + if (status != xnn_status_success) { + tfjs::util::warn( + "XNN status for xnn_create_fully_connected_nc_f32 is not successful. " + "Got status %d. Use -c dbg to see XNN logs.", + status); + return; + } + + operator_cache.insert({cache_key, fully_connected_op}); + + associate_tensor_with_key(b_id, cache_key, b_operator_cache_key_map); + if (bias_id != 0) { + associate_tensor_with_key(bias_id, cache_key, + bias_operator_cache_key_map); + } + + tfjs::backend::xnn_operator_count++; + } else { + fully_connected_op = operator_cache_idx->second; + } + + const size_t batch_size = a_shape_ptr[1]; + xnn_status status = + xnn_setup_fully_connected_nc_f32(fully_connected_op, batch_size, a_buf, + out_buf, nullptr /* thread pool */); + if (status != xnn_status_success) { + tfjs::util::warn( + "XNN status for xnn_setup_fully_connected_nc_f32 is not successful. " + "Got status %d. Use -c dbg to see XNN logs.", + status); + return; + } + + xnn_run_operator(fully_connected_op, nullptr /* thread pool */); +} + +void slow_batch_matmul(const size_t a_id, const size_t* a_shape_ptr, + const size_t a_shape_len, const size_t b_id, + const size_t* b_shape_ptr, const size_t b_shape_len, + const bool transpose_a, const bool transpose_b, + const size_t out_id, const size_t bias_id, + const float output_min, const float output_max) { + const size_t shared_dim = transpose_a ? a_shape_ptr[1] : a_shape_ptr[2]; + const size_t left_dim = transpose_a ? a_shape_ptr[2] : a_shape_ptr[1]; + const size_t right_dim = transpose_b ? b_shape_ptr[1] : b_shape_ptr[2]; + const size_t batch_dim = a_shape_ptr[0]; + + std::vector a_shape(a_shape_ptr, a_shape_ptr + a_shape_len); + std::vector b_shape(b_shape_ptr, b_shape_ptr + b_shape_len); + const std::vector a_strides = tfjs::util::compute_strides(a_shape); + const std::vector b_strides = tfjs::util::compute_strides(b_shape); + + size_t a_batch = a_strides[0]; + size_t a_outer_step, a_inner_step; + if (transpose_a) { + a_outer_step = 1; + a_inner_step = a_strides[1]; + } else { + a_outer_step = a_strides[1]; + a_inner_step = 1; + } + size_t b_batch = b_strides[0]; + size_t b_outer_step, b_inner_step; + if (transpose_b) { + b_outer_step = b_strides[1]; + b_inner_step = 1; + } else { + b_outer_step = 1; + b_inner_step = b_strides[1]; + } + + auto& a_info = tfjs::backend::get_tensor_info(a_id); + auto& b_info = tfjs::backend::get_tensor_info(b_id); + auto& out_info = tfjs::backend::get_tensor_info_out(out_id); + + const float* a_buf = a_info.f32(); + const float* b_buf = b_info.f32(); + float* out_buf = out_info.f32_write(); + + const float* bias_buf = nullptr; + auto& bias_info = tfjs::backend::get_tensor_info_out(bias_id); + if (bias_id != 0) { + bias_buf = bias_info.f32(); + } + + const size_t size = left_dim * right_dim; + + // Zero out the output buffer because it might have been used before. + std::fill(out_buf, out_buf + batch_dim * size, 0); + + for (size_t b = 0; b < batch_dim; ++b) { + for (size_t i0 = 0; i0 < left_dim; i0 += kBlockSize) { + for (size_t j0 = 0; j0 < right_dim; j0 += kBlockSize) { + for (size_t k0 = 0; k0 < shared_dim; k0 += kBlockSize) { + // for when kBlockSize doesn't evenly divide the input + const size_t i_block = std::min(i0 + kBlockSize, left_dim); + const size_t j_block = std::min(j0 + kBlockSize, right_dim); + const size_t k_block = std::min(k0 + kBlockSize, shared_dim); + + for (size_t i = i0; i < i_block; ++i) { + for (size_t j = j0; j < j_block; ++j) { + float sum = 0.0; + + for (size_t k = k0; k < k_block; ++k) { + sum += + a_buf[b * a_batch + i * a_outer_step + k * a_inner_step] * + b_buf[k * b_inner_step + j * b_outer_step + b * b_batch]; + } + size_t innermost_dim = i * right_dim + j; + size_t out_buf_index = b * size + innermost_dim; + float current = out_buf[out_buf_index]; + + // Handles 1D broadcasting. + size_t bias_index = std::min(innermost_dim, bias_info.size - 1); + + out_buf[out_buf_index] = std::max( + std::min(current + sum + bias_buf[bias_index], output_max), + output_min); + } + } + } + } + } + } +} +} // namespace + +namespace tfjs { +namespace wasm { +void fused_batch_mat_mul(const size_t a_id, const size_t* a_shape_ptr, + const size_t a_shape_len, const size_t b_id, + const size_t* b_shape_ptr, const size_t b_shape_len, + const bool transpose_a, const bool transpose_b, + const FusableActivation activation, + const size_t bias_id, const size_t prelu_weights_id, + const size_t out_id) { + FusableActivation clamp_method = activation; + if (activation == FusableActivation::PRELU) { + clamp_method = FusableActivation::LINEAR; + } + + float output_min = -std::numeric_limits::infinity(); + float output_max = std::numeric_limits::infinity(); + + if (activation == FusableActivation::RELU) { + output_min = 0; + } else if (activation == FusableActivation::RELU6) { + output_min = 0; + output_max = 6; + } + + if (!transpose_a && !transpose_b && a_shape_ptr[0] == 1 && + b_shape_ptr[0] == 1) { + xnn_matmul(a_id, a_shape_ptr, a_shape_len, b_id, b_shape_ptr, b_shape_len, + out_id, bias_id, output_min, output_max, clamp_method); + } else { + slow_batch_matmul(a_id, a_shape_ptr, a_shape_len, b_id, b_shape_ptr, + b_shape_len, transpose_a, transpose_b, out_id, bias_id, + output_min, output_max); + } + + auto& out_info = backend::get_tensor_info_out(out_id); + float* out_buf = out_info.f32_write(); + if (activation == FusableActivation::PRELU) { + prelu(out_buf, out_info.size, prelu_weights_id, out_id); + } +} +} // namespace wasm +} // namespace tfjs diff --git a/tfjs-backend-wasm/src/cc/batch_mat_mul_impl.h b/tfjs-backend-wasm/src/cc/batch_mat_mul_impl.h new file mode 100644 index 00000000000..476eeb5d088 --- /dev/null +++ b/tfjs-backend-wasm/src/cc/batch_mat_mul_impl.h @@ -0,0 +1,34 @@ +/* Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ===========================================================================*/ + +#ifndef BATCH_MAT_MUL_IMPL_H_ +#define BATCH_MAT_MUL_IMPL_H_ + +#include + +namespace tfjs { +namespace wasm { + +void fused_batch_mat_mul(const size_t a_id, const size_t* a_shape_ptr, + const size_t a_shape_len, const size_t b_id, + const size_t* b_shape_ptr, const size_t b_shape_len, + const bool transpose_a, const bool transpose_b, + const FusableActivation activation, + const size_t bias_id, const size_t prelu_weights_id, + const size_t out_id); + +} // namespace wasm +} // namespace tfjs + +#endif // BATCH_MAT_MUL_IMPL_H_ diff --git a/tfjs-backend-wasm/src/cc/conv2d_impl.h b/tfjs-backend-wasm/src/cc/conv2d_impl.h index 1836362b5aa..9d796f8fd94 100644 --- a/tfjs-backend-wasm/src/cc/conv2d_impl.h +++ b/tfjs-backend-wasm/src/cc/conv2d_impl.h @@ -17,12 +17,11 @@ #include +#include "src/cc/backend.h" + namespace tfjs { namespace wasm { -// Must match enum in FusedConv2D.ts. -enum FusableActivation { LINEAR = 0, RELU = 1, RELU6 = 2, PRELU = 3 }; - void conv2d(const size_t x_id, const size_t batch_size, const size_t input_height, const size_t input_width, const size_t filter_id, const size_t filter_height, diff --git a/tfjs-backend-wasm/src/cc/kernels/BatchMatMul.cc b/tfjs-backend-wasm/src/cc/kernels/BatchMatMul.cc index e0417dbacaf..32e05d3e54b 100644 --- a/tfjs-backend-wasm/src/cc/kernels/BatchMatMul.cc +++ b/tfjs-backend-wasm/src/cc/kernels/BatchMatMul.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 Google Inc. All Rights Reserved. +/* Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at @@ -16,179 +16,12 @@ #include #endif -#include -#include #include -#include -#include -#include -#include #include "src/cc/backend.h" -#include "src/cc/util.h" - +#include "src/cc/batch_mat_mul_impl.h" #include "src/cc/kernels/BatchMatMul.h" -const size_t kBlockSize = 48; - -namespace { -// We use std::tuple as the cache key as it implements the compare operator -// needed for std::map. -typedef std::tuple OperatorCacheKey; - -// The operator cache maps the weights id to the xnn_operator_t instantiated for -// this set of weights. -std::map operator_cache; - -void delete_xnn_operator(const size_t weights_id) { - xnn_operator_t fully_connected_op = operator_cache.at(weights_id); - xnn_delete_operator(fully_connected_op); - tfjs::backend::xnn_operator_count--; - - operator_cache.erase(weights_id); -} - -void xnn_matmul(const size_t a_id, const size_t* a_shape_ptr, - const size_t a_shape_len, const size_t b_id, - const size_t* b_shape_ptr, const size_t b_shape_len, - const size_t out_id) { - auto& a_info = tfjs::backend::get_tensor_info(a_id); - auto& b_info = tfjs::backend::get_tensor_info(b_id); - auto& out_info = tfjs::backend::get_tensor_info_out(out_id); - - const float* a_buf = a_info.f32(); - const float* b_buf = b_info.f32(); - float* out_buf = out_info.f32_write(); - - xnn_operator_t fully_connected_op = nullptr; - - OperatorCacheKey cache_key = {b_id}; - - // We assume b is the weights and cache the xnn operator on it. - auto operator_cache_idx = operator_cache.find(cache_key); - if (operator_cache_idx == operator_cache.end()) { - const size_t input_channels = b_shape_ptr[1]; - const size_t output_channels = b_shape_ptr[2]; - const size_t input_stride = input_channels; - const size_t output_stride = output_channels; - const float* bias = nullptr; - - const float output_min = -std::numeric_limits::infinity(); - const float output_max = std::numeric_limits::infinity(); - - // XNNPack expects b to already be transposed. TensorFlow.js doesn't do this - // automatically so we have to tell XNNPack to do the transposing. - const uint32_t flags = XNN_FLAG_TRANSPOSE_WEIGHTS; - xnn_status status = xnn_create_fully_connected_nc_f32( - input_channels, output_channels, input_stride, output_stride, b_buf, - bias, output_min, output_max, flags, &fully_connected_op); - if (status != xnn_status_success) { - tfjs::util::warn( - "XNN status for xnn_create_fully_connected_nc_f32 is not successful. " - "Got status %d. Use -c dbg to see XNN logs.", - status); - return; - } - - operator_cache.insert({cache_key, fully_connected_op}); - - tfjs::backend::register_disposal_callback(b_id, *delete_xnn_operator); - - tfjs::backend::xnn_operator_count++; - } else { - fully_connected_op = operator_cache_idx->second; - } - - const size_t batch_size = a_shape_ptr[1]; - xnn_status status = - xnn_setup_fully_connected_nc_f32(fully_connected_op, batch_size, a_buf, - out_buf, nullptr /* thread pool */); - if (status != xnn_status_success) { - tfjs::util::warn( - "XNN status for xnn_setup_fully_connected_nc_f32 is not successful. " - "Got status %d. Use -c dbg to see XNN logs.", - status); - return; - } - - xnn_run_operator(fully_connected_op, nullptr /* thread pool */); -} - -void slow_batch_matmul(const size_t a_id, const size_t* a_shape_ptr, - const size_t a_shape_len, const size_t b_id, - const size_t* b_shape_ptr, const size_t b_shape_len, - const bool transpose_a, const bool transpose_b, - const size_t out_id) { - const size_t shared_dim = transpose_a ? a_shape_ptr[1] : a_shape_ptr[2]; - const size_t left_dim = transpose_a ? a_shape_ptr[2] : a_shape_ptr[1]; - const size_t right_dim = transpose_b ? b_shape_ptr[1] : b_shape_ptr[2]; - const size_t batch_dim = a_shape_ptr[0]; - - std::vector a_shape(a_shape_ptr, a_shape_ptr + a_shape_len); - std::vector b_shape(b_shape_ptr, b_shape_ptr + b_shape_len); - const std::vector a_strides = tfjs::util::compute_strides(a_shape); - const std::vector b_strides = tfjs::util::compute_strides(b_shape); - - size_t a_batch = a_strides[0]; - size_t a_outer_step, a_inner_step; - if (transpose_a) { - a_outer_step = 1; - a_inner_step = a_strides[1]; - } else { - a_outer_step = a_strides[1]; - a_inner_step = 1; - } - size_t b_batch = b_strides[0]; - size_t b_outer_step, b_inner_step; - if (transpose_b) { - b_outer_step = b_strides[1]; - b_inner_step = 1; - } else { - b_outer_step = 1; - b_inner_step = b_strides[1]; - } - - auto& a_info = tfjs::backend::get_tensor_info(a_id); - auto& b_info = tfjs::backend::get_tensor_info(b_id); - auto& out_info = tfjs::backend::get_tensor_info_out(out_id); - - const float* a_buf = a_info.f32(); - const float* b_buf = b_info.f32(); - float* out_buf = out_info.f32_write(); - - const size_t size = left_dim * right_dim; - - // Zero out the output buffer because it might have been used before. - std::fill(out_buf, out_buf + batch_dim * size, 0); - - for (size_t b = 0; b < batch_dim; ++b) { - for (size_t i0 = 0; i0 < left_dim; i0 += kBlockSize) { - for (size_t j0 = 0; j0 < right_dim; j0 += kBlockSize) { - for (size_t k0 = 0; k0 < shared_dim; k0 += kBlockSize) { - // for when kBlockSize doesn't evenly divide the input - const size_t i_block = std::min(i0 + kBlockSize, left_dim); - const size_t j_block = std::min(j0 + kBlockSize, right_dim); - const size_t k_block = std::min(k0 + kBlockSize, shared_dim); - - for (size_t i = i0; i < i_block; ++i) { - for (size_t j = j0; j < j_block; ++j) { - float sum = 0.0; - - for (size_t k = k0; k < k_block; ++k) { - sum += - a_buf[b * a_batch + i * a_outer_step + k * a_inner_step] * - b_buf[k * b_inner_step + j * b_outer_step + b * b_batch]; - } - out_buf[b * size + (i * right_dim + j)] += sum; - } - } - } - } - } - } -} -} // namespace - namespace tfjs { namespace wasm { // We use C-style API to interface with Javascript. @@ -202,14 +35,12 @@ void BatchMatMul(const size_t a_id, const size_t* a_shape_ptr, const size_t* b_shape_ptr, const size_t b_shape_len, const bool transpose_a, const bool transpose_b, const size_t out_id) { - if (!transpose_a && !transpose_b && a_shape_ptr[0] == 1 && - b_shape_ptr[0] == 1) { - xnn_matmul(a_id, a_shape_ptr, a_shape_len, b_id, b_shape_ptr, b_shape_len, - out_id); - } else { - slow_batch_matmul(a_id, a_shape_ptr, a_shape_len, b_id, b_shape_ptr, - b_shape_len, transpose_a, transpose_b, out_id); - } + const size_t bias_id = 0; + const size_t prelu_weights_id = 0; + const FusableActivation activation = FusableActivation::LINEAR; + tfjs::wasm::fused_batch_mat_mul( + a_id, a_shape_ptr, a_shape_len, b_id, b_shape_ptr, b_shape_len, + transpose_a, transpose_b, activation, bias_id, prelu_weights_id, out_id); } } // extern "C" diff --git a/tfjs-backend-wasm/src/cc/kernels/BatchMatMul.h b/tfjs-backend-wasm/src/cc/kernels/BatchMatMul.h index 0b9570f3839..bc5b5728c4b 100644 --- a/tfjs-backend-wasm/src/cc/kernels/BatchMatMul.h +++ b/tfjs-backend-wasm/src/cc/kernels/BatchMatMul.h @@ -1,4 +1,4 @@ -/* Copyright 2019 Google Inc. All Rights Reserved. +/* Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at diff --git a/tfjs-backend-wasm/src/cc/kernels/BatchMatMul_test.cc b/tfjs-backend-wasm/src/cc/kernels/BatchMatMul_test.cc index a5bb986dd3e..1ac04d3d4a9 100644 --- a/tfjs-backend-wasm/src/cc/kernels/BatchMatMul_test.cc +++ b/tfjs-backend-wasm/src/cc/kernels/BatchMatMul_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 Google Inc. All Rights Reserved. +/* Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at diff --git a/tfjs-backend-wasm/src/cc/kernels/Conv2D.cc b/tfjs-backend-wasm/src/cc/kernels/Conv2D.cc index 552c9e93521..a3b9a973cd7 100644 --- a/tfjs-backend-wasm/src/cc/kernels/Conv2D.cc +++ b/tfjs-backend-wasm/src/cc/kernels/Conv2D.cc @@ -20,7 +20,7 @@ #include -#include "src/cc/conv2d_impl.h" +#include "src/cc/backend.h" namespace tfjs { namespace wasm { diff --git a/tfjs-backend-wasm/src/cc/kernels/DepthwiseConv2dNative.cc b/tfjs-backend-wasm/src/cc/kernels/DepthwiseConv2dNative.cc index 621690197d4..a9f3cc9adb6 100644 --- a/tfjs-backend-wasm/src/cc/kernels/DepthwiseConv2dNative.cc +++ b/tfjs-backend-wasm/src/cc/kernels/DepthwiseConv2dNative.cc @@ -18,6 +18,7 @@ #include +#include "src/cc/backend.h" #include "src/cc/conv2d_impl.h" #include "src/cc/kernels/DepthwiseConv2dNative.h" diff --git a/tfjs-backend-wasm/src/cc/kernels/FusedConv2D_test.cc b/tfjs-backend-wasm/src/cc/kernels/FusedConv2D_test.cc index 6a75526979d..1b73d35d514 100644 --- a/tfjs-backend-wasm/src/cc/kernels/FusedConv2D_test.cc +++ b/tfjs-backend-wasm/src/cc/kernels/FusedConv2D_test.cc @@ -18,7 +18,6 @@ #include #include "src/cc/backend.h" -#include "src/cc/conv2d_impl.h" #include "src/cc/kernels/FusedConv2D.h" #include "src/cc/util.h" @@ -76,8 +75,7 @@ TEST(FUSEDCONV2D, xnn_operator_lifetime) { const size_t input_channels = 1; const size_t output_channels = 1; - const tfjs::wasm::FusableActivation activation = - tfjs::wasm::FusableActivation::LINEAR; + const FusableActivation activation = FusableActivation::LINEAR; const size_t prelu_weights_id = 0; tfjs::wasm::FusedConv2D( @@ -173,8 +171,7 @@ TEST(FUSEDCONV2D, xnn_operator_lifetime) { // One new XNN operator should be created for the next call to conv2d with a // different activation. - const tfjs::wasm::FusableActivation activation2 = - tfjs::wasm::FusableActivation::RELU6; + const FusableActivation activation2 = FusableActivation::RELU6; tfjs::wasm::FusedConv2D( x1_id, batch_size, input_height, input_width, weights1_id, filter_height, filter_width, bias1_id, pad_top1, pad_right, pad_bottom1, pad_left, diff --git a/tfjs-backend-wasm/src/cc/kernels/FusedDepthwiseConv2D.cc b/tfjs-backend-wasm/src/cc/kernels/FusedDepthwiseConv2D.cc index 5d8b55453a6..e796e9b8043 100644 --- a/tfjs-backend-wasm/src/cc/kernels/FusedDepthwiseConv2D.cc +++ b/tfjs-backend-wasm/src/cc/kernels/FusedDepthwiseConv2D.cc @@ -20,6 +20,7 @@ #include "src/cc/kernels/FusedDepthwiseConv2D.h" +#include "src/cc/backend.h" #include "src/cc/conv2d_impl.h" namespace tfjs { diff --git a/tfjs-backend-wasm/src/cc/kernels/FusedDepthwiseConv2D.h b/tfjs-backend-wasm/src/cc/kernels/FusedDepthwiseConv2D.h index 34bacbbe4e2..4a876d349da 100644 --- a/tfjs-backend-wasm/src/cc/kernels/FusedDepthwiseConv2D.h +++ b/tfjs-backend-wasm/src/cc/kernels/FusedDepthwiseConv2D.h @@ -17,7 +17,7 @@ #include -#include "src/cc/conv2d_impl.h" +#include "src/cc/backend.h" namespace tfjs { diff --git a/tfjs-backend-wasm/src/cc/kernels/FusedDepthwiseConv2D_test.cc b/tfjs-backend-wasm/src/cc/kernels/FusedDepthwiseConv2D_test.cc index 5f362c9126b..37afa3c7b12 100644 --- a/tfjs-backend-wasm/src/cc/kernels/FusedDepthwiseConv2D_test.cc +++ b/tfjs-backend-wasm/src/cc/kernels/FusedDepthwiseConv2D_test.cc @@ -76,8 +76,7 @@ TEST(FUSEDDEPTHWISECONV2D, xnn_operator_lifetime) { const size_t input_channels = 1; const size_t output_channels = 1; - const tfjs::wasm::FusableActivation activation = - tfjs::wasm::FusableActivation::LINEAR; + const FusableActivation activation = FusableActivation::LINEAR; tfjs::wasm::FusedDepthwiseConv2D( x0_id, batch_size, input_height, input_width, weights0_id, filter_height, @@ -89,8 +88,7 @@ TEST(FUSEDDEPTHWISECONV2D, xnn_operator_lifetime) { // One new xnn operator should be created for second call to conv2d with no // bias and prelu activation. - const tfjs::wasm::FusableActivation prelu_activation = - tfjs::wasm::FusableActivation::PRELU; + const FusableActivation prelu_activation = FusableActivation::PRELU; const size_t prelu_weights_id = 8; const size_t prelu_size = 8; @@ -190,8 +188,7 @@ TEST(FUSEDDEPTHWISECONV2D, xnn_operator_lifetime) { // One new XNN operator should be created for the next call to conv2d with a // different activation. - const tfjs::wasm::FusableActivation activation2 = - tfjs::wasm::FusableActivation::RELU6; + const FusableActivation activation2 = FusableActivation::RELU6; tfjs::wasm::FusedDepthwiseConv2D( x1_id, batch_size, input_height, input_width, weights1_id, filter_height, filter_width, bias1_id, pad_top1, pad_right, pad_bottom1, pad_left, diff --git a/tfjs-backend-wasm/src/cc/kernels/Pow.cc b/tfjs-backend-wasm/src/cc/kernels/Pow.cc new file mode 100644 index 00000000000..646fe355fa1 --- /dev/null +++ b/tfjs-backend-wasm/src/cc/kernels/Pow.cc @@ -0,0 +1,61 @@ +/* Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ===========================================================================*/ + +#ifdef __EMSCRIPTEN__ +#include +#endif + +#include +#include + +#include "src/cc/binary.h" +#include "src/cc/util.h" + +namespace { +template +inline T power(T a, T b) { + return pow(a, b); +} +} // namespace + +namespace tfjs { +namespace wasm { +// We use C-style API to interface with Javascript. +extern "C" { + +#ifdef __EMSCRIPTEN__ +EMSCRIPTEN_KEEPALIVE +#endif +void Pow(const size_t a_id, const size_t* a_shape_ptr, const size_t a_shape_len, + const size_t b_id, const size_t* b_shape_ptr, const size_t b_shape_len, + const DType dtype, const size_t out_id) { + switch (dtype) { + case DType::float32: + binary_f32(a_id, b_id, out_id, power); + break; + case DType::int32: + binary_i32(a_id, b_id, out_id, power); + break; + case DType::boolean: + binary_bool(a_id, b_id, out_id, power); + break; + default: + util::warn("Pow for tensor ids %d and %d failed. Unknown dtype %d", a_id, + b_id, dtype); + } +} + +} // extern "C" +} // namespace wasm +} // namespace tfjs diff --git a/tfjs-backend-wasm/src/cc/kernels/_FusedMatMul.cc b/tfjs-backend-wasm/src/cc/kernels/_FusedMatMul.cc new file mode 100644 index 00000000000..5eef8ff02c1 --- /dev/null +++ b/tfjs-backend-wasm/src/cc/kernels/_FusedMatMul.cc @@ -0,0 +1,46 @@ +/* Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ===========================================================================*/ + +#ifdef __EMSCRIPTEN__ +#include +#endif + +#include + +#include "src/cc/backend.h" +#include "src/cc/batch_mat_mul_impl.h" +#include "src/cc/kernels/_FusedMatMul.h" + +namespace tfjs { +namespace wasm { +// We use C-style API to interface with Javascript. +extern "C" { + +#ifdef __EMSCRIPTEN__ +EMSCRIPTEN_KEEPALIVE +#endif +void _FusedMatMul(const size_t a_id, const size_t* a_shape_ptr, + const size_t a_shape_len, const size_t b_id, + const size_t* b_shape_ptr, const size_t b_shape_len, + const bool transpose_a, const bool transpose_b, + const FusableActivation activation, const size_t bias_id, + const size_t prelu_weights_id, const size_t out_id) { + tfjs::wasm::fused_batch_mat_mul( + a_id, a_shape_ptr, a_shape_len, b_id, b_shape_ptr, b_shape_len, + transpose_a, transpose_b, activation, bias_id, prelu_weights_id, out_id); +} + +} // extern "C" +} // namespace wasm +} // namespace tfjs diff --git a/tfjs-backend-wasm/src/cc/kernels/_FusedMatMul.h b/tfjs-backend-wasm/src/cc/kernels/_FusedMatMul.h new file mode 100644 index 00000000000..98501aabd55 --- /dev/null +++ b/tfjs-backend-wasm/src/cc/kernels/_FusedMatMul.h @@ -0,0 +1,35 @@ +/* Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ===========================================================================*/ + +#ifndef KERNELS__FUSEDMATMUL_H_ +#define KERNELS__FUSEDMATMUL_H_ + +#include + +namespace tfjs { +namespace wasm { +extern "C" { + +void _FusedMatMul(const size_t a_id, const size_t* a_shape_ptr, + const size_t a_shape_len, const size_t b_id, + const size_t* b_shape_ptr, const size_t b_shape_len, + const bool transpose_a, const bool transpose_b, + const FusableActivation activation, const size_t bias_id, + const size_t prelu_weights_id, const size_t out_id); +} + +} // namespace wasm +} // namespace tfjs + +#endif // KERNELS__FUSEDMATMUL_H_ diff --git a/tfjs-backend-wasm/src/cc/kernels/_FusedMatMul_test.cc b/tfjs-backend-wasm/src/cc/kernels/_FusedMatMul_test.cc new file mode 100644 index 00000000000..96ec6f90181 --- /dev/null +++ b/tfjs-backend-wasm/src/cc/kernels/_FusedMatMul_test.cc @@ -0,0 +1,136 @@ +/* Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ===========================================================================*/ + +#include + +#include +#include + +#include "src/cc/backend.h" +#include "src/cc/kernels/_FusedMatMul.h" + +TEST(_FUSED_MATMUL, xnn_operator_lfietime) { + tfjs::wasm::init(); + + ASSERT_EQ(0, tfjs::backend::num_tensors()); + + size_t a0_id = 1; + size_t a1_id = 2; + size_t size = 2; + float a_values[2] = {1, 2}; + std::vector a_shape = {1, 2, 1}; + size_t* a_shape_ptr = a_shape.data(); + + size_t b0_id = 3; + size_t b1_id = 4; + float b_values[2] = {1, 2}; + std::vector b_shape = {1, 1, 2}; + size_t* b_shape_ptr = b_shape.data(); + + size_t out_id = 5; + float out_values[2] = {0, 0}; + + tfjs::wasm::register_tensor(a0_id, size, a_values); + tfjs::wasm::register_tensor(a1_id, size, a_values); + tfjs::wasm::register_tensor(b0_id, size, b_values); + tfjs::wasm::register_tensor(b1_id, size, b_values); + tfjs::wasm::register_tensor(out_id, size, out_values); + + ASSERT_EQ(5, tfjs::backend::num_tensors()); + ASSERT_EQ(0, tfjs::backend::xnn_operator_count); + + const FusableActivation activation = FusableActivation::LINEAR; + size_t bias_id = 0; + size_t prelu_weights_id = 0; + + // One new xnn_operator should be created for the first call to + // _FusedMatMul with no bias. + tfjs::wasm::_FusedMatMul(a0_id, a_shape_ptr, a_shape.size(), b0_id, + b_shape_ptr, b_shape.size(), false /* transpose_a */, + false /* transpose_b */, activation, bias_id, + prelu_weights_id, out_id); + ASSERT_EQ(1, tfjs::backend::xnn_operator_count); + + // No new xnn_operators should be created for the second call to + // _FusedMatMul with the same arguments. + tfjs::wasm::_FusedMatMul(a0_id, a_shape_ptr, a_shape.size(), b0_id, + b_shape_ptr, b_shape.size(), false /* transpose_a */, + false /* transpose_b */, activation, bias_id, + prelu_weights_id, out_id); + ASSERT_EQ(1, tfjs::backend::xnn_operator_count); + + // No new xnn_operators should be created for calling _FusedMatMul + // with a new a. + tfjs::wasm::_FusedMatMul(a1_id, a_shape_ptr, a_shape.size(), b0_id, + b_shape_ptr, b_shape.size(), false /* transpose_a */, + false /* transpose_b */, activation, bias_id, + prelu_weights_id, out_id); + ASSERT_EQ(1, tfjs::backend::xnn_operator_count); + + // One new xnn_operator should be created for calling _FusedMatMul + // with a new b. + tfjs::wasm::_FusedMatMul(a0_id, a_shape_ptr, a_shape.size(), b1_id, + b_shape_ptr, b_shape.size(), false /* transpose_a */, + false /* transpose_b */, activation, bias_id, + prelu_weights_id, out_id); + ASSERT_EQ(2, tfjs::backend::xnn_operator_count); + + // No new xnn_operators should be created for the next call to + // _FusedMatMul with the same b. + tfjs::wasm::_FusedMatMul(a0_id, a_shape_ptr, a_shape.size(), b1_id, + b_shape_ptr, b_shape.size(), false /* transpose_a */, + false /* transpose_b */, activation, bias_id, + prelu_weights_id, out_id); + ASSERT_EQ(2, tfjs::backend::xnn_operator_count); + + const size_t bias1_id = 6; + const size_t bias_size = 1; + float bias_values[bias_size] = {1}; + tfjs::wasm::register_tensor(bias1_id, bias_size, bias_values); + // One new xnn_operator should be created for calling _FusedMatMul with a + // new bias. + tfjs::wasm::_FusedMatMul(a0_id, a_shape_ptr, a_shape.size(), b1_id, + b_shape_ptr, b_shape.size(), false /* transpose_a */, + false /* transpose_b */, activation, bias1_id, + prelu_weights_id, out_id); + ASSERT_EQ(3, tfjs::backend::xnn_operator_count); + + // One new xnn_operator should be created for calling _FusedMatMul with a + // different activation. + const FusableActivation activation2 = FusableActivation::RELU; + tfjs::wasm::_FusedMatMul(a0_id, a_shape_ptr, a_shape.size(), b1_id, + b_shape_ptr, b_shape.size(), false /* transpose_a */, + false /* transpose_b */, activation2, bias1_id, + prelu_weights_id, out_id); + ASSERT_EQ(4, tfjs::backend::xnn_operator_count); + + // Disposing a's should not remove xnn operators. + tfjs::wasm::dispose_data(a0_id); + tfjs::wasm::dispose_data(a1_id); + ASSERT_EQ(4, tfjs::backend::xnn_operator_count); + + // Disposing the second bias should remove the xnn_operators it's associated + // with. + tfjs::wasm::dispose_data(bias1_id); + ASSERT_EQ(2, tfjs::backend::xnn_operator_count); + + // Disposing b's should remove xnn operators. + tfjs::wasm::dispose_data(b0_id); + ASSERT_EQ(1, tfjs::backend::xnn_operator_count); + + tfjs::wasm::dispose_data(b1_id); + ASSERT_EQ(0, tfjs::backend::xnn_operator_count); + + tfjs::wasm::dispose(); +} diff --git a/tfjs-backend-wasm/src/kernels/FusedConv2D.ts b/tfjs-backend-wasm/src/kernels/FusedConv2D.ts index cf3c70e9fe3..7fd7054c2fd 100644 --- a/tfjs-backend-wasm/src/kernels/FusedConv2D.ts +++ b/tfjs-backend-wasm/src/kernels/FusedConv2D.ts @@ -19,6 +19,8 @@ import {backend_util, KernelFunc, NamedTensorInfoMap, registerKernel, TensorInfo import {BackendWasm} from '../backend_wasm'; +import {FusableActivation} from './types'; + interface FusedConv2DInputs extends NamedTensorInfoMap { x: TensorInfo; filter: TensorInfo; @@ -61,14 +63,6 @@ function setup(backend: BackendWasm) { ]); } -// Must match enum in conv2d_impl.h. -enum FusableActivation { - linear = 0, - relu = 1, - relu6 = 2, - prelu = 3 -} - function fusedConv2d(args: { inputs: FusedConv2DInputs, backend: BackendWasm, diff --git a/tfjs-backend-wasm/src/kernels/FusedDepthwiseConv2D.ts b/tfjs-backend-wasm/src/kernels/FusedDepthwiseConv2D.ts index ba482015887..03d28d6e172 100644 --- a/tfjs-backend-wasm/src/kernels/FusedDepthwiseConv2D.ts +++ b/tfjs-backend-wasm/src/kernels/FusedDepthwiseConv2D.ts @@ -18,6 +18,7 @@ import {backend_util, KernelFunc, NamedTensorInfoMap, registerKernel, TensorInfo} from '@tensorflow/tfjs-core'; import {BackendWasm} from '../backend_wasm'; +import {FusableActivation} from './types'; interface FusedDepthwiseConv2DInputs extends NamedTensorInfoMap { x: TensorInfo; @@ -62,14 +63,6 @@ function setup(backend: BackendWasm) { ]); } -// Must match enum in conv2d_impl.h. -enum FusableActivation { - linear = 0, - relu = 1, - relu6 = 2, - prelu = 3 -} - function fusedDepthwiseConv2d(args: { inputs: FusedDepthwiseConv2DInputs, backend: BackendWasm, diff --git a/tfjs-backend-wasm/src/kernels/Pow.ts b/tfjs-backend-wasm/src/kernels/Pow.ts new file mode 100644 index 00000000000..fae7489c2e6 --- /dev/null +++ b/tfjs-backend-wasm/src/kernels/Pow.ts @@ -0,0 +1,20 @@ +/** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {registerBinaryKernel} from './binary_kernel'; +const supportsFullBroadcast = false; +registerBinaryKernel('Pow', supportsFullBroadcast); diff --git a/tfjs-backend-wasm/src/kernels/_FusedMatMul.ts b/tfjs-backend-wasm/src/kernels/_FusedMatMul.ts new file mode 100644 index 00000000000..8d5c4130751 --- /dev/null +++ b/tfjs-backend-wasm/src/kernels/_FusedMatMul.ts @@ -0,0 +1,121 @@ +/** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {NamedAttrMap, NamedTensorInfoMap, registerKernel, TensorInfo} from '@tensorflow/tfjs-core'; + +import {BackendWasm} from '../backend_wasm'; + +import {FusableActivation} from './types'; + +interface FusedMatMulInputs extends NamedTensorInfoMap { + a: TensorInfo; + b: TensorInfo; + bias?: TensorInfo; + preluActivationWeights?: TensorInfo; +} + +interface FusedMatMulAttrs extends NamedAttrMap { + transposeA: boolean; + transposeB: boolean; + activation: FusableActivation; +} + +let wasmFusedMatMul: ( + aId: number, aShape: Uint8Array, aShapeSize: number, bId: number, + bShape: Uint8Array, bShapeSize: number, transposeA: boolean, + transposeB: boolean, activation: number, biasId: number, + preluActivationWeightsId: number, outId: number) => void; + +function setup(backend: BackendWasm) { + wasmFusedMatMul = backend.wasm.cwrap('_FusedMatMul', null /* void */, [ + 'number', // a_id + 'array', // a_shape + 'number', // a_shape.length + 'number', // b_id + 'array', // b_shape + 'number', // b_shape.length + 'number', // transpose_a + 'number', // transpose_b + 'number', // activation + 'number', // biasId + 'number', // preluActivationWeightsId + 'number' // out_id + ]); +} + +function fusedBatchMatMul(args: { + inputs: FusedMatMulInputs, + backend: BackendWasm, + attrs: FusedMatMulAttrs +}) { + const {inputs, backend, attrs} = args; + const {a, b, bias, preluActivationWeights} = inputs; + + if (a.dtype !== 'float32' || b.dtype !== 'float32') { + throw new Error( + `_FusedMatMul for non non-float32 tensors not yet supported.`); + } + + const {transposeA, transposeB, activation} = attrs; + const aId = backend.dataIdMap.get(a.dataId).id; + const bId = backend.dataIdMap.get(b.dataId).id; + + let biasId = 0; + if (bias != null) { + const biasData = backend.dataIdMap.get(bias.dataId); + if (biasData.shape.length !== 1) { + throw new Error( + `_FusedMatMul only supports rank-1 bias but got ` + + `rank ${biasData.shape.length}.`); + } + biasId = biasData.id; + } + const preluActivationWeightsId = preluActivationWeights == null ? + 0 : + backend.dataIdMap.get(preluActivationWeights.dataId).id; + const fusedActivation = + FusableActivation[activation as {} as keyof typeof FusableActivation]; + if (fusedActivation == null) { + throw new Error( + `${activation} activation not yet supported for FusedConv2D ` + + `in the wasm backend.`); + } + + const leftDim = transposeA ? a.shape[2] : a.shape[1]; + const rightDim = transposeB ? b.shape[1] : b.shape[2]; + const batchDim = a.shape[0]; + + const out = backend.makeOutput([batchDim, leftDim, rightDim], a.dtype); + const outId = backend.dataIdMap.get(out.dataId).id; + + const aShapeBytes = new Uint8Array(new Int32Array(a.shape).buffer); + const bShapeBytes = new Uint8Array(new Int32Array(b.shape).buffer); + + wasmFusedMatMul( + aId, aShapeBytes, a.shape.length, bId, bShapeBytes, b.shape.length, + transposeA, transposeB, fusedActivation, biasId, preluActivationWeightsId, + outId); + + return out; +} + +registerKernel({ + kernelName: '_FusedMatMul', + backendName: 'wasm', + setupFunc: setup, + kernelFunc: fusedBatchMatMul +}); diff --git a/tfjs-backend-wasm/src/kernels/all_kernels.ts b/tfjs-backend-wasm/src/kernels/all_kernels.ts index 157a4cbd8fa..56cec9e79f0 100644 --- a/tfjs-backend-wasm/src/kernels/all_kernels.ts +++ b/tfjs-backend-wasm/src/kernels/all_kernels.ts @@ -18,6 +18,7 @@ // We explicitly import the modular kernels so they get registered in the // global registry when we compile the library. A modular build would replace // the contents of this file and import only the kernels that are needed. +import './_FusedMatMul'; import './Abs'; import './Add'; import './AddN'; @@ -56,6 +57,7 @@ import './NonMaxSuppressionV3'; import './NonMaxSuppressionV5'; import './NotEqual'; import './PadV2'; +import './Pow'; import './Prelu'; import './Relu'; import './Relu6'; diff --git a/tfjs-backend-wasm/src/kernels/types.ts b/tfjs-backend-wasm/src/kernels/types.ts index 2a778e4aa30..d13b2e0b871 100644 --- a/tfjs-backend-wasm/src/kernels/types.ts +++ b/tfjs-backend-wasm/src/kernels/types.ts @@ -23,3 +23,11 @@ export enum CppDType { string = 3, complex64 = 4 } + +// Must match enum in cc/fusable_activations.h. +export enum FusableActivation { + linear = 0, + relu = 1, + relu6 = 2, + prelu = 3 +} diff --git a/tfjs-backend-wasm/src/setup_test.ts b/tfjs-backend-wasm/src/setup_test.ts index 8c4d289b7e1..01d7c8bbc81 100644 --- a/tfjs-backend-wasm/src/setup_test.ts +++ b/tfjs-backend-wasm/src/setup_test.ts @@ -35,6 +35,15 @@ const TEST_FILTERS: TestFilter[] = [ ] }, {include: 'softmax'}, + { + include: 'pow', + excludes: [ + 'gradient', // zerosLike not defined yet. + 'broadcasting same rank Tensors different shape', // Broadcasting along + // inner dims not + // supported yet. + ] + }, { include: 'add ', excludes: [ @@ -53,11 +62,12 @@ const TEST_FILTERS: TestFilter[] = [ { include: 'relu', excludes: [ - 'derivative', // Not yet implemented. - 'gradient', // Not yet implemented. - 'valueAndGradients', // Not yet implemented. - 'fused matmul', // Not yet implemented. - 'broadcasted bias', // Not yet implemented. + 'derivative', // Not yet implemented. + 'gradient', // Not yet implemented. + 'valueAndGradients', // Not yet implemented. + 'broadcasted bias', // Not yet implemented. + 'fused A x B with 2d bias' // Fused matMul with 2D bias not yet + // supported. ] }, { @@ -79,12 +89,15 @@ const TEST_FILTERS: TestFilter[] = [ { include: 'matmul ', excludes: [ - 'valueAndGradients', // Gradients not defined yet - 'gradient', // Gradients not defined yet - 'fused matmul', // Fused kernels aren't ready yet - 'zero in its shape', // Zero in shapes aren't supported yet - 'matmul followed by mul', // mul not supported yet - 'upcasts', // Upcasting not supported yet. + 'valueAndGradients', // Gradients not defined yet + 'gradient', // Gradients not defined yet + 'zero in its shape', // Zero in shapes aren't supported yet + 'matmul followed by mul', // mul not supported yet + 'upcasts', // Upcasting not supported yet. + 'fused A x B with elu', // Fused matMul with elu activation not yet + // supported. + 'fused A x B with 2d bias', // Fused matMul with 2D bias not yet + // supported. ] }, { diff --git a/tfjs-core/src/ops/binary_ops.ts b/tfjs-core/src/ops/binary_ops.ts index d705fe3fd31..9aea464bea4 100644 --- a/tfjs-core/src/ops/binary_ops.ts +++ b/tfjs-core/src/ops/binary_ops.ts @@ -20,7 +20,7 @@ import {Tensor} from '../tensor'; import {NamedTensorMap} from '../tensor_types'; import {makeTypesMatch} from '../tensor_util'; import {convertToTensor} from '../tensor_util_env'; -import {TensorLike, upcastType} from '../types'; +import {TensorLike} from '../types'; import * as util from '../util'; import * as broadcast_util from './broadcast_util'; import {where} from './logical_ops'; @@ -247,14 +247,14 @@ function subStrict_(a: T|TensorLike, b: T|TensorLike): T { * @param exp The exponent `tf.Tensor` to pow element-wise. */ /** @doc {heading: 'Operations', subheading: 'Arithmetic'} */ -function pow_(base: T|TensorLike, exp: Tensor|TensorLike): T { - const $base = convertToTensor(base, 'base', 'pow'); - const $exp = convertToTensor(exp, 'exp', 'pow'); +function pow_( + base: Tensor|TensorLike, exp: Tensor|TensorLike): T { + let $base = convertToTensor(base, 'base', 'pow'); + let $exp = convertToTensor(exp, 'exp', 'pow'); + [$base, $exp] = makeTypesMatch($base, $exp); const outShape = broadcast_util.assertAndGetBroadcastShape($base.shape, $exp.shape); - base = $base.cast(upcastType($base.dtype, $exp.dtype)); - exp = $exp.cast(upcastType($base.dtype, $exp.dtype)); const grad = (dy: Tensor, saved: Tensor[]) => { const [$base, $exp, y] = saved; const derBase = () => { @@ -276,13 +276,17 @@ function pow_(base: T|TensorLike, exp: Tensor|TensorLike): T { } return res.reshape($exp.shape); }; - return {$base: derBase, $exp: derExp}; + return {a: derBase, b: derExp}; }; + + const attrs = {}; + const inputsToSave = [$base, $exp]; + const outputsToSave = [true]; return ENGINE.runKernelFunc((backend, save) => { const y = backend.pow($base, $exp); save([$base, $exp, y]); return y; - }, {$base, $exp}, grad) as T; + }, {a: $base, b: $exp}, grad, 'Pow', attrs, inputsToSave, outputsToSave) as T; } /** diff --git a/tfjs-core/src/ops/fused_ops.ts b/tfjs-core/src/ops/fused_ops.ts index 821b35406f5..0977cbae2d4 100644 --- a/tfjs-core/src/ops/fused_ops.ts +++ b/tfjs-core/src/ops/fused_ops.ts @@ -186,66 +186,70 @@ function fusedMatMul_({ let biasGradient = {}; if (bias != null) { - biasGradient = {$bias: () => getFusedBiasGradient($bias, dyActivation)}; + biasGradient = {bias: () => getFusedBiasGradient($bias, dyActivation)}; } if (!transposeA && !transposeB) { return Object.assign( { - $a: () => dyActivation.matMul(b3D as Tensor3D, false, true), - $b: () => a3D.matMul(dyActivation, true, false) + a: () => dyActivation.matMul(b3D as Tensor3D, false, true), + b: () => a3D.matMul(dyActivation, true, false) }, biasGradient); } else if (!transposeA && transposeB) { return Object.assign( { - $a: () => dyActivation.matMul(b3D as Tensor3D, false, false), - $b: () => dyActivation.matMul(a3D as Tensor3D, true, false) + a: () => dyActivation.matMul(b3D as Tensor3D, false, false), + b: () => dyActivation.matMul(a3D as Tensor3D, true, false) }, biasGradient); } else if (transposeA && !transposeB) { return Object.assign( { - $a: () => b3D.matMul(dyActivation, false, true), - $b: () => a3D.matMul(dyActivation, false, false) + a: () => b3D.matMul(dyActivation, false, true), + b: () => a3D.matMul(dyActivation, false, false) }, biasGradient); } else { return Object.assign( { - $a: () => b3D.matMul(dyActivation, true, true), - $b: () => dyActivation.matMul(a3D as Tensor3D, true, true) + a: () => b3D.matMul(dyActivation, true, true), + b: () => dyActivation.matMul(a3D as Tensor3D, true, true) }, biasGradient); } }; - const inputs: { - $a: Tensor, - $b: Tensor, - $bias?: Tensor, - $preluActivationWeights?: Tensor - } = {$a: a3D, $b: b3D}; + const inputs: + {a: Tensor, b: Tensor, + bias?: Tensor, + preluActivationWeights?: Tensor} = {a: a3D, b: b3D}; if (bias != null) { - inputs.$bias = $bias; + inputs.bias = $bias; } if (preluActivationWeights != null) { - inputs.$preluActivationWeights = $preluActivationWeights; + inputs.preluActivationWeights = $preluActivationWeights; } - const res = ENGINE.runKernelFunc((backend, save) => { - const y = backend.fusedBatchMatMul({ - a: a3D, - b: b3D, - transposeA, - transposeB, - bias: $bias, - activation, - preluActivationWeights: $preluActivationWeights - }); - save([a3D, b3D, y]); - return y; - }, inputs, grad); + const inputsToSave = [a3D, b3D]; + const outputsToSave = [true]; + + const res = ENGINE.runKernelFunc( + (backend, save) => { + const y = backend.fusedBatchMatMul({ + a: a3D, + b: b3D, + transposeA, + transposeB, + bias: $bias, + activation, + preluActivationWeights: $preluActivationWeights + }); + save([a3D, b3D, y]); + return y; + }, + inputs, grad, '_FusedMatMul', {transposeA, transposeB, activation}, + inputsToSave, outputsToSave); return res.reshape(outShape) as T; } diff --git a/tfjs-core/src/ops/fused_test.ts b/tfjs-core/src/ops/fused_test.ts index e8c6f9c91dc..bdffb90d6f2 100644 --- a/tfjs-core/src/ops/fused_test.ts +++ b/tfjs-core/src/ops/fused_test.ts @@ -20,7 +20,7 @@ import {ALL_ENVS, describeWithFlags} from '../jasmine_util'; import {expectArraysClose} from '../test_util'; describeWithFlags('fused matmul', ALL_ENVS, () => { - it('A x B', async () => { + it('fused A x B', async () => { const a = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]); const b = tf.tensor2d([0, 1, -3, 2, 2, 1], [3, 2]); @@ -30,7 +30,7 @@ describeWithFlags('fused matmul', ALL_ENVS, () => { expectArraysClose(await c.data(), [0, 8, -3, 20]); }); - it('A x B with relu', async () => { + it('fused A x B with relu', async () => { const a = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]); const b = tf.tensor2d([0, 1, -3, 2, 2, 1], [3, 2]); const transposeA = false; @@ -43,7 +43,7 @@ describeWithFlags('fused matmul', ALL_ENVS, () => { expectArraysClose(await c.data(), [0, 8, 0, 20]); }); - it('A x B with elu', async () => { + it('fused A x B with elu', async () => { const a = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]); const b = tf.tensor2d([0, 1, -3, 2, 2, 1], [3, 2]); const transposeA = false; @@ -56,7 +56,7 @@ describeWithFlags('fused matmul', ALL_ENVS, () => { expectArraysClose(await c.data(), [0, 8, -0.9502, 20]); }); - it('A x B with relu6', async () => { + it('fused A x B with relu6', async () => { const a = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]); const b = tf.tensor2d([0, 1, -3, 2, 2, 1], [3, 2]); const transposeA = false; @@ -69,7 +69,7 @@ describeWithFlags('fused matmul', ALL_ENVS, () => { expectArraysClose(await c.data(), [0, 6, 0, 6]); }); - it('A x B with prelu', async () => { + it('fused A x B with prelu', async () => { const a = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]); const b = tf.tensor2d([0, 1, -3, 2, 2, 1], [3, 2]); const alpha = tf.tensor2d([0.5, 0.5], [1, 2]); @@ -90,7 +90,7 @@ describeWithFlags('fused matmul', ALL_ENVS, () => { expectArraysClose(await c.data(), [0, 8, -1.5, 20]); }); - it('A x B with relu transpose', async () => { + it('fused A x B with relu transpose', async () => { const a = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]); const b = tf.tensor2d([0, 1, -3, 2, 2, 1], [2, 3]); const transposeA = false; @@ -103,7 +103,7 @@ describeWithFlags('fused matmul', ALL_ENVS, () => { expectArraysClose(await c.data(), [0, 9, 0, 24]); }); - it('A x B with relu and bias', async () => { + it('fused A x B with 2d bias and relu', async () => { const a = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]); const b = tf.tensor2d([0, 1, -3, 2, 2, 1], [3, 2]); const c = tf.tensor2d([1, 1, 1, 1], [2, 2]); @@ -117,7 +117,7 @@ describeWithFlags('fused matmul', ALL_ENVS, () => { expectArraysClose(await d.data(), [1, 9, 0, 21]); }); - it('A x B with relu and broadcasted bias', async () => { + it('fused A x B with relu and broadcasted bias', async () => { const a = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]); const b = tf.tensor2d([0, 1, -3, 2, 2, 1], [3, 2]); const c = tf.tensor1d([1, 1]); @@ -132,7 +132,7 @@ describeWithFlags('fused matmul', ALL_ENVS, () => { expectArraysClose(await d.data(), [1, 9, 0, 21]); }); - it('A x B with elu and broadcasted bias', async () => { + it('fused A x B with elu and broadcasted bias', async () => { const a = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]); const b = tf.tensor2d([0, 1, -3, 2, 2, 1], [3, 2]); const c = tf.tensor1d([1, 1]); @@ -147,7 +147,7 @@ describeWithFlags('fused matmul', ALL_ENVS, () => { expectArraysClose(await d.data(), [1, 9, -0.8647, 21]); }); - it('A x B with relu and broadcasted bias different rank', async () => { + it('fused A x B with relu and broadcasted bias different rank', async () => { const a = tf.tensor3d([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], [2, 2, 3]); const b = tf.tensor3d([0, 1, -3, 2, 2, 1, 0, 1, -3, 2, 2, 1], [2, 3, 2]); const c = tf.tensor2d([1, 2], [1, 2]); @@ -162,7 +162,7 @@ describeWithFlags('fused matmul', ALL_ENVS, () => { expectArraysClose(await d.data(), [2, 6, 0, 18, 0, 30, 0, 42]); }); - it('A x B with bias only', async () => { + it('fused A x B with 2d bias only', async () => { const a = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]); const b = tf.tensor2d([0, 1, -3, 2, 2, 1], [3, 2]); const c = tf.tensor2d([1, 1, 1, 1], [2, 2]); @@ -176,7 +176,7 @@ describeWithFlags('fused matmul', ALL_ENVS, () => { expectArraysClose(await d.data(), [1, 9, -2, 21]); }); - it('A x B with relu gradient', async () => { + it('fused A x B with relu gradient', async () => { const a = tf.tensor2d([1, 2, 3, 10, 20, -30], [2, 3]); const b = tf.tensor2d([2, 3, 4, -1, 2, 3], [3, 2]); const dy = tf.tensor2d([1, 10, 20, 30], [2, 2]); @@ -224,7 +224,7 @@ describeWithFlags('fused matmul', ALL_ENVS, () => { expect(fusedDb.shape).toEqual(b.shape); }); - it('A x B with relu bias gradient', async () => { + it('fused A x B with relu bias gradient', async () => { const a = tf.tensor2d([1, 2, 3, 10, 20, -30], [2, 3]); const b = tf.tensor2d([2, 3, 4, -1, 2, 3], [3, 2]); const c = tf.tensor2d([1, 1, 1, 1], [2, 2]); @@ -252,7 +252,7 @@ describeWithFlags('fused matmul', ALL_ENVS, () => { expectArraysClose(await dc.array(), await fusedDc.array()); }); - it('A x B with relu bias gradient transpose', async () => { + it('fused A x B with relu bias gradient transpose', async () => { const a = tf.tensor2d([1, 2, 3, 10, 20, -30], [3, 2]); const b = tf.tensor2d([2, 3, 4, -1, 2, 3], [3, 2]); const c = tf.tensor2d([1, 1, 1, 1], [2, 2]); @@ -280,7 +280,7 @@ describeWithFlags('fused matmul', ALL_ENVS, () => { expectArraysClose(await dc.array(), await fusedDc.array()); }); - it('A x B with relu and broadcasted bias gradient', async () => { + it('fused A x B with relu and broadcasted bias gradient', async () => { const a = tf.tensor2d([1, 2, 3, 10, 20, -30], [2, 3]); const b = tf.tensor2d([2, 3, 4, -1, 2, 3], [3, 2]); const c = tf.tensor2d([[1]]); @@ -1021,7 +1021,7 @@ describeWithFlags('fused conv2d', ALL_ENVS, () => { expectArraysClose(await dbiasFused.array(), await dbias.array()); }); - it('fused matmul with relu6', async () => { + it('fused matmul with relu6 and gradients', async () => { const a = tf.tensor2d([1, 2, 3, 10, 20, -30], [2, 3]); const b = tf.tensor2d([2, 3, 4, -1, 2, 3], [3, 2]); const dy = tf.tensor2d([1, 10, 20, 30], [2, 2]);