From a28d4a2dcb91737524c66f6307557e19dffa9a47 Mon Sep 17 00:00:00 2001 From: Na Li Date: Wed, 10 Jun 2020 14:31:41 -0700 Subject: [PATCH 1/3] Modularize relu6, leakyRelu, prelu, selu. --- tfjs-core/src/gradients/Elu_grad.ts | 29 ++ tfjs-core/src/gradients/Prelu_grad.ts | 47 ++ tfjs-core/src/gradients/Relu6_grad.ts | 34 ++ tfjs-core/src/gradients/Selu_grad.ts | 47 ++ tfjs-core/src/kernel_names.ts | 15 + tfjs-core/src/ops/elu.ts | 54 +++ tfjs-core/src/ops/elu_backprop_input.ts | 53 +++ tfjs-core/src/ops/elu_test.ts | 73 ++++ tfjs-core/src/ops/fused_ops.ts | 4 +- tfjs-core/src/ops/leaky_relu.ts | 48 +++ tfjs-core/src/ops/leaky_relu_test.ts | 112 +++++ tfjs-core/src/ops/ops.ts | 6 +- tfjs-core/src/ops/prelu.ts | 58 +++ tfjs-core/src/ops/relu6.ts | 60 +++ tfjs-core/src/ops/relu6_test.ts | 52 +++ tfjs-core/src/ops/relu_ops.ts | 197 --------- tfjs-core/src/ops/relu_test.ts | 128 ++++++ tfjs-core/src/ops/selu.ts | 56 +++ tfjs-core/src/ops/selu_test.ts | 140 ++++++ tfjs-core/src/ops/unary_ops_test.ts | 403 ------------------ tfjs-core/src/public/chained_ops/elu.ts | 30 ++ .../src/public/chained_ops/leaky_relu.ts | 31 ++ tfjs-core/src/public/chained_ops/prelu.ts | 31 ++ .../chained_ops/register_all_chained_ops.ts | 5 + .../register_all_chained_ops_test.ts | 5 + tfjs-core/src/public/chained_ops/relu6.ts | 30 ++ tfjs-core/src/public/chained_ops/selu.ts | 30 ++ tfjs-core/src/register_all_gradients.ts | 8 + tfjs-core/src/tensor.ts | 25 -- tfjs-core/src/tests.ts | 5 + 30 files changed, 1189 insertions(+), 627 deletions(-) create mode 100644 tfjs-core/src/gradients/Elu_grad.ts create mode 100644 tfjs-core/src/gradients/Prelu_grad.ts create mode 100644 tfjs-core/src/gradients/Relu6_grad.ts create mode 100644 tfjs-core/src/gradients/Selu_grad.ts create mode 100644 tfjs-core/src/ops/elu.ts create mode 100644 tfjs-core/src/ops/elu_backprop_input.ts create mode 100644 tfjs-core/src/ops/elu_test.ts create mode 100644 tfjs-core/src/ops/leaky_relu.ts create mode 100644 tfjs-core/src/ops/leaky_relu_test.ts create mode 100644 tfjs-core/src/ops/prelu.ts create mode 100644 tfjs-core/src/ops/relu6.ts create mode 100644 tfjs-core/src/ops/relu6_test.ts delete mode 100644 tfjs-core/src/ops/relu_ops.ts create mode 100644 tfjs-core/src/ops/relu_test.ts create mode 100644 tfjs-core/src/ops/selu.ts create mode 100644 tfjs-core/src/ops/selu_test.ts create mode 100644 tfjs-core/src/public/chained_ops/elu.ts create mode 100644 tfjs-core/src/public/chained_ops/leaky_relu.ts create mode 100644 tfjs-core/src/public/chained_ops/prelu.ts create mode 100644 tfjs-core/src/public/chained_ops/relu6.ts create mode 100644 tfjs-core/src/public/chained_ops/selu.ts diff --git a/tfjs-core/src/gradients/Elu_grad.ts b/tfjs-core/src/gradients/Elu_grad.ts new file mode 100644 index 00000000000..2957115a048 --- /dev/null +++ b/tfjs-core/src/gradients/Elu_grad.ts @@ -0,0 +1,29 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ +import {Elu} from '../kernel_names'; +import {GradConfig} from '../kernel_registry'; +import {eluBackpropInput} from '../ops/elu_backprop_input'; +import {Tensor} from '../tensor'; + +export const eluGradConfig: GradConfig = { + kernelName: Elu, + outputsToSave: [true], + gradFunc: (dy: Tensor, saved: Tensor[]) => { + const [y] = saved; + return {x: () => eluBackpropInput(dy, y)}; + } +}; diff --git a/tfjs-core/src/gradients/Prelu_grad.ts b/tfjs-core/src/gradients/Prelu_grad.ts new file mode 100644 index 00000000000..452ee27150f --- /dev/null +++ b/tfjs-core/src/gradients/Prelu_grad.ts @@ -0,0 +1,47 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ +import {Prelu} from '../kernel_names'; +import {GradConfig} from '../kernel_registry'; +import {reshape} from '../ops/array_ops'; +import {getReductionAxes} from '../ops/broadcast_util'; +import {greater} from '../ops/greater'; +import {where} from '../ops/logical_ops'; +import {mul} from '../ops/mul'; +import {sum} from '../ops/reduction_ops'; +import {zerosLike} from '../ops/tensor_ops'; +import {Tensor} from '../tensor'; + +export const preluGradConfig: GradConfig = { + kernelName: Prelu, + inputsToSave: ['x', 'alpha'], + gradFunc: (dy: Tensor, saved: Tensor[]) => { + const [x, alpha] = saved; + const mask = greater(x, 0); + + return { + x: () => where(mask, dy, mul(dy, alpha)), + alpha: () => { + let res = where(mask, zerosLike(dy), mul(dy, x)); + const reduceAxes = getReductionAxes(alpha.shape, dy.shape); + if (reduceAxes.length > 0) { + res = sum(res, reduceAxes); + } + return reshape(res, alpha.shape); + } + }; + } +}; diff --git a/tfjs-core/src/gradients/Relu6_grad.ts b/tfjs-core/src/gradients/Relu6_grad.ts new file mode 100644 index 00000000000..d2db845b57a --- /dev/null +++ b/tfjs-core/src/gradients/Relu6_grad.ts @@ -0,0 +1,34 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ +import {Relu6} from '../kernel_names'; +import {GradConfig} from '../kernel_registry'; +import {cast} from '../ops/array_ops'; +import {lessEqual} from '../ops/less_equal'; +import {mul} from '../ops/mul'; +import {step} from '../ops/unary_ops'; +import {Tensor} from '../tensor'; + +export const relu6GradConfig: GradConfig = { + kernelName: Relu6, + inputsToSave: ['x'], + gradFunc: (dy: Tensor, saved: Tensor[]) => { + const [x] = saved; + const mask = mul(lessEqual(x, 6), step(x)); + + return {x: () => mul(dy, cast(mask, 'float32'))}; + } +}; diff --git a/tfjs-core/src/gradients/Selu_grad.ts b/tfjs-core/src/gradients/Selu_grad.ts new file mode 100644 index 00000000000..91021e65e53 --- /dev/null +++ b/tfjs-core/src/gradients/Selu_grad.ts @@ -0,0 +1,47 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ +import {Selu} from '../kernel_names'; +import {GradConfig} from '../kernel_registry'; +import {cast} from '../ops/array_ops'; +import {where} from '../ops/logical_ops'; +import {mul} from '../ops/mul'; +import {SELU_SCALE, SELU_SCALEALPHA} from '../ops/selu_util'; +import {scalar} from '../ops/tensor_ops'; +import {exp} from '../ops/unary_ops'; +import {Tensor} from '../tensor'; + +export const seluGradConfig: GradConfig = { + kernelName: Selu, + inputsToSave: ['x'], + gradFunc: (dy: Tensor, saved: Tensor[]) => { + const [x] = saved; + return { + x: () => { + const mask = x.greater(scalar(0)); + + const scaleAlpha = scalar(SELU_SCALEALPHA); + const scale = scalar(SELU_SCALE); + + const greaterThanZeroDer = mul(dy, scale); + const lessEqualZeroDer = + mul(mul(dy, scaleAlpha), exp(cast(x, 'float32'))); + + return where(mask, greaterThanZeroDer, lessEqualZeroDer); + } + }; + } +}; diff --git a/tfjs-core/src/kernel_names.ts b/tfjs-core/src/kernel_names.ts index 87d3d771d35..192956144a4 100644 --- a/tfjs-core/src/kernel_names.ts +++ b/tfjs-core/src/kernel_names.ts @@ -193,6 +193,12 @@ export type DiagInputs = Pick; export const Div = 'Div'; export type DivInputs = BinaryInputs; +export const Elu = 'Elu'; +export type EluInputs = Pick; + +export const EluBackpropInput = 'EluBackpropInput'; +export type EluBackpropInputInputs = Pick; + export const Equal = 'Equal'; export type EqualInputs = BinaryInputs; @@ -364,15 +370,24 @@ export type PoolInputs = Pick; export const Pow = 'Pow'; export type PowInputs = BinaryInputs; +export const Prelu = 'Prelu'; +export type PreluInputs = Pick; + export const Real = 'Real'; export type RealInputs = Pick; export const Relu = 'Relu'; export type ReluInputs = Pick; +export const Relu6 = 'Relu6'; +export type Relu6Inputs = Pick; + export const SelectV2 = 'SelectV2'; export type SelectV2Inputs = Pick; +export const Selu = 'Selu'; +export type SeluInputs = Pick; + export const SpaceToBatchND = 'SpaceToBatchND'; export type SpaceToBatchNDInputs = Pick; export interface SpaceToBatchNDAttrs { diff --git a/tfjs-core/src/ops/elu.ts b/tfjs-core/src/ops/elu.ts new file mode 100644 index 00000000000..24cd19da830 --- /dev/null +++ b/tfjs-core/src/ops/elu.ts @@ -0,0 +1,54 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {ENGINE, ForwardFunc} from '../engine'; +import {Elu, EluInputs} from '../kernel_names'; +import {Tensor} from '../tensor'; +import {NamedTensorMap} from '../tensor_types'; +import {convertToTensor} from '../tensor_util_env'; +import {TensorLike} from '../types'; + +import {op} from './operation'; + +/** + * Computes exponential linear element-wise: `x > 0 ? e ^ x - 1 : 0`. + * + * ```js + * const x = tf.tensor1d([-1, 1, -3, 2]); + * + * x.elu().print(); // or tf.elu(x) + * ``` + * @param x The input tensor. + */ +/** @doc {heading: 'Operations', subheading: 'Basic math'} */ +function elu_(x: T|TensorLike): T { + const $x = convertToTensor(x, 'x', 'elu'); + + const forward: ForwardFunc = (backend, save) => { + const y = backend.elu($x); + save([y]); + return y; + }; + + const inputs: EluInputs = {x: $x}; + + return ENGINE.runKernelFunc( + forward, inputs as {} as NamedTensorMap, null /* grad */, Elu) as + T; +} + +export const elu = op({elu_}); diff --git a/tfjs-core/src/ops/elu_backprop_input.ts b/tfjs-core/src/ops/elu_backprop_input.ts new file mode 100644 index 00000000000..bdbbfb5c924 --- /dev/null +++ b/tfjs-core/src/ops/elu_backprop_input.ts @@ -0,0 +1,53 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {ENGINE, ForwardFunc} from '../engine'; +import {EluBackpropInput, EluBackpropInputInputs} from '../kernel_names'; +import {Tensor} from '../tensor'; +import {NamedTensorMap} from '../tensor_types'; + +import {op} from './operation'; + +/** + * Computes exponential linear element-wise: `x > 0 ? e ^ x - 1 : 0`. + * + * ```js + * const x = tf.tensor1d([-1, 1, -3, 2]); + * + * x.elu().print(); // or tf.elu(x) + * ``` + * @param x The input tensor. + */ +/** + * Computes the derivative of the input of a elu. + * + * @param dy The derivative of the output. + * @param y The output of the forward function. + */ +function eluBackpropInput_(dy: T, y: T): T { + const forward: ForwardFunc = (backend) => { + return backend.eluDer(dy, y); + }; + + const inputs: EluBackpropInputInputs = {dy, y}; + + return ENGINE.runKernelFunc( + forward, inputs as {} as NamedTensorMap, null /* grad */, + EluBackpropInput) as T; +} + +export const eluBackpropInput = op({eluBackpropInput_}); diff --git a/tfjs-core/src/ops/elu_test.ts b/tfjs-core/src/ops/elu_test.ts new file mode 100644 index 00000000000..592f361f887 --- /dev/null +++ b/tfjs-core/src/ops/elu_test.ts @@ -0,0 +1,73 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import * as tf from '../index'; +import {ALL_ENVS, describeWithFlags} from '../jasmine_util'; +import {expectArraysClose} from '../test_util'; + +describeWithFlags('elu', ALL_ENVS, () => { + it('calculate elu', async () => { + const a = tf.tensor1d([1, -1, 0]); + const result = tf.elu(a); + + expect(result.shape).toEqual(a.shape); + expectArraysClose(await result.data(), [1, -0.6321, 0]); + }); + + it('elu propagates NaN', async () => { + const a = tf.tensor1d([1, NaN]); + const result = tf.elu(a); + expect(result.shape).toEqual(a.shape); + expectArraysClose(await result.data(), [1, NaN]); + }); + + it('derivative', async () => { + const x = tf.tensor1d([1, 3, -2]); + const dy = tf.tensor1d([5, 50, 500]); + const gradients = tf.grad(a => tf.elu(a))(x, dy); + + expect(gradients.shape).toEqual(x.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), [5, 50, 500 * Math.exp(-2)]); + }); + + it('gradient with clones', async () => { + const x = tf.tensor1d([1, 3, -2]); + const dy = tf.tensor1d([5, 50, 500]); + const gradients = tf.grad(a => tf.elu(a.clone()).clone())(x, dy); + + expect(gradients.shape).toEqual(x.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), [5, 50, 500 * Math.exp(-2)]); + }); + + it('throws when passed a non-tensor', () => { + expect(() => tf.elu({} as tf.Tensor)) + .toThrowError(/Argument 'x' passed to 'elu' must be a Tensor/); + }); + + it('accepts a tensor-like object', async () => { + const result = tf.elu([1, -1, 0]); + expect(result.shape).toEqual(result.shape); + expectArraysClose(await result.data(), [1, -0.6321, 0]); + }); + + it('throws for string tensor', () => { + expect(() => tf.elu('q')) + .toThrowError(/Argument 'x' passed to 'elu' must be numeric/); + }); +}); diff --git a/tfjs-core/src/ops/fused_ops.ts b/tfjs-core/src/ops/fused_ops.ts index f3e4d5e4736..7fe81be0c45 100644 --- a/tfjs-core/src/ops/fused_ops.ts +++ b/tfjs-core/src/ops/fused_ops.ts @@ -32,10 +32,12 @@ import {conv2DBackpropInput} from './conv2d_backprop_input'; import {depthwiseConv2d as unfusedDepthwiseConv2d} from './depthwise_conv2d'; import {depthwiseConv2dNativeBackpropFilter} from './depthwise_conv2d_native_backprop_filter'; import {depthwiseConv2dNativeBackpropInput} from './depthwise_conv2d_native_backprop_input'; +import {elu} from './elu'; import {Activation, shouldFuse} from './fused_util'; import {matMul as unfusedMatMul} from './mat_mul'; +import {prelu} from './prelu'; import {relu} from './relu'; -import {elu, prelu, relu6} from './relu_ops'; +import {relu6} from './relu6'; // Returns gradient for fused activation. const getFusedDyActivation = diff --git a/tfjs-core/src/ops/leaky_relu.ts b/tfjs-core/src/ops/leaky_relu.ts new file mode 100644 index 00000000000..7f90d9b5c1e --- /dev/null +++ b/tfjs-core/src/ops/leaky_relu.ts @@ -0,0 +1,48 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Tensor} from '../tensor'; +import {convertToTensor} from '../tensor_util_env'; +import {TensorLike} from '../types'; + +import {maximum} from './maximum'; +import {mul} from './mul'; +import {op} from './operation'; +import {scalar} from './tensor_ops'; + +/** + * Computes leaky rectified linear element-wise. + * + * See + * [http://web.stanford.edu/~awni/papers/relu_hybrid_icml2013_final.pdf]( + * http://web.stanford.edu/~awni/papers/relu_hybrid_icml2013_final.pdf) + * + * ```js + * const x = tf.tensor1d([-1, 2, -3, 4]); + * + * x.leakyRelu(0.1).print(); // or tf.leakyRelu(x, 0.1) + * ``` + * @param x The input tensor. + * @param alpha The scaling factor for negative values, defaults to 0.2. + */ +/** @doc {heading: 'Operations', subheading: 'Basic math'} */ +function leakyRelu_(x: T|TensorLike, alpha = 0.2): T { + const $x = convertToTensor(x, 'x', 'leakyRelu'); + return maximum(mul(scalar(alpha), $x), $x); +} + +export const leakyRelu = op({leakyRelu_}); diff --git a/tfjs-core/src/ops/leaky_relu_test.ts b/tfjs-core/src/ops/leaky_relu_test.ts new file mode 100644 index 00000000000..604fb135354 --- /dev/null +++ b/tfjs-core/src/ops/leaky_relu_test.ts @@ -0,0 +1,112 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import * as tf from '../index'; +import {ALL_ENVS, describeWithFlags} from '../jasmine_util'; +import {expectArraysClose} from '../test_util'; + +describeWithFlags('leakyRelu', ALL_ENVS, () => { + it('basic', async () => { + const a = tf.tensor1d([0, 1, -2]); + const result = tf.leakyRelu(a); + + expect(result.shape).toEqual(a.shape); + expectArraysClose(await result.data(), [0, 1, -0.4]); + }); + + it('propagates NaN', async () => { + const a = tf.tensor1d([0, 1, NaN]); + const result = tf.leakyRelu(a); + + expect(result.shape).toEqual(a.shape); + expectArraysClose(await result.data(), [0, 1, NaN]); + }); + + it('gradients: Scalar', async () => { + const a = tf.scalar(-4); + const dy = tf.scalar(8); + const alpha = 0.1; + + const gradients = tf.grad((a) => tf.leakyRelu(a, alpha))(a, dy); + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), [8 * alpha]); + }); + + it('gradient with clones', async () => { + const a = tf.scalar(-4); + const dy = tf.scalar(8); + const alpha = 0.1; + + const gradients = + tf.grad((a) => tf.leakyRelu(a.clone(), alpha).clone())(a, dy); + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), [8 * alpha]); + }); + + it('gradients: Tensor1D', async () => { + const aValues = [1, -1, 0.1]; + const dyValues = [1, 2, 3]; + const alpha = 0.1; + + const a = tf.tensor1d(aValues); + const dy = tf.tensor1d(dyValues); + + const gradients = tf.grad((a) => tf.leakyRelu(a, alpha))(a, dy); + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + + expectArraysClose(await gradients.data(), [1, 2 * alpha, 3]); + }); + + it('gradients: Tensor2D', async () => { + const aValues = [1, -1, 0.1, 0.5]; + const dyValues = [1, 2, 3, 4]; + const alpha = 0.1; + + const a = tf.tensor2d(aValues, [2, 2]); + const dy = tf.tensor2d(dyValues, [2, 2]); + + const gradients = tf.grad((a) => tf.leakyRelu(a, alpha))(a, dy); + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + + expectArraysClose(await gradients.data(), [1, 2 * alpha, 3, 4]); + }); + + it('throws when passed a non-tensor', () => { + expect(() => tf.leakyRelu({} as tf.Tensor)) + .toThrowError(/Argument 'x' passed to 'leakyRelu' must be a Tensor/); + }); + + it('accepts a tensor-like object', async () => { + const result = tf.leakyRelu([0, 1, -2]); + + expect(result.shape).toEqual([3]); + expectArraysClose(await result.data(), [0, 1, -0.4]); + }); + + it('throws for string tensor', () => { + expect(() => tf.leakyRelu('q')) + .toThrowError(/Argument 'x' passed to 'leakyRelu' must be numeric/); + }); +}); diff --git a/tfjs-core/src/ops/ops.ts b/tfjs-core/src/ops/ops.ts index d8671048ed8..2747d48f9a8 100644 --- a/tfjs-core/src/ops/ops.ts +++ b/tfjs-core/src/ops/ops.ts @@ -46,6 +46,7 @@ export {diag} from './diag'; export {div} from './div'; export {divNoNan} from './div_no_nan'; export {dot} from './dot'; +export {elu} from './elu'; export {equal} from './equal'; export {eye} from './eye'; export {fill} from './fill'; @@ -53,6 +54,7 @@ export {floorDiv} from './floorDiv'; export {greater} from './greater'; export {greaterEqual} from './greater_equal'; export {imag} from './imag'; +export {leakyRelu} from './leaky_relu'; export {less} from './less'; export {lessEqual} from './less_equal'; export {localResponseNormalization} from './local_response_normalization'; @@ -76,12 +78,15 @@ export {pad3d} from './pad3d'; export {pad4d} from './pad4d'; export {pool} from './pool'; export {pow} from './pow'; +export {prelu} from './prelu'; export {rand} from './rand'; export {randomGamma} from './random_gamma'; export {randomNormal} from './random_normal'; export {randomUniform} from './random_uniform'; export {real} from './real'; export {relu} from './relu'; +export {relu6} from './relu6'; +export {selu} from './selu'; export {separableConv2d} from './separable_conv2d'; export {spaceToBatchND} from './space_to_batch_nd'; export {split} from './split'; @@ -98,7 +103,6 @@ export * from './unary_ops'; export * from './reduction_ops'; export * from './compare'; export * from './binary_ops'; -export * from './relu_ops'; export * from './logical_ops'; export * from './array_ops'; export * from './tensor_ops'; diff --git a/tfjs-core/src/ops/prelu.ts b/tfjs-core/src/ops/prelu.ts new file mode 100644 index 00000000000..35a12a27a20 --- /dev/null +++ b/tfjs-core/src/ops/prelu.ts @@ -0,0 +1,58 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {ENGINE, ForwardFunc} from '../engine'; +import {Prelu, PreluInputs} from '../kernel_names'; +import {Tensor} from '../tensor'; +import {NamedTensorMap} from '../tensor_types'; +import {convertToTensor} from '../tensor_util_env'; +import {TensorLike} from '../types'; + +import {op} from './operation'; + +/** + * Computes leaky rectified linear element-wise with parametric alphas. + * + * `x < 0 ? alpha * x : f(x) = x` + * + * ```js + * const x = tf.tensor1d([-1, 2, -3, 4]); + * const alpha = tf.scalar(0.1); + * + * x.prelu(alpha).print(); // or tf.prelu(x, alpha) + * ``` + * @param x The input tensor. + * @param alpha Scaling factor for negative values. + */ +/** @doc {heading: 'Operations', subheading: 'Basic math'} */ +function prelu_(x: T|TensorLike, alpha: T|TensorLike): T { + const $x = convertToTensor(x, 'x', 'prelu'); + const $alpha = convertToTensor(alpha, 'alpha', 'prelu'); + + const forward: ForwardFunc = (backend, save) => { + const res = backend.prelu($x, $alpha); + save([$x, $alpha]); + return res; + }; + + const inputs: PreluInputs = {x: $x, alpha: $alpha}; + return ENGINE.runKernelFunc( + forward, inputs as {} as NamedTensorMap, null /* grad */, Prelu) as + T; +} + +export const prelu = op({prelu_}); diff --git a/tfjs-core/src/ops/relu6.ts b/tfjs-core/src/ops/relu6.ts new file mode 100644 index 00000000000..d426d4903b2 --- /dev/null +++ b/tfjs-core/src/ops/relu6.ts @@ -0,0 +1,60 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {ENGINE, ForwardFunc} from '../engine'; +import {Relu6, Relu6Inputs} from '../kernel_names'; +import {Tensor} from '../tensor'; +import {NamedTensorMap} from '../tensor_types'; +import {convertToTensor} from '../tensor_util_env'; +import {TensorLike} from '../types'; + +import {cast} from './array_ops'; +import {op} from './operation'; + +/** + * Computes rectified linear 6 element-wise: `min(max(x, 0), 6)`. + * + * ```js + * const x = tf.tensor1d([-1, 2, -3, 8]); + * + * x.relu6().print(); // or tf.relu6(x) + * ``` + * @param x The input tensor. If the dtype is `bool`, the output dtype will be + * `int32'. + */ +/** @doc {heading: 'Operations', subheading: 'Basic math'} */ +function relu6_(x: T|TensorLike): T { + const $x = convertToTensor(x, 'x', 'relu6'); + + const forward: ForwardFunc = (backend, save) => { + save([$x]); + + if ($x.dtype === 'bool') { + return cast($x, 'int32'); + } + + return backend.relu6($x); + }; + + const inputs: Relu6Inputs = {x: $x}; + + return ENGINE.runKernelFunc( + forward, inputs as {} as NamedTensorMap, null /* grad */, Relu6) as + T; +} + +export const relu6 = op({relu6_}); diff --git a/tfjs-core/src/ops/relu6_test.ts b/tfjs-core/src/ops/relu6_test.ts new file mode 100644 index 00000000000..c44ce027caa --- /dev/null +++ b/tfjs-core/src/ops/relu6_test.ts @@ -0,0 +1,52 @@ +/** + * @license + * Copyright 2017 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import * as tf from '../index'; +import {ALL_ENVS, describeWithFlags} from '../jasmine_util'; +import {expectArraysClose} from '../test_util'; + +describeWithFlags('relu6', ALL_ENVS, () => { + it('basic relu6', async () => { + const a = tf.tensor1d([1, -2, 0, 8, -0.1]); + const result = tf.relu6(a); + expectArraysClose(await result.data(), [1, 0, 0, 6, 0]); + }); + + it('gradients: relu6', async () => { + const a = tf.scalar(8); + const dy = tf.scalar(5); + + const grad = tf.grad(a => tf.relu6(a)); + const da = grad(a, dy); + + expect(da.shape).toEqual(a.shape); + expect(da.dtype).toEqual('float32'); + expectArraysClose(await da.data(), [0]); + }); + + it('gradients: relu6 array', async () => { + const a = tf.tensor2d([8, -1, 0, .1], [2, 2]); + const dy = tf.tensor2d([1, 2, 3, 4], [2, 2]); + + const grad = tf.grad(a => tf.relu6(a)); + const da = grad(a, dy); + + expect(da.shape).toEqual(a.shape); + expect(da.dtype).toEqual('float32'); + expectArraysClose(await da.data(), [0, 0, 0, 4]); + }); +}); diff --git a/tfjs-core/src/ops/relu_ops.ts b/tfjs-core/src/ops/relu_ops.ts deleted file mode 100644 index 61119e2eeac..00000000000 --- a/tfjs-core/src/ops/relu_ops.ts +++ /dev/null @@ -1,197 +0,0 @@ -/** - * @license - * Copyright 2018 Google LLC. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - -import {ENGINE} from '../engine'; -import {Tensor} from '../tensor'; -import {convertToTensor} from '../tensor_util_env'; -import {TensorLike} from '../types'; - -import {getReductionAxes} from './broadcast_util'; -import {where} from './logical_ops'; -import {maximum} from './maximum'; -import {mul} from './mul'; -import {op} from './operation'; -import {SELU_SCALE, SELU_SCALEALPHA} from './selu_util'; -import {scalar, zerosLike} from './tensor_ops'; - -/** - * Computes rectified linear 6 element-wise: `min(max(x, 0), 6)`. - * - * ```js - * const x = tf.tensor1d([-1, 2, -3, 8]); - * - * x.relu6().print(); // or tf.relu6(x) - * ``` - * @param x The input tensor. If the dtype is `bool`, the output dtype will be - * `int32'. - */ -/** @doc {heading: 'Operations', subheading: 'Basic math'} */ -function relu6_(x: T|TensorLike): T { - const $x = convertToTensor(x, 'x', 'relu6'); - - if ($x.dtype === 'bool') { - return $x.toInt(); - } - const grad = (dy: T, saved: Tensor[]) => { - const [$x] = saved; - const mask = $x.lessEqual(6).mul($x.step()); - // tslint:disable-next-line: no-unnecessary-type-assertion - return {x: () => mul(dy, mask.toFloat()) as T}; - }; - return ENGINE.runKernelFunc((backend, save) => { - const res = backend.relu6($x); - save([$x]); - return res; - }, {x: $x}, grad, 'Relu6'); -} - -/** - * Computes exponential linear element-wise: `x > 0 ? e ^ x - 1 : 0`. - * - * ```js - * const x = tf.tensor1d([-1, 1, -3, 2]); - * - * x.elu().print(); // or tf.elu(x) - * ``` - * @param x The input tensor. - */ -/** @doc {heading: 'Operations', subheading: 'Basic math'} */ -function elu_(x: T|TensorLike): T { - const $x = convertToTensor(x, 'x', 'elu'); - - const grad = (dy: T, saved: Tensor[]) => { - const [y] = saved; - return { - $x: () => - ENGINE.runKernelFunc(backend => backend.eluDer(dy, y), {dy, y}) as T - }; - }; - return ENGINE.runKernelFunc((backend, save) => { - const y = backend.elu($x); - save([y]); - return y; - }, {$x}, grad); -} - -/** - * Computes scaled exponential linear element-wise. - * - * `x < 0 ? scale * alpha * (exp(x) - 1) : x` - * - * ```js - * const x = tf.tensor1d([-1, 2, -3, 4]); - * - * x.selu().print(); // or tf.selu(x) - * ``` - * @param x The input tensor. - */ -/** @doc {heading: 'Operations', subheading: 'Basic math'} */ -function selu_(x: T|TensorLike): T { - const $x = convertToTensor(x, 'x', 'selu'); - - const grad = (dy: T, saved: Tensor[]) => { - const [$x] = saved; - return { - $x: () => { - const mask = $x.greater(scalar(0)); - - const scaleAlpha = scalar(SELU_SCALEALPHA); - const scale = scalar(SELU_SCALE); - - const greaterThanZeroDer = dy.mul(scale); - const lessEqualZeroDer = dy.mul(scaleAlpha).mul($x.toFloat().exp()); - - return where(mask, greaterThanZeroDer, lessEqualZeroDer) as T; - } - }; - }; - return ENGINE.runKernelFunc((backend, save) => { - const res = backend.selu($x); - save([$x]); - return res; - }, {$x}, grad); -} - -/** - * Computes leaky rectified linear element-wise. - * - * See - * [http://web.stanford.edu/~awni/papers/relu_hybrid_icml2013_final.pdf]( - * http://web.stanford.edu/~awni/papers/relu_hybrid_icml2013_final.pdf) - * - * ```js - * const x = tf.tensor1d([-1, 2, -3, 4]); - * - * x.leakyRelu(0.1).print(); // or tf.leakyRelu(x, 0.1) - * ``` - * @param x The input tensor. - * @param alpha The scaling factor for negative values, defaults to 0.2. - */ -/** @doc {heading: 'Operations', subheading: 'Basic math'} */ -function leakyRelu_(x: T|TensorLike, alpha = 0.2): T { - const $x = convertToTensor(x, 'x', 'leakyRelu'); - return maximum(scalar(alpha).mul($x), $x); -} - -/** - * Computes leaky rectified linear element-wise with parametric alphas. - * - * `x < 0 ? alpha * x : f(x) = x` - * - * ```js - * const x = tf.tensor1d([-1, 2, -3, 4]); - * const alpha = tf.scalar(0.1); - * - * x.prelu(alpha).print(); // or tf.prelu(x, alpha) - * ``` - * @param x The input tensor. - * @param alpha Scaling factor for negative values. - */ -/** @doc {heading: 'Operations', subheading: 'Basic math'} */ -function prelu_(x: T|TensorLike, alpha: T|TensorLike): T { - const $x = convertToTensor(x, 'x', 'prelu'); - const $alpha = convertToTensor(alpha, 'alpha', 'prelu'); - - const grad = (dy: Tensor, saved: Tensor[]) => { - const [$x, $alpha] = saved; - const mask = $x.greater(0); - - return { - x: () => where(mask, dy, dy.mul($alpha)) as T, - alpha: () => { - let res = where(mask, zerosLike(dy), dy.mul($x)); - const reduceAxes = getReductionAxes($alpha.shape, dy.shape); - if (reduceAxes.length > 0) { - res = res.sum(reduceAxes); - } - return res.reshape($alpha.shape) as T; - } - }; - }; - - return ENGINE.runKernelFunc((backend, save) => { - const res = backend.prelu($x, $alpha); - save([$x, $alpha]); - return res; - }, {x: $x, alpha: $alpha}, grad, 'Prelu') as T; -} - -export const elu = op({elu_}); -export const leakyRelu = op({leakyRelu_}); -export const prelu = op({prelu_}); -export const relu6 = op({relu6_}); -export const selu = op({selu_}); diff --git a/tfjs-core/src/ops/relu_test.ts b/tfjs-core/src/ops/relu_test.ts new file mode 100644 index 00000000000..42ce232a71c --- /dev/null +++ b/tfjs-core/src/ops/relu_test.ts @@ -0,0 +1,128 @@ +/** + * @license + * Copyright 2017 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import * as tf from '../index'; +import {ALL_ENVS, describeWithFlags} from '../jasmine_util'; +import {expectArraysClose} from '../test_util'; + +describeWithFlags('relu', ALL_ENVS, () => { + it('basic', async () => { + const a = tf.tensor1d([1, -2, 0, 3, -0.1]); + const result = tf.relu(a); + expectArraysClose(await result.data(), [1, 0, 0, 3, 0]); + }); + + it('5D', async () => { + const a = tf.tensor5d([1, -2, 5, -3], [1, 2, 2, 1, 1]); + const result = tf.relu(a); + expectArraysClose(await result.data(), [1, 0, 5, 0]); + }); + + it('6D', async () => { + const a = tf.tensor6d([1, -2, 5, -3, -1, 4, 7, 8], [1, 2, 2, 2, 1, 1]); + const result = tf.relu(a); + expectArraysClose(await result.data(), [1, 0, 5, 0, 0, 4, 7, 8]); + }); + + it('does nothing to positive values', async () => { + const a = tf.scalar(1); + const result = tf.relu(a); + expectArraysClose(await result.data(), [1]); + }); + + it('sets negative values to 0', async () => { + const a = tf.scalar(-1); + const result = tf.relu(a); + expectArraysClose(await result.data(), [0]); + }); + + it('preserves zero values', async () => { + const a = tf.scalar(0); + const result = tf.relu(a); + expectArraysClose(await result.data(), [0]); + }); + + it('propagates NaNs, float32', async () => { + const a = tf.tensor1d([1, -2, 0, 3, -0.1, NaN]); + const result = tf.relu(a); + expect(result.dtype).toBe('float32'); + expectArraysClose(await result.data(), [1, 0, 0, 3, 0, NaN]); + }); + + it('gradients: positive scalar', async () => { + const a = tf.scalar(3); + const dy = tf.scalar(5); + + const grad = tf.grad(a => tf.relu(a)); + const da = grad(a, dy); + + expect(da.shape).toEqual(a.shape); + expect(da.dtype).toEqual('float32'); + expectArraysClose(await da.data(), [5]); + }); + + it('gradient with clones', async () => { + const a = tf.scalar(3); + const dy = tf.scalar(5); + + const grad = tf.grad(a => tf.relu(a.clone()).clone()); + const da = grad(a, dy); + + expect(da.shape).toEqual(a.shape); + expect(da.dtype).toEqual('float32'); + expectArraysClose(await da.data(), [5]); + }); + + it('gradients: negative scalar', async () => { + const a = tf.scalar(-3); + const dy = tf.scalar(5); + + const grad = tf.grad(a => tf.relu(a)); + const da = grad(a, dy); + + expect(da.shape).toEqual(a.shape); + expect(da.dtype).toEqual('float32'); + expectArraysClose(await da.data(), [0]); + }); + + it('gradients: array', async () => { + const a = tf.tensor2d([1, -1, 0, .1], [2, 2]); + const dy = tf.tensor2d([1, 2, 3, 4], [2, 2]); + + const grad = tf.grad(a => tf.relu(a)); + const da = grad(a, dy); + + expect(da.shape).toEqual(a.shape); + expect(da.dtype).toEqual('float32'); + expectArraysClose(await da.data(), [1, 0, 0, 4]); + }); + + it('throws when passed a non-tensor', () => { + expect(() => tf.relu({} as tf.Tensor)) + .toThrowError(/Argument 'x' passed to 'relu' must be a Tensor/); + }); + + it('accepts a tensor-like object', async () => { + const result = tf.relu([1, -2, 0, 3, -0.1]); + expectArraysClose(await result.data(), [1, 0, 0, 3, 0]); + }); + + it('throws for string tensor', () => { + expect(() => tf.relu('q')) + .toThrowError(/Argument 'x' passed to 'relu' must be numeric/); + }); +}); diff --git a/tfjs-core/src/ops/selu.ts b/tfjs-core/src/ops/selu.ts new file mode 100644 index 00000000000..782048289db --- /dev/null +++ b/tfjs-core/src/ops/selu.ts @@ -0,0 +1,56 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {ENGINE, ForwardFunc} from '../engine'; +import {Selu, SeluInputs} from '../kernel_names'; +import {Tensor} from '../tensor'; +import {NamedTensorMap} from '../tensor_types'; +import {convertToTensor} from '../tensor_util_env'; +import {TensorLike} from '../types'; + +import {op} from './operation'; + +/** + * Computes scaled exponential linear element-wise. + * + * `x < 0 ? scale * alpha * (exp(x) - 1) : x` + * + * ```js + * const x = tf.tensor1d([-1, 2, -3, 4]); + * + * x.selu().print(); // or tf.selu(x) + * ``` + * @param x The input tensor. + */ +/** @doc {heading: 'Operations', subheading: 'Basic math'} */ +function selu_(x: T|TensorLike): T { + const $x = convertToTensor(x, 'x', 'selu'); + + const forward: ForwardFunc = (backend, save) => { + const res = backend.selu($x); + save([$x]); + return res; + }; + + const inputs: SeluInputs = {x: $x}; + + return ENGINE.runKernelFunc( + forward, inputs as {} as NamedTensorMap, null /* grad */, Selu) as + T; +} + +export const selu = op({selu_}); diff --git a/tfjs-core/src/ops/selu_test.ts b/tfjs-core/src/ops/selu_test.ts new file mode 100644 index 00000000000..db21fa89b58 --- /dev/null +++ b/tfjs-core/src/ops/selu_test.ts @@ -0,0 +1,140 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import * as tf from '../index'; +import {ALL_ENVS, describeWithFlags} from '../jasmine_util'; +import {expectArraysClose} from '../test_util'; + +import * as selu_util from './selu_util'; + +describeWithFlags('selu', ALL_ENVS, () => { + const scaleAlpha = selu_util.SELU_SCALEALPHA; + const scale = selu_util.SELU_SCALE; + + it('calculate selu', async () => { + const a = tf.tensor1d([1, -1, 0]); + const result = tf.selu(a); + + expect(result.shape).toEqual(a.shape); + expectArraysClose(await result.data(), [1.0507, -1.1113, 0]); + }); + + it('selu propagates NaN', async () => { + const a = tf.tensor1d([1, NaN]); + const result = tf.selu(a); + expect(result.shape).toEqual(a.shape); + expectArraysClose(await result.data(), [1.0507, NaN]); + }); + + it('gradients: Scalar', async () => { + let aValue = 1; + let dyValue = 1; + let a = tf.scalar(aValue); + let dy = tf.scalar(dyValue); + + let gradients = tf.grad(a => tf.selu(a))(a, dy); + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), [dyValue * scale]); + + aValue = -1; + dyValue = 2; + a = tf.scalar(aValue); + dy = tf.scalar(dyValue); + + gradients = tf.grad(a => tf.selu(a))(a, dy); + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose( + await gradients.data(), [dyValue * scaleAlpha * Math.exp(aValue)]); + }); + + it('gradient with clones', async () => { + const aValue = 1; + const dyValue = 1; + const a = tf.scalar(aValue); + const dy = tf.scalar(dyValue); + + const gradients = tf.grad(a => tf.selu(a.clone()).clone())(a, dy); + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), [dyValue * scale]); + }); + + it('gradients: Tensor1D', async () => { + const aValues = [1, -1, 0]; + const dyValues = [1, 2, 3]; + const a = tf.tensor1d(aValues); + const dy = tf.tensor1d(dyValues); + + const gradients = tf.grad(a => tf.selu(a))(a, dy); + + const expected = []; + for (let i = 0; i < a.size; i++) { + if (aValues[i] > 0) { + expected[i] = dyValues[i] * scale; + } else { + expected[i] = dyValues[i] * scaleAlpha * Math.exp(aValues[i]); + } + } + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), expected); + }); + + it('gradients: Tensor2D', async () => { + const aValues = [1, -1, 0, 0.5]; + const dyValues = [1, 2, 3, 4]; + const a = tf.tensor2d(aValues, [2, 2]); + const dy = tf.tensor2d(dyValues, [2, 2]); + + const gradients = tf.grad(a => tf.selu(a))(a, dy); + + const expected = []; + for (let i = 0; i < a.size; i++) { + if (aValues[i] > 0) { + expected[i] = dyValues[i] * scale; + } else { + expected[i] = dyValues[i] * scaleAlpha * Math.exp(aValues[i]); + } + } + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), expected); + }); + + it('throws when passed a non-tensor', () => { + expect(() => tf.selu({} as tf.Tensor)) + .toThrowError(/Argument 'x' passed to 'selu' must be a Tensor/); + }); + + it('accepts a tensor-like object', async () => { + const result = tf.selu([1, -1, 0]); + expect(result.shape).toEqual([3]); + expectArraysClose(await result.data(), [1.0507, -1.1113, 0]); + }); + + it('throws for string tensor', () => { + expect(() => tf.selu('q')) + .toThrowError(/Argument 'x' passed to 'selu' must be numeric/); + }); +}); diff --git a/tfjs-core/src/ops/unary_ops_test.ts b/tfjs-core/src/ops/unary_ops_test.ts index 488bfcef657..6141675c1ff 100644 --- a/tfjs-core/src/ops/unary_ops_test.ts +++ b/tfjs-core/src/ops/unary_ops_test.ts @@ -20,146 +20,6 @@ import {ALL_ENVS, describeWithFlags} from '../jasmine_util'; import {expectArraysClose, TEST_EPSILON_FLOAT16} from '../test_util'; import * as util from '../util'; -import * as selu_util from './selu_util'; - -describeWithFlags('relu', ALL_ENVS, () => { - it('basic', async () => { - const a = tf.tensor1d([1, -2, 0, 3, -0.1]); - const result = tf.relu(a); - expectArraysClose(await result.data(), [1, 0, 0, 3, 0]); - }); - - it('basic relu6', async () => { - const a = tf.tensor1d([1, -2, 0, 8, -0.1]); - const result = tf.relu6(a); - expectArraysClose(await result.data(), [1, 0, 0, 6, 0]); - }); - - it('5D', async () => { - const a = tf.tensor5d([1, -2, 5, -3], [1, 2, 2, 1, 1]); - const result = tf.relu(a); - expectArraysClose(await result.data(), [1, 0, 5, 0]); - }); - - it('6D', async () => { - const a = tf.tensor6d([1, -2, 5, -3, -1, 4, 7, 8], [1, 2, 2, 2, 1, 1]); - const result = tf.relu(a); - expectArraysClose(await result.data(), [1, 0, 5, 0, 0, 4, 7, 8]); - }); - - it('does nothing to positive values', async () => { - const a = tf.scalar(1); - const result = tf.relu(a); - expectArraysClose(await result.data(), [1]); - }); - - it('sets negative values to 0', async () => { - const a = tf.scalar(-1); - const result = tf.relu(a); - expectArraysClose(await result.data(), [0]); - }); - - it('preserves zero values', async () => { - const a = tf.scalar(0); - const result = tf.relu(a); - expectArraysClose(await result.data(), [0]); - }); - - it('propagates NaNs, float32', async () => { - const a = tf.tensor1d([1, -2, 0, 3, -0.1, NaN]); - const result = tf.relu(a); - expect(result.dtype).toBe('float32'); - expectArraysClose(await result.data(), [1, 0, 0, 3, 0, NaN]); - }); - - it('gradients: positive scalar', async () => { - const a = tf.scalar(3); - const dy = tf.scalar(5); - - const grad = tf.grad(a => tf.relu(a)); - const da = grad(a, dy); - - expect(da.shape).toEqual(a.shape); - expect(da.dtype).toEqual('float32'); - expectArraysClose(await da.data(), [5]); - }); - - it('gradients: relu6', async () => { - const a = tf.scalar(8); - const dy = tf.scalar(5); - - const grad = tf.grad(a => tf.relu6(a)); - const da = grad(a, dy); - - expect(da.shape).toEqual(a.shape); - expect(da.dtype).toEqual('float32'); - expectArraysClose(await da.data(), [0]); - }); - - it('gradient with clones', async () => { - const a = tf.scalar(3); - const dy = tf.scalar(5); - - const grad = tf.grad(a => tf.relu(a.clone()).clone()); - const da = grad(a, dy); - - expect(da.shape).toEqual(a.shape); - expect(da.dtype).toEqual('float32'); - expectArraysClose(await da.data(), [5]); - }); - - it('gradients: negative scalar', async () => { - const a = tf.scalar(-3); - const dy = tf.scalar(5); - - const grad = tf.grad(a => tf.relu(a)); - const da = grad(a, dy); - - expect(da.shape).toEqual(a.shape); - expect(da.dtype).toEqual('float32'); - expectArraysClose(await da.data(), [0]); - }); - - it('gradients: array', async () => { - const a = tf.tensor2d([1, -1, 0, .1], [2, 2]); - const dy = tf.tensor2d([1, 2, 3, 4], [2, 2]); - - const grad = tf.grad(a => tf.relu(a)); - const da = grad(a, dy); - - expect(da.shape).toEqual(a.shape); - expect(da.dtype).toEqual('float32'); - expectArraysClose(await da.data(), [1, 0, 0, 4]); - }); - - it('gradients: relu6 array', async () => { - const a = tf.tensor2d([8, -1, 0, .1], [2, 2]); - const dy = tf.tensor2d([1, 2, 3, 4], [2, 2]); - - const grad = tf.grad(a => tf.relu6(a)); - const da = grad(a, dy); - - expect(da.shape).toEqual(a.shape); - expect(da.dtype).toEqual('float32'); - expectArraysClose(await da.data(), [0, 0, 0, 4]); - }); - - it('throws when passed a non-tensor', () => { - expect(() => tf.relu({} as tf.Tensor)) - .toThrowError(/Argument 'x' passed to 'relu' must be a Tensor/); - }); - - it('accepts a tensor-like object', async () => { - const result = tf.relu([1, -2, 0, 3, -0.1]); - expectArraysClose(await result.data(), [1, 0, 0, 3, 0]); - }); - - it('throws for string tensor', () => { - expect(() => tf.relu('q')) - .toThrowError(/Argument 'x' passed to 'relu' must be numeric/); - }); -}); - describeWithFlags('abs', ALL_ENVS, () => { it('basic', async () => { const a = tf.tensor1d([1, -2, 0, 3, -0.1]); @@ -2829,269 +2689,6 @@ describeWithFlags('tanh', ALL_ENVS, () => { }); }); -describeWithFlags('leakyRelu', ALL_ENVS, () => { - it('basic', async () => { - const a = tf.tensor1d([0, 1, -2]); - const result = tf.leakyRelu(a); - - expect(result.shape).toEqual(a.shape); - expectArraysClose(await result.data(), [0, 1, -0.4]); - }); - - it('propagates NaN', async () => { - const a = tf.tensor1d([0, 1, NaN]); - const result = tf.leakyRelu(a); - - expect(result.shape).toEqual(a.shape); - expectArraysClose(await result.data(), [0, 1, NaN]); - }); - - it('gradients: Scalar', async () => { - const a = tf.scalar(-4); - const dy = tf.scalar(8); - const alpha = 0.1; - - const gradients = tf.grad((a) => tf.leakyRelu(a, alpha))(a, dy); - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), [8 * alpha]); - }); - - it('gradient with clones', async () => { - const a = tf.scalar(-4); - const dy = tf.scalar(8); - const alpha = 0.1; - - const gradients = - tf.grad((a) => tf.leakyRelu(a.clone(), alpha).clone())(a, dy); - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), [8 * alpha]); - }); - - it('gradients: Tensor1D', async () => { - const aValues = [1, -1, 0.1]; - const dyValues = [1, 2, 3]; - const alpha = 0.1; - - const a = tf.tensor1d(aValues); - const dy = tf.tensor1d(dyValues); - - const gradients = tf.grad((a) => tf.leakyRelu(a, alpha))(a, dy); - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - - expectArraysClose(await gradients.data(), [1, 2 * alpha, 3]); - }); - - it('gradients: Tensor2D', async () => { - const aValues = [1, -1, 0.1, 0.5]; - const dyValues = [1, 2, 3, 4]; - const alpha = 0.1; - - const a = tf.tensor2d(aValues, [2, 2]); - const dy = tf.tensor2d(dyValues, [2, 2]); - - const gradients = tf.grad((a) => tf.leakyRelu(a, alpha))(a, dy); - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - - expectArraysClose(await gradients.data(), [1, 2 * alpha, 3, 4]); - }); - - it('throws when passed a non-tensor', () => { - expect(() => tf.leakyRelu({} as tf.Tensor)) - .toThrowError(/Argument 'x' passed to 'leakyRelu' must be a Tensor/); - }); - - it('accepts a tensor-like object', async () => { - const result = tf.leakyRelu([0, 1, -2]); - - expect(result.shape).toEqual([3]); - expectArraysClose(await result.data(), [0, 1, -0.4]); - }); - - it('throws for string tensor', () => { - expect(() => tf.leakyRelu('q')) - .toThrowError(/Argument 'x' passed to 'leakyRelu' must be numeric/); - }); -}); - -describeWithFlags('elu', ALL_ENVS, () => { - it('calculate elu', async () => { - const a = tf.tensor1d([1, -1, 0]); - const result = tf.elu(a); - - expect(result.shape).toEqual(a.shape); - expectArraysClose(await result.data(), [1, -0.6321, 0]); - }); - - it('elu propagates NaN', async () => { - const a = tf.tensor1d([1, NaN]); - const result = tf.elu(a); - expect(result.shape).toEqual(a.shape); - expectArraysClose(await result.data(), [1, NaN]); - }); - - it('derivative', async () => { - const x = tf.tensor1d([1, 3, -2]); - const dy = tf.tensor1d([5, 50, 500]); - const gradients = tf.grad(a => tf.elu(a))(x, dy); - - expect(gradients.shape).toEqual(x.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), [5, 50, 500 * Math.exp(-2)]); - }); - - it('gradient with clones', async () => { - const x = tf.tensor1d([1, 3, -2]); - const dy = tf.tensor1d([5, 50, 500]); - const gradients = tf.grad(a => tf.elu(a.clone()).clone())(x, dy); - - expect(gradients.shape).toEqual(x.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), [5, 50, 500 * Math.exp(-2)]); - }); - - it('throws when passed a non-tensor', () => { - expect(() => tf.elu({} as tf.Tensor)) - .toThrowError(/Argument 'x' passed to 'elu' must be a Tensor/); - }); - - it('accepts a tensor-like object', async () => { - const result = tf.elu([1, -1, 0]); - expect(result.shape).toEqual(result.shape); - expectArraysClose(await result.data(), [1, -0.6321, 0]); - }); - - it('throws for string tensor', () => { - expect(() => tf.elu('q')) - .toThrowError(/Argument 'x' passed to 'elu' must be numeric/); - }); -}); - -describeWithFlags('selu', ALL_ENVS, () => { - const scaleAlpha = selu_util.SELU_SCALEALPHA; - const scale = selu_util.SELU_SCALE; - - it('calculate selu', async () => { - const a = tf.tensor1d([1, -1, 0]); - const result = tf.selu(a); - - expect(result.shape).toEqual(a.shape); - expectArraysClose(await result.data(), [1.0507, -1.1113, 0]); - }); - - it('selu propagates NaN', async () => { - const a = tf.tensor1d([1, NaN]); - const result = tf.selu(a); - expect(result.shape).toEqual(a.shape); - expectArraysClose(await result.data(), [1.0507, NaN]); - }); - - it('gradients: Scalar', async () => { - let aValue = 1; - let dyValue = 1; - let a = tf.scalar(aValue); - let dy = tf.scalar(dyValue); - - let gradients = tf.grad(a => tf.selu(a))(a, dy); - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), [dyValue * scale]); - - aValue = -1; - dyValue = 2; - a = tf.scalar(aValue); - dy = tf.scalar(dyValue); - - gradients = tf.grad(a => tf.selu(a))(a, dy); - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose( - await gradients.data(), [dyValue * scaleAlpha * Math.exp(aValue)]); - }); - - it('gradient with clones', async () => { - const aValue = 1; - const dyValue = 1; - const a = tf.scalar(aValue); - const dy = tf.scalar(dyValue); - - const gradients = tf.grad(a => tf.selu(a.clone()).clone())(a, dy); - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), [dyValue * scale]); - }); - - it('gradients: Tensor1D', async () => { - const aValues = [1, -1, 0]; - const dyValues = [1, 2, 3]; - const a = tf.tensor1d(aValues); - const dy = tf.tensor1d(dyValues); - - const gradients = tf.grad(a => tf.selu(a))(a, dy); - - const expected = []; - for (let i = 0; i < a.size; i++) { - if (aValues[i] > 0) { - expected[i] = dyValues[i] * scale; - } else { - expected[i] = dyValues[i] * scaleAlpha * Math.exp(aValues[i]); - } - } - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), expected); - }); - - it('gradients: Tensor2D', async () => { - const aValues = [1, -1, 0, 0.5]; - const dyValues = [1, 2, 3, 4]; - const a = tf.tensor2d(aValues, [2, 2]); - const dy = tf.tensor2d(dyValues, [2, 2]); - - const gradients = tf.grad(a => tf.selu(a))(a, dy); - - const expected = []; - for (let i = 0; i < a.size; i++) { - if (aValues[i] > 0) { - expected[i] = dyValues[i] * scale; - } else { - expected[i] = dyValues[i] * scaleAlpha * Math.exp(aValues[i]); - } - } - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), expected); - }); - - it('throws when passed a non-tensor', () => { - expect(() => tf.selu({} as tf.Tensor)) - .toThrowError(/Argument 'x' passed to 'selu' must be a Tensor/); - }); - - it('accepts a tensor-like object', async () => { - const result = tf.selu([1, -1, 0]); - expect(result.shape).toEqual([3]); - expectArraysClose(await result.data(), [1.0507, -1.1113, 0]); - }); - - it('throws for string tensor', () => { - expect(() => tf.selu('q')) - .toThrowError(/Argument 'x' passed to 'selu' must be numeric/); - }); -}); - describeWithFlags('clip', ALL_ENVS, () => { it('basic', async () => { const a = tf.tensor1d([3, -1, 0, 100, -7, 2]); diff --git a/tfjs-core/src/public/chained_ops/elu.ts b/tfjs-core/src/public/chained_ops/elu.ts new file mode 100644 index 00000000000..923348819ed --- /dev/null +++ b/tfjs-core/src/public/chained_ops/elu.ts @@ -0,0 +1,30 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ +import {elu} from '../../ops/elu'; +import {Tensor} from '../../tensor'; +import {Rank} from '../../types'; + +declare module '../../tensor' { + interface Tensor { + elu(): T; + } +} + +Tensor.prototype.elu = function(this: T): T { + this.throwIfDisposed(); + return elu(this); +}; diff --git a/tfjs-core/src/public/chained_ops/leaky_relu.ts b/tfjs-core/src/public/chained_ops/leaky_relu.ts new file mode 100644 index 00000000000..9c9e8c588ab --- /dev/null +++ b/tfjs-core/src/public/chained_ops/leaky_relu.ts @@ -0,0 +1,31 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ +import {leakyRelu} from '../../ops/leaky_relu'; +import {Tensor} from '../../tensor'; +import {Rank} from '../../types'; + +declare module '../../tensor' { + interface Tensor { + leakyRelu(alpha: number): T; + } +} + +Tensor.prototype.leakyRelu = function( + this: T, alpha: number): T { + this.throwIfDisposed(); + return leakyRelu(this, alpha); +}; diff --git a/tfjs-core/src/public/chained_ops/prelu.ts b/tfjs-core/src/public/chained_ops/prelu.ts new file mode 100644 index 00000000000..1bd48a756cf --- /dev/null +++ b/tfjs-core/src/public/chained_ops/prelu.ts @@ -0,0 +1,31 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ +import {prelu} from '../../ops/prelu'; +import {Tensor} from '../../tensor'; +import {Rank, TensorLike} from '../../types'; + +declare module '../../tensor' { + interface Tensor { + prelu(alpha: T|TensorLike): T; + } +} + +Tensor.prototype.prelu = function( + this: T, alpha: T|TensorLike): T { + this.throwIfDisposed(); + return prelu(this, alpha); +}; diff --git a/tfjs-core/src/public/chained_ops/register_all_chained_ops.ts b/tfjs-core/src/public/chained_ops/register_all_chained_ops.ts index d762c49def9..c201ebfd9e4 100644 --- a/tfjs-core/src/public/chained_ops/register_all_chained_ops.ts +++ b/tfjs-core/src/public/chained_ops/register_all_chained_ops.ts @@ -31,10 +31,12 @@ import './depthwise_conv2D_deprecated'; import './div'; import './div_no_nan'; import './dot'; +import './elu'; import './equal'; import './floorDiv'; import './greater'; import './greater_equal'; +import './leaky_relu'; import './less'; import './less_equal'; import './local_response_normalization'; @@ -50,7 +52,10 @@ import './one_hot'; import './pad'; import './pool'; import './pow'; +import './prelu'; import './relu'; +import './relu6'; +import './selu'; import './separable_conv2d'; import './split'; import './squared_difference'; diff --git a/tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts b/tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts index 41cb1556084..07f15d6bcd1 100644 --- a/tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts +++ b/tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts @@ -41,10 +41,12 @@ const CHAINED_OPS = [ 'div', 'divNoNan', 'dot', + 'elu', 'equal', 'floorDiv', 'greater', 'greaterEqual', + 'leakyRelu', 'less', 'lessEqual', 'localResponseNormalization', @@ -60,7 +62,10 @@ const CHAINED_OPS = [ 'pad', 'pool', 'pow', + 'prelu', 'relu', + 'relu6', + 'selu', 'separableConv2d', 'spaceToBatchND', 'split', diff --git a/tfjs-core/src/public/chained_ops/relu6.ts b/tfjs-core/src/public/chained_ops/relu6.ts new file mode 100644 index 00000000000..f2977eb5985 --- /dev/null +++ b/tfjs-core/src/public/chained_ops/relu6.ts @@ -0,0 +1,30 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ +import {relu6} from '../../ops/relu6'; +import {Tensor} from '../../tensor'; +import {Rank} from '../../types'; + +declare module '../../tensor' { + interface Tensor { + relu6(): T; + } +} + +Tensor.prototype.relu6 = function(this: T): T { + this.throwIfDisposed(); + return relu6(this); +}; diff --git a/tfjs-core/src/public/chained_ops/selu.ts b/tfjs-core/src/public/chained_ops/selu.ts new file mode 100644 index 00000000000..334e1adaa90 --- /dev/null +++ b/tfjs-core/src/public/chained_ops/selu.ts @@ -0,0 +1,30 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ +import {selu} from '../../ops/selu'; +import {Tensor} from '../../tensor'; +import {Rank} from '../../types'; + +declare module '../../tensor' { + interface Tensor { + selu(): T; + } +} + +Tensor.prototype.selu = function(this: T): T { + this.throwIfDisposed(); + return selu(this); +}; diff --git a/tfjs-core/src/register_all_gradients.ts b/tfjs-core/src/register_all_gradients.ts index e9367fda611..0e32c60a3b9 100644 --- a/tfjs-core/src/register_all_gradients.ts +++ b/tfjs-core/src/register_all_gradients.ts @@ -29,6 +29,7 @@ import {conv3DGradConfig} from './gradients/Conv3D_grad'; import {cumsumGradConfig} from './gradients/Cumsum_grad'; import {depthwiseConv2dNativeGradConfig} from './gradients/DepthwiseConv2dNative_grad'; import {divGradConfig} from './gradients/Div_grad'; +import {eluGradConfig} from './gradients/Elu_grad'; import {floorDivGradConfig} from './gradients/FloorDiv_grad'; import {fusedBatchNormGradConfig} from './gradients/FusedBatchNorm_grad'; import {greaterEqualGradConfig} from './gradients/GreaterEqual_grad'; @@ -44,7 +45,10 @@ import {multiplyGradConfig} from './gradients/Multiply_grad'; import {oneHotGradConfig} from './gradients/OneHot_grad'; import {padV2GradConfig} from './gradients/PadV2_grad'; import {powGradConfig} from './gradients/Pow_grad'; +import {preluGradConfig} from './gradients/Prelu_grad'; +import {relu6GradConfig} from './gradients/Relu6_grad'; import {reluGradConfig} from './gradients/Relu_grad'; +import {seluGradConfig} from './gradients/Selu_grad'; import {spaceToBatchNDGradConfig} from './gradients/SpaceToBatchND_grad'; import {splitVGradConfig} from './gradients/SplitV_grad'; import {squareGradConfig} from './gradients/Square_grad'; @@ -72,6 +76,7 @@ const gradConfigs: GradConfig[] = [ cumsumGradConfig, depthwiseConv2dNativeGradConfig, divGradConfig, + eluGradConfig, floorDivGradConfig, fusedBatchNormGradConfig, greaterEqualGradConfig, @@ -92,7 +97,10 @@ const gradConfigs: GradConfig[] = [ oneHotGradConfig, padV2GradConfig, powGradConfig, + preluGradConfig, reluGradConfig, + relu6GradConfig, + seluGradConfig, spaceToBatchNDGradConfig, splitVGradConfig, squareGradConfig, diff --git a/tfjs-core/src/tensor.ts b/tfjs-core/src/tensor.ts index 64da828b9fb..16bf5dbeb6f 100644 --- a/tfjs-core/src/tensor.ts +++ b/tfjs-core/src/tensor.ts @@ -258,11 +258,6 @@ export interface OpHandler { atanh(x: T): T; erf(x: T): T; step(x: T, alpha: number): T; - relu6(x: T): T; - elu(x: T): T; - selu(x: T): T; - leakyRelu(x: T, alpha: number): T; - prelu(x: T, alpha: T|TensorLike): T; softmax(logits: T, dim: number): T; logSoftmax(logits: T, axis: number): T; image: { @@ -950,26 +945,6 @@ export class Tensor { this.throwIfDisposed(); return opHandler.clipByValue(this, min, max); } - relu6(this: T): T { - this.throwIfDisposed(); - return opHandler.relu6(this); - } - elu(this: T): T { - this.throwIfDisposed(); - return opHandler.elu(this); - } - selu(this: T): T { - this.throwIfDisposed(); - return opHandler.selu(this); - } - leakyRelu(alpha = 0.2): Tensor { - this.throwIfDisposed(); - return opHandler.leakyRelu(this, alpha); - } - prelu(alpha: Tensor|TensorLike): Tensor { - this.throwIfDisposed(); - return opHandler.prelu(this, alpha); - } sigmoid(this: T): T { this.throwIfDisposed(); return opHandler.sigmoid(this); diff --git a/tfjs-core/src/tests.ts b/tfjs-core/src/tests.ts index 34926997604..47510b52c88 100644 --- a/tfjs-core/src/tests.ts +++ b/tfjs-core/src/tests.ts @@ -70,6 +70,7 @@ import './ops/depth_to_space_test'; import './ops/diag_test'; import './ops/dropout_test'; import './ops/dropout_util_test'; +import './ops/elu_test'; import './ops/equal_test'; import './ops/eye_test'; import './ops/fused_test'; @@ -78,6 +79,7 @@ import './ops/greater_equal_test'; import './ops/greater_test'; import './ops/image_ops_test'; import './ops/in_top_k_test'; +import './ops/leaky_relu_test'; import './ops/less_equal_test'; import './ops/less_test'; import './ops/linalg_ops_test'; @@ -101,11 +103,14 @@ import './ops/random_gamma_test'; import './ops/random_normal_test'; import './ops/random_uniform_test'; import './ops/reduction_ops_test'; +import './ops/relu6_test'; +import './ops/relu_test'; import './ops/resize_bilinear_test'; import './ops/resize_nearest_neighbor_test'; import './ops/reverse_test'; import './ops/scatter_nd_test'; import './ops/segment_ops_test'; +import './ops/selu_test'; import './ops/signal_ops_test'; import './ops/slice_test'; import './ops/slice_util_test'; From 6eb76124cb07c30cef0a547a5b825165fd34032d Mon Sep 17 00:00:00 2001 From: Na Li Date: Wed, 10 Jun 2020 16:36:16 -0700 Subject: [PATCH 2/3] . --- tfjs-core/src/gradients/Elu_grad.ts | 18 +++++++-- tfjs-core/src/gradients/Relu_grad.ts | 4 +- tfjs-core/src/kernel_names.ts | 4 +- tfjs-core/src/ops/elu_backprop_input.ts | 53 ------------------------- 4 files changed, 20 insertions(+), 59 deletions(-) delete mode 100644 tfjs-core/src/ops/elu_backprop_input.ts diff --git a/tfjs-core/src/gradients/Elu_grad.ts b/tfjs-core/src/gradients/Elu_grad.ts index 2957115a048..29f46bb1f28 100644 --- a/tfjs-core/src/gradients/Elu_grad.ts +++ b/tfjs-core/src/gradients/Elu_grad.ts @@ -14,16 +14,28 @@ * limitations under the License. * ============================================================================= */ -import {Elu} from '../kernel_names'; +import {ENGINE, ForwardFunc} from '../engine'; +import {Elu, EluGrad, EluGradInputs} from '../kernel_names'; import {GradConfig} from '../kernel_registry'; -import {eluBackpropInput} from '../ops/elu_backprop_input'; import {Tensor} from '../tensor'; +import {NamedTensorMap} from '../tensor_types'; export const eluGradConfig: GradConfig = { kernelName: Elu, outputsToSave: [true], gradFunc: (dy: Tensor, saved: Tensor[]) => { const [y] = saved; - return {x: () => eluBackpropInput(dy, y)}; + + const backPropKernelFunc: ForwardFunc = (backend) => { + return backend.eluDer(dy, y); + }; + + const inputs: EluGradInputs = {dy, y}; + + return { + x: () => ENGINE.runKernelFunc( + backPropKernelFunc, inputs as {} as NamedTensorMap, null /* grad */, + EluGrad) + }; } }; diff --git a/tfjs-core/src/gradients/Relu_grad.ts b/tfjs-core/src/gradients/Relu_grad.ts index da2fa8c89ea..ec121f1dadd 100644 --- a/tfjs-core/src/gradients/Relu_grad.ts +++ b/tfjs-core/src/gradients/Relu_grad.ts @@ -16,7 +16,9 @@ */ import {Relu} from '../kernel_names'; import {GradConfig} from '../kernel_registry'; +import {cast} from '../ops/array_ops'; import {mul} from '../ops/mul'; +import {step} from '../ops/unary_ops'; import {Tensor} from '../tensor'; export const reluGradConfig: GradConfig = { @@ -24,6 +26,6 @@ export const reluGradConfig: GradConfig = { inputsToSave: ['x'], gradFunc: (dy: Tensor, saved: Tensor[]) => { const [x] = saved; - return {x: () => mul(dy, x.step().toFloat())}; + return {x: () => mul(dy, cast(step(x), 'float32'))}; } }; diff --git a/tfjs-core/src/kernel_names.ts b/tfjs-core/src/kernel_names.ts index 192956144a4..d8139e7eed2 100644 --- a/tfjs-core/src/kernel_names.ts +++ b/tfjs-core/src/kernel_names.ts @@ -196,8 +196,8 @@ export type DivInputs = BinaryInputs; export const Elu = 'Elu'; export type EluInputs = Pick; -export const EluBackpropInput = 'EluBackpropInput'; -export type EluBackpropInputInputs = Pick; +export const EluGrad = 'EluGrad'; +export type EluGradInputs = Pick; export const Equal = 'Equal'; export type EqualInputs = BinaryInputs; diff --git a/tfjs-core/src/ops/elu_backprop_input.ts b/tfjs-core/src/ops/elu_backprop_input.ts deleted file mode 100644 index bdbbfb5c924..00000000000 --- a/tfjs-core/src/ops/elu_backprop_input.ts +++ /dev/null @@ -1,53 +0,0 @@ -/** - * @license - * Copyright 2020 Google LLC. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - -import {ENGINE, ForwardFunc} from '../engine'; -import {EluBackpropInput, EluBackpropInputInputs} from '../kernel_names'; -import {Tensor} from '../tensor'; -import {NamedTensorMap} from '../tensor_types'; - -import {op} from './operation'; - -/** - * Computes exponential linear element-wise: `x > 0 ? e ^ x - 1 : 0`. - * - * ```js - * const x = tf.tensor1d([-1, 1, -3, 2]); - * - * x.elu().print(); // or tf.elu(x) - * ``` - * @param x The input tensor. - */ -/** - * Computes the derivative of the input of a elu. - * - * @param dy The derivative of the output. - * @param y The output of the forward function. - */ -function eluBackpropInput_(dy: T, y: T): T { - const forward: ForwardFunc = (backend) => { - return backend.eluDer(dy, y); - }; - - const inputs: EluBackpropInputInputs = {dy, y}; - - return ENGINE.runKernelFunc( - forward, inputs as {} as NamedTensorMap, null /* grad */, - EluBackpropInput) as T; -} - -export const eluBackpropInput = op({eluBackpropInput_}); From 681a2f36f94d8a6dbe72b980fbedee9bdccd0702 Mon Sep 17 00:00:00 2001 From: Na Li Date: Thu, 11 Jun 2020 15:07:26 -0700 Subject: [PATCH 3/3] . --- tfjs-core/src/gradients/Selu_grad.ts | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tfjs-core/src/gradients/Selu_grad.ts b/tfjs-core/src/gradients/Selu_grad.ts index 91021e65e53..2faff420e7d 100644 --- a/tfjs-core/src/gradients/Selu_grad.ts +++ b/tfjs-core/src/gradients/Selu_grad.ts @@ -17,6 +17,7 @@ import {Selu} from '../kernel_names'; import {GradConfig} from '../kernel_registry'; import {cast} from '../ops/array_ops'; +import {greater} from '../ops/greater'; import {where} from '../ops/logical_ops'; import {mul} from '../ops/mul'; import {SELU_SCALE, SELU_SCALEALPHA} from '../ops/selu_util'; @@ -31,7 +32,7 @@ export const seluGradConfig: GradConfig = { const [x] = saved; return { x: () => { - const mask = x.greater(scalar(0)); + const mask = greater(x, scalar(0)); const scaleAlpha = scalar(SELU_SCALEALPHA); const scale = scalar(SELU_SCALE);