diff --git a/tfjs-core/src/gradients/Pow_grad.ts b/tfjs-core/src/gradients/Pow_grad.ts new file mode 100644 index 00000000000..04a9d47368c --- /dev/null +++ b/tfjs-core/src/gradients/Pow_grad.ts @@ -0,0 +1,63 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ +import {Pow} from '../kernel_names'; +import {GradConfig} from '../kernel_registry'; +import {cast, reshape} from '../ops/array_ops'; +import {mul} from '../ops/binary_ops'; +import * as broadcast_util from '../ops/broadcast_util'; +import {greater} from '../ops/greater'; +import {where} from '../ops/logical_ops'; +import {pow} from '../ops/pow'; +import {sum} from '../ops/reduction_ops'; +import {sub} from '../ops/sub'; +import {scalar, zerosLike} from '../ops/tensor_ops'; +import {log} from '../ops/unary_ops'; +import {Tensor} from '../tensor'; + +export const powGradConfig: GradConfig = { + kernelName: Pow, + inputsToSave: ['a', 'b'], + outputsToSave: [true], + gradFunc: (dy: Tensor, saved: Tensor[]) => { + const [a, b, y] = saved; + const base = a; + const exp = b; + const outShape = + broadcast_util.assertAndGetBroadcastShape(base.shape, exp.shape); + + const derBase = () => { + const expFloat = cast(exp, 'float32'); + let res = mul(dy, mul(expFloat, pow(base, sub(expFloat, scalar(1))))); + const reduceAxes = broadcast_util.getReductionAxes(base.shape, outShape); + if (reduceAxes.length > 0) { + res = sum(res, reduceAxes); + } + return reshape(res, base.shape); + }; + const derExp = () => { + const condition = greater(base, 0); + const logBase = where(condition, log(base), zerosLike(base)); + let res = mul(dy, mul(y, logBase)); + const reduceAxes = broadcast_util.getReductionAxes(exp.shape, outShape); + if (reduceAxes.length > 0) { + res = sum(res, reduceAxes); + } + return reshape(res, exp.shape); + }; + return {a: derBase, b: derExp}; + } +}; diff --git a/tfjs-core/src/kernel_names.ts b/tfjs-core/src/kernel_names.ts index b65a48a8c78..4d92b597b71 100644 --- a/tfjs-core/src/kernel_names.ts +++ b/tfjs-core/src/kernel_names.ts @@ -327,6 +327,9 @@ export interface PadV2Attrs { export const Pool = 'Pool'; export type PoolInputs = Pick; +export const Pow = 'Pow'; +export type PowInputs = BinaryInputs; + export const Relu = 'Relu'; export type ReluInputs = Pick; diff --git a/tfjs-core/src/ops/binary_ops.ts b/tfjs-core/src/ops/binary_ops.ts index e0c30c6a0cc..c837564d918 100644 --- a/tfjs-core/src/ops/binary_ops.ts +++ b/tfjs-core/src/ops/binary_ops.ts @@ -26,7 +26,6 @@ import * as util from '../util'; import {add} from './add'; import * as broadcast_util from './broadcast_util'; import {op} from './operation'; -import {scalar, zerosLike} from './tensor_ops'; import {neg} from './unary_ops'; /** @@ -69,75 +68,6 @@ function subStrict_(a: T|TensorLike, b: T|TensorLike): T { return $a.sub($b); } -/** - * Computes the power of one `tf.Tensor` to another. Supports broadcasting. - * - * Given a `tf.Tensor` x and a `tf.Tensor` y, this operation computes x^y for - * corresponding elements in x and y. The result's dtype will be the upcasted - * type of the `base` and `exp` dtypes. - * - * ```js - * const a = tf.tensor([[2, 3], [4, 5]]) - * const b = tf.tensor([[1, 2], [3, 0]]).toInt(); - * - * a.pow(b).print(); // or tf.pow(a, b) - * ``` - * - * ```js - * const a = tf.tensor([[1, 2], [3, 4]]) - * const b = tf.tensor(2).toInt(); - * - * a.pow(b).print(); // or tf.pow(a, b) - * ``` - * We also expose `powStrict` which has the same signature as this op and - * asserts that `base` and `exp` are the same shape (does not broadcast). - * - * @param base The base `tf.Tensor` to pow element-wise. - * @param exp The exponent `tf.Tensor` to pow element-wise. - */ -/** @doc {heading: 'Operations', subheading: 'Arithmetic'} */ -function pow_( - base: Tensor|TensorLike, exp: Tensor|TensorLike): T { - let $base = convertToTensor(base, 'base', 'pow'); - let $exp = convertToTensor(exp, 'exp', 'pow'); - [$base, $exp] = makeTypesMatch($base, $exp); - - const outShape = - broadcast_util.assertAndGetBroadcastShape($base.shape, $exp.shape); - const grad = (dy: Tensor, saved: Tensor[]) => { - const [$base, $exp, y] = saved; - const derBase = () => { - const expFloat = $exp.toFloat(); - let res = dy.mul(expFloat.mul($base.pow(expFloat.sub(scalar(1))))); - const reduceAxes = broadcast_util.getReductionAxes($base.shape, outShape); - if (reduceAxes.length > 0) { - res = res.sum(reduceAxes); - } - return res.reshape($base.shape) as T; - }; - const derExp = () => { - const condition = $base.greater(0); - const logBase = $base.log().where(condition, zerosLike($base)); - let res = dy.mul(y.mul(logBase)); - const reduceAxes = broadcast_util.getReductionAxes($exp.shape, outShape); - if (reduceAxes.length > 0) { - res = res.sum(reduceAxes); - } - return res.reshape($exp.shape); - }; - return {a: derBase, b: derExp}; - }; - - const attrs = {}; - const inputsToSave = [$base, $exp]; - const outputsToSave = [true]; - return ENGINE.runKernelFunc((backend, save) => { - const y = backend.pow($base, $exp); - save([$base, $exp, y]); - return y; - }, {a: $base, b: $exp}, grad, 'Pow', attrs, inputsToSave, outputsToSave) as T; -} - /** * @deprecated * Computes the power of one `tf.Tensor` to another. Inputs must @@ -624,7 +554,6 @@ export const mod = op({mod_}); export const modStrict = op({modStrict_}); export const mul = op({mul_}); export const mulStrict = op({mulStrict_}); -export const pow = op({pow_}); export const powStrict = op({powStrict_}); export const squaredDifferenceStrict = op({squaredDifferenceStrict_}); export const subStrict = op({subStrict_}); diff --git a/tfjs-core/src/ops/moving_average.ts b/tfjs-core/src/ops/moving_average.ts index bc2730238ac..8cbe990ed36 100644 --- a/tfjs-core/src/ops/moving_average.ts +++ b/tfjs-core/src/ops/moving_average.ts @@ -20,8 +20,9 @@ import {assertTypesMatch} from '../tensor_util'; import {convertToTensor} from '../tensor_util_env'; import {TensorLike} from '../types'; import * as util from '../util'; -import {pow} from './binary_ops'; + import {op} from './operation'; +import {pow} from './pow'; import {scalar} from './tensor_ops'; /** diff --git a/tfjs-core/src/ops/ops.ts b/tfjs-core/src/ops/ops.ts index b7c6c643bcf..37f3e61dd4e 100644 --- a/tfjs-core/src/ops/ops.ts +++ b/tfjs-core/src/ops/ops.ts @@ -66,6 +66,7 @@ export {pad2d} from './pad2d'; export {pad3d} from './pad3d'; export {pad4d} from './pad4d'; export {pool} from './pool'; +export {pow} from './pow'; export {rand} from './rand'; export {randomGamma} from './random_gamma'; export {randomNormal} from './random_normal'; diff --git a/tfjs-core/src/ops/pow.ts b/tfjs-core/src/ops/pow.ts new file mode 100644 index 00000000000..024525ce782 --- /dev/null +++ b/tfjs-core/src/ops/pow.ts @@ -0,0 +1,72 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ +import {ENGINE, ForwardFunc} from '../engine'; +import {Pow, PowInputs} from '../kernel_names'; +import {Tensor} from '../tensor'; +import {NamedTensorMap} from '../tensor_types'; +import {makeTypesMatch} from '../tensor_util'; +import {convertToTensor} from '../tensor_util_env'; +import {TensorLike} from '../types'; + +import {op} from './operation'; + +/** + * Computes the power of one `tf.Tensor` to another. Supports broadcasting. + * + * Given a `tf.Tensor` x and a `tf.Tensor` y, this operation computes x^y for + * corresponding elements in x and y. The result's dtype will be the upcasted + * type of the `base` and `exp` dtypes. + * + * ```js + * const a = tf.tensor([[2, 3], [4, 5]]) + * const b = tf.tensor([[1, 2], [3, 0]]).toInt(); + * + * a.pow(b).print(); // or tf.pow(a, b) + * ``` + * + * ```js + * const a = tf.tensor([[1, 2], [3, 4]]) + * const b = tf.tensor(2).toInt(); + * + * a.pow(b).print(); // or tf.pow(a, b) + * ``` + * We also expose `powStrict` which has the same signature as this op and + * asserts that `base` and `exp` are the same shape (does not broadcast). + * + * @param base The base `tf.Tensor` to pow element-wise. + * @param exp The exponent `tf.Tensor` to pow element-wise. + */ +/** @doc {heading: 'Operations', subheading: 'Arithmetic'} */ +function pow_( + base: Tensor|TensorLike, exp: Tensor|TensorLike): T { + let $base = convertToTensor(base, 'base', 'pow'); + let $exp = convertToTensor(exp, 'exp', 'pow'); + [$base, $exp] = makeTypesMatch($base, $exp); + + const inputs: PowInputs = {a: $base, b: $exp}; + const forward: ForwardFunc = (backend, save) => { + const y = backend.pow($base, $exp); + save([$base, $exp, y]); + return y; + }; + + return ENGINE.runKernelFunc( + forward, inputs as {} as NamedTensorMap, null /* gradient */, + Pow) as T; +} + +export const pow = op({pow_}); diff --git a/tfjs-core/src/public/chained_ops/pow.ts b/tfjs-core/src/public/chained_ops/pow.ts new file mode 100644 index 00000000000..6a719e95025 --- /dev/null +++ b/tfjs-core/src/public/chained_ops/pow.ts @@ -0,0 +1,30 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ +import {pow} from '../../ops/pow'; +import {Tensor} from '../../tensor'; +import {Rank, TensorLike} from '../../types'; + +declare module '../../tensor' { + interface Tensor { + pow(exp: Tensor|TensorLike): T; + } +} + +Tensor.prototype.pow = function(exp: Tensor|TensorLike): T { + this.throwIfDisposed(); + return pow(this, exp); +}; diff --git a/tfjs-core/src/public/chained_ops/register_all_chained_ops.ts b/tfjs-core/src/public/chained_ops/register_all_chained_ops.ts index 73ac5f82e9c..3cda05f7a9e 100644 --- a/tfjs-core/src/public/chained_ops/register_all_chained_ops.ts +++ b/tfjs-core/src/public/chained_ops/register_all_chained_ops.ts @@ -43,6 +43,7 @@ import './not_equal'; import './one_hot'; import './pad'; import './pool'; +import './pow'; import './relu'; import './separable_conv2d'; import './split'; diff --git a/tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts b/tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts index 72035acf0f8..982b77ef1a0 100644 --- a/tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts +++ b/tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts @@ -53,6 +53,7 @@ const CHAINED_OPS = [ 'oneHot', 'pad', 'pool', + 'pow', 'relu', 'separableConv2d', 'spaceToBatchND', diff --git a/tfjs-core/src/register_all_gradients.ts b/tfjs-core/src/register_all_gradients.ts index 5d1e644be75..58985aaee06 100644 --- a/tfjs-core/src/register_all_gradients.ts +++ b/tfjs-core/src/register_all_gradients.ts @@ -37,6 +37,7 @@ import {maxPool3DGradConfig} from './gradients/MaxPool3D_grad'; import {maxPoolGradConfig} from './gradients/MaxPool_grad'; import {oneHotGradConfig} from './gradients/OneHot_grad'; import {padV2GradConfig} from './gradients/PadV2_grad'; +import {powGradConfig} from './gradients/Pow_grad'; import {reluGradConfig} from './gradients/Relu_grad'; import {spaceToBatchNDGradConfig} from './gradients/SpaceToBatchND_grad'; import {splitVGradConfig} from './gradients/SplitV_grad'; @@ -78,6 +79,7 @@ const gradConfigs: GradConfig[] = [ maxPool3DGradConfig, oneHotGradConfig, padV2GradConfig, + powGradConfig, reluGradConfig, spaceToBatchNDGradConfig, splitVGradConfig, diff --git a/tfjs-core/src/tensor.ts b/tfjs-core/src/tensor.ts index dc1299b677f..05a1a2e4c81 100644 --- a/tfjs-core/src/tensor.ts +++ b/tfjs-core/src/tensor.ts @@ -202,7 +202,6 @@ export interface OpHandler { addStrict(a: T, b: T|TensorLike): T; atan2(a: Tensor, b: Tensor|TensorLike): T; subStrict(a: T, b: T|TensorLike): T; - pow(base: T, exp: Tensor|TensorLike): T; powStrict(base: T, exp: Tensor|TensorLike): T; mul(a: Tensor, b: Tensor|TensorLike): T; mulStrict(a: T, b: T|TensorLike): T; @@ -776,10 +775,6 @@ export class Tensor { this.throwIfDisposed(); return opHandler.subStrict(this, x); } - pow(this: T, exp: Tensor|TensorLike): T { - this.throwIfDisposed(); - return opHandler.pow(this, exp); - } /** * @deprecated strict variants of ops have been deprecated */