From 090265725fa38b7973bbd922f9470d0701d4c8c5 Mon Sep 17 00:00:00 2001 From: Dirk Toewe Date: Sat, 19 Oct 2019 20:26:29 +0200 Subject: [PATCH 1/5] Implemented tf.broadcastTo --- tfjs-core/src/ops/array_ops.ts | 67 ++++++++++++++++++++++ tfjs-core/src/ops/array_ops_test.ts | 86 +++++++++++++++++++++++++++++ 2 files changed, 153 insertions(+) diff --git a/tfjs-core/src/ops/array_ops.ts b/tfjs-core/src/ops/array_ops.ts index 804e991b170..dbc6f0650cb 100644 --- a/tfjs-core/src/ops/array_ops.ts +++ b/tfjs-core/src/ops/array_ops.ts @@ -26,6 +26,72 @@ import {op} from './operation'; import {MPRandGauss, RandGamma, UniformRandom} from './rand'; import {zeros, zerosLike} from './tensor_ops'; +/** Broadcast an array to a compatible shape NumPy-style. + * + * The tensor's shape is compared to the broadcast shape from end to beginning. + * Ones are prepended to the tensor's shape until is has the same length as + * the broadcast shape. If input.shape[i]==shape[i], they (i+1)-th axis is + * already broadcast-compatible. If input.shape[i]==1 and shape[i]==N, then + * the input tensor is tiled N times along that axis (using tf.tile). + * + * @param input The tensor that is to be broadcasted. + * @param shape The input is to be broadcast to this shape. + */ +/** @doc {heading: 'Tensors', subheading: 'Transformations'} */ +function broadcastTo_( + x: Tensor|TensorLike, shape: ShapeMap[R] +): Tensor +{ + let input = convertToTensor(x, 'broadcastTo', 'x'); + const xShape = input.shape; + + if( shape.some( d => d < 0 ) ) { + throw new Error(`broadcastTo(): Invalid broadcast shape [${shape}].`); + } + + if( shape.length < input.rank ) { + throw new Error( + `broadcastTo(): shape.length=${shape.length} < input.rank=${input.rank}.` + ); + } + + if( shape.length > input.rank ) + { + const newShape = input.shape.slice(); + while(newShape.length < shape.length ) { + newShape.unshift(1); + } + input = input.reshape(newShape); + } + + const reps: number[] = Array.from(shape); + for( let i=shape.length; i-- > 0; ) + { + if( input.shape[i] === shape[i] ) { + reps[i] = 1; + } + else if( input.shape[i] !== 1 ) { + throw new Error( + `broadcastTo(): [${xShape}] cannot be broadcast to [${shape}].` + ); + } + } + + const axes = reps.map( ( n,i) => n > 1 ? i : -1 ).filter( i => i >= 0 ); + + if( axes.length === 0 ) { + return input as Tensor; + } + + return ENGINE.runKernelFunc( + backend => backend.tile(input,reps), + {input}, + (dy: Tensor) => ({ + input: () => dy.sum(axes,/*keepDims=*/true) + }) + ) as Tensor; +} + /** * Creates a new tensor with the same values and shape as the specified * tensor. @@ -1120,6 +1186,7 @@ export { }; export const batchToSpaceND = op({batchToSpaceND_}); +export const broadcastTo = op({broadcastTo_}); export const cast = op({cast_}); export const clone = op({clone_}); export const cumsum = op({cumsum_}); diff --git a/tfjs-core/src/ops/array_ops_test.ts b/tfjs-core/src/ops/array_ops_test.ts index fe990494af2..834e44982bd 100644 --- a/tfjs-core/src/ops/array_ops_test.ts +++ b/tfjs-core/src/ops/array_ops_test.ts @@ -19,10 +19,96 @@ import * as tf from '../index'; import {ALL_ENVS, BROWSER_ENVS, describeWithFlags, NODE_ENVS} from '../jasmine_util'; import {expectArraysClose, expectArraysEqual, expectPromiseToFail, expectValuesInRange} from '../test_util'; import {TypedArray} from '../types'; +import {Tensor} from '../tensor'; import * as util from '../util'; import {expectArrayInMeanStdRange, jarqueBeraNormalityTest} from './rand_util'; +describeWithFlags('broadcastTo', ALL_ENVS, () => { + it('[] -> [3,2]', async () => { + const a = tf.scalar(4.2); + const A = tf.tensor2d([[4.2, 4.2], + [4.2, 4.2], + [4.2, 4.2]]); + + expectArraysEqual( + await A.array(), + await tf.broadcastTo(a,A.shape).array() + ); + + // test gradients + const w = tf.tensor2d([[ 4.7, 4.5], + [-6.1,-6.6], + [-8.1,-3.4]]), + f = (a: Tensor) => tf.broadcastTo(a,A.shape).mul(w).mean().asScalar(), + h = (a: Tensor) => a.mul(w).mean().asScalar(); + + const df = tf.grad(f), + dh = tf.grad(h); + + expectArraysClose( + await df(a).array(), + await dh(a).array() + ); + }); + + it('[2] -> [3,2]', async () => { + const a = tf.tensor1d( [1,2] ); + const A = tf.tensor2d([[1,2], + [1,2], + [1,2]]); + + expectArraysEqual( + await A.array(), + await tf.broadcastTo(a,A.shape).array() + ); + + // test gradients + const w = tf.tensor2d([[ 4.7, 4.5], + [-6.1,-6.6], + [-8.1,-3.4]]), + f = (a: Tensor) => tf.broadcastTo(a,A.shape).mul(w).mean().asScalar(), + h = (a: Tensor) => a.mul(w).mean().asScalar(); + + const df = tf.grad(f), + dh = tf.grad(h); + + expectArraysClose( + await df(a).array(), + await dh(a).array() + ); + }); + + it('[3,1] -> [3,2]', async () => { + const a = tf.tensor2d([[1], + [2], + [3]]); + const A = tf.tensor2d([[1,1], + [2,2], + [3,3]]); + + expectArraysEqual( + await A.array(), + await tf.broadcastTo(a,A.shape).array() + ); + + // test gradients + const w = tf.tensor2d([[ 4.7, 4.5], + [-6.1,-6.6], + [-8.1,-3.4]]), + f = (a: Tensor) => tf.broadcastTo(a,A.shape).mul(w).mean().asScalar(), + h = (a: Tensor) => a.mul(w).mean().asScalar(); + + const df = tf.grad(f), + dh = tf.grad(h); + + expectArraysClose( + await df(a).array(), + await dh(a).array() + ); + }); +}); + describeWithFlags('zeros', ALL_ENVS, () => { it('1D default dtype', async () => { const a: tf.Tensor1D = tf.zeros([3]); From fc855d661850a412d9274da3179a6ed0295d5678 Mon Sep 17 00:00:00 2001 From: Dirk Toewe Date: Sat, 19 Oct 2019 20:30:52 +0200 Subject: [PATCH 2/5] tf.broadcastTo now checks shape argument for valid integer entries. --- tfjs-core/src/ops/array_ops.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tfjs-core/src/ops/array_ops.ts b/tfjs-core/src/ops/array_ops.ts index dbc6f0650cb..762df0f10ef 100644 --- a/tfjs-core/src/ops/array_ops.ts +++ b/tfjs-core/src/ops/array_ops.ts @@ -45,7 +45,7 @@ function broadcastTo_( let input = convertToTensor(x, 'broadcastTo', 'x'); const xShape = input.shape; - if( shape.some( d => d < 0 ) ) { + if( shape.some(d => !(d > 0) || d%1 !== 0) ) { throw new Error(`broadcastTo(): Invalid broadcast shape [${shape}].`); } From ff1f0ea44955263dcdeec650634d589c7d2408f1 Mon Sep 17 00:00:00 2001 From: Dirk Toewe Date: Sat, 19 Oct 2019 20:51:59 +0200 Subject: [PATCH 3/5] Made test more lenient on WebGL (required on Mobile Safari) --- tfjs-core/src/ops/array_ops_test.ts | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/tfjs-core/src/ops/array_ops_test.ts b/tfjs-core/src/ops/array_ops_test.ts index 834e44982bd..5cdaa165137 100644 --- a/tfjs-core/src/ops/array_ops_test.ts +++ b/tfjs-core/src/ops/array_ops_test.ts @@ -31,7 +31,11 @@ describeWithFlags('broadcastTo', ALL_ENVS, () => { [4.2, 4.2], [4.2, 4.2]]); - expectArraysEqual( + const expectTensorsEqual = tf.getBackend().startsWith('webgl') + ? expectArraysClose // <- necessary on: Mobile Safari 11.0.0 webgl1 + : expectArraysEqual; + + expectTensorsEqual( await A.array(), await tf.broadcastTo(a,A.shape).array() ); @@ -58,7 +62,11 @@ describeWithFlags('broadcastTo', ALL_ENVS, () => { [1,2], [1,2]]); - expectArraysEqual( + const expectTensorsEqual = tf.getBackend().startsWith('webgl') + ? expectArraysClose // <- necessary on: Mobile Safari 11.0.0 webgl1 + : expectArraysEqual; + + expectTensorsEqual( await A.array(), await tf.broadcastTo(a,A.shape).array() ); @@ -87,7 +95,11 @@ describeWithFlags('broadcastTo', ALL_ENVS, () => { [2,2], [3,3]]); - expectArraysEqual( + const expectTensorsEqual = tf.getBackend().startsWith('webgl') + ? expectArraysClose // <- necessary on: Mobile Safari 11.0.0 webgl1 + : expectArraysEqual; + + expectTensorsEqual( await A.array(), await tf.broadcastTo(a,A.shape).array() ); From 365aef43c24e4efb1156d80b8d3bcc00bd5f1fcb Mon Sep 17 00:00:00 2001 From: Dirk Toewe Date: Thu, 21 Nov 2019 11:18:28 +0100 Subject: [PATCH 4/5] Fixed and changed according to review. --- tfjs-core/src/ops/array_ops.ts | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tfjs-core/src/ops/array_ops.ts b/tfjs-core/src/ops/array_ops.ts index 453442f8e93..58a21f9cd4c 100644 --- a/tfjs-core/src/ops/array_ops.ts +++ b/tfjs-core/src/ops/array_ops.ts @@ -30,7 +30,7 @@ import {zeros, zerosLike} from './tensor_ops'; * * The tensor's shape is compared to the broadcast shape from end to beginning. * Ones are prepended to the tensor's shape until is has the same length as - * the broadcast shape. If input.shape[i]==shape[i], they (i+1)-th axis is + * the broadcast shape. If input.shape[i]==shape[i], the (i+1)-th axis is * already broadcast-compatible. If input.shape[i]==1 and shape[i]==N, then * the input tensor is tiled N times along that axis (using tf.tile). * @@ -65,7 +65,7 @@ function broadcastTo_( } const reps: number[] = Array.from(shape); - for( let i=shape.length; i-- > 0; ) + for( let i=shape.length-1; i >= 0; i-- ) { if( input.shape[i] === shape[i] ) { reps[i] = 1; @@ -80,7 +80,7 @@ function broadcastTo_( const axes = reps.map( ( n,i) => n > 1 ? i : -1 ).filter( i => i >= 0 ); if( axes.length === 0 ) { - return input as Tensor; + return input.clone() as Tensor; } return ENGINE.runKernelFunc( From d1ccdba8ca066f60826fd8307a5806140b2288ec Mon Sep 17 00:00:00 2001 From: Dirk Toewe Date: Thu, 21 Nov 2019 14:03:36 +0100 Subject: [PATCH 5/5] Changed expectArraysEqual to expectArraysClose in broadcastTo tests --- tfjs-core/src/ops/array_ops_test.ts | 19 +++---------------- 1 file changed, 3 insertions(+), 16 deletions(-) diff --git a/tfjs-core/src/ops/array_ops_test.ts b/tfjs-core/src/ops/array_ops_test.ts index 5cdaa165137..33103efb735 100644 --- a/tfjs-core/src/ops/array_ops_test.ts +++ b/tfjs-core/src/ops/array_ops_test.ts @@ -31,11 +31,7 @@ describeWithFlags('broadcastTo', ALL_ENVS, () => { [4.2, 4.2], [4.2, 4.2]]); - const expectTensorsEqual = tf.getBackend().startsWith('webgl') - ? expectArraysClose // <- necessary on: Mobile Safari 11.0.0 webgl1 - : expectArraysEqual; - - expectTensorsEqual( + expectArraysClose( await A.array(), await tf.broadcastTo(a,A.shape).array() ); @@ -61,12 +57,7 @@ describeWithFlags('broadcastTo', ALL_ENVS, () => { const A = tf.tensor2d([[1,2], [1,2], [1,2]]); - - const expectTensorsEqual = tf.getBackend().startsWith('webgl') - ? expectArraysClose // <- necessary on: Mobile Safari 11.0.0 webgl1 - : expectArraysEqual; - - expectTensorsEqual( + expectArraysClose( await A.array(), await tf.broadcastTo(a,A.shape).array() ); @@ -95,11 +86,7 @@ describeWithFlags('broadcastTo', ALL_ENVS, () => { [2,2], [3,3]]); - const expectTensorsEqual = tf.getBackend().startsWith('webgl') - ? expectArraysClose // <- necessary on: Mobile Safari 11.0.0 webgl1 - : expectArraysEqual; - - expectTensorsEqual( + expectArraysClose( await A.array(), await tf.broadcastTo(a,A.shape).array() );