diff --git a/tfjs-backend-wasm/.bazelrc b/tfjs-backend-wasm/.bazelrc index 0754893952d..8d066759084 100644 --- a/tfjs-backend-wasm/.bazelrc +++ b/tfjs-backend-wasm/.bazelrc @@ -13,8 +13,6 @@ build:wasm --cxxopt="-std=c++11" build:wasm --cxxopt="-fno-rtti" build:wasm --cxxopt="-fno-exceptions" build:wasm --cxxopt="-fomit-frame-pointer" -build:wasm --cxxopt="-ffast-math" -build:wasm --copt="-ffast-math" # Disable sandbox environment because emsdk caches files by writing to # home directory. diff --git a/tfjs-backend-wasm/WORKSPACE b/tfjs-backend-wasm/WORKSPACE index 540effdb753..e76d39809be 100644 --- a/tfjs-backend-wasm/WORKSPACE +++ b/tfjs-backend-wasm/WORKSPACE @@ -8,9 +8,9 @@ emsdk_configure(name = "emsdk") git_repository( name = "xnnpack", - commit = "3a77ea7bbe30b2411591f2ab15f9c5032a25f688", + commit = "7278a95e3cfae6eac73f363c4fda5db53e1b2a87", remote = "https://github.com/google/XNNPACK.git", - shallow_since = "1577131863 -0800", + shallow_since = "1580796377 -0800", ) # The libraries below are transitive dependencies of XNNPACK that we need to @@ -64,6 +64,7 @@ http_archive( http_archive( name = "cpuinfo", build_file = "@xnnpack//third_party:cpuinfo.BUILD", + patches = ["@xnnpack//third_party:cpuinfo.patch"], sha256 = "3f2dc1970f397a0e59db72f9fca6ff144b216895c1d606f6c94a507c1e53a025", strip_prefix = "cpuinfo-d5e37adf1406cf899d7d9ec1d317c47506ccb970", urls = [ diff --git a/tfjs-backend-wasm/package.json b/tfjs-backend-wasm/package.json index e2fc73893a5..d86c9f17bc1 100644 --- a/tfjs-backend-wasm/package.json +++ b/tfjs-backend-wasm/package.json @@ -32,7 +32,7 @@ "path": false }, "peerDependencies": { - "@tensorflow/tfjs-core": "1.5.2" + "@tensorflow/tfjs-core": "link:../tfjs-core" }, "dependencies": { "@types/emscripten": "~0.0.34" @@ -40,7 +40,7 @@ "devDependencies": { "@bazel/bazel": "^0.28.0", "@bazel/buildifier": "0.29.0", - "@tensorflow/tfjs-core": "1.5.2", + "@tensorflow/tfjs-core": "link:../tfjs-core", "@types/jasmine": "~2.8.6", "clang-format": "~1.2.4", "jasmine": "~3.1.0", diff --git a/tfjs-backend-wasm/src/cc/BUILD b/tfjs-backend-wasm/src/cc/BUILD index d41cca267f1..0126adfa16e 100644 --- a/tfjs-backend-wasm/src/cc/BUILD +++ b/tfjs-backend-wasm/src/cc/BUILD @@ -177,8 +177,10 @@ tfjs_cc_library( ":Min", ":Minimum", ":Mul", + ":Neg", ":NonMaxSuppressionV3", ":NonMaxSuppressionV5", + ":NotEqual", ":PadV2", ":Prelu", ":Relu", @@ -186,6 +188,7 @@ tfjs_cc_library( ":ResizeBilinear", ":ScatterNd", ":Sigmoid", + ":Softmax", ":Sub", ":Tile", ":Transpose", @@ -457,6 +460,16 @@ tfjs_cc_library( ], ) +tfjs_cc_library( + name = "NotEqual", + srcs = ["kernels/NotEqual.cc"], + deps = [ + ":backend", + ":binary", + ":util", + ], +) + tfjs_cc_library( name = "LogicalAnd", srcs = ["kernels/LogicalAnd.cc"], @@ -542,6 +555,15 @@ tfjs_cc_library( ], ) +tfjs_cc_library( + name = "Neg", + srcs = ["kernels/Neg.cc"], + deps = [ + ":backend", + ":unary", + ], +) + tfjs_cc_library( name = "NonMaxSuppressionV3", srcs = ["kernels/NonMaxSuppressionV3.cc"], @@ -647,6 +669,16 @@ tfjs_cc_library( ], ) +tfjs_cc_library( + name = "Softmax", + srcs = ["kernels/Softmax.cc"], + hdrs = ["kernels/Softmax.h"], + deps = [ + ":backend", + ":unary", + ], +) + tfjs_unit_test( name = "Sigmoid_test", srcs = ["kernels/Sigmoid_test.cc"], @@ -655,6 +687,14 @@ tfjs_unit_test( ], ) +tfjs_unit_test( + name = "Softmax_test", + srcs = ["kernels/Softmax_test.cc"], + deps = [ + ":Softmax", + ], +) + tfjs_cc_library( name = "Square", srcs = ["kernels/Square.cc"], diff --git a/tfjs-backend-wasm/src/cc/kernels/Neg.cc b/tfjs-backend-wasm/src/cc/kernels/Neg.cc new file mode 100644 index 00000000000..7570e1cc810 --- /dev/null +++ b/tfjs-backend-wasm/src/cc/kernels/Neg.cc @@ -0,0 +1,40 @@ +/* Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ===========================================================================*/ + +#ifdef __EMSCRIPTEN__ +#include +#endif + +#include + +#include "src/cc/backend.h" +#include "src/cc/unary.h" + +namespace { +inline float neg(const float val) { return -val; } +} // namespace + +namespace tfjs { +namespace wasm { +// We use C-style API to interface with Javascript. +extern "C" { + +#ifdef __EMSCRIPTEN__ +EMSCRIPTEN_KEEPALIVE +#endif +void Neg(const int x_id, const int out_id) { unary(x_id, out_id, neg); } + +} // extern "C" +} // namespace wasm +} // namespace tfjs diff --git a/tfjs-backend-wasm/src/cc/kernels/NotEqual.cc b/tfjs-backend-wasm/src/cc/kernels/NotEqual.cc new file mode 100644 index 00000000000..064bb4e6194 --- /dev/null +++ b/tfjs-backend-wasm/src/cc/kernels/NotEqual.cc @@ -0,0 +1,60 @@ +/* Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ===========================================================================*/ + +#ifdef __EMSCRIPTEN__ +#include +#endif + +#include "src/cc/binary.h" +#include "src/cc/util.h" + +namespace { +template +inline bool notEqual(T a, T b) { + return a != b; +} +} // namespace + +namespace tfjs { +namespace wasm { +// We use C-style API to interface with Javascript. +extern "C" { + +#ifdef __EMSCRIPTEN__ +EMSCRIPTEN_KEEPALIVE +#endif +void NotEqual(const int a_id, const size_t* a_shape_ptr, const int a_shape_len, + const int b_id, const size_t* b_shape_ptr, const int b_shape_len, + const DType input_type, const int out_id) { + switch (input_type) { + case DType::float32: + compare_f32(a_id, b_id, out_id, notEqual); + break; + case DType::int32: + compare_i32(a_id, b_id, out_id, notEqual); + break; + case DType::boolean: + compare_bool(a_id, b_id, out_id, notEqual); + break; + default: + util::warn( + "NotEqual for tensor ids %d and %d failed." + "Unsupported input_type %d", + a_id, b_id, input_type); + } +} + +} // extern "C" +} // namespace wasm +} // namespace tfjs diff --git a/tfjs-backend-wasm/src/cc/kernels/Sigmoid.h b/tfjs-backend-wasm/src/cc/kernels/Sigmoid.h index 4ed39787b34..f24a6415b03 100644 --- a/tfjs-backend-wasm/src/cc/kernels/Sigmoid.h +++ b/tfjs-backend-wasm/src/cc/kernels/Sigmoid.h @@ -22,7 +22,6 @@ namespace wasm { extern "C" { void Sigmoid(const size_t x_id, const size_t out_id); - } } // namespace wasm diff --git a/tfjs-backend-wasm/src/cc/kernels/Softmax.cc b/tfjs-backend-wasm/src/cc/kernels/Softmax.cc new file mode 100644 index 00000000000..84dc027cb52 --- /dev/null +++ b/tfjs-backend-wasm/src/cc/kernels/Softmax.cc @@ -0,0 +1,99 @@ +/* Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ===========================================================================*/ + +#ifdef __EMSCRIPTEN__ +#include +#endif + +#include +#include +#include +#include +#include + +#include "src/cc/backend.h" +#include "src/cc/kernels/Softmax.h" +#include "src/cc/util.h" + +namespace { +// We use std::tuple as the cache key as it implements the compare operator +// needed for std::map. +typedef std::tuple OperatorCacheKey; + +// The operator cache maps the weights id to the xnn_operator_t instantiated for +// this set of weights. +std::map operator_cache; + +} // namespace + +namespace tfjs { +namespace wasm { + +extern "C" { + +#ifdef __EMSCRIPTEN__ +EMSCRIPTEN_KEEPALIVE +#endif + +void Softmax(const size_t x_id, const size_t out_id, const size_t channels, + const size_t batch) { + auto& x_info = backend::get_tensor_info(x_id); + auto& out_info = backend::get_tensor_info_out(out_id); + + const float* x_buf = x_info.f32(); + float* out_buf = out_info.f32_write(); + + xnn_operator_t softmax_op = nullptr; + + OperatorCacheKey cache_key = {channels}; + + auto operator_cache_idx = operator_cache.find(cache_key); + if (operator_cache_idx == operator_cache.end()) { + const size_t input_stride = channels; + const size_t output_stride = channels; + const uint32_t flags = 0; + + xnn_status status = xnn_create_softmax_nc_f32( + channels, input_stride, output_stride, flags, &softmax_op); + if (status != xnn_status_success) { + tfjs::util::warn( + "XNN status for xnn_create_softmax_nc_f32 is not " + "successful. Got status %d. Use -c dbg to see XNN logs.", + status); + return; + } + + operator_cache.insert({cache_key, softmax_op}); + + tfjs::backend::xnn_operator_count++; + } else { + softmax_op = operator_cache_idx->second; + } + + xnn_status status = xnn_setup_softmax_nc_f32( + softmax_op, batch, x_buf, out_buf, nullptr /* thread pool */); + if (status != xnn_status_success) { + tfjs::util::warn( + "XNN status for xnn_setup_softmax_nc_f32 is not " + "successful. Got status %d. Use -c dbg to see XNN logs.", + status); + return; + } + + xnn_run_operator(softmax_op, nullptr /* thread pool */); +} + +} // extern "C" +} // namespace wasm +} // namespace tfjs diff --git a/tfjs-backend-wasm/src/cc/kernels/Softmax.h b/tfjs-backend-wasm/src/cc/kernels/Softmax.h new file mode 100644 index 00000000000..f2cd8ebc033 --- /dev/null +++ b/tfjs-backend-wasm/src/cc/kernels/Softmax.h @@ -0,0 +1,30 @@ +/* Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ===========================================================================*/ + +#ifndef KERNELS_SOFTMAX_H_ +#define KERNELS_SOFTMAX_H_ + +#include + +namespace tfjs { +namespace wasm { +extern "C" { +void Softmax(const size_t x_id, const size_t out_id, const size_t channels, + const size_t batch); +} + +} // namespace wasm +} // namespace tfjs + +#endif // KERNELS_SOFTMAX_H_ diff --git a/tfjs-backend-wasm/src/cc/kernels/Softmax_test.cc b/tfjs-backend-wasm/src/cc/kernels/Softmax_test.cc new file mode 100644 index 00000000000..eeec4413617 --- /dev/null +++ b/tfjs-backend-wasm/src/cc/kernels/Softmax_test.cc @@ -0,0 +1,59 @@ +/* Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ===========================================================================*/ + +#include + +#include +#include + +#include "src/cc/backend.h" +#include "src/cc/kernels/Softmax.h" + +TEST(SOFTMAX, xnn_operator_lifetime) { + tfjs::wasm::init(); + + ASSERT_EQ(0, tfjs::backend::num_tensors()); + + const size_t x0_id = 1; + const size_t x1_id = 2; + const size_t x_size = 4; + float x_values[x_size] = {1, 2, 2, 2}; + + const size_t out_id = 3; + const size_t out_size = 4; + float out_values[out_size] = {0, 0, 0, 0}; + + tfjs::wasm::register_tensor(x0_id, x_size, x_values); + tfjs::wasm::register_tensor(x1_id, x_size, x_values); + tfjs::wasm::register_tensor(out_id, out_size, out_values); + + ASSERT_EQ(3, tfjs::backend::num_tensors()); + ASSERT_EQ(0, tfjs::backend::xnn_operator_count); + + // One new xnn_operator should be created for the first call to Softmax. + tfjs::wasm::Softmax(x0_id, out_id, 4, 1); + ASSERT_EQ(1, tfjs::backend::xnn_operator_count); + + // No new xnn_operators should be created for the second call to + // Softmax with the same arguments. + tfjs::wasm::Softmax(x0_id, out_id, 4, 1); + ASSERT_EQ(1, tfjs::backend::xnn_operator_count); + + // No new xnn_operators should be created for the second call to + // Softmax with different arguments. + tfjs::wasm::Softmax(x1_id, out_id, 4, 1); + ASSERT_EQ(1, tfjs::backend::xnn_operator_count); + + tfjs::wasm::dispose(); +} diff --git a/tfjs-backend-wasm/src/kernels/Neg.ts b/tfjs-backend-wasm/src/kernels/Neg.ts new file mode 100644 index 00000000000..b226e799331 --- /dev/null +++ b/tfjs-backend-wasm/src/kernels/Neg.ts @@ -0,0 +1,19 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {registerUnaryKernel} from './unary_kernel'; +registerUnaryKernel('Neg'); diff --git a/tfjs-backend-wasm/src/kernels/NotEqual.ts b/tfjs-backend-wasm/src/kernels/NotEqual.ts new file mode 100644 index 00000000000..a20be2a92d8 --- /dev/null +++ b/tfjs-backend-wasm/src/kernels/NotEqual.ts @@ -0,0 +1,20 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {registerBinaryKernel} from './binary_kernel'; +const supportsFullBroadcast = false; +registerBinaryKernel('NotEqual', supportsFullBroadcast, 'bool'); diff --git a/tfjs-backend-wasm/src/kernels/Softmax.ts b/tfjs-backend-wasm/src/kernels/Softmax.ts new file mode 100644 index 00000000000..19eeee113f9 --- /dev/null +++ b/tfjs-backend-wasm/src/kernels/Softmax.ts @@ -0,0 +1,67 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {NamedAttrMap, NamedTensorInfoMap, registerKernel, TensorInfo, util} from '@tensorflow/tfjs-core'; + +import {BackendWasm} from '../backend_wasm'; + +interface SoftmaxInputs extends NamedTensorInfoMap { + x: TensorInfo; +} + +interface SoftmaxAttrs extends NamedAttrMap { + dim: number; +} + +let wasmFunc: (xId: number, outId: number, channels: number, batch: number) => + void; + +function setup(backend: BackendWasm): void { + wasmFunc = backend.wasm.cwrap('Softmax', null /* void */, [ + 'number', // xId + 'number', // outId + 'number', // channels + 'number' // batch + ]); +} + +function softmax( + args: {backend: BackendWasm, inputs: SoftmaxInputs, attrs: SoftmaxAttrs}): + TensorInfo { + const {backend, inputs: {logits}, attrs: {dim}} = args; + const xId = backend.dataIdMap.get(logits.dataId).id; + const out = backend.makeOutput(logits.shape, logits.dtype); + const outId = backend.dataIdMap.get(out.dataId).id; + + const channels = logits.shape[dim]; + const batch = util.sizeFromShape(logits.shape) / channels; + + // Short-circuit zero-sized tensors. + if (util.sizeFromShape(out.shape) === 0) { + return out; + } + + wasmFunc(xId, outId, channels, batch); + return out; +} + +registerKernel({ + kernelName: 'Softmax', + backendName: 'wasm', + setupFunc: setup, + kernelFunc: softmax +}); diff --git a/tfjs-backend-wasm/src/kernels/all_kernels.ts b/tfjs-backend-wasm/src/kernels/all_kernels.ts index 59de2028718..157a4cbd8fa 100644 --- a/tfjs-backend-wasm/src/kernels/all_kernels.ts +++ b/tfjs-backend-wasm/src/kernels/all_kernels.ts @@ -51,8 +51,10 @@ import './MaxPool'; import './Min'; import './Minimum'; import './Mul'; +import './Neg'; import './NonMaxSuppressionV3'; import './NonMaxSuppressionV5'; +import './NotEqual'; import './PadV2'; import './Prelu'; import './Relu'; @@ -64,6 +66,7 @@ import './ScatterNd'; import './Sigmoid'; import './Sin'; import './Slice'; +import './Softmax'; import './Square'; import './Sub'; import './Sum'; diff --git a/tfjs-backend-wasm/src/setup_test.ts b/tfjs-backend-wasm/src/setup_test.ts index 9b794fa25f0..8c4d289b7e1 100644 --- a/tfjs-backend-wasm/src/setup_test.ts +++ b/tfjs-backend-wasm/src/setup_test.ts @@ -34,6 +34,7 @@ const TEST_FILTERS: TestFilter[] = [ 'Tensor2D float32 -> bool', 'Tensor2D int32 -> bool' ] }, + {include: 'softmax'}, { include: 'add ', excludes: [ diff --git a/tfjs-backend-wasm/yarn.lock b/tfjs-backend-wasm/yarn.lock index 2e76213ffed..ac1064763a2 100644 --- a/tfjs-backend-wasm/yarn.lock +++ b/tfjs-backend-wasm/yarn.lock @@ -73,17 +73,9 @@ resolved "https://registry.yarnpkg.com/@bazel/hide-bazel-files/-/hide-bazel-files-0.38.3.tgz#e98231d3d360d51860d9c1a7c3345b40dab4cf81" integrity sha512-o+dNkfDm3qxWQ8h/04cWuTcjR7qnjZi3pQGv4aklVb16oPWx2jF8BzbkwvWuIkdbOl9VnqYP0vaHzwQVJRRcIA== -"@tensorflow/tfjs-core@1.5.2": - version "1.5.2" - resolved "https://registry.yarnpkg.com/@tensorflow/tfjs-core/-/tfjs-core-1.5.2.tgz#df76752cf7c43987df1548fb69820935bd8215d7" - integrity sha512-Rj6l8xf0PxrEKctvX3bvxjqhHLaCBQT0ChvqFK6//HBu8A1/ao4SzeVKpXKNnP9Niax+qV3c9U9VcOwwIkCMag== - dependencies: - "@types/offscreencanvas" "~2019.3.0" - "@types/seedrandom" "2.4.27" - "@types/webgl-ext" "0.0.30" - "@types/webgl2" "0.0.4" - node-fetch "~2.1.2" - seedrandom "2.4.3" +"@tensorflow/tfjs-core@link:../tfjs-core": + version "0.0.0" + uid "" "@types/emscripten@~0.0.34": version "0.0.34" diff --git a/tfjs-core/src/backends/backend.ts b/tfjs-core/src/backends/backend.ts index aeb9ab3897a..d5e9cd02b4b 100644 --- a/tfjs-core/src/backends/backend.ts +++ b/tfjs-core/src/backends/backend.ts @@ -305,6 +305,9 @@ export class KernelBackend implements TensorStorage, Backend, BackendTimer { expm1(x: T): T { return notYetImplemented('expm1'); } + softmax(x: T, dim: number): T { + return notYetImplemented('softmax'); + } log(x: T): T { return notYetImplemented('log'); } diff --git a/tfjs-core/src/backends/cpu/all_kernels.ts b/tfjs-core/src/backends/cpu/all_kernels.ts index 75920b35b88..b9191e1b4c2 100644 --- a/tfjs-core/src/backends/cpu/all_kernels.ts +++ b/tfjs-core/src/backends/cpu/all_kernels.ts @@ -18,5 +18,5 @@ // We explicitly import the modular kernels so they get registered in the // global registry when we compile the library. A modular build would replace // the contents of this file and import only the kernels that are needed. -import './square'; import './non_max_suppression_v5'; +import './square'; diff --git a/tfjs-core/src/backends/cpu/backend_cpu.ts b/tfjs-core/src/backends/cpu/backend_cpu.ts index ee2a1e6772a..f45855309a8 100644 --- a/tfjs-core/src/backends/cpu/backend_cpu.ts +++ b/tfjs-core/src/backends/cpu/backend_cpu.ts @@ -377,6 +377,17 @@ export class MathBackendCPU extends KernelBackend { return result.toTensor() as T; } + softmax(logits: T, dim: number): T { + const axes = util.parseAxisParam([dim], logits.shape); + const maxLogit = this.max(logits, axes); + const expandedShape = axis_util.expandShapeToKeepDim(maxLogit.shape, axes); + const a = this.subtract(logits, maxLogit.reshape(expandedShape)); + const b = this.exp(a); + const sumExp = this.sum(b, axes).reshape(expandedShape); + + return this.realDivide(b, sumExp) as T; + } + subtract(a: Tensor, b: Tensor): Tensor { if (a.dtype === 'complex64' || b.dtype === 'complex64') { return this.broadcastedBinaryComplexOp( diff --git a/tfjs-core/src/backends/webgl/backend_webgl.ts b/tfjs-core/src/backends/webgl/backend_webgl.ts index afa45013a53..a1dfe5c0086 100644 --- a/tfjs-core/src/backends/webgl/backend_webgl.ts +++ b/tfjs-core/src/backends/webgl/backend_webgl.ts @@ -1597,6 +1597,17 @@ export class MathBackendWebGL extends KernelBackend { return this.compileAndRun(program, [x]); } + softmax(logits: T, dim: number): T { + const axes = util.parseAxisParam([dim], logits.shape); + const maxLogit = this.max(logits, axes); + const expandedShape = axis_util.expandShapeToKeepDim(maxLogit.shape, axes); + const a = this.subtract(logits, maxLogit.reshape(expandedShape)); + const b = this.exp(a); + const sumExp = this.sum(b, axes).reshape(expandedShape); + + return this.realDivide(b, sumExp) as T; + } + log(x: T): T { if (this.shouldExecuteOnCPU([x])) { return this.cpuBackend.log(x); diff --git a/tfjs-core/src/ops/compare.ts b/tfjs-core/src/ops/compare.ts index 7b32b7d8b89..521d269a8ef 100644 --- a/tfjs-core/src/ops/compare.ts +++ b/tfjs-core/src/ops/compare.ts @@ -47,8 +47,9 @@ function notEqual_( let $b = convertToTensor(b, 'b', 'notEqual'); [$a, $b] = makeTypesMatch($a, $b); assertAndGetBroadcastShape($a.shape, $b.shape); - return ENGINE.runKernelFunc(backend => backend.notEqual($a, $b), {$a, $b}) as - T; + return ENGINE.runKernelFunc( + backend => backend.notEqual($a, $b), {a: $a, b: $b}, + null /* grad */, 'NotEqual') as T; } /** diff --git a/tfjs-core/src/ops/softmax.ts b/tfjs-core/src/ops/softmax.ts index c308db2ded4..fff8a3df23c 100644 --- a/tfjs-core/src/ops/softmax.ts +++ b/tfjs-core/src/ops/softmax.ts @@ -15,11 +15,13 @@ * ============================================================================= */ +import {ENGINE} from '../engine'; import {customGrad} from '../gradients'; import {Tensor} from '../tensor'; import {GradSaveFunc} from '../tensor_types'; import {convertToTensor} from '../tensor_util_env'; import {TensorLike} from '../types'; + import {op} from './operation'; /** @@ -43,7 +45,7 @@ import {op} from './operation'; */ /** @doc {heading: 'Operations', subheading: 'Normalization'} */ function softmax_(logits: T|TensorLike, dim = -1): T { - const $logits = convertToTensor(logits, 'logits', 'softmax'); + const $logits = convertToTensor(logits, 'logits', 'softmax', 'float32'); if (dim === -1) { dim = $logits.rank - 1; @@ -54,25 +56,26 @@ function softmax_(logits: T|TensorLike, dim = -1): T { `Logits was rank ${$logits.rank} and dim was ${dim}`); } - const customOp = customGrad((logits: Tensor, save: GradSaveFunc) => { - // Do it in log space for numerical stability. - // exp(X - logSumExp(X)) - const keepDims = true; - const lse = logits.logSumExp([dim], keepDims); - const logResult = logits.toFloat().sub(lse); - const y = logResult.exp() as T; - save([y]); - const gradFunc = (dy: T, saved: Tensor[]) => { - const [y] = saved; - const dyTimesY = dy.mul(y); - const keepDims = true; - return dyTimesY.sub(dyTimesY.sum([dim], keepDims).mul(y)); - }; + const inputsToSave: Tensor[] = []; + const outputsToSave = [true]; - return {value: y, gradFunc}; - }); + return ENGINE.runKernelFunc( + (backend, save) => { + const y = backend.softmax($logits, dim); + save([y]); + return y; + }, + {logits: $logits}, + (dy: T, saved: Tensor[]) => { + const [y] = saved; + const dyTimesY = dy.mul(y); + const keepDims = true; - return customOp($logits); + return { + logits: () => dyTimesY.sub(dyTimesY.sum([dim], keepDims).mul(y)) + }; + }, + 'Softmax', {dim}, inputsToSave, outputsToSave); } /** diff --git a/tfjs-core/src/ops/unary_ops.ts b/tfjs-core/src/ops/unary_ops.ts index beeb51f600f..7a33f1e0b42 100644 --- a/tfjs-core/src/ops/unary_ops.ts +++ b/tfjs-core/src/ops/unary_ops.ts @@ -40,9 +40,13 @@ function neg_(x: T|TensorLike): T { const $x = convertToTensor(x, 'x', 'neg'); const grad = (dy: T) => { - return {$x: () => dy.neg()}; + return {x: () => dy.neg()}; }; - return ENGINE.runKernelFunc(backend => backend.neg($x), {$x}, grad); + + const attrs = {}; + const inputsToSave = [$x]; + return ENGINE.runKernelFunc( + backend => backend.neg($x), {x: $x}, grad, 'Neg', attrs, inputsToSave); } /**