From 0565a580bc6b1566f4dcf61983a3df91a3841dc8 Mon Sep 17 00:00:00 2001 From: Yannick Assogba Date: Tue, 7 Jul 2020 21:56:47 -0400 Subject: [PATCH 1/3] modularize abs, acos, acosh, asin, asinh, atan, atanh --- tfjs-core/src/gradients/Abs_grad.ts | 32 + tfjs-core/src/gradients/Acos_grad.ts | 44 + tfjs-core/src/gradients/Acosh_grad.ts | 40 + tfjs-core/src/gradients/Asin_grad.ts | 35 + tfjs-core/src/gradients/Asinh_grad.ts | 41 + tfjs-core/src/gradients/Atan_grad.ts | 34 + tfjs-core/src/gradients/Atanh_grad.ts | 35 + tfjs-core/src/kernel_names.ts | 21 + tfjs-core/src/ops/abs.ts | 53 + tfjs-core/src/ops/abs_test.ts | 167 +++ tfjs-core/src/ops/absolute_difference.ts | 4 +- tfjs-core/src/ops/acos.ts | 47 + tfjs-core/src/ops/acos_test.ts | 123 ++ tfjs-core/src/ops/acosh.ts | 49 + tfjs-core/src/ops/acosh_test.ts | 148 +++ tfjs-core/src/ops/asin.ts | 48 + tfjs-core/src/ops/asin_test.ts | 119 ++ tfjs-core/src/ops/asinh.ts | 50 + tfjs-core/src/ops/asinh_test.ts | 139 +++ tfjs-core/src/ops/atan.ts | 49 + tfjs-core/src/ops/atan_test.ts | 130 +++ tfjs-core/src/ops/atanh.ts | 50 + tfjs-core/src/ops/atanh_test.ts | 140 +++ tfjs-core/src/ops/huber_loss.ts | 3 +- tfjs-core/src/ops/norm.ts | 3 +- tfjs-core/src/ops/ops.ts | 7 + tfjs-core/src/ops/sigmoid_cross_entropy.ts | 3 +- tfjs-core/src/ops/unary_ops.ts | 213 ---- tfjs-core/src/ops/unary_ops_test.ts | 1192 +++----------------- tfjs-core/src/register_all_gradients.ts | 14 + tfjs-core/src/tests.ts | 8 + 31 files changed, 1815 insertions(+), 1226 deletions(-) create mode 100644 tfjs-core/src/gradients/Abs_grad.ts create mode 100644 tfjs-core/src/gradients/Acos_grad.ts create mode 100644 tfjs-core/src/gradients/Acosh_grad.ts create mode 100644 tfjs-core/src/gradients/Asin_grad.ts create mode 100644 tfjs-core/src/gradients/Asinh_grad.ts create mode 100644 tfjs-core/src/gradients/Atan_grad.ts create mode 100644 tfjs-core/src/gradients/Atanh_grad.ts create mode 100644 tfjs-core/src/ops/abs.ts create mode 100644 tfjs-core/src/ops/abs_test.ts create mode 100644 tfjs-core/src/ops/acos.ts create mode 100644 tfjs-core/src/ops/acos_test.ts create mode 100644 tfjs-core/src/ops/acosh.ts create mode 100644 tfjs-core/src/ops/acosh_test.ts create mode 100644 tfjs-core/src/ops/asin.ts create mode 100644 tfjs-core/src/ops/asin_test.ts create mode 100644 tfjs-core/src/ops/asinh.ts create mode 100644 tfjs-core/src/ops/asinh_test.ts create mode 100644 tfjs-core/src/ops/atan.ts create mode 100644 tfjs-core/src/ops/atan_test.ts create mode 100644 tfjs-core/src/ops/atanh.ts create mode 100644 tfjs-core/src/ops/atanh_test.ts diff --git a/tfjs-core/src/gradients/Abs_grad.ts b/tfjs-core/src/gradients/Abs_grad.ts new file mode 100644 index 00000000000..f0632fd3c23 --- /dev/null +++ b/tfjs-core/src/gradients/Abs_grad.ts @@ -0,0 +1,32 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Abs} from '../kernel_names'; +import {GradConfig} from '../kernel_registry'; +import {cast} from '../ops/array_ops'; +import {mul} from '../ops/mul'; +import {step} from '../ops/unary_ops'; +import {Tensor} from '../tensor'; + +export const absGradConfig: GradConfig = { + kernelName: Abs, + inputsToSave: ['x'], + gradFunc: (dy: Tensor, saved: Tensor[]) => { + const [x] = saved; + return {x: () => mul(dy, step(cast(x, 'float32'), -1))}; + } +}; diff --git a/tfjs-core/src/gradients/Acos_grad.ts b/tfjs-core/src/gradients/Acos_grad.ts new file mode 100644 index 00000000000..208e7dc4473 --- /dev/null +++ b/tfjs-core/src/gradients/Acos_grad.ts @@ -0,0 +1,44 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Acos} from '../kernel_names'; +import {GradConfig} from '../kernel_registry'; +import {cast} from '../ops/array_ops'; +import {div} from '../ops/div'; +import {neg} from '../ops/neg'; +import {square} from '../ops/square'; +import {sub} from '../ops/sub'; +import {scalar} from '../ops/tensor_ops'; +import {sqrt} from '../ops/unary_ops'; +import {Tensor} from '../tensor'; + +export const acosGradConfig: GradConfig = { + kernelName: Acos, + inputsToSave: ['x'], + gradFunc: (dy: Tensor, saved: Tensor[]) => { + const [x] = saved; + + return { + x: () => { + const a = square(cast(x, 'float32')); + const b = sqrt(sub(scalar(1), a)); + return neg(div(dy, b)); + } + + }; + } +}; diff --git a/tfjs-core/src/gradients/Acosh_grad.ts b/tfjs-core/src/gradients/Acosh_grad.ts new file mode 100644 index 00000000000..8e38f8c757f --- /dev/null +++ b/tfjs-core/src/gradients/Acosh_grad.ts @@ -0,0 +1,40 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Acosh} from '../kernel_names'; +import {GradConfig} from '../kernel_registry'; +import {cast} from '../ops/array_ops'; +import {div} from '../ops/div'; +import {square} from '../ops/square'; +import {sub} from '../ops/sub'; +import {sqrt} from '../ops/unary_ops'; +import {Tensor} from '../tensor'; + +export const acoshGradConfig: GradConfig = { + kernelName: Acosh, + inputsToSave: ['x'], + gradFunc: (dy: Tensor, saved: Tensor[]) => { + const [x] = saved; + + return { + x: () => { + const a = sqrt(sub(square(cast(x, 'float32')), 1)); + return div(dy, a); + } + }; + } +}; diff --git a/tfjs-core/src/gradients/Asin_grad.ts b/tfjs-core/src/gradients/Asin_grad.ts new file mode 100644 index 00000000000..290cd3ecb4f --- /dev/null +++ b/tfjs-core/src/gradients/Asin_grad.ts @@ -0,0 +1,35 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Asin} from '../kernel_names'; +import {GradConfig} from '../kernel_registry'; +import {cast} from '../ops/array_ops'; +import {div} from '../ops/div'; +import {square} from '../ops/square'; +import {sub} from '../ops/sub'; +import {scalar} from '../ops/tensor_ops'; +import {sqrt} from '../ops/unary_ops'; +import {Tensor} from '../tensor'; + +export const asinGradConfig: GradConfig = { + kernelName: Asin, + inputsToSave: ['x'], + gradFunc: (dy: Tensor, saved: Tensor[]) => { + const [x] = saved; + return {x: () => div(dy, sqrt(sub(scalar(1), square(cast(x, 'float32')))))}; + } +}; diff --git a/tfjs-core/src/gradients/Asinh_grad.ts b/tfjs-core/src/gradients/Asinh_grad.ts new file mode 100644 index 00000000000..8916f012aa1 --- /dev/null +++ b/tfjs-core/src/gradients/Asinh_grad.ts @@ -0,0 +1,41 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Asinh} from '../kernel_names'; +import {GradConfig} from '../kernel_registry'; +import {add} from '../ops/add'; +import {cast} from '../ops/array_ops'; +import {div} from '../ops/div'; +import {square} from '../ops/square'; +import {scalar} from '../ops/tensor_ops'; +import {sqrt} from '../ops/unary_ops'; +import {Tensor} from '../tensor'; + +export const asinhGradConfig: GradConfig = { + kernelName: Asinh, + inputsToSave: ['x'], + gradFunc: (dy: Tensor, saved: Tensor[]) => { + const [x] = saved; + + return { + x: () => { + const a = sqrt(add(scalar(1), square(cast(x, 'float32')))); + return div(dy, a); + } + }; + } +}; diff --git a/tfjs-core/src/gradients/Atan_grad.ts b/tfjs-core/src/gradients/Atan_grad.ts new file mode 100644 index 00000000000..3eea316857c --- /dev/null +++ b/tfjs-core/src/gradients/Atan_grad.ts @@ -0,0 +1,34 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Atan} from '../kernel_names'; +import {GradConfig} from '../kernel_registry'; +import {add} from '../ops/add'; +import {cast} from '../ops/array_ops'; +import {div} from '../ops/div'; +import {square} from '../ops/square'; +import {Tensor} from '../tensor'; + +export const atanGradConfig: GradConfig = { + kernelName: Atan, + inputsToSave: ['x'], + gradFunc: (dy: Tensor, saved: Tensor[]) => { + const [x] = saved; + + return {x: () => div(dy, add(square(cast(x, 'float32')), 1))}; + } +}; diff --git a/tfjs-core/src/gradients/Atanh_grad.ts b/tfjs-core/src/gradients/Atanh_grad.ts new file mode 100644 index 00000000000..67d1aae4685 --- /dev/null +++ b/tfjs-core/src/gradients/Atanh_grad.ts @@ -0,0 +1,35 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Atanh} from '../kernel_names'; +import {GradConfig} from '../kernel_registry'; +import {cast} from '../ops/array_ops'; +import {div} from '../ops/div'; +import {square} from '../ops/square'; +import {sub} from '../ops/sub'; +import {scalar} from '../ops/tensor_ops'; +import {Tensor} from '../tensor'; + +export const atanhGradConfig: GradConfig = { + kernelName: Atanh, + inputsToSave: ['x'], + gradFunc: (dy: Tensor, saved: Tensor[]) => { + const [x] = saved; + + return {x: () => div(dy, sub(scalar(1), square(cast(x, 'float32'))))}; + } +}; diff --git a/tfjs-core/src/kernel_names.ts b/tfjs-core/src/kernel_names.ts index 751c6d45f1a..97cb10a35f7 100644 --- a/tfjs-core/src/kernel_names.ts +++ b/tfjs-core/src/kernel_names.ts @@ -23,6 +23,15 @@ import {ExplicitPadding} from '../src/ops/conv_util'; import {NamedTensorInfoMap, TensorInfo} from './kernel_registry'; import {DataType, PixelData} from './types'; +export const Abs = 'Abs'; +export type AbsInputs = UnaryInputs; + +export const Acos = 'Acos'; +export type AcosInputs = UnaryInputs; + +export const Acosh = 'Acosh'; +export type AcoshInputs = UnaryInputs; + export const Add = 'Add'; export type AddInputs = BinaryInputs; @@ -55,6 +64,18 @@ export interface ArgMinAttrs { axis: number; } +export const Asin = 'Asin'; +export type AsinInputs = UnaryInputs; + +export const Asinh = 'Asinh'; +export type AsinhInputs = UnaryInputs; + +export const Atan = 'Atan'; +export type AtanInputs = UnaryInputs; + +export const Atanh = 'Atanh'; +export type AtanhInputs = UnaryInputs; + export const Atan2 = 'Atan2'; export type Atan2Inputs = BinaryInputs; diff --git a/tfjs-core/src/ops/abs.ts b/tfjs-core/src/ops/abs.ts new file mode 100644 index 00000000000..60a7acf7e1f --- /dev/null +++ b/tfjs-core/src/ops/abs.ts @@ -0,0 +1,53 @@ +/** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {ENGINE} from '../engine'; +import {Abs, AbsInputs} from '../kernel_names'; +import {Tensor} from '../tensor'; +import {NamedTensorMap} from '../tensor_types'; +import {convertToTensor} from '../tensor_util_env'; +import {TensorLike} from '../types'; + +import {op} from './operation'; + +/** + * Computes absolute value element-wise: `abs(x)` + * + * ```js + * const x = tf.tensor1d([-1, 2, -3, 4]); + * + * x.abs().print(); // or tf.abs(x) + * ``` + * @param x The input `tf.Tensor`. + */ +/** @doc {heading: 'Operations', subheading: 'Basic math'} */ +function abs_(x: T|TensorLike): T { + const $x = convertToTensor(x, 'x', 'abs'); + + const inputs: AbsInputs = {x: $x}; + + return ENGINE.runKernelFunc((backend, save) => { + save([$x]); + if ($x.dtype === 'complex64') { + return backend.complexAbs($x); + } + + return backend.abs($x); + }, inputs as {} as NamedTensorMap, null /* grad */, Abs); +} + +export const abs = op({abs_}); diff --git a/tfjs-core/src/ops/abs_test.ts b/tfjs-core/src/ops/abs_test.ts new file mode 100644 index 00000000000..b2ff9bf5fb2 --- /dev/null +++ b/tfjs-core/src/ops/abs_test.ts @@ -0,0 +1,167 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import * as tf from '../index'; +import {ALL_ENVS, describeWithFlags} from '../jasmine_util'; +import {expectArraysClose} from '../test_util'; + +describeWithFlags('abs', ALL_ENVS, () => { + it('basic', async () => { + const a = tf.tensor1d([1, -2, 0, 3, -0.1]); + const result = tf.abs(a); + expectArraysClose(await result.data(), [1, 2, 0, 3, 0.1]); + }); + + it('5D', async () => { + const a = tf.tensor5d([1, -2, 0, -3], [1, 2, 2, 1, 1]); + const result = tf.abs(a); + expectArraysClose(await result.data(), [1, 2, 0, 3]); + }); + + it('6D', async () => { + const a = tf.tensor6d([1, -2, 5, -3, -1, 4, 7, 8], [1, 2, 2, 2, 1, 1]); + const result = tf.abs(a); + expectArraysClose(await result.data(), [1, 2, 5, 3, 1, 4, 7, 8]); + }); + + it('complex64 rank-1', async () => { + const a = tf.complex([-2, -1, 0, 1, 2], [1, 2, 3, 0, -1]); + const result = tf.abs(a); + expectArraysClose(await result.data(), [ + Math.sqrt(-2 * -2 + 1 * 1), Math.sqrt(-1 * -1 + 2 * 2), + Math.sqrt(0 * 0 + 3 * 3), Math.sqrt(1 * 1 + 0 * 0), + Math.sqrt(2 * 2 + -1 * -1) + ]); + expect(result.shape).toEqual([5]); + }); + + it('complex64 rank-2', async () => { + const a = tf.complex([[-3, -2, -1], [0, 1, 2]], [[4, 1, 2], [3, 0, -1]]); + const result = tf.abs(a); + expectArraysClose(await result.data(), [ + Math.sqrt(-3 * -3 + 4 * 4), Math.sqrt(-2 * -2 + 1 * 1), + Math.sqrt(-1 * -1 + 2 * 2), Math.sqrt(0 * 0 + 3 * 3), + Math.sqrt(1 * 1 + 0 * 0), Math.sqrt(2 * 2 + -1 * -1) + ]); + expect(result.shape).toEqual([2, 3]); + }); + + it('complex64 rank-3', async () => { + const a = tf.complex( + [[[-3, -2], [-1, 0]], [[1, 2], [3, 4]]], + [[[4, 1], [2, 3]], [[0, -1], [-3, -4]]]); + const result = tf.abs(a); + expectArraysClose(await result.data(), [ + Math.sqrt(-3 * -3 + 4 * 4), Math.sqrt(-2 * -2 + 1 * 1), + Math.sqrt(-1 * -1 + 2 * 2), Math.sqrt(0 * 0 + 3 * 3), + Math.sqrt(1 * 1 + 0 * 0), Math.sqrt(2 * 2 + -1 * -1), + Math.sqrt(3 * 3 + -3 * -3), Math.sqrt(4 * 4 + -4 * -4) + ]); + expect(result.shape).toEqual([2, 2, 2]); + }); + + it('is underflow-safe for complex64', async () => { + const floatBits = tf.backend().floatPrecision(); + let small; + switch (floatBits) { + case 32: + small = 1e-30; + break; + case 16: + small = 1e-4; + break; + default: + throw new Error(`Test not implemented for ENV.engine.floatPrecision()=${ + floatBits}.`); + } + + const a = tf.complex([small, 0, small, 0], [small, small, 0, 0]); + const result = tf.abs(a); + expectArraysClose( + await result.data(), + [ + Math.hypot(small, small), Math.hypot(0, small), Math.hypot(small, 0), + Math.hypot(0, 0) + ], + /*tolerance=*/small / 100); + expect(result.shape).toEqual([4]); + }); + + it('propagates NaNs', async () => { + const a = tf.tensor1d([1, -2, 0, 3, -0.1, NaN]); + const result = tf.abs(a); + expectArraysClose(await result.data(), [1, 2, 0, 3, 0.1, NaN]); + }); + + it('gradients: Scalar', async () => { + const a = tf.scalar(4); + const dy = tf.scalar(8); + + const da = tf.grad(a => tf.abs(a))(a, dy); + + expect(da.shape).toEqual(a.shape); + expect(da.dtype).toEqual('float32'); + expectArraysClose(await da.data(), [8 * 1]); + }); + + it('gradient with clones', () => { + const a = tf.scalar(4); + const dy = tf.scalar(8); + + const da = tf.grad(a => a.clone().abs().clone())(a, dy); + + expect(da.shape).toEqual(a.shape); + expect(da.dtype).toEqual('float32'); + }); + + it('gradients: Tensor1D', async () => { + const a = tf.tensor1d([1, 2, -3, 5]); + const dy = tf.tensor1d([1, 2, 3, 4]); + + const da = tf.grad(a => tf.abs(a))(a, dy); + + expect(da.shape).toEqual(a.shape); + expect(da.dtype).toEqual('float32'); + expectArraysClose(await da.data(), [1 * 1, 2 * 1, 3 * -1, 4 * 1]); + }); + + it('gradients: Tensor2D', async () => { + const a = tf.tensor2d([3, -1, -2, 3], [2, 2]); + const dy = tf.tensor2d([1, 2, 3, 4], [2, 2]); + + const da = tf.grad(a => tf.abs(a))(a, dy); + + expect(da.shape).toEqual(a.shape); + expect(da.dtype).toEqual('float32'); + expectArraysClose(await da.data(), [1 * 1, 2 * -1, 3 * -1, 4 * 1]); + }); + + it('throws when passed a non-tensor', () => { + expect(() => tf.abs({} as tf.Tensor)) + .toThrowError(/Argument 'x' passed to 'abs' must be a Tensor/); + }); + + it('accepts a tensor-like object', async () => { + const result = tf.abs([1, -2, 0, 3, -0.1]); + expectArraysClose(await result.data(), [1, 2, 0, 3, 0.1]); + }); + + it('throws for string tensor', () => { + expect(() => tf.abs('q')) + .toThrowError(/Argument 'x' passed to 'abs' must be numeric/); + }); +}); diff --git a/tfjs-core/src/ops/absolute_difference.ts b/tfjs-core/src/ops/absolute_difference.ts index 0c6a1fe625d..f9abfe32c30 100644 --- a/tfjs-core/src/ops/absolute_difference.ts +++ b/tfjs-core/src/ops/absolute_difference.ts @@ -19,11 +19,13 @@ import {Tensor} from '../tensor'; import {convertToTensor} from '../tensor_util_env'; import {TensorLike} from '../types'; import {assertShapesMatch} from '../util'; + +import {abs} from './abs'; import {computeWeightedLoss} from './compute_weighted_loss'; import {Reduction} from './loss_ops_utils'; import {op} from './operation'; import {sub} from './sub'; -import {abs} from './unary_ops'; + /** * Computes the absolute difference loss between two tensors. diff --git a/tfjs-core/src/ops/acos.ts b/tfjs-core/src/ops/acos.ts new file mode 100644 index 00000000000..bcda5c5f8b6 --- /dev/null +++ b/tfjs-core/src/ops/acos.ts @@ -0,0 +1,47 @@ +/** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ +import {ENGINE} from '../engine'; +import {Acos, AcosInputs} from '../kernel_names'; +import {Tensor} from '../tensor'; +import {NamedTensorMap} from '../tensor_types'; +import {convertToTensor} from '../tensor_util_env'; +import {TensorLike} from '../types'; + +import {op} from './operation'; + +/** + * Computes acos of the input `tf.Tensor` element-wise: `acos(x)` + * + * ```js + * const x = tf.tensor1d([0, 1, -1, .7]); + * + * x.acos().print(); // or tf.acos(x) + * ``` + * @param x The input tensor. + */ +/** @doc {heading: 'Operations', subheading: 'Basic math'} */ +function acos_(x: T|TensorLike): T { + const $x = convertToTensor(x, 'x', 'acos'); + const inputs: AcosInputs = {x: $x}; + + return ENGINE.runKernelFunc((backend, save) => { + const res = backend.acos($x); + save([$x]); + return res; + }, inputs as {} as NamedTensorMap, null /* grad */, Acos); +} +export const acos = op({acos_}); diff --git a/tfjs-core/src/ops/acos_test.ts b/tfjs-core/src/ops/acos_test.ts new file mode 100644 index 00000000000..a78e6bf07c4 --- /dev/null +++ b/tfjs-core/src/ops/acos_test.ts @@ -0,0 +1,123 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import * as tf from '../index'; +import {ALL_ENVS, describeWithFlags} from '../jasmine_util'; +import {expectArraysClose} from '../test_util'; + +describeWithFlags('acos', ALL_ENVS, () => { + it('basic', async () => { + const values = [.1, -3, 2, 7, -4]; + const a = tf.tensor1d(values); + const result = tf.acos(a); + + const expected = []; + for (let i = 0; i < a.size; i++) { + expected[i] = Math.acos(values[i]); + } + expectArraysClose(await result.data(), expected); + }); + + it('propagates NaNs', async () => { + const a = tf.tensor1d([4, NaN, 0]); + const res = tf.acos(a); + expectArraysClose(await res.data(), [Math.acos(4), NaN, Math.acos(0)]); + }); + + it('gradients: Scalar', async () => { + const a = tf.scalar(0.5); + const dy = tf.scalar(8); + + const gradients = tf.grad(a => tf.acos(a))(a, dy); + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose( + await gradients.data(), [(-1 * 8) / Math.sqrt(1 - (0.5 * 0.5))]); + }); + + it('gradient with clones', async () => { + const a = tf.scalar(0.5); + const dy = tf.scalar(8); + + const gradients = tf.grad(a => tf.acos(a.clone()).clone())(a, dy); + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose( + await gradients.data(), [(-1 * 8) / Math.sqrt(1 - (0.5 * 0.5))]); + }); + + it('gradients: Tensor1D', async () => { + const aValues = [-0.1, 0.2, 0.3, -0.5]; + const dyValues = [1, 2, 3, 4]; + const a = tf.tensor1d(aValues); + const dy = tf.tensor1d(dyValues); + + const gradients = tf.grad(a => tf.acos(a))(a, dy); + + const expected = []; + for (let i = 0; i < a.size; i++) { + expected[i] = + (-1 * dyValues[i]) / Math.sqrt(1 - (aValues[i] * aValues[i])); + } + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), expected); + }); + + it('gradients: Tensor2D', async () => { + const aValues = [-0.3, 0.1, 0.2, 0.3]; + const dyValues = [1, 2, 3, 4]; + const a = tf.tensor2d(aValues, [2, 2]); + const dy = tf.tensor2d(dyValues, [2, 2]); + + const gradients = tf.grad(a => tf.acos(a))(a, dy); + + const expected = []; + for (let i = 0; i < a.size; i++) { + expected[i] = + (-1 * dyValues[i]) / Math.sqrt(1 - (aValues[i] * aValues[i])); + } + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), expected); + }); + + it('throws when passed a non-tensor', () => { + expect(() => tf.acos({} as tf.Tensor)) + .toThrowError(/Argument 'x' passed to 'acos' must be a Tensor/); + }); + + it('accepts a tensor-like object', async () => { + const values = [.1, -3, 2, 7, -4]; + const result = tf.acos(values); + + const expected = []; + for (let i = 0; i < values.length; i++) { + expected[i] = Math.acos(values[i]); + } + expectArraysClose(await result.data(), expected); + }); + + it('throws for string tensor', () => { + expect(() => tf.acos('q')) + .toThrowError(/Argument 'x' passed to 'acos' must be numeric/); + }); +}); diff --git a/tfjs-core/src/ops/acosh.ts b/tfjs-core/src/ops/acosh.ts new file mode 100644 index 00000000000..ff9cd794da3 --- /dev/null +++ b/tfjs-core/src/ops/acosh.ts @@ -0,0 +1,49 @@ +/** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {ENGINE} from '../engine'; +import {Acosh, AcoshInputs} from '../kernel_names'; +import {Tensor} from '../tensor'; +import {NamedTensorMap} from '../tensor_types'; +import {convertToTensor} from '../tensor_util_env'; +import {TensorLike} from '../types'; + +import {op} from './operation'; + +/** + * Computes the inverse hyperbolic cos of the input `tf.Tensor` element-wise: + * `acosh(x)` + * + * ```js + * const x = tf.tensor1d([10, 1, 3, 5.7]); + * + * x.acosh().print(); // or tf.acosh(x) + * ``` + * @param x The input tensor. + */ +/** @doc {heading: 'Operations', subheading: 'Basic math'} */ +function acosh_(x: T|TensorLike): T { + const $x = convertToTensor(x, 'x', 'acosh'); + const inputs: AcoshInputs = {x: $x}; + + return ENGINE.runKernelFunc((backend, save) => { + const res = backend.acosh($x); + save([$x]); + return res; + }, inputs as {} as NamedTensorMap, null /* grad */, Acosh); +} +export const acosh = op({acosh_}); diff --git a/tfjs-core/src/ops/acosh_test.ts b/tfjs-core/src/ops/acosh_test.ts new file mode 100644 index 00000000000..400eb39f23f --- /dev/null +++ b/tfjs-core/src/ops/acosh_test.ts @@ -0,0 +1,148 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import * as tf from '../index'; +import {ALL_ENVS, describeWithFlags} from '../jasmine_util'; +import {expectArraysClose} from '../test_util'; + +describeWithFlags('acosh', ALL_ENVS, () => { + it('basic', async () => { + const values = [2, 3, 4, 5, 6]; + const a = tf.tensor1d(values); + const result = tf.acosh(a); + + const expected = []; + for (let i = 0; i < a.size; i++) { + expected[i] = Math.acosh(values[i]); + } + expectArraysClose(await result.data(), expected); + }); + + it('scalar', async () => { + const value = 2; + const a = tf.scalar(value); + const result = tf.acosh(a); + + const expected = [Math.acosh(value)]; + expectArraysClose(await result.data(), expected); + }); + + it('tensor2d', async () => { + const values = [2, 3, 4, 5]; + const a = tf.tensor2d(values, [2, 2]); + const result = tf.acosh(a); + + const expected = []; + for (let i = 0; i < a.size; i++) { + expected[i] = Math.acosh(values[i]); + } + expectArraysClose(await result.data(), expected); + }); + + it('propagates NaNs', async () => { + const a = tf.tensor1d([4, NaN, 2]); + const res = tf.acosh(a); + expectArraysClose(await res.data(), [Math.acosh(4), NaN, Math.acosh(2)]); + }); + + it('NaN outside function domain', async () => { + const a = tf.tensor1d([4, -1, 2]); + const res = tf.acosh(a); + expectArraysClose(await res.data(), [Math.acosh(4), NaN, Math.acosh(2)]); + }); + + it('gradients: Scalar', async () => { + const a = tf.scalar(1.5); + const dy = tf.scalar(8); + + const gradients = tf.grad(a => tf.acosh(a))(a, dy); + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose( + await gradients.data(), [8.0 / Math.sqrt(1.5 * 1.5 - 1.0)]); + }); + + it('gradient with clones', async () => { + const a = tf.scalar(1.5); + const dy = tf.scalar(8); + + const gradients = tf.grad(a => tf.acosh(a.clone()).clone())(a, dy); + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose( + await gradients.data(), [8.0 / Math.sqrt(1.5 * 1.5 - 1.0)]); + }); + + it('gradients: Tensor1D', async () => { + const aValues = [2, 3, 5, 10]; + const dyValues = [1, 2, 3, 4]; + const a = tf.tensor1d(aValues); + const dy = tf.tensor1d(dyValues); + + const gradients = tf.grad(a => tf.acosh(a))(a, dy); + + const expected = []; + for (let i = 0; i < a.size; i++) { + expected[i] = dyValues[i] / Math.sqrt(Math.pow(aValues[i], 2) - 1.0); + } + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), expected); + }); + + it('gradients: Tensor2D', async () => { + const aValues = [2, 3, 5, 7]; + const dyValues = [1, 2, 3, 4]; + const a = tf.tensor2d(aValues, [2, 2]); + const dy = tf.tensor2d(dyValues, [2, 2]); + + const gradients = tf.grad(a => tf.acosh(a))(a, dy); + + const expected = []; + for (let i = 0; i < a.size; i++) { + expected[i] = dyValues[i] / Math.sqrt(Math.pow(aValues[i], 2) - 1.0); + } + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), expected); + }); + + it('throws when passed a non-tensor', () => { + expect(() => tf.acosh({} as tf.Tensor)) + .toThrowError(/Argument 'x' passed to 'acosh' must be a Tensor/); + }); + + it('accepts a tensor-like object', async () => { + const values = [2, 3, 4, 5, 6]; + const result = tf.acosh(values); + + const expected = []; + for (let i = 0; i < values.length; i++) { + expected[i] = Math.acosh(values[i]); + } + expectArraysClose(await result.data(), expected); + }); + + it('throws for string tensor', () => { + expect(() => tf.acosh('q')) + .toThrowError(/Argument 'x' passed to 'acosh' must be numeric/); + }); +}); diff --git a/tfjs-core/src/ops/asin.ts b/tfjs-core/src/ops/asin.ts new file mode 100644 index 00000000000..b11b3808b51 --- /dev/null +++ b/tfjs-core/src/ops/asin.ts @@ -0,0 +1,48 @@ +/** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {ENGINE} from '../engine'; +import {Asin, AsinInputs} from '../kernel_names'; +import {Tensor} from '../tensor'; +import {NamedTensorMap} from '../tensor_types'; +import {convertToTensor} from '../tensor_util_env'; +import {TensorLike} from '../types'; + +import {op} from './operation'; + +/** + * Computes asin of the input `tf.Tensor` element-wise: `asin(x)` + * + * ```js + * const x = tf.tensor1d([0, 1, -1, .7]); + * + * x.asin().print(); // or tf.asin(x) + * ``` + * @param x The input tensor. + */ +/** @doc {heading: 'Operations', subheading: 'Basic math'} */ +function asin_(x: T|TensorLike): T { + const $x = convertToTensor(x, 'x', 'asin'); + const inputs: AsinInputs = {x: $x}; + + return ENGINE.runKernelFunc((backend, save) => { + const res = backend.asin($x); + save([$x]); + return res; + }, inputs as {} as NamedTensorMap, null /* grad */, Asin); +} +export const asin = op({asin_}); diff --git a/tfjs-core/src/ops/asin_test.ts b/tfjs-core/src/ops/asin_test.ts new file mode 100644 index 00000000000..ff75ed18f0c --- /dev/null +++ b/tfjs-core/src/ops/asin_test.ts @@ -0,0 +1,119 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import * as tf from '../index'; +import {ALL_ENVS, describeWithFlags} from '../jasmine_util'; +import {expectArraysClose} from '../test_util'; + +describeWithFlags('asin', ALL_ENVS, () => { + it('basic', async () => { + const values = [.1, -3, 2, 7, -4]; + const a = tf.tensor1d(values); + const result = tf.asin(a); + + const expected = []; + for (let i = 0; i < a.size; i++) { + expected[i] = Math.asin(values[i]); + } + expectArraysClose(await result.data(), expected); + }); + + it('propagates NaNs', async () => { + const a = tf.tensor1d([4, NaN, 0]); + const res = tf.asin(a); + expectArraysClose(await res.data(), [Math.asin(4), NaN, Math.asin(0)]); + }); + + it('gradients: Scalar', async () => { + const a = tf.scalar(0.5); + const dy = tf.scalar(8); + + const gradients = tf.grad(a => tf.asin(a))(a, dy); + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), [8 / Math.sqrt(1 - (0.5 * 0.5))]); + }); + + it('gradient with clones', async () => { + const a = tf.scalar(0.5); + const dy = tf.scalar(8); + + const gradients = tf.grad(a => tf.asin(a.clone()).clone())(a, dy); + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), [8 / Math.sqrt(1 - (0.5 * 0.5))]); + }); + + it('gradients: Tensor1D', async () => { + const aValues = [-0.1, 0.2, 0.3, -0.5]; + const dyValues = [1, 2, 3, 4]; + const a = tf.tensor1d(aValues); + const dy = tf.tensor1d(dyValues); + + const gradients = tf.grad(a => tf.asin(a))(a, dy); + + const expected = []; + for (let i = 0; i < a.size; i++) { + expected[i] = dyValues[i] / Math.sqrt(1 - (aValues[i] * aValues[i])); + } + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), expected); + }); + + it('gradients: Tensor2D', async () => { + const aValues = [-0.3, 0.1, 0.2, 0.3]; + const dyValues = [1, 2, 3, 4]; + const a = tf.tensor2d(aValues, [2, 2]); + const dy = tf.tensor2d(dyValues, [2, 2]); + + const gradients = tf.grad(a => tf.asin(a))(a, dy); + + const expected = []; + for (let i = 0; i < a.size; i++) { + expected[i] = dyValues[i] / Math.sqrt(1 - (aValues[i] * aValues[i])); + } + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), expected); + }); + + it('throws when passed a non-tensor', () => { + expect(() => tf.asin({} as tf.Tensor)) + .toThrowError(/Argument 'x' passed to 'asin' must be a Tensor/); + }); + + it('accepts a tensor-like object', async () => { + const values = [.1, -3, 2, 7, -4]; + const result = tf.asin(values); + + const expected = []; + for (let i = 0; i < values.length; i++) { + expected[i] = Math.asin(values[i]); + } + expectArraysClose(await result.data(), expected); + }); + + it('throws for string tensor', () => { + expect(() => tf.asin('q')) + .toThrowError(/Argument 'x' passed to 'asin' must be numeric/); + }); +}); diff --git a/tfjs-core/src/ops/asinh.ts b/tfjs-core/src/ops/asinh.ts new file mode 100644 index 00000000000..90083da6544 --- /dev/null +++ b/tfjs-core/src/ops/asinh.ts @@ -0,0 +1,50 @@ +/** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {ENGINE} from '../engine'; +import {Asinh, AsinhInputs} from '../kernel_names'; +import {Tensor} from '../tensor'; +import {NamedTensorMap} from '../tensor_types'; +import {convertToTensor} from '../tensor_util_env'; +import {TensorLike} from '../types'; + +import {op} from './operation'; + +/** + * Computes inverse hyperbolic sin of the input `tf.Tensor` element-wise: + * `asinh(x)` + * + * ```js + * const x = tf.tensor1d([0, 1, -1, .7]); + * + * x.asinh().print(); // or tf.asinh(x) + * ``` + * @param x The input tensor. + */ +/** @doc {heading: 'Operations', subheading: 'Basic math'} */ +function asinh_(x: T|TensorLike): T { + const $x = convertToTensor(x, 'x', 'asinh'); + + const inputs: AsinhInputs = {x: $x}; + + return ENGINE.runKernelFunc((backend, save) => { + const res = backend.asinh($x); + save([$x]); + return res; + }, inputs as {} as NamedTensorMap, null /* grad */, Asinh); +} +export const asinh = op({asinh_}); diff --git a/tfjs-core/src/ops/asinh_test.ts b/tfjs-core/src/ops/asinh_test.ts new file mode 100644 index 00000000000..83251b18030 --- /dev/null +++ b/tfjs-core/src/ops/asinh_test.ts @@ -0,0 +1,139 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import * as tf from '../index'; +import {ALL_ENVS, describeWithFlags} from '../jasmine_util'; +import {expectArraysClose} from '../test_util'; + +describeWithFlags('asinh', ALL_ENVS, () => { + it('basic', async () => { + const values = [1, -3, 2, 7, -4]; + const a = tf.tensor1d(values); + const result = tf.asinh(a); + + const expected = []; + for (let i = 0; i < a.size; i++) { + expected[i] = Math.asinh(values[i]); + } + expectArraysClose(await result.data(), expected); + }); + + it('scalar', async () => { + const a = tf.scalar(1); + const result = tf.asinh(a); + + const expected = [Math.asinh(1)]; + expectArraysClose(await result.data(), expected); + }); + + it('tensor2D', async () => { + const values = [1, -3, 2, 7]; + const a = tf.tensor2d(values, [2, 2]); + const result = tf.asinh(a); + + const expected = []; + for (let i = 0; i < a.size; i++) { + expected[i] = Math.asinh(values[i]); + } + expectArraysClose(await result.data(), expected); + }); + + it('propagates NaNs', async () => { + const a = tf.tensor1d([4, NaN, 0]); + const res = tf.asinh(a); + expectArraysClose(await res.data(), [Math.asinh(4), NaN, Math.asinh(0)]); + }); + + it('gradients: Scalar', async () => { + const a = tf.scalar(0.5); + const dy = tf.scalar(8); + + const gradients = tf.grad(a => tf.asinh(a))(a, dy); + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), [8 / Math.sqrt(1.0 + 0.5 * 0.5)]); + }); + + it('gradient with clones', async () => { + const a = tf.scalar(0.5); + const dy = tf.scalar(8); + + const gradients = tf.grad(a => tf.asinh(a.clone()).clone())(a, dy); + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), [8 / Math.sqrt(1.0 + 0.5 * 0.5)]); + }); + + it('gradients: Tensor1D', async () => { + const aValues = [-1, 2, 3, -5]; + const dyValues = [1, 2, 3, 4]; + const a = tf.tensor1d(aValues); + const dy = tf.tensor1d(dyValues); + + const gradients = tf.grad(a => tf.asinh(a))(a, dy); + + const expected = []; + for (let i = 0; i < a.size; i++) { + expected[i] = dyValues[i] / Math.sqrt(1 + aValues[i] * aValues[i]); + } + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), expected); + }); + + it('gradients: Tensor2D', async () => { + const aValues = [-3, 1, 2, 3]; + const dyValues = [1, 2, 3, 4]; + const a = tf.tensor2d(aValues, [2, 2]); + const dy = tf.tensor2d(dyValues, [2, 2]); + + const gradients = tf.grad(a => tf.asinh(a))(a, dy); + + const expected = []; + for (let i = 0; i < a.size; i++) { + expected[i] = dyValues[i] / Math.sqrt(1 + aValues[i] * aValues[i]); + } + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), expected); + }); + + it('throws when passed a non-tensor', () => { + expect(() => tf.asinh({} as tf.Tensor)) + .toThrowError(/Argument 'x' passed to 'asinh' must be a Tensor/); + }); + + it('accepts a tensor-like object', async () => { + const values = [1, -3, 2, 7, -4]; + const result = tf.asinh(values); + + const expected = []; + for (let i = 0; i < values.length; i++) { + expected[i] = Math.asinh(values[i]); + } + expectArraysClose(await result.data(), expected); + }); + + it('throws for string tensor', () => { + expect(() => tf.asinh('q')) + .toThrowError(/Argument 'x' passed to 'asinh' must be numeric/); + }); +}); diff --git a/tfjs-core/src/ops/atan.ts b/tfjs-core/src/ops/atan.ts new file mode 100644 index 00000000000..dbc09500016 --- /dev/null +++ b/tfjs-core/src/ops/atan.ts @@ -0,0 +1,49 @@ +/** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {ENGINE} from '../engine'; +import {Atan, AtanInputs} from '../kernel_names'; +import {Tensor} from '../tensor'; +import {NamedTensorMap} from '../tensor_types'; +import {convertToTensor} from '../tensor_util_env'; +import {TensorLike} from '../types'; + +import {op} from './operation'; + +/** + * Computes atan of the input `tf.Tensor` element-wise: `atan(x)` + * + * ```js + * const x = tf.tensor1d([0, 1, -1, .7]); + * + * x.atan().print(); // or tf.atan(x) + * ``` + * @param x The input tensor. + */ +/** @doc {heading: 'Operations', subheading: 'Basic math'} */ +function atan_(x: T|TensorLike): T { + const $x = convertToTensor(x, 'x', 'atan'); + + const inputs: AtanInputs = {x: $x}; + + return ENGINE.runKernelFunc((backend, save) => { + const res = backend.atan($x); + save([$x]); + return res; + }, inputs as {} as NamedTensorMap, null /* grad */, Atan); +} +export const atan = op({atan_}); diff --git a/tfjs-core/src/ops/atan_test.ts b/tfjs-core/src/ops/atan_test.ts new file mode 100644 index 00000000000..79b7b85796a --- /dev/null +++ b/tfjs-core/src/ops/atan_test.ts @@ -0,0 +1,130 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import * as tf from '../index'; +import {ALL_ENVS, describeWithFlags} from '../jasmine_util'; +import {expectArraysClose} from '../test_util'; + +describeWithFlags('atan', ALL_ENVS, () => { + it('basic', async () => { + const values = [1, -3, 2, 7, -4]; + const a = tf.tensor1d(values); + const result = tf.atan(a); + + const expected = []; + for (let i = 0; i < a.size; i++) { + expected[i] = Math.atan(values[i]); + } + expectArraysClose(await result.data(), expected); + }); + + it('6D atan', async () => { + const a = tf.range(1, 65).reshape([2, 2, 2, 2, 2, 2]); + const result = tf.atan(a); + + const expected = []; + for (let i = 1; i < 65; ++i) { + expected[i - 1] = Math.atan(i); + } + expectArraysClose(await result.data(), expected); + }); + + it('propagates NaNs', async () => { + const a = tf.tensor1d([4, NaN, 0]); + const res = tf.atan(a); + expectArraysClose(await res.data(), [Math.atan(4), NaN, Math.atan(0)]); + }); + + it('gradients: Scalar', async () => { + const a = tf.scalar(0.5); + const dy = tf.scalar(8); + + const gradients = tf.grad(a => tf.atan(a))(a, dy); + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), [8 / (1 + (0.5 * 0.5))]); + }); + + it('gradient with clones', async () => { + const a = tf.scalar(0.5); + const dy = tf.scalar(8); + + const gradients = tf.grad(a => tf.atan(a.clone()).clone())(a, dy); + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), [8 / (1 + (0.5 * 0.5))]); + }); + + it('gradients: Tensor1D', async () => { + const aValues = [-0.1, 0.2, 0.3, -0.5]; + const dyValues = [1, 2, 3, 4]; + const a = tf.tensor1d(aValues); + const dy = tf.tensor1d(dyValues); + + const gradients = tf.grad(a => tf.atan(a))(a, dy); + + const expected = []; + for (let i = 0; i < a.size; i++) { + expected[i] = dyValues[i] / (1 + (aValues[i] * aValues[i])); + } + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), expected); + }); + + it('gradients: Tensor2D', async () => { + const aValues = [-0.3, 0.1, 0.2, 0.3]; + const dyValues = [1, 2, 3, 4]; + const a = tf.tensor2d(aValues, [2, 2]); + const dy = tf.tensor2d(dyValues, [2, 2]); + + const gradients = tf.grad(a => tf.atan(a))(a, dy); + + const expected = []; + for (let i = 0; i < a.size; i++) { + expected[i] = dyValues[i] / (1 + (aValues[i] * aValues[i])); + } + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), expected); + }); + + it('throws when passed a non-tensor', () => { + expect(() => tf.atan({} as tf.Tensor)) + .toThrowError(/Argument 'x' passed to 'atan' must be a Tensor/); + }); + + it('accepts a tensor-like object', async () => { + const values = [1, -3, 2, 7, -4]; + const result = tf.atan(values); + + const expected = []; + for (let i = 0; i < values.length; i++) { + expected[i] = Math.atan(values[i]); + } + expectArraysClose(await result.data(), expected); + }); + + it('throws for string tensor', () => { + expect(() => tf.atan('q')) + .toThrowError(/Argument 'x' passed to 'atan' must be numeric/); + }); +}); diff --git a/tfjs-core/src/ops/atanh.ts b/tfjs-core/src/ops/atanh.ts new file mode 100644 index 00000000000..083c333b9a8 --- /dev/null +++ b/tfjs-core/src/ops/atanh.ts @@ -0,0 +1,50 @@ +/** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {ENGINE} from '../engine'; +import {Atanh, AtanhInputs} from '../kernel_names'; +import {Tensor} from '../tensor'; +import {NamedTensorMap} from '../tensor_types'; +import {convertToTensor} from '../tensor_util_env'; +import {TensorLike} from '../types'; + +import {op} from './operation'; + +/** + * Computes inverse hyperbolic tan of the input `tf.Tensor` element-wise: + * `atanh(x)` + * + * ```js + * const x = tf.tensor1d([0, .1, -.1, .7]); + * + * x.atanh().print(); // or tf.atanh(x) + * ``` + * @param x The input tensor. + */ +/** @doc {heading: 'Operations', subheading: 'Basic math'} */ +function atanh_(x: T|TensorLike): T { + const $x = convertToTensor(x, 'x', 'atanh'); + + const inputs: AtanhInputs = {x: $x}; + + return ENGINE.runKernelFunc((backend, save) => { + const res = backend.atanh($x); + save([$x]); + return res; + }, inputs as {} as NamedTensorMap, null /* grad */, Atanh); +} +export const atanh = op({atanh_}); diff --git a/tfjs-core/src/ops/atanh_test.ts b/tfjs-core/src/ops/atanh_test.ts new file mode 100644 index 00000000000..26e3b4e89c5 --- /dev/null +++ b/tfjs-core/src/ops/atanh_test.ts @@ -0,0 +1,140 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import * as tf from '../index'; +import {ALL_ENVS, describeWithFlags} from '../jasmine_util'; +import {expectArraysClose} from '../test_util'; + +describeWithFlags('atanh', ALL_ENVS, () => { + it('basic', async () => { + const values = [-0.25, 0.25, 0.5, .75, -0.4]; + const a = tf.tensor1d(values); + const result = tf.atanh(a); + + const expected = []; + for (let i = 0; i < a.size; i++) { + expected[i] = Math.atanh(values[i]); + } + expectArraysClose(await result.data(), expected); + }); + + it('scalar', async () => { + const value = 0.2; + const a = tf.scalar(value); + const result = tf.atanh(a); + + const expected = [Math.atanh(value)]; + expectArraysClose(await result.data(), expected); + }); + + it('tensor2d', async () => { + const values = [0.2, 0.3, 0.4, 0.5]; + const a = tf.tensor2d(values, [2, 2]); + const result = tf.atanh(a); + + const expected = []; + for (let i = 0; i < a.size; i++) { + expected[i] = Math.atanh(values[i]); + } + expectArraysClose(await result.data(), expected); + }); + + it('propagates NaNs', async () => { + const a = tf.tensor1d([0.5, NaN, 0]); + const res = tf.atanh(a); + expectArraysClose(await res.data(), [Math.atanh(0.5), NaN, Math.atanh(0)]); + }); + + it('NaN outside function domain', async () => { + const a = tf.tensor1d([-2, 0, 2]); + const res = tf.atanh(a); + expectArraysClose(await res.data(), [NaN, Math.atanh(0), NaN]); + }); + + it('gradients: Scalar', async () => { + const a = tf.scalar(0.5); + const dy = tf.scalar(8); + + const gradients = tf.grad(a => tf.atanh(a))(a, dy); + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), [8 / (1 - 0.5 * 0.5)]); + }); + + it('gradient with clones', async () => { + const a = tf.scalar(0.5); + const dy = tf.scalar(8); + + const gradients = tf.grad(a => tf.atanh(a.clone()).clone())(a, dy); + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), [8 / (1 - 0.5 * 0.5)]); + }); + + it('gradients: Tensor1D', async () => { + const aValues = [-0.1, 0.2, 0.3, -0.5]; + const dyValues = [1, 2, 3, 4]; + const a = tf.tensor1d(aValues); + const dy = tf.tensor1d(dyValues); + + const gradients = tf.grad(a => tf.atanh(a))(a, dy); + + const expected = []; + for (let i = 0; i < a.size; i++) { + expected[i] = dyValues[i] / (1 - Math.pow(aValues[i], 2)); + } + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), expected); + }); + + it('gradients: Tensor2D', async () => { + const aValues = [-0.3, 0.1, 0.2, 0.3]; + const dyValues = [1, 2, 3, 4]; + const a = tf.tensor2d(aValues, [2, 2]); + const dy = tf.tensor2d(dyValues, [2, 2]); + + const gradients = tf.grad(a => tf.atanh(a))(a, dy); + + const expected = []; + for (let i = 0; i < a.size; i++) { + expected[i] = dyValues[i] / (1 - Math.pow(aValues[i], 2)); + } + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), expected); + }); + + it('throws when passed a non-tensor', () => { + expect(() => tf.atanh({} as tf.Tensor)) + .toThrowError(/Argument 'x' passed to 'atanh' must be a Tensor/); + }); + + it('accepts a tensor-like object', async () => { + const result = tf.atanh(0.2); + expectArraysClose(await result.data(), [Math.atanh(0.2)]); + }); + + it('throws for string tensor', () => { + expect(() => tf.atanh('q')) + .toThrowError(/Argument 'x' passed to 'atanh' must be numeric/); + }); +}); diff --git a/tfjs-core/src/ops/huber_loss.ts b/tfjs-core/src/ops/huber_loss.ts index 992f8207357..9ee27b0c9d4 100644 --- a/tfjs-core/src/ops/huber_loss.ts +++ b/tfjs-core/src/ops/huber_loss.ts @@ -19,6 +19,8 @@ import {Tensor} from '../tensor'; import {convertToTensor} from '../tensor_util_env'; import {TensorLike} from '../types'; import {assertShapesMatch} from '../util'; + +import {abs} from './abs'; import {add} from './add'; import {computeWeightedLoss} from './compute_weighted_loss'; import {Reduction} from './loss_ops_utils'; @@ -28,7 +30,6 @@ import {op} from './operation'; import {square} from './square'; import {sub} from './sub'; import {scalar} from './tensor_ops'; -import {abs} from './unary_ops'; /** * Computes the huber loss between two tensors. diff --git a/tfjs-core/src/ops/norm.ts b/tfjs-core/src/ops/norm.ts index c9f87121374..8a68364045f 100644 --- a/tfjs-core/src/ops/norm.ts +++ b/tfjs-core/src/ops/norm.ts @@ -20,6 +20,7 @@ import {convertToTensor} from '../tensor_util_env'; import {TensorLike} from '../types'; import {parseAxisParam} from '../util'; +import {abs} from './abs'; import * as axis_util from './axis_util'; import {max} from './max'; import {min} from './min'; @@ -29,7 +30,7 @@ import {reshape} from './reshape'; import {square} from './square'; import {sum} from './sum'; import {scalar} from './tensor_ops'; -import {abs, sqrt} from './unary_ops'; +import {sqrt} from './unary_ops'; /** * Computes the norm of scalar, vectors, and matrices. diff --git a/tfjs-core/src/ops/ops.ts b/tfjs-core/src/ops/ops.ts index 50fee3c0e72..99ecd497cbf 100644 --- a/tfjs-core/src/ops/ops.ts +++ b/tfjs-core/src/ops/ops.ts @@ -16,13 +16,20 @@ */ // Modularized ops. +export {abs} from './abs'; +export {acos} from './acos'; +export {acosh} from './acosh'; export {add} from './add'; export {addN} from './add_n'; export {all} from './all'; export {any} from './any'; export {argMax} from './arg_max'; export {argMin} from './arg_min'; +export {asin} from './asin'; +export {asinh} from './asinh'; +export {atan} from './atan'; export {atan2} from './atan2'; +export {atanh} from './atanh'; export {avgPool} from './avg_pool'; export {avgPool3d} from './avg_pool_3d'; export {basicLSTMCell} from './basic_lstm_cell'; diff --git a/tfjs-core/src/ops/sigmoid_cross_entropy.ts b/tfjs-core/src/ops/sigmoid_cross_entropy.ts index 36123ab567c..7316002b2da 100644 --- a/tfjs-core/src/ops/sigmoid_cross_entropy.ts +++ b/tfjs-core/src/ops/sigmoid_cross_entropy.ts @@ -20,6 +20,7 @@ import {convertToTensor} from '../tensor_util_env'; import {TensorLike} from '../types'; import {assertShapesMatch} from '../util'; +import {abs} from './abs'; import {add} from './add'; import {computeWeightedLoss} from './compute_weighted_loss'; import {Reduction} from './loss_ops_utils'; @@ -29,7 +30,7 @@ import {op} from './operation'; import {relu} from './relu'; import {sub} from './sub'; import {scalar} from './tensor_ops'; -import {abs, exp, log1p} from './unary_ops'; +import {exp, log1p} from './unary_ops'; function sigmoidCrossEntropyWithLogits_( labels: T|TensorLike, logits: T|TensorLike): O { diff --git a/tfjs-core/src/ops/unary_ops.ts b/tfjs-core/src/ops/unary_ops.ts index 3c2441a3066..dfc16f520f7 100644 --- a/tfjs-core/src/ops/unary_ops.ts +++ b/tfjs-core/src/ops/unary_ops.ts @@ -298,35 +298,6 @@ function reciprocal_(x: T|TensorLike): T { }, {$x}, grad); } -/** - * Computes absolute value element-wise: `abs(x)` - * - * ```js - * const x = tf.tensor1d([-1, 2, -3, 4]); - * - * x.abs().print(); // or tf.abs(x) - * ``` - * @param x The input `tf.Tensor`. - */ -/** @doc {heading: 'Operations', subheading: 'Basic math'} */ -function abs_(x: T|TensorLike): T { - const $x = convertToTensor(x, 'x', 'abs'); - - if ($x.dtype === 'complex64') { - return ENGINE.runKernelFunc(backend => backend.complexAbs($x), {$x}); - } - - const grad = (dy: T, saved: Tensor[]) => { - const [$x] = saved; - return {x: () => dy.mul($x.toFloat().step(-1))} as {x: () => T}; - }; - return ENGINE.runKernelFunc((backend, save) => { - const res = backend.abs($x); - save([$x]); - return res; - }, {x: $x}, grad, 'Abs'); -} - /** * Clips values element-wise. `max(min(x, clipValueMax), clipValueMin)` * @@ -520,92 +491,6 @@ function tan_(x: T|TensorLike): T { }, {$x}, grad); } -/** - * Computes asin of the input `tf.Tensor` element-wise: `asin(x)` - * - * ```js - * const x = tf.tensor1d([0, 1, -1, .7]); - * - * x.asin().print(); // or tf.asin(x) - * ``` - * @param x The input tensor. - */ -/** @doc {heading: 'Operations', subheading: 'Basic math'} */ -function asin_(x: T|TensorLike): T { - const $x = convertToTensor(x, 'x', 'asin'); - - const grad = (dy: T, saved: Tensor[]) => { - const [$x] = saved; - return { - // tslint:disable-next-line: no-unnecessary-type-assertion - $x: () => dy.div(scalar(1).sub($x.toFloat().square()).sqrt()) as T - }; - }; - return ENGINE.runKernelFunc((backend, save) => { - const res = backend.asin($x); - save([$x]); - return res; - }, {$x}, grad); -} - -/** - * Computes acos of the input `tf.Tensor` element-wise: `acos(x)` - * - * ```js - * const x = tf.tensor1d([0, 1, -1, .7]); - * - * x.acos().print(); // or tf.acos(x) - * ``` - * @param x The input tensor. - */ -/** @doc {heading: 'Operations', subheading: 'Basic math'} */ -function acos_(x: T|TensorLike): T { - const $x = convertToTensor(x, 'x', 'acos'); - - const grad = (dy: T, saved: Tensor[]) => { - const [$x] = saved; - return { - $x: () => { - const a = $x.toFloat().square(); - const b = scalar(1).sub(a).sqrt(); - // tslint:disable-next-line: no-unnecessary-type-assertion - return (dy.div(b) as T).neg(); - } - - }; - }; - return ENGINE.runKernelFunc((backend, save) => { - const res = backend.acos($x); - save([$x]); - return res; - }, {$x}, grad); -} - -/** - * Computes atan of the input `tf.Tensor` element-wise: `atan(x)` - * - * ```js - * const x = tf.tensor1d([0, 1, -1, .7]); - * - * x.atan().print(); // or tf.atan(x) - * ``` - * @param x The input tensor. - */ -/** @doc {heading: 'Operations', subheading: 'Basic math'} */ -function atan_(x: T|TensorLike): T { - const $x = convertToTensor(x, 'x', 'atan'); - - const grad = (dy: T, saved: Tensor[]) => { - const [$x] = saved; - return {$x: () => dy.div($x.toFloat().square().add(1))} as {$x: () => T}; - }; - return ENGINE.runKernelFunc((backend, save) => { - const res = backend.atan($x); - save([$x]); - return res; - }, {$x}, grad); -} - /** * Computes hyperbolic sin of the input `tf.Tensor` element-wise: `sinh(x)` * @@ -688,97 +573,6 @@ function tanh_(x: T|TensorLike): T { outputsToSave); } -/** - * Computes inverse hyperbolic sin of the input `tf.Tensor` element-wise: - * `asinh(x)` - * - * ```js - * const x = tf.tensor1d([0, 1, -1, .7]); - * - * x.asinh().print(); // or tf.asinh(x) - * ``` - * @param x The input tensor. - */ -/** @doc {heading: 'Operations', subheading: 'Basic math'} */ -function asinh_(x: T|TensorLike): T { - const $x = convertToTensor(x, 'x', 'asinh'); - - const grad = (dy: T, saved: Tensor[]) => { - const [$x] = saved; - return { - $x: () => { - const a = scalar(1).add($x.toFloat().square()).sqrt(); - // tslint:disable-next-line: no-unnecessary-type-assertion - return dy.div(a) as T; - } - }; - }; - return ENGINE.runKernelFunc((backend, save) => { - const res = backend.asinh($x); - save([$x]); - return res; - }, {$x}, grad); -} - -/** - * Computes the inverse hyperbolic cos of the input `tf.Tensor` element-wise: - * `acosh(x)` - * - * ```js - * const x = tf.tensor1d([10, 1, 3, 5.7]); - * - * x.acosh().print(); // or tf.acosh(x) - * ``` - * @param x The input tensor. - */ -/** @doc {heading: 'Operations', subheading: 'Basic math'} */ -function acosh_(x: T|TensorLike): T { - const $x = convertToTensor(x, 'x', 'acosh'); - - const grad = (dy: T, saved: Tensor[]) => { - const [$x] = saved; - return { - $x: () => { - const a = $x.toFloat().square().sub(1).sqrt(); - // tslint:disable-next-line: no-unnecessary-type-assertion - return dy.div(a) as T; - } - }; - }; - return ENGINE.runKernelFunc((backend, save) => { - const res = backend.acosh($x); - save([$x]); - return res; - }, {$x}, grad); -} - -/** - * Computes inverse hyperbolic tan of the input `tf.Tensor` element-wise: - * `atanh(x)` - * - * ```js - * const x = tf.tensor1d([0, .1, -.1, .7]); - * - * x.atanh().print(); // or tf.atanh(x) - * ``` - * @param x The input tensor. - */ -/** @doc {heading: 'Operations', subheading: 'Basic math'} */ -function atanh_(x: T|TensorLike): T { - const $x = convertToTensor(x, 'x', 'atanh'); - - const grad = (dy: T, saved: Tensor[]) => { - const [$x] = saved; - return {$x: () => dy.div(scalar(1).sub($x.toFloat().square()))} as - {$x: () => T}; - }; - return ENGINE.runKernelFunc((backend, save) => { - const res = backend.atanh($x); - save([$x]); - return res; - }, {$x}, grad); -} - /** * Computes gause error function of the input `tf.Tensor` element-wise: * `erf(x)` @@ -837,13 +631,6 @@ function step_(x: T|TensorLike, alpha = 0.0): T { return ENGINE.runKernelFunc(backend => backend.step($x, alpha), {$x}, grad); } -export const abs = op({abs_}); -export const acos = op({acos_}); -export const acosh = op({acosh_}); -export const asin = op({asin_}); -export const asinh = op({asinh_}); -export const atan = op({atan_}); -export const atanh = op({atanh_}); export const clipByValue = op({clipByValue_}); export const cos = op({cos_}); export const cosh = op({cosh_}); diff --git a/tfjs-core/src/ops/unary_ops_test.ts b/tfjs-core/src/ops/unary_ops_test.ts index e7538511175..3bb25750325 100644 --- a/tfjs-core/src/ops/unary_ops_test.ts +++ b/tfjs-core/src/ops/unary_ops_test.ts @@ -20,153 +20,6 @@ import {ALL_ENVS, describeWithFlags} from '../jasmine_util'; import {expectArraysClose, TEST_EPSILON_FLOAT16} from '../test_util'; import * as util from '../util'; -describeWithFlags('abs', ALL_ENVS, () => { - it('basic', async () => { - const a = tf.tensor1d([1, -2, 0, 3, -0.1]); - const result = tf.abs(a); - expectArraysClose(await result.data(), [1, 2, 0, 3, 0.1]); - }); - - it('5D', async () => { - const a = tf.tensor5d([1, -2, 0, -3], [1, 2, 2, 1, 1]); - const result = tf.abs(a); - expectArraysClose(await result.data(), [1, 2, 0, 3]); - }); - - it('6D', async () => { - const a = tf.tensor6d([1, -2, 5, -3, -1, 4, 7, 8], [1, 2, 2, 2, 1, 1]); - const result = tf.abs(a); - expectArraysClose(await result.data(), [1, 2, 5, 3, 1, 4, 7, 8]); - }); - - it('complex64 rank-1', async () => { - const a = tf.complex([-2, -1, 0, 1, 2], [1, 2, 3, 0, -1]); - const result = tf.abs(a); - expectArraysClose(await result.data(), [ - Math.sqrt(-2 * -2 + 1 * 1), Math.sqrt(-1 * -1 + 2 * 2), - Math.sqrt(0 * 0 + 3 * 3), Math.sqrt(1 * 1 + 0 * 0), - Math.sqrt(2 * 2 + -1 * -1) - ]); - expect(result.shape).toEqual([5]); - }); - - it('complex64 rank-2', async () => { - const a = tf.complex([[-3, -2, -1], [0, 1, 2]], [[4, 1, 2], [3, 0, -1]]); - const result = tf.abs(a); - expectArraysClose(await result.data(), [ - Math.sqrt(-3 * -3 + 4 * 4), Math.sqrt(-2 * -2 + 1 * 1), - Math.sqrt(-1 * -1 + 2 * 2), Math.sqrt(0 * 0 + 3 * 3), - Math.sqrt(1 * 1 + 0 * 0), Math.sqrt(2 * 2 + -1 * -1) - ]); - expect(result.shape).toEqual([2, 3]); - }); - - it('complex64 rank-3', async () => { - const a = tf.complex( - [[[-3, -2], [-1, 0]], [[1, 2], [3, 4]]], - [[[4, 1], [2, 3]], [[0, -1], [-3, -4]]]); - const result = tf.abs(a); - expectArraysClose(await result.data(), [ - Math.sqrt(-3 * -3 + 4 * 4), Math.sqrt(-2 * -2 + 1 * 1), - Math.sqrt(-1 * -1 + 2 * 2), Math.sqrt(0 * 0 + 3 * 3), - Math.sqrt(1 * 1 + 0 * 0), Math.sqrt(2 * 2 + -1 * -1), - Math.sqrt(3 * 3 + -3 * -3), Math.sqrt(4 * 4 + -4 * -4) - ]); - expect(result.shape).toEqual([2, 2, 2]); - }); - - it('is underflow-safe for complex64', async () => { - const floatBits = tf.backend().floatPrecision(); - let small; - switch (floatBits) { - case 32: - small = 1e-30; - break; - case 16: - small = 1e-4; - break; - default: - throw new Error(`Test not implemented for ENV.engine.floatPrecision()=${ - floatBits}.`); - } - - const a = tf.complex([small, 0, small, 0], [small, small, 0, 0]); - const result = tf.abs(a); - expectArraysClose( - await result.data(), - [ - Math.hypot(small, small), Math.hypot(0, small), Math.hypot(small, 0), - Math.hypot(0, 0) - ], - /*tolerance=*/small / 100); - expect(result.shape).toEqual([4]); - }); - - it('propagates NaNs', async () => { - const a = tf.tensor1d([1, -2, 0, 3, -0.1, NaN]); - const result = tf.abs(a); - expectArraysClose(await result.data(), [1, 2, 0, 3, 0.1, NaN]); - }); - - it('gradients: Scalar', async () => { - const a = tf.scalar(4); - const dy = tf.scalar(8); - - const da = tf.grad(a => tf.abs(a))(a, dy); - - expect(da.shape).toEqual(a.shape); - expect(da.dtype).toEqual('float32'); - expectArraysClose(await da.data(), [8 * 1]); - }); - - it('gradient with clones', () => { - const a = tf.scalar(4); - const dy = tf.scalar(8); - - const da = tf.grad(a => a.clone().abs().clone())(a, dy); - - expect(da.shape).toEqual(a.shape); - expect(da.dtype).toEqual('float32'); - }); - - it('gradients: Tensor1D', async () => { - const a = tf.tensor1d([1, 2, -3, 5]); - const dy = tf.tensor1d([1, 2, 3, 4]); - - const da = tf.grad(a => tf.abs(a))(a, dy); - - expect(da.shape).toEqual(a.shape); - expect(da.dtype).toEqual('float32'); - expectArraysClose(await da.data(), [1 * 1, 2 * 1, 3 * -1, 4 * 1]); - }); - - it('gradients: Tensor2D', async () => { - const a = tf.tensor2d([3, -1, -2, 3], [2, 2]); - const dy = tf.tensor2d([1, 2, 3, 4], [2, 2]); - - const da = tf.grad(a => tf.abs(a))(a, dy); - - expect(da.shape).toEqual(a.shape); - expect(da.dtype).toEqual('float32'); - expectArraysClose(await da.data(), [1 * 1, 2 * -1, 3 * -1, 4 * 1]); - }); - - it('throws when passed a non-tensor', () => { - expect(() => tf.abs({} as tf.Tensor)) - .toThrowError(/Argument 'x' passed to 'abs' must be a Tensor/); - }); - - it('accepts a tensor-like object', async () => { - const result = tf.abs([1, -2, 0, 3, -0.1]); - expectArraysClose(await result.data(), [1, 2, 0, 3, 0.1]); - }); - - it('throws for string tensor', () => { - expect(() => tf.abs('q')) - .toThrowError(/Argument 'x' passed to 'abs' must be numeric/); - }); -}); - describeWithFlags('step', ALL_ENVS, () => { it('with 1d tensor', async () => { const a = tf.tensor1d([1, -2, -.01, 3, -0.1]); @@ -1780,58 +1633,58 @@ describeWithFlags('tan', ALL_ENVS, () => { }); }); -describeWithFlags('asin', ALL_ENVS, () => { +describeWithFlags('sinh', ALL_ENVS, () => { it('basic', async () => { - const values = [.1, -3, 2, 7, -4]; + const values = [1, -3, 2, -1, -4]; const a = tf.tensor1d(values); - const result = tf.asin(a); + const result = tf.sinh(a); const expected = []; for (let i = 0; i < a.size; i++) { - expected[i] = Math.asin(values[i]); + expected[i] = Math.sinh(values[i]); } expectArraysClose(await result.data(), expected); }); it('propagates NaNs', async () => { const a = tf.tensor1d([4, NaN, 0]); - const res = tf.asin(a); - expectArraysClose(await res.data(), [Math.asin(4), NaN, Math.asin(0)]); + const res = tf.sinh(a); + expectArraysClose(await res.data(), [Math.sinh(4), NaN, Math.sinh(0)]); }); it('gradients: Scalar', async () => { const a = tf.scalar(0.5); const dy = tf.scalar(8); - const gradients = tf.grad(a => tf.asin(a))(a, dy); + const gradients = tf.grad(a => tf.sinh(a))(a, dy); expect(gradients.shape).toEqual(a.shape); expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), [8 / Math.sqrt(1 - (0.5 * 0.5))]); + expectArraysClose(await gradients.data(), [8 * Math.cosh(0.5)]); }); it('gradient with clones', async () => { const a = tf.scalar(0.5); const dy = tf.scalar(8); - const gradients = tf.grad(a => tf.asin(a.clone()).clone())(a, dy); + const gradients = tf.grad(a => tf.sinh(a.clone()).clone())(a, dy); expect(gradients.shape).toEqual(a.shape); expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), [8 / Math.sqrt(1 - (0.5 * 0.5))]); + expectArraysClose(await gradients.data(), [8 * Math.cosh(0.5)]); }); it('gradients: Tensor1D', async () => { - const aValues = [-0.1, 0.2, 0.3, -0.5]; + const aValues = [-1, 2, 3, -5]; const dyValues = [1, 2, 3, 4]; const a = tf.tensor1d(aValues); const dy = tf.tensor1d(dyValues); - const gradients = tf.grad(a => tf.asin(a))(a, dy); + const gradients = tf.grad(a => tf.sinh(a))(a, dy); const expected = []; for (let i = 0; i < a.size; i++) { - expected[i] = dyValues[i] / Math.sqrt(1 - (aValues[i] * aValues[i])); + expected[i] = dyValues[i] * Math.cosh(aValues[i]); } expect(gradients.shape).toEqual(a.shape); @@ -1840,16 +1693,16 @@ describeWithFlags('asin', ALL_ENVS, () => { }); it('gradients: Tensor2D', async () => { - const aValues = [-0.3, 0.1, 0.2, 0.3]; + const aValues = [-3, 1, 2, 3]; const dyValues = [1, 2, 3, 4]; const a = tf.tensor2d(aValues, [2, 2]); const dy = tf.tensor2d(dyValues, [2, 2]); - const gradients = tf.grad(a => tf.asin(a))(a, dy); + const gradients = tf.grad(a => tf.sinh(a))(a, dy); const expected = []; for (let i = 0; i < a.size; i++) { - expected[i] = dyValues[i] / Math.sqrt(1 - (aValues[i] * aValues[i])); + expected[i] = dyValues[i] * Math.cosh(aValues[i]); } expect(gradients.shape).toEqual(a.shape); @@ -1858,82 +1711,80 @@ describeWithFlags('asin', ALL_ENVS, () => { }); it('throws when passed a non-tensor', () => { - expect(() => tf.asin({} as tf.Tensor)) - .toThrowError(/Argument 'x' passed to 'asin' must be a Tensor/); + expect(() => tf.sinh({} as tf.Tensor)) + .toThrowError(/Argument 'x' passed to 'sinh' must be a Tensor/); }); it('accepts a tensor-like object', async () => { - const values = [.1, -3, 2, 7, -4]; - const result = tf.asin(values); + const values = [1, -3, 2, -1, -4]; + const result = tf.sinh(values); const expected = []; for (let i = 0; i < values.length; i++) { - expected[i] = Math.asin(values[i]); + expected[i] = Math.sinh(values[i]); } expectArraysClose(await result.data(), expected); }); it('throws for string tensor', () => { - expect(() => tf.asin('q')) - .toThrowError(/Argument 'x' passed to 'asin' must be numeric/); + expect(() => tf.sinh('q')) + .toThrowError(/Argument 'x' passed to 'sinh' must be numeric/); }); }); -describeWithFlags('acos', ALL_ENVS, () => { +describeWithFlags('cosh', ALL_ENVS, () => { it('basic', async () => { - const values = [.1, -3, 2, 7, -4]; + const values = [1, -3, 2, -1, -4]; const a = tf.tensor1d(values); - const result = tf.acos(a); + const result = tf.cosh(a); const expected = []; for (let i = 0; i < a.size; i++) { - expected[i] = Math.acos(values[i]); + expected[i] = Math.cosh(values[i]); } + expectArraysClose(await result.data(), expected); }); it('propagates NaNs', async () => { const a = tf.tensor1d([4, NaN, 0]); - const res = tf.acos(a); - expectArraysClose(await res.data(), [Math.acos(4), NaN, Math.acos(0)]); + const res = tf.cosh(a); + expectArraysClose(await res.data(), [Math.cosh(4), NaN, Math.cosh(0)]); }); it('gradients: Scalar', async () => { const a = tf.scalar(0.5); const dy = tf.scalar(8); - const gradients = tf.grad(a => tf.acos(a))(a, dy); + const gradients = tf.grad(a => tf.cosh(a))(a, dy); expect(gradients.shape).toEqual(a.shape); expect(gradients.dtype).toEqual('float32'); - expectArraysClose( - await gradients.data(), [(-1 * 8) / Math.sqrt(1 - (0.5 * 0.5))]); + expectArraysClose(await gradients.data(), [8 * Math.sinh(0.5)]); }); it('gradient with clones', async () => { const a = tf.scalar(0.5); const dy = tf.scalar(8); - const gradients = tf.grad(a => tf.acos(a.clone()).clone())(a, dy); + const gradients = tf.grad(a => tf.cosh(a.clone()).clone())(a, dy); expect(gradients.shape).toEqual(a.shape); expect(gradients.dtype).toEqual('float32'); - expectArraysClose( - await gradients.data(), [(-1 * 8) / Math.sqrt(1 - (0.5 * 0.5))]); + expectArraysClose(await gradients.data(), [8 * Math.sinh(0.5)]); }); it('gradients: Tensor1D', async () => { - const aValues = [-0.1, 0.2, 0.3, -0.5]; + const aValues = [-1, 2, 3, -5]; const dyValues = [1, 2, 3, 4]; const a = tf.tensor1d(aValues); const dy = tf.tensor1d(dyValues); - const gradients = tf.grad(a => tf.acos(a))(a, dy); + const gradients = tf.grad(a => tf.cosh(a))(a, dy); const expected = []; for (let i = 0; i < a.size; i++) { - expected[i] = - (-1 * dyValues[i]) / Math.sqrt(1 - (aValues[i] * aValues[i])); + expected[i] = dyValues[i] * Math.sinh(aValues[i]); } expect(gradients.shape).toEqual(a.shape); @@ -1942,17 +1793,16 @@ describeWithFlags('acos', ALL_ENVS, () => { }); it('gradients: Tensor2D', async () => { - const aValues = [-0.3, 0.1, 0.2, 0.3]; + const aValues = [-3, 1, 2, 3]; const dyValues = [1, 2, 3, 4]; const a = tf.tensor2d(aValues, [2, 2]); const dy = tf.tensor2d(dyValues, [2, 2]); - const gradients = tf.grad(a => tf.acos(a))(a, dy); + const gradients = tf.grad(a => tf.cosh(a))(a, dy); const expected = []; for (let i = 0; i < a.size; i++) { - expected[i] = - (-1 * dyValues[i]) / Math.sqrt(1 - (aValues[i] * aValues[i])); + expected[i] = dyValues[i] * Math.sinh(aValues[i]); } expect(gradients.shape).toEqual(a.shape); @@ -1961,90 +1811,83 @@ describeWithFlags('acos', ALL_ENVS, () => { }); it('throws when passed a non-tensor', () => { - expect(() => tf.acos({} as tf.Tensor)) - .toThrowError(/Argument 'x' passed to 'acos' must be a Tensor/); + expect(() => tf.cosh({} as tf.Tensor)) + .toThrowError(/Argument 'x' passed to 'cosh' must be a Tensor/); }); it('accepts a tensor-like object', async () => { - const values = [.1, -3, 2, 7, -4]; - const result = tf.acos(values); + const values = [1, -3, 2, -1, -4]; + const result = tf.cosh(values); const expected = []; for (let i = 0; i < values.length; i++) { - expected[i] = Math.acos(values[i]); + expected[i] = Math.cosh(values[i]); } + expectArraysClose(await result.data(), expected); }); it('throws for string tensor', () => { - expect(() => tf.acos('q')) - .toThrowError(/Argument 'x' passed to 'acos' must be numeric/); + expect(() => tf.cosh('q')) + .toThrowError(/Argument 'x' passed to 'cosh' must be numeric/); }); }); -describeWithFlags('atan', ALL_ENVS, () => { +describeWithFlags('tanh', ALL_ENVS, () => { it('basic', async () => { const values = [1, -3, 2, 7, -4]; const a = tf.tensor1d(values); - const result = tf.atan(a); + const result = tf.tanh(a); const expected = []; for (let i = 0; i < a.size; i++) { - expected[i] = Math.atan(values[i]); - } - expectArraysClose(await result.data(), expected); - }); - - it('6D atan', async () => { - const a = tf.range(1, 65).reshape([2, 2, 2, 2, 2, 2]); - const result = tf.atan(a); - - const expected = []; - for (let i = 1; i < 65; ++i) { - expected[i - 1] = Math.atan(i); + expected[i] = util.tanh(values[i]); } expectArraysClose(await result.data(), expected); }); it('propagates NaNs', async () => { const a = tf.tensor1d([4, NaN, 0]); - const res = tf.atan(a); - expectArraysClose(await res.data(), [Math.atan(4), NaN, Math.atan(0)]); + const res = tf.tanh(a); + expectArraysClose(await res.data(), [util.tanh(4), NaN, util.tanh(0)]); }); it('gradients: Scalar', async () => { const a = tf.scalar(0.5); const dy = tf.scalar(8); - const gradients = tf.grad(a => tf.atan(a))(a, dy); + const gradients = tf.grad(a => tf.tanh(a))(a, dy); expect(gradients.shape).toEqual(a.shape); expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), [8 / (1 + (0.5 * 0.5))]); + expectArraysClose( + await gradients.data(), [8 * (1 - (Math.tanh(0.5) * Math.tanh(0.5)))]); }); it('gradient with clones', async () => { const a = tf.scalar(0.5); const dy = tf.scalar(8); - const gradients = tf.grad(a => tf.atan(a.clone()).clone())(a, dy); + const gradients = tf.grad(a => tf.tanh(a.clone()).clone())(a, dy); expect(gradients.shape).toEqual(a.shape); expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), [8 / (1 + (0.5 * 0.5))]); + expectArraysClose( + await gradients.data(), [8 * (1 - (Math.tanh(0.5) * Math.tanh(0.5)))]); }); it('gradients: Tensor1D', async () => { - const aValues = [-0.1, 0.2, 0.3, -0.5]; + const aValues = [-1, 2, 3, -5]; const dyValues = [1, 2, 3, 4]; const a = tf.tensor1d(aValues); const dy = tf.tensor1d(dyValues); - const gradients = tf.grad(a => tf.atan(a))(a, dy); + const gradients = tf.grad(a => tf.tanh(a))(a, dy); const expected = []; for (let i = 0; i < a.size; i++) { - expected[i] = dyValues[i] / (1 + (aValues[i] * aValues[i])); + expected[i] = + dyValues[i] * (1 - (Math.tanh(aValues[i]) * Math.tanh(aValues[i]))); } expect(gradients.shape).toEqual(a.shape); @@ -2053,16 +1896,17 @@ describeWithFlags('atan', ALL_ENVS, () => { }); it('gradients: Tensor2D', async () => { - const aValues = [-0.3, 0.1, 0.2, 0.3]; + const aValues = [-3, 1, 2, 3]; const dyValues = [1, 2, 3, 4]; const a = tf.tensor2d(aValues, [2, 2]); const dy = tf.tensor2d(dyValues, [2, 2]); - const gradients = tf.grad(a => tf.atan(a))(a, dy); + const gradients = tf.grad(a => tf.tanh(a))(a, dy); const expected = []; for (let i = 0; i < a.size; i++) { - expected[i] = dyValues[i] / (1 + (aValues[i] * aValues[i])); + expected[i] = + dyValues[i] * (1 - (Math.tanh(aValues[i]) * Math.tanh(aValues[i]))); } expect(gradients.shape).toEqual(a.shape); @@ -2071,886 +1915,216 @@ describeWithFlags('atan', ALL_ENVS, () => { }); it('throws when passed a non-tensor', () => { - expect(() => tf.atan({} as tf.Tensor)) - .toThrowError(/Argument 'x' passed to 'atan' must be a Tensor/); + expect(() => tf.tanh({} as tf.Tensor)) + .toThrowError(/Argument 'x' passed to 'tanh' must be a Tensor/); }); it('accepts a tensor-like object', async () => { const values = [1, -3, 2, 7, -4]; - const result = tf.atan(values); + const result = tf.tanh(values); const expected = []; for (let i = 0; i < values.length; i++) { - expected[i] = Math.atan(values[i]); + expected[i] = util.tanh(values[i]); } expectArraysClose(await result.data(), expected); }); it('throws for string tensor', () => { - expect(() => tf.atan('q')) - .toThrowError(/Argument 'x' passed to 'atan' must be numeric/); + expect(() => tf.tanh('q')) + .toThrowError(/Argument 'x' passed to 'tanh' must be numeric/); }); }); -describeWithFlags('sinh', ALL_ENVS, () => { +describeWithFlags('clip', ALL_ENVS, () => { it('basic', async () => { - const values = [1, -3, 2, -1, -4]; - const a = tf.tensor1d(values); - const result = tf.sinh(a); + const a = tf.tensor1d([3, -1, 0, 100, -7, 2]); + const min = -1; + const max = 50; - const expected = []; - for (let i = 0; i < a.size; i++) { - expected[i] = Math.sinh(values[i]); - } - expectArraysClose(await result.data(), expected); + const result = tf.clipByValue(a, min, max); + + expectArraysClose(await result.data(), [3, -1, 0, 50, -1, 2]); }); it('propagates NaNs', async () => { - const a = tf.tensor1d([4, NaN, 0]); - const res = tf.sinh(a); - expectArraysClose(await res.data(), [Math.sinh(4), NaN, Math.sinh(0)]); + const a = tf.tensor1d([3, -1, 0, 100, -7, 2, NaN]); + const min = -1; + const max = 50; + + const result = tf.clipByValue(a, min, max); + + expectArraysClose(await result.data(), [3, -1, 0, 50, -1, 2, NaN]); }); - it('gradients: Scalar', async () => { - const a = tf.scalar(0.5); - const dy = tf.scalar(8); + it('min greater than max', () => { + const a = tf.tensor1d([3, -1, 0, 100, -7, 2]); + const min = 1; + const max = -1; - const gradients = tf.grad(a => tf.sinh(a))(a, dy); + const f = () => { + tf.clipByValue(a, min, max); + }; + expect(f).toThrowError(); + }); - expect(gradients.shape).toEqual(a.shape); + it('gradient: 1D tensor', async () => { + const min = -1; + const max = 2; + const x = tf.tensor1d([3, -2, 1]); // Only 1 is not clipped. + const dy = tf.tensor1d([5, 50, 500]); + const gradients = tf.grad(x => x.clipByValue(min, max))(x, dy); + + expect(gradients.shape).toEqual(x.shape); expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), [8 * Math.cosh(0.5)]); + expectArraysClose(await gradients.data(), [0, 0, 500]); }); - it('gradient with clones', async () => { - const a = tf.scalar(0.5); - const dy = tf.scalar(8); - - const gradients = tf.grad(a => tf.sinh(a.clone()).clone())(a, dy); + it('gradient: 1D tensor with max or min value', async () => { + const min = -1; + const max = 2; + const x = tf.tensor1d([-1, 1, 2, 3]); + const dy = tf.tensor1d([1, 10, 100, 1000]); + const gradients = tf.grad(x => x.clipByValue(min, max))(x, dy); - expect(gradients.shape).toEqual(a.shape); + expect(gradients.shape).toEqual(x.shape); expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), [8 * Math.cosh(0.5)]); + expectArraysClose(await gradients.data(), [1, 10, 100, 0]); }); - it('gradients: Tensor1D', async () => { - const aValues = [-1, 2, 3, -5]; - const dyValues = [1, 2, 3, 4]; - const a = tf.tensor1d(aValues); - const dy = tf.tensor1d(dyValues); + it('gradient: scalar', async () => { + const min = -1; + const max = 2; + const x = tf.scalar(-10); // Clipped. + const dy = tf.scalar(5); + const gradients = tf.grad(x => x.clipByValue(min, max))(x, dy); - const gradients = tf.grad(a => tf.sinh(a))(a, dy); + expect(gradients.shape).toEqual(x.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), [0]); + }); - const expected = []; - for (let i = 0; i < a.size; i++) { - expected[i] = dyValues[i] * Math.cosh(aValues[i]); - } + it('gradient with clones', async () => { + const min = -1; + const max = 2; + const x = tf.scalar(-10); // Clipped. + const dy = tf.scalar(5); + const gradients = + tf.grad(x => x.clone().clipByValue(min, max).clone())(x, dy); - expect(gradients.shape).toEqual(a.shape); + expect(gradients.shape).toEqual(x.shape); expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), expected); + expectArraysClose(await gradients.data(), [0]); }); - it('gradients: Tensor2D', async () => { - const aValues = [-3, 1, 2, 3]; - const dyValues = [1, 2, 3, 4]; - const a = tf.tensor2d(aValues, [2, 2]); - const dy = tf.tensor2d(dyValues, [2, 2]); - - const gradients = tf.grad(a => tf.sinh(a))(a, dy); - - const expected = []; - for (let i = 0; i < a.size; i++) { - expected[i] = dyValues[i] * Math.cosh(aValues[i]); - } - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), expected); - }); - - it('throws when passed a non-tensor', () => { - expect(() => tf.sinh({} as tf.Tensor)) - .toThrowError(/Argument 'x' passed to 'sinh' must be a Tensor/); - }); - - it('accepts a tensor-like object', async () => { - const values = [1, -3, 2, -1, -4]; - const result = tf.sinh(values); - - const expected = []; - for (let i = 0; i < values.length; i++) { - expected[i] = Math.sinh(values[i]); - } - expectArraysClose(await result.data(), expected); - }); - - it('throws for string tensor', () => { - expect(() => tf.sinh('q')) - .toThrowError(/Argument 'x' passed to 'sinh' must be numeric/); - }); -}); - -describeWithFlags('cosh', ALL_ENVS, () => { - it('basic', async () => { - const values = [1, -3, 2, -1, -4]; - const a = tf.tensor1d(values); - const result = tf.cosh(a); - - const expected = []; - for (let i = 0; i < a.size; i++) { - expected[i] = Math.cosh(values[i]); - } - - expectArraysClose(await result.data(), expected); - }); - - it('propagates NaNs', async () => { - const a = tf.tensor1d([4, NaN, 0]); - const res = tf.cosh(a); - expectArraysClose(await res.data(), [Math.cosh(4), NaN, Math.cosh(0)]); - }); - - it('gradients: Scalar', async () => { - const a = tf.scalar(0.5); - const dy = tf.scalar(8); - - const gradients = tf.grad(a => tf.cosh(a))(a, dy); - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), [8 * Math.sinh(0.5)]); - }); - - it('gradient with clones', async () => { - const a = tf.scalar(0.5); - const dy = tf.scalar(8); - - const gradients = tf.grad(a => tf.cosh(a.clone()).clone())(a, dy); - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), [8 * Math.sinh(0.5)]); - }); - - it('gradients: Tensor1D', async () => { - const aValues = [-1, 2, 3, -5]; - const dyValues = [1, 2, 3, 4]; - const a = tf.tensor1d(aValues); - const dy = tf.tensor1d(dyValues); - - const gradients = tf.grad(a => tf.cosh(a))(a, dy); - - const expected = []; - for (let i = 0; i < a.size; i++) { - expected[i] = dyValues[i] * Math.sinh(aValues[i]); - } - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), expected); - }); - - it('gradients: Tensor2D', async () => { - const aValues = [-3, 1, 2, 3]; - const dyValues = [1, 2, 3, 4]; - const a = tf.tensor2d(aValues, [2, 2]); - const dy = tf.tensor2d(dyValues, [2, 2]); - - const gradients = tf.grad(a => tf.cosh(a))(a, dy); - - const expected = []; - for (let i = 0; i < a.size; i++) { - expected[i] = dyValues[i] * Math.sinh(aValues[i]); - } - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), expected); - }); - - it('throws when passed a non-tensor', () => { - expect(() => tf.cosh({} as tf.Tensor)) - .toThrowError(/Argument 'x' passed to 'cosh' must be a Tensor/); - }); - - it('accepts a tensor-like object', async () => { - const values = [1, -3, 2, -1, -4]; - const result = tf.cosh(values); - - const expected = []; - for (let i = 0; i < values.length; i++) { - expected[i] = Math.cosh(values[i]); - } - - expectArraysClose(await result.data(), expected); - }); - - it('throws for string tensor', () => { - expect(() => tf.cosh('q')) - .toThrowError(/Argument 'x' passed to 'cosh' must be numeric/); - }); -}); - -describeWithFlags('tanh', ALL_ENVS, () => { - it('basic', async () => { - const values = [1, -3, 2, 7, -4]; - const a = tf.tensor1d(values); - const result = tf.tanh(a); - - const expected = []; - for (let i = 0; i < a.size; i++) { - expected[i] = util.tanh(values[i]); - } - expectArraysClose(await result.data(), expected); - }); - - it('propagates NaNs', async () => { - const a = tf.tensor1d([4, NaN, 0]); - const res = tf.tanh(a); - expectArraysClose(await res.data(), [util.tanh(4), NaN, util.tanh(0)]); - }); - - it('gradients: Scalar', async () => { - const a = tf.scalar(0.5); - const dy = tf.scalar(8); - - const gradients = tf.grad(a => tf.tanh(a))(a, dy); - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose( - await gradients.data(), [8 * (1 - (Math.tanh(0.5) * Math.tanh(0.5)))]); - }); - - it('gradient with clones', async () => { - const a = tf.scalar(0.5); - const dy = tf.scalar(8); - - const gradients = tf.grad(a => tf.tanh(a.clone()).clone())(a, dy); - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose( - await gradients.data(), [8 * (1 - (Math.tanh(0.5) * Math.tanh(0.5)))]); - }); - - it('gradients: Tensor1D', async () => { - const aValues = [-1, 2, 3, -5]; - const dyValues = [1, 2, 3, 4]; - const a = tf.tensor1d(aValues); - const dy = tf.tensor1d(dyValues); - - const gradients = tf.grad(a => tf.tanh(a))(a, dy); - - const expected = []; - for (let i = 0; i < a.size; i++) { - expected[i] = - dyValues[i] * (1 - (Math.tanh(aValues[i]) * Math.tanh(aValues[i]))); - } - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), expected); - }); - - it('gradients: Tensor2D', async () => { - const aValues = [-3, 1, 2, 3]; - const dyValues = [1, 2, 3, 4]; - const a = tf.tensor2d(aValues, [2, 2]); - const dy = tf.tensor2d(dyValues, [2, 2]); - - const gradients = tf.grad(a => tf.tanh(a))(a, dy); - - const expected = []; - for (let i = 0; i < a.size; i++) { - expected[i] = - dyValues[i] * (1 - (Math.tanh(aValues[i]) * Math.tanh(aValues[i]))); - } - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), expected); - }); - - it('throws when passed a non-tensor', () => { - expect(() => tf.tanh({} as tf.Tensor)) - .toThrowError(/Argument 'x' passed to 'tanh' must be a Tensor/); - }); - - it('accepts a tensor-like object', async () => { - const values = [1, -3, 2, 7, -4]; - const result = tf.tanh(values); - - const expected = []; - for (let i = 0; i < values.length; i++) { - expected[i] = util.tanh(values[i]); - } - expectArraysClose(await result.data(), expected); - }); - - it('throws for string tensor', () => { - expect(() => tf.tanh('q')) - .toThrowError(/Argument 'x' passed to 'tanh' must be numeric/); - }); -}); - -describeWithFlags('clip', ALL_ENVS, () => { - it('basic', async () => { - const a = tf.tensor1d([3, -1, 0, 100, -7, 2]); - const min = -1; - const max = 50; - - const result = tf.clipByValue(a, min, max); - - expectArraysClose(await result.data(), [3, -1, 0, 50, -1, 2]); - }); - - it('propagates NaNs', async () => { - const a = tf.tensor1d([3, -1, 0, 100, -7, 2, NaN]); - const min = -1; - const max = 50; - - const result = tf.clipByValue(a, min, max); - - expectArraysClose(await result.data(), [3, -1, 0, 50, -1, 2, NaN]); - }); - - it('min greater than max', () => { - const a = tf.tensor1d([3, -1, 0, 100, -7, 2]); - const min = 1; - const max = -1; - - const f = () => { - tf.clipByValue(a, min, max); - }; - expect(f).toThrowError(); - }); - - it('gradient: 1D tensor', async () => { - const min = -1; - const max = 2; - const x = tf.tensor1d([3, -2, 1]); // Only 1 is not clipped. - const dy = tf.tensor1d([5, 50, 500]); - const gradients = tf.grad(x => x.clipByValue(min, max))(x, dy); - - expect(gradients.shape).toEqual(x.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), [0, 0, 500]); - }); - - it('gradient: 1D tensor with max or min value', async () => { - const min = -1; - const max = 2; - const x = tf.tensor1d([-1, 1, 2, 3]); - const dy = tf.tensor1d([1, 10, 100, 1000]); - const gradients = tf.grad(x => x.clipByValue(min, max))(x, dy); - - expect(gradients.shape).toEqual(x.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), [1, 10, 100, 0]); - }); - - it('gradient: scalar', async () => { - const min = -1; - const max = 2; - const x = tf.scalar(-10); // Clipped. - const dy = tf.scalar(5); - const gradients = tf.grad(x => x.clipByValue(min, max))(x, dy); - - expect(gradients.shape).toEqual(x.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), [0]); - }); - - it('gradient with clones', async () => { - const min = -1; - const max = 2; - const x = tf.scalar(-10); // Clipped. - const dy = tf.scalar(5); - const gradients = - tf.grad(x => x.clone().clipByValue(min, max).clone())(x, dy); - - expect(gradients.shape).toEqual(x.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), [0]); - }); - - it('gradient with primitive as input', async () => { - const min = -1; - const max = 2; - const x = -10; - const dy = tf.scalar(5); - const gradients = tf.grad(x => x.clipByValue(min, max))(x, dy); - expect(gradients.shape).toEqual([]); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), [0]); - }); - - it('throws when passed a non-tensor', () => { - expect(() => tf.clipByValue({} as tf.Tensor, 0, 1)) - .toThrowError(/Argument 'x' passed to 'clipByValue' must be a Tensor/); - }); - - it('accepts a tensor-like object', async () => { - const min = -1; - const max = 50; - const result = tf.clipByValue([3, -1, 0, 100, -7, 2], min, max); - expectArraysClose(await result.data(), [3, -1, 0, 50, -1, 2]); - }); - - it('clip(x, eps, 1-eps) never returns 0 or 1', async () => { - const min = tf.backend().epsilon(); - const max = 0.5; - const res = await tf.clipByValue([0, 1], min, max).data(); - expect(res[0]).toBeGreaterThan(0); - expect(res[1]).toBeCloseTo(max); - }); - - it('throws for string tensor', () => { - expect(() => tf.clipByValue('q', 0, 1)) - .toThrowError(/Argument 'x' passed to 'clipByValue' must be numeric/); - }); -}); - -describeWithFlags('round', ALL_ENVS, () => { - it('basic', async () => { - const a = tf.tensor1d([0.9, 2.5, 2.3, 1.5, -4.5]); - const r = a.round(); - - expectArraysClose(await r.data(), [1, 2, 2, 2, -4]); - }); - - it('propagates NaNs', async () => { - const a = tf.tensor1d([1.5, NaN, -1.4]); - const r = tf.round(a); - expectArraysClose(await r.data(), [2, NaN, -1]); - }); - - it('gradients: Scalar', async () => { - const a = tf.scalar(5.2); - const dy = tf.scalar(3); - - const gradients = tf.grad(a => tf.round(a))(a, dy); - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), [0]); - }); - - it('gradient with clones', async () => { - const a = tf.scalar(5.2); - const dy = tf.scalar(3); - - const gradients = tf.grad(a => tf.round(a.clone()).clone())(a, dy); - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), [0]); - }); - - it('gradients: Tensor1D', async () => { - const a = tf.tensor1d([-1.1, 2.6, 3, -5.9]); - const dy = tf.tensor1d([1, 2, 3, 4]); - - const gradients = tf.grad(a => tf.round(a))(a, dy); - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), [0, 0, 0, 0]); - }); - - it('gradients: Tensor2D', async () => { - const a = tf.tensor2d([-3, 1, 2.2, 3], [2, 2]); - const dy = tf.tensor2d([1, 2, 3, 4], [2, 2]); - - const gradients = tf.grad(a => tf.round(a))(a, dy); - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), [0, 0, 0, 0]); - }); - - it('throws when passed a non-tensor', () => { - expect(() => tf.round({} as tf.Tensor)) - .toThrowError(/Argument 'x' passed to 'round' must be a Tensor/); - }); - - it('accepts a tensor-like object', async () => { - const r = tf.round([0.9, 2.5, 2.3, 1.5, -4.5]); - expectArraysClose(await r.data(), [1, 2, 2, 2, -4]); - }); - - it('throws for string tensor', () => { - expect(() => tf.round('q')) - .toThrowError(/Argument 'x' passed to 'round' must be numeric/); - }); -}); - -describeWithFlags('asinh', ALL_ENVS, () => { - it('basic', async () => { - const values = [1, -3, 2, 7, -4]; - const a = tf.tensor1d(values); - const result = tf.asinh(a); - - const expected = []; - for (let i = 0; i < a.size; i++) { - expected[i] = Math.asinh(values[i]); - } - expectArraysClose(await result.data(), expected); - }); - - it('scalar', async () => { - const a = tf.scalar(1); - const result = tf.asinh(a); - - const expected = [Math.asinh(1)]; - expectArraysClose(await result.data(), expected); - }); - - it('tensor2D', async () => { - const values = [1, -3, 2, 7]; - const a = tf.tensor2d(values, [2, 2]); - const result = tf.asinh(a); - - const expected = []; - for (let i = 0; i < a.size; i++) { - expected[i] = Math.asinh(values[i]); - } - expectArraysClose(await result.data(), expected); - }); - - it('propagates NaNs', async () => { - const a = tf.tensor1d([4, NaN, 0]); - const res = tf.asinh(a); - expectArraysClose(await res.data(), [Math.asinh(4), NaN, Math.asinh(0)]); - }); - - it('gradients: Scalar', async () => { - const a = tf.scalar(0.5); - const dy = tf.scalar(8); - - const gradients = tf.grad(a => tf.asinh(a))(a, dy); - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), [8 / Math.sqrt(1.0 + 0.5 * 0.5)]); - }); - - it('gradient with clones', async () => { - const a = tf.scalar(0.5); - const dy = tf.scalar(8); - - const gradients = tf.grad(a => tf.asinh(a.clone()).clone())(a, dy); - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), [8 / Math.sqrt(1.0 + 0.5 * 0.5)]); - }); - - it('gradients: Tensor1D', async () => { - const aValues = [-1, 2, 3, -5]; - const dyValues = [1, 2, 3, 4]; - const a = tf.tensor1d(aValues); - const dy = tf.tensor1d(dyValues); - - const gradients = tf.grad(a => tf.asinh(a))(a, dy); - - const expected = []; - for (let i = 0; i < a.size; i++) { - expected[i] = dyValues[i] / Math.sqrt(1 + aValues[i] * aValues[i]); - } - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), expected); - }); - - it('gradients: Tensor2D', async () => { - const aValues = [-3, 1, 2, 3]; - const dyValues = [1, 2, 3, 4]; - const a = tf.tensor2d(aValues, [2, 2]); - const dy = tf.tensor2d(dyValues, [2, 2]); - - const gradients = tf.grad(a => tf.asinh(a))(a, dy); - - const expected = []; - for (let i = 0; i < a.size; i++) { - expected[i] = dyValues[i] / Math.sqrt(1 + aValues[i] * aValues[i]); - } - - expect(gradients.shape).toEqual(a.shape); + it('gradient with primitive as input', async () => { + const min = -1; + const max = 2; + const x = -10; + const dy = tf.scalar(5); + const gradients = tf.grad(x => x.clipByValue(min, max))(x, dy); + expect(gradients.shape).toEqual([]); expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), expected); + expectArraysClose(await gradients.data(), [0]); }); it('throws when passed a non-tensor', () => { - expect(() => tf.asinh({} as tf.Tensor)) - .toThrowError(/Argument 'x' passed to 'asinh' must be a Tensor/); + expect(() => tf.clipByValue({} as tf.Tensor, 0, 1)) + .toThrowError(/Argument 'x' passed to 'clipByValue' must be a Tensor/); }); it('accepts a tensor-like object', async () => { - const values = [1, -3, 2, 7, -4]; - const result = tf.asinh(values); - - const expected = []; - for (let i = 0; i < values.length; i++) { - expected[i] = Math.asinh(values[i]); - } - expectArraysClose(await result.data(), expected); - }); - - it('throws for string tensor', () => { - expect(() => tf.asinh('q')) - .toThrowError(/Argument 'x' passed to 'asinh' must be numeric/); - }); -}); - -describeWithFlags('acosh', ALL_ENVS, () => { - it('basic', async () => { - const values = [2, 3, 4, 5, 6]; - const a = tf.tensor1d(values); - const result = tf.acosh(a); - - const expected = []; - for (let i = 0; i < a.size; i++) { - expected[i] = Math.acosh(values[i]); - } - expectArraysClose(await result.data(), expected); - }); - - it('scalar', async () => { - const value = 2; - const a = tf.scalar(value); - const result = tf.acosh(a); - - const expected = [Math.acosh(value)]; - expectArraysClose(await result.data(), expected); - }); - - it('tensor2d', async () => { - const values = [2, 3, 4, 5]; - const a = tf.tensor2d(values, [2, 2]); - const result = tf.acosh(a); - - const expected = []; - for (let i = 0; i < a.size; i++) { - expected[i] = Math.acosh(values[i]); - } - expectArraysClose(await result.data(), expected); - }); - - it('propagates NaNs', async () => { - const a = tf.tensor1d([4, NaN, 2]); - const res = tf.acosh(a); - expectArraysClose(await res.data(), [Math.acosh(4), NaN, Math.acosh(2)]); - }); - - it('NaN outside function domain', async () => { - const a = tf.tensor1d([4, -1, 2]); - const res = tf.acosh(a); - expectArraysClose(await res.data(), [Math.acosh(4), NaN, Math.acosh(2)]); - }); - - it('gradients: Scalar', async () => { - const a = tf.scalar(1.5); - const dy = tf.scalar(8); - - const gradients = tf.grad(a => tf.acosh(a))(a, dy); - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose( - await gradients.data(), [8.0 / Math.sqrt(1.5 * 1.5 - 1.0)]); - }); - - it('gradient with clones', async () => { - const a = tf.scalar(1.5); - const dy = tf.scalar(8); - - const gradients = tf.grad(a => tf.acosh(a.clone()).clone())(a, dy); - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose( - await gradients.data(), [8.0 / Math.sqrt(1.5 * 1.5 - 1.0)]); - }); - - it('gradients: Tensor1D', async () => { - const aValues = [2, 3, 5, 10]; - const dyValues = [1, 2, 3, 4]; - const a = tf.tensor1d(aValues); - const dy = tf.tensor1d(dyValues); - - const gradients = tf.grad(a => tf.acosh(a))(a, dy); - - const expected = []; - for (let i = 0; i < a.size; i++) { - expected[i] = dyValues[i] / Math.sqrt(Math.pow(aValues[i], 2) - 1.0); - } - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), expected); - }); - - it('gradients: Tensor2D', async () => { - const aValues = [2, 3, 5, 7]; - const dyValues = [1, 2, 3, 4]; - const a = tf.tensor2d(aValues, [2, 2]); - const dy = tf.tensor2d(dyValues, [2, 2]); - - const gradients = tf.grad(a => tf.acosh(a))(a, dy); - - const expected = []; - for (let i = 0; i < a.size; i++) { - expected[i] = dyValues[i] / Math.sqrt(Math.pow(aValues[i], 2) - 1.0); - } - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), expected); - }); - - it('throws when passed a non-tensor', () => { - expect(() => tf.acosh({} as tf.Tensor)) - .toThrowError(/Argument 'x' passed to 'acosh' must be a Tensor/); + const min = -1; + const max = 50; + const result = tf.clipByValue([3, -1, 0, 100, -7, 2], min, max); + expectArraysClose(await result.data(), [3, -1, 0, 50, -1, 2]); }); - it('accepts a tensor-like object', async () => { - const values = [2, 3, 4, 5, 6]; - const result = tf.acosh(values); - - const expected = []; - for (let i = 0; i < values.length; i++) { - expected[i] = Math.acosh(values[i]); - } - expectArraysClose(await result.data(), expected); + it('clip(x, eps, 1-eps) never returns 0 or 1', async () => { + const min = tf.backend().epsilon(); + const max = 0.5; + const res = await tf.clipByValue([0, 1], min, max).data(); + expect(res[0]).toBeGreaterThan(0); + expect(res[1]).toBeCloseTo(max); }); it('throws for string tensor', () => { - expect(() => tf.acosh('q')) - .toThrowError(/Argument 'x' passed to 'acosh' must be numeric/); + expect(() => tf.clipByValue('q', 0, 1)) + .toThrowError(/Argument 'x' passed to 'clipByValue' must be numeric/); }); }); -describeWithFlags('atanh', ALL_ENVS, () => { +describeWithFlags('round', ALL_ENVS, () => { it('basic', async () => { - const values = [-0.25, 0.25, 0.5, .75, -0.4]; - const a = tf.tensor1d(values); - const result = tf.atanh(a); - - const expected = []; - for (let i = 0; i < a.size; i++) { - expected[i] = Math.atanh(values[i]); - } - expectArraysClose(await result.data(), expected); - }); - - it('scalar', async () => { - const value = 0.2; - const a = tf.scalar(value); - const result = tf.atanh(a); - - const expected = [Math.atanh(value)]; - expectArraysClose(await result.data(), expected); - }); - - it('tensor2d', async () => { - const values = [0.2, 0.3, 0.4, 0.5]; - const a = tf.tensor2d(values, [2, 2]); - const result = tf.atanh(a); + const a = tf.tensor1d([0.9, 2.5, 2.3, 1.5, -4.5]); + const r = a.round(); - const expected = []; - for (let i = 0; i < a.size; i++) { - expected[i] = Math.atanh(values[i]); - } - expectArraysClose(await result.data(), expected); + expectArraysClose(await r.data(), [1, 2, 2, 2, -4]); }); it('propagates NaNs', async () => { - const a = tf.tensor1d([0.5, NaN, 0]); - const res = tf.atanh(a); - expectArraysClose(await res.data(), [Math.atanh(0.5), NaN, Math.atanh(0)]); - }); - - it('NaN outside function domain', async () => { - const a = tf.tensor1d([-2, 0, 2]); - const res = tf.atanh(a); - expectArraysClose(await res.data(), [NaN, Math.atanh(0), NaN]); + const a = tf.tensor1d([1.5, NaN, -1.4]); + const r = tf.round(a); + expectArraysClose(await r.data(), [2, NaN, -1]); }); it('gradients: Scalar', async () => { - const a = tf.scalar(0.5); - const dy = tf.scalar(8); + const a = tf.scalar(5.2); + const dy = tf.scalar(3); - const gradients = tf.grad(a => tf.atanh(a))(a, dy); + const gradients = tf.grad(a => tf.round(a))(a, dy); expect(gradients.shape).toEqual(a.shape); expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), [8 / (1 - 0.5 * 0.5)]); + expectArraysClose(await gradients.data(), [0]); }); it('gradient with clones', async () => { - const a = tf.scalar(0.5); - const dy = tf.scalar(8); + const a = tf.scalar(5.2); + const dy = tf.scalar(3); - const gradients = tf.grad(a => tf.atanh(a.clone()).clone())(a, dy); + const gradients = tf.grad(a => tf.round(a.clone()).clone())(a, dy); expect(gradients.shape).toEqual(a.shape); expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), [8 / (1 - 0.5 * 0.5)]); + expectArraysClose(await gradients.data(), [0]); }); it('gradients: Tensor1D', async () => { - const aValues = [-0.1, 0.2, 0.3, -0.5]; - const dyValues = [1, 2, 3, 4]; - const a = tf.tensor1d(aValues); - const dy = tf.tensor1d(dyValues); - - const gradients = tf.grad(a => tf.atanh(a))(a, dy); + const a = tf.tensor1d([-1.1, 2.6, 3, -5.9]); + const dy = tf.tensor1d([1, 2, 3, 4]); - const expected = []; - for (let i = 0; i < a.size; i++) { - expected[i] = dyValues[i] / (1 - Math.pow(aValues[i], 2)); - } + const gradients = tf.grad(a => tf.round(a))(a, dy); expect(gradients.shape).toEqual(a.shape); expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), expected); + expectArraysClose(await gradients.data(), [0, 0, 0, 0]); }); it('gradients: Tensor2D', async () => { - const aValues = [-0.3, 0.1, 0.2, 0.3]; - const dyValues = [1, 2, 3, 4]; - const a = tf.tensor2d(aValues, [2, 2]); - const dy = tf.tensor2d(dyValues, [2, 2]); - - const gradients = tf.grad(a => tf.atanh(a))(a, dy); + const a = tf.tensor2d([-3, 1, 2.2, 3], [2, 2]); + const dy = tf.tensor2d([1, 2, 3, 4], [2, 2]); - const expected = []; - for (let i = 0; i < a.size; i++) { - expected[i] = dyValues[i] / (1 - Math.pow(aValues[i], 2)); - } + const gradients = tf.grad(a => tf.round(a))(a, dy); expect(gradients.shape).toEqual(a.shape); expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), expected); + expectArraysClose(await gradients.data(), [0, 0, 0, 0]); }); it('throws when passed a non-tensor', () => { - expect(() => tf.atanh({} as tf.Tensor)) - .toThrowError(/Argument 'x' passed to 'atanh' must be a Tensor/); + expect(() => tf.round({} as tf.Tensor)) + .toThrowError(/Argument 'x' passed to 'round' must be a Tensor/); }); it('accepts a tensor-like object', async () => { - const result = tf.atanh(0.2); - expectArraysClose(await result.data(), [Math.atanh(0.2)]); + const r = tf.round([0.9, 2.5, 2.3, 1.5, -4.5]); + expectArraysClose(await r.data(), [1, 2, 2, 2, -4]); }); it('throws for string tensor', () => { - expect(() => tf.atanh('q')) - .toThrowError(/Argument 'x' passed to 'atanh' must be numeric/); + expect(() => tf.round('q')) + .toThrowError(/Argument 'x' passed to 'round' must be numeric/); }); }); diff --git a/tfjs-core/src/register_all_gradients.ts b/tfjs-core/src/register_all_gradients.ts index b6b7755a79e..7a6ec9f8daf 100644 --- a/tfjs-core/src/register_all_gradients.ts +++ b/tfjs-core/src/register_all_gradients.ts @@ -14,11 +14,18 @@ * limitations under the License. * ============================================================================= */ +import {absGradConfig} from './gradients/Abs_grad'; +import {acosGradConfig} from './gradients/Acos_grad'; +import {acoshGradConfig} from './gradients/Acosh_grad'; import {addGradConfig} from './gradients/Add_grad'; import {addNGradConfig} from './gradients/AddN_grad'; import {argMaxGradConfig} from './gradients/ArgMax_grad'; import {argMinGradConfig} from './gradients/ArgMin_grad'; +import {asinGradConfig} from './gradients/Asin_grad'; +import {asinhGradConfig} from './gradients/Asinh_grad'; import {atan2GradConfig} from './gradients/Atan2_grad'; +import {atanGradConfig} from './gradients/Atan_grad'; +import {atanhGradConfig} from './gradients/Atanh_grad'; import {avgPool3DGradConfig} from './gradients/AvgPool3D_grad'; import {avgPoolGradConfig} from './gradients/AvgPool_grad'; import {batchMatMulGradConfig} from './gradients/BatchMatMul_grad'; @@ -78,11 +85,18 @@ import {registerGradient} from './kernel_registry'; // Export all kernel configs here so that the package can auto register them const gradConfigs: GradConfig[] = [ + absGradConfig, + acosGradConfig, + acoshGradConfig, addGradConfig, addNGradConfig, argMaxGradConfig, argMinGradConfig, + asinGradConfig, + asinhGradConfig, atan2GradConfig, + atanGradConfig, + atanhGradConfig, avgPool3DGradConfig, avgPoolGradConfig, batchMatMulGradConfig, diff --git a/tfjs-core/src/tests.ts b/tfjs-core/src/tests.ts index 7bc70093dff..ee53aa9fd37 100644 --- a/tfjs-core/src/tests.ts +++ b/tfjs-core/src/tests.ts @@ -38,7 +38,10 @@ import './io/router_registry_test'; import './io/weights_loader_test'; import './jasmine_util_test'; import './kernel_registry_test'; +import './ops/abs_test'; import './ops/absolute_difference_test'; +import './ops/acos_test'; +import './ops/acosh_test'; import './ops/add_n_test'; import './ops/add_test'; import './ops/all_test'; @@ -47,6 +50,10 @@ import './ops/arg_max_test'; import './ops/arg_min_test'; import './ops/arithmetic_test'; import './ops/array_ops_test'; +import './ops/asin_test'; +import './ops/asinh_test'; +import './ops/atan_test'; +import './ops/atanh_test'; import './ops/avg_pool_3d_test'; import './ops/avg_pool_test'; import './ops/axis_util_test'; @@ -163,6 +170,7 @@ import './ops/topk_test'; import './ops/transpose_test'; import './ops/truncated_normal_test'; import './ops/unary_ops_test'; +import './ops/undefined_test'; import './ops/unsorted_segment_sum_test'; import './ops/unstack_test'; import './ops/where_async_test'; From dc0c9cb0946e9990ce842c3079661f34d92d1918 Mon Sep 17 00:00:00 2001 From: Yannick Assogba Date: Tue, 7 Jul 2020 22:30:29 -0400 Subject: [PATCH 2/3] modularize cos, cosh, sin, sinh, tan, tanh --- tfjs-core/src/gradients/Cos_grad.ts | 34 ++ tfjs-core/src/gradients/Cosh_grad.ts | 33 ++ tfjs-core/src/gradients/Sin_grad.ts | 32 ++ tfjs-core/src/gradients/Sinh_grad.ts | 33 ++ tfjs-core/src/gradients/Tan_grad.ts | 33 ++ tfjs-core/src/gradients/Tanh_grad.ts | 34 ++ tfjs-core/src/kernel_names.ts | 18 + tfjs-core/src/ops/absolute_difference.ts | 1 - tfjs-core/src/ops/basic_lstm_cell.ts | 3 +- tfjs-core/src/ops/cos.ts | 49 ++ tfjs-core/src/ops/cos_test.ts | 117 +++++ tfjs-core/src/ops/cosh.ts | 48 ++ tfjs-core/src/ops/cosh_test.ts | 121 +++++ tfjs-core/src/ops/ops.ts | 6 + tfjs-core/src/ops/sin.ts | 49 ++ tfjs-core/src/ops/sin_test.ts | 111 +++++ tfjs-core/src/ops/sinh.ts | 48 ++ tfjs-core/src/ops/sinh_test.ts | 119 +++++ tfjs-core/src/ops/tan.ts | 49 ++ tfjs-core/src/ops/tan_test.ts | 122 +++++ tfjs-core/src/ops/tanh.ts | 49 ++ tfjs-core/src/ops/tanh_test.ts | 124 +++++ tfjs-core/src/ops/unary_ops.ts | 165 ------- tfjs-core/src/ops/unary_ops_test.ts | 596 +---------------------- tfjs-core/src/register_all_gradients.ts | 12 + tfjs-core/src/tests.ts | 7 +- 26 files changed, 1250 insertions(+), 763 deletions(-) create mode 100644 tfjs-core/src/gradients/Cos_grad.ts create mode 100644 tfjs-core/src/gradients/Cosh_grad.ts create mode 100644 tfjs-core/src/gradients/Sin_grad.ts create mode 100644 tfjs-core/src/gradients/Sinh_grad.ts create mode 100644 tfjs-core/src/gradients/Tan_grad.ts create mode 100644 tfjs-core/src/gradients/Tanh_grad.ts create mode 100644 tfjs-core/src/ops/cos.ts create mode 100644 tfjs-core/src/ops/cos_test.ts create mode 100644 tfjs-core/src/ops/cosh.ts create mode 100644 tfjs-core/src/ops/cosh_test.ts create mode 100644 tfjs-core/src/ops/sin.ts create mode 100644 tfjs-core/src/ops/sin_test.ts create mode 100644 tfjs-core/src/ops/sinh.ts create mode 100644 tfjs-core/src/ops/sinh_test.ts create mode 100644 tfjs-core/src/ops/tan.ts create mode 100644 tfjs-core/src/ops/tan_test.ts create mode 100644 tfjs-core/src/ops/tanh.ts create mode 100644 tfjs-core/src/ops/tanh_test.ts diff --git a/tfjs-core/src/gradients/Cos_grad.ts b/tfjs-core/src/gradients/Cos_grad.ts new file mode 100644 index 00000000000..0b42e974c8d --- /dev/null +++ b/tfjs-core/src/gradients/Cos_grad.ts @@ -0,0 +1,34 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Cos} from '../kernel_names'; +import {GradConfig} from '../kernel_registry'; +import {cast} from '../ops/array_ops'; +import {mul} from '../ops/mul'; +import {neg} from '../ops/neg'; +import {sin} from '../ops/sin'; +import {Tensor} from '../tensor'; + +export const cosGradConfig: GradConfig = { + kernelName: Cos, + inputsToSave: ['x'], + gradFunc: (dy: Tensor, saved: Tensor[]) => { + const [x] = saved; + + return {x: () => mul(neg(sin(cast(x, 'float32'))), dy)}; + } +}; diff --git a/tfjs-core/src/gradients/Cosh_grad.ts b/tfjs-core/src/gradients/Cosh_grad.ts new file mode 100644 index 00000000000..f983e4b7776 --- /dev/null +++ b/tfjs-core/src/gradients/Cosh_grad.ts @@ -0,0 +1,33 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Cosh} from '../kernel_names'; +import {GradConfig} from '../kernel_registry'; +import {cast} from '../ops/array_ops'; +import {mul} from '../ops/mul'; +import {sinh} from '../ops/sinh'; +import {Tensor} from '../tensor'; + +export const coshGradConfig: GradConfig = { + kernelName: Cosh, + inputsToSave: ['x'], + gradFunc: (dy: Tensor, saved: Tensor[]) => { + const [x] = saved; + + return {x: () => mul(sinh(cast(x, 'float32')), dy)}; + } +}; diff --git a/tfjs-core/src/gradients/Sin_grad.ts b/tfjs-core/src/gradients/Sin_grad.ts new file mode 100644 index 00000000000..687cdad0c7e --- /dev/null +++ b/tfjs-core/src/gradients/Sin_grad.ts @@ -0,0 +1,32 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Sin} from '../kernel_names'; +import {GradConfig} from '../kernel_registry'; +import {cast} from '../ops/array_ops'; +import {cos} from '../ops/cos'; +import {mul} from '../ops/mul'; +import {Tensor} from '../tensor'; + +export const sinGradConfig: GradConfig = { + kernelName: Sin, + inputsToSave: ['x'], + gradFunc: (dy: Tensor, saved: Tensor[]) => { + const [x] = saved; + return {x: () => mul(cos(cast(x, 'float32')), dy)}; + } +}; diff --git a/tfjs-core/src/gradients/Sinh_grad.ts b/tfjs-core/src/gradients/Sinh_grad.ts new file mode 100644 index 00000000000..2976eacd410 --- /dev/null +++ b/tfjs-core/src/gradients/Sinh_grad.ts @@ -0,0 +1,33 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Sinh} from '../kernel_names'; +import {GradConfig} from '../kernel_registry'; +import {cast} from '../ops/array_ops'; +import {cosh} from '../ops/cosh'; +import {mul} from '../ops/mul'; +import {Tensor} from '../tensor'; + +export const sinhGradConfig: GradConfig = { + kernelName: Sinh, + inputsToSave: ['x'], + gradFunc: (dy: Tensor, saved: Tensor[]) => { + const [x] = saved; + + return {x: () => mul(cosh(cast(x, 'float32')), dy)}; + } +}; diff --git a/tfjs-core/src/gradients/Tan_grad.ts b/tfjs-core/src/gradients/Tan_grad.ts new file mode 100644 index 00000000000..858285b1ee1 --- /dev/null +++ b/tfjs-core/src/gradients/Tan_grad.ts @@ -0,0 +1,33 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Tan} from '../kernel_names'; +import {GradConfig} from '../kernel_registry'; +import {cos} from '../ops/cos'; +import {div} from '../ops/div'; +import {square} from '../ops/square'; +import {Tensor} from '../tensor'; + +export const tanGradConfig: GradConfig = { + kernelName: Tan, + inputsToSave: ['x'], + gradFunc: (dy: Tensor, saved: Tensor[]) => { + const [x] = saved; + + return {x: () => div(dy, square(cos(x)))}; + } +}; diff --git a/tfjs-core/src/gradients/Tanh_grad.ts b/tfjs-core/src/gradients/Tanh_grad.ts new file mode 100644 index 00000000000..e2e1096c370 --- /dev/null +++ b/tfjs-core/src/gradients/Tanh_grad.ts @@ -0,0 +1,34 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Tanh} from '../kernel_names'; +import {GradConfig} from '../kernel_registry'; +import {mul} from '../ops/mul'; +import {square} from '../ops/square'; +import {sub} from '../ops/sub'; +import {scalar} from '../ops/tensor_ops'; +import {Tensor} from '../tensor'; + +export const tanhGradConfig: GradConfig = { + kernelName: Tanh, + outputsToSave: [true], + gradFunc: (dy: Tensor, saved: Tensor[]) => { + const [y] = saved; + + return {x: () => mul(sub(scalar(1), square(y)), dy)}; + } +}; diff --git a/tfjs-core/src/kernel_names.ts b/tfjs-core/src/kernel_names.ts index 97cb10a35f7..f9a47048dc4 100644 --- a/tfjs-core/src/kernel_names.ts +++ b/tfjs-core/src/kernel_names.ts @@ -204,6 +204,12 @@ export interface Conv3DBackpropInputAttrs { pad: 'valid'|'same'; } +export const Cos = 'Cos'; +export type CosInputs = UnaryInputs; + +export const Cosh = 'Cosh'; +export type CoshInputs = UnaryInputs; + export const Cumsum = 'Cumsum'; export type CumsumInputs = Pick; export interface CumsumAttrs { @@ -547,6 +553,12 @@ export type SelectV2Inputs = Pick; export const Selu = 'Selu'; export type SeluInputs = Pick; +export const Sin = 'Sin'; +export type SinInputs = UnaryInputs; + +export const Sinh = 'Sinh'; +export type SinhInputs = UnaryInputs; + export const Sign = 'Sign'; export type SignInputs = UnaryInputs; @@ -580,6 +592,12 @@ export type SquareInputs = Pick; export const Sub = 'Sub'; export type SubInputs = BinaryInputs; +export const Tan = 'Tan'; +export type TanInputs = UnaryInputs; + +export const Tanh = 'Tanh'; +export type TanhInputs = UnaryInputs; + export const Tile = 'Tile'; export type TileInputs = Pick; export interface TileAttrs { diff --git a/tfjs-core/src/ops/absolute_difference.ts b/tfjs-core/src/ops/absolute_difference.ts index f9abfe32c30..9a91866c6cb 100644 --- a/tfjs-core/src/ops/absolute_difference.ts +++ b/tfjs-core/src/ops/absolute_difference.ts @@ -26,7 +26,6 @@ import {Reduction} from './loss_ops_utils'; import {op} from './operation'; import {sub} from './sub'; - /** * Computes the absolute difference loss between two tensors. * diff --git a/tfjs-core/src/ops/basic_lstm_cell.ts b/tfjs-core/src/ops/basic_lstm_cell.ts index e78de650841..cda956aceed 100644 --- a/tfjs-core/src/ops/basic_lstm_cell.ts +++ b/tfjs-core/src/ops/basic_lstm_cell.ts @@ -25,7 +25,8 @@ import {matMul} from './mat_mul'; import {mul} from './mul'; import {op} from './operation'; import {slice} from './slice'; -import {sigmoid, tanh} from './unary_ops'; +import {tanh} from './tanh'; +import {sigmoid} from './unary_ops'; /** * Computes the next state and output of a BasicLSTMCell. diff --git a/tfjs-core/src/ops/cos.ts b/tfjs-core/src/ops/cos.ts new file mode 100644 index 00000000000..9d4872d6de4 --- /dev/null +++ b/tfjs-core/src/ops/cos.ts @@ -0,0 +1,49 @@ +/** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {ENGINE} from '../engine'; +import {Cos, CosInputs} from '../kernel_names'; +import {Tensor} from '../tensor'; +import {NamedTensorMap} from '../tensor_types'; +import {convertToTensor} from '../tensor_util_env'; +import {TensorLike} from '../types'; + +import {op} from './operation'; + +/** + * Computes cos of the input `tf.Tensor` element-wise: `cos(x)` + * + * ```js + * const x = tf.tensor1d([0, Math.PI / 2, Math.PI * 3 / 4]); + * + * x.cos().print(); // or tf.cos(x) + * ``` + * @param x The input tensor. + */ +/** @doc {heading: 'Operations', subheading: 'Basic math'} */ +function cos_(x: T|TensorLike): T { + const $x = convertToTensor(x, 'x', 'cos'); + + const inputs: CosInputs = {x: $x}; + + return ENGINE.runKernelFunc((backend, save) => { + const res = backend.cos($x); + save([$x]); + return res; + }, inputs as {} as NamedTensorMap, null /* grad */, Cos); +} +export const cos = op({cos_}); diff --git a/tfjs-core/src/ops/cos_test.ts b/tfjs-core/src/ops/cos_test.ts new file mode 100644 index 00000000000..ee7d4d6349f --- /dev/null +++ b/tfjs-core/src/ops/cos_test.ts @@ -0,0 +1,117 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import * as tf from '../index'; +import {ALL_ENVS, describeWithFlags} from '../jasmine_util'; +import {expectArraysClose} from '../test_util'; + +describeWithFlags('cos', ALL_ENVS, () => { + it('basic', async () => { + const values = [1, -3, 2, 7, -4]; + const a = tf.tensor1d(values); + const result = tf.cos(a); + + const expected = []; + for (let i = 0; i < a.size; i++) { + expected[i] = Math.cos(values[i]); + } + expectArraysClose(await result.data(), expected); + }); + + it('propagates NaNs', async () => { + const a = tf.tensor1d([4, NaN, 0]); + const res = tf.cos(a); + expectArraysClose(await res.data(), [Math.cos(4), NaN, Math.cos(0)]); + }); + + it('gradients: Scalar', async () => { + const a = tf.scalar(5); + const dy = tf.scalar(8); + + const gradients = tf.grad(a => tf.cos(a))(a, dy); + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), [8 * Math.sin(5) * -1]); + }); + + it('gradient with clones', async () => { + const a = tf.scalar(5); + const dy = tf.scalar(8); + + const gradients = tf.grad(a => tf.cos(a.clone()).clone())(a, dy); + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), [8 * Math.sin(5) * -1]); + }); + + it('gradients: Tensor1D', async () => { + const a = tf.tensor1d([-1, 2, 3, -5]); + const dy = tf.tensor1d([1, 2, 3, 4]); + + const gradients = tf.grad(a => tf.cos(a))(a, dy); + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose( + await gradients.data(), + [ + 1 * Math.sin(-1) * -1, 2 * Math.sin(2) * -1, 3 * Math.sin(3) * -1, + 4 * Math.sin(-5) * -1 + ], + 1e-1); + }); + + it('gradients: Tensor2D', async () => { + const a = tf.tensor2d([-3, 1, 2, 3], [2, 2]); + const dy = tf.tensor2d([1, 2, 3, 4], [2, 2]); + + const gradients = tf.grad(a => tf.cos(a))(a, dy); + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose( + await gradients.data(), + [ + 1 * Math.sin(-3) * -1, 2 * Math.sin(1) * -1, 3 * Math.sin(2) * -1, + 4 * Math.sin(3) * -1 + ], + 1e-1); + }); + + it('throws when passed a non-tensor', () => { + expect(() => tf.cos({} as tf.Tensor)) + .toThrowError(/Argument 'x' passed to 'cos' must be a Tensor/); + }); + + it('accepts a tensor-like object', async () => { + const values = [1, -3, 2, 7, -4]; + const result = tf.cos(values); + + const expected = []; + for (let i = 0; i < values.length; i++) { + expected[i] = Math.cos(values[i]); + } + expectArraysClose(await result.data(), expected); + }); + + it('throws for string tensor', () => { + expect(() => tf.cos('q')) + .toThrowError(/Argument 'x' passed to 'cos' must be numeric/); + }); +}); diff --git a/tfjs-core/src/ops/cosh.ts b/tfjs-core/src/ops/cosh.ts new file mode 100644 index 00000000000..30843823e29 --- /dev/null +++ b/tfjs-core/src/ops/cosh.ts @@ -0,0 +1,48 @@ +/** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {ENGINE} from '../engine'; +import {Cosh, CoshInputs} from '../kernel_names'; +import {Tensor} from '../tensor'; +import {NamedTensorMap} from '../tensor_types'; +import {convertToTensor} from '../tensor_util_env'; +import {TensorLike} from '../types'; + +import {op} from './operation'; + +/** + * Computes hyperbolic cos of the input `tf.Tensor` element-wise: `cosh(x)` + * + * ```js + * const x = tf.tensor1d([0, 1, -1, .7]); + * + * x.cosh().print(); // or tf.cosh(x) + * ``` + * @param x The input tensor. + */ +/** @doc {heading: 'Operations', subheading: 'Basic math'} */ +function cosh_(x: T|TensorLike): T { + const $x = convertToTensor(x, 'x', 'cosh'); + const inputs: CoshInputs = {x: $x}; + + return ENGINE.runKernelFunc((backend, save) => { + const res = backend.cosh($x); + save([$x]); + return res; + }, inputs as {} as NamedTensorMap, null /* grad */, Cosh); +} +export const cosh = op({cosh_}); diff --git a/tfjs-core/src/ops/cosh_test.ts b/tfjs-core/src/ops/cosh_test.ts new file mode 100644 index 00000000000..93c0a0082c7 --- /dev/null +++ b/tfjs-core/src/ops/cosh_test.ts @@ -0,0 +1,121 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import * as tf from '../index'; +import {ALL_ENVS, describeWithFlags} from '../jasmine_util'; +import {expectArraysClose} from '../test_util'; + +describeWithFlags('cosh', ALL_ENVS, () => { + it('basic', async () => { + const values = [1, -3, 2, -1, -4]; + const a = tf.tensor1d(values); + const result = tf.cosh(a); + + const expected = []; + for (let i = 0; i < a.size; i++) { + expected[i] = Math.cosh(values[i]); + } + + expectArraysClose(await result.data(), expected); + }); + + it('propagates NaNs', async () => { + const a = tf.tensor1d([4, NaN, 0]); + const res = tf.cosh(a); + expectArraysClose(await res.data(), [Math.cosh(4), NaN, Math.cosh(0)]); + }); + + it('gradients: Scalar', async () => { + const a = tf.scalar(0.5); + const dy = tf.scalar(8); + + const gradients = tf.grad(a => tf.cosh(a))(a, dy); + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), [8 * Math.sinh(0.5)]); + }); + + it('gradient with clones', async () => { + const a = tf.scalar(0.5); + const dy = tf.scalar(8); + + const gradients = tf.grad(a => tf.cosh(a.clone()).clone())(a, dy); + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), [8 * Math.sinh(0.5)]); + }); + + it('gradients: Tensor1D', async () => { + const aValues = [-1, 2, 3, -5]; + const dyValues = [1, 2, 3, 4]; + const a = tf.tensor1d(aValues); + const dy = tf.tensor1d(dyValues); + + const gradients = tf.grad(a => tf.cosh(a))(a, dy); + + const expected = []; + for (let i = 0; i < a.size; i++) { + expected[i] = dyValues[i] * Math.sinh(aValues[i]); + } + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), expected); + }); + + it('gradients: Tensor2D', async () => { + const aValues = [-3, 1, 2, 3]; + const dyValues = [1, 2, 3, 4]; + const a = tf.tensor2d(aValues, [2, 2]); + const dy = tf.tensor2d(dyValues, [2, 2]); + + const gradients = tf.grad(a => tf.cosh(a))(a, dy); + + const expected = []; + for (let i = 0; i < a.size; i++) { + expected[i] = dyValues[i] * Math.sinh(aValues[i]); + } + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), expected); + }); + + it('throws when passed a non-tensor', () => { + expect(() => tf.cosh({} as tf.Tensor)) + .toThrowError(/Argument 'x' passed to 'cosh' must be a Tensor/); + }); + + it('accepts a tensor-like object', async () => { + const values = [1, -3, 2, -1, -4]; + const result = tf.cosh(values); + + const expected = []; + for (let i = 0; i < values.length; i++) { + expected[i] = Math.cosh(values[i]); + } + + expectArraysClose(await result.data(), expected); + }); + + it('throws for string tensor', () => { + expect(() => tf.cosh('q')) + .toThrowError(/Argument 'x' passed to 'cosh' must be numeric/); + }); +}); diff --git a/tfjs-core/src/ops/ops.ts b/tfjs-core/src/ops/ops.ts index 99ecd497cbf..7f6d30703ec 100644 --- a/tfjs-core/src/ops/ops.ts +++ b/tfjs-core/src/ops/ops.ts @@ -52,6 +52,8 @@ export {conv2d} from './conv2d'; export {conv2dTranspose} from './conv2d_transpose'; export {conv3d} from './conv3d'; export {conv3dTranspose} from './conv3d_transpose'; +export {cos} from './cos'; +export {cosh} from './cosh'; export {cumsum} from './cumsum'; export {depthToSpace} from './depth_to_space'; export {depthwiseConv2d} from './depthwise_conv2d'; @@ -123,6 +125,8 @@ export {reverse4d} from './reverse_4d'; export {selu} from './selu'; export {separableConv2d} from './separable_conv2d'; export {sign} from './sign'; +export {sin} from './sin'; +export {sinh} from './sinh'; export {spaceToBatchND} from './space_to_batch_nd'; export {split} from './split'; export {square} from './square'; @@ -131,6 +135,8 @@ export {squeeze} from './squeeze'; export {stack} from './stack'; export {sub} from './sub'; export {sum} from './sum'; +export {tan} from './tan'; +export {tanh} from './tanh'; export {tile} from './tile'; export {truncatedNormal} from './truncated_normal'; export {unsortedSegmentSum} from './unsorted_segment_sum'; diff --git a/tfjs-core/src/ops/sin.ts b/tfjs-core/src/ops/sin.ts new file mode 100644 index 00000000000..1c228e1851f --- /dev/null +++ b/tfjs-core/src/ops/sin.ts @@ -0,0 +1,49 @@ +/** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {ENGINE} from '../engine'; +import {Sin, SinInputs} from '../kernel_names'; +import {Tensor} from '../tensor'; +import {NamedTensorMap} from '../tensor_types'; +import {convertToTensor} from '../tensor_util_env'; +import {TensorLike} from '../types'; + +import {op} from './operation'; + +/** + * Computes sin of the input Tensor element-wise: `sin(x)` + * + * ```js + * const x = tf.tensor1d([0, Math.PI / 2, Math.PI * 3 / 4]); + * + * x.sin().print(); // or tf.sin(x) + * ``` + * @param x The input tensor. + */ +/** @doc {heading: 'Operations', subheading: 'Basic math'} */ +function sin_(x: T|TensorLike): T { + const $x = convertToTensor(x, 'x', 'sin'); + + const inputs: SinInputs = {x: $x}; + + return ENGINE.runKernelFunc((backend, save) => { + const res = backend.sin($x); + save([$x]); + return res; + }, inputs as {} as NamedTensorMap, null /* grad */, Sin); +} +export const sin = op({sin_}); diff --git a/tfjs-core/src/ops/sin_test.ts b/tfjs-core/src/ops/sin_test.ts new file mode 100644 index 00000000000..536583e1c12 --- /dev/null +++ b/tfjs-core/src/ops/sin_test.ts @@ -0,0 +1,111 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import * as tf from '../index'; +import {ALL_ENVS, describeWithFlags} from '../jasmine_util'; +import {expectArraysClose} from '../test_util'; + +describeWithFlags('sin', ALL_ENVS, () => { + it('basic', async () => { + const values = [1, -3, 2, 7, -4]; + const a = tf.tensor1d(values); + const result = tf.sin(a); + + const expected = []; + for (let i = 0; i < a.size; i++) { + expected[i] = Math.sin(values[i]); + } + expectArraysClose(await result.data(), expected); + }); + + it('propagates NaNs', async () => { + const a = tf.tensor1d([4, NaN, 0]); + const res = tf.sin(a); + expectArraysClose(await res.data(), [Math.sin(4), NaN, Math.sin(0)]); + }); + + it('gradients: Scalar', async () => { + const a = tf.scalar(5); + const dy = tf.scalar(8); + + const gradients = tf.grad(a => tf.sin(a))(a, dy); + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), [8 * Math.cos(5)]); + }); + + it('gradient with clones', async () => { + const a = tf.scalar(5); + const dy = tf.scalar(8); + + const gradients = tf.grad(a => tf.sin(a.clone()).clone())(a, dy); + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), [8 * Math.cos(5)]); + }); + + it('gradients: Tensor1D', async () => { + const a = tf.tensor1d([-1, 2, 3, -5]); + const dy = tf.tensor1d([1, 2, 3, 4]); + + const gradients = tf.grad(a => tf.sin(a))(a, dy); + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose( + await gradients.data(), + [1 * Math.cos(-1), 2 * Math.cos(2), 3 * Math.cos(3), 4 * Math.cos(-5)], + 1e-1); + }); + + it('gradients: Tensor2D', async () => { + const a = tf.tensor2d([-3, 1, 2, 3], [2, 2]); + const dy = tf.tensor2d([1, 2, 3, 4], [2, 2]); + + const gradients = tf.grad(a => tf.sin(a))(a, dy); + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose( + await gradients.data(), + [1 * Math.cos(-3), 2 * Math.cos(1), 3 * Math.cos(2), 4 * Math.cos(3)], + 1e-1); + }); + + it('throws when passed a non-tensor', () => { + expect(() => tf.sin({} as tf.Tensor)) + .toThrowError(/Argument 'x' passed to 'sin' must be a Tensor/); + }); + + it('accepts a tensor-like object', async () => { + const values = [1, -3, 2, 7, -4]; + const result = tf.sin(values); + + const expected = []; + for (let i = 0; i < values.length; i++) { + expected[i] = Math.sin(values[i]); + } + expectArraysClose(await result.data(), expected); + }); + + it('throws for string tensor', () => { + expect(() => tf.sin('q')) + .toThrowError(/Argument 'x' passed to 'sin' must be numeric/); + }); +}); diff --git a/tfjs-core/src/ops/sinh.ts b/tfjs-core/src/ops/sinh.ts new file mode 100644 index 00000000000..010f42a9f43 --- /dev/null +++ b/tfjs-core/src/ops/sinh.ts @@ -0,0 +1,48 @@ +/** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {ENGINE} from '../engine'; +import {Sinh, SinhInputs} from '../kernel_names'; +import {Tensor} from '../tensor'; +import {NamedTensorMap} from '../tensor_types'; +import {convertToTensor} from '../tensor_util_env'; +import {TensorLike} from '../types'; + +import {op} from './operation'; + +/** + * Computes hyperbolic sin of the input `tf.Tensor` element-wise: `sinh(x)` + * + * ```js + * const x = tf.tensor1d([0, 1, -1, .7]); + * + * x.sinh().print(); // or tf.sinh(x) + * ``` + * @param x The input tensor. + */ +/** @doc {heading: 'Operations', subheading: 'Basic math'} */ +function sinh_(x: T|TensorLike): T { + const $x = convertToTensor(x, 'x', 'sinh'); + const inputs: SinhInputs = {x: $x}; + + return ENGINE.runKernelFunc((backend, save) => { + const res = backend.sinh($x); + save([$x]); + return res; + }, inputs as {} as NamedTensorMap, null /* grad */, Sinh); +} +export const sinh = op({sinh_}); diff --git a/tfjs-core/src/ops/sinh_test.ts b/tfjs-core/src/ops/sinh_test.ts new file mode 100644 index 00000000000..efd6e86b92f --- /dev/null +++ b/tfjs-core/src/ops/sinh_test.ts @@ -0,0 +1,119 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import * as tf from '../index'; +import {ALL_ENVS, describeWithFlags} from '../jasmine_util'; +import {expectArraysClose} from '../test_util'; + +describeWithFlags('sinh', ALL_ENVS, () => { + it('basic', async () => { + const values = [1, -3, 2, -1, -4]; + const a = tf.tensor1d(values); + const result = tf.sinh(a); + + const expected = []; + for (let i = 0; i < a.size; i++) { + expected[i] = Math.sinh(values[i]); + } + expectArraysClose(await result.data(), expected); + }); + + it('propagates NaNs', async () => { + const a = tf.tensor1d([4, NaN, 0]); + const res = tf.sinh(a); + expectArraysClose(await res.data(), [Math.sinh(4), NaN, Math.sinh(0)]); + }); + + it('gradients: Scalar', async () => { + const a = tf.scalar(0.5); + const dy = tf.scalar(8); + + const gradients = tf.grad(a => tf.sinh(a))(a, dy); + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), [8 * Math.cosh(0.5)]); + }); + + it('gradient with clones', async () => { + const a = tf.scalar(0.5); + const dy = tf.scalar(8); + + const gradients = tf.grad(a => tf.sinh(a.clone()).clone())(a, dy); + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), [8 * Math.cosh(0.5)]); + }); + + it('gradients: Tensor1D', async () => { + const aValues = [-1, 2, 3, -5]; + const dyValues = [1, 2, 3, 4]; + const a = tf.tensor1d(aValues); + const dy = tf.tensor1d(dyValues); + + const gradients = tf.grad(a => tf.sinh(a))(a, dy); + + const expected = []; + for (let i = 0; i < a.size; i++) { + expected[i] = dyValues[i] * Math.cosh(aValues[i]); + } + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), expected); + }); + + it('gradients: Tensor2D', async () => { + const aValues = [-3, 1, 2, 3]; + const dyValues = [1, 2, 3, 4]; + const a = tf.tensor2d(aValues, [2, 2]); + const dy = tf.tensor2d(dyValues, [2, 2]); + + const gradients = tf.grad(a => tf.sinh(a))(a, dy); + + const expected = []; + for (let i = 0; i < a.size; i++) { + expected[i] = dyValues[i] * Math.cosh(aValues[i]); + } + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), expected); + }); + + it('throws when passed a non-tensor', () => { + expect(() => tf.sinh({} as tf.Tensor)) + .toThrowError(/Argument 'x' passed to 'sinh' must be a Tensor/); + }); + + it('accepts a tensor-like object', async () => { + const values = [1, -3, 2, -1, -4]; + const result = tf.sinh(values); + + const expected = []; + for (let i = 0; i < values.length; i++) { + expected[i] = Math.sinh(values[i]); + } + expectArraysClose(await result.data(), expected); + }); + + it('throws for string tensor', () => { + expect(() => tf.sinh('q')) + .toThrowError(/Argument 'x' passed to 'sinh' must be numeric/); + }); +}); diff --git a/tfjs-core/src/ops/tan.ts b/tfjs-core/src/ops/tan.ts new file mode 100644 index 00000000000..20055d4f0d4 --- /dev/null +++ b/tfjs-core/src/ops/tan.ts @@ -0,0 +1,49 @@ +/** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {ENGINE} from '../engine'; +import {Tan, TanInputs} from '../kernel_names'; +import {Tensor} from '../tensor'; +import {NamedTensorMap} from '../tensor_types'; +import {convertToTensor} from '../tensor_util_env'; +import {TensorLike} from '../types'; + +import {op} from './operation'; + +/** + * Computes tan of the input `tf.Tensor` element-wise, `tan(x)` + * + * ```js + * const x = tf.tensor1d([0, Math.PI / 2, Math.PI * 3 / 4]); + * + * x.tan().print(); // or tf.tan(x) + * ``` + * @param x The input tensor. + */ +/** @doc {heading: 'Operations', subheading: 'Basic math'} */ +function tan_(x: T|TensorLike): T { + const $x = convertToTensor(x, 'x', 'tan'); + + const inputs: TanInputs = {x: $x}; + + return ENGINE.runKernelFunc((backend, save) => { + const res = backend.tan($x); + save([$x]); + return res; + }, inputs as {} as NamedTensorMap, null /* grad */, Tan); +} +export const tan = op({tan_}); diff --git a/tfjs-core/src/ops/tan_test.ts b/tfjs-core/src/ops/tan_test.ts new file mode 100644 index 00000000000..ba3661ac8dc --- /dev/null +++ b/tfjs-core/src/ops/tan_test.ts @@ -0,0 +1,122 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import * as tf from '../index'; +import {ALL_ENVS, describeWithFlags} from '../jasmine_util'; +import {expectArraysClose, TEST_EPSILON_FLOAT16} from '../test_util'; + +describeWithFlags('tan', ALL_ENVS, () => { + it('basic', async () => { + const values = [1, -3, 2, 7, -4]; + const a = tf.tensor1d(values); + const result = tf.tan(a); + + const expected = []; + for (let i = 0; i < a.size; i++) { + expected[i] = Math.tan(values[i]); + } + expectArraysClose(await result.data(), expected); + }); + + it('propagates NaNs', async () => { + const a = tf.tensor1d([4, NaN, 0]); + const res = tf.tan(a); + expectArraysClose(await res.data(), [Math.tan(4), NaN, Math.tan(0)]); + }); + + it('gradients: Scalar', async () => { + const a = tf.scalar(0.5); + const dy = tf.scalar(8); + + const gradients = tf.grad(a => tf.tan(a))(a, dy); + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose( + await gradients.data(), [8 / (Math.cos(0.5) * Math.cos(0.5))]); + }); + + it('gradient with clones', async () => { + const a = tf.scalar(0.5); + const dy = tf.scalar(8); + + const gradients = tf.grad(a => tf.tan(a.clone()).clone())(a, dy); + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose( + await gradients.data(), [8 / (Math.cos(0.5) * Math.cos(0.5))]); + }); + + it('gradients: Tensor1D', async () => { + const aValues = [-1, 2, 3, -5]; + const dyValues = [1, 2, 3, 4]; + const a = tf.tensor1d(aValues); + const dy = tf.tensor1d(dyValues); + + const gradients = tf.grad(a => tf.tan(a))(a, dy); + + const expected = []; + for (let i = 0; i < a.size; i++) { + expected[i] = dyValues[i] / (Math.cos(aValues[i]) * Math.cos(aValues[i])); + } + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + // The grad(tan(x)) which relies on 1/cos(x) is less precise on Windows. + expectArraysClose(await gradients.data(), expected, TEST_EPSILON_FLOAT16); + }); + + it('gradients: Tensor2D', async () => { + const aValues = [-3, 1, 2, 3]; + const dyValues = [1, 2, 3, 4]; + const a = tf.tensor2d(aValues, [2, 2]); + const dy = tf.tensor2d(dyValues, [2, 2]); + + const gradients = tf.grad(a => tf.tan(a))(a, dy); + + const expected = []; + for (let i = 0; i < a.size; i++) { + expected[i] = dyValues[i] / (Math.cos(aValues[i]) * Math.cos(aValues[i])); + } + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), expected); + }); + + it('throws when passed a non-tensor', () => { + expect(() => tf.tan({} as tf.Tensor)) + .toThrowError(/Argument 'x' passed to 'tan' must be a Tensor/); + }); + + it('accepts a tensor-like object', async () => { + const values = [1, -3, 2, 7, -4]; + const result = tf.tan(values); + + const expected = []; + for (let i = 0; i < values.length; i++) { + expected[i] = Math.tan(values[i]); + } + expectArraysClose(await result.data(), expected); + }); + + it('throws for string tensor', () => { + expect(() => tf.tan('q')) + .toThrowError(/Argument 'x' passed to 'tan' must be numeric/); + }); +}); diff --git a/tfjs-core/src/ops/tanh.ts b/tfjs-core/src/ops/tanh.ts new file mode 100644 index 00000000000..a2a1555a310 --- /dev/null +++ b/tfjs-core/src/ops/tanh.ts @@ -0,0 +1,49 @@ +/** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {ENGINE} from '../engine'; +import {Tanh, TanhInputs} from '../kernel_names'; +import {Tensor} from '../tensor'; +import {NamedTensorMap} from '../tensor_types'; +import {convertToTensor} from '../tensor_util_env'; +import {TensorLike} from '../types'; + +import {op} from './operation'; + +/** + * Computes hyperbolic tangent of the input `tf.Tensor` element-wise: `tanh(x)` + * + * ```js + * const x = tf.tensor1d([0, 1, -1, 70]); + * + * x.tanh().print(); // or tf.tanh(x) + * ``` + * @param x The input tensor. + */ +/** @doc {heading: 'Operations', subheading: 'Basic math'} */ +function tanh_(x: T|TensorLike): T { + const $x = convertToTensor(x, 'x', 'tanh'); + + const inputs: TanhInputs = {x: $x}; + + return ENGINE.runKernelFunc((backend, save) => { + const y = backend.tanh($x); + save([y]); + return y; + }, inputs as {} as NamedTensorMap, null /* grad */, Tanh); +} +export const tanh = op({tanh_}); diff --git a/tfjs-core/src/ops/tanh_test.ts b/tfjs-core/src/ops/tanh_test.ts new file mode 100644 index 00000000000..18ad3655f13 --- /dev/null +++ b/tfjs-core/src/ops/tanh_test.ts @@ -0,0 +1,124 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import * as tf from '../index'; +import {ALL_ENVS, describeWithFlags} from '../jasmine_util'; +import {expectArraysClose} from '../test_util'; +import * as util from '../util'; + +describeWithFlags('tanh', ALL_ENVS, () => { + it('basic', async () => { + const values = [1, -3, 2, 7, -4]; + const a = tf.tensor1d(values); + const result = tf.tanh(a); + + const expected = []; + for (let i = 0; i < a.size; i++) { + expected[i] = util.tanh(values[i]); + } + expectArraysClose(await result.data(), expected); + }); + + it('propagates NaNs', async () => { + const a = tf.tensor1d([4, NaN, 0]); + const res = tf.tanh(a); + expectArraysClose(await res.data(), [util.tanh(4), NaN, util.tanh(0)]); + }); + + it('gradients: Scalar', async () => { + const a = tf.scalar(0.5); + const dy = tf.scalar(8); + + const gradients = tf.grad(a => tf.tanh(a))(a, dy); + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose( + await gradients.data(), [8 * (1 - (Math.tanh(0.5) * Math.tanh(0.5)))]); + }); + + it('gradient with clones', async () => { + const a = tf.scalar(0.5); + const dy = tf.scalar(8); + + const gradients = tf.grad(a => tf.tanh(a.clone()).clone())(a, dy); + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose( + await gradients.data(), [8 * (1 - (Math.tanh(0.5) * Math.tanh(0.5)))]); + }); + + it('gradients: Tensor1D', async () => { + const aValues = [-1, 2, 3, -5]; + const dyValues = [1, 2, 3, 4]; + const a = tf.tensor1d(aValues); + const dy = tf.tensor1d(dyValues); + + const gradients = tf.grad(a => tf.tanh(a))(a, dy); + + const expected = []; + for (let i = 0; i < a.size; i++) { + expected[i] = + dyValues[i] * (1 - (Math.tanh(aValues[i]) * Math.tanh(aValues[i]))); + } + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), expected); + }); + + it('gradients: Tensor2D', async () => { + const aValues = [-3, 1, 2, 3]; + const dyValues = [1, 2, 3, 4]; + const a = tf.tensor2d(aValues, [2, 2]); + const dy = tf.tensor2d(dyValues, [2, 2]); + + const gradients = tf.grad(a => tf.tanh(a))(a, dy); + + const expected = []; + for (let i = 0; i < a.size; i++) { + expected[i] = + dyValues[i] * (1 - (Math.tanh(aValues[i]) * Math.tanh(aValues[i]))); + } + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), expected); + }); + + it('throws when passed a non-tensor', () => { + expect(() => tf.tanh({} as tf.Tensor)) + .toThrowError(/Argument 'x' passed to 'tanh' must be a Tensor/); + }); + + it('accepts a tensor-like object', async () => { + const values = [1, -3, 2, 7, -4]; + const result = tf.tanh(values); + + const expected = []; + for (let i = 0; i < values.length; i++) { + expected[i] = util.tanh(values[i]); + } + expectArraysClose(await result.data(), expected); + }); + + it('throws for string tensor', () => { + expect(() => tf.tanh('q')) + .toThrowError(/Argument 'x' passed to 'tanh' must be numeric/); + }); +}); diff --git a/tfjs-core/src/ops/unary_ops.ts b/tfjs-core/src/ops/unary_ops.ts index dfc16f520f7..13fb7f66a25 100644 --- a/tfjs-core/src/ops/unary_ops.ts +++ b/tfjs-core/src/ops/unary_ops.ts @@ -414,165 +414,6 @@ function softplus_(x: T|TensorLike): T { }, {$x}, grad); } -/** - * Computes sin of the input Tensor element-wise: `sin(x)` - * - * ```js - * const x = tf.tensor1d([0, Math.PI / 2, Math.PI * 3 / 4]); - * - * x.sin().print(); // or tf.sin(x) - * ``` - * @param x The input tensor. - */ -/** @doc {heading: 'Operations', subheading: 'Basic math'} */ -function sin_(x: T|TensorLike): T { - const $x = convertToTensor(x, 'x', 'sin'); - - const grad = (dy: T, saved: Tensor[]) => { - const [$x] = saved; - return {x: () => $x.toFloat().cos().mul(dy)} as {x: () => T}; - }; - const inputsToSave = [$x]; - return ENGINE.runKernelFunc((backend, save) => { - const res = backend.sin($x); - save([$x]); - return res; - }, {x: $x}, grad, 'Sin', {} /* attrs */, inputsToSave); -} - -/** - * Computes cos of the input `tf.Tensor` element-wise: `cos(x)` - * - * ```js - * const x = tf.tensor1d([0, Math.PI / 2, Math.PI * 3 / 4]); - * - * x.cos().print(); // or tf.cos(x) - * ``` - * @param x The input tensor. - */ -/** @doc {heading: 'Operations', subheading: 'Basic math'} */ -function cos_(x: T|TensorLike): T { - const $x = convertToTensor(x, 'x', 'cos'); - - const grad = (dy: T, saved: Tensor[]) => { - const [$x] = saved; - return {x: () => $x.toFloat().sin().neg().mul(dy)} as {x: () => T}; - }; - const inputsToSave = [$x]; - return ENGINE.runKernelFunc((backend, save) => { - const res = backend.cos($x); - save([$x]); - return res; - }, {x: $x}, grad, 'Cos', {} /* attrs */, inputsToSave); -} - -/** - * Computes tan of the input `tf.Tensor` element-wise, `tan(x)` - * - * ```js - * const x = tf.tensor1d([0, Math.PI / 2, Math.PI * 3 / 4]); - * - * x.tan().print(); // or tf.tan(x) - * ``` - * @param x The input tensor. - */ -/** @doc {heading: 'Operations', subheading: 'Basic math'} */ -function tan_(x: T|TensorLike): T { - const $x = convertToTensor(x, 'x', 'tan'); - - const grad = (dy: T, saved: Tensor[]) => { - const [$x] = saved; - return {$x: () => dy.div($x.cos().square())} as {$x: () => T}; - }; - return ENGINE.runKernelFunc((backend, save) => { - const res = backend.tan($x); - save([$x]); - return res; - }, {$x}, grad); -} - -/** - * Computes hyperbolic sin of the input `tf.Tensor` element-wise: `sinh(x)` - * - * ```js - * const x = tf.tensor1d([0, 1, -1, .7]); - * - * x.sinh().print(); // or tf.sinh(x) - * ``` - * @param x The input tensor. - */ -/** @doc {heading: 'Operations', subheading: 'Basic math'} */ -function sinh_(x: T|TensorLike): T { - const $x = convertToTensor(x, 'x', 'sinh'); - - const grad = (dy: T, saved: Tensor[]) => { - const [$x] = saved; - // tslint:disable-next-line: no-unnecessary-type-assertion - return {$x: () => $x.toFloat().cosh().mul(dy) as T}; - }; - return ENGINE.runKernelFunc((backend, save) => { - const res = backend.sinh($x); - save([$x]); - return res; - }, {$x}, grad); -} - -/** - * Computes hyperbolic cos of the input `tf.Tensor` element-wise: `cosh(x)` - * - * ```js - * const x = tf.tensor1d([0, 1, -1, .7]); - * - * x.cosh().print(); // or tf.cosh(x) - * ``` - * @param x The input tensor. - */ -/** @doc {heading: 'Operations', subheading: 'Basic math'} */ -function cosh_(x: T|TensorLike): T { - const $x = convertToTensor(x, 'x', 'cosh'); - - const grad = (dy: T, saved: Tensor[]) => { - const [$x] = saved; - // tslint:disable-next-line: no-unnecessary-type-assertion - return {$x: () => $x.toFloat().sinh().mul(dy) as T}; - }; - return ENGINE.runKernelFunc((backend, save) => { - const res = backend.cosh($x); - save([$x]); - return res; - }, {$x}, grad); -} - -/** - * Computes hyperbolic tangent of the input `tf.Tensor` element-wise: `tanh(x)` - * - * ```js - * const x = tf.tensor1d([0, 1, -1, 70]); - * - * x.tanh().print(); // or tf.tanh(x) - * ``` - * @param x The input tensor. - */ -/** @doc {heading: 'Operations', subheading: 'Basic math'} */ -function tanh_(x: T|TensorLike): T { - const $x = convertToTensor(x, 'x', 'tanh'); - - const grad = (dy: T, saved: Tensor[]) => { - const [y] = saved; - // tslint:disable-next-line: no-unnecessary-type-assertion - return {x: () => scalar(1).sub(y.square()).mul(dy) as T}; - }; - const outputsToSave = [true]; - return ENGINE.runKernelFunc( - (backend, save) => { - const y = backend.tanh($x); - save([y]); - return y; - }, - {x: $x}, grad, 'Tanh', {} /* attrs */, null /* inputsToSave */, - outputsToSave); -} - /** * Computes gause error function of the input `tf.Tensor` element-wise: * `erf(x)` @@ -632,8 +473,6 @@ function step_(x: T|TensorLike, alpha = 0.0): T { } export const clipByValue = op({clipByValue_}); -export const cos = op({cos_}); -export const cosh = op({cosh_}); export const erf = op({erf_}); export const exp = op({exp_}); export const expm1 = op({expm1_}); @@ -647,10 +486,6 @@ export const sigmoid = op({sigmoid_}); export const isNaN = op({isNaN_}); export const isInf = op({isInf_}); export const isFinite = op({isFinite_}); -export const sin = op({sin_}); -export const sinh = op({sinh_}); export const softplus = op({softplus_}); export const sqrt = op({sqrt_}); export const step = op({step_}); -export const tan = op({tan_}); -export const tanh = op({tanh_}); diff --git a/tfjs-core/src/ops/unary_ops_test.ts b/tfjs-core/src/ops/unary_ops_test.ts index 3bb25750325..a45ca1d2810 100644 --- a/tfjs-core/src/ops/unary_ops_test.ts +++ b/tfjs-core/src/ops/unary_ops_test.ts @@ -17,8 +17,7 @@ import * as tf from '../index'; import {ALL_ENVS, describeWithFlags} from '../jasmine_util'; -import {expectArraysClose, TEST_EPSILON_FLOAT16} from '../test_util'; -import * as util from '../util'; +import {expectArraysClose} from '../test_util'; describeWithFlags('step', ALL_ENVS, () => { it('with 1d tensor', async () => { @@ -1343,599 +1342,6 @@ describeWithFlags('expm1', ALL_ENVS, () => { }); }); -describeWithFlags('sin', ALL_ENVS, () => { - it('basic', async () => { - const values = [1, -3, 2, 7, -4]; - const a = tf.tensor1d(values); - const result = tf.sin(a); - - const expected = []; - for (let i = 0; i < a.size; i++) { - expected[i] = Math.sin(values[i]); - } - expectArraysClose(await result.data(), expected); - }); - - it('propagates NaNs', async () => { - const a = tf.tensor1d([4, NaN, 0]); - const res = tf.sin(a); - expectArraysClose(await res.data(), [Math.sin(4), NaN, Math.sin(0)]); - }); - - it('gradients: Scalar', async () => { - const a = tf.scalar(5); - const dy = tf.scalar(8); - - const gradients = tf.grad(a => tf.sin(a))(a, dy); - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), [8 * Math.cos(5)]); - }); - - it('gradient with clones', async () => { - const a = tf.scalar(5); - const dy = tf.scalar(8); - - const gradients = tf.grad(a => tf.sin(a.clone()).clone())(a, dy); - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), [8 * Math.cos(5)]); - }); - - it('gradients: Tensor1D', async () => { - const a = tf.tensor1d([-1, 2, 3, -5]); - const dy = tf.tensor1d([1, 2, 3, 4]); - - const gradients = tf.grad(a => tf.sin(a))(a, dy); - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose( - await gradients.data(), - [1 * Math.cos(-1), 2 * Math.cos(2), 3 * Math.cos(3), 4 * Math.cos(-5)], - 1e-1); - }); - - it('gradients: Tensor2D', async () => { - const a = tf.tensor2d([-3, 1, 2, 3], [2, 2]); - const dy = tf.tensor2d([1, 2, 3, 4], [2, 2]); - - const gradients = tf.grad(a => tf.sin(a))(a, dy); - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose( - await gradients.data(), - [1 * Math.cos(-3), 2 * Math.cos(1), 3 * Math.cos(2), 4 * Math.cos(3)], - 1e-1); - }); - - it('throws when passed a non-tensor', () => { - expect(() => tf.sin({} as tf.Tensor)) - .toThrowError(/Argument 'x' passed to 'sin' must be a Tensor/); - }); - - it('accepts a tensor-like object', async () => { - const values = [1, -3, 2, 7, -4]; - const result = tf.sin(values); - - const expected = []; - for (let i = 0; i < values.length; i++) { - expected[i] = Math.sin(values[i]); - } - expectArraysClose(await result.data(), expected); - }); - - it('throws for string tensor', () => { - expect(() => tf.sin('q')) - .toThrowError(/Argument 'x' passed to 'sin' must be numeric/); - }); -}); - -describeWithFlags('cos', ALL_ENVS, () => { - it('basic', async () => { - const values = [1, -3, 2, 7, -4]; - const a = tf.tensor1d(values); - const result = tf.cos(a); - - const expected = []; - for (let i = 0; i < a.size; i++) { - expected[i] = Math.cos(values[i]); - } - expectArraysClose(await result.data(), expected); - }); - - it('propagates NaNs', async () => { - const a = tf.tensor1d([4, NaN, 0]); - const res = tf.cos(a); - expectArraysClose(await res.data(), [Math.cos(4), NaN, Math.cos(0)]); - }); - - it('gradients: Scalar', async () => { - const a = tf.scalar(5); - const dy = tf.scalar(8); - - const gradients = tf.grad(a => tf.cos(a))(a, dy); - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), [8 * Math.sin(5) * -1]); - }); - - it('gradient with clones', async () => { - const a = tf.scalar(5); - const dy = tf.scalar(8); - - const gradients = tf.grad(a => tf.cos(a.clone()).clone())(a, dy); - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), [8 * Math.sin(5) * -1]); - }); - - it('gradients: Tensor1D', async () => { - const a = tf.tensor1d([-1, 2, 3, -5]); - const dy = tf.tensor1d([1, 2, 3, 4]); - - const gradients = tf.grad(a => tf.cos(a))(a, dy); - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose( - await gradients.data(), - [ - 1 * Math.sin(-1) * -1, 2 * Math.sin(2) * -1, 3 * Math.sin(3) * -1, - 4 * Math.sin(-5) * -1 - ], - 1e-1); - }); - - it('gradients: Tensor2D', async () => { - const a = tf.tensor2d([-3, 1, 2, 3], [2, 2]); - const dy = tf.tensor2d([1, 2, 3, 4], [2, 2]); - - const gradients = tf.grad(a => tf.cos(a))(a, dy); - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose( - await gradients.data(), - [ - 1 * Math.sin(-3) * -1, 2 * Math.sin(1) * -1, 3 * Math.sin(2) * -1, - 4 * Math.sin(3) * -1 - ], - 1e-1); - }); - - it('throws when passed a non-tensor', () => { - expect(() => tf.cos({} as tf.Tensor)) - .toThrowError(/Argument 'x' passed to 'cos' must be a Tensor/); - }); - - it('accepts a tensor-like object', async () => { - const values = [1, -3, 2, 7, -4]; - const result = tf.cos(values); - - const expected = []; - for (let i = 0; i < values.length; i++) { - expected[i] = Math.cos(values[i]); - } - expectArraysClose(await result.data(), expected); - }); - - it('throws for string tensor', () => { - expect(() => tf.cos('q')) - .toThrowError(/Argument 'x' passed to 'cos' must be numeric/); - }); -}); - -describeWithFlags('tan', ALL_ENVS, () => { - it('basic', async () => { - const values = [1, -3, 2, 7, -4]; - const a = tf.tensor1d(values); - const result = tf.tan(a); - - const expected = []; - for (let i = 0; i < a.size; i++) { - expected[i] = Math.tan(values[i]); - } - expectArraysClose(await result.data(), expected); - }); - - it('propagates NaNs', async () => { - const a = tf.tensor1d([4, NaN, 0]); - const res = tf.tan(a); - expectArraysClose(await res.data(), [Math.tan(4), NaN, Math.tan(0)]); - }); - - it('gradients: Scalar', async () => { - const a = tf.scalar(0.5); - const dy = tf.scalar(8); - - const gradients = tf.grad(a => tf.tan(a))(a, dy); - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose( - await gradients.data(), [8 / (Math.cos(0.5) * Math.cos(0.5))]); - }); - - it('gradient with clones', async () => { - const a = tf.scalar(0.5); - const dy = tf.scalar(8); - - const gradients = tf.grad(a => tf.tan(a.clone()).clone())(a, dy); - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose( - await gradients.data(), [8 / (Math.cos(0.5) * Math.cos(0.5))]); - }); - - it('gradients: Tensor1D', async () => { - const aValues = [-1, 2, 3, -5]; - const dyValues = [1, 2, 3, 4]; - const a = tf.tensor1d(aValues); - const dy = tf.tensor1d(dyValues); - - const gradients = tf.grad(a => tf.tan(a))(a, dy); - - const expected = []; - for (let i = 0; i < a.size; i++) { - expected[i] = dyValues[i] / (Math.cos(aValues[i]) * Math.cos(aValues[i])); - } - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - // The grad(tan(x)) which relies on 1/cos(x) is less precise on Windows. - expectArraysClose(await gradients.data(), expected, TEST_EPSILON_FLOAT16); - }); - - it('gradients: Tensor2D', async () => { - const aValues = [-3, 1, 2, 3]; - const dyValues = [1, 2, 3, 4]; - const a = tf.tensor2d(aValues, [2, 2]); - const dy = tf.tensor2d(dyValues, [2, 2]); - - const gradients = tf.grad(a => tf.tan(a))(a, dy); - - const expected = []; - for (let i = 0; i < a.size; i++) { - expected[i] = dyValues[i] / (Math.cos(aValues[i]) * Math.cos(aValues[i])); - } - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), expected); - }); - - it('throws when passed a non-tensor', () => { - expect(() => tf.tan({} as tf.Tensor)) - .toThrowError(/Argument 'x' passed to 'tan' must be a Tensor/); - }); - - it('accepts a tensor-like object', async () => { - const values = [1, -3, 2, 7, -4]; - const result = tf.tan(values); - - const expected = []; - for (let i = 0; i < values.length; i++) { - expected[i] = Math.tan(values[i]); - } - expectArraysClose(await result.data(), expected); - }); - - it('throws for string tensor', () => { - expect(() => tf.tan('q')) - .toThrowError(/Argument 'x' passed to 'tan' must be numeric/); - }); -}); - -describeWithFlags('sinh', ALL_ENVS, () => { - it('basic', async () => { - const values = [1, -3, 2, -1, -4]; - const a = tf.tensor1d(values); - const result = tf.sinh(a); - - const expected = []; - for (let i = 0; i < a.size; i++) { - expected[i] = Math.sinh(values[i]); - } - expectArraysClose(await result.data(), expected); - }); - - it('propagates NaNs', async () => { - const a = tf.tensor1d([4, NaN, 0]); - const res = tf.sinh(a); - expectArraysClose(await res.data(), [Math.sinh(4), NaN, Math.sinh(0)]); - }); - - it('gradients: Scalar', async () => { - const a = tf.scalar(0.5); - const dy = tf.scalar(8); - - const gradients = tf.grad(a => tf.sinh(a))(a, dy); - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), [8 * Math.cosh(0.5)]); - }); - - it('gradient with clones', async () => { - const a = tf.scalar(0.5); - const dy = tf.scalar(8); - - const gradients = tf.grad(a => tf.sinh(a.clone()).clone())(a, dy); - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), [8 * Math.cosh(0.5)]); - }); - - it('gradients: Tensor1D', async () => { - const aValues = [-1, 2, 3, -5]; - const dyValues = [1, 2, 3, 4]; - const a = tf.tensor1d(aValues); - const dy = tf.tensor1d(dyValues); - - const gradients = tf.grad(a => tf.sinh(a))(a, dy); - - const expected = []; - for (let i = 0; i < a.size; i++) { - expected[i] = dyValues[i] * Math.cosh(aValues[i]); - } - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), expected); - }); - - it('gradients: Tensor2D', async () => { - const aValues = [-3, 1, 2, 3]; - const dyValues = [1, 2, 3, 4]; - const a = tf.tensor2d(aValues, [2, 2]); - const dy = tf.tensor2d(dyValues, [2, 2]); - - const gradients = tf.grad(a => tf.sinh(a))(a, dy); - - const expected = []; - for (let i = 0; i < a.size; i++) { - expected[i] = dyValues[i] * Math.cosh(aValues[i]); - } - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), expected); - }); - - it('throws when passed a non-tensor', () => { - expect(() => tf.sinh({} as tf.Tensor)) - .toThrowError(/Argument 'x' passed to 'sinh' must be a Tensor/); - }); - - it('accepts a tensor-like object', async () => { - const values = [1, -3, 2, -1, -4]; - const result = tf.sinh(values); - - const expected = []; - for (let i = 0; i < values.length; i++) { - expected[i] = Math.sinh(values[i]); - } - expectArraysClose(await result.data(), expected); - }); - - it('throws for string tensor', () => { - expect(() => tf.sinh('q')) - .toThrowError(/Argument 'x' passed to 'sinh' must be numeric/); - }); -}); - -describeWithFlags('cosh', ALL_ENVS, () => { - it('basic', async () => { - const values = [1, -3, 2, -1, -4]; - const a = tf.tensor1d(values); - const result = tf.cosh(a); - - const expected = []; - for (let i = 0; i < a.size; i++) { - expected[i] = Math.cosh(values[i]); - } - - expectArraysClose(await result.data(), expected); - }); - - it('propagates NaNs', async () => { - const a = tf.tensor1d([4, NaN, 0]); - const res = tf.cosh(a); - expectArraysClose(await res.data(), [Math.cosh(4), NaN, Math.cosh(0)]); - }); - - it('gradients: Scalar', async () => { - const a = tf.scalar(0.5); - const dy = tf.scalar(8); - - const gradients = tf.grad(a => tf.cosh(a))(a, dy); - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), [8 * Math.sinh(0.5)]); - }); - - it('gradient with clones', async () => { - const a = tf.scalar(0.5); - const dy = tf.scalar(8); - - const gradients = tf.grad(a => tf.cosh(a.clone()).clone())(a, dy); - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), [8 * Math.sinh(0.5)]); - }); - - it('gradients: Tensor1D', async () => { - const aValues = [-1, 2, 3, -5]; - const dyValues = [1, 2, 3, 4]; - const a = tf.tensor1d(aValues); - const dy = tf.tensor1d(dyValues); - - const gradients = tf.grad(a => tf.cosh(a))(a, dy); - - const expected = []; - for (let i = 0; i < a.size; i++) { - expected[i] = dyValues[i] * Math.sinh(aValues[i]); - } - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), expected); - }); - - it('gradients: Tensor2D', async () => { - const aValues = [-3, 1, 2, 3]; - const dyValues = [1, 2, 3, 4]; - const a = tf.tensor2d(aValues, [2, 2]); - const dy = tf.tensor2d(dyValues, [2, 2]); - - const gradients = tf.grad(a => tf.cosh(a))(a, dy); - - const expected = []; - for (let i = 0; i < a.size; i++) { - expected[i] = dyValues[i] * Math.sinh(aValues[i]); - } - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), expected); - }); - - it('throws when passed a non-tensor', () => { - expect(() => tf.cosh({} as tf.Tensor)) - .toThrowError(/Argument 'x' passed to 'cosh' must be a Tensor/); - }); - - it('accepts a tensor-like object', async () => { - const values = [1, -3, 2, -1, -4]; - const result = tf.cosh(values); - - const expected = []; - for (let i = 0; i < values.length; i++) { - expected[i] = Math.cosh(values[i]); - } - - expectArraysClose(await result.data(), expected); - }); - - it('throws for string tensor', () => { - expect(() => tf.cosh('q')) - .toThrowError(/Argument 'x' passed to 'cosh' must be numeric/); - }); -}); - -describeWithFlags('tanh', ALL_ENVS, () => { - it('basic', async () => { - const values = [1, -3, 2, 7, -4]; - const a = tf.tensor1d(values); - const result = tf.tanh(a); - - const expected = []; - for (let i = 0; i < a.size; i++) { - expected[i] = util.tanh(values[i]); - } - expectArraysClose(await result.data(), expected); - }); - - it('propagates NaNs', async () => { - const a = tf.tensor1d([4, NaN, 0]); - const res = tf.tanh(a); - expectArraysClose(await res.data(), [util.tanh(4), NaN, util.tanh(0)]); - }); - - it('gradients: Scalar', async () => { - const a = tf.scalar(0.5); - const dy = tf.scalar(8); - - const gradients = tf.grad(a => tf.tanh(a))(a, dy); - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose( - await gradients.data(), [8 * (1 - (Math.tanh(0.5) * Math.tanh(0.5)))]); - }); - - it('gradient with clones', async () => { - const a = tf.scalar(0.5); - const dy = tf.scalar(8); - - const gradients = tf.grad(a => tf.tanh(a.clone()).clone())(a, dy); - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose( - await gradients.data(), [8 * (1 - (Math.tanh(0.5) * Math.tanh(0.5)))]); - }); - - it('gradients: Tensor1D', async () => { - const aValues = [-1, 2, 3, -5]; - const dyValues = [1, 2, 3, 4]; - const a = tf.tensor1d(aValues); - const dy = tf.tensor1d(dyValues); - - const gradients = tf.grad(a => tf.tanh(a))(a, dy); - - const expected = []; - for (let i = 0; i < a.size; i++) { - expected[i] = - dyValues[i] * (1 - (Math.tanh(aValues[i]) * Math.tanh(aValues[i]))); - } - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), expected); - }); - - it('gradients: Tensor2D', async () => { - const aValues = [-3, 1, 2, 3]; - const dyValues = [1, 2, 3, 4]; - const a = tf.tensor2d(aValues, [2, 2]); - const dy = tf.tensor2d(dyValues, [2, 2]); - - const gradients = tf.grad(a => tf.tanh(a))(a, dy); - - const expected = []; - for (let i = 0; i < a.size; i++) { - expected[i] = - dyValues[i] * (1 - (Math.tanh(aValues[i]) * Math.tanh(aValues[i]))); - } - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), expected); - }); - - it('throws when passed a non-tensor', () => { - expect(() => tf.tanh({} as tf.Tensor)) - .toThrowError(/Argument 'x' passed to 'tanh' must be a Tensor/); - }); - - it('accepts a tensor-like object', async () => { - const values = [1, -3, 2, 7, -4]; - const result = tf.tanh(values); - - const expected = []; - for (let i = 0; i < values.length; i++) { - expected[i] = util.tanh(values[i]); - } - expectArraysClose(await result.data(), expected); - }); - - it('throws for string tensor', () => { - expect(() => tf.tanh('q')) - .toThrowError(/Argument 'x' passed to 'tanh' must be numeric/); - }); -}); - describeWithFlags('clip', ALL_ENVS, () => { it('basic', async () => { const a = tf.tensor1d([3, -1, 0, 100, -7, 2]); diff --git a/tfjs-core/src/register_all_gradients.ts b/tfjs-core/src/register_all_gradients.ts index 7a6ec9f8daf..aaee434d501 100644 --- a/tfjs-core/src/register_all_gradients.ts +++ b/tfjs-core/src/register_all_gradients.ts @@ -36,6 +36,8 @@ import {concatGradConfig} from './gradients/Concat_grad'; import {conv2DGradConfig} from './gradients/Conv2D_grad'; import {conv2DBackpropInputGradConfig} from './gradients/Conv2DBackpropInput_grad'; import {conv3DGradConfig} from './gradients/Conv3D_grad'; +import {cosGradConfig} from './gradients/Cos_grad'; +import {coshGradConfig} from './gradients/Cosh_grad'; import {cumsumGradConfig} from './gradients/Cumsum_grad'; import {depthwiseConv2dNativeGradConfig} from './gradients/DepthwiseConv2dNative_grad'; import {dilation2dGradConfig} from './gradients/Dilation2D_grad'; @@ -70,12 +72,16 @@ import {reverseGradConfig} from './gradients/Reverse_grad'; import {selectV2PoolGradConfig} from './gradients/SelectV2_grad'; import {seluGradConfig} from './gradients/Selu_grad'; import {signGradConfig} from './gradients/Sign_grad'; +import {sinGradConfig} from './gradients/Sin_grad'; +import {sinhGradConfig} from './gradients/Sinh_grad'; import {spaceToBatchNDGradConfig} from './gradients/SpaceToBatchND_grad'; import {splitVGradConfig} from './gradients/SplitV_grad'; import {squareGradConfig} from './gradients/Square_grad'; import {squaredDifferenceGradConfig} from './gradients/SquaredDifference_grad'; import {subGradConfig} from './gradients/Sub_grad'; import {sumGradConfig} from './gradients/Sum_grad'; +import {tanGradConfig} from './gradients/Tan_grad'; +import {tanhGradConfig} from './gradients/Tanh_grad'; import {tileGradConfig} from './gradients/Tile_grad'; import {transposeGradConfig} from './gradients/Transpose_grad'; import {unpackGradConfig} from './gradients/Unpack_grad'; @@ -107,6 +113,8 @@ const gradConfigs: GradConfig[] = [ conv2DBackpropInputGradConfig, conv2DGradConfig, conv3DGradConfig, + cosGradConfig, + coshGradConfig, cumsumGradConfig, depthwiseConv2dNativeGradConfig, dilation2dGradConfig, @@ -144,6 +152,8 @@ const gradConfigs: GradConfig[] = [ selectV2PoolGradConfig, seluGradConfig, signGradConfig, + sinGradConfig, + sinhGradConfig, spaceToBatchNDGradConfig, spaceToBatchNDGradConfig, splitVGradConfig, @@ -152,6 +162,8 @@ const gradConfigs: GradConfig[] = [ squareGradConfig, subGradConfig, sumGradConfig, + tanGradConfig, + tanhGradConfig, tileGradConfig, transposeGradConfig, unpackGradConfig, diff --git a/tfjs-core/src/tests.ts b/tfjs-core/src/tests.ts index ee53aa9fd37..5fc04f12ba4 100644 --- a/tfjs-core/src/tests.ts +++ b/tfjs-core/src/tests.ts @@ -81,6 +81,8 @@ import './ops/conv2d_transpose_test'; import './ops/conv3d_test'; import './ops/conv3d_transpose_test'; import './ops/conv_util_test'; +import './ops/cos_test'; +import './ops/cosh_test'; import './ops/cosine_distance_test'; import './ops/crop_and_resize_test'; import './ops/cumsum_test'; @@ -154,6 +156,8 @@ import './ops/selu_test'; import './ops/sigmoid_cross_entropy_test'; import './ops/sign_test'; import './ops/signal_ops_test'; +import './ops/sin_test'; +import './ops/sinh_test'; import './ops/slice_test'; import './ops/slice_util_test'; import './ops/softmax_cross_entropy_test'; @@ -165,12 +169,13 @@ import './ops/stack_test'; import './ops/strided_slice_test'; import './ops/sub_test'; import './ops/sum_test'; +import './ops/tan_test'; +import './ops/tanh_test'; import './ops/tile_test'; import './ops/topk_test'; import './ops/transpose_test'; import './ops/truncated_normal_test'; import './ops/unary_ops_test'; -import './ops/undefined_test'; import './ops/unsorted_segment_sum_test'; import './ops/unstack_test'; import './ops/where_async_test'; From 6455264e4afa2646ef0683dcecc2d03cd7c6dc32 Mon Sep 17 00:00:00 2001 From: Yannick Assogba Date: Tue, 7 Jul 2020 22:31:37 -0400 Subject: [PATCH 3/3] add kernelname to gradient template --- tfjs-core/scripts/touch_modular_op_files.ts | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tfjs-core/scripts/touch_modular_op_files.ts b/tfjs-core/scripts/touch_modular_op_files.ts index ae2e8644710..66f27c2f80d 100644 --- a/tfjs-core/scripts/touch_modular_op_files.ts +++ b/tfjs-core/scripts/touch_modular_op_files.ts @@ -60,7 +60,7 @@ async function main() { console.log('Called touch_modular_op_files with args:', args); if (args.op != null) { - const ops: string[] = args.op.split(','); + const ops: string[] = args.op.split(',').filter((o: string) => o != null); ops.forEach(op => { let filePath = `./src/ops/${op}.ts`; let command = `touch ${filePath}`; @@ -84,7 +84,8 @@ async function main() { return str.charAt(0).toLowerCase() + str.slice(1); }; - const kernels: string[] = args.grad.split(','); + const kernels: string[] = + args.grad.split(',').filter((k: string) => k != null); kernels.forEach(kernelName => { const gradientFileTemplate = `/** @@ -104,12 +105,12 @@ async function main() { * ============================================================================= */ -import {KernelName, KernelNameAttrs} from '../kernel_names'; +import {${kernelName}, ${kernelName}Attrs} from '../kernel_names'; import {GradConfig, NamedAttrMap} from '../kernel_registry'; import {Tensor} from '../tensor'; export const ${downcaseFirstChar(kernelName)}GradConfig: GradConfig = { - kernelName: KernelName, + kernelName: ${kernelName}, inputsToSave: [], // UPDATE ME outputsToSave: [], // UPDATE ME gradFunc: (dy: Tensor, saved: Tensor[], attrs: NamedAttrMap) => {