Skip to content
67 changes: 67 additions & 0 deletions tfjs-core/src/ops/array_ops.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,72 @@ import {op} from './operation';
import {MPRandGauss, RandGamma, UniformRandom} from './rand';
import {zeros, zerosLike} from './tensor_ops';

/** Broadcast an array to a compatible shape NumPy-style.
*
* The tensor's shape is compared to the broadcast shape from end to beginning.
* Ones are prepended to the tensor's shape until is has the same length as
* the broadcast shape. If input.shape[i]==shape[i], the (i+1)-th axis is
* already broadcast-compatible. If input.shape[i]==1 and shape[i]==N, then
* the input tensor is tiled N times along that axis (using tf.tile).
*
* @param input The tensor that is to be broadcasted.
* @param shape The input is to be broadcast to this shape.
*/
/** @doc {heading: 'Tensors', subheading: 'Transformations'} */
function broadcastTo_<R extends Rank>(
x: Tensor|TensorLike, shape: ShapeMap[R]
): Tensor<R>
{
let input = convertToTensor(x, 'broadcastTo', 'x');
const xShape = input.shape;

if( shape.some(d => !(d > 0) || d%1 !== 0) ) {
throw new Error(`broadcastTo(): Invalid broadcast shape [${shape}].`);
}

if( shape.length < input.rank ) {
throw new Error(
`broadcastTo(): shape.length=${shape.length} < input.rank=${input.rank}.`
);
}

if( shape.length > input.rank )
{
const newShape = input.shape.slice();
while(newShape.length < shape.length ) {
newShape.unshift(1);
}
input = input.reshape(newShape);
}

const reps: number[] = Array.from(shape);
for( let i=shape.length-1; i >= 0; i-- )
{
if( input.shape[i] === shape[i] ) {
reps[i] = 1;
}
else if( input.shape[i] !== 1 ) {
throw new Error(
`broadcastTo(): [${xShape}] cannot be broadcast to [${shape}].`
);
}
}

const axes = reps.map( ( n,i) => n > 1 ? i : -1 ).filter( i => i >= 0 );

if( axes.length === 0 ) {
return input.clone() as Tensor<R>;
}

return ENGINE.runKernelFunc(
backend => backend.tile(input,reps),
{input},
(dy: Tensor) => ({
input: () => dy.sum(axes,/*keepDims=*/true)
})
) as Tensor<R>;
}

/**
* Creates a new tensor with the same values and shape as the specified
* tensor.
Expand Down Expand Up @@ -1126,6 +1192,7 @@ export {
};

export const batchToSpaceND = op({batchToSpaceND_});
export const broadcastTo = op({broadcastTo_});
export const cast = op({cast_});
export const clone = op({clone_});
export const cumsum = op({cumsum_});
Expand Down
85 changes: 85 additions & 0 deletions tfjs-core/src/ops/array_ops_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,95 @@ import * as tf from '../index';
import {ALL_ENVS, BROWSER_ENVS, describeWithFlags, NODE_ENVS} from '../jasmine_util';
import {expectArraysClose, expectArraysEqual, expectPromiseToFail, expectValuesInRange} from '../test_util';
import {TypedArray} from '../types';
import {Tensor} from '../tensor';
import * as util from '../util';

import {expectArrayInMeanStdRange, jarqueBeraNormalityTest} from './rand_util';

describeWithFlags('broadcastTo', ALL_ENVS, () => {
it('[] -> [3,2]', async () => {
const a = tf.scalar(4.2);
const A = tf.tensor2d([[4.2, 4.2],
[4.2, 4.2],
[4.2, 4.2]]);

expectArraysClose(
await A.array(),
await tf.broadcastTo(a,A.shape).array()
);

// test gradients
const w = tf.tensor2d([[ 4.7, 4.5],
[-6.1,-6.6],
[-8.1,-3.4]]),
f = (a: Tensor) => tf.broadcastTo(a,A.shape).mul(w).mean().asScalar(),
h = (a: Tensor) => a.mul(w).mean().asScalar();

const df = tf.grad(f),
dh = tf.grad(h);

expectArraysClose(
await df(a).array(),
await dh(a).array()
);
});

it('[2] -> [3,2]', async () => {
const a = tf.tensor1d( [1,2] );
const A = tf.tensor2d([[1,2],
[1,2],
[1,2]]);
expectArraysClose(
await A.array(),
await tf.broadcastTo(a,A.shape).array()
);

// test gradients
const w = tf.tensor2d([[ 4.7, 4.5],
[-6.1,-6.6],
[-8.1,-3.4]]),
f = (a: Tensor) => tf.broadcastTo(a,A.shape).mul(w).mean().asScalar(),
h = (a: Tensor) => a.mul(w).mean().asScalar();

const df = tf.grad(f),
dh = tf.grad(h);

expectArraysClose(
await df(a).array(),
await dh(a).array()
);
});

it('[3,1] -> [3,2]', async () => {
const a = tf.tensor2d([[1],
[2],
[3]]);
const A = tf.tensor2d([[1,1],
[2,2],
[3,3]]);

expectArraysClose(
await A.array(),
await tf.broadcastTo(a,A.shape).array()
);

// test gradients
const w = tf.tensor2d([[ 4.7, 4.5],
[-6.1,-6.6],
[-8.1,-3.4]]),
f = (a: Tensor) => tf.broadcastTo(a,A.shape).mul(w).mean().asScalar(),
h = (a: Tensor) => a.mul(w).mean().asScalar();

const df = tf.grad(f),
dh = tf.grad(h);

expectArraysClose(
await df(a).array(),
await dh(a).array()
);
});
});

describeWithFlags('zeros', ALL_ENVS, () => {
it('1D default dtype', async () => {
const a: tf.Tensor1D = tf.zeros([3]);
Expand Down