Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 2 additions & 8 deletions tfjs-backend-wasm/src/kernels/Split.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
*/

import {NamedAttrMap, NamedTensorInfoMap, registerKernel, SplitV, SplitVAttrs, SplitVInputs, util} from '@tensorflow/tfjs-core';
import {backend_util} from '@tensorflow/tfjs-core';

import {BackendWasm} from '../backend_wasm';

Expand All @@ -32,14 +33,7 @@ export function split(args: {

const $axis = util.parseAxisParam(axis, x.shape)[0];

let splitSizes: number[];
if (typeof (numOrSizeSplits) === 'number') {
splitSizes =
new Array(numOrSizeSplits).fill(x.shape[$axis] / numOrSizeSplits);
} else {
splitSizes = numOrSizeSplits;
}

const splitSizes = backend_util.prepareSplitSize(x, numOrSizeSplits, axis);
const begin = new Array(x.shape.length).fill(0);
const size = x.shape.slice();
return splitSizes.map(s => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,15 +128,6 @@ export const executeOp: InternalOpExecutor = (node: Node,
number[];
const tensor = getParamValue('x', node, tensorMap, context) as tfc.Tensor;

// Allow the last number of split array to be -1, which indicates the rest
// of dimension is allocated to the last split.
if (Array.isArray(numOrSizeSplits)) {
const negIndex = numOrSizeSplits.indexOf(-1);
if (negIndex !== -1) {
const total = numOrSizeSplits.reduce((a, b) => b > 0 ? a + b : a);
numOrSizeSplits[negIndex] = tensor.shape[axis] - total;
}
}
return tfc.split(tensor, numOrSizeSplits, axis);
}
case 'ScatterNd': {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import * as slice_join from '../op_list/slice_join';
import {Node} from '../types';

import {executeOp} from './slice_join_executor';
import {createNumberAttr, createNumberAttrFromIndex, createNumericArrayAttr, createNumericArrayAttrFromIndex, createTensorAttr, createTensorsAttr, validateParam} from './test_helper';
import {createNumberAttr, createNumberAttrFromIndex, createNumericArrayAttrFromIndex, createTensorAttr, createTensorsAttr, validateParam} from './test_helper';

describe('slice join', () => {
let node: Node;
Expand Down Expand Up @@ -340,19 +340,6 @@ describe('slice join', () => {

expect(tfc.split).toHaveBeenCalledWith(input2[0], 2, 1);
});
it('should support -1 split', () => {
spyOn(tfc, 'split');
node.op = 'Split';
node.inputParams.axis = createNumberAttrFromIndex(0);
node.inputParams.x = createTensorAttr(1);
node.attrParams.numOrSizeSplits = createNumericArrayAttr([1, 1, -1]);
node.inputNames = ['input4', 'input3'];
const input3 = [tfc.tensor1d([1, 2, 3, 4, 5])];
const input4 = [tfc.scalar(0)];
executeOp(node, {input3, input4}, context);

expect(tfc.split).toHaveBeenCalledWith(input3[0], [1, 1, 3], 0);
});
it('should match json def for split', () => {
node.op = 'Split';
node.inputParams.axis = createNumberAttrFromIndex(0);
Expand Down
1 change: 1 addition & 0 deletions tfjs-core/src/backends/backend_util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ export * from '../ops/fused_util';
export * from '../ops/erf_util';
export * from '../log';
export * from '../backends/complex_util';
export * from '../ops/split_util';

import * as segment_util from '../ops/segment_util';
export {segment_util};
Expand Down
17 changes: 17 additions & 0 deletions tfjs-core/src/ops/array_ops_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1706,7 +1706,24 @@ describeWithFlags('split', ALL_ENVS, () => {
expect(res[2].shape).toEqual([2, 1]);
expectArraysClose(await res[2].data(), [4, 8]);
});
it('should support -1 split', async () => {
const x = tf.tensor2d([1, 2, 3, 4, 5, 6, 7, 8], [2, 4]);
const res = x.split([1, 1, -1], 1);

expect(res.length).toEqual(3);
expect(res[0].shape).toEqual([2, 1]);
expectArraysClose(await res[0].data(), [1, 5]);
expect(res[1].shape).toEqual([2, 1]);
expectArraysClose(await res[1].data(), [2, 6]);
expect(res[2].shape).toEqual([2, 2]);
expectArraysClose(await res[2].data(), [3, 4, 7, 8]);
});

it('multiple negative number throws error', () => {
const x = tf.tensor2d([1, 2, 3, 4, 5, 6, 7, 8], [2, 4]);
const f = () => tf.split(x, [1, -1, -1], 1);
expect(f).toThrowError();
});
it('sizes to not sum to axis size throws error', () => {
const x = tf.tensor2d([1, 2, 3, 4, 5, 6, 7, 8], [2, 4]);
const f = () => tf.split(x, [1, 2], 1);
Expand Down
17 changes: 12 additions & 5 deletions tfjs-core/src/ops/spectral_ops.ts
Original file line number Diff line number Diff line change
Expand Up @@ -179,11 +179,10 @@ function rfft_(input: Tensor, fftLength?: number): Tensor {
function irfft_(input: Tensor): Tensor {
const innerDimensionSize = input.shape[input.shape.length - 1];
const batch = input.size / innerDimensionSize;

let ret: Tensor;
if (innerDimensionSize <= 2) {
const complexInput = input.as2D(batch, innerDimensionSize);
const ret = ifft(complexInput);
return real(ret);
ret = ifft(complexInput);
} else {
// The length of unique components of the DFT of a real-valued signal
// is 2 * (input_len - 1)
Expand All @@ -201,9 +200,17 @@ function irfft_(input: Tensor): Tensor {
const r = realInput.concat(realConjugate, 1);
const i = imagInput.concat(imagConjugate, 1);
const complexInput = complex(r, i).as2D(outputShape[0], outputShape[1]);
const ret = ifft(complexInput);
return real(ret);
ret = ifft(complexInput);
}
ret = real(ret);
// reshape the result if the input is 3D tensor.
if (input.rank === 3 && input.shape[0] !== 0) {
const temp = ret;
const batch = input.shape[0];
ret = ret.reshape([batch, ret.shape[0] / batch, ret.shape[1]]);
temp.dispose();
}
return ret;
}

export const fft = op({fft_});
Expand Down
14 changes: 10 additions & 4 deletions tfjs-core/src/ops/spectral_ops_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -353,9 +353,11 @@ describeWithFlags('3D IRFFT', ALL_ENVS, () => {
const t1Real = tf.tensor3d([1, 2, 3, 4, 1, 2, 3, 4], [2, 2, 2]);
const t1Imag = tf.tensor3d([0, 0, 0, 0, 0, 0, 0, 0], [2, 2, 2]);
const t1 = tf.complex(t1Real, t1Imag);
const result = tf.spectral.irfft(t1);
expect(result.shape).toEqual([2, 2, 2]);

expectArraysClose(
await tf.spectral.irfft(t1).data(),
[1.5, -0.5, 3.5, -0.5, 1.5, -0.5, 3.5, -0.5]);
await result.data(), [1.5, -0.5, 3.5, -0.5, 1.5, -0.5, 3.5, -0.5]);
});

it('should return the same value with TensorFlow (2x2x3 elements)',
Expand All @@ -365,7 +367,9 @@ describeWithFlags('3D IRFFT', ALL_ENVS, () => {
const t1Imag =
tf.tensor3d([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [2, 2, 3]);
const t1 = tf.complex(t1Real, t1Imag);
expectArraysClose(await tf.spectral.irfft(t1).data(), [
const result = tf.spectral.irfft(t1);
expect(result.shape).toEqual([2, 2, 4]);
expectArraysClose(await result.data(), [
2, -0.5, 0, -0.5, 5, -0.5, 0, -0.5, 2, -0.5, 0, -0.5, 5, -0.5, 0, -0.5
]);
});
Expand All @@ -377,8 +381,10 @@ describeWithFlags('3D IRFFT', ALL_ENVS, () => {
const t1Imag =
tf.tensor3d([1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6], [2, 2, 3]);
const t1 = tf.complex(t1Real, t1Imag);
const result = tf.spectral.irfft(t1);
expect(result.shape).toEqual([2, 2, 4]);
expectArraysClose(
await tf.spectral.irfft(t1).data(),
await result.data(),
[2, -1.5, 0, 0.5, 5, -3, 0, 2, 2, -1.5, 0, 0.5, 5, -3, 0, 2]);
});
});
21 changes: 4 additions & 17 deletions tfjs-core/src/ops/split.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ import {Tensor} from '../tensor';
import {NamedTensorMap} from '../tensor_types';
import {convertToTensor} from '../tensor_util_env';
import {TensorLike} from '../types';
import {assert,} from '../util';
import {parseAxisParam} from '../util';

import {op} from './operation';
import {prepareSplitSize} from './split_util';

/**
* Splits a `tf.Tensor` into sub tensors.
Expand Down Expand Up @@ -55,6 +55,7 @@ import {op} from './operation';
* splits along the axis or an array of integers containing the sizes of
* each output tensor along the axis. If a number then it must evenly divide
* `x.shape[axis]`; otherwise the sum of sizes must match `x.shape[axis]`.
* Can contain one -1 indicating that dimension is to be inferred.
* @param axis The dimension along which to split. Defaults to 0 (the first
* dim).
*/
Expand All @@ -63,23 +64,9 @@ function split_<T extends Tensor>(
x: Tensor|TensorLike, numOrSizeSplits: number[]|number, axis = 0): T[] {
const $x = convertToTensor(x, 'x', 'split');

const $axis = parseAxisParam(axis, $x.shape)[0];
let splitSizes: number[];

if (typeof (numOrSizeSplits) === 'number') {
assert(
$x.shape[$axis] % numOrSizeSplits === 0,
() => 'Number of splits must evenly divide the axis.');
splitSizes =
new Array(numOrSizeSplits).fill($x.shape[$axis] / numOrSizeSplits);
} else {
assert(
$x.shape[$axis] === numOrSizeSplits.reduce((a, b) => a + b),
() => 'The sum of sizes must match the size of the axis dimension.');
splitSizes = numOrSizeSplits;
}

const forward: ForwardFunc<Tensor> = (backend, _) => {
const $axis = parseAxisParam(axis, $x.shape)[0];
const splitSizes = prepareSplitSize($x, numOrSizeSplits, $axis);
return backend.split($x, splitSizes, $axis) as {} as T;
};

Expand Down
60 changes: 60 additions & 0 deletions tfjs-core/src/ops/split_util.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import {TensorInfo} from '../kernel_registry';
import {Tensor} from '../tensor';
import {assert} from '../util';

/**
* Prepare the split size array. When the input is a number, the axis is evenly
* divided among the split size. When the input contains the negative value, the
* rest of the axis is allocated toward that.
*/
export function prepareSplitSize(
x: Tensor|TensorInfo, numOrSizeSplits: number[]|number,
axis = 0): number[] {
let splitSizes = [];
if (typeof (numOrSizeSplits) === 'number') {
assert(
x.shape[axis] % numOrSizeSplits === 0,
() => 'Number of splits must evenly divide the axis.');
splitSizes =
new Array(numOrSizeSplits).fill(x.shape[axis] / numOrSizeSplits);
} else {
const numOfNegs = numOrSizeSplits.reduce((count, value) => {
if (value === -1) {
count += 1;
}
return count;
}, 0);
assert(
numOfNegs <= 1,
() => 'There should be only one negative value in split array.');
const negIndex = numOrSizeSplits.indexOf(-1);
// Allow the number of split array to be -1, which indicates the rest
// of dimension is allocated to that split.
if (negIndex !== -1) {
const total = numOrSizeSplits.reduce((a, b) => b > 0 ? a + b : a);
numOrSizeSplits[negIndex] = x.shape[axis] - total;
}
assert(
x.shape[axis] === numOrSizeSplits.reduce((a, b) => a + b),
() => 'The sum of sizes must match the size of the axis dimension.');
splitSizes = numOrSizeSplits;
}

return splitSizes;
}
39 changes: 28 additions & 11 deletions tfjs-core/src/ops/where.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@
* =============================================================================
*/

import {ENGINE} from '../engine';
import {ENGINE, ForwardFunc} from '../engine';
import {SelectV2, SelectV2Inputs} from '../kernel_names';
import {Tensor} from '../tensor';
import {NamedTensorMap} from '../tensor_types';
import {convertToTensor} from '../tensor_util_env';
import {TensorLike} from '../types';
import {assert, assertShapesMatch} from '../util';

import {assertAndGetBroadcastShape} from './broadcast_util';
import {op} from './operation';

/**
Expand All @@ -41,34 +42,50 @@ import {op} from './operation';
* @param condition The input condition. Must be of dtype bool.
* @param a If `condition` is rank 1, `a` may have a higher rank but
* its first dimension must match the size of `condition`.
* @param b A tensor with the same shape and type as `a`.
* @param b A tensor with the same dtype as `a` and with shape that is
* compatible with `a`.
* @return A tensor with same dtype as `a` and `b`, and shape that is
* broadcastable from `a` and `b`.
*/
/** @doc {heading: 'Operations', subheading: 'Logical'} */
function where_<T extends Tensor>(
condition: Tensor|TensorLike, a: T|TensorLike, b: T|TensorLike): T {
const $a = convertToTensor(a, 'a', 'where');
const $b = convertToTensor(b, 'b', 'where');
const $condition = convertToTensor(condition, 'condition', 'where', 'bool');

assertShapesMatch($a.shape, $b.shape, 'Error in where: ');

// TODO: move this logic to forward function when the broadcastTo op is
// implemented in WASM.
// Find the broadcastable shape for $a and $b.
const broadcastShape = assertAndGetBroadcastShape($a.shape, $b.shape);
const $broadcastedA = $a.broadcastTo(broadcastShape);
const $broadcastedB = $b.broadcastTo(broadcastShape);
if ($condition.rank === 1) {
// If condition rank is 1, then the first dimension must match the size of
// condition.
assert(
$condition.shape[0] === $a.shape[0],
() => 'The first dimension of `a` must match the size of `condition`.');
} else {
}

if ($condition.rank !== 1) {
// A must have the same shape as condition.
assertShapesMatch($condition.shape, $b.shape, 'Error in where: ');
assertShapesMatch(
$condition.shape, $broadcastedB.shape, 'Error in where: ');
}

const inputs: SelectV2Inputs = {condition: $condition, t: $a, e: $b};
return ENGINE.runKernelFunc((backend, save) => {
const res = backend.select($condition, $a, $b);
const forward: ForwardFunc<Tensor> = (backend, save) => {
const res = backend.select($condition, $broadcastedA, $broadcastedB);
save([$condition]);
return res;
}, inputs as unknown as NamedTensorMap, null /* gradient */, SelectV2) as T;
};
const inputs: SelectV2Inputs = {
condition: $condition,
t: $broadcastedA,
e: $broadcastedB
};
return ENGINE.runKernelFunc(
forward, inputs as unknown as NamedTensorMap, null /* gradient */,
SelectV2) as T;
}

export const where = op({where_});
19 changes: 18 additions & 1 deletion tfjs-core/src/ops/where_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ describeWithFlags('where', ALL_ENVS, () => {

it('Tensor4D different a/b shapes', () => {
const c = tf.tensor4d([1, 0, 1, 1], [2, 2, 1, 1], 'bool');
let a = tf.tensor4d([7, 7, 7, 7, 7, 7, 7, 7], [2, 2, 2, 1]);
let a = tf.tensor4d([7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7], [2, 3, 2, 1]);
let b = tf.tensor4d([3, 3, 3, 3], [2, 2, 1, 1]);
let f = () => {
tf.where(c, a, b);
Expand All @@ -227,6 +227,23 @@ describeWithFlags('where', ALL_ENVS, () => {
expect(f).toThrowError();
});

it('Tensor4D broadcastable a/b shapes', () => {
const c = tf.tensor4d([1, 0, 1, 1], [2, 2, 1, 1], 'bool');
const a = tf.tensor4d([7, 7, 7, 7, 7, 7, 7, 7], [2, 2, 2, 1]);
const b = [3];
let f = () => {
tf.where(c, a, b);
};
expect(f).toThrowError();

const a1 = [7];
const b1 = tf.tensor4d([3, 3, 3, 3, 3, 3, 3, 3], [2, 2, 2, 1]);
f = () => {
tf.where(c, a1, b1);
};
expect(f).toThrowError();
});

it('Tensor4D different condition/a shapes', () => {
const c = tf.tensor4d([1, 0, 1, 1, 1, 0, 1, 1], [2, 2, 2, 1], 'bool');
const a = tf.tensor4d([7, 7, 7, 7], [2, 2, 1, 1]);
Expand Down