diff --git a/tfjs-core/src/ops/array_ops.ts b/tfjs-core/src/ops/array_ops.ts index f770b48c584..58a21f9cd4c 100644 --- a/tfjs-core/src/ops/array_ops.ts +++ b/tfjs-core/src/ops/array_ops.ts @@ -26,6 +26,72 @@ import {op} from './operation'; import {MPRandGauss, RandGamma, UniformRandom} from './rand'; import {zeros, zerosLike} from './tensor_ops'; +/** Broadcast an array to a compatible shape NumPy-style. + * + * The tensor's shape is compared to the broadcast shape from end to beginning. + * Ones are prepended to the tensor's shape until is has the same length as + * the broadcast shape. If input.shape[i]==shape[i], the (i+1)-th axis is + * already broadcast-compatible. If input.shape[i]==1 and shape[i]==N, then + * the input tensor is tiled N times along that axis (using tf.tile). + * + * @param input The tensor that is to be broadcasted. + * @param shape The input is to be broadcast to this shape. + */ +/** @doc {heading: 'Tensors', subheading: 'Transformations'} */ +function broadcastTo_( + x: Tensor|TensorLike, shape: ShapeMap[R] +): Tensor +{ + let input = convertToTensor(x, 'broadcastTo', 'x'); + const xShape = input.shape; + + if( shape.some(d => !(d > 0) || d%1 !== 0) ) { + throw new Error(`broadcastTo(): Invalid broadcast shape [${shape}].`); + } + + if( shape.length < input.rank ) { + throw new Error( + `broadcastTo(): shape.length=${shape.length} < input.rank=${input.rank}.` + ); + } + + if( shape.length > input.rank ) + { + const newShape = input.shape.slice(); + while(newShape.length < shape.length ) { + newShape.unshift(1); + } + input = input.reshape(newShape); + } + + const reps: number[] = Array.from(shape); + for( let i=shape.length-1; i >= 0; i-- ) + { + if( input.shape[i] === shape[i] ) { + reps[i] = 1; + } + else if( input.shape[i] !== 1 ) { + throw new Error( + `broadcastTo(): [${xShape}] cannot be broadcast to [${shape}].` + ); + } + } + + const axes = reps.map( ( n,i) => n > 1 ? i : -1 ).filter( i => i >= 0 ); + + if( axes.length === 0 ) { + return input.clone() as Tensor; + } + + return ENGINE.runKernelFunc( + backend => backend.tile(input,reps), + {input}, + (dy: Tensor) => ({ + input: () => dy.sum(axes,/*keepDims=*/true) + }) + ) as Tensor; +} + /** * Creates a new tensor with the same values and shape as the specified * tensor. @@ -1126,6 +1192,7 @@ export { }; export const batchToSpaceND = op({batchToSpaceND_}); +export const broadcastTo = op({broadcastTo_}); export const cast = op({cast_}); export const clone = op({clone_}); export const cumsum = op({cumsum_}); diff --git a/tfjs-core/src/ops/array_ops_test.ts b/tfjs-core/src/ops/array_ops_test.ts index fe990494af2..33103efb735 100644 --- a/tfjs-core/src/ops/array_ops_test.ts +++ b/tfjs-core/src/ops/array_ops_test.ts @@ -19,10 +19,95 @@ import * as tf from '../index'; import {ALL_ENVS, BROWSER_ENVS, describeWithFlags, NODE_ENVS} from '../jasmine_util'; import {expectArraysClose, expectArraysEqual, expectPromiseToFail, expectValuesInRange} from '../test_util'; import {TypedArray} from '../types'; +import {Tensor} from '../tensor'; import * as util from '../util'; import {expectArrayInMeanStdRange, jarqueBeraNormalityTest} from './rand_util'; +describeWithFlags('broadcastTo', ALL_ENVS, () => { + it('[] -> [3,2]', async () => { + const a = tf.scalar(4.2); + const A = tf.tensor2d([[4.2, 4.2], + [4.2, 4.2], + [4.2, 4.2]]); + + expectArraysClose( + await A.array(), + await tf.broadcastTo(a,A.shape).array() + ); + + // test gradients + const w = tf.tensor2d([[ 4.7, 4.5], + [-6.1,-6.6], + [-8.1,-3.4]]), + f = (a: Tensor) => tf.broadcastTo(a,A.shape).mul(w).mean().asScalar(), + h = (a: Tensor) => a.mul(w).mean().asScalar(); + + const df = tf.grad(f), + dh = tf.grad(h); + + expectArraysClose( + await df(a).array(), + await dh(a).array() + ); + }); + + it('[2] -> [3,2]', async () => { + const a = tf.tensor1d( [1,2] ); + const A = tf.tensor2d([[1,2], + [1,2], + [1,2]]); + expectArraysClose( + await A.array(), + await tf.broadcastTo(a,A.shape).array() + ); + + // test gradients + const w = tf.tensor2d([[ 4.7, 4.5], + [-6.1,-6.6], + [-8.1,-3.4]]), + f = (a: Tensor) => tf.broadcastTo(a,A.shape).mul(w).mean().asScalar(), + h = (a: Tensor) => a.mul(w).mean().asScalar(); + + const df = tf.grad(f), + dh = tf.grad(h); + + expectArraysClose( + await df(a).array(), + await dh(a).array() + ); + }); + + it('[3,1] -> [3,2]', async () => { + const a = tf.tensor2d([[1], + [2], + [3]]); + const A = tf.tensor2d([[1,1], + [2,2], + [3,3]]); + + expectArraysClose( + await A.array(), + await tf.broadcastTo(a,A.shape).array() + ); + + // test gradients + const w = tf.tensor2d([[ 4.7, 4.5], + [-6.1,-6.6], + [-8.1,-3.4]]), + f = (a: Tensor) => tf.broadcastTo(a,A.shape).mul(w).mean().asScalar(), + h = (a: Tensor) => a.mul(w).mean().asScalar(); + + const df = tf.grad(f), + dh = tf.grad(h); + + expectArraysClose( + await df(a).array(), + await dh(a).array() + ); + }); +}); + describeWithFlags('zeros', ALL_ENVS, () => { it('1D default dtype', async () => { const a: tf.Tensor1D = tf.zeros([3]);