Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
1555ed8
clean
annxingyuan Feb 4, 2020
22171d3
shell
annxingyuan Feb 4, 2020
99b1e5c
compute lse
annxingyuan Feb 4, 2020
3865aac
softmax
annxingyuan Feb 4, 2020
971fe5a
add cpu
annxingyuan Feb 4, 2020
14100b7
upgrade deps
annxingyuan Feb 4, 2020
cb717af
setup
annxingyuan Feb 4, 2020
5560bb8
wtf
annxingyuan Feb 4, 2020
8a59b0c
use
annxingyuan Feb 4, 2020
d5037f0
simplify
annxingyuan Feb 4, 2020
6981390
softmax test
annxingyuan Feb 4, 2020
6b01d0c
fix
annxingyuan Feb 5, 2020
099a594
pass in dim
annxingyuan Feb 5, 2020
188bde6
test case
annxingyuan Feb 5, 2020
19a91bd
logs
annxingyuan Feb 5, 2020
74af251
delete
annxingyuan Feb 5, 2020
007edeb
add batch to key
annxingyuan Feb 5, 2020
bbcea30
move log statement
annxingyuan Feb 5, 2020
932e7a6
remove batch from cache
annxingyuan Feb 6, 2020
1244ca3
Merge branch 'master' into softmax
annxingyuan Feb 6, 2020
ccebc1e
add note
annxingyuan Feb 6, 2020
fcb079a
remove build flags
annxingyuan Feb 7, 2020
2daaa4f
remove logs
annxingyuan Feb 7, 2020
dd0ba9f
testing
annxingyuan Feb 7, 2020
7962041
remove header
annxingyuan Feb 7, 2020
4e7fdc8
add neg
annxingyuan Feb 7, 2020
869768b
register
annxingyuan Feb 8, 2020
3c9dda6
notequal
annxingyuan Feb 8, 2020
9c58b33
lint
annxingyuan Feb 10, 2020
88d8705
revive spy
annxingyuan Feb 10, 2020
648bc6c
save
annxingyuan Feb 10, 2020
40d5e82
add neg
annxingyuan Feb 10, 2020
e149666
save outputs
annxingyuan Feb 11, 2020
8903f08
start
annxingyuan Feb 11, 2020
2c9572b
fix
annxingyuan Feb 11, 2020
8d9f789
edit
annxingyuan Feb 13, 2020
fa76890
Merge branch 'master' into softmax
annxingyuan Feb 13, 2020
29a7cdc
remove
annxingyuan Feb 13, 2020
2ae0a59
revive
annxingyuan Feb 13, 2020
4442f31
revive
annxingyuan Feb 13, 2020
c4385c4
build
annxingyuan Feb 13, 2020
0024ad6
add h
annxingyuan Feb 13, 2020
9de07da
fix test
annxingyuan Feb 13, 2020
4827669
save
annxingyuan Feb 13, 2020
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions tfjs-backend-wasm/.bazelrc
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@ build:wasm --cxxopt="-std=c++11"
build:wasm --cxxopt="-fno-rtti"
build:wasm --cxxopt="-fno-exceptions"
build:wasm --cxxopt="-fomit-frame-pointer"
build:wasm --cxxopt="-ffast-math"
build:wasm --copt="-ffast-math"

# Disable sandbox environment because emsdk caches files by writing to
# home directory.
Expand Down
5 changes: 3 additions & 2 deletions tfjs-backend-wasm/WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ emsdk_configure(name = "emsdk")

git_repository(
name = "xnnpack",
commit = "3a77ea7bbe30b2411591f2ab15f9c5032a25f688",
commit = "7278a95e3cfae6eac73f363c4fda5db53e1b2a87",
remote = "https://github.com/google/XNNPACK.git",
shallow_since = "1577131863 -0800",
shallow_since = "1580796377 -0800",
)

# The libraries below are transitive dependencies of XNNPACK that we need to
Expand Down Expand Up @@ -64,6 +64,7 @@ http_archive(
http_archive(
name = "cpuinfo",
build_file = "@xnnpack//third_party:cpuinfo.BUILD",
patches = ["@xnnpack//third_party:cpuinfo.patch"],
sha256 = "3f2dc1970f397a0e59db72f9fca6ff144b216895c1d606f6c94a507c1e53a025",
strip_prefix = "cpuinfo-d5e37adf1406cf899d7d9ec1d317c47506ccb970",
urls = [
Expand Down
4 changes: 2 additions & 2 deletions tfjs-backend-wasm/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,15 @@
"path": false
},
"peerDependencies": {
"@tensorflow/tfjs-core": "1.5.2"
"@tensorflow/tfjs-core": "link:../tfjs-core"
},
"dependencies": {
"@types/emscripten": "~0.0.34"
},
"devDependencies": {
"@bazel/bazel": "^0.28.0",
"@bazel/buildifier": "0.29.0",
"@tensorflow/tfjs-core": "1.5.2",
"@tensorflow/tfjs-core": "link:../tfjs-core",
"@types/jasmine": "~2.8.6",
"clang-format": "~1.2.4",
"jasmine": "~3.1.0",
Expand Down
40 changes: 40 additions & 0 deletions tfjs-backend-wasm/src/cc/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -177,15 +177,18 @@ tfjs_cc_library(
":Min",
":Minimum",
":Mul",
":Neg",
":NonMaxSuppressionV3",
":NonMaxSuppressionV5",
":NotEqual",
":PadV2",
":Prelu",
":Relu",
":Relu6",
":ResizeBilinear",
":ScatterNd",
":Sigmoid",
":Softmax",
":Sub",
":Tile",
":Transpose",
Expand Down Expand Up @@ -457,6 +460,16 @@ tfjs_cc_library(
],
)

tfjs_cc_library(
name = "NotEqual",
srcs = ["kernels/NotEqual.cc"],
deps = [
":backend",
":binary",
":util",
],
)

tfjs_cc_library(
name = "LogicalAnd",
srcs = ["kernels/LogicalAnd.cc"],
Expand Down Expand Up @@ -542,6 +555,15 @@ tfjs_cc_library(
],
)

tfjs_cc_library(
name = "Neg",
srcs = ["kernels/Neg.cc"],
deps = [
":backend",
":unary",
],
)

tfjs_cc_library(
name = "NonMaxSuppressionV3",
srcs = ["kernels/NonMaxSuppressionV3.cc"],
Expand Down Expand Up @@ -647,6 +669,16 @@ tfjs_cc_library(
],
)

tfjs_cc_library(
name = "Softmax",
srcs = ["kernels/Softmax.cc"],
hdrs = ["kernels/Softmax.h"],
deps = [
":backend",
":unary",
],
)

tfjs_unit_test(
name = "Sigmoid_test",
srcs = ["kernels/Sigmoid_test.cc"],
Expand All @@ -655,6 +687,14 @@ tfjs_unit_test(
],
)

tfjs_unit_test(
name = "Softmax_test",
srcs = ["kernels/Softmax_test.cc"],
deps = [
":Softmax",
],
)

tfjs_cc_library(
name = "Square",
srcs = ["kernels/Square.cc"],
Expand Down
40 changes: 40 additions & 0 deletions tfjs-backend-wasm/src/cc/kernels/Neg.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* ===========================================================================*/

#ifdef __EMSCRIPTEN__
#include <emscripten.h>
#endif

#include <cmath>

#include "src/cc/backend.h"
#include "src/cc/unary.h"

namespace {
inline float neg(const float val) { return -val; }
} // namespace

namespace tfjs {
namespace wasm {
// We use C-style API to interface with Javascript.
extern "C" {

#ifdef __EMSCRIPTEN__
EMSCRIPTEN_KEEPALIVE
#endif
void Neg(const int x_id, const int out_id) { unary(x_id, out_id, neg); }

} // extern "C"
} // namespace wasm
} // namespace tfjs
60 changes: 60 additions & 0 deletions tfjs-backend-wasm/src/cc/kernels/NotEqual.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* ===========================================================================*/

#ifdef __EMSCRIPTEN__
#include <emscripten.h>
#endif

#include "src/cc/binary.h"
#include "src/cc/util.h"

namespace {
template <class T>
inline bool notEqual(T a, T b) {
return a != b;
}
} // namespace

namespace tfjs {
namespace wasm {
// We use C-style API to interface with Javascript.
extern "C" {

#ifdef __EMSCRIPTEN__
EMSCRIPTEN_KEEPALIVE
#endif
void NotEqual(const int a_id, const size_t* a_shape_ptr, const int a_shape_len,
const int b_id, const size_t* b_shape_ptr, const int b_shape_len,
const DType input_type, const int out_id) {
switch (input_type) {
case DType::float32:
compare_f32(a_id, b_id, out_id, notEqual<float>);
break;
case DType::int32:
compare_i32(a_id, b_id, out_id, notEqual<int>);
break;
case DType::boolean:
compare_bool(a_id, b_id, out_id, notEqual<bool>);
break;
default:
util::warn(
"NotEqual for tensor ids %d and %d failed."
"Unsupported input_type %d",
a_id, b_id, input_type);
}
}

} // extern "C"
} // namespace wasm
} // namespace tfjs
1 change: 0 additions & 1 deletion tfjs-backend-wasm/src/cc/kernels/Sigmoid.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ namespace wasm {
extern "C" {

void Sigmoid(const size_t x_id, const size_t out_id);

}

} // namespace wasm
Expand Down
99 changes: 99 additions & 0 deletions tfjs-backend-wasm/src/cc/kernels/Softmax.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
/* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* ===========================================================================*/

#ifdef __EMSCRIPTEN__
#include <emscripten.h>
#endif

#include <xnnpack.h>
#include <cmath>
#include <cstddef>
#include <map>
#include <tuple>

#include "src/cc/backend.h"
#include "src/cc/kernels/Softmax.h"
#include "src/cc/util.h"

namespace {
// We use std::tuple as the cache key as it implements the compare operator
// needed for std::map.
typedef std::tuple<size_t> OperatorCacheKey;

// The operator cache maps the weights id to the xnn_operator_t instantiated for
// this set of weights.
std::map<OperatorCacheKey, xnn_operator_t> operator_cache;

} // namespace

namespace tfjs {
namespace wasm {

extern "C" {

#ifdef __EMSCRIPTEN__
EMSCRIPTEN_KEEPALIVE
#endif

void Softmax(const size_t x_id, const size_t out_id, const size_t channels,
const size_t batch) {
auto& x_info = backend::get_tensor_info(x_id);
auto& out_info = backend::get_tensor_info_out(out_id);

const float* x_buf = x_info.f32();
float* out_buf = out_info.f32_write();

xnn_operator_t softmax_op = nullptr;

OperatorCacheKey cache_key = {channels};

auto operator_cache_idx = operator_cache.find(cache_key);
if (operator_cache_idx == operator_cache.end()) {
const size_t input_stride = channels;
const size_t output_stride = channels;
const uint32_t flags = 0;

xnn_status status = xnn_create_softmax_nc_f32(
channels, input_stride, output_stride, flags, &softmax_op);
if (status != xnn_status_success) {
tfjs::util::warn(
"XNN status for xnn_create_softmax_nc_f32 is not "
"successful. Got status %d. Use -c dbg to see XNN logs.",
status);
return;
}

operator_cache.insert({cache_key, softmax_op});

tfjs::backend::xnn_operator_count++;
} else {
softmax_op = operator_cache_idx->second;
}

xnn_status status = xnn_setup_softmax_nc_f32(
softmax_op, batch, x_buf, out_buf, nullptr /* thread pool */);
if (status != xnn_status_success) {
tfjs::util::warn(
"XNN status for xnn_setup_softmax_nc_f32 is not "
"successful. Got status %d. Use -c dbg to see XNN logs.",
status);
return;
}

xnn_run_operator(softmax_op, nullptr /* thread pool */);
}

} // extern "C"
} // namespace wasm
} // namespace tfjs
30 changes: 30 additions & 0 deletions tfjs-backend-wasm/src/cc/kernels/Softmax.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* ===========================================================================*/

#ifndef KERNELS_SOFTMAX_H_
#define KERNELS_SOFTMAX_H_

#include <cstddef>

namespace tfjs {
namespace wasm {
extern "C" {
void Softmax(const size_t x_id, const size_t out_id, const size_t channels,
const size_t batch);
}

} // namespace wasm
} // namespace tfjs

#endif // KERNELS_SOFTMAX_H_
Loading