Skip to content
This repository has been archived by the owner on Aug 15, 2019. It is now read-only.

Commit

Permalink
added support prod reduce op (#1279)
Browse files Browse the repository at this point in the history
* added support prod reduce op

* addressed the review comments, and removed the gradient func for prod op, since it requires the cumprod op
  • Loading branch information
pyu10055 committed Oct 1, 2018
1 parent ebcb598 commit 13601f4
Show file tree
Hide file tree
Showing 7 changed files with 204 additions and 3 deletions.
1 change: 1 addition & 0 deletions src/kernels/backend.ts
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ export interface KernelBackend extends TensorStorage, BackendTimer {
floorDiv(a: Tensor, b: Tensor): Tensor;

sum(x: Tensor, axes: number[]): Tensor;
prod(x: Tensor, axes: number[]): Tensor;

unsortedSegmentSum<T extends Tensor>(
x: T, segmentIds: Tensor1D, numSegments: number): Tensor;
Expand Down
22 changes: 22 additions & 0 deletions src/kernels/backend_cpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,28 @@ export class MathBackendCPU implements KernelBackend {
return result;
}

prod(x: Tensor, axes: number[]): Tensor {
this.assertNotComplex(x, 'sum');

const [outShape, reduceShape] =
axis_util.computeOutAndReduceShapes(x.shape, axes);
const resultDtype = upcastType(x.dtype, 'int32');
const result = ops.zeros(outShape, resultDtype);
const reduceSize = util.sizeFromShape(reduceShape);
const vals = result.dataSync();

const aVals = x.dataSync();
for (let i = 0; i < vals.length; ++i) {
const offset = i * reduceSize;
let prod = 1;
for (let j = 0; j < reduceSize; ++j) {
prod *= aVals[offset + j];
}
vals[i] = prod;
}
return result;
}

unsortedSegmentSum<T extends Tensor>(
x: T, segmentIds: Tensor1D, numSegments: number): Tensor {
this.assertNotComplex(x, 'unsortedSegmentSum');
Expand Down
11 changes: 10 additions & 1 deletion src/kernels/backend_webgl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -769,7 +769,7 @@ export class MathBackendWebGL implements KernelBackend {
}

private reduce(
x: Tensor2D, reduceType: 'all'|'any'|'max'|'min'|'sum',
x: Tensor2D, reduceType: 'all'|'any'|'max'|'min'|'sum'|'prod',
dtype: DataType): Tensor2D {
const batchSize = x.shape[0];
const inSize = x.shape[1];
Expand Down Expand Up @@ -824,6 +824,15 @@ export class MathBackendWebGL implements KernelBackend {
return this.reduce(a2D, 'sum', outputDType).reshape(outShape);
}

prod(x: Tensor, axes: number[]): Tensor {
const [outShape, reduceShape] =
axis_util.computeOutAndReduceShapes(x.shape, axes);
const inSize = util.sizeFromShape(reduceShape);
const a2D = x.as2D(-1, inSize);
const outputDType = sumOutType(x.dtype);
return this.reduce(a2D, 'prod', outputDType).reshape(outShape);
}

unsortedSegmentSum<T extends Tensor>(
x: T, segmentIds: Tensor1D, numSegments: number): Tensor {
let axis = 0;
Expand Down
13 changes: 11 additions & 2 deletions src/kernels/webgl/reduce_gpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ export class ReduceProgram implements GPGPUProgram {
userCode: string;

constructor(
reduceInfo: ReduceInfo, reduceType: 'all'|'any'|'max'|'min'|'sum') {
reduceInfo: ReduceInfo,
reduceType: 'all'|'any'|'max'|'min'|'sum'|'prod') {
const windowSize = reduceInfo.windowSize;
const batchSize = reduceInfo.batchSize;
const inSize = reduceInfo.inSize;
Expand All @@ -34,7 +35,9 @@ export class ReduceProgram implements GPGPUProgram {
let initializationValue = '0.0';
let compareOp = ``;

if (reduceType === 'min') {
if (reduceType === 'prod') {
initializationValue = '1.0';
} else if (reduceType === 'min') {
initializationValue = '1.0 / 0.0';
compareOp = `min`;
} else if (reduceType === 'max') {
Expand All @@ -47,6 +50,8 @@ export class ReduceProgram implements GPGPUProgram {

if (reduceType === 'sum') {
returnValue = `sumValue`;
} else if (reduceType === 'prod') {
returnValue = `prodValue`;
} else if (reduceType === 'all') {
returnValue = `allValue`;
} else if (reduceType === 'any') {
Expand All @@ -59,6 +64,9 @@ export class ReduceProgram implements GPGPUProgram {
let updateSnippet = `
if (${reduceType === 'sum'}) {
sumValue += dot(values, ones);
} else if (${reduceType === 'prod'}) {
vec2 tmp = vec2(values[0], values[1]) * vec2(values[2], values[3]);
prodValue *= tmp[0] * tmp[1];
} else {
minMaxValue = ${compareOp}(values, minMaxValue);
}
Expand Down Expand Up @@ -108,6 +116,7 @@ export class ReduceProgram implements GPGPUProgram {
int inOffset = outIdx * ${windowSize};
vec4 minMaxValue = vec4(${initializationValue});
float prodValue = 1.0;
float sumValue = 0.0;
float allValue = 1.0;
float anyValue = 0.0;
Expand Down
55 changes: 55 additions & 0 deletions src/ops/reduction_ops.ts
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,60 @@ function sum_<T extends Tensor>(
return customOp($x) as T;
}

/**
* Computes the product of elements across dimensions of a `Tensor`.
*
* Reduces the input along the dimensions given in `axes`. Unless `keepDims`
* is true, the rank of the `Tensor` is reduced by 1 for each entry in `axes`.
* If `keepDims` is true, the reduced dimensions are retained with length 1.
* If `axes` has no entries, all dimensions are reduced, and a `Tensor` with a
* single element is returned.
*
* ```js
* const x = tf.tensor1d([1, 2, 3]);
*
* x.prod().print(); // or tf.prod(x)
* ```
*
* ```js
* const x = tf.tensor2d([1, 2, 3, 4], [2, 2]);
*
* const axis = 1;
* x.prod(axis).print(); // or tf.prod(x, axis)
* ```
*
* @param x The input tensor to compute the product over. If the dtype is `bool`
* it will be converted to `int32` and the output dtype will be `int32`.
* @param axis The dimension(s) to reduce. By default it reduces
* all dimensions.
* @param keepDims If true, retains reduced dimensions with size 1.
*/
/** @doc {heading: 'Operations', subheading: 'Reduction'} */
function prod_<T extends Tensor>(
x: Tensor|TensorLike, axis: number|number[] = null, keepDims = false): T {
let $x = convertToTensor(x, 'x', 'prod');

if ($x.dtype === 'bool') {
$x = $x.toInt();
}
const axes = axis_util.parseAxisParam(axis, $x.shape);

const permutation = axis_util.getAxesPermutation(axes, $x.rank);
let reductionAxes = axes;
let permutedX = $x;
if (permutation != null) {
permutedX = $x.transpose(permutation);
reductionAxes = axis_util.getInnerMostAxes(reductionAxes.length, $x.rank);
}
let value = ENV.engine.runKernel(
backend => backend.prod(permutedX, reductionAxes), {permutedX});
if (keepDims) {
const newShape = axis_util.expandShapeToKeepDim(value.shape, axes);
value = value.reshape(newShape);
}

return value as T;
}
/**
* Computes the mean of elements across dimensions of a `Tensor`.
*
Expand Down Expand Up @@ -554,3 +608,4 @@ export const mean = op({mean_});
export const min = op({min_});
export const moments = op({moments_});
export const sum = op({sum_});
export const prod = op({prod_});
99 changes: 99 additions & 0 deletions src/ops/reduction_ops_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -852,6 +852,105 @@ describeWithFlags('Reduction: sum', ALL_ENVS, () => {
});
});

describeWithFlags('Reduction: prod', ALL_ENVS, () => {
it('basic', () => {
const a = tf.tensor2d([1, 2, 3, 0, 0, 1], [3, 2]);
const result = tf.prod(a);
expectNumbersClose(result.get(), 0);
});

it('propagates NaNs', () => {
const a = tf.tensor2d([1, 2, 3, NaN, 0, 1], [3, 2]);
expect(tf.prod(a).get()).toEqual(NaN);
});

it('prod over dtype int32', () => {
const a = tf.tensor1d([1, 5, 7, 3], 'int32');
const prod = tf.prod(a);
expect(prod.get()).toBe(105);
});

it('prod over dtype bool', () => {
const a = tf.tensor1d([true, false, false, true, true], 'bool');
const prod = tf.prod(a);
expect(prod.get()).toBe(0);
});

it('prods all values in 2D array with keep dim', () => {
const a = tf.tensor2d([1, 2, 3, 1, 0, 1], [3, 2]);
const res = tf.prod(a, null, true /* keepDims */);

expect(res.shape).toEqual([1, 1]);
expectArraysClose(res, [0]);
});

it('prods across axis=0 in 2D array', () => {
const a = tf.tensor2d([1, 2, 3, 1, 0, 1], [3, 2]);
const res = tf.prod(a, [0]);

expect(res.shape).toEqual([2]);
expectArraysClose(res, [0, 2]);
});

it('prods across axis=0 in 2D array, keepDims', () => {
const a = tf.tensor2d([1, 2, 3, 1, 0, 1], [3, 2]);
const res = tf.prod(a, [0], true /* keepDims */);

expect(res.shape).toEqual([1, 2]);
expectArraysClose(res, [0, 2]);
});

it('prods across axis=1 in 2D array', () => {
const a = tf.tensor2d([1, 2, 3, 1, 1, 1], [3, 2]);
const res = tf.prod(a, [1]);

expect(res.shape).toEqual([3]);
expectArraysClose(res, [2, 3, 1]);
});

it('2D, axis=1 provided as number', () => {
const a = tf.tensor2d([1, 2, 3, 1, 1, 1], [2, 3]);
const res = tf.prod(a, 1);

expect(res.shape).toEqual([2]);
expectArraysClose(res, [6, 1]);
});

it('2D, axis = -1 provided as number', () => {
const a = tf.tensor2d([1, 2, 3, 1, 1, 1], [2, 3]);
const res = tf.prod(a, -1);

expect(res.shape).toEqual([2]);
expectArraysClose(res, [6, 1]);
});

it('prods across axis=0,1 in 2D array', () => {
const a = tf.tensor2d([1, 2, 3, 1, 1, 1], [3, 2]);
const res = tf.prod(a, [0, 1]);

expect(res.shape).toEqual([]);
expectArraysClose(res, [6]);
});

it('2D, axis=[-1,-2] in 2D array', () => {
const a = tf.tensor2d([1, 2, 3, 1, 1, 1], [3, 2]);
const res = tf.prod(a, [-1, -2]);

expect(res.shape).toEqual([]);
expectArraysClose(res, [6]);
});

it('throws when passed a non-tensor', () => {
expect(() => tf.prod({} as tf.Tensor))
.toThrowError(/Argument 'x' passed to 'prod' must be a Tensor/);
});

it('accepts a tensor-like object', () => {
const result = tf.prod([[1, 2], [3, 1], [1, 1]]);
expectNumbersClose(result.get(), 6);
});
});

describeWithFlags('Reduction: mean', ALL_ENVS, () => {
it('basic', () => {
const a = tf.tensor2d([1, 2, 3, 0, 0, 1], [3, 2]);
Expand Down
6 changes: 6 additions & 0 deletions src/tensor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,8 @@ export interface OpHandler {
logSumExp<T extends Tensor>(
x: Tensor, axis: number|number[], keepDims: boolean): T;
sum<T extends Tensor>(x: Tensor, axis: number|number[], keepDims: boolean): T;
prod<T extends Tensor>(x: Tensor, axis: number|number[], keepDims: boolean):
T;
mean<T extends Tensor>(x: Tensor, axis: number|number[], keepDims: boolean):
T;
min<T extends Tensor>(x: Tensor, axis: number|number[], keepDims: boolean): T;
Expand Down Expand Up @@ -769,6 +771,10 @@ export class Tensor<R extends Rank = Rank> {
this.throwIfDisposed();
return opHandler.sum(this, axis, keepDims);
}
prod<T extends Tensor>(axis: number|number[] = null, keepDims = false): T {
this.throwIfDisposed();
return opHandler.prod(this, axis, keepDims);
}
mean<T extends Tensor>(axis: number|number[] = null, keepDims = false): T {
this.throwIfDisposed();
return opHandler.mean(this, axis, keepDims);
Expand Down

0 comments on commit 13601f4

Please sign in to comment.