-
Notifications
You must be signed in to change notification settings - Fork 2k
[WASM] Add Conv2D that calls into xnn pack. #2283
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
ceb3cf4
9f1e25f
cc48262
1cd5f15
0c58a0f
bab1ccf
98e5c09
81541e7
51f681c
6c4b744
23200a3
18e3398
0a6092b
6359ee4
6a37bfe
7b26420
63a32da
f656a20
15707da
24381c9
20c05a1
a8efa2b
86fef5b
5134c81
4e57f7a
80fdf3b
ccca603
e278402
2e4472b
f95090d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,153 @@ | ||
| /* Copyright 2019 Google Inc. All Rights Reserved. | ||
| * Licensed under the Apache License, Version 2.0 (the "License"); | ||
| * you may not use this file except in compliance with the License. | ||
| * You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| * ===========================================================================*/ | ||
|
|
||
| #ifdef __EMSCRIPTEN__ | ||
| #include <emscripten.h> | ||
| #endif | ||
|
|
||
| #include <xnnpack.h> | ||
| #include <array> | ||
| #include <cmath> | ||
| #include <limits> | ||
| #include <map> | ||
| #include <memory> | ||
| #include <unordered_map> | ||
| #include <utility> | ||
| #include <vector> | ||
|
|
||
| #include "src/cc/backend.h" | ||
| #include "src/cc/kernels/Conv2D.h" | ||
| #include "src/cc/util.h" | ||
|
|
||
| namespace { | ||
| // These integer values are keys to creating the conv2d operator. We use | ||
| // std::array instead of a vanilla array as it implements the compare operator | ||
| // needed for std::map. | ||
| typedef std::array<int, 15> OperatorCacheKey; | ||
|
|
||
| // The operator cache maps the cache key to the xnn_operator_t instantiated for | ||
| // this set of arguments to the xnn_operator. | ||
| std::map<OperatorCacheKey, xnn_operator_t> operator_cache; | ||
|
|
||
| // Maps a filter id to a list of operator cache keys that this filter belongs | ||
| // to. | ||
| std::unordered_map<int, std::vector<OperatorCacheKey>> | ||
| filter_operator_cache_key_map; | ||
|
|
||
| void delete_xnn_operators(int filter_id) { | ||
| std::vector<OperatorCacheKey> operator_cache_keys = | ||
| filter_operator_cache_key_map[filter_id]; | ||
| for (auto& operator_cache_key : operator_cache_keys) { | ||
| auto& conv2d_op = operator_cache[operator_cache_key]; | ||
| operator_cache.erase(operator_cache_key); | ||
| tfjs::backend::xnn_operator_count--; | ||
| } | ||
| filter_operator_cache_key_map.erase(filter_id); | ||
| } | ||
| } // namespace | ||
|
|
||
| namespace tfjs { | ||
| namespace wasm { | ||
| // We use C-style API to interface with Javascript. | ||
| extern "C" { | ||
|
|
||
| #ifdef __EMSCRIPTEN__ | ||
| EMSCRIPTEN_KEEPALIVE | ||
| #endif | ||
| void Conv2D(const int x_id, const int batch_size, const int input_height, | ||
| const int input_width, const int filter_id, const int filter_height, | ||
| const int filter_width, int pad_top, int pad_right, int pad_bottom, | ||
| int pad_left, const int is_same_pad, const int dilation_height, | ||
| const int dilation_width, const int stride_height, | ||
| const int stride_width, const int input_channels, | ||
| const int output_channels, const int out_id) { | ||
| auto& x_info = backend::get_tensor_info(x_id); | ||
| auto& filter_info = backend::get_tensor_info(filter_id); | ||
| auto& out_info = backend::get_tensor_info(out_id); | ||
|
|
||
| const float* x_buf = reinterpret_cast<float*>(x_info.memory_offset); | ||
| const float* filter_buf = reinterpret_cast<float*>(filter_info.memory_offset); | ||
| float* out_buf = reinterpret_cast<float*>(out_info.memory_offset); | ||
|
|
||
| xnn_operator_t conv2d_op = nullptr; | ||
|
|
||
| int flags = 0; | ||
| if (is_same_pad) { | ||
| pad_top = 0, pad_right = 0, pad_bottom = 0, pad_left = 0; | ||
| flags = XNN_FLAG_TENSORFLOW_SAME_PADDING; | ||
| } | ||
|
|
||
| const int groups = 1; | ||
|
|
||
| OperatorCacheKey cache_key = {pad_top, pad_right, pad_bottom, | ||
| pad_left, filter_height, filter_width, | ||
| stride_height, stride_width, dilation_height, | ||
| dilation_width, groups, input_channels, | ||
| output_channels, filter_id, flags}; | ||
|
|
||
| auto operator_cache_idx = operator_cache.find(cache_key); | ||
| if (operator_cache_idx == operator_cache.end()) { | ||
| float output_min = -std::numeric_limits<float>::infinity(); | ||
| float output_max = std::numeric_limits<float>::infinity(); | ||
|
|
||
| const float* bias_buf = nullptr; | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does TF.js support fused Conv2D+Bias? Adding bias separately is inefficient |
||
| xnn_status status = xnn_create_convolution2d_nhwc_f32( | ||
| pad_top, pad_right, pad_bottom, pad_left, filter_height, filter_width, | ||
| stride_height, stride_width, dilation_height, dilation_width, groups, | ||
| input_channels /* group_input_channels */, | ||
| output_channels /* group_output_channels */, | ||
| input_channels /* input_pixel_stride */, | ||
| output_channels /* output_pixel_stride */, filter_buf, bias_buf, | ||
| output_min, output_max, flags, &conv2d_op); | ||
| if (status != xnn_status_success) { | ||
| util::warn( | ||
| "XNN status for xnn_create_convolution2d_nhwc_f32 is not successful. " | ||
| "Got status %d. Use -c dbg to see XNN logs.", | ||
| status); | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. does |
||
| } | ||
|
|
||
| operator_cache.emplace(cache_key, conv2d_op); | ||
|
|
||
| auto cache_keys_idx = filter_operator_cache_key_map.find(filter_id); | ||
| if (cache_keys_idx == filter_operator_cache_key_map.end()) { | ||
| std::vector<OperatorCacheKey> cache_keys = {cache_key}; | ||
| filter_operator_cache_key_map.emplace(filter_id, std::move(cache_keys)); | ||
| backend::register_disposal_callback(filter_id, *delete_xnn_operators); | ||
|
|
||
| } else { | ||
| auto& cache_keys = filter_operator_cache_key_map.at(filter_id); | ||
| cache_keys.emplace_back(cache_key); | ||
| } | ||
|
|
||
| tfjs::backend::xnn_operator_count++; | ||
| } else { | ||
| conv2d_op = operator_cache_idx->second; | ||
| } | ||
|
|
||
| xnn_status status = xnn_setup_convolution2d_nhwc_f32( | ||
| conv2d_op, batch_size, input_height, input_width, x_buf, out_buf, | ||
| nullptr /* thread pool */); | ||
| if (status != xnn_status_success) { | ||
| util::warn( | ||
| "XNN status for xnn_setup_convolution2d_nhwc_f32 is not successful. " | ||
| "Got status %d. Use -c dbg to see XNN logs.", | ||
| status); | ||
| } | ||
|
|
||
| xnn_run_operator(conv2d_op, nullptr /* thread pool */); | ||
| } | ||
|
|
||
| } // extern "C" | ||
| } // namespace wasm | ||
| } // namespace tfjs | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,34 @@ | ||
| /* Copyright 2019 Google Inc. All Rights Reserved. | ||
| * Licensed under the Apache License, Version 2.0 (the "License"); | ||
| * you may not use this file except in compliance with the License. | ||
| * You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| * ===========================================================================*/ | ||
|
|
||
| #ifndef KERNELS_CONV2D_H_ | ||
| #define KERNELS_CONV2D_H_ | ||
|
|
||
| namespace tfjs { | ||
|
|
||
| namespace wasm { | ||
| extern "C" { | ||
| void Conv2D(const int x_id, const int batch_size, const int input_height, | ||
| const int input_width, const int filter_id, const int filter_height, | ||
| const int filter_width, int pad_top, int pad_right, int pad_bottom, | ||
| int pad_left, const int is_same_pad, const int dilation_height, | ||
| const int dilation_width, const int stride_height, | ||
| const int stride_width, const int input_channels, | ||
| const int output_channels, const int out_id); | ||
| } | ||
|
|
||
| } // namespace wasm | ||
| } // namespace tfjs | ||
|
|
||
| #endif // KERNELS_CONV2D_H_ |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,138 @@ | ||
| /* Copyright 2019 Google Inc. All Rights Reserved. | ||
| * Licensed under the Apache License, Version 2.0 (the "License"); | ||
| * you may not use this file except in compliance with the License. | ||
| * You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| * ===========================================================================*/ | ||
|
|
||
| #include <gtest/gtest.h> | ||
|
|
||
| #include "src/cc/backend.h" | ||
| #include "src/cc/kernels/Conv2D.h" | ||
|
|
||
| TEST(CONV2D, xnn_operator_lifetime) { | ||
| tfjs::wasm::init(); | ||
|
|
||
| ASSERT_EQ(0, tfjs::backend::num_tensors()); | ||
|
|
||
| const int x0_id = 0; | ||
| const int x1_id = 1; | ||
| const int size = 8; | ||
| float x_values[size] = {1, 2, 3, 4, 5, 6, 7, 8}; | ||
|
|
||
| const int weights0_id = 2; | ||
| const int weights1_id = 3; | ||
| const int weights_size = 8; | ||
| float weights_values[weights_size] = {1, 2, 3, 4, 5, 6, 7, 8}; | ||
|
|
||
| const int out_id = 4; | ||
| const int out_size = 12; | ||
| float out_values[out_size] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; | ||
|
|
||
| tfjs::wasm::register_tensor(x0_id, size, x_values); | ||
| tfjs::wasm::register_tensor(x1_id, size, x_values); | ||
| tfjs::wasm::register_tensor(weights0_id, weights_size, weights_values); | ||
| tfjs::wasm::register_tensor(weights1_id, weights_size, weights_values); | ||
| tfjs::wasm::register_tensor(out_id, out_size, out_values); | ||
|
|
||
| ASSERT_EQ(5, tfjs::backend::num_tensors()); | ||
| ASSERT_EQ(0, tfjs::backend::xnn_operator_count); | ||
|
|
||
| // One xnn_operator should be created for the first call to conv2d. | ||
| const int batch_size = 1; | ||
| const int input_height = 4; | ||
| const int input_width = 2; | ||
| const int filter_height = 4; | ||
| const int filter_width = 2; | ||
| const int pad_top0 = 1; | ||
| const int pad_right = 0; | ||
| const int pad_bottom0 = 0; | ||
| const int pad_left = 0; | ||
| const int is_same_pad0 = 0; | ||
| const int dilation_height = 1; | ||
| const int dilation_width = 1; | ||
| const int stride_height = 1; | ||
| const int stride_width = 1; | ||
| const int input_channels = 1; | ||
| const int output_channels = 1; | ||
| tfjs::wasm::Conv2D(x0_id, batch_size, input_height, input_width, weights0_id, | ||
| filter_height, filter_width, pad_top0, pad_right, | ||
| pad_bottom0, pad_left, is_same_pad0, dilation_height, | ||
| dilation_width, stride_height, stride_width, | ||
| input_channels, output_channels, out_id); | ||
| ASSERT_EQ(1, tfjs::backend::xnn_operator_count); | ||
|
|
||
| // No new xnn_operators should be created for the second call to conv2d with | ||
| // the same arguments. | ||
| tfjs::wasm::Conv2D(x0_id, batch_size, input_height, input_width, weights0_id, | ||
| filter_height, filter_width, pad_top0, pad_right, | ||
| pad_bottom0, pad_left, is_same_pad0, dilation_height, | ||
| dilation_width, stride_height, stride_width, | ||
| input_channels, output_channels, out_id); | ||
| ASSERT_EQ(1, tfjs::backend::xnn_operator_count); | ||
|
|
||
| // No new xnn_operators should be created for the second call to conv2d with | ||
| // the same arguments but different input. | ||
| tfjs::wasm::Conv2D(x1_id, batch_size, input_height, input_width, weights0_id, | ||
| filter_height, filter_width, pad_top0, pad_right, | ||
| pad_bottom0, pad_left, is_same_pad0, dilation_height, | ||
| dilation_width, stride_height, stride_width, | ||
| input_channels, output_channels, out_id); | ||
| ASSERT_EQ(1, tfjs::backend::xnn_operator_count); | ||
|
|
||
| // One new xnn_operator should be created for the next call to conv2d with the | ||
| // same weights but different arguments. | ||
| const int pad_top1 = 0; | ||
| const int pad_bottom1 = 1; | ||
| tfjs::wasm::Conv2D(x0_id, batch_size, input_height, input_width, weights0_id, | ||
| filter_height, filter_width, pad_top1, pad_right, | ||
| pad_bottom1, pad_left, is_same_pad0, dilation_height, | ||
| dilation_width, stride_height, stride_width, | ||
| input_channels, output_channels, out_id); | ||
| ASSERT_EQ(2, tfjs::backend::xnn_operator_count); | ||
|
|
||
| // One more xnn operator should be created for the next call to conv2d with | ||
| // new weights and same input. | ||
| tfjs::wasm::Conv2D(x0_id, batch_size, input_height, input_width, weights1_id, | ||
| filter_height, filter_width, pad_top0, pad_right, | ||
| pad_bottom0, pad_left, is_same_pad0, dilation_height, | ||
| dilation_width, stride_height, stride_width, | ||
| input_channels, output_channels, out_id); | ||
| ASSERT_EQ(3, tfjs::backend::xnn_operator_count); | ||
|
|
||
| // One more xnn operator should be created for the next call to conv2d with | ||
| // 'SAME' padding. | ||
| const int is_same_pad1 = 1; | ||
| tfjs::wasm::Conv2D(x0_id, batch_size, input_height, input_width, weights1_id, | ||
| filter_height, filter_width, pad_top0, pad_right, | ||
| pad_bottom0, pad_left, is_same_pad1, dilation_height, | ||
| dilation_width, stride_height, stride_width, | ||
| input_channels, output_channels, out_id); | ||
| ASSERT_EQ(4, tfjs::backend::xnn_operator_count); | ||
|
|
||
| // No new XNN operators should be created for the next call to conv2d with | ||
| // 'SAME' padding and different input and raw padding values. | ||
| tfjs::wasm::Conv2D(x1_id, batch_size, input_height, input_width, weights1_id, | ||
| filter_height, filter_width, pad_top1, pad_right, | ||
| pad_bottom1, pad_left, is_same_pad1, dilation_height, | ||
| dilation_width, stride_height, stride_width, | ||
| input_channels, output_channels, out_id); | ||
| ASSERT_EQ(4, tfjs::backend::xnn_operator_count); | ||
|
|
||
| // Disposing the first weights should remove 2 operators. | ||
| tfjs::wasm::dispose_data(weights0_id); | ||
| ASSERT_EQ(2, tfjs::backend::xnn_operator_count); | ||
|
|
||
| // Disposing the second weights should remove the last 2 operator. | ||
| tfjs::wasm::dispose_data(weights1_id); | ||
| ASSERT_EQ(0, tfjs::backend::xnn_operator_count); | ||
|
|
||
| tfjs::wasm::dispose(); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Erase filter operator cache key