diff --git a/tfjs-backend-wasm/src/cc/kernels/Cos.cc b/tfjs-backend-wasm/src/cc/kernels/Cos.cc new file mode 100644 index 00000000000..e07ddfd98fd --- /dev/null +++ b/tfjs-backend-wasm/src/cc/kernels/Cos.cc @@ -0,0 +1,36 @@ +/* Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ===========================================================================*/ + +#ifdef __EMSCRIPTEN__ +#include +#endif + +#include + +#include "src/cc/backend.h" +#include "src/cc/unary.h" + +namespace tfjs { +namespace wasm { +// We use C-style API to interface with Javascript. +extern "C" { + +#ifdef __EMSCRIPTEN__ +EMSCRIPTEN_KEEPALIVE +#endif +void Cos(const int x_id, const int out_id) { unary(x_id, out_id, cos); } + +} // extern "C" +} // namespace wasm +} // namespace tfjs diff --git a/tfjs-backend-wasm/src/cc/kernels/Rsqrt.cc b/tfjs-backend-wasm/src/cc/kernels/Rsqrt.cc new file mode 100644 index 00000000000..cb62a1b0697 --- /dev/null +++ b/tfjs-backend-wasm/src/cc/kernels/Rsqrt.cc @@ -0,0 +1,39 @@ +/* Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ===========================================================================*/ + +#ifdef __EMSCRIPTEN__ +#include +#endif + +#include + +#include "src/cc/backend.h" +#include "src/cc/unary.h" + +namespace { +float rsqrt(const float a) { return 1 / sqrt(a); } +} // namespace +namespace tfjs { +namespace wasm { +// We use C-style API to interface with Javascript. +extern "C" { + +#ifdef __EMSCRIPTEN__ +EMSCRIPTEN_KEEPALIVE +#endif +void Rsqrt(const int x_id, const int out_id) { unary(x_id, out_id, rsqrt); } + +} // extern "C" +} // namespace wasm +} // namespace tfjs diff --git a/tfjs-backend-wasm/src/cc/kernels/Sin.cc b/tfjs-backend-wasm/src/cc/kernels/Sin.cc new file mode 100644 index 00000000000..3aff7190aed --- /dev/null +++ b/tfjs-backend-wasm/src/cc/kernels/Sin.cc @@ -0,0 +1,36 @@ +/* Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ===========================================================================*/ + +#ifdef __EMSCRIPTEN__ +#include +#endif + +#include + +#include "src/cc/backend.h" +#include "src/cc/unary.h" + +namespace tfjs { +namespace wasm { +// We use C-style API to interface with Javascript. +extern "C" { + +#ifdef __EMSCRIPTEN__ +EMSCRIPTEN_KEEPALIVE +#endif +void Sin(const int x_id, const int out_id) { unary(x_id, out_id, sin); } + +} // extern "C" +} // namespace wasm +} // namespace tfjs diff --git a/tfjs-backend-wasm/src/cc/kernels/Tanh.cc b/tfjs-backend-wasm/src/cc/kernels/Tanh.cc new file mode 100644 index 00000000000..61c98e894cd --- /dev/null +++ b/tfjs-backend-wasm/src/cc/kernels/Tanh.cc @@ -0,0 +1,36 @@ +/* Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ===========================================================================*/ + +#ifdef __EMSCRIPTEN__ +#include +#endif + +#include + +#include "src/cc/backend.h" +#include "src/cc/unary.h" + +namespace tfjs { +namespace wasm { +// We use C-style API to interface with Javascript. +extern "C" { + +#ifdef __EMSCRIPTEN__ +EMSCRIPTEN_KEEPALIVE +#endif +void Tanh(const int x_id, const int out_id) { unary(x_id, out_id, tanh); } + +} // extern "C" +} // namespace wasm +} // namespace tfjs diff --git a/tfjs-backend-wasm/src/kernels/Cos.ts b/tfjs-backend-wasm/src/kernels/Cos.ts new file mode 100644 index 00000000000..182dcd686f0 --- /dev/null +++ b/tfjs-backend-wasm/src/kernels/Cos.ts @@ -0,0 +1,19 @@ +/** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {registerUnaryKernel} from './unary_kernel'; +registerUnaryKernel('Cos'); diff --git a/tfjs-backend-wasm/src/kernels/Exp.ts b/tfjs-backend-wasm/src/kernels/Exp.ts index b131f2abd4b..ccbffc28abd 100644 --- a/tfjs-backend-wasm/src/kernels/Exp.ts +++ b/tfjs-backend-wasm/src/kernels/Exp.ts @@ -15,5 +15,5 @@ * ============================================================================= */ -import { registerUnaryKernel } from './unary_kernel'; +import {registerUnaryKernel} from './unary_kernel'; registerUnaryKernel('Exp'); diff --git a/tfjs-backend-wasm/src/kernels/Rsqrt.ts b/tfjs-backend-wasm/src/kernels/Rsqrt.ts new file mode 100644 index 00000000000..e8be1918bd5 --- /dev/null +++ b/tfjs-backend-wasm/src/kernels/Rsqrt.ts @@ -0,0 +1,19 @@ +/** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {registerUnaryKernel} from './unary_kernel'; +registerUnaryKernel('Rsqrt'); diff --git a/tfjs-backend-wasm/src/kernels/Sin.ts b/tfjs-backend-wasm/src/kernels/Sin.ts new file mode 100644 index 00000000000..5991ea16d21 --- /dev/null +++ b/tfjs-backend-wasm/src/kernels/Sin.ts @@ -0,0 +1,19 @@ +/** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {registerUnaryKernel} from './unary_kernel'; +registerUnaryKernel('Sin'); diff --git a/tfjs-backend-wasm/src/kernels/Tanh.ts b/tfjs-backend-wasm/src/kernels/Tanh.ts new file mode 100644 index 00000000000..390389bf816 --- /dev/null +++ b/tfjs-backend-wasm/src/kernels/Tanh.ts @@ -0,0 +1,19 @@ +/** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {registerUnaryKernel} from './unary_kernel'; +registerUnaryKernel('Tanh'); diff --git a/tfjs-backend-wasm/src/kernels/all_kernels.ts b/tfjs-backend-wasm/src/kernels/all_kernels.ts index 626114d6e37..389be69b21b 100644 --- a/tfjs-backend-wasm/src/kernels/all_kernels.ts +++ b/tfjs-backend-wasm/src/kernels/all_kernels.ts @@ -29,6 +29,7 @@ import './ClipByValue'; import './Concat'; import './Conv2D'; import './CropAndResize'; +import './Cos'; import './DepthwiseConv2dNative'; import './Div'; import './Exp'; @@ -55,11 +56,14 @@ import './Relu'; import './Relu6'; import './Reshape'; import './ResizeBilinear'; +import './Rsqrt'; import './Sigmoid'; +import './Sin'; import './Slice'; import './Square'; import './Sub'; import './Sum'; +import './Tanh'; import './Tile'; import './Transpose'; import './Unpack'; diff --git a/tfjs-backend-wasm/src/setup_test.ts b/tfjs-backend-wasm/src/setup_test.ts index ef7af28fdc8..a85766650bc 100644 --- a/tfjs-backend-wasm/src/setup_test.ts +++ b/tfjs-backend-wasm/src/setup_test.ts @@ -280,7 +280,23 @@ const TEST_FILTERS: TestFilter[] = [ 'gradient', // Gradient not yet implemented. 'string tensor' // String tensors not yet implemented. ] - } + }, + {startsWith: 'sin '}, + { + startsWith: 'cos ', + excludes: [ + 'gradient', // Gradient not yet implemented. + ] + }, + { + startsWith: 'tanh ', + excludes: ['gradient'] // Gradient not yet implemented. + }, + { + startsWith: 'rsqrt ', + excludes: ['gradient'] // Gradient not yet implemented. + }, + ]; const customInclude = (testName: string) => { diff --git a/tfjs-core/src/engine.ts b/tfjs-core/src/engine.ts index da81c7b6e67..d87b98ac828 100644 --- a/tfjs-core/src/engine.ts +++ b/tfjs-core/src/engine.ts @@ -575,7 +575,7 @@ export class Engine implements TensorTracker, DataMover { this.makeTensorFromDataId(dataId, shape, dtype)); const outsToSave = outTensors.filter((_, i) => outputsToSave[i]); // Save the inputs and outputs. - saveFunc(inputsToSave.slice().concat(outsToSave)); + saveFunc((inputsToSave || []).slice().concat(outsToSave)); return outTensors; }; } else { diff --git a/tfjs-core/src/ops/unary_ops.ts b/tfjs-core/src/ops/unary_ops.ts index 1e3cdbe3850..beeb51f600f 100644 --- a/tfjs-core/src/ops/unary_ops.ts +++ b/tfjs-core/src/ops/unary_ops.ts @@ -347,13 +347,14 @@ function rsqrt_(x: T|TensorLike): T { const grad = (dy: T, saved: Tensor[]) => { const [$x] = saved; - return {$x: () => dy.div($x.pow(1.5).mul(2)).neg() as T}; + return {x: () => dy.div($x.pow(1.5).mul(2)).neg() as T}; }; + const inputsToSave = [$x]; return ENGINE.runKernelFunc((backend, save) => { const res = backend.rsqrt($x); save([$x]); return res; - }, {$x}, grad); + }, {x: $x}, grad, 'Rsqrt', {} /* attrs */, inputsToSave); } /** @@ -541,13 +542,14 @@ function sin_(x: T|TensorLike): T { const grad = (dy: T, saved: Tensor[]) => { const [$x] = saved; - return {$x: () => $x.toFloat().cos().mul(dy)} as {$x: () => T}; + return {x: () => $x.toFloat().cos().mul(dy)} as {x: () => T}; }; + const inputsToSave = [$x]; return ENGINE.runKernelFunc((backend, save) => { const res = backend.sin($x); save([$x]); return res; - }, {$x}, grad); + }, {x: $x}, grad, 'Sin', {} /* attrs */, inputsToSave); } /** @@ -566,13 +568,14 @@ function cos_(x: T|TensorLike): T { const grad = (dy: T, saved: Tensor[]) => { const [$x] = saved; - return {$x: () => $x.toFloat().sin().neg().mul(dy)} as {$x: () => T}; + return {x: () => $x.toFloat().sin().neg().mul(dy)} as {x: () => T}; }; + const inputsToSave = [$x]; return ENGINE.runKernelFunc((backend, save) => { const res = backend.cos($x); save([$x]); return res; - }, {$x}, grad); + }, {x: $x}, grad, 'Cos', {} /* attrs */, inputsToSave); } /** @@ -746,13 +749,17 @@ function tanh_(x: T|TensorLike): T { const grad = (dy: T, saved: Tensor[]) => { const [y] = saved; - return {$x: () => scalar(1).sub(y.square()).mulStrict(dy) as T}; + return {x: () => scalar(1).sub(y.square()).mulStrict(dy) as T}; }; - return ENGINE.runKernelFunc((backend, save) => { - const y = backend.tanh($x); - save([y]); - return y; - }, {$x}, grad); + const outputsToSave = [true]; + return ENGINE.runKernelFunc( + (backend, save) => { + const y = backend.tanh($x); + save([y]); + return y; + }, + {x: $x}, grad, 'Tanh', {} /* attrs */, null /* inputsToSave */, + outputsToSave); } /**