From 1863ffe25955ab0426c5950c6142387431d8b55a Mon Sep 17 00:00:00 2001 From: Ping Yu <4018+pyu10055@users.noreply.github.com> Date: Tue, 30 Jun 2020 15:46:30 -0700 Subject: [PATCH 01/11] add various op fixes for ddsp models --- .../executors/slice_join_executor.ts | 9 ------- .../executors/slice_join_executor_test.ts | 15 +---------- .../operations/executors/spectral_executor.ts | 15 +++++++++-- .../executors/spectral_executor_test.ts | 10 ++++++++ tfjs-core/src/ops/array_ops_test.ts | 12 +++++++++ tfjs-core/src/ops/split.ts | 7 ++++++ tfjs-core/src/ops/where.ts | 25 ++++++++++++++----- tfjs-core/src/ops/where_test.ts | 19 +++++++++++++- 8 files changed, 80 insertions(+), 32 deletions(-) diff --git a/tfjs-converter/src/operations/executors/slice_join_executor.ts b/tfjs-converter/src/operations/executors/slice_join_executor.ts index f2da33ea870..e7a0965bd4c 100644 --- a/tfjs-converter/src/operations/executors/slice_join_executor.ts +++ b/tfjs-converter/src/operations/executors/slice_join_executor.ts @@ -128,15 +128,6 @@ export const executeOp: InternalOpExecutor = (node: Node, number[]; const tensor = getParamValue('x', node, tensorMap, context) as tfc.Tensor; - // Allow the last number of split array to be -1, which indicates the rest - // of dimension is allocated to the last split. - if (Array.isArray(numOrSizeSplits)) { - const negIndex = numOrSizeSplits.indexOf(-1); - if (negIndex !== -1) { - const total = numOrSizeSplits.reduce((a, b) => b > 0 ? a + b : a); - numOrSizeSplits[negIndex] = tensor.shape[axis] - total; - } - } return tfc.split(tensor, numOrSizeSplits, axis); } case 'ScatterNd': { diff --git a/tfjs-converter/src/operations/executors/slice_join_executor_test.ts b/tfjs-converter/src/operations/executors/slice_join_executor_test.ts index 1b6852c5412..d8698fc3187 100644 --- a/tfjs-converter/src/operations/executors/slice_join_executor_test.ts +++ b/tfjs-converter/src/operations/executors/slice_join_executor_test.ts @@ -21,7 +21,7 @@ import * as slice_join from '../op_list/slice_join'; import {Node} from '../types'; import {executeOp} from './slice_join_executor'; -import {createNumberAttr, createNumberAttrFromIndex, createNumericArrayAttr, createNumericArrayAttrFromIndex, createTensorAttr, createTensorsAttr, validateParam} from './test_helper'; +import {createNumberAttr, createNumberAttrFromIndex, createNumericArrayAttrFromIndex, createTensorAttr, createTensorsAttr, validateParam} from './test_helper'; describe('slice join', () => { let node: Node; @@ -340,19 +340,6 @@ describe('slice join', () => { expect(tfc.split).toHaveBeenCalledWith(input2[0], 2, 1); }); - it('should support -1 split', () => { - spyOn(tfc, 'split'); - node.op = 'Split'; - node.inputParams.axis = createNumberAttrFromIndex(0); - node.inputParams.x = createTensorAttr(1); - node.attrParams.numOrSizeSplits = createNumericArrayAttr([1, 1, -1]); - node.inputNames = ['input4', 'input3']; - const input3 = [tfc.tensor1d([1, 2, 3, 4, 5])]; - const input4 = [tfc.scalar(0)]; - executeOp(node, {input3, input4}, context); - - expect(tfc.split).toHaveBeenCalledWith(input3[0], [1, 1, 3], 0); - }); it('should match json def for split', () => { node.op = 'Split'; node.inputParams.axis = createNumberAttrFromIndex(0); diff --git a/tfjs-converter/src/operations/executors/spectral_executor.ts b/tfjs-converter/src/operations/executors/spectral_executor.ts index b665774f67a..62f01679beb 100644 --- a/tfjs-converter/src/operations/executors/spectral_executor.ts +++ b/tfjs-converter/src/operations/executors/spectral_executor.ts @@ -40,8 +40,19 @@ export const executeOp: InternalOpExecutor = getParamValue('x', node, tensorMap, context) as tfc.Tensor)]; } case 'IRFFT': { - return [tfc.irfft( - getParamValue('x', node, tensorMap, context) as tfc.Tensor)]; + const x = getParamValue('x', node, tensorMap, context) as tfc.Tensor; + let result = tfc.irfft( + getParamValue('x', node, tensorMap, context) as tfc.Tensor); + // when the input tensor is 3d, tfjs.irfft will treat it as 2d, we + // need to reshape the result. + if (x.rank === 3 && x.shape[0] !== 0) { + const temp = result; + const batch = x.shape[0]; + result = result.reshape( + [batch, result.shape[0] / batch, result.shape[1]]); + temp.dispose(); + } + return [result]; } default: throw TypeError(`Node type ${node.op} is not implemented`); diff --git a/tfjs-converter/src/operations/executors/spectral_executor_test.ts b/tfjs-converter/src/operations/executors/spectral_executor_test.ts index c493a55dda9..c6c7e319e41 100644 --- a/tfjs-converter/src/operations/executors/spectral_executor_test.ts +++ b/tfjs-converter/src/operations/executors/spectral_executor_test.ts @@ -92,6 +92,16 @@ describe('spectral', () => { expect(tfc.irfft).toHaveBeenCalledWith(input1[0]); }); + it('should reshape result for 3d', () => { + const result = tfc.tensor2d([2, 2, 2, 2], [2, 2]); + const input2 = [tfc.tensor3d([2, 2, 2, 2], [1, 2, 2])]; + spyOn(tfc, 'irfft').and.returnValue(result); + node.op = 'IRFFT'; + node.inputNames = ['input2']; + const output = executeOp(node, {input2}, context) as tfc.Tensor[]; + expect(output[0].rank).toEqual(3); + expect(output[0].shape).toEqual([1, 2, 2]); + }); it('should match json def', () => { node.op = 'IRFFT'; diff --git a/tfjs-core/src/ops/array_ops_test.ts b/tfjs-core/src/ops/array_ops_test.ts index 0179f106e9b..2866991761e 100644 --- a/tfjs-core/src/ops/array_ops_test.ts +++ b/tfjs-core/src/ops/array_ops_test.ts @@ -2196,6 +2196,18 @@ describeWithFlags('split', ALL_ENVS, () => { expect(res[2].shape).toEqual([2, 1]); expectArraysClose(await res[2].data(), [4, 8]); }); + it('should support -1 split', async () => { + const x = tf.tensor2d([1, 2, 3, 4, 5, 6, 7, 8], [2, 4]); + const res = x.split([1, 1, -1], 1); + + expect(res.length).toEqual(3); + expect(res[0].shape).toEqual([2, 1]); + expectArraysClose(await res[0].data(), [1, 5]); + expect(res[1].shape).toEqual([2, 1]); + expectArraysClose(await res[1].data(), [2, 6]); + expect(res[2].shape).toEqual([2, 2]); + expectArraysClose(await res[2].data(), [3, 4, 7, 8]); + }); it('sizes to not sum to axis size throws error', () => { const x = tf.tensor2d([1, 2, 3, 4, 5, 6, 7, 8], [2, 4]); diff --git a/tfjs-core/src/ops/split.ts b/tfjs-core/src/ops/split.ts index a68d07481c3..e01fc2b15af 100644 --- a/tfjs-core/src/ops/split.ts +++ b/tfjs-core/src/ops/split.ts @@ -73,6 +73,13 @@ function split_( splitSizes = new Array(numOrSizeSplits).fill($x.shape[$axis] / numOrSizeSplits); } else { + // Allow the last number of split array to be -1, which indicates the rest + // of dimension is allocated to that split. + const negIndex = numOrSizeSplits.indexOf(-1); + if (negIndex !== -1) { + const total = numOrSizeSplits.reduce((a, b) => b > 0 ? a + b : a); + numOrSizeSplits[negIndex] = $x.shape[$axis] - total; + } assert( $x.shape[$axis] === numOrSizeSplits.reduce((a, b) => a + b), () => 'The sum of sizes must match the size of the axis dimension.'); diff --git a/tfjs-core/src/ops/where.ts b/tfjs-core/src/ops/where.ts index 6723e50f8b7..a671f739772 100644 --- a/tfjs-core/src/ops/where.ts +++ b/tfjs-core/src/ops/where.ts @@ -23,8 +23,10 @@ import {convertToTensor} from '../tensor_util_env'; import {TensorLike} from '../types'; import {assert, assertShapesMatch} from '../util'; +import {assertAndGetBroadcastShape} from './broadcast_util'; import {op} from './operation'; + /** * Returns the elements, either `a` or `b` depending on the `condition`. * @@ -41,7 +43,10 @@ import {op} from './operation'; * @param condition The input condition. Must be of dtype bool. * @param a If `condition` is rank 1, `a` may have a higher rank but * its first dimension must match the size of `condition`. - * @param b A tensor with the same shape and type as `a`. + * @param b A tensor with the same dtype as `a` and with shape that is + * compatible with `a`. + * @return A tensor with same detype as `a` and `b`, and shape that is + * broadcastable from `a` and `b`. */ /** @doc {heading: 'Operations', subheading: 'Logical'} */ function where_( @@ -50,22 +55,30 @@ function where_( const $b = convertToTensor(b, 'b', 'where'); const $condition = convertToTensor(condition, 'condition', 'where', 'bool'); - assertShapesMatch($a.shape, $b.shape, 'Error in where: '); + // find the broadcastable shape for $a and $b + const broadcastShape = assertAndGetBroadcastShape($a.shape, $b.shape); + const $broadcastedA = $a.broadcastTo(broadcastShape); + const $broadcastedB = $b.broadcastTo(broadcastShape); if ($condition.rank === 1) { // If condition rank is 1, then the first dimension must match the size of // condition. assert( - $condition.shape[0] === $a.shape[0], + $condition.shape[0] === $broadcastedA.shape[0], () => 'The first dimension of `a` must match the size of `condition`.'); } else { // A must have the same shape as condition. - assertShapesMatch($condition.shape, $b.shape, 'Error in where: '); + assertShapesMatch( + $condition.shape, $broadcastedB.shape, 'Error in where: '); } - const inputs: SelectV2Inputs = {condition: $condition, t: $a, e: $b}; + const inputs: SelectV2Inputs = { + condition: $condition, + t: $broadcastedA, + e: $broadcastedB + }; return ENGINE.runKernelFunc((backend, save) => { - const res = backend.select($condition, $a, $b); + const res = backend.select($condition, $broadcastedA, $broadcastedB); save([$condition]); return res; }, inputs as unknown as NamedTensorMap, null /* gradient */, SelectV2) as T; diff --git a/tfjs-core/src/ops/where_test.ts b/tfjs-core/src/ops/where_test.ts index 2de3d312c55..1b939d26f5c 100644 --- a/tfjs-core/src/ops/where_test.ts +++ b/tfjs-core/src/ops/where_test.ts @@ -212,7 +212,7 @@ describeWithFlags('where', ALL_ENVS, () => { it('Tensor4D different a/b shapes', () => { const c = tf.tensor4d([1, 0, 1, 1], [2, 2, 1, 1], 'bool'); - let a = tf.tensor4d([7, 7, 7, 7, 7, 7, 7, 7], [2, 2, 2, 1]); + let a = tf.tensor4d([7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7], [2, 3, 2, 1]); let b = tf.tensor4d([3, 3, 3, 3], [2, 2, 1, 1]); let f = () => { tf.where(c, a, b); @@ -227,6 +227,23 @@ describeWithFlags('where', ALL_ENVS, () => { expect(f).toThrowError(); }); + it('Tensor4D broadcastable a/b shapes', () => { + const c = tf.tensor4d([1, 0, 1, 1], [2, 2, 1, 1], 'bool'); + const a = tf.tensor4d([7, 7, 7, 7, 7, 7, 7, 7], [2, 2, 2, 1]); + const b = [3]; + let f = () => { + tf.where(c, a, b); + }; + expect(f).toThrowError(); + + const a1 = [7]; + const b1 = tf.tensor4d([3, 3, 3, 3, 3, 3, 3, 3], [2, 2, 2, 1]); + f = () => { + tf.where(c, a1, b1); + }; + expect(f).toThrowError(); + }); + it('Tensor4D different condition/a shapes', () => { const c = tf.tensor4d([1, 0, 1, 1, 1, 0, 1, 1], [2, 2, 2, 1], 'bool'); const a = tf.tensor4d([7, 7, 7, 7], [2, 2, 1, 1]); From f0b294ea4cdd87a830be3300844fa148435c3449 Mon Sep 17 00:00:00 2001 From: Ping Yu <4018+pyu10055@users.noreply.github.com> Date: Tue, 30 Jun 2020 17:13:11 -0700 Subject: [PATCH 02/11] added todo, fix lint --- tfjs-core/src/ops/split.ts | 1 + tfjs-core/src/ops/where.ts | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/tfjs-core/src/ops/split.ts b/tfjs-core/src/ops/split.ts index e01fc2b15af..a69b0b553f0 100644 --- a/tfjs-core/src/ops/split.ts +++ b/tfjs-core/src/ops/split.ts @@ -73,6 +73,7 @@ function split_( splitSizes = new Array(numOrSizeSplits).fill($x.shape[$axis] / numOrSizeSplits); } else { + // TODO(piyu): move the preprocess logic to kernels // Allow the last number of split array to be -1, which indicates the rest // of dimension is allocated to that split. const negIndex = numOrSizeSplits.indexOf(-1); diff --git a/tfjs-core/src/ops/where.ts b/tfjs-core/src/ops/where.ts index a671f739772..a53bcaeea7f 100644 --- a/tfjs-core/src/ops/where.ts +++ b/tfjs-core/src/ops/where.ts @@ -26,7 +26,6 @@ import {assert, assertShapesMatch} from '../util'; import {assertAndGetBroadcastShape} from './broadcast_util'; import {op} from './operation'; - /** * Returns the elements, either `a` or `b` depending on the `condition`. * @@ -55,6 +54,7 @@ function where_( const $b = convertToTensor(b, 'b', 'where'); const $condition = convertToTensor(condition, 'condition', 'where', 'bool'); + // TODO(piyu): move the preprocess logic to kernels // find the broadcastable shape for $a and $b const broadcastShape = assertAndGetBroadcastShape($a.shape, $b.shape); const $broadcastedA = $a.broadcastTo(broadcastShape); From 656cf846687ea45504bca62014fad70e2f72f23e Mon Sep 17 00:00:00 2001 From: Ping Yu <4018+pyu10055@users.noreply.github.com> Date: Wed, 1 Jul 2020 11:34:25 -0700 Subject: [PATCH 03/11] move the logic to forwardFUnc --- tfjs-core/src/ops/split.ts | 27 ++++++++++++++------------- tfjs-core/src/ops/where.ts | 29 ++++++++++++----------------- 2 files changed, 26 insertions(+), 30 deletions(-) diff --git a/tfjs-core/src/ops/split.ts b/tfjs-core/src/ops/split.ts index a69b0b553f0..115fcb612f7 100644 --- a/tfjs-core/src/ops/split.ts +++ b/tfjs-core/src/ops/split.ts @@ -72,22 +72,23 @@ function split_( () => 'Number of splits must evenly divide the axis.'); splitSizes = new Array(numOrSizeSplits).fill($x.shape[$axis] / numOrSizeSplits); - } else { - // TODO(piyu): move the preprocess logic to kernels - // Allow the last number of split array to be -1, which indicates the rest - // of dimension is allocated to that split. - const negIndex = numOrSizeSplits.indexOf(-1); - if (negIndex !== -1) { - const total = numOrSizeSplits.reduce((a, b) => b > 0 ? a + b : a); - numOrSizeSplits[negIndex] = $x.shape[$axis] - total; - } - assert( - $x.shape[$axis] === numOrSizeSplits.reduce((a, b) => a + b), - () => 'The sum of sizes must match the size of the axis dimension.'); - splitSizes = numOrSizeSplits; } const forward: ForwardFunc = (backend, _) => { + if (typeof (numOrSizeSplits) !== 'number') { + // TODO(piyu): move the preprocess logic to kernels + // Allow the last number of split array to be -1, which indicates the rest + // of dimension is allocated to that split. + const negIndex = numOrSizeSplits.indexOf(-1); + if (negIndex !== -1) { + const total = numOrSizeSplits.reduce((a, b) => b > 0 ? a + b : a); + numOrSizeSplits[negIndex] = $x.shape[$axis] - total; + } + assert( + $x.shape[$axis] === numOrSizeSplits.reduce((a, b) => a + b), + () => 'The sum of sizes must match the size of the axis dimension.'); + splitSizes = numOrSizeSplits; + } return backend.split($x, splitSizes, $axis) as {} as T; }; diff --git a/tfjs-core/src/ops/where.ts b/tfjs-core/src/ops/where.ts index a53bcaeea7f..22174d1fc6c 100644 --- a/tfjs-core/src/ops/where.ts +++ b/tfjs-core/src/ops/where.ts @@ -44,7 +44,7 @@ import {op} from './operation'; * its first dimension must match the size of `condition`. * @param b A tensor with the same dtype as `a` and with shape that is * compatible with `a`. - * @return A tensor with same detype as `a` and `b`, and shape that is + * @return A tensor with same dtype as `a` and `b`, and shape that is * broadcastable from `a` and `b`. */ /** @doc {heading: 'Operations', subheading: 'Logical'} */ @@ -54,30 +54,25 @@ function where_( const $b = convertToTensor(b, 'b', 'where'); const $condition = convertToTensor(condition, 'condition', 'where', 'bool'); - // TODO(piyu): move the preprocess logic to kernels - // find the broadcastable shape for $a and $b - const broadcastShape = assertAndGetBroadcastShape($a.shape, $b.shape); - const $broadcastedA = $a.broadcastTo(broadcastShape); - const $broadcastedB = $b.broadcastTo(broadcastShape); - if ($condition.rank === 1) { // If condition rank is 1, then the first dimension must match the size of // condition. assert( - $condition.shape[0] === $broadcastedA.shape[0], + $condition.shape[0] === $a.shape[0], () => 'The first dimension of `a` must match the size of `condition`.'); - } else { - // A must have the same shape as condition. - assertShapesMatch( - $condition.shape, $broadcastedB.shape, 'Error in where: '); } - const inputs: SelectV2Inputs = { - condition: $condition, - t: $broadcastedA, - e: $broadcastedB - }; + const inputs: SelectV2Inputs = {condition: $condition, t: $a, e: $b}; return ENGINE.runKernelFunc((backend, save) => { + // find the broadcastable shape for $a and $b + const broadcastShape = assertAndGetBroadcastShape($a.shape, $b.shape); + const $broadcastedA = $a.broadcastTo(broadcastShape); + const $broadcastedB = $b.broadcastTo(broadcastShape); + if ($condition.rank !== 1) { + // A must have the same shape as condition. + assertShapesMatch( + $condition.shape, $broadcastedB.shape, 'Error in where: '); + } const res = backend.select($condition, $broadcastedA, $broadcastedB); save([$condition]); return res; From 7a42690c30bb3b8f13429026365dda01077e2012 Mon Sep 17 00:00:00 2001 From: Ping Yu <4018+pyu10055@users.noreply.github.com> Date: Wed, 1 Jul 2020 22:27:02 -0700 Subject: [PATCH 04/11] address comments --- tfjs-core/src/ops/split.ts | 20 +++++++++----------- tfjs-core/src/ops/where.ts | 28 +++++++++++++++------------- 2 files changed, 24 insertions(+), 24 deletions(-) diff --git a/tfjs-core/src/ops/split.ts b/tfjs-core/src/ops/split.ts index 115fcb612f7..88aaf54c00a 100644 --- a/tfjs-core/src/ops/split.ts +++ b/tfjs-core/src/ops/split.ts @@ -66,18 +66,16 @@ function split_( const $axis = parseAxisParam(axis, $x.shape)[0]; let splitSizes: number[]; - if (typeof (numOrSizeSplits) === 'number') { - assert( - $x.shape[$axis] % numOrSizeSplits === 0, - () => 'Number of splits must evenly divide the axis.'); - splitSizes = - new Array(numOrSizeSplits).fill($x.shape[$axis] / numOrSizeSplits); - } - const forward: ForwardFunc = (backend, _) => { - if (typeof (numOrSizeSplits) !== 'number') { + if (typeof (numOrSizeSplits) === 'number') { + assert( + $x.shape[$axis] % numOrSizeSplits === 0, + () => 'Number of splits must evenly divide the axis.'); + splitSizes = + new Array(numOrSizeSplits).fill($x.shape[$axis] / numOrSizeSplits); + } else { // TODO(piyu): move the preprocess logic to kernels - // Allow the last number of split array to be -1, which indicates the rest + // Allow the number of split array to be -1, which indicates the rest // of dimension is allocated to that split. const negIndex = numOrSizeSplits.indexOf(-1); if (negIndex !== -1) { @@ -93,7 +91,7 @@ function split_( }; const inputs: SplitVInputs = {x: $x}; - const attr: SplitVAttrs = {numOrSizeSplits, axis}; + const attr: SplitVAttrs = {numOrSizeSplits, axis: $axis}; return ENGINE.runKernelFunc( forward, inputs as {} as NamedTensorMap, null /* grad */, SplitV, diff --git a/tfjs-core/src/ops/where.ts b/tfjs-core/src/ops/where.ts index 22174d1fc6c..0e3dc978b1e 100644 --- a/tfjs-core/src/ops/where.ts +++ b/tfjs-core/src/ops/where.ts @@ -15,7 +15,7 @@ * ============================================================================= */ -import {ENGINE} from '../engine'; +import {ENGINE, ForwardFunc} from '../engine'; import {SelectV2, SelectV2Inputs} from '../kernel_names'; import {Tensor} from '../tensor'; import {NamedTensorMap} from '../tensor_types'; @@ -53,17 +53,15 @@ function where_( const $a = convertToTensor(a, 'a', 'where'); const $b = convertToTensor(b, 'b', 'where'); const $condition = convertToTensor(condition, 'condition', 'where', 'bool'); - - if ($condition.rank === 1) { - // If condition rank is 1, then the first dimension must match the size of - // condition. - assert( - $condition.shape[0] === $a.shape[0], - () => 'The first dimension of `a` must match the size of `condition`.'); - } - - const inputs: SelectV2Inputs = {condition: $condition, t: $a, e: $b}; - return ENGINE.runKernelFunc((backend, save) => { + const forward: ForwardFunc = (backend, save) => { + if ($condition.rank === 1) { + // If condition rank is 1, then the first dimension must match the size of + // condition. + assert( + $condition.shape[0] === $a.shape[0], + () => + 'The first dimension of `a` must match the size of `condition`.'); + } // find the broadcastable shape for $a and $b const broadcastShape = assertAndGetBroadcastShape($a.shape, $b.shape); const $broadcastedA = $a.broadcastTo(broadcastShape); @@ -76,7 +74,11 @@ function where_( const res = backend.select($condition, $broadcastedA, $broadcastedB); save([$condition]); return res; - }, inputs as unknown as NamedTensorMap, null /* gradient */, SelectV2) as T; + }; + const inputs: SelectV2Inputs = {condition: $condition, t: $a, e: $b}; + return ENGINE.runKernelFunc( + forward, inputs as unknown as NamedTensorMap, null /* gradient */, + SelectV2) as T; } export const where = op({where_}); From f14e4bae33ef989cf34ea2dbc3ddddecba67df4b Mon Sep 17 00:00:00 2001 From: Ping Yu <4018+pyu10055@users.noreply.github.com> Date: Wed, 1 Jul 2020 22:33:57 -0700 Subject: [PATCH 05/11] updated doc for the spit op --- tfjs-core/src/ops/split.ts | 1 + 1 file changed, 1 insertion(+) diff --git a/tfjs-core/src/ops/split.ts b/tfjs-core/src/ops/split.ts index 88aaf54c00a..0703b0b0304 100644 --- a/tfjs-core/src/ops/split.ts +++ b/tfjs-core/src/ops/split.ts @@ -55,6 +55,7 @@ import {op} from './operation'; * splits along the axis or an array of integers containing the sizes of * each output tensor along the axis. If a number then it must evenly divide * `x.shape[axis]`; otherwise the sum of sizes must match `x.shape[axis]`. + * Can contain one -1 indicating that dimension is to be inferred. * @param axis The dimension along which to split. Defaults to 0 (the first * dim). */ From 274c8466d4f1dd0b4d4d864eb2e12cb37f68c9e0 Mon Sep 17 00:00:00 2001 From: Ping Yu <4018+pyu10055@users.noreply.github.com> Date: Mon, 6 Jul 2020 11:57:50 -0700 Subject: [PATCH 06/11] moved irfft fixes from converter to core --- .../operations/executors/spectral_executor.ts | 15 ++------------- .../executors/spectral_executor_test.ts | 10 ---------- tfjs-core/src/ops/spectral_ops.ts | 17 ++++++++++++----- tfjs-core/src/ops/spectral_ops_test.ts | 14 ++++++++++---- 4 files changed, 24 insertions(+), 32 deletions(-) diff --git a/tfjs-converter/src/operations/executors/spectral_executor.ts b/tfjs-converter/src/operations/executors/spectral_executor.ts index 62f01679beb..b665774f67a 100644 --- a/tfjs-converter/src/operations/executors/spectral_executor.ts +++ b/tfjs-converter/src/operations/executors/spectral_executor.ts @@ -40,19 +40,8 @@ export const executeOp: InternalOpExecutor = getParamValue('x', node, tensorMap, context) as tfc.Tensor)]; } case 'IRFFT': { - const x = getParamValue('x', node, tensorMap, context) as tfc.Tensor; - let result = tfc.irfft( - getParamValue('x', node, tensorMap, context) as tfc.Tensor); - // when the input tensor is 3d, tfjs.irfft will treat it as 2d, we - // need to reshape the result. - if (x.rank === 3 && x.shape[0] !== 0) { - const temp = result; - const batch = x.shape[0]; - result = result.reshape( - [batch, result.shape[0] / batch, result.shape[1]]); - temp.dispose(); - } - return [result]; + return [tfc.irfft( + getParamValue('x', node, tensorMap, context) as tfc.Tensor)]; } default: throw TypeError(`Node type ${node.op} is not implemented`); diff --git a/tfjs-converter/src/operations/executors/spectral_executor_test.ts b/tfjs-converter/src/operations/executors/spectral_executor_test.ts index c6c7e319e41..c493a55dda9 100644 --- a/tfjs-converter/src/operations/executors/spectral_executor_test.ts +++ b/tfjs-converter/src/operations/executors/spectral_executor_test.ts @@ -92,16 +92,6 @@ describe('spectral', () => { expect(tfc.irfft).toHaveBeenCalledWith(input1[0]); }); - it('should reshape result for 3d', () => { - const result = tfc.tensor2d([2, 2, 2, 2], [2, 2]); - const input2 = [tfc.tensor3d([2, 2, 2, 2], [1, 2, 2])]; - spyOn(tfc, 'irfft').and.returnValue(result); - node.op = 'IRFFT'; - node.inputNames = ['input2']; - const output = executeOp(node, {input2}, context) as tfc.Tensor[]; - expect(output[0].rank).toEqual(3); - expect(output[0].shape).toEqual([1, 2, 2]); - }); it('should match json def', () => { node.op = 'IRFFT'; diff --git a/tfjs-core/src/ops/spectral_ops.ts b/tfjs-core/src/ops/spectral_ops.ts index 8cfd88fec69..839e7f49939 100644 --- a/tfjs-core/src/ops/spectral_ops.ts +++ b/tfjs-core/src/ops/spectral_ops.ts @@ -179,11 +179,10 @@ function rfft_(input: Tensor, fftLength?: number): Tensor { function irfft_(input: Tensor): Tensor { const innerDimensionSize = input.shape[input.shape.length - 1]; const batch = input.size / innerDimensionSize; - + let ret: Tensor; if (innerDimensionSize <= 2) { const complexInput = input.as2D(batch, innerDimensionSize); - const ret = ifft(complexInput); - return real(ret); + ret = ifft(complexInput); } else { // The length of unique components of the DFT of a real-valued signal // is 2 * (input_len - 1) @@ -201,9 +200,17 @@ function irfft_(input: Tensor): Tensor { const r = realInput.concat(realConjugate, 1); const i = imagInput.concat(imagConjugate, 1); const complexInput = complex(r, i).as2D(outputShape[0], outputShape[1]); - const ret = ifft(complexInput); - return real(ret); + ret = ifft(complexInput); + } + ret = real(ret); + // reshape the result if the input is 3D tensor. + if (input.rank === 3 && input.shape[0] !== 0) { + const temp = ret; + const batch = input.shape[0]; + ret = ret.reshape([batch, ret.shape[0] / batch, ret.shape[1]]); + temp.dispose(); } + return ret; } export const fft = op({fft_}); diff --git a/tfjs-core/src/ops/spectral_ops_test.ts b/tfjs-core/src/ops/spectral_ops_test.ts index fe5a7459dc6..97c354ef5d2 100644 --- a/tfjs-core/src/ops/spectral_ops_test.ts +++ b/tfjs-core/src/ops/spectral_ops_test.ts @@ -353,9 +353,11 @@ describeWithFlags('3D IRFFT', ALL_ENVS, () => { const t1Real = tf.tensor3d([1, 2, 3, 4, 1, 2, 3, 4], [2, 2, 2]); const t1Imag = tf.tensor3d([0, 0, 0, 0, 0, 0, 0, 0], [2, 2, 2]); const t1 = tf.complex(t1Real, t1Imag); + const result = tf.spectral.irfft(t1); + expect(result.shape).toEqual([2, 2, 2]); + expectArraysClose( - await tf.spectral.irfft(t1).data(), - [1.5, -0.5, 3.5, -0.5, 1.5, -0.5, 3.5, -0.5]); + await result.data(), [1.5, -0.5, 3.5, -0.5, 1.5, -0.5, 3.5, -0.5]); }); it('should return the same value with TensorFlow (2x2x3 elements)', @@ -365,7 +367,9 @@ describeWithFlags('3D IRFFT', ALL_ENVS, () => { const t1Imag = tf.tensor3d([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [2, 2, 3]); const t1 = tf.complex(t1Real, t1Imag); - expectArraysClose(await tf.spectral.irfft(t1).data(), [ + const result = tf.spectral.irfft(t1); + expect(result.shape).toEqual([2, 2, 4]); + expectArraysClose(await result.data(), [ 2, -0.5, 0, -0.5, 5, -0.5, 0, -0.5, 2, -0.5, 0, -0.5, 5, -0.5, 0, -0.5 ]); }); @@ -377,8 +381,10 @@ describeWithFlags('3D IRFFT', ALL_ENVS, () => { const t1Imag = tf.tensor3d([1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6], [2, 2, 3]); const t1 = tf.complex(t1Real, t1Imag); + const result = tf.spectral.irfft(t1); + expect(result.shape).toEqual([2, 2, 4]); expectArraysClose( - await tf.spectral.irfft(t1).data(), + await result.data(), [2, -1.5, 0, 0.5, 5, -3, 0, 2, 2, -1.5, 0, 0.5, 5, -3, 0, 2]); }); }); From f0127021c7f450f939fa9b4c31f59d250b5ede62 Mon Sep 17 00:00:00 2001 From: Ping Yu <4018+pyu10055@users.noreply.github.com> Date: Mon, 6 Jul 2020 14:59:22 -0700 Subject: [PATCH 07/11] move split log to split_util --- tfjs-backend-wasm/src/kernels/Split.ts | 11 ++--- tfjs-core/src/backends/backend_util.ts | 1 + tfjs-core/src/ops/array_ops_test.ts | 5 +++ tfjs-core/src/ops/split.ts | 24 +---------- tfjs-core/src/ops/split_util.ts | 60 ++++++++++++++++++++++++++ 5 files changed, 71 insertions(+), 30 deletions(-) create mode 100644 tfjs-core/src/ops/split_util.ts diff --git a/tfjs-backend-wasm/src/kernels/Split.ts b/tfjs-backend-wasm/src/kernels/Split.ts index e2e72150606..0aa103e2a5c 100644 --- a/tfjs-backend-wasm/src/kernels/Split.ts +++ b/tfjs-backend-wasm/src/kernels/Split.ts @@ -16,6 +16,7 @@ */ import {NamedAttrMap, NamedTensorInfoMap, registerKernel, SplitV, SplitVAttrs, SplitVInputs, util} from '@tensorflow/tfjs-core'; +import {backend_util} from '@tensorflow/tfjs-core'; import {BackendWasm} from '../backend_wasm'; @@ -32,14 +33,8 @@ export function split(args: { const $axis = util.parseAxisParam(axis, x.shape)[0]; - let splitSizes: number[]; - if (typeof (numOrSizeSplits) === 'number') { - splitSizes = - new Array(numOrSizeSplits).fill(x.shape[$axis] / numOrSizeSplits); - } else { - splitSizes = numOrSizeSplits; - } - + const splitSizes = backend_util.prepareSplitSize(x, numOrSizeSplits, axis); + console.log(numOrSizeSplits, axis, splitSizes); const begin = new Array(x.shape.length).fill(0); const size = x.shape.slice(); return splitSizes.map(s => { diff --git a/tfjs-core/src/backends/backend_util.ts b/tfjs-core/src/backends/backend_util.ts index 901d96fa2b5..c188bf9289a 100644 --- a/tfjs-core/src/backends/backend_util.ts +++ b/tfjs-core/src/backends/backend_util.ts @@ -42,6 +42,7 @@ export * from '../ops/fused_util'; export * from '../ops/erf_util'; export * from '../log'; export * from '../backends/complex_util'; +export * from '../ops/split_util'; import * as segment_util from '../ops/segment_util'; export {segment_util}; diff --git a/tfjs-core/src/ops/array_ops_test.ts b/tfjs-core/src/ops/array_ops_test.ts index 2866991761e..bdb3e9148ba 100644 --- a/tfjs-core/src/ops/array_ops_test.ts +++ b/tfjs-core/src/ops/array_ops_test.ts @@ -2209,6 +2209,11 @@ describeWithFlags('split', ALL_ENVS, () => { expectArraysClose(await res[2].data(), [3, 4, 7, 8]); }); + it('multiple negative number throws error', () => { + const x = tf.tensor2d([1, 2, 3, 4, 5, 6, 7, 8], [2, 4]); + const f = () => tf.split(x, [1, -1, -1], 1); + expect(f).toThrowError(); + }); it('sizes to not sum to axis size throws error', () => { const x = tf.tensor2d([1, 2, 3, 4, 5, 6, 7, 8], [2, 4]); const f = () => tf.split(x, [1, 2], 1); diff --git a/tfjs-core/src/ops/split.ts b/tfjs-core/src/ops/split.ts index 0703b0b0304..4966481d8c6 100644 --- a/tfjs-core/src/ops/split.ts +++ b/tfjs-core/src/ops/split.ts @@ -21,10 +21,10 @@ import {Tensor} from '../tensor'; import {NamedTensorMap} from '../tensor_types'; import {convertToTensor} from '../tensor_util_env'; import {TensorLike} from '../types'; -import {assert,} from '../util'; import {parseAxisParam} from '../util'; import {op} from './operation'; +import {prepareSplitSize} from './split_util'; /** * Splits a `tf.Tensor` into sub tensors. @@ -65,29 +65,9 @@ function split_( const $x = convertToTensor(x, 'x', 'split'); const $axis = parseAxisParam(axis, $x.shape)[0]; - let splitSizes: number[]; const forward: ForwardFunc = (backend, _) => { - if (typeof (numOrSizeSplits) === 'number') { - assert( - $x.shape[$axis] % numOrSizeSplits === 0, - () => 'Number of splits must evenly divide the axis.'); - splitSizes = - new Array(numOrSizeSplits).fill($x.shape[$axis] / numOrSizeSplits); - } else { - // TODO(piyu): move the preprocess logic to kernels - // Allow the number of split array to be -1, which indicates the rest - // of dimension is allocated to that split. - const negIndex = numOrSizeSplits.indexOf(-1); - if (negIndex !== -1) { - const total = numOrSizeSplits.reduce((a, b) => b > 0 ? a + b : a); - numOrSizeSplits[negIndex] = $x.shape[$axis] - total; - } - assert( - $x.shape[$axis] === numOrSizeSplits.reduce((a, b) => a + b), - () => 'The sum of sizes must match the size of the axis dimension.'); - splitSizes = numOrSizeSplits; - } + const splitSizes = prepareSplitSize($x, numOrSizeSplits, $axis); return backend.split($x, splitSizes, $axis) as {} as T; }; diff --git a/tfjs-core/src/ops/split_util.ts b/tfjs-core/src/ops/split_util.ts new file mode 100644 index 00000000000..6e01ac3395b --- /dev/null +++ b/tfjs-core/src/ops/split_util.ts @@ -0,0 +1,60 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ +import {TensorInfo} from '../kernel_registry'; +import {Tensor} from '../tensor'; +import {assert} from '../util'; + +/** + * Prepare the split size array. When the input is a number, the axis is evenly + * divided among the split size. When the input contains the negative value, the + * rest of the axis is allocated toward that. + */ +export function prepareSplitSize( + x: Tensor|TensorInfo, numOrSizeSplits: number[]|number, + axis = 0): number[] { + let splitSizes = []; + if (typeof (numOrSizeSplits) === 'number') { + assert( + x.shape[axis] % numOrSizeSplits === 0, + () => 'Number of splits must evenly divide the axis.'); + splitSizes = + new Array(numOrSizeSplits).fill(x.shape[axis] / numOrSizeSplits); + } else { + const numOfNegs = numOrSizeSplits.reduce((count, value) => { + if (value === -1) { + count += 1; + } + return count; + }, 0); + assert( + numOfNegs <= 1, + () => 'There should be only one negative value in split array.'); + const negIndex = numOrSizeSplits.indexOf(-1); + // Allow the number of split array to be -1, which indicates the rest + // of dimension is allocated to that split. + if (negIndex !== -1) { + const total = numOrSizeSplits.reduce((a, b) => b > 0 ? a + b : a); + numOrSizeSplits[negIndex] = x.shape[axis] - total; + } + assert( + x.shape[axis] === numOrSizeSplits.reduce((a, b) => a + b), + () => 'The sum of sizes must match the size of the axis dimension.'); + splitSizes = numOrSizeSplits; + } + + return splitSizes; +} From 16bb4f04bbe893e5203f6f8978ae3421e52d58b1 Mon Sep 17 00:00:00 2001 From: Ping Yu <4018+pyu10055@users.noreply.github.com> Date: Mon, 6 Jul 2020 14:59:41 -0700 Subject: [PATCH 08/11] remove console.log --- tfjs-backend-wasm/src/kernels/Split.ts | 1 - 1 file changed, 1 deletion(-) diff --git a/tfjs-backend-wasm/src/kernels/Split.ts b/tfjs-backend-wasm/src/kernels/Split.ts index 0aa103e2a5c..f1236fd4a6b 100644 --- a/tfjs-backend-wasm/src/kernels/Split.ts +++ b/tfjs-backend-wasm/src/kernels/Split.ts @@ -34,7 +34,6 @@ export function split(args: { const $axis = util.parseAxisParam(axis, x.shape)[0]; const splitSizes = backend_util.prepareSplitSize(x, numOrSizeSplits, axis); - console.log(numOrSizeSplits, axis, splitSizes); const begin = new Array(x.shape.length).fill(0); const size = x.shape.slice(); return splitSizes.map(s => { From 6d2277654803a850bcc7039b0886f5165bb02347 Mon Sep 17 00:00:00 2001 From: Ping Yu <4018+pyu10055@users.noreply.github.com> Date: Mon, 6 Jul 2020 16:15:39 -0700 Subject: [PATCH 09/11] move the broadcast and validation logic out of forwardFunc --- tfjs-core/src/ops/where.ts | 42 ++++++++++++++++++++++---------------- 1 file changed, 24 insertions(+), 18 deletions(-) diff --git a/tfjs-core/src/ops/where.ts b/tfjs-core/src/ops/where.ts index 0e3dc978b1e..dcb315d77a4 100644 --- a/tfjs-core/src/ops/where.ts +++ b/tfjs-core/src/ops/where.ts @@ -53,29 +53,35 @@ function where_( const $a = convertToTensor(a, 'a', 'where'); const $b = convertToTensor(b, 'b', 'where'); const $condition = convertToTensor(condition, 'condition', 'where', 'bool'); + // find the broadcastable shape for $a and $b + const broadcastShape = assertAndGetBroadcastShape($a.shape, $b.shape); + const $broadcastedA = $a.broadcastTo(broadcastShape); + const $broadcastedB = $b.broadcastTo(broadcastShape); + if ($condition.rank === 1) { + // If condition rank is 1, then the first dimension must match the size of + // condition. + assert( + $condition.shape[0] === $a.shape[0], + () => + 'The first dimension of `a` must match the size of `condition`.'); + } + + if ($condition.rank !== 1) { + // A must have the same shape as condition. + assertShapesMatch( + $condition.shape, $broadcastedB.shape, 'Error in where: '); + } + const forward: ForwardFunc = (backend, save) => { - if ($condition.rank === 1) { - // If condition rank is 1, then the first dimension must match the size of - // condition. - assert( - $condition.shape[0] === $a.shape[0], - () => - 'The first dimension of `a` must match the size of `condition`.'); - } - // find the broadcastable shape for $a and $b - const broadcastShape = assertAndGetBroadcastShape($a.shape, $b.shape); - const $broadcastedA = $a.broadcastTo(broadcastShape); - const $broadcastedB = $b.broadcastTo(broadcastShape); - if ($condition.rank !== 1) { - // A must have the same shape as condition. - assertShapesMatch( - $condition.shape, $broadcastedB.shape, 'Error in where: '); - } const res = backend.select($condition, $broadcastedA, $broadcastedB); save([$condition]); return res; }; - const inputs: SelectV2Inputs = {condition: $condition, t: $a, e: $b}; + const inputs: SelectV2Inputs = { + condition: $condition, + t: $broadcastedA, + e: $broadcastedB + }; return ENGINE.runKernelFunc( forward, inputs as unknown as NamedTensorMap, null /* gradient */, SelectV2) as T; From a8875f11d55ca472a980eb221f9665880de69c2d Mon Sep 17 00:00:00 2001 From: Ping Yu <4018+pyu10055@users.noreply.github.com> Date: Wed, 8 Jul 2020 10:27:07 -0700 Subject: [PATCH 10/11] added todo comment --- tfjs-core/src/ops/where.ts | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tfjs-core/src/ops/where.ts b/tfjs-core/src/ops/where.ts index dcb315d77a4..03912334892 100644 --- a/tfjs-core/src/ops/where.ts +++ b/tfjs-core/src/ops/where.ts @@ -53,7 +53,9 @@ function where_( const $a = convertToTensor(a, 'a', 'where'); const $b = convertToTensor(b, 'b', 'where'); const $condition = convertToTensor(condition, 'condition', 'where', 'bool'); - // find the broadcastable shape for $a and $b + // TODO: move this logic to forward function when the broadcastTo op is + // implemented in WASM. + // Find the broadcastable shape for $a and $b. const broadcastShape = assertAndGetBroadcastShape($a.shape, $b.shape); const $broadcastedA = $a.broadcastTo(broadcastShape); const $broadcastedB = $b.broadcastTo(broadcastShape); @@ -62,8 +64,7 @@ function where_( // condition. assert( $condition.shape[0] === $a.shape[0], - () => - 'The first dimension of `a` must match the size of `condition`.'); + () => 'The first dimension of `a` must match the size of `condition`.'); } if ($condition.rank !== 1) { From 6b11b8d657cd6fc996448bad150a85b5039feee1 Mon Sep 17 00:00:00 2001 From: Ping Yu <4018+pyu10055@users.noreply.github.com> Date: Wed, 8 Jul 2020 10:29:56 -0700 Subject: [PATCH 11/11] update axis param --- tfjs-core/src/ops/split.ts | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tfjs-core/src/ops/split.ts b/tfjs-core/src/ops/split.ts index 4966481d8c6..6dc30b64a4f 100644 --- a/tfjs-core/src/ops/split.ts +++ b/tfjs-core/src/ops/split.ts @@ -64,15 +64,14 @@ function split_( x: Tensor|TensorLike, numOrSizeSplits: number[]|number, axis = 0): T[] { const $x = convertToTensor(x, 'x', 'split'); - const $axis = parseAxisParam(axis, $x.shape)[0]; - const forward: ForwardFunc = (backend, _) => { + const $axis = parseAxisParam(axis, $x.shape)[0]; const splitSizes = prepareSplitSize($x, numOrSizeSplits, $axis); return backend.split($x, splitSizes, $axis) as {} as T; }; const inputs: SplitVInputs = {x: $x}; - const attr: SplitVAttrs = {numOrSizeSplits, axis: $axis}; + const attr: SplitVAttrs = {numOrSizeSplits, axis}; return ENGINE.runKernelFunc( forward, inputs as {} as NamedTensorMap, null /* grad */, SplitV,