diff --git a/tfjs-backend-wasm/src/kernels/Split.ts b/tfjs-backend-wasm/src/kernels/Split.ts index e2e72150606..f1236fd4a6b 100644 --- a/tfjs-backend-wasm/src/kernels/Split.ts +++ b/tfjs-backend-wasm/src/kernels/Split.ts @@ -16,6 +16,7 @@ */ import {NamedAttrMap, NamedTensorInfoMap, registerKernel, SplitV, SplitVAttrs, SplitVInputs, util} from '@tensorflow/tfjs-core'; +import {backend_util} from '@tensorflow/tfjs-core'; import {BackendWasm} from '../backend_wasm'; @@ -32,14 +33,7 @@ export function split(args: { const $axis = util.parseAxisParam(axis, x.shape)[0]; - let splitSizes: number[]; - if (typeof (numOrSizeSplits) === 'number') { - splitSizes = - new Array(numOrSizeSplits).fill(x.shape[$axis] / numOrSizeSplits); - } else { - splitSizes = numOrSizeSplits; - } - + const splitSizes = backend_util.prepareSplitSize(x, numOrSizeSplits, axis); const begin = new Array(x.shape.length).fill(0); const size = x.shape.slice(); return splitSizes.map(s => { diff --git a/tfjs-converter/src/operations/executors/slice_join_executor.ts b/tfjs-converter/src/operations/executors/slice_join_executor.ts index f2da33ea870..e7a0965bd4c 100644 --- a/tfjs-converter/src/operations/executors/slice_join_executor.ts +++ b/tfjs-converter/src/operations/executors/slice_join_executor.ts @@ -128,15 +128,6 @@ export const executeOp: InternalOpExecutor = (node: Node, number[]; const tensor = getParamValue('x', node, tensorMap, context) as tfc.Tensor; - // Allow the last number of split array to be -1, which indicates the rest - // of dimension is allocated to the last split. - if (Array.isArray(numOrSizeSplits)) { - const negIndex = numOrSizeSplits.indexOf(-1); - if (negIndex !== -1) { - const total = numOrSizeSplits.reduce((a, b) => b > 0 ? a + b : a); - numOrSizeSplits[negIndex] = tensor.shape[axis] - total; - } - } return tfc.split(tensor, numOrSizeSplits, axis); } case 'ScatterNd': { diff --git a/tfjs-converter/src/operations/executors/slice_join_executor_test.ts b/tfjs-converter/src/operations/executors/slice_join_executor_test.ts index 1b6852c5412..d8698fc3187 100644 --- a/tfjs-converter/src/operations/executors/slice_join_executor_test.ts +++ b/tfjs-converter/src/operations/executors/slice_join_executor_test.ts @@ -21,7 +21,7 @@ import * as slice_join from '../op_list/slice_join'; import {Node} from '../types'; import {executeOp} from './slice_join_executor'; -import {createNumberAttr, createNumberAttrFromIndex, createNumericArrayAttr, createNumericArrayAttrFromIndex, createTensorAttr, createTensorsAttr, validateParam} from './test_helper'; +import {createNumberAttr, createNumberAttrFromIndex, createNumericArrayAttrFromIndex, createTensorAttr, createTensorsAttr, validateParam} from './test_helper'; describe('slice join', () => { let node: Node; @@ -340,19 +340,6 @@ describe('slice join', () => { expect(tfc.split).toHaveBeenCalledWith(input2[0], 2, 1); }); - it('should support -1 split', () => { - spyOn(tfc, 'split'); - node.op = 'Split'; - node.inputParams.axis = createNumberAttrFromIndex(0); - node.inputParams.x = createTensorAttr(1); - node.attrParams.numOrSizeSplits = createNumericArrayAttr([1, 1, -1]); - node.inputNames = ['input4', 'input3']; - const input3 = [tfc.tensor1d([1, 2, 3, 4, 5])]; - const input4 = [tfc.scalar(0)]; - executeOp(node, {input3, input4}, context); - - expect(tfc.split).toHaveBeenCalledWith(input3[0], [1, 1, 3], 0); - }); it('should match json def for split', () => { node.op = 'Split'; node.inputParams.axis = createNumberAttrFromIndex(0); diff --git a/tfjs-core/src/backends/backend_util.ts b/tfjs-core/src/backends/backend_util.ts index 901d96fa2b5..c188bf9289a 100644 --- a/tfjs-core/src/backends/backend_util.ts +++ b/tfjs-core/src/backends/backend_util.ts @@ -42,6 +42,7 @@ export * from '../ops/fused_util'; export * from '../ops/erf_util'; export * from '../log'; export * from '../backends/complex_util'; +export * from '../ops/split_util'; import * as segment_util from '../ops/segment_util'; export {segment_util}; diff --git a/tfjs-core/src/ops/array_ops_test.ts b/tfjs-core/src/ops/array_ops_test.ts index c124102dd04..c7c3561fde1 100644 --- a/tfjs-core/src/ops/array_ops_test.ts +++ b/tfjs-core/src/ops/array_ops_test.ts @@ -1706,7 +1706,24 @@ describeWithFlags('split', ALL_ENVS, () => { expect(res[2].shape).toEqual([2, 1]); expectArraysClose(await res[2].data(), [4, 8]); }); + it('should support -1 split', async () => { + const x = tf.tensor2d([1, 2, 3, 4, 5, 6, 7, 8], [2, 4]); + const res = x.split([1, 1, -1], 1); + expect(res.length).toEqual(3); + expect(res[0].shape).toEqual([2, 1]); + expectArraysClose(await res[0].data(), [1, 5]); + expect(res[1].shape).toEqual([2, 1]); + expectArraysClose(await res[1].data(), [2, 6]); + expect(res[2].shape).toEqual([2, 2]); + expectArraysClose(await res[2].data(), [3, 4, 7, 8]); + }); + + it('multiple negative number throws error', () => { + const x = tf.tensor2d([1, 2, 3, 4, 5, 6, 7, 8], [2, 4]); + const f = () => tf.split(x, [1, -1, -1], 1); + expect(f).toThrowError(); + }); it('sizes to not sum to axis size throws error', () => { const x = tf.tensor2d([1, 2, 3, 4, 5, 6, 7, 8], [2, 4]); const f = () => tf.split(x, [1, 2], 1); diff --git a/tfjs-core/src/ops/spectral_ops.ts b/tfjs-core/src/ops/spectral_ops.ts index 8cfd88fec69..839e7f49939 100644 --- a/tfjs-core/src/ops/spectral_ops.ts +++ b/tfjs-core/src/ops/spectral_ops.ts @@ -179,11 +179,10 @@ function rfft_(input: Tensor, fftLength?: number): Tensor { function irfft_(input: Tensor): Tensor { const innerDimensionSize = input.shape[input.shape.length - 1]; const batch = input.size / innerDimensionSize; - + let ret: Tensor; if (innerDimensionSize <= 2) { const complexInput = input.as2D(batch, innerDimensionSize); - const ret = ifft(complexInput); - return real(ret); + ret = ifft(complexInput); } else { // The length of unique components of the DFT of a real-valued signal // is 2 * (input_len - 1) @@ -201,9 +200,17 @@ function irfft_(input: Tensor): Tensor { const r = realInput.concat(realConjugate, 1); const i = imagInput.concat(imagConjugate, 1); const complexInput = complex(r, i).as2D(outputShape[0], outputShape[1]); - const ret = ifft(complexInput); - return real(ret); + ret = ifft(complexInput); + } + ret = real(ret); + // reshape the result if the input is 3D tensor. + if (input.rank === 3 && input.shape[0] !== 0) { + const temp = ret; + const batch = input.shape[0]; + ret = ret.reshape([batch, ret.shape[0] / batch, ret.shape[1]]); + temp.dispose(); } + return ret; } export const fft = op({fft_}); diff --git a/tfjs-core/src/ops/spectral_ops_test.ts b/tfjs-core/src/ops/spectral_ops_test.ts index fe5a7459dc6..97c354ef5d2 100644 --- a/tfjs-core/src/ops/spectral_ops_test.ts +++ b/tfjs-core/src/ops/spectral_ops_test.ts @@ -353,9 +353,11 @@ describeWithFlags('3D IRFFT', ALL_ENVS, () => { const t1Real = tf.tensor3d([1, 2, 3, 4, 1, 2, 3, 4], [2, 2, 2]); const t1Imag = tf.tensor3d([0, 0, 0, 0, 0, 0, 0, 0], [2, 2, 2]); const t1 = tf.complex(t1Real, t1Imag); + const result = tf.spectral.irfft(t1); + expect(result.shape).toEqual([2, 2, 2]); + expectArraysClose( - await tf.spectral.irfft(t1).data(), - [1.5, -0.5, 3.5, -0.5, 1.5, -0.5, 3.5, -0.5]); + await result.data(), [1.5, -0.5, 3.5, -0.5, 1.5, -0.5, 3.5, -0.5]); }); it('should return the same value with TensorFlow (2x2x3 elements)', @@ -365,7 +367,9 @@ describeWithFlags('3D IRFFT', ALL_ENVS, () => { const t1Imag = tf.tensor3d([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [2, 2, 3]); const t1 = tf.complex(t1Real, t1Imag); - expectArraysClose(await tf.spectral.irfft(t1).data(), [ + const result = tf.spectral.irfft(t1); + expect(result.shape).toEqual([2, 2, 4]); + expectArraysClose(await result.data(), [ 2, -0.5, 0, -0.5, 5, -0.5, 0, -0.5, 2, -0.5, 0, -0.5, 5, -0.5, 0, -0.5 ]); }); @@ -377,8 +381,10 @@ describeWithFlags('3D IRFFT', ALL_ENVS, () => { const t1Imag = tf.tensor3d([1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6], [2, 2, 3]); const t1 = tf.complex(t1Real, t1Imag); + const result = tf.spectral.irfft(t1); + expect(result.shape).toEqual([2, 2, 4]); expectArraysClose( - await tf.spectral.irfft(t1).data(), + await result.data(), [2, -1.5, 0, 0.5, 5, -3, 0, 2, 2, -1.5, 0, 0.5, 5, -3, 0, 2]); }); }); diff --git a/tfjs-core/src/ops/split.ts b/tfjs-core/src/ops/split.ts index a68d07481c3..6dc30b64a4f 100644 --- a/tfjs-core/src/ops/split.ts +++ b/tfjs-core/src/ops/split.ts @@ -21,10 +21,10 @@ import {Tensor} from '../tensor'; import {NamedTensorMap} from '../tensor_types'; import {convertToTensor} from '../tensor_util_env'; import {TensorLike} from '../types'; -import {assert,} from '../util'; import {parseAxisParam} from '../util'; import {op} from './operation'; +import {prepareSplitSize} from './split_util'; /** * Splits a `tf.Tensor` into sub tensors. @@ -55,6 +55,7 @@ import {op} from './operation'; * splits along the axis or an array of integers containing the sizes of * each output tensor along the axis. If a number then it must evenly divide * `x.shape[axis]`; otherwise the sum of sizes must match `x.shape[axis]`. + * Can contain one -1 indicating that dimension is to be inferred. * @param axis The dimension along which to split. Defaults to 0 (the first * dim). */ @@ -63,23 +64,9 @@ function split_( x: Tensor|TensorLike, numOrSizeSplits: number[]|number, axis = 0): T[] { const $x = convertToTensor(x, 'x', 'split'); - const $axis = parseAxisParam(axis, $x.shape)[0]; - let splitSizes: number[]; - - if (typeof (numOrSizeSplits) === 'number') { - assert( - $x.shape[$axis] % numOrSizeSplits === 0, - () => 'Number of splits must evenly divide the axis.'); - splitSizes = - new Array(numOrSizeSplits).fill($x.shape[$axis] / numOrSizeSplits); - } else { - assert( - $x.shape[$axis] === numOrSizeSplits.reduce((a, b) => a + b), - () => 'The sum of sizes must match the size of the axis dimension.'); - splitSizes = numOrSizeSplits; - } - const forward: ForwardFunc = (backend, _) => { + const $axis = parseAxisParam(axis, $x.shape)[0]; + const splitSizes = prepareSplitSize($x, numOrSizeSplits, $axis); return backend.split($x, splitSizes, $axis) as {} as T; }; diff --git a/tfjs-core/src/ops/split_util.ts b/tfjs-core/src/ops/split_util.ts new file mode 100644 index 00000000000..6e01ac3395b --- /dev/null +++ b/tfjs-core/src/ops/split_util.ts @@ -0,0 +1,60 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ +import {TensorInfo} from '../kernel_registry'; +import {Tensor} from '../tensor'; +import {assert} from '../util'; + +/** + * Prepare the split size array. When the input is a number, the axis is evenly + * divided among the split size. When the input contains the negative value, the + * rest of the axis is allocated toward that. + */ +export function prepareSplitSize( + x: Tensor|TensorInfo, numOrSizeSplits: number[]|number, + axis = 0): number[] { + let splitSizes = []; + if (typeof (numOrSizeSplits) === 'number') { + assert( + x.shape[axis] % numOrSizeSplits === 0, + () => 'Number of splits must evenly divide the axis.'); + splitSizes = + new Array(numOrSizeSplits).fill(x.shape[axis] / numOrSizeSplits); + } else { + const numOfNegs = numOrSizeSplits.reduce((count, value) => { + if (value === -1) { + count += 1; + } + return count; + }, 0); + assert( + numOfNegs <= 1, + () => 'There should be only one negative value in split array.'); + const negIndex = numOrSizeSplits.indexOf(-1); + // Allow the number of split array to be -1, which indicates the rest + // of dimension is allocated to that split. + if (negIndex !== -1) { + const total = numOrSizeSplits.reduce((a, b) => b > 0 ? a + b : a); + numOrSizeSplits[negIndex] = x.shape[axis] - total; + } + assert( + x.shape[axis] === numOrSizeSplits.reduce((a, b) => a + b), + () => 'The sum of sizes must match the size of the axis dimension.'); + splitSizes = numOrSizeSplits; + } + + return splitSizes; +} diff --git a/tfjs-core/src/ops/where.ts b/tfjs-core/src/ops/where.ts index 6723e50f8b7..03912334892 100644 --- a/tfjs-core/src/ops/where.ts +++ b/tfjs-core/src/ops/where.ts @@ -15,7 +15,7 @@ * ============================================================================= */ -import {ENGINE} from '../engine'; +import {ENGINE, ForwardFunc} from '../engine'; import {SelectV2, SelectV2Inputs} from '../kernel_names'; import {Tensor} from '../tensor'; import {NamedTensorMap} from '../tensor_types'; @@ -23,6 +23,7 @@ import {convertToTensor} from '../tensor_util_env'; import {TensorLike} from '../types'; import {assert, assertShapesMatch} from '../util'; +import {assertAndGetBroadcastShape} from './broadcast_util'; import {op} from './operation'; /** @@ -41,7 +42,10 @@ import {op} from './operation'; * @param condition The input condition. Must be of dtype bool. * @param a If `condition` is rank 1, `a` may have a higher rank but * its first dimension must match the size of `condition`. - * @param b A tensor with the same shape and type as `a`. + * @param b A tensor with the same dtype as `a` and with shape that is + * compatible with `a`. + * @return A tensor with same dtype as `a` and `b`, and shape that is + * broadcastable from `a` and `b`. */ /** @doc {heading: 'Operations', subheading: 'Logical'} */ function where_( @@ -49,26 +53,39 @@ function where_( const $a = convertToTensor(a, 'a', 'where'); const $b = convertToTensor(b, 'b', 'where'); const $condition = convertToTensor(condition, 'condition', 'where', 'bool'); - - assertShapesMatch($a.shape, $b.shape, 'Error in where: '); - + // TODO: move this logic to forward function when the broadcastTo op is + // implemented in WASM. + // Find the broadcastable shape for $a and $b. + const broadcastShape = assertAndGetBroadcastShape($a.shape, $b.shape); + const $broadcastedA = $a.broadcastTo(broadcastShape); + const $broadcastedB = $b.broadcastTo(broadcastShape); if ($condition.rank === 1) { // If condition rank is 1, then the first dimension must match the size of // condition. assert( $condition.shape[0] === $a.shape[0], () => 'The first dimension of `a` must match the size of `condition`.'); - } else { + } + + if ($condition.rank !== 1) { // A must have the same shape as condition. - assertShapesMatch($condition.shape, $b.shape, 'Error in where: '); + assertShapesMatch( + $condition.shape, $broadcastedB.shape, 'Error in where: '); } - const inputs: SelectV2Inputs = {condition: $condition, t: $a, e: $b}; - return ENGINE.runKernelFunc((backend, save) => { - const res = backend.select($condition, $a, $b); + const forward: ForwardFunc = (backend, save) => { + const res = backend.select($condition, $broadcastedA, $broadcastedB); save([$condition]); return res; - }, inputs as unknown as NamedTensorMap, null /* gradient */, SelectV2) as T; + }; + const inputs: SelectV2Inputs = { + condition: $condition, + t: $broadcastedA, + e: $broadcastedB + }; + return ENGINE.runKernelFunc( + forward, inputs as unknown as NamedTensorMap, null /* gradient */, + SelectV2) as T; } export const where = op({where_}); diff --git a/tfjs-core/src/ops/where_test.ts b/tfjs-core/src/ops/where_test.ts index 2de3d312c55..1b939d26f5c 100644 --- a/tfjs-core/src/ops/where_test.ts +++ b/tfjs-core/src/ops/where_test.ts @@ -212,7 +212,7 @@ describeWithFlags('where', ALL_ENVS, () => { it('Tensor4D different a/b shapes', () => { const c = tf.tensor4d([1, 0, 1, 1], [2, 2, 1, 1], 'bool'); - let a = tf.tensor4d([7, 7, 7, 7, 7, 7, 7, 7], [2, 2, 2, 1]); + let a = tf.tensor4d([7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7], [2, 3, 2, 1]); let b = tf.tensor4d([3, 3, 3, 3], [2, 2, 1, 1]); let f = () => { tf.where(c, a, b); @@ -227,6 +227,23 @@ describeWithFlags('where', ALL_ENVS, () => { expect(f).toThrowError(); }); + it('Tensor4D broadcastable a/b shapes', () => { + const c = tf.tensor4d([1, 0, 1, 1], [2, 2, 1, 1], 'bool'); + const a = tf.tensor4d([7, 7, 7, 7, 7, 7, 7, 7], [2, 2, 2, 1]); + const b = [3]; + let f = () => { + tf.where(c, a, b); + }; + expect(f).toThrowError(); + + const a1 = [7]; + const b1 = tf.tensor4d([3, 3, 3, 3, 3, 3, 3, 3], [2, 2, 2, 1]); + f = () => { + tf.where(c, a1, b1); + }; + expect(f).toThrowError(); + }); + it('Tensor4D different condition/a shapes', () => { const c = tf.tensor4d([1, 0, 1, 1, 1, 0, 1, 1], [2, 2, 2, 1], 'bool'); const a = tf.tensor4d([7, 7, 7, 7], [2, 2, 1, 1]);