diff --git a/tfjs-core/src/gradients/Abs_grad.ts b/tfjs-core/src/gradients/Abs_grad.ts index 0325f1fadec..a0709ce6dbe 100644 --- a/tfjs-core/src/gradients/Abs_grad.ts +++ b/tfjs-core/src/gradients/Abs_grad.ts @@ -19,7 +19,7 @@ import {Abs} from '../kernel_names'; import {GradConfig} from '../kernel_registry'; import {cast} from '../ops/cast'; import {mul} from '../ops/mul'; -import {step} from '../ops/unary_ops'; +import {step} from '../ops/step'; import {Tensor} from '../tensor'; export const absGradConfig: GradConfig = { diff --git a/tfjs-core/src/gradients/Acos_grad.ts b/tfjs-core/src/gradients/Acos_grad.ts index ba720548717..718865319fc 100644 --- a/tfjs-core/src/gradients/Acos_grad.ts +++ b/tfjs-core/src/gradients/Acos_grad.ts @@ -21,9 +21,9 @@ import {cast} from '../ops/cast'; import {div} from '../ops/div'; import {neg} from '../ops/neg'; import {scalar} from '../ops/scalar'; +import {sqrt} from '../ops/sqrt'; import {square} from '../ops/square'; import {sub} from '../ops/sub'; -import {sqrt} from '../ops/unary_ops'; import {Tensor} from '../tensor'; export const acosGradConfig: GradConfig = { diff --git a/tfjs-core/src/gradients/Acosh_grad.ts b/tfjs-core/src/gradients/Acosh_grad.ts index ec30b6bec13..5fd89161121 100644 --- a/tfjs-core/src/gradients/Acosh_grad.ts +++ b/tfjs-core/src/gradients/Acosh_grad.ts @@ -19,9 +19,9 @@ import {Acosh} from '../kernel_names'; import {GradConfig} from '../kernel_registry'; import {cast} from '../ops/cast'; import {div} from '../ops/div'; +import {sqrt} from '../ops/sqrt'; import {square} from '../ops/square'; import {sub} from '../ops/sub'; -import {sqrt} from '../ops/unary_ops'; import {Tensor} from '../tensor'; export const acoshGradConfig: GradConfig = { diff --git a/tfjs-core/src/gradients/Asin_grad.ts b/tfjs-core/src/gradients/Asin_grad.ts index 799c51a72df..09851e3477c 100644 --- a/tfjs-core/src/gradients/Asin_grad.ts +++ b/tfjs-core/src/gradients/Asin_grad.ts @@ -20,9 +20,9 @@ import {GradConfig} from '../kernel_registry'; import {cast} from '../ops/cast'; import {div} from '../ops/div'; import {scalar} from '../ops/scalar'; +import {sqrt} from '../ops/sqrt'; import {square} from '../ops/square'; import {sub} from '../ops/sub'; -import {sqrt} from '../ops/unary_ops'; import {Tensor} from '../tensor'; export const asinGradConfig: GradConfig = { diff --git a/tfjs-core/src/gradients/Asinh_grad.ts b/tfjs-core/src/gradients/Asinh_grad.ts index df9ab69247f..4ee93ba40e6 100644 --- a/tfjs-core/src/gradients/Asinh_grad.ts +++ b/tfjs-core/src/gradients/Asinh_grad.ts @@ -21,8 +21,8 @@ import {add} from '../ops/add'; import {cast} from '../ops/cast'; import {div} from '../ops/div'; import {scalar} from '../ops/scalar'; +import {sqrt} from '../ops/sqrt'; import {square} from '../ops/square'; -import {sqrt} from '../ops/unary_ops'; import {Tensor} from '../tensor'; export const asinhGradConfig: GradConfig = { diff --git a/tfjs-core/src/gradients/FusedBatchNorm_grad.ts b/tfjs-core/src/gradients/FusedBatchNorm_grad.ts index e5f516974fe..f20d2b5160b 100644 --- a/tfjs-core/src/gradients/FusedBatchNorm_grad.ts +++ b/tfjs-core/src/gradients/FusedBatchNorm_grad.ts @@ -20,11 +20,11 @@ import {add} from '../ops/add'; import {getReductionAxes} from '../ops/broadcast_util'; import {mul} from '../ops/mul'; import {reshape} from '../ops/reshape'; +import {rsqrt} from '../ops/rsqrt'; import {scalar} from '../ops/scalar'; import {sub} from '../ops/sub'; import {sum} from '../ops/sum'; import {tile} from '../ops/tile'; -import {rsqrt} from '../ops/unary_ops'; import {Tensor} from '../tensor'; import {Rank, ShapeMap} from '../types'; diff --git a/tfjs-core/src/gradients/IsFinite_grad.ts b/tfjs-core/src/gradients/IsFinite_grad.ts new file mode 100644 index 00000000000..fa7a40416a1 --- /dev/null +++ b/tfjs-core/src/gradients/IsFinite_grad.ts @@ -0,0 +1,30 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {IsFinite} from '../kernel_names'; +import {GradConfig} from '../kernel_registry'; +import {zerosLike} from '../ops/zeros_like'; +import {Tensor} from '../tensor'; + +export const isFiniteGradConfig: GradConfig = { + kernelName: IsFinite, + gradFunc: (dy: Tensor) => { + // TODO(nsthorat): Let gradients be null for cases where we want to stop + // backpropgation. + return {x: () => zerosLike(dy)}; + } +}; diff --git a/tfjs-core/src/gradients/IsInf_grad.ts b/tfjs-core/src/gradients/IsInf_grad.ts new file mode 100644 index 00000000000..e56b750efbd --- /dev/null +++ b/tfjs-core/src/gradients/IsInf_grad.ts @@ -0,0 +1,31 @@ + +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {IsInf} from '../kernel_names'; +import {GradConfig} from '../kernel_registry'; +import {zerosLike} from '../ops/zeros_like'; +import {Tensor} from '../tensor'; + +export const isInfGradConfig: GradConfig = { + kernelName: IsInf, + gradFunc: (dy: Tensor) => { + // TODO(nsthorat): Let gradients be null for cases where we want to stop + // backpropgation. + return {x: () => zerosLike(dy)}; + } +}; diff --git a/tfjs-core/src/gradients/IsNan_grad.ts b/tfjs-core/src/gradients/IsNan_grad.ts new file mode 100644 index 00000000000..c5c5bb9e18c --- /dev/null +++ b/tfjs-core/src/gradients/IsNan_grad.ts @@ -0,0 +1,30 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {IsNan} from '../kernel_names'; +import {GradConfig} from '../kernel_registry'; +import {zerosLike} from '../ops/zeros_like'; +import {Tensor} from '../tensor'; + +export const isNanGradConfig: GradConfig = { + kernelName: IsNan, + gradFunc: (dy: Tensor) => { + // TODO(nsthorat): Let gradients be null for cases where we want to stop + // backpropgation. + return {x: () => zerosLike(dy)}; + } +}; diff --git a/tfjs-core/src/gradients/Relu6_grad.ts b/tfjs-core/src/gradients/Relu6_grad.ts index c85effe4c29..ebf2ca7a407 100644 --- a/tfjs-core/src/gradients/Relu6_grad.ts +++ b/tfjs-core/src/gradients/Relu6_grad.ts @@ -19,7 +19,7 @@ import {GradConfig} from '../kernel_registry'; import {cast} from '../ops/cast'; import {lessEqual} from '../ops/less_equal'; import {mul} from '../ops/mul'; -import {step} from '../ops/unary_ops'; +import {step} from '../ops/step'; import {Tensor} from '../tensor'; export const relu6GradConfig: GradConfig = { diff --git a/tfjs-core/src/gradients/Relu_grad.ts b/tfjs-core/src/gradients/Relu_grad.ts index f2826a8cf1e..c96bdcc3059 100644 --- a/tfjs-core/src/gradients/Relu_grad.ts +++ b/tfjs-core/src/gradients/Relu_grad.ts @@ -18,7 +18,7 @@ import {Relu} from '../kernel_names'; import {GradConfig} from '../kernel_registry'; import {cast} from '../ops/cast'; import {mul} from '../ops/mul'; -import {step} from '../ops/unary_ops'; +import {step} from '../ops/step'; import {Tensor} from '../tensor'; export const reluGradConfig: GradConfig = { diff --git a/tfjs-core/src/gradients/Round_grad.ts b/tfjs-core/src/gradients/Round_grad.ts new file mode 100644 index 00000000000..8ddcb155c37 --- /dev/null +++ b/tfjs-core/src/gradients/Round_grad.ts @@ -0,0 +1,30 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Round} from '../kernel_names'; +import {GradConfig} from '../kernel_registry'; +import {zerosLike} from '../ops/zeros_like'; +import {Tensor} from '../tensor'; + +export const roundGradConfig: GradConfig = { + kernelName: Round, + gradFunc: (dy: Tensor) => { + // TODO(nsthorat): Let gradients be null for cases where we want to stop + // backpropgation. + return {x: () => zerosLike(dy)}; + } +}; diff --git a/tfjs-core/src/gradients/Rsqrt_grad.ts b/tfjs-core/src/gradients/Rsqrt_grad.ts new file mode 100644 index 00000000000..fa77d7792e3 --- /dev/null +++ b/tfjs-core/src/gradients/Rsqrt_grad.ts @@ -0,0 +1,33 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Rsqrt} from '../kernel_names'; +import {GradConfig} from '../kernel_registry'; +import {div} from '../ops/div'; +import {mul} from '../ops/mul'; +import {neg} from '../ops/neg'; +import {pow} from '../ops/pow'; +import {Tensor} from '../tensor'; + +export const rsqrtGradConfig: GradConfig = { + kernelName: Rsqrt, + inputsToSave: ['x'], + gradFunc: (dy: Tensor, saved: Tensor[]) => { + const [x] = saved; + return {x: () => neg(div(dy, mul(pow(x, 1.5), 2)))}; + } +}; diff --git a/tfjs-core/src/gradients/Sigmoid_grad.ts b/tfjs-core/src/gradients/Sigmoid_grad.ts new file mode 100644 index 00000000000..5a595406e51 --- /dev/null +++ b/tfjs-core/src/gradients/Sigmoid_grad.ts @@ -0,0 +1,33 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Sigmoid} from '../kernel_names'; +import {GradConfig} from '../kernel_registry'; +import {mul} from '../ops/mul'; +import {scalar} from '../ops/scalar'; +import {sub} from '../ops/sub'; +import {Tensor} from '../tensor'; + +export const sigmoidGradConfig: GradConfig = { + kernelName: Sigmoid, + outputsToSave: [true], + gradFunc: (dy: Tensor, saved: Tensor[]) => { + const [y] = saved; + + return {x: () => mul(dy, mul(y, sub(scalar(1), y)))}; + } +}; diff --git a/tfjs-core/src/gradients/Softplus_grad.ts b/tfjs-core/src/gradients/Softplus_grad.ts new file mode 100644 index 00000000000..257ff601c11 --- /dev/null +++ b/tfjs-core/src/gradients/Softplus_grad.ts @@ -0,0 +1,32 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Softplus} from '../kernel_names'; +import {GradConfig} from '../kernel_registry'; +import {mul} from '../ops/mul'; +import {sigmoid} from '../ops/sigmoid'; +import {Tensor} from '../tensor'; + +export const softplusGradConfig: GradConfig = { + kernelName: Softplus, + inputsToSave: ['x'], + gradFunc: (dy: Tensor, saved: Tensor[]) => { + const [x] = saved; + + return {x: () => mul(dy, sigmoid(x))}; + } +}; diff --git a/tfjs-core/src/gradients/Sqrt_grad.ts b/tfjs-core/src/gradients/Sqrt_grad.ts new file mode 100644 index 00000000000..71012f49b8f --- /dev/null +++ b/tfjs-core/src/gradients/Sqrt_grad.ts @@ -0,0 +1,34 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Sqrt} from '../kernel_names'; +import {GradConfig} from '../kernel_registry'; +import {cast} from '../ops/cast'; +import {div} from '../ops/div'; +import {mul} from '../ops/mul'; +import {sqrt} from '../ops/sqrt'; +import {Tensor} from '../tensor'; + +export const sqrtGradConfig: GradConfig = { + kernelName: Sqrt, + inputsToSave: ['x'], + gradFunc: (dy: Tensor, saved: Tensor[]) => { + const [x] = saved; + + return {x: () => div(dy, mul(sqrt(cast(x, 'float32')), 2))}; + } +}; diff --git a/tfjs-core/src/gradients/Step_grad.ts b/tfjs-core/src/gradients/Step_grad.ts new file mode 100644 index 00000000000..a24535136ca --- /dev/null +++ b/tfjs-core/src/gradients/Step_grad.ts @@ -0,0 +1,30 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Step} from '../kernel_names'; +import {GradConfig} from '../kernel_registry'; +import {zerosLike} from '../ops/zeros_like'; +import {Tensor} from '../tensor'; + +export const stepGradConfig: GradConfig = { + kernelName: Step, + gradFunc: (dy: Tensor) => { + // TODO(manrajgrover): Return null for gradients when backprop supports + // it. + return {x: () => zerosLike(dy)}; + } +}; diff --git a/tfjs-core/src/kernel_names.ts b/tfjs-core/src/kernel_names.ts index 2020d63ee8a..c5b017c2cac 100644 --- a/tfjs-core/src/kernel_names.ts +++ b/tfjs-core/src/kernel_names.ts @@ -349,6 +349,15 @@ export type IdentityInputs = Pick; export const Imag = 'Imag'; export type ImagInputs = Pick; +export const IsFinite = 'IsFinite'; +export type IsFiniteInputs = UnaryInputs; + +export const IsInf = 'IsInf'; +export type IsInfInputs = UnaryInputs; + +export const IsNan = 'IsNan'; +export type IsNanInputs = UnaryInputs; + export const Less = 'Less'; export type LessInputs = BinaryInputs; @@ -595,6 +604,12 @@ export interface ReverseAttrs { dims: number|number[]; } +export const Round = 'Round'; +export type RoundInputs = UnaryInputs; + +export const Rsqrt = 'Rsqrt'; +export type RsqrtInputs = UnaryInputs; + export const ScatterNd = 'ScatterNd'; export type ScatterNdInputs = Pick; export interface ScatterNdAttrs { @@ -622,6 +637,15 @@ export type SinhInputs = UnaryInputs; export const Sign = 'Sign'; export type SignInputs = UnaryInputs; +export const Sigmoid = 'Sigmoid'; +export type SigmoidInputs = UnaryInputs; + +export const Softplus = 'Softplus'; +export type SoftplusInputs = UnaryInputs; + +export const Sqrt = 'Sqrt'; +export type SqrtInputs = UnaryInputs; + export const Sum = 'Sum'; export type SumInputs = Pick; export interface SumAttrs { @@ -724,6 +748,12 @@ export type ZerosLikeInputs = UnaryInputs; /** * TensorFlow.js-only kernels */ +export const Step = 'Step'; +export type StepInputs = UnaryInputs; +export interface StepAttrs { + alpha: number; +} + export const FromPixels = 'FromPixels'; export interface FromPixelsInputs { pixels: PixelData|ImageData|HTMLImageElement|HTMLCanvasElement| diff --git a/tfjs-core/src/ops/basic_lstm_cell.ts b/tfjs-core/src/ops/basic_lstm_cell.ts index cda956aceed..d0644bad075 100644 --- a/tfjs-core/src/ops/basic_lstm_cell.ts +++ b/tfjs-core/src/ops/basic_lstm_cell.ts @@ -24,9 +24,9 @@ import {concat} from './concat'; import {matMul} from './mat_mul'; import {mul} from './mul'; import {op} from './operation'; +import {sigmoid} from './sigmoid'; import {slice} from './slice'; import {tanh} from './tanh'; -import {sigmoid} from './unary_ops'; /** * Computes the next state and output of a BasicLSTMCell. diff --git a/tfjs-core/src/ops/is_finite.ts b/tfjs-core/src/ops/is_finite.ts new file mode 100644 index 00000000000..88c75974224 --- /dev/null +++ b/tfjs-core/src/ops/is_finite.ts @@ -0,0 +1,47 @@ +/** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {ENGINE} from '../engine'; +import {IsFinite, IsFiniteInputs} from '../kernel_names'; +import {Tensor} from '../tensor'; +import {NamedTensorMap} from '../tensor_types'; +import {convertToTensor} from '../tensor_util_env'; +import {TensorLike} from '../types'; + +import {op} from './operation'; + +/** + * Returns which elements of x are finite. + * + * ```js + * const x = tf.tensor1d([NaN, Infinity, -Infinity, 0, 1]); + * + * x.isFinite().print(); // or tf.isNaN(x) + * ``` + * @param x The input Tensor. + */ +/** @doc {heading: 'Operations', subheading: 'Basic math'} */ +function isFinite_(x: T|TensorLike): T { + const $x = convertToTensor(x, 'x', 'isFinite'); + + const inputs: IsFiniteInputs = {x: $x}; + + return ENGINE.runKernelFunc( + (backend) => backend.isFinite($x), inputs as {} as NamedTensorMap, + null /* grad */, IsFinite); +} +export const isFinite = op({isFinite_}); diff --git a/tfjs-core/src/ops/is_finite_test.ts b/tfjs-core/src/ops/is_finite_test.ts new file mode 100644 index 00000000000..8d09349add3 --- /dev/null +++ b/tfjs-core/src/ops/is_finite_test.ts @@ -0,0 +1,77 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import * as tf from '../index'; +import {ALL_ENVS, describeWithFlags} from '../jasmine_util'; +import {expectArraysClose} from '../test_util'; + +describeWithFlags('isFinite', ALL_ENVS, () => { + it('basic', async () => { + const a = tf.tensor1d([NaN, Infinity, -Infinity, 0, 1]); + const r = tf.isFinite(a); + expect(r.dtype).toEqual('bool'); + expectArraysClose(await r.data(), [0, 0, 0, 1, 1]); + }); + + it('gradients: Scalar', async () => { + const a = tf.scalar(NaN); + const dy = tf.scalar(3); + + const gradients = tf.grad(a => tf.isFinite(a))(a, dy); + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), [0]); + }); + + it('gradients: Tensor1D', async () => { + const a = tf.tensor1d([NaN, Infinity, -Infinity, 0, 1]); + const dy = tf.tensor1d([1, 1, 1, 1, 1]); + + const gradients = tf.grad(a => tf.isFinite(a))(a, dy); + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), [0, 0, 0, 0, 0]); + }); + + it('gradients: Tensor2D', async () => { + const a = tf.tensor2d([NaN, Infinity, -Infinity, 0], [2, 2]); + const dy = tf.tensor2d([1, 2, 3, 4], [2, 2]); + + const gradients = tf.grad(a => tf.isFinite(a))(a, dy); + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), [0, 0, 0, 0]); + }); + + it('throws when passed a non-tensor', () => { + expect(() => tf.isFinite({} as tf.Tensor)) + .toThrowError(/Argument 'x' passed to 'isFinite' must be a Tensor/); + }); + + it('accepts a tensor-like object', async () => { + const r = tf.isFinite([NaN, Infinity, -Infinity, 0, 1]); + expectArraysClose(await r.data(), [0, 0, 0, 1, 1]); + }); + + it('throws for string tensor', () => { + expect(() => tf.isFinite('q')) + .toThrowError(/Argument 'x' passed to 'isFinite' must be numeric/); + }); +}); diff --git a/tfjs-core/src/ops/is_inf.ts b/tfjs-core/src/ops/is_inf.ts new file mode 100644 index 00000000000..4bd7a88e802 --- /dev/null +++ b/tfjs-core/src/ops/is_inf.ts @@ -0,0 +1,47 @@ +/** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {ENGINE} from '../engine'; +import {IsInf, IsInfInputs} from '../kernel_names'; +import {Tensor} from '../tensor'; +import {NamedTensorMap} from '../tensor_types'; +import {convertToTensor} from '../tensor_util_env'; +import {TensorLike} from '../types'; + +import {op} from './operation'; + +/** + * Returns which elements of x are Infinity or -Infinity. + * + * ```js + * const x = tf.tensor1d([NaN, Infinity, -Infinity, 0, 1]); + * + * x.isInf().print(); // or tf.isNaN(x) + * ``` + * @param x The input Tensor. + */ +/** @doc {heading: 'Operations', subheading: 'Basic math'} */ +function isInf_(x: T|TensorLike): T { + const $x = convertToTensor(x, 'x', 'isInf'); + + const inputs: IsInfInputs = {x: $x}; + + return ENGINE.runKernelFunc( + (backend) => backend.isInf($x), inputs as {} as NamedTensorMap, + null /* grad */, IsInf); +} +export const isInf = op({isInf_}); diff --git a/tfjs-core/src/ops/is_inf_test.ts b/tfjs-core/src/ops/is_inf_test.ts new file mode 100644 index 00000000000..07e948ab7c3 --- /dev/null +++ b/tfjs-core/src/ops/is_inf_test.ts @@ -0,0 +1,77 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import * as tf from '../index'; +import {ALL_ENVS, describeWithFlags} from '../jasmine_util'; +import {expectArraysClose} from '../test_util'; + +describeWithFlags('isInf', ALL_ENVS, () => { + it('basic', async () => { + const a = tf.tensor1d([NaN, Infinity, -Infinity, 0, 1]); + const r = tf.isInf(a); + expect(r.dtype).toEqual('bool'); + expectArraysClose(await r.data(), [0, 1, 1, 0, 0]); + }); + + it('gradients: Scalar', async () => { + const a = tf.scalar(NaN); + const dy = tf.scalar(3); + + const gradients = tf.grad(a => tf.isInf(a))(a, dy); + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), [0]); + }); + + it('gradients: Tensor1D', async () => { + const a = tf.tensor1d([NaN, Infinity, -Infinity, 0, 1]); + const dy = tf.tensor1d([1, 1, 1, 1, 1]); + + const gradients = tf.grad(a => tf.isInf(a))(a, dy); + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), [0, 0, 0, 0, 0]); + }); + + it('gradients: Tensor2D', async () => { + const a = tf.tensor2d([NaN, Infinity, -Infinity, 0], [2, 2]); + const dy = tf.tensor2d([1, 2, 3, 4], [2, 2]); + + const gradients = tf.grad(a => tf.isInf(a))(a, dy); + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), [0, 0, 0, 0]); + }); + + it('throws when passed a non-tensor', () => { + expect(() => tf.isInf({} as tf.Tensor)) + .toThrowError(/Argument 'x' passed to 'isInf' must be a Tensor/); + }); + + it('accepts a tensor-like object', async () => { + const r = tf.isInf([NaN, Infinity, -Infinity, 0, 1]); + expectArraysClose(await r.data(), [0, 1, 1, 0, 0]); + }); + + it('throws for string tensor', () => { + expect(() => tf.isInf('q')) + .toThrowError(/Argument 'x' passed to 'isInf' must be numeric/); + }); +}); diff --git a/tfjs-core/src/ops/is_nan.ts b/tfjs-core/src/ops/is_nan.ts new file mode 100644 index 00000000000..b731122f8c0 --- /dev/null +++ b/tfjs-core/src/ops/is_nan.ts @@ -0,0 +1,46 @@ +/** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {ENGINE} from '../engine'; +import {IsNan, IsNanInputs} from '../kernel_names'; +import {Tensor} from '../tensor'; +import {NamedTensorMap} from '../tensor_types'; +import {convertToTensor} from '../tensor_util_env'; +import {TensorLike} from '../types'; + +import {op} from './operation'; + +/** + * RReturns which elements of x are NaN. + * + * ```js + * const x = tf.tensor1d([NaN, Infinity, -Infinity, 0, 1]); + * + * x.isNaN().print(); // or tf.isNaN(x) + * ``` + * @param x The input Tensor. + */ +/** @doc {heading: 'Operations', subheading: 'Basic math'} */ +function isNaN_(x: T|TensorLike): T { + const $x = convertToTensor(x, 'x', 'isNaN'); + const inputs: IsNanInputs = {x: $x}; + + return ENGINE.runKernelFunc( + backend => backend.isNaN($x), inputs as {} as NamedTensorMap, + null /* grad */, IsNan); +} +export const isNaN = op({isNaN_}); diff --git a/tfjs-core/src/ops/is_nan_test.ts b/tfjs-core/src/ops/is_nan_test.ts new file mode 100644 index 00000000000..db6144ab085 --- /dev/null +++ b/tfjs-core/src/ops/is_nan_test.ts @@ -0,0 +1,77 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import * as tf from '../index'; +import {ALL_ENVS, describeWithFlags} from '../jasmine_util'; +import {expectArraysClose} from '../test_util'; + +describeWithFlags('isNaN', ALL_ENVS, () => { + it('basic', async () => { + const a = tf.tensor1d([NaN, Infinity, -Infinity, 0, 1]); + const r = tf.isNaN(a); + expect(r.dtype).toEqual('bool'); + expectArraysClose(await r.data(), [1, 0, 0, 0, 0]); + }); + + it('gradients: Scalar', async () => { + const a = tf.scalar(NaN); + const dy = tf.scalar(3); + + const gradients = tf.grad(a => tf.isNaN(a))(a, dy); + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), [0]); + }); + + it('gradients: Tensor1D', async () => { + const a = tf.tensor1d([NaN, Infinity, -Infinity, 0, 1]); + const dy = tf.tensor1d([1, 1, 1, 1, 1]); + + const gradients = tf.grad(a => tf.isNaN(a))(a, dy); + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), [0, 0, 0, 0, 0]); + }); + + it('gradients: Tensor2D', async () => { + const a = tf.tensor2d([NaN, Infinity, -Infinity, 0], [2, 2]); + const dy = tf.tensor2d([1, 2, 3, 4], [2, 2]); + + const gradients = tf.grad(a => tf.isNaN(a))(a, dy); + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), [0, 0, 0, 0]); + }); + + it('throws when passed a non-tensor', () => { + expect(() => tf.isNaN({} as tf.Tensor)) + .toThrowError(/Argument 'x' passed to 'isNaN' must be a Tensor/); + }); + + it('accepts a tensor-like object', async () => { + const r = tf.isNaN([NaN, Infinity, -Infinity, 0, 1]); + expectArraysClose(await r.data(), [1, 0, 0, 0, 0]); + }); + + it('throws for string tensor', () => { + expect(() => tf.isNaN('q')) + .toThrowError(/Argument 'x' passed to 'isNaN' must be numeric/); + }); +}); diff --git a/tfjs-core/src/ops/log_sigmoid.ts b/tfjs-core/src/ops/log_sigmoid.ts index f20da287691..d02aa0443fe 100644 --- a/tfjs-core/src/ops/log_sigmoid.ts +++ b/tfjs-core/src/ops/log_sigmoid.ts @@ -23,7 +23,8 @@ import {TensorLike} from '../types'; import {mul} from './mul'; import {neg} from './neg'; import {op} from './operation'; -import {sigmoid} from './unary_ops'; +import {sigmoid} from './sigmoid'; + /** * Computes log sigmoid of the input `tf.Tensor` element-wise: * `logSigmoid(x)`. For numerical stability, we use `-tf.softplus(-x)`. diff --git a/tfjs-core/src/ops/norm.ts b/tfjs-core/src/ops/norm.ts index be8364b2493..ee91476cd7c 100644 --- a/tfjs-core/src/ops/norm.ts +++ b/tfjs-core/src/ops/norm.ts @@ -28,9 +28,9 @@ import {op} from './operation'; import {pow} from './pow'; import {reshape} from './reshape'; import {scalar} from './scalar'; +import {sqrt} from './sqrt'; import {square} from './square'; import {sum} from './sum'; -import {sqrt} from './unary_ops'; /** * Computes the norm of scalar, vectors, and matrices. diff --git a/tfjs-core/src/ops/ops.ts b/tfjs-core/src/ops/ops.ts index 91c36a2fd63..31d7107f658 100644 --- a/tfjs-core/src/ops/ops.ts +++ b/tfjs-core/src/ops/ops.ts @@ -79,6 +79,9 @@ export {gather} from './gather'; export {greater} from './greater'; export {greaterEqual} from './greater_equal'; export {imag} from './imag'; +export {isFinite} from './is_finite'; +export {isInf} from './is_inf'; +export {isNaN} from './is_nan'; export {leakyRelu} from './leaky_relu'; export {less} from './less'; export {lessEqual} from './less_equal'; @@ -138,10 +141,13 @@ export {reverse1d} from './reverse_1d'; export {reverse2d} from './reverse_2d'; export {reverse3d} from './reverse_3d'; export {reverse4d} from './reverse_4d'; +export {round} from './round'; +export {rsqrt} from './rsqrt'; export {scalar} from './scalar'; export {selu} from './selu'; export {separableConv2d} from './separable_conv2d'; export {setdiff1dAsync} from './setdiff1d_async'; +export {sigmoid} from './sigmoid'; export {sign} from './sign'; export {sin} from './sin'; export {sinh} from './sinh'; @@ -151,12 +157,15 @@ export {slice2d} from './slice2d'; export {slice3d} from './slice3d'; export {slice4d} from './slice4d'; export {softmax} from './softmax'; +export {softplus} from './softplus'; export {spaceToBatchND} from './space_to_batch_nd'; export {split} from './split'; +export {sqrt} from './sqrt'; export {square} from './square'; export {squaredDifference} from './squared_difference'; export {squeeze} from './squeeze'; export {stack} from './stack'; +export {step} from './step'; export {stridedSlice} from './strided_slice'; export {sub} from './sub'; export {sum} from './sum'; @@ -181,7 +190,6 @@ export {zeros} from './zeros'; export {zerosLike} from './zeros_like'; export * from './boolean_mask'; -export * from './unary_ops'; export * from './compare'; export * from './binary_ops'; export * from './transpose'; diff --git a/tfjs-core/src/ops/round.ts b/tfjs-core/src/ops/round.ts new file mode 100644 index 00000000000..ed16d57bbab --- /dev/null +++ b/tfjs-core/src/ops/round.ts @@ -0,0 +1,48 @@ +/** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {ENGINE} from '../engine'; +import {Round, RoundInputs} from '../kernel_names'; +import {Tensor} from '../tensor'; +import {NamedTensorMap} from '../tensor_types'; +import {convertToTensor} from '../tensor_util_env'; +import {TensorLike} from '../types'; + +import {op} from './operation'; + +/** + * Computes round of input `tf.Tensor` element-wise: `round(x)`. + * It implements banker's rounding. + * + * ```js + * const x = tf.tensor1d([.6, 1.1, -3.3]); + * + * x.round().print(); // or tf.round(x) + * ``` + * @param x The input tensor. + */ +/** @doc {heading: 'Operations', subheading: 'Basic math'} */ +function round_(x: T|TensorLike): T { + const $x = convertToTensor(x, 'x', 'round'); + const inputs: RoundInputs = {x: $x}; + + return ENGINE.runKernelFunc( + (backend) => backend.round($x), inputs as {} as NamedTensorMap, + null /* grad */, Round); +} + +export const round = op({round_}); diff --git a/tfjs-core/src/ops/round_test.ts b/tfjs-core/src/ops/round_test.ts new file mode 100644 index 00000000000..e3365ed11a4 --- /dev/null +++ b/tfjs-core/src/ops/round_test.ts @@ -0,0 +1,94 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import * as tf from '../index'; +import {ALL_ENVS, describeWithFlags} from '../jasmine_util'; +import {expectArraysClose} from '../test_util'; + +describeWithFlags('round', ALL_ENVS, () => { + it('basic', async () => { + const a = tf.tensor1d([0.9, 2.5, 2.3, 1.5, -4.5]); + const r = a.round(); + + expectArraysClose(await r.data(), [1, 2, 2, 2, -4]); + }); + + it('propagates NaNs', async () => { + const a = tf.tensor1d([1.5, NaN, -1.4]); + const r = tf.round(a); + expectArraysClose(await r.data(), [2, NaN, -1]); + }); + + it('gradients: Scalar', async () => { + const a = tf.scalar(5.2); + const dy = tf.scalar(3); + + const gradients = tf.grad(a => tf.round(a))(a, dy); + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), [0]); + }); + + it('gradient with clones', async () => { + const a = tf.scalar(5.2); + const dy = tf.scalar(3); + + const gradients = tf.grad(a => tf.round(a.clone()).clone())(a, dy); + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), [0]); + }); + + it('gradients: Tensor1D', async () => { + const a = tf.tensor1d([-1.1, 2.6, 3, -5.9]); + const dy = tf.tensor1d([1, 2, 3, 4]); + + const gradients = tf.grad(a => tf.round(a))(a, dy); + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), [0, 0, 0, 0]); + }); + + it('gradients: Tensor2D', async () => { + const a = tf.tensor2d([-3, 1, 2.2, 3], [2, 2]); + const dy = tf.tensor2d([1, 2, 3, 4], [2, 2]); + + const gradients = tf.grad(a => tf.round(a))(a, dy); + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), [0, 0, 0, 0]); + }); + + it('throws when passed a non-tensor', () => { + expect(() => tf.round({} as tf.Tensor)) + .toThrowError(/Argument 'x' passed to 'round' must be a Tensor/); + }); + + it('accepts a tensor-like object', async () => { + const r = tf.round([0.9, 2.5, 2.3, 1.5, -4.5]); + expectArraysClose(await r.data(), [1, 2, 2, 2, -4]); + }); + + it('throws for string tensor', () => { + expect(() => tf.round('q')) + .toThrowError(/Argument 'x' passed to 'round' must be numeric/); + }); +}); diff --git a/tfjs-core/src/ops/rsqrt.ts b/tfjs-core/src/ops/rsqrt.ts new file mode 100644 index 00000000000..04000961341 --- /dev/null +++ b/tfjs-core/src/ops/rsqrt.ts @@ -0,0 +1,50 @@ +/** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {ENGINE} from '../engine'; +import {Rsqrt, RsqrtInputs} from '../kernel_names'; +import {Tensor} from '../tensor'; +import {NamedTensorMap} from '../tensor_types'; +import {convertToTensor} from '../tensor_util_env'; +import {TensorLike} from '../types'; + +import {op} from './operation'; + +/** + * Computes reciprocal of square root of the input `tf.Tensor` element-wise: + * `y = 1 / sqrt(x)` + * + * ```js + * const x = tf.tensor1d([1, 2, 4, -1]); + * + * x.rsqrt().print(); // or tf.rsqrt(x) + * ``` + * @param x The input tensor. + */ +/** @doc {heading: 'Operations', subheading: 'Basic math'} */ +function rsqrt_(x: T|TensorLike): T { + const $x = convertToTensor(x, 'x', 'rsqrt'); + + const inputs: RsqrtInputs = {x: $x}; + + return ENGINE.runKernelFunc((backend, save) => { + const res = backend.rsqrt($x); + save([$x]); + return res; + }, inputs as {} as NamedTensorMap, null /* grad */, Rsqrt); +} +export const rsqrt = op({rsqrt_}); diff --git a/tfjs-core/src/ops/rsqrt_test.ts b/tfjs-core/src/ops/rsqrt_test.ts new file mode 100644 index 00000000000..f9aeec61698 --- /dev/null +++ b/tfjs-core/src/ops/rsqrt_test.ts @@ -0,0 +1,105 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import * as tf from '../index'; +import {ALL_ENVS, describeWithFlags} from '../jasmine_util'; +import {expectArraysClose} from '../test_util'; + +describeWithFlags('rsqrt', ALL_ENVS, () => { + it('rsqrt', async () => { + const a = tf.tensor1d([2, 4]); + const r = tf.rsqrt(a); + expectArraysClose(await r.data(), [1 / Math.sqrt(2), 1 / Math.sqrt(4)]); + }); + + it('rsqrt propagates NaNs', async () => { + const a = tf.tensor1d([1, NaN]); + const r = tf.rsqrt(a); + expectArraysClose(await r.data(), [1 / Math.sqrt(1), NaN]); + }); + + it('gradients: Scalar', async () => { + const a = tf.scalar(4); + const dy = tf.scalar(8); + + const da = tf.grad(a => tf.rsqrt(a))(a, dy); + + expect(da.shape).toEqual(a.shape); + expect(da.dtype).toEqual('float32'); + expectArraysClose(await da.data(), [(-1 * 8) / (2 * Math.pow(4, 1.5))]); + }); + + it('gradients: Scalar', async () => { + const a = tf.scalar(4); + const dy = tf.scalar(8); + + const da = tf.grad(a => tf.rsqrt(a.clone()).clone())(a, dy); + + expect(da.shape).toEqual(a.shape); + expect(da.dtype).toEqual('float32'); + expectArraysClose(await da.data(), [(-1 * 8) / (2 * Math.pow(4, 1.5))]); + }); + + it('gradients: Tensor1D', async () => { + const a = tf.tensor1d([1, 2, 3, 5]); + const dy = tf.tensor1d([1, 2, 3, 4]); + + const gradients = tf.grad(a => tf.rsqrt(a))(a, dy); + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose( + await gradients.data(), + [ + -1 * 1 / (2 * Math.pow(1, 1.5)), -1 * 2 / (2 * Math.pow(2, 1.5)), + -1 * 3 / (2 * Math.pow(3, 1.5)), -1 * 4 / (2 * Math.pow(5, 1.5)) + ], + 1e-1); + }); + + it('gradients: Tensor2D', async () => { + const a = tf.tensor2d([3, 1, 2, 3], [2, 2]); + const dy = tf.tensor2d([1, 2, 3, 4], [2, 2]); + + const gradients = tf.grad(a => tf.rsqrt(a))(a, dy); + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose( + await gradients.data(), + [ + -1 * 1 / (2 * Math.pow(3, 1.5)), -1 * 2 / (2 * Math.pow(1, 1.5)), + -1 * 3 / (2 * Math.pow(2, 1.5)), -1 * 4 / (2 * Math.pow(3, 1.5)) + ], + 1e-1); + }); + + it('throws when passed a non-tensor', () => { + expect(() => tf.rsqrt({} as tf.Tensor)) + .toThrowError(/Argument 'x' passed to 'rsqrt' must be a Tensor/); + }); + + it('accepts a tensor-like object', async () => { + const r = tf.rsqrt([2, 4]); + expectArraysClose(await r.data(), [1 / Math.sqrt(2), 1 / Math.sqrt(4)]); + }); + + it('throws for string tensor', () => { + expect(() => tf.rsqrt('q')) + .toThrowError(/Argument 'x' passed to 'rsqrt' must be numeric/); + }); +}); diff --git a/tfjs-core/src/ops/sigmoid.ts b/tfjs-core/src/ops/sigmoid.ts new file mode 100644 index 00000000000..1e0c74c09d4 --- /dev/null +++ b/tfjs-core/src/ops/sigmoid.ts @@ -0,0 +1,49 @@ +/** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {ENGINE} from '../engine'; +import {Sigmoid, SigmoidInputs} from '../kernel_names'; +import {Tensor} from '../tensor'; +import {NamedTensorMap} from '../tensor_types'; +import {convertToTensor} from '../tensor_util_env'; +import {TensorLike} from '../types'; + +import {op} from './operation'; + +/** + * Computes sigmoid element-wise, `1 / (1 + exp(-x))` + * + * ```js + * const x = tf.tensor1d([0, -1, 2, -3]); + * + * x.sigmoid().print(); // or tf.sigmoid(x) + * ``` + * @param x The input tensor. + */ +/** @doc {heading: 'Operations', subheading: 'Basic math'} */ +function sigmoid_(x: T|TensorLike): T { + const $x = convertToTensor(x, 'x', 'sigmoid'); + + const inputs: SigmoidInputs = {x: $x}; + + return ENGINE.runKernelFunc((backend, save) => { + const res = backend.sigmoid($x); + save([res]); + return res; + }, inputs as {} as NamedTensorMap, null /* grad */, Sigmoid); +} +export const sigmoid = op({sigmoid_}); diff --git a/tfjs-core/src/ops/sigmoid_test.ts b/tfjs-core/src/ops/sigmoid_test.ts new file mode 100644 index 00000000000..1489e6cb6b0 --- /dev/null +++ b/tfjs-core/src/ops/sigmoid_test.ts @@ -0,0 +1,107 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import * as tf from '../index'; +import {ALL_ENVS, describeWithFlags} from '../jasmine_util'; +import {expectArraysClose} from '../test_util'; + +describeWithFlags('sigmoid', ALL_ENVS, () => { + it('basic', async () => { + const values = [1, -3, 2, 7, -4]; + const a = tf.tensor1d(values); + + const result = tf.sigmoid(a); + + const expected = []; + for (let i = 0; i < a.size; i++) { + expected[i] = 1 / (1 + Math.exp(-values[i])); + } + expectArraysClose(await result.data(), expected); + }); + + it('6D', async () => { + const a = tf.ones([2, 2, 2, 2, 2, 2]); + const result = tf.sigmoid(a); + + const expected = []; + for (let i = 0; i < a.size; i++) { + expected[i] = 1 / (1 + Math.exp(-1.0)); + } + expectArraysClose(await result.data(), expected); + }); + + it('propagates NaNs', async () => { + const a = tf.tensor1d([3, NaN]); + const res = tf.sigmoid(a); + expectArraysClose(await res.data(), [1 / (1 + Math.exp(-3)), NaN]); + }); + + it('gradients: Tensor1D', async () => { + const a = tf.tensor1d([1, 2, -3, 5]); + const dy = tf.tensor1d([1, 2, 3, 4]); + + const da = tf.grad(a => tf.sigmoid(a))(a, dy); + + const aVals = await a.array(); + const dyVals = await dy.array(); + const expected = []; + for (let i = 0; i < a.size; i++) { + const y = 1 / (1 + Math.exp(-aVals[i])); + expected[i] = dyVals[i] * y * (1 - y); + } + + expectArraysClose(await da.data(), expected); + }); + + it('gradient with clones', async () => { + const a = tf.tensor1d([1, 2, -3, 5]); + const dy = tf.tensor1d([1, 2, 3, 4]); + + const da = tf.grad(a => tf.sigmoid(a.clone()).clone())(a, dy); + + const aVals = await a.array(); + const dyVals = await dy.array(); + const expected = []; + for (let i = 0; i < a.size; i++) { + const y = 1 / (1 + Math.exp(-aVals[i])); + expected[i] = dyVals[i] * y * (1 - y); + } + + expectArraysClose(await da.data(), expected); + }); + + it('throws when passed a non-tensor', () => { + expect(() => tf.sigmoid({} as tf.Tensor)) + .toThrowError(/Argument 'x' passed to 'sigmoid' must be a Tensor/); + }); + + it('accepts a tensor-like object', async () => { + const values = [1, -3, 2, 7, -4]; + const result = tf.sigmoid(values); + + const expected = []; + for (let i = 0; i < values.length; i++) { + expected[i] = 1 / (1 + Math.exp(-values[i])); + } + expectArraysClose(await result.data(), expected); + }); + + it('throws for string tensor', () => { + expect(() => tf.sigmoid('q')) + .toThrowError(/Argument 'x' passed to 'sigmoid' must be numeric/); + }); +}); diff --git a/tfjs-core/src/ops/softplus.ts b/tfjs-core/src/ops/softplus.ts new file mode 100644 index 00000000000..53f032693a0 --- /dev/null +++ b/tfjs-core/src/ops/softplus.ts @@ -0,0 +1,48 @@ +/** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {ENGINE} from '../engine'; +import {Softplus, SoftplusInputs} from '../kernel_names'; +import {Tensor} from '../tensor'; +import {NamedTensorMap} from '../tensor_types'; +import {convertToTensor} from '../tensor_util_env'; +import {TensorLike} from '../types'; + +import {op} from './operation'; + +/** + * Computes softplus of the input `tf.Tensor` element-wise: `log(exp(x) + 1)` + * + * ```js + * const x = tf.tensor1d([0, 1, -1, .7]); + * + * x.softplus().print(); // or tf.softplus(x) + * ``` + * @param x The input tensor. + */ +/** @doc {heading: 'Operations', subheading: 'Basic math'} */ +function softplus_(x: T|TensorLike): T { + const $x = convertToTensor(x, 'x', 'softplus'); + + const inputs: SoftplusInputs = {x: $x}; + return ENGINE.runKernelFunc((backend, save) => { + const res = backend.softplus($x); + save([$x]); + return res; + }, inputs as {} as NamedTensorMap, null /* grad */, Softplus); +} +export const softplus = op({softplus_}); diff --git a/tfjs-core/src/ops/softplus_test.ts b/tfjs-core/src/ops/softplus_test.ts new file mode 100644 index 00000000000..a4b3a1d011d --- /dev/null +++ b/tfjs-core/src/ops/softplus_test.ts @@ -0,0 +1,159 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import * as tf from '../index'; +import {ALL_ENVS, describeWithFlags} from '../jasmine_util'; +import {expectArraysClose} from '../test_util'; + +describeWithFlags('softplus', ALL_ENVS, () => { + it('basic', async () => { + const values = [1, -3, 2, 7, -4]; + const a = tf.tensor1d(values); + + const result = tf.softplus(a); + + const expected = []; + for (let i = 0; i < a.size; i++) { + expected[i] = Math.log((1 + Math.exp(values[i]))); + } + expectArraysClose(await result.data(), expected); + }); + + it('scalar', async () => { + const a = tf.scalar(-2); + + const result = tf.softplus(a); + + const expected = [Math.log((1 + Math.exp(-2)))]; + expectArraysClose(await result.data(), expected); + }); + + it('tensor2D', async () => { + const values = [1, 2, -3, 5]; + const a = tf.tensor2d(values, [2, 2]); + + const result = tf.softplus(a); + + const expected = []; + for (let i = 0; i < a.size; i++) { + expected[i] = Math.log((1 + Math.exp(values[i]))); + } + expectArraysClose(await result.data(), expected); + }); + + it('larger magnitude negative inputs', async () => { + const values = [-100, -200, -3000, -50000]; + const a = tf.tensor1d(values); + + const result = tf.softplus(a); + + const expected = [0, 0, 0, 0]; + + expectArraysClose(await result.data(), expected); + }); + + it('larger magnitude positive inputs', async () => { + const values = [100, 200, 3000]; + const a = tf.tensor1d(values); + + const result = tf.softplus(a); + + const expected = [100, 200, 3000]; + + expectArraysClose(await result.data(), expected); + }); + + it('propagates NaNs', async () => { + const a = tf.tensor1d([3, NaN]); + const res = tf.softplus(a); + expectArraysClose(await res.data(), [Math.log((1 + Math.exp(3))), NaN]); + }); + + it('gradients: Scalar', async () => { + const a = tf.scalar(3); + const dy = tf.scalar(4); + const aVal = await a.array(); + const dyVal = await dy.array(); + + const da = tf.grad(a => tf.softplus(a))(a, dy); + const y = 1 / (1 + Math.exp(-aVal)); + + expectArraysClose(await da.data(), [dyVal * y]); + }); + + it('gradients: Scalar', async () => { + const a = tf.scalar(3); + const dy = tf.scalar(4); + const aVal = await a.array(); + const dyVal = await dy.array(); + + const da = tf.grad(a => tf.softplus(a.clone()).clone())(a, dy); + const y = 1 / (1 + Math.exp(-aVal)); + + expectArraysClose(await da.data(), [dyVal * y]); + }); + + it('gradients: Tensor1D', async () => { + const a = tf.tensor1d([1, 2, -3, 5]); + const aVals = await a.array(); + const dy = tf.tensor1d([1, 2, 3, 4]); + const dyVals = await dy.array(); + const da = tf.grad(a => tf.softplus(a))(a, dy); + + const expected = []; + for (let i = 0; i < a.size; i++) { + const y = 1 / (1 + Math.exp(-aVals[i])); + expected[i] = dyVals[i] * y; + } + + expectArraysClose(await da.data(), expected); + }); + + it('gradients: Tensor2D', async () => { + const a = tf.tensor2d([1, 2, -3, 5], [2, 2]); + const dy = tf.tensor2d([1, 2, 3, 4], [2, 2]); + + const da = tf.grad(a => tf.softplus(a))(a, dy); + + const expected = []; + const aVals = await a.data(); + const dyVals = await dy.data(); + + for (let i = 0; i < a.size; i++) { + const y = 1 / (1 + Math.exp(-aVals[i])); + expected[i] = dyVals[i] * y; + } + + expectArraysClose(await da.data(), expected); + }); + + it('throws when passed a non-tensor', () => { + expect(() => tf.softplus({} as tf.Tensor)) + .toThrowError(/Argument 'x' passed to 'softplus' must be a Tensor/); + }); + + it('accepts a tensor-like object', async () => { + const result = tf.softplus(-2); + const expected = [Math.log((1 + Math.exp(-2)))]; + expectArraysClose(await result.data(), expected); + }); + + it('throws for string tensor', () => { + expect(() => tf.softplus('q')) + .toThrowError(/Argument 'x' passed to 'softplus' must be numeric/); + }); +}); diff --git a/tfjs-core/src/ops/sqrt.ts b/tfjs-core/src/ops/sqrt.ts new file mode 100644 index 00000000000..2dcd3e8002e --- /dev/null +++ b/tfjs-core/src/ops/sqrt.ts @@ -0,0 +1,49 @@ +/** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {ENGINE} from '../engine'; +import {Sqrt, SqrtInputs} from '../kernel_names'; +import {Tensor} from '../tensor'; +import {NamedTensorMap} from '../tensor_types'; +import {convertToTensor} from '../tensor_util_env'; +import {TensorLike} from '../types'; + +import {op} from './operation'; + +/** + * Computes square root of the input `tf.Tensor` element-wise: `y = sqrt(x)` + * + * ```js + * const x = tf.tensor1d([1, 2, 4, -1]); + * + * x.sqrt().print(); // or tf.sqrt(x) + * ``` + * @param x The input tensor. + */ +/** @doc {heading: 'Operations', subheading: 'Basic math'} */ +function sqrt_(x: T|TensorLike): T { + const $x = convertToTensor(x, 'x', 'sqrt'); + + const inputs: SqrtInputs = {x: $x}; + + return ENGINE.runKernelFunc((backend, save) => { + const res = backend.sqrt($x); + save([$x]); + return res; + }, inputs as {} as NamedTensorMap, null /* grad */, Sqrt); +} +export const sqrt = op({sqrt_}); diff --git a/tfjs-core/src/ops/sqrt_test.ts b/tfjs-core/src/ops/sqrt_test.ts new file mode 100644 index 00000000000..f9046ce2582 --- /dev/null +++ b/tfjs-core/src/ops/sqrt_test.ts @@ -0,0 +1,105 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import * as tf from '../index'; +import {ALL_ENVS, describeWithFlags} from '../jasmine_util'; +import {expectArraysClose} from '../test_util'; + +describeWithFlags('sqrt', ALL_ENVS, () => { + it('sqrt', async () => { + const a = tf.tensor1d([2, 4]); + const r = tf.sqrt(a); + expectArraysClose(await r.data(), [Math.sqrt(2), Math.sqrt(4)]); + }); + + it('sqrt propagates NaNs', async () => { + const a = tf.tensor1d([1, NaN]); + const r = tf.sqrt(a); + expectArraysClose(await r.data(), [Math.sqrt(1), NaN]); + }); + + it('gradients: Scalar', async () => { + const a = tf.scalar(4); + const dy = tf.scalar(8); + + const da = tf.grad(a => tf.sqrt(a))(a, dy); + + expect(da.shape).toEqual(a.shape); + expect(da.dtype).toEqual('float32'); + expectArraysClose(await da.data(), [8 / (2 * Math.sqrt(4))]); + }); + + it('gradient with clones', async () => { + const a = tf.scalar(4); + const dy = tf.scalar(8); + + const da = tf.grad(a => tf.sqrt(a.clone()).clone())(a, dy); + + expect(da.shape).toEqual(a.shape); + expect(da.dtype).toEqual('float32'); + expectArraysClose(await da.data(), [8 / (2 * Math.sqrt(4))]); + }); + + it('gradients: Tensor1D', async () => { + const a = tf.tensor1d([1, 2, 3, 5]); + const dy = tf.tensor1d([1, 2, 3, 4]); + + const gradients = tf.grad(a => tf.sqrt(a))(a, dy); + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose( + await gradients.data(), + [ + 1 / (2 * Math.sqrt(1)), 2 / (2 * Math.sqrt(2)), + 3 / (2 * Math.sqrt(3)), 4 / (2 * Math.sqrt(5)) + ], + 1e-1); + }); + + it('gradients: Tensor2D', async () => { + const a = tf.tensor2d([3, 1, 2, 3], [2, 2]); + const dy = tf.tensor2d([1, 2, 3, 4], [2, 2]); + + const gradients = tf.grad(a => tf.sqrt(a))(a, dy); + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose( + await gradients.data(), + [ + 1 / (2 * Math.sqrt(3)), 2 / (2 * Math.sqrt(1)), + 3 / (2 * Math.sqrt(2)), 4 / (2 * Math.sqrt(3)) + ], + 1e-1); + }); + + it('throws when passed a non-tensor', () => { + expect(() => tf.sqrt({} as tf.Tensor)) + .toThrowError(/Argument 'x' passed to 'sqrt' must be a Tensor/); + }); + + it('accepts a tensor-like object', async () => { + const r = tf.sqrt([2, 4]); + expectArraysClose(await r.data(), [Math.sqrt(2), Math.sqrt(4)]); + }); + + it('throws for string tensor', () => { + expect(() => tf.sqrt('q')) + .toThrowError(/Argument 'x' passed to 'sqrt' must be numeric/); + }); +}); diff --git a/tfjs-core/src/ops/square_test.ts b/tfjs-core/src/ops/square_test.ts new file mode 100644 index 00000000000..63eb300a30e --- /dev/null +++ b/tfjs-core/src/ops/square_test.ts @@ -0,0 +1,140 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import * as tf from '../index'; +import {ALL_ENVS, describeWithFlags} from '../jasmine_util'; +import {expectArraysClose} from '../test_util'; + +describeWithFlags('square', ALL_ENVS, () => { + it('1D array', async () => { + const a = tf.tensor1d([2, 4, Math.sqrt(2)]); + const r = tf.square(a); + expectArraysClose(await r.data(), [4, 16, 2]); + }); + + it('2D array', async () => { + const a = tf.tensor2d([1, 2, Math.sqrt(2), Math.sqrt(3)], [2, 2]); + const r = tf.square(a); + expect(r.shape).toEqual([2, 2]); + expectArraysClose(await r.data(), [1, 4, 2, 3]); + }); + + it('5D array', async () => { + const a = tf.tensor5d([1, 2, Math.sqrt(2), Math.sqrt(3)], [1, 1, 2, 2, 1]); + const r = tf.square(a); + expect(r.shape).toEqual([1, 1, 2, 2, 1]); + expectArraysClose(await r.data(), [1, 4, 2, 3]); + }); + + it('6D array', async () => { + const a = tf.tensor6d( + [1, 2, Math.sqrt(2), Math.sqrt(3), 3, 4, Math.sqrt(7), Math.sqrt(13)], + [1, 1, 2, 2, 2, 1]); + const r = tf.square(a); + expect(r.shape).toEqual(a.shape); + expectArraysClose(await r.data(), [1, 4, 2, 3, 9, 16, 7, 13]); + }); + + it('square propagates NaNs', async () => { + const a = tf.tensor1d([1.5, NaN]); + const r = tf.square(a); + expectArraysClose(await r.data(), [2.25, NaN]); + }); + + it('gradients: Scalar', async () => { + const a = tf.scalar(5); + const dy = tf.scalar(8); + + const gradients = tf.grad(a => tf.square(a))(a, dy); + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), [2 * 5 * 8]); + }); + + it('gradients: Scalar', async () => { + const a = tf.scalar(5); + const dy = tf.scalar(8); + + const gradients = tf.grad(a => tf.square(a.clone()).clone())(a, dy); + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), [2 * 5 * 8]); + }); + + it('gradients: Tensor1D', async () => { + const a = tf.tensor1d([-1, 2, 3, -5]); + const dy = tf.tensor1d([1, 2, 3, 4]); + + const gradients = tf.grad(a => tf.square(a))(a, dy); + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), [-2, 4 * 2, 6 * 3, -10 * 4]); + }); + + it('gradients: Tensor2D', async () => { + const a = tf.tensor2d([-3, 1, 2, 3], [2, 2]); + const dy = tf.tensor2d([1, 2, 3, 4], [2, 2]); + + const gradients = tf.grad(a => tf.square(a))(a, dy); + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), [-6 * 1, 2 * 2, 4 * 3, 6 * 4]); + }); + + it('gradients: Tensor5D', async () => { + const a = tf.tensor5d([-3, 1, 2, 3], [1, 1, 1, 2, 2]); + const dy = tf.tensor5d([1, 2, 3, 4], [1, 1, 1, 2, 2]); + + const gradients = tf.grad(a => tf.square(a))(a, dy); + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), [-6 * 1, 2 * 2, 4 * 3, 6 * 4]); + }); + + it('gradients: Tensor6D', async () => { + const a = tf.tensor6d([-3, 1, 2, 3, -4, 5, 12, 3], [1, 1, 1, 2, 2, 2]); + const dy = tf.tensor6d([1, 2, 3, 4, 5, 6, 7, 8], [1, 1, 1, 2, 2, 2]); + + const gradients = tf.grad(a => tf.square(a))(a, dy); + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose( + await gradients.data(), + [-6 * 1, 2 * 2, 4 * 3, 6 * 4, -8 * 5, 10 * 6, 24 * 7, 6 * 8]); + }); + + it('throws when passed a non-tensor', () => { + expect(() => tf.square({} as tf.Tensor)) + .toThrowError(/Argument 'x' passed to 'square' must be a Tensor/); + }); + + it('accepts a tensor-like object', async () => { + const r = tf.square([2, 4, Math.sqrt(2)]); + expectArraysClose(await r.data(), [4, 16, 2]); + }); + + it('throws for string tensor', () => { + expect(() => tf.square('q')) + .toThrowError(/Argument 'x' passed to 'square' must be numeric/); + }); +}); diff --git a/tfjs-core/src/ops/step.ts b/tfjs-core/src/ops/step.ts new file mode 100644 index 00000000000..14812b06977 --- /dev/null +++ b/tfjs-core/src/ops/step.ts @@ -0,0 +1,50 @@ +/** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {ENGINE} from '../engine'; +import {Step, StepAttrs, StepInputs} from '../kernel_names'; +import {NamedAttrMap} from '../kernel_registry'; +import {Tensor} from '../tensor'; +import {NamedTensorMap} from '../tensor_types'; +import {convertToTensor} from '../tensor_util_env'; +import {TensorLike} from '../types'; + +import {op} from './operation'; + +/** + * Computes step of the input `tf.Tensor` element-wise: `x > 0 ? 1 : alpha * x` + * + * ```js + * const x = tf.tensor1d([0, 2, -1, -3]); + * + * x.step(.5).print(); // or tf.step(x, .5) + * ``` + * @param x The input tensor. + * @param alpha The gradient when input is negative. + */ +/** @doc {heading: 'Operations', subheading: 'Basic math'} */ +function step_(x: T|TensorLike, alpha = 0.0): T { + const $x = convertToTensor(x, 'x', 'step'); + + const inputs: StepInputs = {x: $x}; + const attrs: StepAttrs = {alpha}; + + return ENGINE.runKernelFunc( + backend => backend.step($x, alpha), inputs as {} as NamedTensorMap, + null /* grad */, Step, attrs as {} as NamedAttrMap); +} +export const step = op({step_}); diff --git a/tfjs-core/src/ops/step_test.ts b/tfjs-core/src/ops/step_test.ts new file mode 100644 index 00000000000..39bcf6e0a69 --- /dev/null +++ b/tfjs-core/src/ops/step_test.ts @@ -0,0 +1,106 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import * as tf from '../index'; +import {ALL_ENVS, describeWithFlags} from '../jasmine_util'; +import {expectArraysClose} from '../test_util'; + +describeWithFlags('step', ALL_ENVS, () => { + it('with 1d tensor', async () => { + const a = tf.tensor1d([1, -2, -.01, 3, -0.1]); + const result = tf.step(a); + expectArraysClose(await result.data(), [1, 0, 0, 1, 0]); + }); + + it('with 1d tensor and alpha', async () => { + const a = tf.tensor1d([1, -2, -.01, 3, NaN]); + const result = tf.step(a, 0.1); + expectArraysClose(await result.data(), [1, 0.1, 0.1, 1, NaN]); + }); + + it('with 2d tensor', async () => { + const a = tf.tensor2d([1, -5, -3, 4], [2, 2]); + const result = tf.step(a); + expect(result.shape).toEqual([2, 2]); + expectArraysClose(await result.data(), [1, 0, 0, 1]); + }); + + it('propagates NaNs', async () => { + const a = tf.tensor1d([1, -2, -.01, 3, NaN]); + const result = tf.step(a); + expectArraysClose(await result.data(), [1, 0, 0, 1, NaN]); + }); + + it('gradients: Scalar', async () => { + const a = tf.scalar(-4); + const dy = tf.scalar(8); + + const gradients = tf.grad(a => tf.step(a))(a, dy); + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), [0]); + }); + + it('gradient with clones', async () => { + const a = tf.scalar(-4); + const dy = tf.scalar(8); + + const gradients = tf.grad(a => tf.step(a.clone()).clone())(a, dy); + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), [0]); + }); + + it('gradients: Tensor1D', async () => { + const a = tf.tensor1d([1, 2, -3, 5]); + const dy = tf.tensor1d([1, 2, 3, 4]); + + const gradients = tf.grad(a => tf.step(a))(a, dy); + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), [0, 0, 0, 0]); + }); + + it('gradients: Tensor2D', async () => { + const a = tf.tensor2d([3, -1, -2, 3], [2, 2]); + const dy = tf.tensor2d([1, 2, 3, 4], [2, 2]); + + const gradients = tf.grad(a => tf.step(a))(a, dy); + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), [0, 0, 0, 0]); + }); + + it('throws when passed a non-tensor', () => { + expect(() => tf.step({} as tf.Tensor)) + .toThrowError(/Argument 'x' passed to 'step' must be a Tensor/); + }); + + it('accepts a tensor-like object', async () => { + const result = tf.step([1, -2, -.01, 3, -0.1]); + expectArraysClose(await result.data(), [1, 0, 0, 1, 0]); + }); + + it('throws for string tensor', () => { + expect(() => tf.step('q')) + .toThrowError(/Argument 'x' passed to 'step' must be numeric/); + }); +}); diff --git a/tfjs-core/src/ops/unary_ops.ts b/tfjs-core/src/ops/unary_ops.ts deleted file mode 100644 index faad0b69eed..00000000000 --- a/tfjs-core/src/ops/unary_ops.ts +++ /dev/null @@ -1,248 +0,0 @@ -/** - * @license - * Copyright 2018 Google LLC. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - -import {ENGINE} from '../engine'; -import {Tensor} from '../tensor'; -import {convertToTensor} from '../tensor_util_env'; -import {TensorLike} from '../types'; -import {op} from './operation'; -import {scalar} from './scalar'; -import {zerosLike} from './zeros_like'; - -/** - * RReturns which elements of x are NaN. - * - * ```js - * const x = tf.tensor1d([NaN, Infinity, -Infinity, 0, 1]); - * - * x.isNaN().print(); // or tf.isNaN(x) - * ``` - * @param x The input Tensor. - */ -/** @doc {heading: 'Operations', subheading: 'Basic math'} */ -function isNaN_(x: T|TensorLike): T { - const $x = convertToTensor(x, 'x', 'isNaN'); - - // TODO(nsthorat): Let gradients be null for cases where we want to stop - // backpropgation. - const grad = (dy: T) => { - return {$x: () => zerosLike(dy)}; - }; - return ENGINE.runKernelFunc(backend => backend.isNaN($x), {$x}, grad); -} - -/** - * Returns which elements of x are Infinity or -Infinity. - * - * ```js - * const x = tf.tensor1d([NaN, Infinity, -Infinity, 0, 1]); - * - * x.isInf().print(); // or tf.isNaN(x) - * ``` - * @param x The input Tensor. - */ -/** @doc {heading: 'Operations', subheading: 'Basic math'} */ -function isInf_(x: T|TensorLike): T { - const $x = convertToTensor(x, 'x', 'isInf'); - - // TODO(nsthorat): Let gradients be null for cases where we want to stop - // backpropgation. - const grad = (dy: T) => { - return {$x: () => zerosLike(dy)}; - }; - return ENGINE.runKernelFunc(backend => backend.isInf($x), {$x}, grad); -} - -/** - * Returns which elements of x are finite. - * - * ```js - * const x = tf.tensor1d([NaN, Infinity, -Infinity, 0, 1]); - * - * x.isFinite().print(); // or tf.isNaN(x) - * ``` - * @param x The input Tensor. - */ -/** @doc {heading: 'Operations', subheading: 'Basic math'} */ -function isFinite_(x: T|TensorLike): T { - const $x = convertToTensor(x, 'x', 'isFinite'); - - // TODO(nsthorat): Let gradients be null for cases where we want to stop - // backpropgation. - const grad = (dy: T) => { - return {$x: () => zerosLike(dy)}; - }; - return ENGINE.runKernelFunc(backend => backend.isFinite($x), {$x}, grad); -} - -/** - * Computes round of input `tf.Tensor` element-wise: `round(x)`. - * It implements banker's rounding. - * - * ```js - * const x = tf.tensor1d([.6, 1.1, -3.3]); - * - * x.round().print(); // or tf.round(x) - * ``` - * @param x The input tensor. - */ -/** @doc {heading: 'Operations', subheading: 'Basic math'} */ -function round_(x: T|TensorLike): T { - const $x = convertToTensor(x, 'x', 'round'); - - // TODO(nsthorat): Let gradients be null for cases where we want to stop - // backpropgation. - const grad = (dy: T) => { - return {$x: () => zerosLike(dy)}; - }; - return ENGINE.runKernelFunc(backend => backend.round($x), {$x}, grad); -} - -/** - * Computes square root of the input `tf.Tensor` element-wise: `y = sqrt(x)` - * - * ```js - * const x = tf.tensor1d([1, 2, 4, -1]); - * - * x.sqrt().print(); // or tf.sqrt(x) - * ``` - * @param x The input tensor. - */ -/** @doc {heading: 'Operations', subheading: 'Basic math'} */ -function sqrt_(x: T|TensorLike): T { - const $x = convertToTensor(x, 'x', 'sqrt'); - - const grad = (dy: T, saved: Tensor[]) => { - const [$x] = saved; - return {x: () => dy.div($x.toFloat().sqrt().mul(2))} as {x: () => T}; - }; - return ENGINE.runKernelFunc((backend, save) => { - const res = backend.sqrt($x); - save([$x]); - return res; - }, {x: $x}, grad, 'Sqrt', {}); -} - -/** - * Computes reciprocal of square root of the input `tf.Tensor` element-wise: - * `y = 1 / sqrt(x)` - * - * ```js - * const x = tf.tensor1d([1, 2, 4, -1]); - * - * x.rsqrt().print(); // or tf.rsqrt(x) - * ``` - * @param x The input tensor. - */ -/** @doc {heading: 'Operations', subheading: 'Basic math'} */ -function rsqrt_(x: T|TensorLike): T { - const $x = convertToTensor(x, 'x', 'rsqrt'); - - const grad = (dy: T, saved: Tensor[]) => { - const [$x] = saved; - return {x: () => dy.div($x.pow(1.5).mul(2)).neg() as T}; - }; - const inputsToSave = [$x]; - return ENGINE.runKernelFunc((backend, save) => { - const res = backend.rsqrt($x); - save([$x]); - return res; - }, {x: $x}, grad, 'Rsqrt', {} /* attrs */, inputsToSave); -} - -/** - * Computes sigmoid element-wise, `1 / (1 + exp(-x))` - * - * ```js - * const x = tf.tensor1d([0, -1, 2, -3]); - * - * x.sigmoid().print(); // or tf.sigmoid(x) - * ``` - * @param x The input tensor. - */ -/** @doc {heading: 'Operations', subheading: 'Basic math'} */ -function sigmoid_(x: T|TensorLike): T { - const $x = convertToTensor(x, 'x', 'sigmoid'); - - const grad = (dy: T, saved: Tensor[]) => { - const [y] = saved; - return {x: () => dy.mul(y.mul(scalar(1).sub(y)))} as {x: () => T}; - }; - return ENGINE.runKernelFunc((backend, save) => { - const y = backend.sigmoid($x); - save([y]); - return y; - }, {x: $x}, grad, 'Sigmoid'); -} - -/** - * Computes softplus of the input `tf.Tensor` element-wise: `log(exp(x) + 1)` - * - * ```js - * const x = tf.tensor1d([0, 1, -1, .7]); - * - * x.softplus().print(); // or tf.softplus(x) - * ``` - * @param x The input tensor. - */ -/** @doc {heading: 'Operations', subheading: 'Basic math'} */ -function softplus_(x: T|TensorLike): T { - const $x = convertToTensor(x, 'x', 'softplus'); - - const grad = (dy: T, saved: Tensor[]) => { - const [$x] = saved; - return {$x: () => dy.mul($x.sigmoid())} as {$x: () => T}; - }; - return ENGINE.runKernelFunc((backend, save) => { - const res = backend.softplus($x); - save([$x]); - return res; - }, {$x}, grad); -} - -/** - * Computes step of the input `tf.Tensor` element-wise: `x > 0 ? 1 : alpha * x` - * - * ```js - * const x = tf.tensor1d([0, 2, -1, -3]); - * - * x.step(.5).print(); // or tf.step(x, .5) - * ``` - * @param x The input tensor. - * @param alpha The gradient when input is negative. - */ -/** @doc {heading: 'Operations', subheading: 'Basic math'} */ -function step_(x: T|TensorLike, alpha = 0.0): T { - const $x = convertToTensor(x, 'x', 'step'); - - // TODO(manrajgrover): Return null for gradients when backprop supports - // it. - const grad = (dy: T) => { - return {$x: () => zerosLike(dy)}; - }; - return ENGINE.runKernelFunc(backend => backend.step($x, alpha), {$x}, grad); -} - -export const round = op({round_}); -export const rsqrt = op({rsqrt_}); -export const sigmoid = op({sigmoid_}); -export const isNaN = op({isNaN_}); -export const isInf = op({isInf_}); -export const isFinite = op({isFinite_}); -export const softplus = op({softplus_}); -export const sqrt = op({sqrt_}); -export const step = op({step_}); diff --git a/tfjs-core/src/ops/unary_ops_test.ts b/tfjs-core/src/ops/unary_ops_test.ts deleted file mode 100644 index 3b2154dbeff..00000000000 --- a/tfjs-core/src/ops/unary_ops_test.ts +++ /dev/null @@ -1,867 +0,0 @@ -/** - * @license - * Copyright 2017 Google LLC. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - -import * as tf from '../index'; -import {ALL_ENVS, describeWithFlags} from '../jasmine_util'; -import {expectArraysClose} from '../test_util'; - -describeWithFlags('step', ALL_ENVS, () => { - it('with 1d tensor', async () => { - const a = tf.tensor1d([1, -2, -.01, 3, -0.1]); - const result = tf.step(a); - expectArraysClose(await result.data(), [1, 0, 0, 1, 0]); - }); - - it('with 1d tensor and alpha', async () => { - const a = tf.tensor1d([1, -2, -.01, 3, NaN]); - const result = tf.step(a, 0.1); - expectArraysClose(await result.data(), [1, 0.1, 0.1, 1, NaN]); - }); - - it('with 2d tensor', async () => { - const a = tf.tensor2d([1, -5, -3, 4], [2, 2]); - const result = tf.step(a); - expect(result.shape).toEqual([2, 2]); - expectArraysClose(await result.data(), [1, 0, 0, 1]); - }); - - it('propagates NaNs', async () => { - const a = tf.tensor1d([1, -2, -.01, 3, NaN]); - const result = tf.step(a); - expectArraysClose(await result.data(), [1, 0, 0, 1, NaN]); - }); - - it('gradients: Scalar', async () => { - const a = tf.scalar(-4); - const dy = tf.scalar(8); - - const gradients = tf.grad(a => tf.step(a))(a, dy); - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), [0]); - }); - - it('gradient with clones', async () => { - const a = tf.scalar(-4); - const dy = tf.scalar(8); - - const gradients = tf.grad(a => tf.step(a.clone()).clone())(a, dy); - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), [0]); - }); - - it('gradients: Tensor1D', async () => { - const a = tf.tensor1d([1, 2, -3, 5]); - const dy = tf.tensor1d([1, 2, 3, 4]); - - const gradients = tf.grad(a => tf.step(a))(a, dy); - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), [0, 0, 0, 0]); - }); - - it('gradients: Tensor2D', async () => { - const a = tf.tensor2d([3, -1, -2, 3], [2, 2]); - const dy = tf.tensor2d([1, 2, 3, 4], [2, 2]); - - const gradients = tf.grad(a => tf.step(a))(a, dy); - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), [0, 0, 0, 0]); - }); - - it('throws when passed a non-tensor', () => { - expect(() => tf.step({} as tf.Tensor)) - .toThrowError(/Argument 'x' passed to 'step' must be a Tensor/); - }); - - it('accepts a tensor-like object', async () => { - const result = tf.step([1, -2, -.01, 3, -0.1]); - expectArraysClose(await result.data(), [1, 0, 0, 1, 0]); - }); - - it('throws for string tensor', () => { - expect(() => tf.step('q')) - .toThrowError(/Argument 'x' passed to 'step' must be numeric/); - }); -}); - -describeWithFlags('sigmoid', ALL_ENVS, () => { - it('basic', async () => { - const values = [1, -3, 2, 7, -4]; - const a = tf.tensor1d(values); - - const result = tf.sigmoid(a); - - const expected = []; - for (let i = 0; i < a.size; i++) { - expected[i] = 1 / (1 + Math.exp(-values[i])); - } - expectArraysClose(await result.data(), expected); - }); - - it('6D', async () => { - const a = tf.ones([2, 2, 2, 2, 2, 2]); - const result = tf.sigmoid(a); - - const expected = []; - for (let i = 0; i < a.size; i++) { - expected[i] = 1 / (1 + Math.exp(-1.0)); - } - expectArraysClose(await result.data(), expected); - }); - - it('propagates NaNs', async () => { - const a = tf.tensor1d([3, NaN]); - const res = tf.sigmoid(a); - expectArraysClose(await res.data(), [1 / (1 + Math.exp(-3)), NaN]); - }); - - it('gradients: Tensor1D', async () => { - const a = tf.tensor1d([1, 2, -3, 5]); - const dy = tf.tensor1d([1, 2, 3, 4]); - - const da = tf.grad(a => tf.sigmoid(a))(a, dy); - - const aVals = await a.array(); - const dyVals = await dy.array(); - const expected = []; - for (let i = 0; i < a.size; i++) { - const y = 1 / (1 + Math.exp(-aVals[i])); - expected[i] = dyVals[i] * y * (1 - y); - } - - expectArraysClose(await da.data(), expected); - }); - - it('gradient with clones', async () => { - const a = tf.tensor1d([1, 2, -3, 5]); - const dy = tf.tensor1d([1, 2, 3, 4]); - - const da = tf.grad(a => tf.sigmoid(a.clone()).clone())(a, dy); - - const aVals = await a.array(); - const dyVals = await dy.array(); - const expected = []; - for (let i = 0; i < a.size; i++) { - const y = 1 / (1 + Math.exp(-aVals[i])); - expected[i] = dyVals[i] * y * (1 - y); - } - - expectArraysClose(await da.data(), expected); - }); - - it('throws when passed a non-tensor', () => { - expect(() => tf.sigmoid({} as tf.Tensor)) - .toThrowError(/Argument 'x' passed to 'sigmoid' must be a Tensor/); - }); - - it('accepts a tensor-like object', async () => { - const values = [1, -3, 2, 7, -4]; - const result = tf.sigmoid(values); - - const expected = []; - for (let i = 0; i < values.length; i++) { - expected[i] = 1 / (1 + Math.exp(-values[i])); - } - expectArraysClose(await result.data(), expected); - }); - - it('throws for string tensor', () => { - expect(() => tf.sigmoid('q')) - .toThrowError(/Argument 'x' passed to 'sigmoid' must be numeric/); - }); -}); - -describeWithFlags('softplus', ALL_ENVS, () => { - it('basic', async () => { - const values = [1, -3, 2, 7, -4]; - const a = tf.tensor1d(values); - - const result = tf.softplus(a); - - const expected = []; - for (let i = 0; i < a.size; i++) { - expected[i] = Math.log((1 + Math.exp(values[i]))); - } - expectArraysClose(await result.data(), expected); - }); - - it('scalar', async () => { - const a = tf.scalar(-2); - - const result = tf.softplus(a); - - const expected = [Math.log((1 + Math.exp(-2)))]; - expectArraysClose(await result.data(), expected); - }); - - it('tensor2D', async () => { - const values = [1, 2, -3, 5]; - const a = tf.tensor2d(values, [2, 2]); - - const result = tf.softplus(a); - - const expected = []; - for (let i = 0; i < a.size; i++) { - expected[i] = Math.log((1 + Math.exp(values[i]))); - } - expectArraysClose(await result.data(), expected); - }); - - it('larger magnitude negative inputs', async () => { - const values = [-100, -200, -3000, -50000]; - const a = tf.tensor1d(values); - - const result = tf.softplus(a); - - const expected = [0, 0, 0, 0]; - - expectArraysClose(await result.data(), expected); - }); - - it('larger magnitude positive inputs', async () => { - const values = [100, 200, 3000]; - const a = tf.tensor1d(values); - - const result = tf.softplus(a); - - const expected = [100, 200, 3000]; - - expectArraysClose(await result.data(), expected); - }); - - it('propagates NaNs', async () => { - const a = tf.tensor1d([3, NaN]); - const res = tf.softplus(a); - expectArraysClose(await res.data(), [Math.log((1 + Math.exp(3))), NaN]); - }); - - it('gradients: Scalar', async () => { - const a = tf.scalar(3); - const dy = tf.scalar(4); - const aVal = await a.array(); - const dyVal = await dy.array(); - - const da = tf.grad(a => tf.softplus(a))(a, dy); - const y = 1 / (1 + Math.exp(-aVal)); - - expectArraysClose(await da.data(), [dyVal * y]); - }); - - it('gradients: Scalar', async () => { - const a = tf.scalar(3); - const dy = tf.scalar(4); - const aVal = await a.array(); - const dyVal = await dy.array(); - - const da = tf.grad(a => tf.softplus(a.clone()).clone())(a, dy); - const y = 1 / (1 + Math.exp(-aVal)); - - expectArraysClose(await da.data(), [dyVal * y]); - }); - - it('gradients: Tensor1D', async () => { - const a = tf.tensor1d([1, 2, -3, 5]); - const aVals = await a.array(); - const dy = tf.tensor1d([1, 2, 3, 4]); - const dyVals = await dy.array(); - const da = tf.grad(a => tf.softplus(a))(a, dy); - - const expected = []; - for (let i = 0; i < a.size; i++) { - const y = 1 / (1 + Math.exp(-aVals[i])); - expected[i] = dyVals[i] * y; - } - - expectArraysClose(await da.data(), expected); - }); - - it('gradients: Tensor2D', async () => { - const a = tf.tensor2d([1, 2, -3, 5], [2, 2]); - const dy = tf.tensor2d([1, 2, 3, 4], [2, 2]); - - const da = tf.grad(a => tf.softplus(a))(a, dy); - - const expected = []; - const aVals = await a.data(); - const dyVals = await dy.data(); - - for (let i = 0; i < a.size; i++) { - const y = 1 / (1 + Math.exp(-aVals[i])); - expected[i] = dyVals[i] * y; - } - - expectArraysClose(await da.data(), expected); - }); - - it('throws when passed a non-tensor', () => { - expect(() => tf.softplus({} as tf.Tensor)) - .toThrowError(/Argument 'x' passed to 'softplus' must be a Tensor/); - }); - - it('accepts a tensor-like object', async () => { - const result = tf.softplus(-2); - const expected = [Math.log((1 + Math.exp(-2)))]; - expectArraysClose(await result.data(), expected); - }); - - it('throws for string tensor', () => { - expect(() => tf.softplus('q')) - .toThrowError(/Argument 'x' passed to 'softplus' must be numeric/); - }); -}); - -describeWithFlags('sqrt', ALL_ENVS, () => { - it('sqrt', async () => { - const a = tf.tensor1d([2, 4]); - const r = tf.sqrt(a); - expectArraysClose(await r.data(), [Math.sqrt(2), Math.sqrt(4)]); - }); - - it('sqrt propagates NaNs', async () => { - const a = tf.tensor1d([1, NaN]); - const r = tf.sqrt(a); - expectArraysClose(await r.data(), [Math.sqrt(1), NaN]); - }); - - it('gradients: Scalar', async () => { - const a = tf.scalar(4); - const dy = tf.scalar(8); - - const da = tf.grad(a => tf.sqrt(a))(a, dy); - - expect(da.shape).toEqual(a.shape); - expect(da.dtype).toEqual('float32'); - expectArraysClose(await da.data(), [8 / (2 * Math.sqrt(4))]); - }); - - it('gradient with clones', async () => { - const a = tf.scalar(4); - const dy = tf.scalar(8); - - const da = tf.grad(a => tf.sqrt(a.clone()).clone())(a, dy); - - expect(da.shape).toEqual(a.shape); - expect(da.dtype).toEqual('float32'); - expectArraysClose(await da.data(), [8 / (2 * Math.sqrt(4))]); - }); - - it('gradients: Tensor1D', async () => { - const a = tf.tensor1d([1, 2, 3, 5]); - const dy = tf.tensor1d([1, 2, 3, 4]); - - const gradients = tf.grad(a => tf.sqrt(a))(a, dy); - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose( - await gradients.data(), - [ - 1 / (2 * Math.sqrt(1)), 2 / (2 * Math.sqrt(2)), - 3 / (2 * Math.sqrt(3)), 4 / (2 * Math.sqrt(5)) - ], - 1e-1); - }); - - it('gradients: Tensor2D', async () => { - const a = tf.tensor2d([3, 1, 2, 3], [2, 2]); - const dy = tf.tensor2d([1, 2, 3, 4], [2, 2]); - - const gradients = tf.grad(a => tf.sqrt(a))(a, dy); - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose( - await gradients.data(), - [ - 1 / (2 * Math.sqrt(3)), 2 / (2 * Math.sqrt(1)), - 3 / (2 * Math.sqrt(2)), 4 / (2 * Math.sqrt(3)) - ], - 1e-1); - }); - - it('throws when passed a non-tensor', () => { - expect(() => tf.sqrt({} as tf.Tensor)) - .toThrowError(/Argument 'x' passed to 'sqrt' must be a Tensor/); - }); - - it('accepts a tensor-like object', async () => { - const r = tf.sqrt([2, 4]); - expectArraysClose(await r.data(), [Math.sqrt(2), Math.sqrt(4)]); - }); - - it('throws for string tensor', () => { - expect(() => tf.sqrt('q')) - .toThrowError(/Argument 'x' passed to 'sqrt' must be numeric/); - }); -}); - -describeWithFlags('rsqrt', ALL_ENVS, () => { - it('rsqrt', async () => { - const a = tf.tensor1d([2, 4]); - const r = tf.rsqrt(a); - expectArraysClose(await r.data(), [1 / Math.sqrt(2), 1 / Math.sqrt(4)]); - }); - - it('rsqrt propagates NaNs', async () => { - const a = tf.tensor1d([1, NaN]); - const r = tf.rsqrt(a); - expectArraysClose(await r.data(), [1 / Math.sqrt(1), NaN]); - }); - - it('gradients: Scalar', async () => { - const a = tf.scalar(4); - const dy = tf.scalar(8); - - const da = tf.grad(a => tf.rsqrt(a))(a, dy); - - expect(da.shape).toEqual(a.shape); - expect(da.dtype).toEqual('float32'); - expectArraysClose(await da.data(), [(-1 * 8) / (2 * Math.pow(4, 1.5))]); - }); - - it('gradients: Scalar', async () => { - const a = tf.scalar(4); - const dy = tf.scalar(8); - - const da = tf.grad(a => tf.rsqrt(a.clone()).clone())(a, dy); - - expect(da.shape).toEqual(a.shape); - expect(da.dtype).toEqual('float32'); - expectArraysClose(await da.data(), [(-1 * 8) / (2 * Math.pow(4, 1.5))]); - }); - - it('gradients: Tensor1D', async () => { - const a = tf.tensor1d([1, 2, 3, 5]); - const dy = tf.tensor1d([1, 2, 3, 4]); - - const gradients = tf.grad(a => tf.rsqrt(a))(a, dy); - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose( - await gradients.data(), - [ - -1 * 1 / (2 * Math.pow(1, 1.5)), -1 * 2 / (2 * Math.pow(2, 1.5)), - -1 * 3 / (2 * Math.pow(3, 1.5)), -1 * 4 / (2 * Math.pow(5, 1.5)) - ], - 1e-1); - }); - - it('gradients: Tensor2D', async () => { - const a = tf.tensor2d([3, 1, 2, 3], [2, 2]); - const dy = tf.tensor2d([1, 2, 3, 4], [2, 2]); - - const gradients = tf.grad(a => tf.rsqrt(a))(a, dy); - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose( - await gradients.data(), - [ - -1 * 1 / (2 * Math.pow(3, 1.5)), -1 * 2 / (2 * Math.pow(1, 1.5)), - -1 * 3 / (2 * Math.pow(2, 1.5)), -1 * 4 / (2 * Math.pow(3, 1.5)) - ], - 1e-1); - }); - - it('throws when passed a non-tensor', () => { - expect(() => tf.rsqrt({} as tf.Tensor)) - .toThrowError(/Argument 'x' passed to 'rsqrt' must be a Tensor/); - }); - - it('accepts a tensor-like object', async () => { - const r = tf.rsqrt([2, 4]); - expectArraysClose(await r.data(), [1 / Math.sqrt(2), 1 / Math.sqrt(4)]); - }); - - it('throws for string tensor', () => { - expect(() => tf.rsqrt('q')) - .toThrowError(/Argument 'x' passed to 'rsqrt' must be numeric/); - }); -}); - -describeWithFlags('square', ALL_ENVS, () => { - it('1D array', async () => { - const a = tf.tensor1d([2, 4, Math.sqrt(2)]); - const r = tf.square(a); - expectArraysClose(await r.data(), [4, 16, 2]); - }); - - it('2D array', async () => { - const a = tf.tensor2d([1, 2, Math.sqrt(2), Math.sqrt(3)], [2, 2]); - const r = tf.square(a); - expect(r.shape).toEqual([2, 2]); - expectArraysClose(await r.data(), [1, 4, 2, 3]); - }); - - it('5D array', async () => { - const a = tf.tensor5d([1, 2, Math.sqrt(2), Math.sqrt(3)], [1, 1, 2, 2, 1]); - const r = tf.square(a); - expect(r.shape).toEqual([1, 1, 2, 2, 1]); - expectArraysClose(await r.data(), [1, 4, 2, 3]); - }); - - it('6D array', async () => { - const a = tf.tensor6d( - [1, 2, Math.sqrt(2), Math.sqrt(3), 3, 4, Math.sqrt(7), Math.sqrt(13)], - [1, 1, 2, 2, 2, 1]); - const r = tf.square(a); - expect(r.shape).toEqual(a.shape); - expectArraysClose(await r.data(), [1, 4, 2, 3, 9, 16, 7, 13]); - }); - - it('square propagates NaNs', async () => { - const a = tf.tensor1d([1.5, NaN]); - const r = tf.square(a); - expectArraysClose(await r.data(), [2.25, NaN]); - }); - - it('gradients: Scalar', async () => { - const a = tf.scalar(5); - const dy = tf.scalar(8); - - const gradients = tf.grad(a => tf.square(a))(a, dy); - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), [2 * 5 * 8]); - }); - - it('gradients: Scalar', async () => { - const a = tf.scalar(5); - const dy = tf.scalar(8); - - const gradients = tf.grad(a => tf.square(a.clone()).clone())(a, dy); - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), [2 * 5 * 8]); - }); - - it('gradients: Tensor1D', async () => { - const a = tf.tensor1d([-1, 2, 3, -5]); - const dy = tf.tensor1d([1, 2, 3, 4]); - - const gradients = tf.grad(a => tf.square(a))(a, dy); - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), [-2, 4 * 2, 6 * 3, -10 * 4]); - }); - - it('gradients: Tensor2D', async () => { - const a = tf.tensor2d([-3, 1, 2, 3], [2, 2]); - const dy = tf.tensor2d([1, 2, 3, 4], [2, 2]); - - const gradients = tf.grad(a => tf.square(a))(a, dy); - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), [-6 * 1, 2 * 2, 4 * 3, 6 * 4]); - }); - - it('gradients: Tensor5D', async () => { - const a = tf.tensor5d([-3, 1, 2, 3], [1, 1, 1, 2, 2]); - const dy = tf.tensor5d([1, 2, 3, 4], [1, 1, 1, 2, 2]); - - const gradients = tf.grad(a => tf.square(a))(a, dy); - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), [-6 * 1, 2 * 2, 4 * 3, 6 * 4]); - }); - - it('gradients: Tensor6D', async () => { - const a = tf.tensor6d([-3, 1, 2, 3, -4, 5, 12, 3], [1, 1, 1, 2, 2, 2]); - const dy = tf.tensor6d([1, 2, 3, 4, 5, 6, 7, 8], [1, 1, 1, 2, 2, 2]); - - const gradients = tf.grad(a => tf.square(a))(a, dy); - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose( - await gradients.data(), - [-6 * 1, 2 * 2, 4 * 3, 6 * 4, -8 * 5, 10 * 6, 24 * 7, 6 * 8]); - }); - - it('throws when passed a non-tensor', () => { - expect(() => tf.square({} as tf.Tensor)) - .toThrowError(/Argument 'x' passed to 'square' must be a Tensor/); - }); - - it('accepts a tensor-like object', async () => { - const r = tf.square([2, 4, Math.sqrt(2)]); - expectArraysClose(await r.data(), [4, 16, 2]); - }); - - it('throws for string tensor', () => { - expect(() => tf.square('q')) - .toThrowError(/Argument 'x' passed to 'square' must be numeric/); - }); -}); - -describeWithFlags('isNaN', ALL_ENVS, () => { - it('basic', async () => { - const a = tf.tensor1d([NaN, Infinity, -Infinity, 0, 1]); - const r = tf.isNaN(a); - expect(r.dtype).toEqual('bool'); - expectArraysClose(await r.data(), [1, 0, 0, 0, 0]); - }); - - it('gradients: Scalar', async () => { - const a = tf.scalar(NaN); - const dy = tf.scalar(3); - - const gradients = tf.grad(a => tf.isNaN(a))(a, dy); - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), [0]); - }); - - it('gradients: Tensor1D', async () => { - const a = tf.tensor1d([NaN, Infinity, -Infinity, 0, 1]); - const dy = tf.tensor1d([1, 1, 1, 1, 1]); - - const gradients = tf.grad(a => tf.isNaN(a))(a, dy); - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), [0, 0, 0, 0, 0]); - }); - - it('gradients: Tensor2D', async () => { - const a = tf.tensor2d([NaN, Infinity, -Infinity, 0], [2, 2]); - const dy = tf.tensor2d([1, 2, 3, 4], [2, 2]); - - const gradients = tf.grad(a => tf.isNaN(a))(a, dy); - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), [0, 0, 0, 0]); - }); - - it('throws when passed a non-tensor', () => { - expect(() => tf.isNaN({} as tf.Tensor)) - .toThrowError(/Argument 'x' passed to 'isNaN' must be a Tensor/); - }); - - it('accepts a tensor-like object', async () => { - const r = tf.isNaN([NaN, Infinity, -Infinity, 0, 1]); - expectArraysClose(await r.data(), [1, 0, 0, 0, 0]); - }); - - it('throws for string tensor', () => { - expect(() => tf.isNaN('q')) - .toThrowError(/Argument 'x' passed to 'isNaN' must be numeric/); - }); -}); - -describeWithFlags('isInf', ALL_ENVS, () => { - it('basic', async () => { - const a = tf.tensor1d([NaN, Infinity, -Infinity, 0, 1]); - const r = tf.isInf(a); - expect(r.dtype).toEqual('bool'); - expectArraysClose(await r.data(), [0, 1, 1, 0, 0]); - }); - - it('gradients: Scalar', async () => { - const a = tf.scalar(NaN); - const dy = tf.scalar(3); - - const gradients = tf.grad(a => tf.isInf(a))(a, dy); - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), [0]); - }); - - it('gradients: Tensor1D', async () => { - const a = tf.tensor1d([NaN, Infinity, -Infinity, 0, 1]); - const dy = tf.tensor1d([1, 1, 1, 1, 1]); - - const gradients = tf.grad(a => tf.isInf(a))(a, dy); - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), [0, 0, 0, 0, 0]); - }); - - it('gradients: Tensor2D', async () => { - const a = tf.tensor2d([NaN, Infinity, -Infinity, 0], [2, 2]); - const dy = tf.tensor2d([1, 2, 3, 4], [2, 2]); - - const gradients = tf.grad(a => tf.isInf(a))(a, dy); - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), [0, 0, 0, 0]); - }); - - it('throws when passed a non-tensor', () => { - expect(() => tf.isInf({} as tf.Tensor)) - .toThrowError(/Argument 'x' passed to 'isInf' must be a Tensor/); - }); - - it('accepts a tensor-like object', async () => { - const r = tf.isInf([NaN, Infinity, -Infinity, 0, 1]); - expectArraysClose(await r.data(), [0, 1, 1, 0, 0]); - }); - - it('throws for string tensor', () => { - expect(() => tf.isInf('q')) - .toThrowError(/Argument 'x' passed to 'isInf' must be numeric/); - }); -}); - -describeWithFlags('isFinite', ALL_ENVS, () => { - it('basic', async () => { - const a = tf.tensor1d([NaN, Infinity, -Infinity, 0, 1]); - const r = tf.isFinite(a); - expect(r.dtype).toEqual('bool'); - expectArraysClose(await r.data(), [0, 0, 0, 1, 1]); - }); - - it('gradients: Scalar', async () => { - const a = tf.scalar(NaN); - const dy = tf.scalar(3); - - const gradients = tf.grad(a => tf.isFinite(a))(a, dy); - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), [0]); - }); - - it('gradients: Tensor1D', async () => { - const a = tf.tensor1d([NaN, Infinity, -Infinity, 0, 1]); - const dy = tf.tensor1d([1, 1, 1, 1, 1]); - - const gradients = tf.grad(a => tf.isFinite(a))(a, dy); - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), [0, 0, 0, 0, 0]); - }); - - it('gradients: Tensor2D', async () => { - const a = tf.tensor2d([NaN, Infinity, -Infinity, 0], [2, 2]); - const dy = tf.tensor2d([1, 2, 3, 4], [2, 2]); - - const gradients = tf.grad(a => tf.isFinite(a))(a, dy); - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), [0, 0, 0, 0]); - }); - - it('throws when passed a non-tensor', () => { - expect(() => tf.isFinite({} as tf.Tensor)) - .toThrowError(/Argument 'x' passed to 'isFinite' must be a Tensor/); - }); - - it('accepts a tensor-like object', async () => { - const r = tf.isFinite([NaN, Infinity, -Infinity, 0, 1]); - expectArraysClose(await r.data(), [0, 0, 0, 1, 1]); - }); - - it('throws for string tensor', () => { - expect(() => tf.isFinite('q')) - .toThrowError(/Argument 'x' passed to 'isFinite' must be numeric/); - }); -}); - -describeWithFlags('round', ALL_ENVS, () => { - it('basic', async () => { - const a = tf.tensor1d([0.9, 2.5, 2.3, 1.5, -4.5]); - const r = a.round(); - - expectArraysClose(await r.data(), [1, 2, 2, 2, -4]); - }); - - it('propagates NaNs', async () => { - const a = tf.tensor1d([1.5, NaN, -1.4]); - const r = tf.round(a); - expectArraysClose(await r.data(), [2, NaN, -1]); - }); - - it('gradients: Scalar', async () => { - const a = tf.scalar(5.2); - const dy = tf.scalar(3); - - const gradients = tf.grad(a => tf.round(a))(a, dy); - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), [0]); - }); - - it('gradient with clones', async () => { - const a = tf.scalar(5.2); - const dy = tf.scalar(3); - - const gradients = tf.grad(a => tf.round(a.clone()).clone())(a, dy); - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), [0]); - }); - - it('gradients: Tensor1D', async () => { - const a = tf.tensor1d([-1.1, 2.6, 3, -5.9]); - const dy = tf.tensor1d([1, 2, 3, 4]); - - const gradients = tf.grad(a => tf.round(a))(a, dy); - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), [0, 0, 0, 0]); - }); - - it('gradients: Tensor2D', async () => { - const a = tf.tensor2d([-3, 1, 2.2, 3], [2, 2]); - const dy = tf.tensor2d([1, 2, 3, 4], [2, 2]); - - const gradients = tf.grad(a => tf.round(a))(a, dy); - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), [0, 0, 0, 0]); - }); - - it('throws when passed a non-tensor', () => { - expect(() => tf.round({} as tf.Tensor)) - .toThrowError(/Argument 'x' passed to 'round' must be a Tensor/); - }); - - it('accepts a tensor-like object', async () => { - const r = tf.round([0.9, 2.5, 2.3, 1.5, -4.5]); - expectArraysClose(await r.data(), [1, 2, 2, 2, -4]); - }); - - it('throws for string tensor', () => { - expect(() => tf.round('q')) - .toThrowError(/Argument 'x' passed to 'round' must be numeric/); - }); -}); diff --git a/tfjs-core/src/register_all_gradients.ts b/tfjs-core/src/register_all_gradients.ts index 9caa7a6fbb9..4801243bf6e 100644 --- a/tfjs-core/src/register_all_gradients.ts +++ b/tfjs-core/src/register_all_gradients.ts @@ -54,6 +54,9 @@ import {fusedBatchNormGradConfig} from './gradients/FusedBatchNorm_grad'; import {gatherGradConfig} from './gradients/GatherV2_grad'; import {greaterEqualGradConfig} from './gradients/GreaterEqual_grad'; import {identityGradConfig} from './gradients/Identity_grad'; +import {isFiniteGradConfig} from './gradients/IsFinite_grad'; +import {isInfGradConfig} from './gradients/IsInf_grad'; +import {isNanGradConfig} from './gradients/IsNan_grad'; import {log1pGradConfig} from './gradients/Log1p_grad'; import {logGradConfig} from './gradients/Log_grad'; import {logSoftmaxGradConfig} from './gradients/LogSoftmax_grad'; @@ -79,17 +82,23 @@ import {reshapeGradConfig} from './gradients/Reshape_grad'; import {resizeBilinearGradConfig} from './gradients/ResizeBilinear_grad'; import {resizeNearestNeighborGradConfig} from './gradients/ResizeNearestNeighbor_grad'; import {reverseGradConfig} from './gradients/Reverse_grad'; +import {roundGradConfig} from './gradients/Round_grad'; +import {rsqrtGradConfig} from './gradients/Rsqrt_grad'; import {selectV2PoolGradConfig} from './gradients/SelectV2_grad'; import {seluGradConfig} from './gradients/Selu_grad'; +import {sigmoidGradConfig} from './gradients/Sigmoid_grad'; import {signGradConfig} from './gradients/Sign_grad'; import {sinGradConfig} from './gradients/Sin_grad'; import {sinhGradConfig} from './gradients/Sinh_grad'; import {sliceGradConfig} from './gradients/Slice_grad'; import {softmaxGradConfig} from './gradients/Softmax_grad'; +import {softplusGradConfig} from './gradients/Softplus_grad'; import {spaceToBatchNDGradConfig} from './gradients/SpaceToBatchND_grad'; import {splitVGradConfig} from './gradients/SplitV_grad'; +import {sqrtGradConfig} from './gradients/Sqrt_grad'; import {squareGradConfig} from './gradients/Square_grad'; import {squaredDifferenceGradConfig} from './gradients/SquaredDifference_grad'; +import {stepGradConfig} from './gradients/Step_grad'; import {subGradConfig} from './gradients/Sub_grad'; import {sumGradConfig} from './gradients/Sum_grad'; import {tanGradConfig} from './gradients/Tan_grad'; @@ -144,10 +153,13 @@ const gradConfigs: GradConfig[] = [ gatherGradConfig, greaterEqualGradConfig, identityGradConfig, + isFiniteGradConfig, + isInfGradConfig, + isNanGradConfig, log1pGradConfig, logGradConfig, - lrnGradConfig, logSoftmaxGradConfig, + lrnGradConfig, maxGradConfig, maxGradConfig, maximumGradConfig, @@ -171,21 +183,27 @@ const gradConfigs: GradConfig[] = [ resizeBilinearGradConfig, resizeNearestNeighborGradConfig, reverseGradConfig, + roundGradConfig, + rsqrtGradConfig, selectV2PoolGradConfig, seluGradConfig, + sigmoidGradConfig, signGradConfig, sinGradConfig, sinhGradConfig, sliceGradConfig, + softmaxGradConfig, + softplusGradConfig, spaceToBatchNDGradConfig, spaceToBatchNDGradConfig, splitVGradConfig, splitVGradConfig, + sqrtGradConfig, squaredDifferenceGradConfig, squareGradConfig, + stepGradConfig, subGradConfig, sumGradConfig, - softmaxGradConfig, tanGradConfig, tanhGradConfig, tileGradConfig, diff --git a/tfjs-core/src/tests.ts b/tfjs-core/src/tests.ts index 5b2cbfcbaff..4504912ac5f 100644 --- a/tfjs-core/src/tests.ts +++ b/tfjs-core/src/tests.ts @@ -113,6 +113,9 @@ import './ops/hann_window_test'; import './ops/hinge_loss_test'; import './ops/huber_loss_test'; import './ops/in_top_k_test'; +import './ops/is_finite_test'; +import './ops/is_inf_test'; +import './ops/is_nan_test'; import './ops/leaky_relu_test'; import './ops/less_equal_test'; import './ops/less_test'; @@ -169,10 +172,13 @@ import './ops/reverse_3d_test'; import './ops/reverse_4d_test'; import './ops/reverse_test'; import './ops/rotate_with_offset_test'; +import './ops/round_test'; +import './ops/rsqrt_test'; import './ops/scatter_nd_test'; import './ops/selu_test'; import './ops/setdiff1d_async_test'; import './ops/sigmoid_cross_entropy_test'; +import './ops/sigmoid_test'; import './ops/sign_test'; import './ops/sin_test'; import './ops/sinh_test'; @@ -184,11 +190,15 @@ import './ops/slice_test'; import './ops/slice_util_test'; import './ops/softmax_cross_entropy_test'; import './ops/softmax_test'; +import './ops/softplus_test'; import './ops/space_to_batch_nd_test'; import './ops/sparse_to_dense_test'; import './ops/spectral_ops_test'; import './ops/split_test'; +import './ops/sqrt_test'; +import './ops/square_test'; import './ops/stack_test'; +import './ops/step_test'; import './ops/stft_test'; import './ops/strided_slice_test'; import './ops/sub_test'; @@ -200,7 +210,6 @@ import './ops/to_pixels_test'; import './ops/topk_test'; import './ops/transpose_test'; import './ops/truncated_normal_test'; -import './ops/unary_ops_test'; import './ops/unsorted_segment_sum_test'; import './ops/unstack_test'; import './ops/where_async_test';