From aeabb45922a500ec0496288061b7ff48cad92701 Mon Sep 17 00:00:00 2001 From: Yannick Assogba Date: Thu, 4 Jun 2020 10:22:53 -0400 Subject: [PATCH] modularize complex ops --- tfjs-core/src/io/io_utils.ts | 2 +- tfjs-core/src/kernel_names.ts | 9 +++ .../src/ops/{complex_ops.ts => complex.ts} | 55 ++++--------------- tfjs-core/src/ops/imag.ts | 51 +++++++++++++++++ tfjs-core/src/ops/ops.ts | 4 +- tfjs-core/src/ops/real.ts | 53 ++++++++++++++++++ tfjs-core/src/ops/spectral_ops.ts | 5 +- tfjs-core/src/ops/tensor_ops.ts | 6 +- 8 files changed, 136 insertions(+), 49 deletions(-) rename tfjs-core/src/ops/{complex_ops.ts => complex.ts} (56%) create mode 100644 tfjs-core/src/ops/imag.ts create mode 100644 tfjs-core/src/ops/real.ts diff --git a/tfjs-core/src/io/io_utils.ts b/tfjs-core/src/io/io_utils.ts index 49c590f1f23..1387152e2c0 100644 --- a/tfjs-core/src/io/io_utils.ts +++ b/tfjs-core/src/io/io_utils.ts @@ -15,7 +15,7 @@ * ============================================================================= */ -import {complex} from '../ops/complex_ops'; +import {complex} from '../ops/complex'; import {tensor} from '../ops/tensor_ops'; import {NamedTensor, NamedTensorMap} from '../tensor_types'; diff --git a/tfjs-core/src/kernel_names.ts b/tfjs-core/src/kernel_names.ts index 4d92b597b71..30162115398 100644 --- a/tfjs-core/src/kernel_names.ts +++ b/tfjs-core/src/kernel_names.ts @@ -88,6 +88,9 @@ export interface BroadCastToAttrs { inputShape: number[]; // for gradient } +export const Complex = 'Complex'; +export type ComplexInputs = Pick; + export const Concat = 'Concat'; export type ConcatInputs = TensorInfo[]; export interface ConcatAttrs { @@ -206,6 +209,9 @@ export type GreaterEqualInputs = BinaryInputs; export const Identity = 'Identity'; export type IdentityInputs = Pick; +export const Imag = 'Imag'; +export type ImagInputs = Pick; + export const Less = 'Less'; export type LessInputs = BinaryInputs; @@ -330,6 +336,9 @@ export type PoolInputs = Pick; export const Pow = 'Pow'; export type PowInputs = BinaryInputs; +export const Real = 'Real'; +export type RealInputs = Pick; + export const Relu = 'Relu'; export type ReluInputs = Pick; diff --git a/tfjs-core/src/ops/complex_ops.ts b/tfjs-core/src/ops/complex.ts similarity index 56% rename from tfjs-core/src/ops/complex_ops.ts rename to tfjs-core/src/ops/complex.ts index 2b79863e180..d50fd10508c 100644 --- a/tfjs-core/src/ops/complex_ops.ts +++ b/tfjs-core/src/ops/complex.ts @@ -1,6 +1,6 @@ /** * @license - * Copyright 2018 Google LLC. All Rights Reserved. + * Copyright 2020 Google Inc. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at @@ -14,11 +14,14 @@ * limitations under the License. * ============================================================================= */ -import {ENGINE} from '../engine'; +import {ENGINE, ForwardFunc} from '../engine'; +import {Complex, ComplexInputs} from '../kernel_names'; import {Tensor} from '../tensor'; +import {NamedTensorMap} from '../tensor_types'; import {convertToTensor} from '../tensor_util_env'; import {TensorLike} from '../types'; import * as util from '../util'; + import {op} from './operation'; /** @@ -48,49 +51,13 @@ function complex_(real: T|TensorLike, imag: T|TensorLike): T { `real and imag shapes, ${$real.shape} and ${$imag.shape}, ` + `must match in call to tf.complex().`); + const forward: ForwardFunc = (backend) => { + return backend.complex($real, $imag); + }; + const inputs: ComplexInputs = {real: $real, imag: $imag}; return ENGINE.runKernelFunc( - backend => backend.complex($real, $imag), {$real, $imag}); -} - -/** - * Returns the real part of a complex (or real) tensor. - * - * Given a tensor input, this operation returns a tensor of type float that is - * the real part of each element in input considered as a complex number. - * - * If the input is real, it simply makes a clone. - * - * ```js - * const x = tf.complex([-2.25, 3.25], [4.75, 5.75]); - * tf.real(x).print(); - * ``` - */ -/** @doc {heading: 'Tensors', subheading: 'Creation'} */ -function real_(input: T|TensorLike): T { - const $input = convertToTensor(input, 'input', 'real'); - - return ENGINE.runKernelFunc(backend => backend.real($input), {$input}); -} - -/** - * Returns the imaginary part of a complex (or real) tensor. - * - * Given a tensor input, this operation returns a tensor of type float that is - * the imaginary part of each element in input considered as a complex number. - * If input is real, a tensor of all zeros is returned. - * - * ```js - * const x = tf.complex([-2.25, 3.25], [4.75, 5.75]); - * tf.imag(x).print(); - * ``` - */ -/** @doc {heading: 'Tensors', subheading: 'Creation'} */ -function imag_(input: T|TensorLike): T { - const $input = convertToTensor(input, 'input', 'imag'); - - return ENGINE.runKernelFunc(backend => backend.imag($input), {$input}); + forward, inputs as {} as NamedTensorMap, null /* gradient */, + Complex) as T; } export const complex = op({complex_}); -export const real = op({real_}); -export const imag = op({imag_}); diff --git a/tfjs-core/src/ops/imag.ts b/tfjs-core/src/ops/imag.ts new file mode 100644 index 00000000000..1b8e2006b03 --- /dev/null +++ b/tfjs-core/src/ops/imag.ts @@ -0,0 +1,51 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {ENGINE, ForwardFunc} from '../engine'; +import {Imag, ImagInputs} from '../kernel_names'; +import {Tensor} from '../tensor'; +import {NamedTensorMap} from '../tensor_types'; +import {convertToTensor} from '../tensor_util_env'; +import {TensorLike} from '../types'; +import {op} from './operation'; +/** + * Returns the imaginary part of a complex (or real) tensor. + * + * Given a tensor input, this operation returns a tensor of type float that is + * the imaginary part of each element in input considered as a complex number. + * If input is real, a tensor of all zeros is returned. + * + * ```js + * const x = tf.complex([-2.25, 3.25], [4.75, 5.75]); + * tf.imag(x).print(); + * ``` + */ +/** @doc {heading: 'Tensors', subheading: 'Creation'} */ +function imag_(input: T|TensorLike): T { + const $input = convertToTensor(input, 'input', 'imag'); + + const forward: ForwardFunc = (backend) => { + return backend.imag($input); + }; + + const inputs: ImagInputs = {input: $input}; + return ENGINE.runKernelFunc( + forward, inputs as {} as NamedTensorMap, null /* gradient */, + Imag) as T; +} + +export const imag = op({imag_}); diff --git a/tfjs-core/src/ops/ops.ts b/tfjs-core/src/ops/ops.ts index 37f3e61dd4e..189fe76b270 100644 --- a/tfjs-core/src/ops/ops.ts +++ b/tfjs-core/src/ops/ops.ts @@ -27,6 +27,7 @@ export {batchNorm3d} from './batchnorm3d'; export {batchNorm4d} from './batchnorm4d'; export {broadcastTo} from './broadcast_to'; export {clone} from './clone'; +export {complex} from './complex'; export {concat} from './concat'; export {concat1d} from './concat_1d'; export {concat2d} from './concat_2d'; @@ -48,6 +49,7 @@ export {equal} from './equal'; export {eye} from './eye'; export {greater} from './greater'; export {greaterEqual} from './greater_equal'; +export {imag} from './imag'; export {less} from './less'; export {lessEqual} from './less_equal'; export {localResponseNormalization} from './local_response_normalization'; @@ -71,6 +73,7 @@ export {rand} from './rand'; export {randomGamma} from './random_gamma'; export {randomNormal} from './random_normal'; export {randomUniform} from './random_uniform'; +export {real} from './real'; export {relu} from './relu'; export {separableConv2d} from './separable_conv2d'; export {spaceToBatchND} from './space_to_batch_nd'; @@ -82,7 +85,6 @@ export {tile} from './tile'; export {truncatedNormal} from './truncated_normal'; export * from './boolean_mask'; -export * from './complex_ops'; export * from './reverse'; export * from './slice'; export * from './unary_ops'; diff --git a/tfjs-core/src/ops/real.ts b/tfjs-core/src/ops/real.ts new file mode 100644 index 00000000000..3f45dce0886 --- /dev/null +++ b/tfjs-core/src/ops/real.ts @@ -0,0 +1,53 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {ENGINE, ForwardFunc} from '../engine'; +import {Real, RealInputs} from '../kernel_names'; +import {Tensor} from '../tensor'; +import {NamedTensorMap} from '../tensor_types'; +import {convertToTensor} from '../tensor_util_env'; +import {TensorLike} from '../types'; +import {op} from './operation'; + +/** + * Returns the real part of a complex (or real) tensor. + * + * Given a tensor input, this operation returns a tensor of type float that is + * the real part of each element in input considered as a complex number. + * + * If the input is real, it simply makes a clone. + * + * ```js + * const x = tf.complex([-2.25, 3.25], [4.75, 5.75]); + * tf.real(x).print(); + * ``` + */ +/** @doc {heading: 'Tensors', subheading: 'Creation'} */ +function real_(input: T|TensorLike): T { + const $input = convertToTensor(input, 'input', 'real'); + + const forward: ForwardFunc = (backend) => { + return backend.real($input); + }; + + const inputs: RealInputs = {input: $input}; + return ENGINE.runKernelFunc( + forward, inputs as {} as NamedTensorMap, null /* gradient */, + Real) as T; +} + +export const real = op({real_}); diff --git a/tfjs-core/src/ops/spectral_ops.ts b/tfjs-core/src/ops/spectral_ops.ts index 2ec2b676240..0db449fd1b1 100644 --- a/tfjs-core/src/ops/spectral_ops.ts +++ b/tfjs-core/src/ops/spectral_ops.ts @@ -16,10 +16,13 @@ */ import {ENGINE} from '../engine'; -import {complex, imag, real} from '../ops/complex_ops'; +import {complex} from '../ops/complex'; +import {imag} from '../ops/imag'; import {op} from '../ops/operation'; +import {real} from '../ops/real'; import {Tensor, Tensor2D} from '../tensor'; import {assert} from '../util'; + import {scalar, zeros} from './tensor_ops'; /** diff --git a/tfjs-core/src/ops/tensor_ops.ts b/tfjs-core/src/ops/tensor_ops.ts index b3b9ad7ea6b..cb32bfe234c 100644 --- a/tfjs-core/src/ops/tensor_ops.ts +++ b/tfjs-core/src/ops/tensor_ops.ts @@ -17,14 +17,16 @@ import {ENGINE} from '../engine'; import {env} from '../environment'; - import {Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D, Tensor6D, Variable} from '../tensor'; import {convertToTensor, inferShape} from '../tensor_util_env'; import {TensorLike, TensorLike1D, TensorLike2D, TensorLike3D, TensorLike4D, TensorLike5D, TensorLike6D, TypedArray} from '../types'; import {DataType, Rank, ShapeMap} from '../types'; import {assert, assertNonNegativeIntegerDimensions, assertNonNull, flatten, inferDtype, isTypedArray, makeOnesTypedArray, makeZerosTypedArray, sizeFromShape, toTypedArray} from '../util'; -import {complex, imag, real} from './complex_ops'; + +import {complex} from './complex'; +import {imag} from './imag'; import {op} from './operation'; +import {real} from './real'; /** * Creates a `tf.Tensor` with the provided values, shape and dtype.