From be4219b68165dd1083f08a2eb8093947b0f20c17 Mon Sep 17 00:00:00 2001 From: Yannick Assogba Date: Fri, 10 Jul 2020 16:36:48 -0400 Subject: [PATCH 1/4] modularize more unary ops --- tfjs-core/scripts/touch_modular_op_files.ts | 2 +- tfjs-core/src/gradients/ClipByValue_grad.ts | 39 + tfjs-core/src/gradients/Erf_grad.ts | 34 + tfjs-core/src/gradients/Exp_grad.ts | 30 + tfjs-core/src/gradients/Expm1_grad.ts | 31 + tfjs-core/src/gradients/Log1p_grad.ts | 31 + tfjs-core/src/gradients/Log_grad.ts | 31 + tfjs-core/src/gradients/Pow_grad.ts | 2 +- tfjs-core/src/gradients/Reciprocal_grad.ts | 32 + tfjs-core/src/gradients/Selu_grad.ts | 2 +- tfjs-core/src/kernel_names.ts | 25 + tfjs-core/src/ops/clip_by_value.ts | 62 ++ tfjs-core/src/ops/clip_by_value_test.ts | 138 ++++ tfjs-core/src/ops/erf.ts | 57 ++ tfjs-core/src/ops/erf_test.ts | 135 ++++ tfjs-core/src/ops/exp.ts | 48 ++ tfjs-core/src/ops/exp_test.ts | 100 +++ tfjs-core/src/ops/expm1.ts | 49 ++ tfjs-core/src/ops/expm1_test.ts | 103 +++ tfjs-core/src/ops/log.ts | 48 ++ tfjs-core/src/ops/log1p.ts | 49 ++ tfjs-core/src/ops/log1p_test.ts | 94 +++ tfjs-core/src/ops/log_loss.ts | 2 +- tfjs-core/src/ops/log_sigmoid.ts | 60 ++ tfjs-core/src/ops/log_sigmoid_test.ts | 162 ++++ tfjs-core/src/ops/log_sum_exp.ts | 3 +- tfjs-core/src/ops/log_test.ts | 105 +++ tfjs-core/src/ops/ops.ts | 8 + tfjs-core/src/ops/reciprocal.ts | 48 ++ tfjs-core/src/ops/reciprocal_test.ts | 106 +++ tfjs-core/src/ops/sigmoid_cross_entropy.ts | 3 +- tfjs-core/src/ops/softmax_cross_entropy.ts | 2 +- tfjs-core/src/ops/unary_ops.ts | 244 ------ tfjs-core/src/ops/unary_ops_test.ts | 783 -------------------- tfjs-core/src/register_all_gradients.ts | 16 +- tfjs-core/src/tests.ts | 8 + 36 files changed, 1657 insertions(+), 1035 deletions(-) create mode 100644 tfjs-core/src/gradients/ClipByValue_grad.ts create mode 100644 tfjs-core/src/gradients/Erf_grad.ts create mode 100644 tfjs-core/src/gradients/Exp_grad.ts create mode 100644 tfjs-core/src/gradients/Expm1_grad.ts create mode 100644 tfjs-core/src/gradients/Log1p_grad.ts create mode 100644 tfjs-core/src/gradients/Log_grad.ts create mode 100644 tfjs-core/src/gradients/Reciprocal_grad.ts create mode 100644 tfjs-core/src/ops/clip_by_value.ts create mode 100644 tfjs-core/src/ops/clip_by_value_test.ts create mode 100644 tfjs-core/src/ops/erf.ts create mode 100644 tfjs-core/src/ops/erf_test.ts create mode 100644 tfjs-core/src/ops/exp.ts create mode 100644 tfjs-core/src/ops/exp_test.ts create mode 100644 tfjs-core/src/ops/expm1.ts create mode 100644 tfjs-core/src/ops/expm1_test.ts create mode 100644 tfjs-core/src/ops/log.ts create mode 100644 tfjs-core/src/ops/log1p.ts create mode 100644 tfjs-core/src/ops/log1p_test.ts create mode 100644 tfjs-core/src/ops/log_sigmoid.ts create mode 100644 tfjs-core/src/ops/log_sigmoid_test.ts create mode 100644 tfjs-core/src/ops/log_test.ts create mode 100644 tfjs-core/src/ops/reciprocal.ts create mode 100644 tfjs-core/src/ops/reciprocal_test.ts diff --git a/tfjs-core/scripts/touch_modular_op_files.ts b/tfjs-core/scripts/touch_modular_op_files.ts index 66f27c2f80d..fd827106e38 100644 --- a/tfjs-core/scripts/touch_modular_op_files.ts +++ b/tfjs-core/scripts/touch_modular_op_files.ts @@ -115,7 +115,7 @@ export const ${downcaseFirstChar(kernelName)}GradConfig: GradConfig = { outputsToSave: [], // UPDATE ME gradFunc: (dy: Tensor, saved: Tensor[], attrs: NamedAttrMap) => { const [] = saved; - const {} = attrs as {} as KernelNameAttrs; + const {} = attrs as {} as ${kernelName}Attrs; return { }; } diff --git a/tfjs-core/src/gradients/ClipByValue_grad.ts b/tfjs-core/src/gradients/ClipByValue_grad.ts new file mode 100644 index 00000000000..88c609dc838 --- /dev/null +++ b/tfjs-core/src/gradients/ClipByValue_grad.ts @@ -0,0 +1,39 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {ClipByValue, ClipByValueAttrs} from '../kernel_names'; +import {GradConfig, NamedAttrMap} from '../kernel_registry'; +import {greaterEqual} from '../ops/greater_equal'; +import {lessEqual} from '../ops/less_equal'; +import {logicalAnd} from '../ops/logical_and'; +import {zerosLike} from '../ops/tensor_ops'; +import {Tensor} from '../tensor'; + +export const clipByValueGradConfig: GradConfig = { + kernelName: ClipByValue, + inputsToSave: ['x'], + gradFunc: (dy: Tensor, saved: Tensor[], attrs: NamedAttrMap) => { + const [x] = saved; + const {clipValueMin, clipValueMax} = attrs as {} as ClipByValueAttrs; + return { + // tslint:disable-next-line: no-unnecessary-type-assertion + x: () => dy.where( + logicalAnd(greaterEqual(x, clipValueMin), lessEqual(x, clipValueMax)), + zerosLike(dy)), + }; + } +}; diff --git a/tfjs-core/src/gradients/Erf_grad.ts b/tfjs-core/src/gradients/Erf_grad.ts new file mode 100644 index 00000000000..dcb11ecf5bf --- /dev/null +++ b/tfjs-core/src/gradients/Erf_grad.ts @@ -0,0 +1,34 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Erf} from '../kernel_names'; +import {GradConfig} from '../kernel_registry'; +import {exp} from '../ops/exp'; +import {mul} from '../ops/mul'; +import {neg} from '../ops/neg'; +import {square} from '../ops/square'; +import {Tensor} from '../tensor'; + +export const erfGradConfig: GradConfig = { + kernelName: Erf, + inputsToSave: ['x'], + gradFunc: (dy: Tensor, saved: Tensor[]) => { + const [x] = saved; + const a = mul(exp(neg(square(x))), 2 / Math.sqrt(Math.PI)); + return {x: () => mul(dy, a)}; + } +}; diff --git a/tfjs-core/src/gradients/Exp_grad.ts b/tfjs-core/src/gradients/Exp_grad.ts new file mode 100644 index 00000000000..44c1acf6603 --- /dev/null +++ b/tfjs-core/src/gradients/Exp_grad.ts @@ -0,0 +1,30 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Exp} from '../kernel_names'; +import {GradConfig} from '../kernel_registry'; +import {mul} from '../ops/mul'; +import {Tensor} from '../tensor'; + +export const expGradConfig: GradConfig = { + kernelName: Exp, + outputsToSave: [true], + gradFunc: (dy: Tensor, saved: Tensor[]) => { + const [y] = saved; + return {x: () => mul(dy, y)}; + } +}; diff --git a/tfjs-core/src/gradients/Expm1_grad.ts b/tfjs-core/src/gradients/Expm1_grad.ts new file mode 100644 index 00000000000..3bf6e91f6e9 --- /dev/null +++ b/tfjs-core/src/gradients/Expm1_grad.ts @@ -0,0 +1,31 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Expm1} from '../kernel_names'; +import {GradConfig} from '../kernel_registry'; +import {exp} from '../ops/exp'; +import {mul} from '../ops/mul'; +import {Tensor} from '../tensor'; + +export const expm1GradConfig: GradConfig = { + kernelName: Expm1, + inputsToSave: ['x'], + gradFunc: (dy: Tensor, saved: Tensor[]) => { + const [x] = saved; + return {x: () => mul(dy, exp(x))}; + } +}; diff --git a/tfjs-core/src/gradients/Log1p_grad.ts b/tfjs-core/src/gradients/Log1p_grad.ts new file mode 100644 index 00000000000..323b6a04afd --- /dev/null +++ b/tfjs-core/src/gradients/Log1p_grad.ts @@ -0,0 +1,31 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Log1p} from '../kernel_names'; +import {GradConfig} from '../kernel_registry'; +import {add} from '../ops/add'; +import {div} from '../ops/div'; +import {Tensor} from '../tensor'; + +export const log1pGradConfig: GradConfig = { + kernelName: Log1p, + inputsToSave: ['x'], + gradFunc: (dy: Tensor, saved: Tensor[]) => { + const [x] = saved; + return {x: () => div(dy, add(x, 1))}; + } +}; diff --git a/tfjs-core/src/gradients/Log_grad.ts b/tfjs-core/src/gradients/Log_grad.ts new file mode 100644 index 00000000000..e571b74a7d0 --- /dev/null +++ b/tfjs-core/src/gradients/Log_grad.ts @@ -0,0 +1,31 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Log} from '../kernel_names'; +import {GradConfig} from '../kernel_registry'; +import {cast} from '../ops/array_ops'; +import {div} from '../ops/div'; +import {Tensor} from '../tensor'; + +export const logGradConfig: GradConfig = { + kernelName: Log, + inputsToSave: ['x'], + gradFunc: (dy: Tensor, saved: Tensor[]) => { + const [x] = saved; + return {x: () => div(dy, cast(x, 'float32'))}; + } +}; diff --git a/tfjs-core/src/gradients/Pow_grad.ts b/tfjs-core/src/gradients/Pow_grad.ts index e0422e65f33..0a5a9e12ab8 100644 --- a/tfjs-core/src/gradients/Pow_grad.ts +++ b/tfjs-core/src/gradients/Pow_grad.ts @@ -19,13 +19,13 @@ import {GradConfig} from '../kernel_registry'; import {cast} from '../ops/array_ops'; import * as broadcast_util from '../ops/broadcast_util'; import {greater} from '../ops/greater'; +import {log} from '../ops/log'; import {mul} from '../ops/mul'; import {pow} from '../ops/pow'; import {reshape} from '../ops/reshape'; import {sub} from '../ops/sub'; import {sum} from '../ops/sum'; import {scalar, zerosLike} from '../ops/tensor_ops'; -import {log} from '../ops/unary_ops'; import {where} from '../ops/where'; import {Tensor} from '../tensor'; diff --git a/tfjs-core/src/gradients/Reciprocal_grad.ts b/tfjs-core/src/gradients/Reciprocal_grad.ts new file mode 100644 index 00000000000..c4bfa6c3d3e --- /dev/null +++ b/tfjs-core/src/gradients/Reciprocal_grad.ts @@ -0,0 +1,32 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Reciprocal} from '../kernel_names'; +import {GradConfig} from '../kernel_registry'; +import {div} from '../ops/div'; +import {neg} from '../ops/neg'; +import {square} from '../ops/square'; +import {Tensor} from '../tensor'; + +export const reciprocalGradConfig: GradConfig = { + kernelName: Reciprocal, + inputsToSave: ['x'], + gradFunc: (dy: Tensor, saved: Tensor[]) => { + const [x] = saved; + return {x: () => div(dy, neg(square(x)))}; + } +}; diff --git a/tfjs-core/src/gradients/Selu_grad.ts b/tfjs-core/src/gradients/Selu_grad.ts index 59177669fa2..8cb76a95569 100644 --- a/tfjs-core/src/gradients/Selu_grad.ts +++ b/tfjs-core/src/gradients/Selu_grad.ts @@ -17,11 +17,11 @@ import {Selu} from '../kernel_names'; import {GradConfig} from '../kernel_registry'; import {cast} from '../ops/array_ops'; +import {exp} from '../ops/exp'; import {greater} from '../ops/greater'; import {mul} from '../ops/mul'; import {SELU_SCALE, SELU_SCALEALPHA} from '../ops/selu_util'; import {scalar} from '../ops/tensor_ops'; -import {exp} from '../ops/unary_ops'; import {where} from '../ops/where'; import {Tensor} from '../tensor'; diff --git a/tfjs-core/src/kernel_names.ts b/tfjs-core/src/kernel_names.ts index 70cf4cb01fa..5d595cc6a5c 100644 --- a/tfjs-core/src/kernel_names.ts +++ b/tfjs-core/src/kernel_names.ts @@ -143,6 +143,13 @@ export interface BroadCastToAttrs { export const Ceil = 'Ceil'; export type CeilInputs = UnaryInputs; +export const ClipByValue = 'ClipByValue'; +export type ClipByValueInputs = UnaryInputs; +export interface ClipByValueAttrs { + clipValueMin: number; + clipValueMax: number; +} + export const Complex = 'Complex'; export type ComplexInputs = Pick; @@ -283,9 +290,18 @@ export type EluInputs = Pick; export const EluGrad = 'EluGrad'; export type EluGradInputs = Pick; +export const Erf = 'Erf'; +export type ErfInputs = UnaryInputs; + export const Equal = 'Equal'; export type EqualInputs = BinaryInputs; +export const Exp = 'Exp'; +export type ExpInputs = UnaryInputs; + +export const Expm1 = 'Expm1'; +export type Expm1Inputs = UnaryInputs; + export const Floor = 'Floor'; export type FloorInputs = UnaryInputs; @@ -333,6 +349,12 @@ export type LessInputs = BinaryInputs; export const LessEqual = 'LessEqual'; export type LessEqualInputs = BinaryInputs; +export const Log = 'Log'; +export type LogInputs = UnaryInputs; + +export const Log1p = 'Log1p'; +export type Log1pInputs = UnaryInputs; + export const LogicalAnd = 'LogicalAnd'; export type LogicalAndInputs = BinaryInputs; @@ -502,6 +524,9 @@ export interface ProdAttrs { export const Real = 'Real'; export type RealInputs = Pick; +export const Reciprocal = 'Reciprocal'; +export type ReciprocalInputs = UnaryInputs; + export const Relu = 'Relu'; export type ReluInputs = Pick; diff --git a/tfjs-core/src/ops/clip_by_value.ts b/tfjs-core/src/ops/clip_by_value.ts new file mode 100644 index 00000000000..ad9f5f10b25 --- /dev/null +++ b/tfjs-core/src/ops/clip_by_value.ts @@ -0,0 +1,62 @@ +/** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ +import {ENGINE} from '../engine'; +import {ClipByValue, ClipByValueAttrs, ClipByValueInputs} from '../kernel_names'; +import {NamedAttrMap} from '../kernel_registry'; +import {Tensor} from '../tensor'; +import {NamedTensorMap} from '../tensor_types'; +import {convertToTensor} from '../tensor_util_env'; +import {TensorLike} from '../types'; +import * as util from '../util'; + +import {op} from './operation'; + +/** + * Clips values element-wise. `max(min(x, clipValueMax), clipValueMin)` + * + * ```js + * const x = tf.tensor1d([-1, 2, -3, 4]); + * + * x.clipByValue(-2, 3).print(); // or tf.clipByValue(x, -2, 3) + * ``` + * @param x The input tensor. + * @param clipValueMin Lower-bound of range to be clipped to. + * @param clipValueMax Upper-bound of range to be clipped to. + */ +/** @doc {heading: 'Operations', subheading: 'Basic math'} */ +function clipByValue_( + x: T|TensorLike, clipValueMin: number, clipValueMax: number): T { + const $x = convertToTensor(x, 'x', 'clipByValue'); + util.assert( + (clipValueMin <= clipValueMax), + () => `Error in clip: min (${clipValueMin}) must be ` + + `less than or equal to max (${clipValueMax}).`); + + const inputs: ClipByValueInputs = {x: $x}; + const attrs: ClipByValueAttrs = {clipValueMin, clipValueMax}; + + return ENGINE.runKernelFunc( + (backend, save) => { + const res = backend.clip($x, clipValueMin, clipValueMax); + save([$x]); + return res; + }, + inputs as {} as NamedTensorMap, null /* grad */, ClipByValue, + attrs as {} as NamedAttrMap); +} + +export const clipByValue = op({clipByValue_}); diff --git a/tfjs-core/src/ops/clip_by_value_test.ts b/tfjs-core/src/ops/clip_by_value_test.ts new file mode 100644 index 00000000000..0c64c7d2bbd --- /dev/null +++ b/tfjs-core/src/ops/clip_by_value_test.ts @@ -0,0 +1,138 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import * as tf from '../index'; +import {ALL_ENVS, describeWithFlags} from '../jasmine_util'; +import {expectArraysClose} from '../test_util'; + +describeWithFlags('clipByValue', ALL_ENVS, () => { + it('basic', async () => { + const a = tf.tensor1d([3, -1, 0, 100, -7, 2]); + const min = -1; + const max = 50; + + const result = tf.clipByValue(a, min, max); + + expectArraysClose(await result.data(), [3, -1, 0, 50, -1, 2]); + }); + + it('propagates NaNs', async () => { + const a = tf.tensor1d([3, -1, 0, 100, -7, 2, NaN]); + const min = -1; + const max = 50; + + const result = tf.clipByValue(a, min, max); + + expectArraysClose(await result.data(), [3, -1, 0, 50, -1, 2, NaN]); + }); + + it('min greater than max', () => { + const a = tf.tensor1d([3, -1, 0, 100, -7, 2]); + const min = 1; + const max = -1; + + const f = () => { + tf.clipByValue(a, min, max); + }; + expect(f).toThrowError(); + }); + + it('gradient: 1D tensor', async () => { + const min = -1; + const max = 2; + const x = tf.tensor1d([3, -2, 1]); // Only 1 is not clipped. + const dy = tf.tensor1d([5, 50, 500]); + const gradients = tf.grad(x => x.clipByValue(min, max))(x, dy); + + expect(gradients.shape).toEqual(x.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), [0, 0, 500]); + }); + + it('gradient: 1D tensor with max or min value', async () => { + const min = -1; + const max = 2; + const x = tf.tensor1d([-1, 1, 2, 3]); + const dy = tf.tensor1d([1, 10, 100, 1000]); + const gradients = tf.grad(x => x.clipByValue(min, max))(x, dy); + + expect(gradients.shape).toEqual(x.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), [1, 10, 100, 0]); + }); + + it('gradient: scalar', async () => { + const min = -1; + const max = 2; + const x = tf.scalar(-10); // Clipped. + const dy = tf.scalar(5); + const gradients = tf.grad(x => x.clipByValue(min, max))(x, dy); + + expect(gradients.shape).toEqual(x.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), [0]); + }); + + it('gradient with clones', async () => { + const min = -1; + const max = 2; + const x = tf.scalar(-10); // Clipped. + const dy = tf.scalar(5); + const gradients = + tf.grad(x => x.clone().clipByValue(min, max).clone())(x, dy); + + expect(gradients.shape).toEqual(x.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), [0]); + }); + + it('gradient with primitive as input', async () => { + const min = -1; + const max = 2; + const x = -10; + const dy = tf.scalar(5); + const gradients = tf.grad(x => x.clipByValue(min, max))(x, dy); + expect(gradients.shape).toEqual([]); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), [0]); + }); + + it('throws when passed a non-tensor', () => { + expect(() => tf.clipByValue({} as tf.Tensor, 0, 1)) + .toThrowError(/Argument 'x' passed to 'clipByValue' must be a Tensor/); + }); + + it('accepts a tensor-like object', async () => { + const min = -1; + const max = 50; + const result = tf.clipByValue([3, -1, 0, 100, -7, 2], min, max); + expectArraysClose(await result.data(), [3, -1, 0, 50, -1, 2]); + }); + + it('clip(x, eps, 1-eps) never returns 0 or 1', async () => { + const min = tf.backend().epsilon(); + const max = 0.5; + const res = await tf.clipByValue([0, 1], min, max).data(); + expect(res[0]).toBeGreaterThan(0); + expect(res[1]).toBeCloseTo(max); + }); + + it('throws for string tensor', () => { + expect(() => tf.clipByValue('q', 0, 1)) + .toThrowError(/Argument 'x' passed to 'clipByValue' must be numeric/); + }); +}); diff --git a/tfjs-core/src/ops/erf.ts b/tfjs-core/src/ops/erf.ts new file mode 100644 index 00000000000..0968423740e --- /dev/null +++ b/tfjs-core/src/ops/erf.ts @@ -0,0 +1,57 @@ +/** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {ENGINE} from '../engine'; +import {Erf, ErfInputs} from '../kernel_names'; +import {Tensor} from '../tensor'; +import {NamedTensorMap} from '../tensor_types'; +import {convertToTensor} from '../tensor_util_env'; +import {TensorLike} from '../types'; +import * as util from '../util'; + +import {op} from './operation'; + +/** + * Computes gause error function of the input `tf.Tensor` element-wise: + * `erf(x)` + * + * ```js + * const x = tf.tensor1d([0, .1, -.1, .7]); + * + * x.erf().print(); // or tf.erf(x); + * ``` + * @param x The input tensor. + */ +/** @doc {heading: 'Operations', subheading: 'Basic math'} */ +function erf_(x: T|TensorLike): T { + let $x = convertToTensor(x, 'x', 'erf'); + util.assert( + $x.dtype === 'int32' || $x.dtype === 'float32', + () => 'Input dtype must be `int32` or `float32`.'); + + if ($x.dtype === 'int32') { + $x = $x.toFloat(); + } + + const inputs: ErfInputs = {x: $x}; + return ENGINE.runKernelFunc((backend, save) => { + const res = backend.erf($x); + save([$x]); + return res; + }, inputs as {} as NamedTensorMap, null /* grad */, Erf); +} +export const erf = op({erf_}); diff --git a/tfjs-core/src/ops/erf_test.ts b/tfjs-core/src/ops/erf_test.ts new file mode 100644 index 00000000000..0dd9b037468 --- /dev/null +++ b/tfjs-core/src/ops/erf_test.ts @@ -0,0 +1,135 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import * as tf from '../index'; +import {ALL_ENVS, describeWithFlags} from '../jasmine_util'; +import {expectArraysClose} from '../test_util'; + +describeWithFlags('erf', ALL_ENVS, () => { + it('basic', async () => { + const values = [-0.25, 0.25, 0.5, .75, -0.4]; + const a = tf.tensor1d(values); + const result = tf.erf(a); + const expected = [-0.2763264, 0.2763264, 0.5204999, 0.7111556, -0.4283924]; + expectArraysClose(await result.data(), expected); + }); + + it('blowup', async () => { + const values = [-1.4, -2.5, -3.1, -4.4]; + const a = tf.tensor1d(values); + const result = tf.erf(a); + const expected = [-0.9522852, -0.999593, -0.9999883, -1]; + expectArraysClose(await result.data(), expected); + }); + + it('scalar', async () => { + const a = tf.scalar(1); + const result = tf.erf(a); + const expected = [0.8427008]; + expectArraysClose(await result.data(), expected); + }); + + it('scalar in int32', async () => { + const a = tf.scalar(1, 'int32'); + const result = tf.erf(a); + const expected = [0.8427008]; + expectArraysClose(await result.data(), expected); + }); + + it('tensor2d', async () => { + const values = [0.2, 0.3, 0.4, 0.5]; + const a = tf.tensor2d(values, [2, 2]); + const result = tf.erf(a); + const expected = [0.2227026, 0.32862678, 0.42839235, 0.5204999]; + expectArraysClose(await result.data(), expected); + }); + + it('propagates NaNs', async () => { + const a = tf.tensor1d([0.5, NaN, 0]); + const res = tf.erf(a); + expectArraysClose(await res.data(), [0.5204999, NaN, 0.0]); + }); + + it('gradients: Scalar', async () => { + const a = tf.scalar(0.5); + const dy = tf.scalar(8); + const gradients = tf.grad(a => tf.erf(a))(a, dy); + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose( + await gradients.data(), + [8 * 2 * Math.exp(-0.5 * 0.5) / Math.sqrt(Math.PI)]); + }); + + it('gradient with clones', async () => { + const a = tf.scalar(0.5); + const dy = tf.scalar(8); + const gradients = tf.grad(a => tf.erf(a.clone()).clone())(a, dy); + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose( + await gradients.data(), + [8 * 2 * Math.exp(-0.5 * 0.5) / Math.sqrt(Math.PI)]); + }); + + it('gradients: Tensor1D', async () => { + const aValues = [-0.1, 0.2, 0.3, -0.5]; + const dyValues = [1, 2, 3, 4]; + const a = tf.tensor1d(aValues); + const dy = tf.tensor1d(dyValues); + const gradients = tf.grad(a => tf.erf(a))(a, dy); + const expected = []; + for (let i = 0; i < a.size; i++) { + expected[i] = dyValues[i] * 2 * Math.exp(-aValues[i] * aValues[i]) / + Math.sqrt(Math.PI); + } + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), expected); + }); + + it('gradients: Tensor2D', async () => { + const aValues = [-0.3, 0.1, 0.2, 0.3]; + const dyValues = [1, 2, 3, 4]; + const a = tf.tensor2d(aValues, [2, 2]); + const dy = tf.tensor2d(dyValues, [2, 2]); + const gradients = tf.grad(a => tf.erf(a))(a, dy); + const expected = []; + for (let i = 0; i < a.size; i++) { + expected[i] = dyValues[i] * 2 * Math.exp(-aValues[i] * aValues[i]) / + Math.sqrt(Math.PI); + } + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), expected); + }); + + it('throws when passed a non-tensor', () => { + expect(() => tf.erf({} as tf.Tensor)) + .toThrowError(/Argument 'x' passed to 'erf' must be a Tensor/); + }); + + it('accepts a tensor-like object', async () => { + const result = tf.erf(1); + expectArraysClose(await result.data(), [0.8427008]); + }); + + it('throws for string tensor', () => { + expect(() => tf.erf('q')) + .toThrowError(/Argument 'x' passed to 'erf' must be numeric/); + }); +}); diff --git a/tfjs-core/src/ops/exp.ts b/tfjs-core/src/ops/exp.ts new file mode 100644 index 00000000000..8fb32789a85 --- /dev/null +++ b/tfjs-core/src/ops/exp.ts @@ -0,0 +1,48 @@ +/** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {ENGINE} from '../engine'; +import {Exp, ExpInputs} from '../kernel_names'; +import {Tensor} from '../tensor'; +import {NamedTensorMap} from '../tensor_types'; +import {convertToTensor} from '../tensor_util_env'; +import {TensorLike} from '../types'; + +import {op} from './operation'; + +/** + * Computes exponential of the input `tf.Tensor` element-wise. `e ^ x` + * + * ```js + * const x = tf.tensor1d([1, 2, -3]); + * + * x.exp().print(); // or tf.exp(x) + * ``` + * @param x The input tensor. + */ +/** @doc {heading: 'Operations', subheading: 'Basic math'} */ +function exp_(x: T|TensorLike): T { + const $x = convertToTensor(x, 'x', 'exp'); + + const inputs: ExpInputs = {x: $x}; + return ENGINE.runKernelFunc((backend, save) => { + const res = backend.exp($x); + save([res]); + return res; + }, inputs as {} as NamedTensorMap, null /* grad */, Exp); +} +export const exp = op({exp_}); diff --git a/tfjs-core/src/ops/exp_test.ts b/tfjs-core/src/ops/exp_test.ts new file mode 100644 index 00000000000..26a82aa0cae --- /dev/null +++ b/tfjs-core/src/ops/exp_test.ts @@ -0,0 +1,100 @@ +/** + * @license + * Copyright 2017 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import * as tf from '../index'; +import {ALL_ENVS, describeWithFlags} from '../jasmine_util'; +import {expectArraysClose} from '../test_util'; + +describeWithFlags('exp', ALL_ENVS, () => { + it('exp', async () => { + const a = tf.tensor1d([1, 2, 0]); + const r = tf.exp(a); + + expectArraysClose(await r.data(), [Math.exp(1), Math.exp(2), 1]); + }); + + it('exp propagates NaNs', async () => { + const a = tf.tensor1d([1, NaN, 0]); + const r = tf.exp(a); + expectArraysClose(await r.data(), [Math.exp(1), NaN, 1]); + }); + + it('gradients: Scalar', async () => { + const a = tf.scalar(0.5); + const dy = tf.scalar(3); + + const gradients = tf.grad(a => tf.exp(a))(a, dy); + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), [3 * Math.exp(0.5)]); + }); + + it('gradient with clones', async () => { + const a = tf.scalar(0.5); + const dy = tf.scalar(3); + + const gradients = tf.grad(a => tf.exp(a.clone()).clone())(a, dy); + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), [3 * Math.exp(0.5)]); + }); + + it('gradients: Tensor1D', async () => { + const a = tf.tensor1d([-1, 2, 3, -5]); + const dy = tf.tensor1d([1, 2, 3, 4]); + + const gradients = tf.grad(a => tf.exp(a))(a, dy); + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose( + await gradients.data(), + [1 * Math.exp(-1), 2 * Math.exp(2), 3 * Math.exp(3), 4 * Math.exp(-5)], + 1e-1); + }); + + it('gradients: Tensor2D', async () => { + const a = tf.tensor2d([-3, 1, 2, 3], [2, 2]); + const dy = tf.tensor2d([1, 2, 3, 4], [2, 2]); + + const gradients = tf.grad(a => tf.exp(a))(a, dy); + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose( + await gradients.data(), + [1 * Math.exp(-3), 2 * Math.exp(1), 3 * Math.exp(2), 4 * Math.exp(3)], + 1e-1); + }); + + it('throws when passed a non-tensor', () => { + expect(() => tf.exp({} as tf.Tensor)) + .toThrowError(/Argument 'x' passed to 'exp' must be a Tensor/); + }); + + it('accepts a tensor-like object', async () => { + const r = tf.exp([1, 2, 0]); + expectArraysClose(await r.data(), [Math.exp(1), Math.exp(2), 1]); + }); + + it('throws for string tensor', () => { + expect(() => tf.exp('q')) + .toThrowError(/Argument 'x' passed to 'exp' must be numeric/); + }); +}); diff --git a/tfjs-core/src/ops/expm1.ts b/tfjs-core/src/ops/expm1.ts new file mode 100644 index 00000000000..ef8db3a3722 --- /dev/null +++ b/tfjs-core/src/ops/expm1.ts @@ -0,0 +1,49 @@ +/** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {ENGINE} from '../engine'; +import {Expm1, Expm1Inputs} from '../kernel_names'; +import {Tensor} from '../tensor'; +import {NamedTensorMap} from '../tensor_types'; +import {convertToTensor} from '../tensor_util_env'; +import {TensorLike} from '../types'; + +import {op} from './operation'; + +/** + * Computes exponential of the input `tf.Tensor` minus one element-wise. + * `e ^ x - 1` + * + * ```js + * const x = tf.tensor1d([1, 2, -3]); + * + * x.expm1().print(); // or tf.expm1(x) + * ``` + * @param x The input tensor. + */ +/** @doc {heading: 'Operations', subheading: 'Basic math'} */ +function expm1_(x: T|TensorLike): T { + const $x = convertToTensor(x, 'x', 'expm1'); + + const inputs: Expm1Inputs = {x: $x}; + return ENGINE.runKernelFunc((backend, save) => { + const res = backend.expm1($x); + save([$x]); + return res; + }, inputs as {} as NamedTensorMap, null /* grad */, Expm1); +} +export const expm1 = op({expm1_}); diff --git a/tfjs-core/src/ops/expm1_test.ts b/tfjs-core/src/ops/expm1_test.ts new file mode 100644 index 00000000000..6b1908fec87 --- /dev/null +++ b/tfjs-core/src/ops/expm1_test.ts @@ -0,0 +1,103 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import * as tf from '../index'; +import {ALL_ENVS, describeWithFlags} from '../jasmine_util'; +import {expectArraysClose} from '../test_util'; + +describeWithFlags('expm1', ALL_ENVS, () => { + it('expm1', async () => { + const a = tf.tensor1d([1, 2, 0]); + const r = tf.expm1(a); + + expectArraysClose( + await r.data(), [Math.expm1(1), Math.expm1(2), Math.expm1(0)]); + }); + + it('expm1 propagates NaNs', async () => { + const a = tf.tensor1d([1, NaN, 0]); + const r = tf.expm1(a); + expectArraysClose(await r.data(), [Math.expm1(1), NaN, Math.expm1(0)]); + }); + + it('gradients: Scalar', async () => { + const a = tf.scalar(0.5); + const dy = tf.scalar(3); + + const gradients = tf.grad(a => tf.expm1(a))(a, dy); + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), [3 * Math.exp(0.5)]); + }); + + it('gradient with clones', async () => { + const a = tf.scalar(0.5); + const dy = tf.scalar(3); + + const gradients = tf.grad(a => tf.expm1(a.clone()).clone())(a, dy); + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), [3 * Math.exp(0.5)]); + }); + + it('gradients: Tensor1D', async () => { + const a = tf.tensor1d([-1, 2, 3, -5]); + const dy = tf.tensor1d([1, 2, 3, 4]); + + const gradients = tf.grad(a => tf.expm1(a))(a, dy); + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose( + await gradients.data(), + [1 * Math.exp(-1), 2 * Math.exp(2), 3 * Math.exp(3), 4 * Math.exp(-5)], + 1e-1); + }); + + it('gradients: Tensor2D', async () => { + const a = tf.tensor2d([-3, 1, 2, 3], [2, 2]); + const dy = tf.tensor2d([1, 2, 3, 4], [2, 2]); + + const gradients = tf.grad(a => tf.expm1(a))(a, dy); + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose( + await gradients.data(), + [1 * Math.exp(-3), 2 * Math.exp(1), 3 * Math.exp(2), 4 * Math.exp(3)], + 1e-1); + }); + + it('throws when passed a non-tensor', () => { + expect(() => tf.expm1({} as tf.Tensor)) + .toThrowError(/Argument 'x' passed to 'expm1' must be a Tensor/); + }); + + it('accepts a tensor-like object', async () => { + const r = tf.expm1([1, 2, 0]); + + expectArraysClose( + await r.data(), [Math.expm1(1), Math.expm1(2), Math.expm1(0)]); + }); + + it('throws for string tensor', () => { + expect(() => tf.expm1('q')) + .toThrowError(/Argument 'x' passed to 'expm1' must be numeric/); + }); +}); diff --git a/tfjs-core/src/ops/log.ts b/tfjs-core/src/ops/log.ts new file mode 100644 index 00000000000..949b2ec4cc9 --- /dev/null +++ b/tfjs-core/src/ops/log.ts @@ -0,0 +1,48 @@ +/** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {ENGINE} from '../engine'; +import {Log, LogInputs} from '../kernel_names'; +import {Tensor} from '../tensor'; +import {NamedTensorMap} from '../tensor_types'; +import {convertToTensor} from '../tensor_util_env'; +import {TensorLike} from '../types'; + +import {op} from './operation'; + +/** + * Computes natural logarithm of the input `tf.Tensor` element-wise: `ln(x)` + * + * ```js + * const x = tf.tensor1d([1, 2, Math.E]); + * + * x.log().print(); // or tf.log(x) + * ``` + * @param x The input tensor. + */ +/** @doc {heading: 'Operations', subheading: 'Basic math'} */ +function log_(x: T|TensorLike): T { + const $x = convertToTensor(x, 'x', 'log'); + + const inputs: LogInputs = {x: $x}; + return ENGINE.runKernelFunc((backend, save) => { + const res = backend.log($x); + save([$x]); + return res; + }, inputs as {} as NamedTensorMap, null /* grad */, Log); +} +export const log = op({log_}); diff --git a/tfjs-core/src/ops/log1p.ts b/tfjs-core/src/ops/log1p.ts new file mode 100644 index 00000000000..ddaa5cf3590 --- /dev/null +++ b/tfjs-core/src/ops/log1p.ts @@ -0,0 +1,49 @@ +/** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {ENGINE} from '../engine'; +import {Log1p, Log1pInputs} from '../kernel_names'; +import {Tensor} from '../tensor'; +import {NamedTensorMap} from '../tensor_types'; +import {convertToTensor} from '../tensor_util_env'; +import {TensorLike} from '../types'; + +import {op} from './operation'; + +/** + * Computes natural logarithm of the input `tf.Tensor` plus one + * element-wise: `ln(1 + x)` + * + * ```js + * const x = tf.tensor1d([1, 2, Math.E - 1]); + * + * x.log1p().print(); // or tf.log1p(x) + * ``` + * @param x The input tensor. + */ +/** @doc {heading: 'Operations', subheading: 'Basic math'} */ +function log1p_(x: T|TensorLike): T { + const $x = convertToTensor(x, 'x', 'log1p'); + + const inputs: Log1pInputs = {x: $x}; + return ENGINE.runKernelFunc((backend, save) => { + const res = backend.log1p($x); + save([$x]); + return res; + }, inputs as {} as NamedTensorMap, null /* grad */, Log1p); +} +export const log1p = op({log1p_}); diff --git a/tfjs-core/src/ops/log1p_test.ts b/tfjs-core/src/ops/log1p_test.ts new file mode 100644 index 00000000000..5befbfaff2a --- /dev/null +++ b/tfjs-core/src/ops/log1p_test.ts @@ -0,0 +1,94 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import * as tf from '../index'; +import {ALL_ENVS, describeWithFlags} from '../jasmine_util'; +import {expectArraysClose} from '../test_util'; + +describeWithFlags('log1p', ALL_ENVS, () => { + it('log1p', async () => { + const a = tf.tensor1d([1, 2]); + const r = tf.log1p(a); + expectArraysClose(await r.data(), [Math.log1p(1), Math.log1p(2)]); + }); + + it('log1p propagates NaNs', async () => { + const a = tf.tensor1d([1, NaN]); + const r = tf.log1p(a); + expectArraysClose(await r.data(), [Math.log1p(1), NaN]); + }); + + it('gradients: Scalar', async () => { + const a = tf.scalar(5); + const dy = tf.scalar(3); + + const gradients = tf.grad(a => tf.log1p(a))(a, dy); + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), [3 / (1 + 5)]); + }); + + it('gradient with clones', () => { + const a = tf.scalar(5); + const gradients = tf.grad(a => a.clone().log1p().clone())(a); + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + }); + + it('gradients: Tensor1D', async () => { + const a = tf.tensor1d([-1, 2, 3, -5]); + const dy = tf.tensor1d([1, 2, 3, 4]); + + const gradients = tf.grad(a => tf.log1p(a))(a, dy); + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose( + await gradients.data(), + [Infinity, 2 / (1 + 2), 3 / (1 + 3), 4 / (1 + -5)]); + }); + + it('gradients: Tensor2D', async () => { + const a = tf.tensor2d([-3, 1, 2, 3], [2, 2]); + const dy = tf.tensor2d([1, 2, 3, 4], [2, 2]); + + const gradients = tf.grad(a => tf.log1p(a))(a, dy); + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose( + await gradients.data(), + [1 / (1 + -3), 2 / (1 + 1), 3 / (1 + 2), 4 / (1 + 3)]); + }); + + it('throws when passed a non-tensor', () => { + expect(() => tf.log1p({} as tf.Tensor)) + .toThrowError(/Argument 'x' passed to 'log1p' must be a Tensor/); + }); + + it('accepts a tensor-like object', async () => { + const r = tf.log1p([1, 2]); + expectArraysClose(await r.data(), [Math.log1p(1), Math.log1p(2)]); + }); + + it('throws for string tensor', () => { + expect(() => tf.log1p('q')) + .toThrowError(/Argument 'x' passed to 'log1p' must be numeric/); + }); +}); diff --git a/tfjs-core/src/ops/log_loss.ts b/tfjs-core/src/ops/log_loss.ts index 0543646c578..17e27c032b6 100644 --- a/tfjs-core/src/ops/log_loss.ts +++ b/tfjs-core/src/ops/log_loss.ts @@ -22,13 +22,13 @@ import {assertShapesMatch} from '../util'; import {add} from './add'; import {computeWeightedLoss} from './compute_weighted_loss'; +import {log} from './log'; import {Reduction} from './loss_ops_utils'; import {mul} from './mul'; import {neg} from './neg'; import {op} from './operation'; import {sub} from './sub'; import {scalar} from './tensor_ops'; -import {log} from './unary_ops'; /** * Computes the log loss between two tensors. diff --git a/tfjs-core/src/ops/log_sigmoid.ts b/tfjs-core/src/ops/log_sigmoid.ts new file mode 100644 index 00000000000..f20da287691 --- /dev/null +++ b/tfjs-core/src/ops/log_sigmoid.ts @@ -0,0 +1,60 @@ +/** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {customGrad} from '../gradients'; +import {Tensor} from '../tensor'; +import {convertToTensor} from '../tensor_util_env'; +import {TensorLike} from '../types'; + +import {mul} from './mul'; +import {neg} from './neg'; +import {op} from './operation'; +import {sigmoid} from './unary_ops'; +/** + * Computes log sigmoid of the input `tf.Tensor` element-wise: + * `logSigmoid(x)`. For numerical stability, we use `-tf.softplus(-x)`. + * + * ```js + * const x = tf.tensor1d([0, 1, -1, .7]); + * + * x.logSigmoid().print(); // or tf.logSigmoid(x) + * ``` + * @param x The input tensor. + */ +/** @doc {heading: 'Operations', subheading: 'Basic math'} */ +function logSigmoid_(x: T|TensorLike): T { + const $x = convertToTensor(x, 'x', 'logSigmoid'); + + // Use a custom gradient to maintain previous implementation. + // There is no LogSigmoid kernel in TF so we can't use engine.runKernel + // directly + const customOp = customGrad((x: Tensor) => { + // TODO(yassogba) we can remove the chained softplus call here only + // after backends have modualrized softplus at which point we can call + // engine runKernel(..., Sotfplus, ...) directly. + const value = neg(neg(x).softplus()); + + const gradFunc = (dy: T) => { + const derX = mul(dy, sigmoid(neg(x))); + return derX; + }; + return {value, gradFunc}; + }); + + return customOp($x) as T; +} +export const logSigmoid = op({logSigmoid_}); diff --git a/tfjs-core/src/ops/log_sigmoid_test.ts b/tfjs-core/src/ops/log_sigmoid_test.ts new file mode 100644 index 00000000000..8e7ee974d94 --- /dev/null +++ b/tfjs-core/src/ops/log_sigmoid_test.ts @@ -0,0 +1,162 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import * as tf from '../index'; +import {ALL_ENVS, describeWithFlags} from '../jasmine_util'; +import {expectArraysClose} from '../test_util'; + +describeWithFlags('logSigmoid', ALL_ENVS, () => { + it('basic', async () => { + const values = [1, -3, 2, 7, -4]; + const a = tf.tensor1d(values); + + const result = tf.logSigmoid(a); + + const expected = []; + for (let i = 0; i < a.size; i++) { + expected[i] = Math.log(1 / (1 + Math.exp(-values[i]))); + } + expectArraysClose(await result.data(), expected); + }); + + it('scalar', async () => { + const a = tf.scalar(-2); + + const result = tf.logSigmoid(a); + + const expected = [Math.log(1 / (1 + Math.exp(2)))]; + expectArraysClose(await result.data(), expected); + }); + + it('tensor2D', async () => { + const values = [1, 2, -3, 5]; + const a = tf.tensor2d(values, [2, 2]); + + const result = tf.logSigmoid(a); + + const expected = []; + for (let i = 0; i < a.size; i++) { + expected[i] = Math.log(1 / (1 + Math.exp(-values[i]))); + } + expectArraysClose(await result.data(), expected); + }); + + it('larger magnitude negative inputs', async () => { + const values = [-100, -200, -3000]; + const a = tf.tensor1d(values); + + const result = tf.logSigmoid(a); + + const expected = [-100, -200, -3000]; + + expectArraysClose(await result.data(), expected); + }); + + it('larger magnitude positive inputs', async () => { + const values = [100, 200, 3000, 50000]; + const a = tf.tensor1d(values); + + const result = tf.logSigmoid(a); + + const expected = [0, 0, 0, 0]; + + expectArraysClose(await result.data(), expected); + }); + + it('propagates NaNs', async () => { + const a = tf.tensor1d([3, NaN]); + const res = tf.logSigmoid(a); + expectArraysClose( + await res.data(), [Math.log(1 / (1 + Math.exp(-3))), NaN]); + }); + + it('gradients: Scalar', async () => { + const a = tf.scalar(3); + const dy = tf.scalar(4); + const dyVal = await dy.array(); + + const da = tf.grad(a => tf.logSigmoid(a))(a, dy); + const aVal = await a.array(); + const y = 1 / (1 + Math.exp(aVal)); + expectArraysClose(await da.data(), [dyVal * y]); + }); + + it('gradients: Tensor1D', async () => { + const a = tf.tensor1d([1, 2, -3, 5]); + const aVals = await a.array(); + const dy = tf.tensor1d([1, 2, 3, 4]); + const dyVals = await dy.array(); + const da = tf.grad(a => tf.logSigmoid(a))(a, dy); + + const expected = []; + for (let i = 0; i < a.size; i++) { + const y = 1 / (1 + Math.exp(aVals[i])); + expected[i] = dyVals[i] * y; + } + + expectArraysClose(await da.data(), expected); + }); + + it('gradient with clones', async () => { + const a = tf.tensor1d([1, 2, -3, 5]); + const aVals = await a.array(); + const dy = tf.tensor1d([1, 2, 3, 4]); + const dyVals = await dy.array(); + const da = tf.grad(a => tf.logSigmoid(a.clone()).clone())(a, dy); + + const expected = []; + for (let i = 0; i < a.size; i++) { + const y = 1 / (1 + Math.exp(aVals[i])); + expected[i] = dyVals[i] * y; + } + + expectArraysClose(await da.data(), expected); + }); + + it('gradients: Tensor2D', async () => { + const a = tf.tensor2d([1, 2, -3, 5], [2, 2]); + const dy = tf.tensor2d([1, 2, 3, 4], [2, 2]); + + const da = tf.grad(a => tf.logSigmoid(a))(a, dy); + + const expected = []; + const aVals = await a.data(); + const dyVals = await dy.data(); + for (let i = 0; i < a.size; i++) { + const y = 1 / (1 + Math.exp(aVals[i])); + expected[i] = dyVals[i] * y; + } + + expectArraysClose(await da.data(), expected); + }); + + it('throws when passed a non-tensor', () => { + expect(() => tf.logSigmoid({} as tf.Tensor)) + .toThrowError(/Argument 'x' passed to 'logSigmoid' must be a Tensor/); + }); + + it('accepts a tensor-like object', async () => { + const result = tf.logSigmoid(-2); + const expected = [Math.log(1 / (1 + Math.exp(2)))]; + expectArraysClose(await result.data(), expected); + }); + + it('throws for string tensor', () => { + expect(() => tf.logSigmoid('q')) + .toThrowError(/Argument 'x' passed to 'logSigmoid' must be numeric/); + }); +}); diff --git a/tfjs-core/src/ops/log_sum_exp.ts b/tfjs-core/src/ops/log_sum_exp.ts index 4a61ae75819..d58cfdc94b8 100644 --- a/tfjs-core/src/ops/log_sum_exp.ts +++ b/tfjs-core/src/ops/log_sum_exp.ts @@ -22,12 +22,13 @@ import {parseAxisParam} from '../util'; import {add} from './add'; import {expandShapeToKeepDim} from './axis_util'; +import {exp} from './exp'; +import {log} from './log'; import {max} from './max'; import {op} from './operation'; import {reshape} from './reshape'; import {sub} from './sub'; import {sum} from './sum'; -import {exp, log} from './unary_ops'; /** * Computes the log(sum(exp(elements across the reduction dimensions)). diff --git a/tfjs-core/src/ops/log_test.ts b/tfjs-core/src/ops/log_test.ts new file mode 100644 index 00000000000..e34df65eb6d --- /dev/null +++ b/tfjs-core/src/ops/log_test.ts @@ -0,0 +1,105 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import * as tf from '../index'; +import {ALL_ENVS, describeWithFlags} from '../jasmine_util'; +import {expectArraysClose} from '../test_util'; + +describeWithFlags('log', ALL_ENVS, () => { + it('log', async () => { + const a = tf.tensor1d([1, 2]); + const r = tf.log(a); + expectArraysClose(await r.data(), [Math.log(1), Math.log(2)]); + }); + + it('log 6D', async () => { + const a = tf.range(1, 65).reshape([2, 2, 2, 2, 2, 2]); + const r = tf.log(a); + + const expectedResult = []; + for (let i = 1; i < 65; i++) { + expectedResult[i - 1] = Math.log(i); + } + + expectArraysClose(await r.data(), expectedResult); + }); + + it('log propagates NaNs', async () => { + const a = tf.tensor1d([1, NaN]); + const r = tf.log(a); + expectArraysClose(await r.data(), [Math.log(1), NaN]); + }); + + it('gradients: Scalar', async () => { + const a = tf.scalar(5); + const dy = tf.scalar(3); + + const gradients = tf.grad(a => tf.log(a))(a, dy); + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), [3 / 5]); + }); + + it('gradient with clones', async () => { + const a = tf.scalar(5); + const dy = tf.scalar(3); + + const gradients = tf.grad(a => tf.log(a.clone()).clone())(a, dy); + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), [3 / 5]); + }); + + it('gradients: Tensor1D', async () => { + const a = tf.tensor1d([-1, 2, 3, -5]); + const dy = tf.tensor1d([1, 2, 3, 4]); + + const gradients = tf.grad(a => tf.log(a))(a, dy); + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), [1 / -1, 2 / 2, 3 / 3, 4 / -5]); + }); + + it('gradients: Tensor2D', async () => { + const a = tf.tensor2d([-3, 1, 2, 3], [2, 2]); + const dy = tf.tensor2d([1, 2, 3, 4], [2, 2]); + + const gradients = tf.grad(a => tf.log(a))(a, dy); + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), [1 / -3, 2 / 1, 3 / 2, 4 / 3]); + }); + + it('throws when passed a non-tensor', () => { + expect(() => tf.log({} as tf.Tensor)) + .toThrowError(/Argument 'x' passed to 'log' must be a Tensor/); + }); + + it('accepts a tensor-like object', async () => { + const r = tf.log([1, 2]); + expectArraysClose(await r.data(), [Math.log(1), Math.log(2)]); + }); + + it('throws for string tensor', () => { + expect(() => tf.log('q')) + .toThrowError(/Argument 'x' passed to 'log' must be numeric/); + }); +}); diff --git a/tfjs-core/src/ops/ops.ts b/tfjs-core/src/ops/ops.ts index 378658b52ca..e9fbfecc4d1 100644 --- a/tfjs-core/src/ops/ops.ts +++ b/tfjs-core/src/ops/ops.ts @@ -40,6 +40,7 @@ export {batchNorm3d} from './batchnorm3d'; export {batchNorm4d} from './batchnorm4d'; export {broadcastTo} from './broadcast_to'; export {ceil} from './ceil'; +export {clipByValue} from './clip_by_value'; export {clone} from './clone'; export {complex} from './complex'; export {concat} from './concat'; @@ -64,7 +65,10 @@ export {divNoNan} from './div_no_nan'; export {dot} from './dot'; export {elu} from './elu'; export {equal} from './equal'; +export {erf} from './erf'; +export {exp} from './exp'; export {expandDims} from './expand_dims'; +export {expm1} from './expm1'; export {eye} from './eye'; export {fill} from './fill'; export {floor} from './floor'; @@ -77,6 +81,9 @@ export {leakyRelu} from './leaky_relu'; export {less} from './less'; export {lessEqual} from './less_equal'; export {localResponseNormalization} from './local_response_normalization'; +export {log} from './log'; +export {log1p} from './log1p'; +export {logSigmoid} from './log_sigmoid'; export {logSumExp} from './log_sum_exp'; export {logicalAnd} from './logical_and'; export {logicalNot} from './logical_not'; @@ -114,6 +121,7 @@ export {randomGamma} from './random_gamma'; export {randomNormal} from './random_normal'; export {randomUniform} from './random_uniform'; export {real} from './real'; +export {reciprocal} from './reciprocal'; export {relu} from './relu'; export {relu6} from './relu6'; export {reshape} from './reshape'; diff --git a/tfjs-core/src/ops/reciprocal.ts b/tfjs-core/src/ops/reciprocal.ts new file mode 100644 index 00000000000..de58c96c85e --- /dev/null +++ b/tfjs-core/src/ops/reciprocal.ts @@ -0,0 +1,48 @@ +/** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {ENGINE} from '../engine'; +import {Reciprocal, ReciprocalInputs} from '../kernel_names'; +import {Tensor} from '../tensor'; +import {NamedTensorMap} from '../tensor_types'; +import {convertToTensor} from '../tensor_util_env'; +import {TensorLike} from '../types'; + +import {op} from './operation'; + +/** + * Computes reciprocal of x element-wise: `1 / x` + * + * ```js + * const x = tf.tensor1d([0, 1, 2]); + * + * x.reciprocal().print(); // or tf.reciprocal(x) + * ``` + * @param x The input tensor. + */ +/** @doc {heading: 'Operations', subheading: 'Basic math'} */ +function reciprocal_(x: T|TensorLike): T { + const $x = convertToTensor(x, 'x', 'reciprocal'); + + const inputs: ReciprocalInputs = {x: $x}; + return ENGINE.runKernelFunc((backend, save) => { + const res = backend.reciprocal($x); + save([$x]); + return res; + }, inputs as {} as NamedTensorMap, null /* grad */, Reciprocal); +} +export const reciprocal = op({reciprocal_}); diff --git a/tfjs-core/src/ops/reciprocal_test.ts b/tfjs-core/src/ops/reciprocal_test.ts new file mode 100644 index 00000000000..4d5ed4d2c61 --- /dev/null +++ b/tfjs-core/src/ops/reciprocal_test.ts @@ -0,0 +1,106 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import * as tf from '../index'; +import {ALL_ENVS, describeWithFlags} from '../jasmine_util'; +import {expectArraysClose} from '../test_util'; + +describeWithFlags('reciprocal', ALL_ENVS, () => { + it('1D array', async () => { + const a = tf.tensor1d([2, 3, 0, NaN]); + const r = tf.reciprocal(a); + expectArraysClose(await r.data(), [1 / 2, 1 / 3, Infinity, NaN]); + }); + + it('2D array', async () => { + const a = tf.tensor2d([1, Infinity, 0, NaN], [2, 2]); + const r = tf.reciprocal(a); + expect(r.shape).toEqual([2, 2]); + expectArraysClose(await r.data(), [1 / 1, 0, Infinity, NaN]); + }); + + it('reciprocal propagates NaNs', async () => { + const a = tf.tensor1d([1.5, NaN]); + const r = tf.reciprocal(a); + expectArraysClose(await r.data(), [1 / 1.5, NaN]); + }); + + it('gradients: Scalar', async () => { + const a = tf.scalar(5); + const dy = tf.scalar(8); + + const gradients = tf.grad(a => tf.reciprocal(a))(a, dy); + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), [-1 * 8 * (1 / (5 * 5))]); + }); + + it('gradient with clones', async () => { + const a = tf.scalar(5); + const dy = tf.scalar(8); + + const gradients = tf.grad(a => tf.reciprocal(a.clone()).clone())(a, dy); + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), [-1 * 8 * (1 / (5 * 5))]); + }); + + it('gradients: Tensor1D', async () => { + const a = tf.tensor1d([-1, 2, 3, -5]); + const dy = tf.tensor1d([1, 2, 3, 4]); + + const gradients = tf.grad(a => tf.reciprocal(a))(a, dy); + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), [ + -1 * 1 * (1 / (-1 * -1)), -1 * 2 * (1 / (2 * 2)), -1 * 3 * (1 / (3 * 3)), + -1 * 4 * (1 / (-5 * -5)) + ]); + }); + + it('gradients: Tensor2D', async () => { + const a = tf.tensor2d([-1, 2, 3, -5], [2, 2]); + const dy = tf.tensor2d([1, 2, 3, 4], [2, 2]); + + const gradients = tf.grad(a => tf.reciprocal(a))(a, dy); + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), [ + -1 * 1 * (1 / (-1 * -1)), -1 * 2 * (1 / (2 * 2)), -1 * 3 * (1 / (3 * 3)), + -1 * 4 * (1 / (-5 * -5)) + ]); + }); + + it('throws when passed a non-tensor', () => { + expect(() => tf.reciprocal({} as tf.Tensor)) + .toThrowError(/Argument 'x' passed to 'reciprocal' must be a Tensor/); + }); + + it('accepts a tensor-like object', async () => { + const r = tf.reciprocal([2, 3, 0, NaN]); + expectArraysClose(await r.data(), [1 / 2, 1 / 3, Infinity, NaN]); + }); + + it('throws for string tensor', () => { + expect(() => tf.reciprocal('q')) + .toThrowError(/Argument 'x' passed to 'reciprocal' must be numeric/); + }); +}); diff --git a/tfjs-core/src/ops/sigmoid_cross_entropy.ts b/tfjs-core/src/ops/sigmoid_cross_entropy.ts index 7316002b2da..b98318ee523 100644 --- a/tfjs-core/src/ops/sigmoid_cross_entropy.ts +++ b/tfjs-core/src/ops/sigmoid_cross_entropy.ts @@ -23,6 +23,8 @@ import {assertShapesMatch} from '../util'; import {abs} from './abs'; import {add} from './add'; import {computeWeightedLoss} from './compute_weighted_loss'; +import {exp} from './exp'; +import {log1p} from './log1p'; import {Reduction} from './loss_ops_utils'; import {mul} from './mul'; import {neg} from './neg'; @@ -30,7 +32,6 @@ import {op} from './operation'; import {relu} from './relu'; import {sub} from './sub'; import {scalar} from './tensor_ops'; -import {exp, log1p} from './unary_ops'; function sigmoidCrossEntropyWithLogits_( labels: T|TensorLike, logits: T|TensorLike): O { diff --git a/tfjs-core/src/ops/softmax_cross_entropy.ts b/tfjs-core/src/ops/softmax_cross_entropy.ts index b86f01f1db9..93869f286e5 100644 --- a/tfjs-core/src/ops/softmax_cross_entropy.ts +++ b/tfjs-core/src/ops/softmax_cross_entropy.ts @@ -26,6 +26,7 @@ import {cast} from './array_ops'; import {expandShapeToKeepDim} from './axis_util'; import {computeWeightedLoss} from './compute_weighted_loss'; import {div} from './div'; +import {exp} from './exp'; import {logSumExp} from './log_sum_exp'; import {Reduction} from './loss_ops_utils'; import {mul} from './mul'; @@ -35,7 +36,6 @@ import {reshape} from './reshape'; import {sub} from './sub'; import {sum} from './sum'; import {scalar} from './tensor_ops'; -import {exp} from './unary_ops'; /** * Computes softmax cross entropy between logits and labels. diff --git a/tfjs-core/src/ops/unary_ops.ts b/tfjs-core/src/ops/unary_ops.ts index 13fb7f66a25..7995a23a3ee 100644 --- a/tfjs-core/src/ops/unary_ops.ts +++ b/tfjs-core/src/ops/unary_ops.ts @@ -19,7 +19,6 @@ import {ENGINE} from '../engine'; import {Tensor} from '../tensor'; import {convertToTensor} from '../tensor_util_env'; import {TensorLike} from '../types'; -import * as util from '../util'; import {op} from './operation'; import {scalar, zerosLike} from './tensor_ops'; @@ -112,115 +111,6 @@ function round_(x: T|TensorLike): T { return ENGINE.runKernelFunc(backend => backend.round($x), {$x}, grad); } -/** - * Computes exponential of the input `tf.Tensor` element-wise. `e ^ x` - * - * ```js - * const x = tf.tensor1d([1, 2, -3]); - * - * x.exp().print(); // or tf.exp(x) - * ``` - * @param x The input tensor. - */ -/** @doc {heading: 'Operations', subheading: 'Basic math'} */ -function exp_(x: T|TensorLike): T { - const $x = convertToTensor(x, 'x', 'exp'); - - const bck = (dy: T, saved: Tensor[]) => { - // tslint:disable-next-line: no-unnecessary-type-assertion - return {x: () => dy.mul(saved[0]) as T}; - }; - const attrs = {}; - const inputsToSave: Tensor[] = []; - const outputsToSave = [true]; - return ENGINE.runKernelFunc((backend, save) => { - const y = backend.exp($x); - save([y]); - return y; - }, {x: $x}, bck, 'Exp', attrs, inputsToSave, outputsToSave); -} - -/** - * Computes exponential of the input `tf.Tensor` minus one element-wise. - * `e ^ x - 1` - * - * ```js - * const x = tf.tensor1d([1, 2, -3]); - * - * x.expm1().print(); // or tf.expm1(x) - * ``` - * @param x The input tensor. - */ -/** @doc {heading: 'Operations', subheading: 'Basic math'} */ -function expm1_(x: T|TensorLike): T { - const $x = convertToTensor(x, 'x', 'expm1'); - - const grad = (dy: T, saved: Tensor[]) => { - const [$x] = saved; - return {$x: () => dy.mul($x.exp())} as {$x: () => T}; - }; - return ENGINE.runKernelFunc((backend, save) => { - const res = backend.expm1($x); - save([$x]); - return res; - }, {$x}, grad); -} - -/** - * Computes natural logarithm of the input `tf.Tensor` element-wise: `ln(x)` - * - * ```js - * const x = tf.tensor1d([1, 2, Math.E]); - * - * x.log().print(); // or tf.log(x) - * ``` - * @param x The input tensor. - */ -/** @doc {heading: 'Operations', subheading: 'Basic math'} */ -function log_(x: T|TensorLike): T { - const $x = convertToTensor(x, 'x', 'log'); - - const grad = (dy: T, saved: Tensor[]) => { - const [$x] = saved; - return {x: () => dy.div($x.toFloat())} as {x: () => T}; - }; - - const attrs = {}; - const inputsToSave = [$x]; - - return ENGINE.runKernelFunc((backend, save) => { - const res = backend.log($x); - save([$x]); - return res; - }, {x: $x}, grad, 'Log', attrs, inputsToSave); -} - -/** - * Computes natural logarithm of the input `tf.Tensor` plus one - * element-wise: `ln(1 + x)` - * - * ```js - * const x = tf.tensor1d([1, 2, Math.E - 1]); - * - * x.log1p().print(); // or tf.log1p(x) - * ``` - * @param x The input tensor. - */ -/** @doc {heading: 'Operations', subheading: 'Basic math'} */ -function log1p_(x: T|TensorLike): T { - const $x = convertToTensor(x, 'x', 'log1p'); - - const grad = (dy: T, saved: Tensor[]) => { - const [$x] = saved; - return {$x: () => dy.div($x.add(1))} as {$x: () => T}; - }; - return ENGINE.runKernelFunc((backend, save) => { - const res = backend.log1p($x); - save([$x]); - return res; - }, {$x}, grad); -} - /** * Computes square root of the input `tf.Tensor` element-wise: `y = sqrt(x)` * @@ -273,71 +163,6 @@ function rsqrt_(x: T|TensorLike): T { }, {x: $x}, grad, 'Rsqrt', {} /* attrs */, inputsToSave); } -/** - * Computes reciprocal of x element-wise: `1 / x` - * - * ```js - * const x = tf.tensor1d([0, 1, 2]); - * - * x.reciprocal().print(); // or tf.reciprocal(x) - * ``` - * @param x The input tensor. - */ -/** @doc {heading: 'Operations', subheading: 'Basic math'} */ -function reciprocal_(x: T|TensorLike): T { - const $x = convertToTensor(x, 'x', 'reciprocal'); - - const grad = (dy: T, saved: Tensor[]) => { - const [$x] = saved; - return {$x: () => dy.div($x.square().neg())} as {$x: () => T}; - }; - return ENGINE.runKernelFunc((backend, save) => { - const res = backend.reciprocal($x); - save([$x]); - return res; - }, {$x}, grad); -} - -/** - * Clips values element-wise. `max(min(x, clipValueMax), clipValueMin)` - * - * ```js - * const x = tf.tensor1d([-1, 2, -3, 4]); - * - * x.clipByValue(-2, 3).print(); // or tf.clipByValue(x, -2, 3) - * ``` - * @param x The input tensor. - * @param clipValueMin Lower-bound of range to be clipped to. - * @param clipValueMax Upper-bound of range to be clipped to. - */ -/** @doc {heading: 'Operations', subheading: 'Basic math'} */ -function clipByValue_( - x: T|TensorLike, clipValueMin: number, clipValueMax: number): T { - const $x = convertToTensor(x, 'x', 'clipByValue'); - util.assert( - (clipValueMin <= clipValueMax), - () => `Error in clip: min (${clipValueMin}) must be ` + - `less than or equal to max (${clipValueMax}).`); - - const grad = (dy: T, saved: Tensor[]) => { - const [$x] = saved; - return { - // tslint:disable-next-line: no-unnecessary-type-assertion - x: () => dy.where( - $x.greaterEqual(clipValueMin) - .logicalAnd($x.lessEqual(clipValueMax)), - zerosLike(dy)) as T, - }; - }; - const inputsToSave = [$x]; - const attr = {min: clipValueMin, max: clipValueMax}; - return ENGINE.runKernelFunc((backend, save) => { - const res = backend.clip($x, clipValueMin, clipValueMax); - save([$x]); - return res; - }, {x: $x}, grad, 'ClipByValue', attr, inputsToSave); -} - /** * Computes sigmoid element-wise, `1 / (1 + exp(-x))` * @@ -363,32 +188,6 @@ function sigmoid_(x: T|TensorLike): T { }, {x: $x}, grad, 'Sigmoid'); } -/** - * Computes log sigmoid of the input `tf.Tensor` element-wise: - * `logSigmoid(x)`. For numerical stability, we use `-tf.softplus(-x)`. - * - * ```js - * const x = tf.tensor1d([0, 1, -1, .7]); - * - * x.logSigmoid().print(); // or tf.logSigmoid(x) - * ``` - * @param x The input tensor. - */ -/** @doc {heading: 'Operations', subheading: 'Basic math'} */ -function logSigmoid_(x: T|TensorLike): T { - const $x = convertToTensor(x, 'x', 'logSigmoid'); - - const grad = (dy: T, saved: Tensor[]) => { - const [$x] = saved; - return {$x: () => dy.mul($x.neg().sigmoid())} as {$x: () => T}; - }; - return ENGINE.runKernelFunc((backend, save) => { - const res = backend.softplus($x.neg()).neg(); - save([$x]); - return res; - }, {$x}, grad); -} - /** * Computes softplus of the input `tf.Tensor` element-wise: `log(exp(x) + 1)` * @@ -414,41 +213,6 @@ function softplus_(x: T|TensorLike): T { }, {$x}, grad); } -/** - * Computes gause error function of the input `tf.Tensor` element-wise: - * `erf(x)` - * - * ```js - * const x = tf.tensor1d([0, .1, -.1, .7]); - * - * x.erf().print(); // or tf.erf(x); - * ``` - * @param x The input tensor. - */ -/** @doc {heading: 'Operations', subheading: 'Basic math'} */ -function erf_(x: T|TensorLike): T { - let $x = convertToTensor(x, 'x', 'erf'); - util.assert( - $x.dtype === 'int32' || $x.dtype === 'float32', - () => 'Input dtype must be `int32` or `float32`.'); - - if ($x.dtype === 'int32') { - $x = $x.toFloat(); - } - - const grad = (dy: T, saved: Tensor[]) => { - const [$x] = saved; - return { - $x: () => dy.mul($x.square().neg().exp().mul(2 / Math.sqrt(Math.PI))) - } as {$x: () => T}; - }; - return ENGINE.runKernelFunc((backend, save) => { - const res = backend.erf($x); - save([$x]); - return res; - }, {$x}, grad); -} - /** * Computes step of the input `tf.Tensor` element-wise: `x > 0 ? 1 : alpha * x` * @@ -472,14 +236,6 @@ function step_(x: T|TensorLike, alpha = 0.0): T { return ENGINE.runKernelFunc(backend => backend.step($x, alpha), {$x}, grad); } -export const clipByValue = op({clipByValue_}); -export const erf = op({erf_}); -export const exp = op({exp_}); -export const expm1 = op({expm1_}); -export const log = op({log_}); -export const log1p = op({log1p_}); -export const logSigmoid = op({logSigmoid_}); -export const reciprocal = op({reciprocal_}); export const round = op({round_}); export const rsqrt = op({rsqrt_}); export const sigmoid = op({sigmoid_}); diff --git a/tfjs-core/src/ops/unary_ops_test.ts b/tfjs-core/src/ops/unary_ops_test.ts index a45ca1d2810..3b2154dbeff 100644 --- a/tfjs-core/src/ops/unary_ops_test.ts +++ b/tfjs-core/src/ops/unary_ops_test.ts @@ -192,148 +192,6 @@ describeWithFlags('sigmoid', ALL_ENVS, () => { }); }); -describeWithFlags('logSigmoid', ALL_ENVS, () => { - it('basic', async () => { - const values = [1, -3, 2, 7, -4]; - const a = tf.tensor1d(values); - - const result = tf.logSigmoid(a); - - const expected = []; - for (let i = 0; i < a.size; i++) { - expected[i] = Math.log(1 / (1 + Math.exp(-values[i]))); - } - expectArraysClose(await result.data(), expected); - }); - - it('scalar', async () => { - const a = tf.scalar(-2); - - const result = tf.logSigmoid(a); - - const expected = [Math.log(1 / (1 + Math.exp(2)))]; - expectArraysClose(await result.data(), expected); - }); - - it('tensor2D', async () => { - const values = [1, 2, -3, 5]; - const a = tf.tensor2d(values, [2, 2]); - - const result = tf.logSigmoid(a); - - const expected = []; - for (let i = 0; i < a.size; i++) { - expected[i] = Math.log(1 / (1 + Math.exp(-values[i]))); - } - expectArraysClose(await result.data(), expected); - }); - - it('larger magnitude negative inputs', async () => { - const values = [-100, -200, -3000]; - const a = tf.tensor1d(values); - - const result = tf.logSigmoid(a); - - const expected = [-100, -200, -3000]; - - expectArraysClose(await result.data(), expected); - }); - - it('larger magnitude positive inputs', async () => { - const values = [100, 200, 3000, 50000]; - const a = tf.tensor1d(values); - - const result = tf.logSigmoid(a); - - const expected = [0, 0, 0, 0]; - - expectArraysClose(await result.data(), expected); - }); - - it('propagates NaNs', async () => { - const a = tf.tensor1d([3, NaN]); - const res = tf.logSigmoid(a); - expectArraysClose( - await res.data(), [Math.log(1 / (1 + Math.exp(-3))), NaN]); - }); - - it('gradients: Scalar', async () => { - const a = tf.scalar(3); - const dy = tf.scalar(4); - const dyVal = await dy.array(); - - const da = tf.grad(a => tf.logSigmoid(a))(a, dy); - const aVal = await a.array(); - const y = 1 / (1 + Math.exp(aVal)); - expectArraysClose(await da.data(), [dyVal * y]); - }); - - it('gradients: Tensor1D', async () => { - const a = tf.tensor1d([1, 2, -3, 5]); - const aVals = await a.array(); - const dy = tf.tensor1d([1, 2, 3, 4]); - const dyVals = await dy.array(); - const da = tf.grad(a => tf.logSigmoid(a))(a, dy); - - const expected = []; - for (let i = 0; i < a.size; i++) { - const y = 1 / (1 + Math.exp(aVals[i])); - expected[i] = dyVals[i] * y; - } - - expectArraysClose(await da.data(), expected); - }); - - it('gradient with clones', async () => { - const a = tf.tensor1d([1, 2, -3, 5]); - const aVals = await a.array(); - const dy = tf.tensor1d([1, 2, 3, 4]); - const dyVals = await dy.array(); - const da = tf.grad(a => tf.logSigmoid(a.clone()).clone())(a, dy); - - const expected = []; - for (let i = 0; i < a.size; i++) { - const y = 1 / (1 + Math.exp(aVals[i])); - expected[i] = dyVals[i] * y; - } - - expectArraysClose(await da.data(), expected); - }); - - it('gradients: Tensor2D', async () => { - const a = tf.tensor2d([1, 2, -3, 5], [2, 2]); - const dy = tf.tensor2d([1, 2, 3, 4], [2, 2]); - - const da = tf.grad(a => tf.logSigmoid(a))(a, dy); - - const expected = []; - const aVals = await a.data(); - const dyVals = await dy.data(); - for (let i = 0; i < a.size; i++) { - const y = 1 / (1 + Math.exp(aVals[i])); - expected[i] = dyVals[i] * y; - } - - expectArraysClose(await da.data(), expected); - }); - - it('throws when passed a non-tensor', () => { - expect(() => tf.logSigmoid({} as tf.Tensor)) - .toThrowError(/Argument 'x' passed to 'logSigmoid' must be a Tensor/); - }); - - it('accepts a tensor-like object', async () => { - const result = tf.logSigmoid(-2); - const expected = [Math.log(1 / (1 + Math.exp(2)))]; - expectArraysClose(await result.data(), expected); - }); - - it('throws for string tensor', () => { - expect(() => tf.logSigmoid('q')) - .toThrowError(/Argument 'x' passed to 'logSigmoid' must be numeric/); - }); -}); - describeWithFlags('softplus', ALL_ENVS, () => { it('basic', async () => { const values = [1, -3, 2, 7, -4]; @@ -763,251 +621,6 @@ describeWithFlags('square', ALL_ENVS, () => { }); }); -describeWithFlags('reciprocal', ALL_ENVS, () => { - it('1D array', async () => { - const a = tf.tensor1d([2, 3, 0, NaN]); - const r = tf.reciprocal(a); - expectArraysClose(await r.data(), [1 / 2, 1 / 3, Infinity, NaN]); - }); - - it('2D array', async () => { - const a = tf.tensor2d([1, Infinity, 0, NaN], [2, 2]); - const r = tf.reciprocal(a); - expect(r.shape).toEqual([2, 2]); - expectArraysClose(await r.data(), [1 / 1, 0, Infinity, NaN]); - }); - - it('reciprocal propagates NaNs', async () => { - const a = tf.tensor1d([1.5, NaN]); - const r = tf.reciprocal(a); - expectArraysClose(await r.data(), [1 / 1.5, NaN]); - }); - - it('gradients: Scalar', async () => { - const a = tf.scalar(5); - const dy = tf.scalar(8); - - const gradients = tf.grad(a => tf.reciprocal(a))(a, dy); - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), [-1 * 8 * (1 / (5 * 5))]); - }); - - it('gradient with clones', async () => { - const a = tf.scalar(5); - const dy = tf.scalar(8); - - const gradients = tf.grad(a => tf.reciprocal(a.clone()).clone())(a, dy); - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), [-1 * 8 * (1 / (5 * 5))]); - }); - - it('gradients: Tensor1D', async () => { - const a = tf.tensor1d([-1, 2, 3, -5]); - const dy = tf.tensor1d([1, 2, 3, 4]); - - const gradients = tf.grad(a => tf.reciprocal(a))(a, dy); - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), [ - -1 * 1 * (1 / (-1 * -1)), -1 * 2 * (1 / (2 * 2)), -1 * 3 * (1 / (3 * 3)), - -1 * 4 * (1 / (-5 * -5)) - ]); - }); - - it('gradients: Tensor2D', async () => { - const a = tf.tensor2d([-1, 2, 3, -5], [2, 2]); - const dy = tf.tensor2d([1, 2, 3, 4], [2, 2]); - - const gradients = tf.grad(a => tf.reciprocal(a))(a, dy); - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), [ - -1 * 1 * (1 / (-1 * -1)), -1 * 2 * (1 / (2 * 2)), -1 * 3 * (1 / (3 * 3)), - -1 * 4 * (1 / (-5 * -5)) - ]); - }); - - it('throws when passed a non-tensor', () => { - expect(() => tf.reciprocal({} as tf.Tensor)) - .toThrowError(/Argument 'x' passed to 'reciprocal' must be a Tensor/); - }); - - it('accepts a tensor-like object', async () => { - const r = tf.reciprocal([2, 3, 0, NaN]); - expectArraysClose(await r.data(), [1 / 2, 1 / 3, Infinity, NaN]); - }); - - it('throws for string tensor', () => { - expect(() => tf.reciprocal('q')) - .toThrowError(/Argument 'x' passed to 'reciprocal' must be numeric/); - }); -}); - -describeWithFlags('log', ALL_ENVS, () => { - it('log', async () => { - const a = tf.tensor1d([1, 2]); - const r = tf.log(a); - expectArraysClose(await r.data(), [Math.log(1), Math.log(2)]); - }); - - it('log 6D', async () => { - const a = tf.range(1, 65).reshape([2, 2, 2, 2, 2, 2]); - const r = tf.log(a); - - const expectedResult = []; - for (let i = 1; i < 65; i++) { - expectedResult[i - 1] = Math.log(i); - } - - expectArraysClose(await r.data(), expectedResult); - }); - - it('log propagates NaNs', async () => { - const a = tf.tensor1d([1, NaN]); - const r = tf.log(a); - expectArraysClose(await r.data(), [Math.log(1), NaN]); - }); - - it('gradients: Scalar', async () => { - const a = tf.scalar(5); - const dy = tf.scalar(3); - - const gradients = tf.grad(a => tf.log(a))(a, dy); - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), [3 / 5]); - }); - - it('gradient with clones', async () => { - const a = tf.scalar(5); - const dy = tf.scalar(3); - - const gradients = tf.grad(a => tf.log(a.clone()).clone())(a, dy); - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), [3 / 5]); - }); - - it('gradients: Tensor1D', async () => { - const a = tf.tensor1d([-1, 2, 3, -5]); - const dy = tf.tensor1d([1, 2, 3, 4]); - - const gradients = tf.grad(a => tf.log(a))(a, dy); - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), [1 / -1, 2 / 2, 3 / 3, 4 / -5]); - }); - - it('gradients: Tensor2D', async () => { - const a = tf.tensor2d([-3, 1, 2, 3], [2, 2]); - const dy = tf.tensor2d([1, 2, 3, 4], [2, 2]); - - const gradients = tf.grad(a => tf.log(a))(a, dy); - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), [1 / -3, 2 / 1, 3 / 2, 4 / 3]); - }); - - it('throws when passed a non-tensor', () => { - expect(() => tf.log({} as tf.Tensor)) - .toThrowError(/Argument 'x' passed to 'log' must be a Tensor/); - }); - - it('accepts a tensor-like object', async () => { - const r = tf.log([1, 2]); - expectArraysClose(await r.data(), [Math.log(1), Math.log(2)]); - }); - - it('throws for string tensor', () => { - expect(() => tf.log('q')) - .toThrowError(/Argument 'x' passed to 'log' must be numeric/); - }); -}); - -describeWithFlags('log1p', ALL_ENVS, () => { - it('log1p', async () => { - const a = tf.tensor1d([1, 2]); - const r = tf.log1p(a); - expectArraysClose(await r.data(), [Math.log1p(1), Math.log1p(2)]); - }); - - it('log1p propagates NaNs', async () => { - const a = tf.tensor1d([1, NaN]); - const r = tf.log1p(a); - expectArraysClose(await r.data(), [Math.log1p(1), NaN]); - }); - - it('gradients: Scalar', async () => { - const a = tf.scalar(5); - const dy = tf.scalar(3); - - const gradients = tf.grad(a => tf.log1p(a))(a, dy); - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), [3 / (1 + 5)]); - }); - - it('gradient with clones', () => { - const a = tf.scalar(5); - const gradients = tf.grad(a => a.clone().log1p().clone())(a); - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - }); - - it('gradients: Tensor1D', async () => { - const a = tf.tensor1d([-1, 2, 3, -5]); - const dy = tf.tensor1d([1, 2, 3, 4]); - - const gradients = tf.grad(a => tf.log1p(a))(a, dy); - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose( - await gradients.data(), - [Infinity, 2 / (1 + 2), 3 / (1 + 3), 4 / (1 + -5)]); - }); - - it('gradients: Tensor2D', async () => { - const a = tf.tensor2d([-3, 1, 2, 3], [2, 2]); - const dy = tf.tensor2d([1, 2, 3, 4], [2, 2]); - - const gradients = tf.grad(a => tf.log1p(a))(a, dy); - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose( - await gradients.data(), - [1 / (1 + -3), 2 / (1 + 1), 3 / (1 + 2), 4 / (1 + 3)]); - }); - - it('throws when passed a non-tensor', () => { - expect(() => tf.log1p({} as tf.Tensor)) - .toThrowError(/Argument 'x' passed to 'log1p' must be a Tensor/); - }); - - it('accepts a tensor-like object', async () => { - const r = tf.log1p([1, 2]); - expectArraysClose(await r.data(), [Math.log1p(1), Math.log1p(2)]); - }); - - it('throws for string tensor', () => { - expect(() => tf.log1p('q')) - .toThrowError(/Argument 'x' passed to 'log1p' must be numeric/); - }); -}); - describeWithFlags('isNaN', ALL_ENVS, () => { it('basic', async () => { const a = tf.tensor1d([NaN, Infinity, -Infinity, 0, 1]); @@ -1179,287 +792,6 @@ describeWithFlags('isFinite', ALL_ENVS, () => { }); }); -describeWithFlags('exp', ALL_ENVS, () => { - it('exp', async () => { - const a = tf.tensor1d([1, 2, 0]); - const r = tf.exp(a); - - expectArraysClose(await r.data(), [Math.exp(1), Math.exp(2), 1]); - }); - - it('exp propagates NaNs', async () => { - const a = tf.tensor1d([1, NaN, 0]); - const r = tf.exp(a); - expectArraysClose(await r.data(), [Math.exp(1), NaN, 1]); - }); - - it('gradients: Scalar', async () => { - const a = tf.scalar(0.5); - const dy = tf.scalar(3); - - const gradients = tf.grad(a => tf.exp(a))(a, dy); - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), [3 * Math.exp(0.5)]); - }); - - it('gradient with clones', async () => { - const a = tf.scalar(0.5); - const dy = tf.scalar(3); - - const gradients = tf.grad(a => tf.exp(a.clone()).clone())(a, dy); - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), [3 * Math.exp(0.5)]); - }); - - it('gradients: Tensor1D', async () => { - const a = tf.tensor1d([-1, 2, 3, -5]); - const dy = tf.tensor1d([1, 2, 3, 4]); - - const gradients = tf.grad(a => tf.exp(a))(a, dy); - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose( - await gradients.data(), - [1 * Math.exp(-1), 2 * Math.exp(2), 3 * Math.exp(3), 4 * Math.exp(-5)], - 1e-1); - }); - - it('gradients: Tensor2D', async () => { - const a = tf.tensor2d([-3, 1, 2, 3], [2, 2]); - const dy = tf.tensor2d([1, 2, 3, 4], [2, 2]); - - const gradients = tf.grad(a => tf.exp(a))(a, dy); - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose( - await gradients.data(), - [1 * Math.exp(-3), 2 * Math.exp(1), 3 * Math.exp(2), 4 * Math.exp(3)], - 1e-1); - }); - - it('throws when passed a non-tensor', () => { - expect(() => tf.exp({} as tf.Tensor)) - .toThrowError(/Argument 'x' passed to 'exp' must be a Tensor/); - }); - - it('accepts a tensor-like object', async () => { - const r = tf.exp([1, 2, 0]); - expectArraysClose(await r.data(), [Math.exp(1), Math.exp(2), 1]); - }); - - it('throws for string tensor', () => { - expect(() => tf.exp('q')) - .toThrowError(/Argument 'x' passed to 'exp' must be numeric/); - }); -}); - -describeWithFlags('expm1', ALL_ENVS, () => { - it('expm1', async () => { - const a = tf.tensor1d([1, 2, 0]); - const r = tf.expm1(a); - - expectArraysClose( - await r.data(), [Math.expm1(1), Math.expm1(2), Math.expm1(0)]); - }); - - it('expm1 propagates NaNs', async () => { - const a = tf.tensor1d([1, NaN, 0]); - const r = tf.expm1(a); - expectArraysClose(await r.data(), [Math.expm1(1), NaN, Math.expm1(0)]); - }); - - it('gradients: Scalar', async () => { - const a = tf.scalar(0.5); - const dy = tf.scalar(3); - - const gradients = tf.grad(a => tf.expm1(a))(a, dy); - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), [3 * Math.exp(0.5)]); - }); - - it('gradient with clones', async () => { - const a = tf.scalar(0.5); - const dy = tf.scalar(3); - - const gradients = tf.grad(a => tf.expm1(a.clone()).clone())(a, dy); - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), [3 * Math.exp(0.5)]); - }); - - it('gradients: Tensor1D', async () => { - const a = tf.tensor1d([-1, 2, 3, -5]); - const dy = tf.tensor1d([1, 2, 3, 4]); - - const gradients = tf.grad(a => tf.expm1(a))(a, dy); - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose( - await gradients.data(), - [1 * Math.exp(-1), 2 * Math.exp(2), 3 * Math.exp(3), 4 * Math.exp(-5)], - 1e-1); - }); - - it('gradients: Tensor2D', async () => { - const a = tf.tensor2d([-3, 1, 2, 3], [2, 2]); - const dy = tf.tensor2d([1, 2, 3, 4], [2, 2]); - - const gradients = tf.grad(a => tf.expm1(a))(a, dy); - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose( - await gradients.data(), - [1 * Math.exp(-3), 2 * Math.exp(1), 3 * Math.exp(2), 4 * Math.exp(3)], - 1e-1); - }); - - it('throws when passed a non-tensor', () => { - expect(() => tf.expm1({} as tf.Tensor)) - .toThrowError(/Argument 'x' passed to 'expm1' must be a Tensor/); - }); - - it('accepts a tensor-like object', async () => { - const r = tf.expm1([1, 2, 0]); - - expectArraysClose( - await r.data(), [Math.expm1(1), Math.expm1(2), Math.expm1(0)]); - }); - - it('throws for string tensor', () => { - expect(() => tf.expm1('q')) - .toThrowError(/Argument 'x' passed to 'expm1' must be numeric/); - }); -}); - -describeWithFlags('clip', ALL_ENVS, () => { - it('basic', async () => { - const a = tf.tensor1d([3, -1, 0, 100, -7, 2]); - const min = -1; - const max = 50; - - const result = tf.clipByValue(a, min, max); - - expectArraysClose(await result.data(), [3, -1, 0, 50, -1, 2]); - }); - - it('propagates NaNs', async () => { - const a = tf.tensor1d([3, -1, 0, 100, -7, 2, NaN]); - const min = -1; - const max = 50; - - const result = tf.clipByValue(a, min, max); - - expectArraysClose(await result.data(), [3, -1, 0, 50, -1, 2, NaN]); - }); - - it('min greater than max', () => { - const a = tf.tensor1d([3, -1, 0, 100, -7, 2]); - const min = 1; - const max = -1; - - const f = () => { - tf.clipByValue(a, min, max); - }; - expect(f).toThrowError(); - }); - - it('gradient: 1D tensor', async () => { - const min = -1; - const max = 2; - const x = tf.tensor1d([3, -2, 1]); // Only 1 is not clipped. - const dy = tf.tensor1d([5, 50, 500]); - const gradients = tf.grad(x => x.clipByValue(min, max))(x, dy); - - expect(gradients.shape).toEqual(x.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), [0, 0, 500]); - }); - - it('gradient: 1D tensor with max or min value', async () => { - const min = -1; - const max = 2; - const x = tf.tensor1d([-1, 1, 2, 3]); - const dy = tf.tensor1d([1, 10, 100, 1000]); - const gradients = tf.grad(x => x.clipByValue(min, max))(x, dy); - - expect(gradients.shape).toEqual(x.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), [1, 10, 100, 0]); - }); - - it('gradient: scalar', async () => { - const min = -1; - const max = 2; - const x = tf.scalar(-10); // Clipped. - const dy = tf.scalar(5); - const gradients = tf.grad(x => x.clipByValue(min, max))(x, dy); - - expect(gradients.shape).toEqual(x.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), [0]); - }); - - it('gradient with clones', async () => { - const min = -1; - const max = 2; - const x = tf.scalar(-10); // Clipped. - const dy = tf.scalar(5); - const gradients = - tf.grad(x => x.clone().clipByValue(min, max).clone())(x, dy); - - expect(gradients.shape).toEqual(x.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), [0]); - }); - - it('gradient with primitive as input', async () => { - const min = -1; - const max = 2; - const x = -10; - const dy = tf.scalar(5); - const gradients = tf.grad(x => x.clipByValue(min, max))(x, dy); - expect(gradients.shape).toEqual([]); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), [0]); - }); - - it('throws when passed a non-tensor', () => { - expect(() => tf.clipByValue({} as tf.Tensor, 0, 1)) - .toThrowError(/Argument 'x' passed to 'clipByValue' must be a Tensor/); - }); - - it('accepts a tensor-like object', async () => { - const min = -1; - const max = 50; - const result = tf.clipByValue([3, -1, 0, 100, -7, 2], min, max); - expectArraysClose(await result.data(), [3, -1, 0, 50, -1, 2]); - }); - - it('clip(x, eps, 1-eps) never returns 0 or 1', async () => { - const min = tf.backend().epsilon(); - const max = 0.5; - const res = await tf.clipByValue([0, 1], min, max).data(); - expect(res[0]).toBeGreaterThan(0); - expect(res[1]).toBeCloseTo(max); - }); - - it('throws for string tensor', () => { - expect(() => tf.clipByValue('q', 0, 1)) - .toThrowError(/Argument 'x' passed to 'clipByValue' must be numeric/); - }); -}); - describeWithFlags('round', ALL_ENVS, () => { it('basic', async () => { const a = tf.tensor1d([0.9, 2.5, 2.3, 1.5, -4.5]); @@ -1533,118 +865,3 @@ describeWithFlags('round', ALL_ENVS, () => { .toThrowError(/Argument 'x' passed to 'round' must be numeric/); }); }); - -describeWithFlags('erf', ALL_ENVS, () => { - it('basic', async () => { - const values = [-0.25, 0.25, 0.5, .75, -0.4]; - const a = tf.tensor1d(values); - const result = tf.erf(a); - const expected = [-0.2763264, 0.2763264, 0.5204999, 0.7111556, -0.4283924]; - expectArraysClose(await result.data(), expected); - }); - - it('blowup', async () => { - const values = [-1.4, -2.5, -3.1, -4.4]; - const a = tf.tensor1d(values); - const result = tf.erf(a); - const expected = [-0.9522852, -0.999593, -0.9999883, -1]; - expectArraysClose(await result.data(), expected); - }); - - it('scalar', async () => { - const a = tf.scalar(1); - const result = tf.erf(a); - const expected = [0.8427008]; - expectArraysClose(await result.data(), expected); - }); - - it('scalar in int32', async () => { - const a = tf.scalar(1, 'int32'); - const result = tf.erf(a); - const expected = [0.8427008]; - expectArraysClose(await result.data(), expected); - }); - - it('tensor2d', async () => { - const values = [0.2, 0.3, 0.4, 0.5]; - const a = tf.tensor2d(values, [2, 2]); - const result = tf.erf(a); - const expected = [0.2227026, 0.32862678, 0.42839235, 0.5204999]; - expectArraysClose(await result.data(), expected); - }); - - it('propagates NaNs', async () => { - const a = tf.tensor1d([0.5, NaN, 0]); - const res = tf.erf(a); - expectArraysClose(await res.data(), [0.5204999, NaN, 0.0]); - }); - - it('gradients: Scalar', async () => { - const a = tf.scalar(0.5); - const dy = tf.scalar(8); - const gradients = tf.grad(a => tf.erf(a))(a, dy); - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose( - await gradients.data(), - [8 * 2 * Math.exp(-0.5 * 0.5) / Math.sqrt(Math.PI)]); - }); - - it('gradient with clones', async () => { - const a = tf.scalar(0.5); - const dy = tf.scalar(8); - const gradients = tf.grad(a => tf.erf(a.clone()).clone())(a, dy); - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose( - await gradients.data(), - [8 * 2 * Math.exp(-0.5 * 0.5) / Math.sqrt(Math.PI)]); - }); - - it('gradients: Tensor1D', async () => { - const aValues = [-0.1, 0.2, 0.3, -0.5]; - const dyValues = [1, 2, 3, 4]; - const a = tf.tensor1d(aValues); - const dy = tf.tensor1d(dyValues); - const gradients = tf.grad(a => tf.erf(a))(a, dy); - const expected = []; - for (let i = 0; i < a.size; i++) { - expected[i] = dyValues[i] * 2 * Math.exp(-aValues[i] * aValues[i]) / - Math.sqrt(Math.PI); - } - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), expected); - }); - - it('gradients: Tensor2D', async () => { - const aValues = [-0.3, 0.1, 0.2, 0.3]; - const dyValues = [1, 2, 3, 4]; - const a = tf.tensor2d(aValues, [2, 2]); - const dy = tf.tensor2d(dyValues, [2, 2]); - const gradients = tf.grad(a => tf.erf(a))(a, dy); - const expected = []; - for (let i = 0; i < a.size; i++) { - expected[i] = dyValues[i] * 2 * Math.exp(-aValues[i] * aValues[i]) / - Math.sqrt(Math.PI); - } - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), expected); - }); - - it('throws when passed a non-tensor', () => { - expect(() => tf.erf({} as tf.Tensor)) - .toThrowError(/Argument 'x' passed to 'erf' must be a Tensor/); - }); - - it('accepts a tensor-like object', async () => { - const result = tf.erf(1); - expectArraysClose(await result.data(), [0.8427008]); - }); - - it('throws for string tensor', () => { - expect(() => tf.erf('q')) - .toThrowError(/Argument 'x' passed to 'erf' must be numeric/); - }); -}); diff --git a/tfjs-core/src/register_all_gradients.ts b/tfjs-core/src/register_all_gradients.ts index 5aea486a509..730e734b2fb 100644 --- a/tfjs-core/src/register_all_gradients.ts +++ b/tfjs-core/src/register_all_gradients.ts @@ -32,6 +32,7 @@ import {batchMatMulGradConfig} from './gradients/BatchMatMul_grad'; import {batchToSpaceNDGradConfig} from './gradients/BatchToSpaceND_grad'; import {broadcastToGradConfig} from './gradients/BroadcastTo_grad'; import {ceilGradConfig} from './gradients/Ceil_grad'; +import {clipByValueGradConfig} from './gradients/ClipByValue_grad'; import {concatGradConfig} from './gradients/Concat_grad'; import {conv2DGradConfig} from './gradients/Conv2D_grad'; import {conv2DBackpropInputGradConfig} from './gradients/Conv2DBackpropInput_grad'; @@ -43,12 +44,17 @@ import {depthwiseConv2dNativeGradConfig} from './gradients/DepthwiseConv2dNative import {dilation2dGradConfig} from './gradients/Dilation2D_grad'; import {divGradConfig} from './gradients/Div_grad'; import {eluGradConfig} from './gradients/Elu_grad'; +import {erfGradConfig} from './gradients/Erf_grad'; +import {expGradConfig} from './gradients/Exp_grad'; +import {expm1GradConfig} from './gradients/Expm1_grad'; import {floorGradConfig} from './gradients/Floor_grad'; import {floorDivGradConfig} from './gradients/FloorDiv_grad'; import {fusedBatchNormGradConfig} from './gradients/FusedBatchNorm_grad'; import {gatherGradConfig} from './gradients/GatherV2_grad'; import {greaterEqualGradConfig} from './gradients/GreaterEqual_grad'; import {identityGradConfig} from './gradients/Identity_grad'; +import {log1pGradConfig} from './gradients/Log1p_grad'; +import {logGradConfig} from './gradients/Log_grad'; import {lrnGradConfig} from './gradients/LRN_grad'; import {maxGradConfig} from './gradients/Max_grad'; import {maximumGradConfig} from './gradients/Maximum_grad'; @@ -63,6 +69,7 @@ import {oneHotGradConfig} from './gradients/OneHot_grad'; import {padV2GradConfig} from './gradients/PadV2_grad'; import {powGradConfig} from './gradients/Pow_grad'; import {preluGradConfig} from './gradients/Prelu_grad'; +import {reciprocalGradConfig} from './gradients/Reciprocal_grad'; import {relu6GradConfig} from './gradients/Relu6_grad'; import {reluGradConfig} from './gradients/Relu_grad'; import {reshapeGradConfig} from './gradients/Reshape_grad'; @@ -110,6 +117,7 @@ const gradConfigs: GradConfig[] = [ batchToSpaceNDGradConfig, broadcastToGradConfig, ceilGradConfig, + clipByValueGradConfig, concatGradConfig, conv2DBackpropInputGradConfig, conv2DGradConfig, @@ -121,12 +129,17 @@ const gradConfigs: GradConfig[] = [ dilation2dGradConfig, divGradConfig, eluGradConfig, + erfGradConfig, + expGradConfig, + expm1GradConfig, floorDivGradConfig, floorGradConfig, fusedBatchNormGradConfig, gatherGradConfig, greaterEqualGradConfig, identityGradConfig, + log1pGradConfig, + logGradConfig, lrnGradConfig, maxGradConfig, maxGradConfig, @@ -144,6 +157,7 @@ const gradConfigs: GradConfig[] = [ padV2GradConfig, powGradConfig, preluGradConfig, + reciprocalGradConfig, relu6GradConfig, reluGradConfig, reshapeGradConfig, @@ -153,9 +167,9 @@ const gradConfigs: GradConfig[] = [ selectV2PoolGradConfig, seluGradConfig, signGradConfig, - sliceGradConfig, sinGradConfig, sinhGradConfig, + sliceGradConfig, spaceToBatchNDGradConfig, spaceToBatchNDGradConfig, splitVGradConfig, diff --git a/tfjs-core/src/tests.ts b/tfjs-core/src/tests.ts index 728cebed70f..6e2851f8e3c 100644 --- a/tfjs-core/src/tests.ts +++ b/tfjs-core/src/tests.ts @@ -66,6 +66,7 @@ import './ops/boolean_mask_test'; import './ops/broadcast_to_test'; import './ops/broadcast_util_test'; import './ops/ceil_test'; +import './ops/clip_by_value_test'; import './ops/clone_test'; import './ops/compare_ops_test'; import './ops/complex_ops_test'; @@ -93,7 +94,10 @@ import './ops/dropout_test'; import './ops/dropout_util_test'; import './ops/elu_test'; import './ops/equal_test'; +import './ops/erf_test'; +import './ops/exp_test'; import './ops/expand_dims_test'; +import './ops/expm1_test'; import './ops/eye_test'; import './ops/floor_test'; import './ops/frame_test'; @@ -112,8 +116,11 @@ import './ops/leaky_relu_test'; import './ops/less_equal_test'; import './ops/less_test'; import './ops/local_response_normalization_test'; +import './ops/log1p_test'; import './ops/log_loss_test'; +import './ops/log_sigmoid_test'; import './ops/log_sum_exp_test'; +import './ops/log_test'; import './ops/logical_and_test'; import './ops/logical_not_test'; import './ops/logical_or_test'; @@ -145,6 +152,7 @@ import './ops/rand_test'; import './ops/random_gamma_test'; import './ops/random_normal_test'; import './ops/random_uniform_test'; +import './ops/reciprocal_test'; import './ops/relu6_test'; import './ops/relu_test'; import './ops/resize_bilinear_test'; From 8aac335f79ee66a303ec51c37913addfa74a1de8 Mon Sep 17 00:00:00 2001 From: Yannick Assogba Date: Fri, 10 Jul 2020 18:16:43 -0400 Subject: [PATCH 2/4] fix clipbyvalue wasm --- tfjs-backend-wasm/src/kernels/ClipByValue.ts | 20 +++++--------------- 1 file changed, 5 insertions(+), 15 deletions(-) diff --git a/tfjs-backend-wasm/src/kernels/ClipByValue.ts b/tfjs-backend-wasm/src/kernels/ClipByValue.ts index c5fd8681510..40aeaae3445 100644 --- a/tfjs-backend-wasm/src/kernels/ClipByValue.ts +++ b/tfjs-backend-wasm/src/kernels/ClipByValue.ts @@ -15,20 +15,10 @@ * ============================================================================= */ -import {KernelConfig, NamedAttrMap, NamedTensorInfoMap} from '@tensorflow/tfjs-core'; -import {TensorInfo} from '@tensorflow/tfjs-core'; +import {ClipByValue, ClipByValueAttrs, ClipByValueInputs, KernelConfig, KernelFunc} from '@tensorflow/tfjs-core'; import {BackendWasm} from '../backend_wasm'; -interface ClipByValueInputs extends NamedTensorInfoMap { - x: TensorInfo; -} - -interface ClipByValueAttrs extends NamedAttrMap { - min: number; - max: number; -} - let wasmClip: (xId: number, min: number, max: number, outId: number) => void; function setup(backend: BackendWasm) { @@ -47,17 +37,17 @@ function clip(args: { }) { const {inputs, backend, attrs} = args; const {x} = inputs; - const {min, max} = attrs; + const {clipValueMin, clipValueMax} = attrs; const xId = backend.dataIdMap.get(x.dataId).id; const out = backend.makeOutput(x.shape, 'float32'); const outId = backend.dataIdMap.get(out.dataId).id; - wasmClip(xId, min, max, outId); + wasmClip(xId, clipValueMin, clipValueMax, outId); return out; } export const clipByValueConfig: KernelConfig = { - kernelName: 'ClipByValue', + kernelName: ClipByValue, backendName: 'wasm', setupFunc: setup, - kernelFunc: clip + kernelFunc: clip as {} as KernelFunc }; From e8da6d7b68466358ff58e3a7bddcced7ac0d0d30 Mon Sep 17 00:00:00 2001 From: Yannick Assogba Date: Mon, 13 Jul 2020 09:54:37 -0400 Subject: [PATCH 3/4] code review fixes --- tfjs-backend-wasm/src/kernels/ClipByValue.ts | 2 +- tfjs-core/src/gradients/ClipByValue_grad.ts | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/tfjs-backend-wasm/src/kernels/ClipByValue.ts b/tfjs-backend-wasm/src/kernels/ClipByValue.ts index 40aeaae3445..fe848574334 100644 --- a/tfjs-backend-wasm/src/kernels/ClipByValue.ts +++ b/tfjs-backend-wasm/src/kernels/ClipByValue.ts @@ -22,7 +22,7 @@ import {BackendWasm} from '../backend_wasm'; let wasmClip: (xId: number, min: number, max: number, outId: number) => void; function setup(backend: BackendWasm) { - wasmClip = backend.wasm.cwrap('ClipByValue', null /* void */, [ + wasmClip = backend.wasm.cwrap(ClipByValue, null /* void */, [ 'number', // x_id 'number', // min 'number', // max diff --git a/tfjs-core/src/gradients/ClipByValue_grad.ts b/tfjs-core/src/gradients/ClipByValue_grad.ts index 88c609dc838..a00e280d171 100644 --- a/tfjs-core/src/gradients/ClipByValue_grad.ts +++ b/tfjs-core/src/gradients/ClipByValue_grad.ts @@ -21,6 +21,7 @@ import {greaterEqual} from '../ops/greater_equal'; import {lessEqual} from '../ops/less_equal'; import {logicalAnd} from '../ops/logical_and'; import {zerosLike} from '../ops/tensor_ops'; +import {where} from '../ops/where'; import {Tensor} from '../tensor'; export const clipByValueGradConfig: GradConfig = { @@ -30,8 +31,8 @@ export const clipByValueGradConfig: GradConfig = { const [x] = saved; const {clipValueMin, clipValueMax} = attrs as {} as ClipByValueAttrs; return { - // tslint:disable-next-line: no-unnecessary-type-assertion - x: () => dy.where( + x: () => where( + dy, logicalAnd(greaterEqual(x, clipValueMin), lessEqual(x, clipValueMax)), zerosLike(dy)), }; From 99a2712df4eadf5fc927c230ffc06627d308dfa7 Mon Sep 17 00:00:00 2001 From: Yannick Assogba Date: Mon, 13 Jul 2020 10:17:24 -0400 Subject: [PATCH 4/4] fix where call --- tfjs-core/src/gradients/ClipByValue_grad.ts | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tfjs-core/src/gradients/ClipByValue_grad.ts b/tfjs-core/src/gradients/ClipByValue_grad.ts index a00e280d171..b13b2a7c77a 100644 --- a/tfjs-core/src/gradients/ClipByValue_grad.ts +++ b/tfjs-core/src/gradients/ClipByValue_grad.ts @@ -32,9 +32,8 @@ export const clipByValueGradConfig: GradConfig = { const {clipValueMin, clipValueMax} = attrs as {} as ClipByValueAttrs; return { x: () => where( - dy, logicalAnd(greaterEqual(x, clipValueMin), lessEqual(x, clipValueMax)), - zerosLike(dy)), + dy, zerosLike(dy)), }; } };