diff --git a/tfjs-core/src/gradients/Concat_grad.ts b/tfjs-core/src/gradients/Concat_grad.ts new file mode 100644 index 00000000000..7c4d6963fd2 --- /dev/null +++ b/tfjs-core/src/gradients/Concat_grad.ts @@ -0,0 +1,34 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ +import {Concat, ConcatAttrs} from '../kernel_names'; +import {GradConfig, NamedAttrMap} from '../kernel_registry'; +import {split} from '../ops/split'; +import {Tensor} from '../tensor'; +import {parseAxisParam} from '../util'; + +export const concatGradConfig: GradConfig = { + kernelName: Concat, + saveAllInputs: true, + gradFunc: (dy: Tensor, saved: Tensor[], attrs: NamedAttrMap) => { + const shapes = saved.map(t => t.shape); + const {axis} = attrs as {} as ConcatAttrs; + const $axis = parseAxisParam(axis, saved[0].shape)[0]; + const sizeSplits = shapes.map(s => s[$axis]); + const derTensors = split(dy, sizeSplits, $axis); + return derTensors.map(t => () => t) as {}; + } +}; diff --git a/tfjs-core/src/gradients/SplitV_grad.ts b/tfjs-core/src/gradients/SplitV_grad.ts new file mode 100644 index 00000000000..0fd58fe32e3 --- /dev/null +++ b/tfjs-core/src/gradients/SplitV_grad.ts @@ -0,0 +1,29 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ +import {SplitV, SplitVAttrs} from '../kernel_names'; +import {GradConfig, NamedAttrMap} from '../kernel_registry'; +import {concat} from '../ops/concat'; +import {Tensor} from '../tensor'; + +export const splitVGradConfig: GradConfig = { + kernelName: SplitV, + gradFunc: (dy: Tensor[], saved: Tensor[], attrs: NamedAttrMap) => { + const {axis} = attrs as {} as SplitVAttrs; + + return {x: () => concat(dy, axis)}; + } +}; diff --git a/tfjs-core/src/kernel_names.ts b/tfjs-core/src/kernel_names.ts index b54d4ca91d7..656eca8b085 100644 --- a/tfjs-core/src/kernel_names.ts +++ b/tfjs-core/src/kernel_names.ts @@ -29,6 +29,19 @@ export type AddNInputs = TensorInfo[]; export type BinaryInputs = Pick; +export const BroadcastTo = 'BroadcastTo'; +export type BroadcastToInputs = Pick; +export interface BroadCastToAttrs { + shape: number[]; + inputShape: number[]; // for gradient +} + +export const Concat = 'Concat'; +export type ConcatInputs = TensorInfo[]; +export interface ConcatAttrs { + axis: number; +} + export const Div = 'Div'; export type DivInputs = BinaryInputs; @@ -39,24 +52,21 @@ export interface FusedBatchNormAttrs { varianceEpsilon: number; } -export const NotEqual = 'NotEqual'; -export type NotEqualInputs = BinaryInputs; - -export const SquaredDifference = 'SquaredDifference'; -export type SquaredDifferenceInputs = BinaryInputs; - -export const Square = 'Square'; -export type SquareInputs = Pick; - -export const Sub = 'Sub'; -export type SubInputs = BinaryInputs; +export const Identity = 'Identity'; +export type IdentityInputs = Pick; -export const Transpose = 'Transpose'; -export type TransposeInputs = Pick; -export interface TransposeAttrs { - perm: number[]; +export const MaxPoolWithArgmax = 'MaxPoolWithArgmax'; +export type MaxPoolWithArgmaxInputs = Pick; +export interface MaxPoolWithArgmaxAttrs { + filterSize: [number, number]|number; + strides: [number, number]|number; + pad: 'valid'|'same'|number; + includeBatchInIndex: boolean; } +export const NotEqual = 'NotEqual'; +export type NotEqualInputs = BinaryInputs; + export const NonMaxSuppressionV5 = 'NonMaxSuppressionV5'; export type NonMaxSuppressionV5Inputs = Pick; @@ -67,13 +77,6 @@ export interface NonMaxSuppressionV5Attrs { softNmsSigma: number; } -export const BroadcastTo = 'BroadcastTo'; -export type BroadcastToInputs = Pick; -export interface BroadCastToAttrs { - shape: number[]; - inputShape: number[]; // for gradient -} - export const OneHot = 'OneHot'; export type OneHotInputs = Pick; export interface OneHotAttrs { @@ -82,8 +85,28 @@ export interface OneHotAttrs { offValue: number; } -export const Identity = 'Identity'; -export type IdentityInputs = Pick; +export const PadV2 = 'PadV2'; +export type PadV2Inputs = Pick; +export interface PadV2Attrs { + paddings: Array<[number, number]>; + constantValue: number; +} + +export const SplitV = 'SplitV'; +export type SplitVInputs = Pick; +export interface SplitVAttrs { + numOrSizeSplits: number[]|number; + axis: number; +} + +export const SquaredDifference = 'SquaredDifference'; +export type SquaredDifferenceInputs = BinaryInputs; + +export const Square = 'Square'; +export type SquareInputs = Pick; + +export const Sub = 'Sub'; +export type SubInputs = BinaryInputs; export const Tile = 'Tile'; export type TileInputs = Pick; @@ -91,11 +114,10 @@ export interface TileAttrs { reps: number[]; } -export const PadV2 = 'PadV2'; -export type PadV2Inputs = Pick; -export interface PadV2Attrs { - paddings: Array<[number, number]>; - constantValue: number; +export const Transpose = 'Transpose'; +export type TransposeInputs = Pick; +export interface TransposeAttrs { + perm: number[]; } /** @@ -109,12 +131,3 @@ export interface FromPixelsInputs { export interface FromPixelsAttrs { numChannels: number; } - -export const MaxPoolWithArgmax = 'MaxPoolWithArgmax'; -export type MaxPoolWithArgmaxInputs = Pick; -export interface MaxPoolWithArgmaxAttrs { - filterSize: [number, number]|number; - strides: [number, number]|number; - pad: 'valid'|'same'|number; - includeBatchInIndex: boolean; -} diff --git a/tfjs-core/src/ops/add_n.ts b/tfjs-core/src/ops/add_n.ts index 621abd06e34..0ae7b52c8f3 100644 --- a/tfjs-core/src/ops/add_n.ts +++ b/tfjs-core/src/ops/add_n.ts @@ -64,8 +64,11 @@ function addN_(tensors: Array): T { } }); - const forward: ForwardFunc = (backend, save) => - backend.addN($tensors); + const forward: ForwardFunc = (backend, save) => { + const res = backend.addN($tensors); + save($tensors); + return res; + }; const inputs: AddNInputs = $tensors; diff --git a/tfjs-core/src/ops/array_ops.ts b/tfjs-core/src/ops/array_ops.ts index cd35b7ab58c..7164a8f4f00 100644 --- a/tfjs-core/src/ops/array_ops.ts +++ b/tfjs-core/src/ops/array_ops.ts @@ -21,7 +21,7 @@ import {convertToTensor, convertToTensorArray} from '../tensor_util_env'; import {DataType, DataTypeMap, Rank, ShapeMap, TensorLike, TensorLike4D} from '../types'; import * as util from '../util'; import {getAxesPermutation, getInnerMostAxes} from './axis_util'; -import {concat} from './concat_split'; +import {concat} from './concat'; import {op} from './operation'; /** diff --git a/tfjs-core/src/ops/concat.ts b/tfjs-core/src/ops/concat.ts new file mode 100644 index 00000000000..98b9056bfe5 --- /dev/null +++ b/tfjs-core/src/ops/concat.ts @@ -0,0 +1,111 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ +import {ENGINE, ForwardFunc} from '../engine'; +import {Concat, ConcatAttrs, ConcatInputs} from '../kernel_names'; +import {NamedAttrMap} from '../kernel_registry'; +import {Tensor} from '../tensor'; +import {NamedTensorMap} from '../tensor_types'; +import {convertToTensorArray} from '../tensor_util_env'; +import {TensorLike} from '../types'; +import {assert, parseAxisParam, sizeFromShape} from '../util'; + +import {assertParamsConsistent, computeOutShape} from './concat_util'; +import {op} from './operation'; +import {tensor} from './tensor_ops'; + +/** + * Concatenates a list of `tf.Tensor`s along a given axis. + * + * The tensors ranks and types must match, and their sizes must match in all + * dimensions except `axis`. + * + * Also available are stricter rank-specific methods that assert that + * `tensors` are of the given rank: + * - `tf.concat1d` + * - `tf.concat2d` + * - `tf.concat3d` + * - `tf.concat4d` + * + * Except `tf.concat1d` (which does not have axis param), all methods have + * same signature as this method. + * + * ```js + * const a = tf.tensor1d([1, 2]); + * const b = tf.tensor1d([3, 4]); + * a.concat(b).print(); // or a.concat(b) + * ``` + * + * ```js + * const a = tf.tensor1d([1, 2]); + * const b = tf.tensor1d([3, 4]); + * const c = tf.tensor1d([5, 6]); + * tf.concat([a, b, c]).print(); + * ``` + * + * ```js + * const a = tf.tensor2d([[1, 2], [10, 20]]); + * const b = tf.tensor2d([[3, 4], [30, 40]]); + * const axis = 1; + * tf.concat([a, b], axis).print(); + * ``` + * @param tensors A list of tensors to concatenate. + * @param axis The axis to concate along. Defaults to 0 (the first dim). + */ +/** @doc {heading: 'Tensors', subheading: 'Slicing and Joining'} */ +function concat_(tensors: Array, axis = 0): T { + assert(tensors.length >= 1, () => 'Pass at least one tensor to concat'); + + let $tensors = convertToTensorArray(tensors, 'tensors', 'concat'); + if ($tensors[0].dtype === 'complex64') { + $tensors.forEach(tensor => { + if (tensor.dtype !== 'complex64') { + throw new Error(`Cannot concatenate complex64 tensors with a tensor + with dtype ${tensor.dtype}. `); + } + }); + } + + const $axis = parseAxisParam(axis, $tensors[0].shape)[0]; + const outShape = computeOutShape($tensors.map(t => t.shape), $axis); + if (sizeFromShape(outShape) === 0) { + return tensor([], outShape) as T; + } + // Keep only non-empty tensors (ignore tensors with 0 in their shape). + $tensors = $tensors.filter(t => t.size > 0); + if ($tensors.length === 1) { + return $tensors[0]; + } + + const shapes = $tensors.map(t => t.shape); + assertParamsConsistent(shapes, $axis); + + const forward: ForwardFunc = (backend, save) => { + const $axis = parseAxisParam(axis, $tensors[0].shape)[0]; + const res = backend.concat($tensors, $axis); + save($tensors); + return res; + }; + + const inputs: ConcatInputs = $tensors; + const attr: ConcatAttrs = {axis}; + + return ENGINE.runKernelFunc( + forward, inputs as {} as NamedTensorMap, null /* grad */, Concat, + attr as {} as NamedAttrMap) as T; +} + +export const concat = op({concat_}); diff --git a/tfjs-core/src/ops/concat_1d.ts b/tfjs-core/src/ops/concat_1d.ts new file mode 100644 index 00000000000..672f85a7d84 --- /dev/null +++ b/tfjs-core/src/ops/concat_1d.ts @@ -0,0 +1,38 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ +import {Tensor1D} from '../tensor'; +import {TensorLike} from '../types'; + +import {concat} from './concat'; +import {op} from './operation'; + +/** + * Concatenates a list of`tf.Tensor1D`s along an axis. See `concat` for details. + * + * For example, if: + * A: shape(3) = |r1, g1, b1| + * B: shape(2) = |r2, g2| + * C = tf.concat1d([A, B]) == |r1, g1, b1, r2, g2| + * + * @param tensors A list of`tf.Tensor`s to concatenate. + * @return The concatenated array. + */ +function concat1d_(tensors: Array): Tensor1D { + return concat(tensors, 0 /* axis */); +} + +export const concat1d = op({concat1d_}); diff --git a/tfjs-core/src/ops/concat_2d.ts b/tfjs-core/src/ops/concat_2d.ts new file mode 100644 index 00000000000..881226613dd --- /dev/null +++ b/tfjs-core/src/ops/concat_2d.ts @@ -0,0 +1,55 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ +import {Tensor2D} from '../tensor'; +import {TensorLike} from '../types'; + +import {concat} from './concat'; +import {op} from './operation'; + +/** + * Concatenates a list of`tf.Tensor2D`s along an axis. See `concat` for details. + * + * For example, if: + * A: shape(2, 3) = | r1, g1, b1 | + * | r2, g2, b2 | + * + * B: shape(2, 3) = | r3, g3, b3 | + * | r4, g4, b4 | + * + * C = tf.concat2d([A, B], axis) + * + * if axis = 0: + * C: shape(4, 3) = | r1, g1, b1 | + * | r2, g2, b2 | + * | r3, g3, b3 | + * | r4, g4, b4 | + * + * if axis = 1: + * C = shape(2, 6) = | r1, g1, b1, r3, g3, b3 | + * | r2, g2, b2, r4, g4, b4 | + * + * + * @param tensors A list of `tf.Tensor`s to concatenate. + * @param axis The axis to concatenate along. + * @return The concatenated array. + */ +function concat2d_( + tensors: Array, axis: number): Tensor2D { + return concat(tensors, axis); +} + +export const concat2d = op({concat2d_}); diff --git a/tfjs-core/src/ops/concat_3d.ts b/tfjs-core/src/ops/concat_3d.ts new file mode 100644 index 00000000000..9ed7de5cfeb --- /dev/null +++ b/tfjs-core/src/ops/concat_3d.ts @@ -0,0 +1,59 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ +import {Tensor3D} from '../tensor'; +import {TensorLike} from '../types'; + +import {concat} from './concat'; +import {op} from './operation'; + +/** + * Concatenates a list of `tf.Tensor3D`s along an axis. + * See `concat` for details. + * + * For example, if: + * A: shape(2, 1, 3) = | r1, g1, b1 | + * | r2, g2, b2 | + * + * B: shape(2, 1, 3) = | r3, g3, b3 | + * | r4, g4, b4 | + * + * C = tf.concat3d([A, B], axis) + * + * if axis = 0: + * C: shape(4, 1, 3) = | r1, g1, b1 | + * | r2, g2, b2 | + * | r3, g3, b3 | + * | r4, g4, b4 | + * + * if axis = 1: + * C: shape(2, 2, 3) = | r1, g1, b1, r3, g3, b3 | + * | r2, g2, b2, r4, g4, b4 | + * + * if axis = 2: + * C = shape(2, 1, 6) = | r1, g1, b1, r3, g3, b3 | + * | r2, g2, b2, r4, g4, b4 | + * + * @param tensors A list of`tf.Tensor`s to concatenate. + * @param axis The axis to concate along. + * @return The concatenated array. + */ +function concat3d_( + tensors: Array, axis: number): Tensor3D { + return concat(tensors, axis); +} + +export const concat3d = op({concat3d_}); diff --git a/tfjs-core/src/ops/concat_4d.ts b/tfjs-core/src/ops/concat_4d.ts new file mode 100644 index 00000000000..78df8db29a5 --- /dev/null +++ b/tfjs-core/src/ops/concat_4d.ts @@ -0,0 +1,36 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ +import {Tensor4D} from '../tensor'; +import {TensorLike} from '../types'; + +import {concat} from './concat'; +import {op} from './operation'; + +/** + * Concatenates a list of `tf.Tensor4D`s along an axis. + * See `concat` for details. + * + * @param tensors A list of `tf.Tensor`s to concatenate. + * @param axis The axis to concate along. + * @return The concatenated array. + */ +function concat4d_( + tensors: Array, axis: number): Tensor4D { + return concat(tensors, axis); +} + +export const concat4d = op({concat4d_}); diff --git a/tfjs-core/src/ops/concat_split.ts b/tfjs-core/src/ops/concat_split.ts deleted file mode 100644 index 92909ed19eb..00000000000 --- a/tfjs-core/src/ops/concat_split.ts +++ /dev/null @@ -1,261 +0,0 @@ -/** - * @license - * Copyright 2018 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - -import {ENGINE} from '../engine'; -import {Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D} from '../tensor'; -import {convertToTensor, convertToTensorArray} from '../tensor_util_env'; -import {TensorLike} from '../types'; -import {assert, sizeFromShape} from '../util'; -import {parseAxisParam} from '../util'; -import {assertParamsConsistent, computeOutShape} from './concat_util'; -import {op} from './operation'; -import {tensor} from './tensor_ops'; - -/** - * Concatenates a list of`tf.Tensor1D`s along an axis. See `concat` for details. - * - * For example, if: - * A: shape(3) = |r1, g1, b1| - * B: shape(2) = |r2, g2| - * C = tf.concat1d([A, B]) == |r1, g1, b1, r2, g2| - * - * @param tensors A list of`tf.Tensor`s to concatenate. - * @return The concatenated array. - */ -function concat1d_(tensors: Array): Tensor1D { - return concat(tensors, 0 /* axis */); -} - -/** - * Concatenates a list of`tf.Tensor2D`s along an axis. See `concat` for details. - * - * For example, if: - * A: shape(2, 3) = | r1, g1, b1 | - * | r2, g2, b2 | - * - * B: shape(2, 3) = | r3, g3, b3 | - * | r4, g4, b4 | - * - * C = tf.concat2d([A, B], axis) - * - * if axis = 0: - * C: shape(4, 3) = | r1, g1, b1 | - * | r2, g2, b2 | - * | r3, g3, b3 | - * | r4, g4, b4 | - * - * if axis = 1: - * C = shape(2, 6) = | r1, g1, b1, r3, g3, b3 | - * | r2, g2, b2, r4, g4, b4 | - * - * - * @param tensors A list of `tf.Tensor`s to concatenate. - * @param axis The axis to concatenate along. - * @return The concatenated array. - */ -function concat2d_( - tensors: Array, axis: number): Tensor2D { - return concat(tensors, axis); -} - -/** - * Concatenates a list of `tf.Tensor3D`s along an axis. - * See `concat` for details. - * - * For example, if: - * A: shape(2, 1, 3) = | r1, g1, b1 | - * | r2, g2, b2 | - * - * B: shape(2, 1, 3) = | r3, g3, b3 | - * | r4, g4, b4 | - * - * C = tf.concat3d([A, B], axis) - * - * if axis = 0: - * C: shape(4, 1, 3) = | r1, g1, b1 | - * | r2, g2, b2 | - * | r3, g3, b3 | - * | r4, g4, b4 | - * - * if axis = 1: - * C: shape(2, 2, 3) = | r1, g1, b1, r3, g3, b3 | - * | r2, g2, b2, r4, g4, b4 | - * - * if axis = 2: - * C = shape(2, 1, 6) = | r1, g1, b1, r3, g3, b3 | - * | r2, g2, b2, r4, g4, b4 | - * - * @param tensors A list of`tf.Tensor`s to concatenate. - * @param axis The axis to concate along. - * @return The concatenated array. - */ -function concat3d_( - tensors: Array, axis: number): Tensor3D { - return concat(tensors, axis); -} - -/** - * Concatenates a list of `tf.Tensor4D`s along an axis. - * See `concat` for details. - * - * @param tensors A list of `tf.Tensor`s to concatenate. - * @param axis The axis to concate along. - * @return The concatenated array. - */ -function concat4d_( - tensors: Array, axis: number): Tensor4D { - return concat(tensors, axis); -} - -/** - * Concatenates a list of `tf.Tensor`s along a given axis. - * - * The tensors ranks and types must match, and their sizes must match in all - * dimensions except `axis`. - * - * Also available are stricter rank-specific methods that assert that - * `tensors` are of the given rank: - * - `tf.concat1d` - * - `tf.concat2d` - * - `tf.concat3d` - * - `tf.concat4d` - * - * Except `tf.concat1d` (which does not have axis param), all methods have - * same signature as this method. - * - * ```js - * const a = tf.tensor1d([1, 2]); - * const b = tf.tensor1d([3, 4]); - * a.concat(b).print(); // or a.concat(b) - * ``` - * - * ```js - * const a = tf.tensor1d([1, 2]); - * const b = tf.tensor1d([3, 4]); - * const c = tf.tensor1d([5, 6]); - * tf.concat([a, b, c]).print(); - * ``` - * - * ```js - * const a = tf.tensor2d([[1, 2], [10, 20]]); - * const b = tf.tensor2d([[3, 4], [30, 40]]); - * const axis = 1; - * tf.concat([a, b], axis).print(); - * ``` - * @param tensors A list of tensors to concatenate. - * @param axis The axis to concate along. Defaults to 0 (the first dim). - */ -/** @doc {heading: 'Tensors', subheading: 'Slicing and Joining'} */ -function concat_(tensors: Array, axis = 0): T { - assert(tensors.length >= 1, () => 'Pass at least one tensor to concat'); - let $tensors = convertToTensorArray(tensors, 'tensors', 'concat'); - if ($tensors[0].dtype === 'complex64') { - $tensors.forEach(tensor => { - if (tensor.dtype !== 'complex64') { - throw new Error(`Cannot concatenate complex64 tensors with a tensor - with dtype ${tensor.dtype}. `); - } - }); - } - - axis = parseAxisParam(axis, $tensors[0].shape)[0]; - const outShape = computeOutShape($tensors.map(t => t.shape), axis); - if (sizeFromShape(outShape) === 0) { - return tensor([], outShape) as T; - } - // Keep only non-empty tensors (ignore tensors with 0 in their shape). - $tensors = $tensors.filter(t => t.size > 0); - if ($tensors.length === 1) { - return $tensors[0]; - } - - const shapes = $tensors.map(t => t.shape); - assertParamsConsistent(shapes, axis); - const der = (dy: T) => { - const sizeSplits = shapes.map(s => s[axis]); - const derTensors = split(dy, sizeSplits, axis); - return derTensors.map(t => () => t) as {}; - }; - const inputs = $tensors as {}; - const attr = {axis}; - return ENGINE.runKernelFunc( - backend => backend.concat($tensors, axis) as T, inputs, der, 'Concat', - attr); -} - -/** - * Splits a `tf.Tensor` into sub tensors. - * - * If `numOrSizeSplits` is a number, splits `x` along dimension `axis` - * into `numOrSizeSplits` smaller tensors. - * Requires that `numOrSizeSplits` evenly divides `x.shape[axis]`. - * - * If `numOrSizeSplits` is a number array, splits `x` into - * `numOrSizeSplits.length` pieces. The shape of the `i`-th piece has the - * same size as `x` except along dimension `axis` where the size is - * `numOrSizeSplits[i]`. - * - * ```js - * const x = tf.tensor2d([1, 2, 3, 4, 5, 6, 7, 8], [2, 4]); - * const [a, b] = tf.split(x, 2, 1); - * a.print(); - * b.print(); - * - * const [c, d, e] = tf.split(x, [1, 2, 1], 1); - * c.print(); - * d.print(); - * e.print(); - * ``` - * - * @param x The input tensor to split. - * @param numOrSizeSplits Either an integer indicating the number of - * splits along the axis or an array of integers containing the sizes of - * each output tensor along the axis. If a number then it must evenly divide - * `x.shape[axis]`; otherwise the sum of sizes must match `x.shape[axis]`. - * @param axis The dimension along which to split. Defaults to 0 (the first - * dim). - */ -/** @doc {heading: 'Tensors', subheading: 'Slicing and Joining'} */ -function split_( - x: T|TensorLike, numOrSizeSplits: number[]|number, axis = 0): T[] { - const $x = convertToTensor(x, 'x', 'split'); - - axis = parseAxisParam(axis, $x.shape)[0]; - let splitSizes: number[]; - if (typeof (numOrSizeSplits) === 'number') { - assert( - $x.shape[axis] % numOrSizeSplits === 0, - () => 'Number of splits must evenly divide the axis.'); - splitSizes = - new Array(numOrSizeSplits).fill($x.shape[axis] / numOrSizeSplits); - } else { - assert( - $x.shape[axis] === numOrSizeSplits.reduce((a, b) => a + b), - () => 'The sum of sizes must match the size of the axis dimension.'); - splitSizes = numOrSizeSplits; - } - const der = (dy: T[]) => ({$x: () => concat(dy, axis)}); - return ENGINE.runKernelFunc( - backend => backend.split($x, splitSizes, axis), {$x}, der); -} - -export const concat = op({concat_}); -export const concat1d = op({concat1d_}); -export const concat2d = op({concat2d_}); -export const concat3d = op({concat3d_}); -export const concat4d = op({concat4d_}); -export const split = op({split_}); diff --git a/tfjs-core/src/ops/linalg_ops.ts b/tfjs-core/src/ops/linalg_ops.ts index 722651d2ff9..cb00fa80657 100644 --- a/tfjs-core/src/ops/linalg_ops.ts +++ b/tfjs-core/src/ops/linalg_ops.ts @@ -27,12 +27,12 @@ import {TensorLike} from '../types'; import {assert} from '../util'; import {squeeze, stack, unstack} from './array_ops'; -import {split} from './concat_split'; import {eye} from './eye'; import {logicalAnd, where} from './logical_ops'; import {norm} from './norm'; import {op} from './operation'; import {sum} from './reduction_ops'; +import {split} from './split'; import {sub} from './sub'; import {range, scalar, tensor2d, zeros} from './tensor_ops'; diff --git a/tfjs-core/src/ops/ops.ts b/tfjs-core/src/ops/ops.ts index 00237b3052d..4dd9d8b7115 100644 --- a/tfjs-core/src/ops/ops.ts +++ b/tfjs-core/src/ops/ops.ts @@ -24,6 +24,11 @@ export {batchNorm3d, batchNormalization3d} from './batchnorm3d'; export {batchNorm4d, batchNormalization4d} from './batchnorm4d'; export {broadcastTo} from './broadcast_to'; export {clone} from './clone'; +export {concat} from './concat'; +export {concat1d} from './concat_1d'; +export {concat2d} from './concat_2d'; +export {concat3d} from './concat_3d'; +export {concat4d} from './concat_4d'; export {div} from './div'; export {divNoNan} from './div_no_nan'; export {eye} from './eye'; @@ -39,6 +44,7 @@ export {rand} from './rand'; export {randomGamma} from './random_gamma'; export {randomNormal} from './random_normal'; export {randomUniform} from './random_uniform'; +export {split} from './split'; export {square} from './square'; export {squaredDifference} from './squared_difference'; export {sub} from './sub'; @@ -47,7 +53,6 @@ export {truncatedNormal} from './truncated_normal'; export * from './boolean_mask'; export * from './complex_ops'; -export * from './concat_split'; // Selectively exporting to avoid exposing gradient ops. export {conv1d, conv2d, conv3d, depthwiseConv2d, separableConv2d, conv2dTranspose, conv3dTranspose} from './conv'; export * from './matmul'; diff --git a/tfjs-core/src/ops/signal_ops.ts b/tfjs-core/src/ops/signal_ops.ts index 35a240a6f02..899bf8ecb59 100644 --- a/tfjs-core/src/ops/signal_ops.ts +++ b/tfjs-core/src/ops/signal_ops.ts @@ -19,7 +19,7 @@ import {op} from '../ops/operation'; import {Tensor, Tensor1D} from '../tensor'; import {mul} from './binary_ops'; -import {concat} from './concat_split'; +import {concat} from './concat'; import {slice} from './slice'; import {rfft} from './spectral_ops'; import {fill, tensor1d, tensor2d} from './tensor_ops'; @@ -88,9 +88,9 @@ function frame_( if (padEnd) { while (start < signal.size) { const padLen = (start + frameLength) - signal.size; - const pad = concat( - [slice(signal, start, frameLength - padLen), - fill([padLen], padValue)]); + const pad = concat([ + slice(signal, start, frameLength - padLen), fill([padLen], padValue) + ]); output.push(pad); start += frameStep; } @@ -131,8 +131,8 @@ function stft_( const windowedSignal = mul(framedSignal, windowFn(frameLength)); const output: Tensor[] = []; for (let i = 0; i < framedSignal.shape[0]; i++) { - output.push(rfft(windowedSignal.slice([i, 0], [1, frameLength]), - fftLength)); + output.push( + rfft(windowedSignal.slice([i, 0], [1, frameLength]), fftLength)); } return concat(output); } diff --git a/tfjs-core/src/ops/split.ts b/tfjs-core/src/ops/split.ts new file mode 100644 index 00000000000..5751f9499f7 --- /dev/null +++ b/tfjs-core/src/ops/split.ts @@ -0,0 +1,95 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ +import {ENGINE, ForwardFunc} from '../engine'; +import {SplitV, SplitVAttrs, SplitVInputs} from '../kernel_names'; +import {NamedAttrMap} from '../kernel_registry'; +import {Tensor} from '../tensor'; +import {NamedTensorMap} from '../tensor_types'; +import {convertToTensor} from '../tensor_util_env'; +import {TensorLike} from '../types'; +import {assert,} from '../util'; +import {parseAxisParam} from '../util'; + +import {op} from './operation'; + +/** + * Splits a `tf.Tensor` into sub tensors. + * + * If `numOrSizeSplits` is a number, splits `x` along dimension `axis` + * into `numOrSizeSplits` smaller tensors. + * Requires that `numOrSizeSplits` evenly divides `x.shape[axis]`. + * + * If `numOrSizeSplits` is a number array, splits `x` into + * `numOrSizeSplits.length` pieces. The shape of the `i`-th piece has the + * same size as `x` except along dimension `axis` where the size is + * `numOrSizeSplits[i]`. + * + * ```js + * const x = tf.tensor2d([1, 2, 3, 4, 5, 6, 7, 8], [2, 4]); + * const [a, b] = tf.split(x, 2, 1); + * a.print(); + * b.print(); + * + * const [c, d, e] = tf.split(x, [1, 2, 1], 1); + * c.print(); + * d.print(); + * e.print(); + * ``` + * + * @param x The input tensor to split. + * @param numOrSizeSplits Either an integer indicating the number of + * splits along the axis or an array of integers containing the sizes of + * each output tensor along the axis. If a number then it must evenly divide + * `x.shape[axis]`; otherwise the sum of sizes must match `x.shape[axis]`. + * @param axis The dimension along which to split. Defaults to 0 (the first + * dim). + */ +/** @doc {heading: 'Tensors', subheading: 'Slicing and Joining'} */ +function split_( + x: Tensor|TensorLike, numOrSizeSplits: number[]|number, axis = 0): T[] { + const $x = convertToTensor(x, 'x', 'split'); + + const $axis = parseAxisParam(axis, $x.shape)[0]; + let splitSizes: number[]; + + if (typeof (numOrSizeSplits) === 'number') { + assert( + $x.shape[$axis] % numOrSizeSplits === 0, + () => 'Number of splits must evenly divide the axis.'); + splitSizes = + new Array(numOrSizeSplits).fill($x.shape[$axis] / numOrSizeSplits); + } else { + assert( + $x.shape[$axis] === numOrSizeSplits.reduce((a, b) => a + b), + () => 'The sum of sizes must match the size of the axis dimension.'); + splitSizes = numOrSizeSplits; + } + + const forward: ForwardFunc = (backend, _) => { + const $axis = parseAxisParam(axis, $x.shape)[0]; + return backend.split($x, splitSizes, $axis) as {} as T; + }; + + const inputs: SplitVInputs = {x: $x}; + const attr: SplitVAttrs = {numOrSizeSplits, axis}; + + return ENGINE.runKernelFunc( + forward, inputs as {} as NamedTensorMap, null /* grad */, SplitV, + attr as {} as NamedAttrMap) as {} as T[]; +} + +export const split = op({split_}); diff --git a/tfjs-core/src/public/chained_ops/concat.ts b/tfjs-core/src/public/chained_ops/concat.ts new file mode 100644 index 00000000000..c8b80db6136 --- /dev/null +++ b/tfjs-core/src/public/chained_ops/concat.ts @@ -0,0 +1,34 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ +import {concat} from '../../ops/concat'; +import {Tensor} from '../../tensor'; +import {Rank, TensorLike} from '../../types'; + +declare module '../../tensor' { + interface Tensor { + concat(tensors: T|Array, axis?: number): T; + } +} + +Tensor.prototype.concat = function( + x: T|Array, axis?: number): T { + this.throwIfDisposed(); + if (x instanceof Tensor) { + x = [x]; + } + return concat([this, ...x], axis) as T; +}; diff --git a/tfjs-core/src/public/chained_ops/register_all_chained_ops.ts b/tfjs-core/src/public/chained_ops/register_all_chained_ops.ts index 6381adc4449..2516bae57db 100644 --- a/tfjs-core/src/public/chained_ops/register_all_chained_ops.ts +++ b/tfjs-core/src/public/chained_ops/register_all_chained_ops.ts @@ -17,11 +17,13 @@ import './add'; import './batchnorm'; import './broadcast_to'; +import './concat'; import './div'; import './div_no_nan'; import './one_hot'; import './not_equal'; import './pad'; +import './split'; import './squared_difference'; import './sub'; import './tile'; diff --git a/tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts b/tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts index e9175206785..541100ca35c 100644 --- a/tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts +++ b/tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts @@ -24,8 +24,8 @@ import {ALL_ENVS, describeWithFlags} from '../../jasmine_util'; // flexibility to change in future. const CHAINED_OPS = [ - 'add', 'batchNorm', 'broadcastTo', 'div', 'divNoNan', 'oneHot', 'notEqual', - 'pad', 'square', 'sub', 'tile', 'transpose' + 'add', 'batchNorm', 'broadcastTo', 'concat', 'div', 'divNoNan', 'notEqual', + 'oneHot', 'pad', 'split', 'square', 'sub', 'tile', 'transpose' ]; describeWithFlags('chained ops', ALL_ENVS, () => { diff --git a/tfjs-core/src/public/chained_ops/split.ts b/tfjs-core/src/public/chained_ops/split.ts new file mode 100644 index 00000000000..4044308edec --- /dev/null +++ b/tfjs-core/src/public/chained_ops/split.ts @@ -0,0 +1,32 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ +import {split} from '../../ops/split'; +import {Tensor} from '../../tensor'; +import {Rank} from '../../types'; + +declare module '../../tensor' { + interface Tensor { + split(numOrSizeSplits: number[]|number, axis?: number): + T[]; + } +} + +Tensor.prototype.split = function( + numOrSizeSplits: number[]|number, axis?: number): T[] { + this.throwIfDisposed(); + return split(this, numOrSizeSplits, axis); +}; diff --git a/tfjs-core/src/register_all_gradients.ts b/tfjs-core/src/register_all_gradients.ts index 9c440df575c..3339be99279 100644 --- a/tfjs-core/src/register_all_gradients.ts +++ b/tfjs-core/src/register_all_gradients.ts @@ -17,11 +17,13 @@ import {addGradConfig} from './gradients/Add_grad'; import {addNGradConfig} from './gradients/AddN_grad'; import {broadcastToGradConfig} from './gradients/BroadcastTo_grad'; +import {concatGradConfig} from './gradients/Concat_grad'; import {divGradConfig} from './gradients/Div_grad'; import {fusedBatchNormGradConfig} from './gradients/FusedBatchNorm_grad'; import {identityGradConfig} from './gradients/Identity_grad'; import {oneHotGradConfig} from './gradients/OneHot_grad'; import {padV2GradConfig} from './gradients/PadV2_grad'; +import {splitVGradConfig} from './gradients/SplitV_grad'; import {squareGradConfig} from './gradients/Square_grad'; import {squaredDifferenceGradConfig} from './gradients/SquaredDifference_grad'; import {subGradConfig} from './gradients/Sub_grad'; @@ -32,10 +34,11 @@ import {registerGradient} from './kernel_registry'; // Export all kernel configs here so that the package can auto register them const gradConfigs: GradConfig[] = [ - addGradConfig, addNGradConfig, broadcastToGradConfig, divGradConfig, - fusedBatchNormGradConfig, identityGradConfig, oneHotGradConfig, - padV2GradConfig, squareGradConfig, squaredDifferenceGradConfig, - tileGradConfig, transposeGradConfig, subGradConfig + addGradConfig, addNGradConfig, broadcastToGradConfig, concatGradConfig, + divGradConfig, fusedBatchNormGradConfig, identityGradConfig, oneHotGradConfig, + padV2GradConfig, splitVGradConfig, squareGradConfig, + squaredDifferenceGradConfig, tileGradConfig, transposeGradConfig, + subGradConfig ]; for (const gradientConfig of gradConfigs) { diff --git a/tfjs-core/src/tensor.ts b/tfjs-core/src/tensor.ts index 3759a3e8233..f5ece96217a 100644 --- a/tfjs-core/src/tensor.ts +++ b/tfjs-core/src/tensor.ts @@ -189,10 +189,7 @@ export interface OpHandler { keepDims: boolean): Tensor; slice>( x: T, begin: number|number[], size?: number|number[]): T; - split( - x: T, numOrSizeSplits: number[]|number, axis?: number): T[]; reverse(x: T, axis?: number|number[]): T; - concat(tensors: Array, axis: number): T; stack(tensors: Array, axis: number): Tensor; unstack(value: T, axis: number): Tensor[]; all(x: Tensor, axis: number|number[], keepDims: boolean): T; @@ -790,18 +787,6 @@ export class Tensor { this.throwIfDisposed(); return opHandler.reverse(this, axis); } - concat(this: T, x: T|Array, axis = 0): T { - this.throwIfDisposed(); - if (x instanceof Tensor) { - x = [x]; - } - return opHandler.concat([this, ...x], axis); - } - split(this: T, numOrSizeSplits: number[]|number, axis = 0): - T[] { - this.throwIfDisposed(); - return opHandler.split(this, numOrSizeSplits, axis); - } stack(x: Tensor, axis = 0): Tensor { return opHandler.stack([this, x], axis); }