Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 12 additions & 11 deletions tfjs-backend-wasm/src/kernels/Slice.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
* =============================================================================
*/

import {backend_util, buffer, NamedAttrMap, NamedTensorInfoMap, registerKernel, slice_util, util} from '@tensorflow/tfjs-core';
import {backend_util, buffer, NamedAttrMap, NamedTensorInfoMap, registerKernel, slice_util, Tensor, util} from '@tensorflow/tfjs-core';
import {TensorInfo} from '@tensorflow/tfjs-core';

import {BackendWasm} from '../backend_wasm';
Expand All @@ -32,33 +32,34 @@ interface SliceAttrs extends NamedAttrMap {
export function slice(
args: {inputs: SliceInputs, attrs: SliceAttrs, backend: BackendWasm}) {
const {inputs: {x}, attrs: {begin, size}, backend} = args;
const isContinous = slice_util.isSliceContinous(x.shape, begin, size);
const [begin_, size_] = slice_util.parseSliceParams(x as Tensor, begin, size);
const isContinous = slice_util.isSliceContinous(x.shape, begin_, size_);
const xVals = backend.typedArrayFromHeap(x);
const out = backend.makeOutput(size, x.dtype);
const out = backend.makeOutput(size_, x.dtype);
const outVals = backend.typedArrayFromHeap(out);
const xStrides = util.computeStrides(x.shape);
if (isContinous) {
const flatOffset = slice_util.computeFlatOffset(begin, xStrides);
const flatOffset = slice_util.computeFlatOffset(begin_, xStrides);
outVals.set(
xVals.subarray(flatOffset, flatOffset + util.sizeFromShape(size)));
xVals.subarray(flatOffset, flatOffset + util.sizeFromShape(size_)));
return out;
}
const rank = x.shape.length;
if (rank === 2) {
slice2d(
xVals, xStrides[0], outVals, begin as [number, number],
size as [number, number]);
xVals, xStrides[0], outVals, begin_ as [number, number],
size_ as [number, number]);
} else if (rank === 3) {
slice3d(
xVals, xStrides[0], xStrides[1], outVals,
begin as [number, number, number], size as [number, number, number]);
begin_ as [number, number, number], size_ as [number, number, number]);
} else if (rank === 4) {
slice4d(
xVals, xStrides[0], xStrides[1], xStrides[2], outVals,
begin as [number, number, number, number],
size as [number, number, number, number]);
begin_ as [number, number, number, number],
size_ as [number, number, number, number]);
} else {
genericSliceSlow(xVals, x, outVals, begin, size);
genericSliceSlow(xVals, x, outVals, begin_, size_);
}
return out;
}
Expand Down
7 changes: 0 additions & 7 deletions tfjs-core/scripts/extract_op.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import * as argparse from 'argparse';
import {execSync} from 'child_process';
import * as path from 'path';
import {FunctionDeclaration, ImportDeclaration, Project, SourceFile, VariableStatement} from 'ts-morph';

Expand Down Expand Up @@ -85,12 +84,6 @@ async function moveToNewFile(
newOpFile.fixUnusedIdentifiers();
await newOpFile.save();

// make a test file
// create a test file
const testFilePath = `src/ops/${args.op}_test.ts`;
const command = `touch ${testFilePath}`;
execSync(command);

// Add export to ops file.
const opsExportFile = project.getSourceFile('src/ops/ops.ts');
opsExportFile.addExportDeclaration({
Expand Down
46 changes: 46 additions & 0 deletions tfjs-core/src/gradients/Slice_grad.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import {Slice, SliceAttrs} from '../kernel_names';
import {GradConfig, NamedAttrMap} from '../kernel_registry';
import {pad} from '../ops/pad';
import {parseSliceParams} from '../ops/slice_util';
import {Tensor} from '../tensor';

export const sliceGradConfig: GradConfig = {
kernelName: Slice,
inputsToSave: ['x'],
gradFunc: (dy: Tensor, saved: Tensor[], attrs: NamedAttrMap) => {
const [x] = saved;
const {begin, size} = attrs as {} as SliceAttrs;

const inputShape = x.shape;
const [begin_, size_] = parseSliceParams(x, begin, size);

// Create an Nx2 padding where the first column represents how many
// zeros are prepended (at start) for each dimension, and the second
// column indicates how many zeros are appended (at end).

// The number of zeros to append is the shape of the input
// elementwise-subtracted by both the begin vector and sizes vector.
const paddings: Array<[number, number]> = [];
for (let i = 0; i < dy.rank; i++) {
paddings.push([begin_[i], inputShape[i] - begin_[i] - size_[i]]);
}
return {x: () => pad(dy, paddings)};
}
};
6 changes: 6 additions & 0 deletions tfjs-core/src/kernel_names.ts
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,12 @@ export type SelectV2Inputs = Pick<NamedTensorInfoMap, 'condition'|'t'|'e'>;
export const Selu = 'Selu';
export type SeluInputs = Pick<NamedTensorInfoMap, 'x'>;

export const Slice = 'Slice';
export type SliceInputs = Pick<NamedTensorInfoMap, 'x'>;
export interface SliceAttrs {
begin: number|number[];
size: number|number[];
}
export const Sin = 'Sin';
export type SinInputs = UnaryInputs;

Expand Down
6 changes: 5 additions & 1 deletion tfjs-core/src/ops/ops.ts
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,11 @@ export {separableConv2d} from './separable_conv2d';
export {sign} from './sign';
export {sin} from './sin';
export {sinh} from './sinh';
export {slice} from './slice';
export {slice1d} from './slice1d';
export {slice2d} from './slice2d';
export {slice3d} from './slice3d';
export {slice4d} from './slice4d';
export {spaceToBatchND} from './space_to_batch_nd';
export {split} from './split';
export {square} from './square';
Expand All @@ -145,7 +150,6 @@ export {where} from './where';
export {whereAsync} from './where_async';

export * from './boolean_mask';
export * from './slice';
export * from './unary_ops';
export * from './compare';
export * from './binary_ops';
Expand Down
131 changes: 15 additions & 116 deletions tfjs-core/src/ops/slice.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,75 +15,17 @@
* =============================================================================
*/

import {ENGINE} from '../engine';
import {Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D} from '../tensor';
import {ENGINE, ForwardFunc} from '../engine';
import {Slice, SliceAttrs, SliceInputs} from '../kernel_names';
import {NamedAttrMap} from '../kernel_registry';
import {Tensor} from '../tensor';
import {NamedTensorMap} from '../tensor_types';
import {convertToTensor} from '../tensor_util_env';
import {Rank, TensorLike} from '../types';
import * as util from '../util';

import {op} from './operation';
import {pad} from './pad';
import * as slice_util from './slice_util';

/**
* Extracts a 1D slice from 1D array starting at coordinates `begin` and is
* of length `size`. See `slice` for details.
*/
function slice1d_(
x: Tensor1D|TensorLike, begin: number, size: number): Tensor1D {
const $x = convertToTensor(x, 'x', 'slice1d');
util.assert(
$x.rank === 1,
() =>
`slice1d expects a rank-1 tensor, but got a rank-${$x.rank} tensor`);
return slice($x, [begin], [size]);
}

/**
* Extracts a 2D slice from a 2D array starting at coordinates `begin` and
* is of size `size`. See `slice` for details.
*/
function slice2d_(
x: Tensor2D|TensorLike, begin: [number, number],
size: [number, number]): Tensor2D {
const $x = convertToTensor(x, 'x', 'slice2d');
util.assert(
$x.rank === 2,
() =>
`slice2d expects a rank-2 tensor, but got a rank-${$x.rank} tensor`);
return slice($x, begin, size);
}

/**
* Extracts a 3D slice from a 3D array starting at coordinates `begin` and
* is of size `size`. See `slice` for details.
*/
function slice3d_(
x: Tensor3D|TensorLike, begin: [number, number, number],
size: [number, number, number]): Tensor3D {
const $x = convertToTensor(x, 'x', 'slice3d');
util.assert(
$x.rank === 3,
() =>
`slice3d expects a rank-3 tensor, but got a rank-${$x.rank} tensor`);
return slice($x, begin, size);
}

/**
* Extracts a 4D slice from a 4D array starting at coordinates `begin` and
* is of size `size`. See `slice` for details.
*/
function slice4d_(
x: Tensor4D|TensorLike, begin: [number, number, number, number],
size: [number, number, number, number]): Tensor4D {
const $x = convertToTensor(x, 'x', 'slice4d');
util.assert(
$x.rank === 4,
() =>
`slice4d expects a rank-4 tensor, but got a rank-${$x.rank} tensor`);
return slice($x, begin, size);
}

/**
* Extracts a slice from a `tf.Tensor` starting at coordinates `begin`
* and is of size `size`.
Expand Down Expand Up @@ -124,63 +66,20 @@ function slice_<R extends Rank, T extends Tensor<R>>(
if ($x.rank === 0) {
throw new Error('Slicing scalar is not possible');
}
// The following logic allows for more ergonomic calls.
let begin_: number[];
if (typeof begin === 'number') {
begin_ = [begin, ...new Array($x.rank - 1).fill(0)];
} else if (begin.length < $x.rank) {
begin_ = begin.concat(new Array($x.rank - begin.length).fill(0));
} else {
begin_ = begin.slice();
}
begin_.forEach(d => {
util.assert(
d !== -1, () => 'slice() does not support negative begin indexing.');
});
let size_: number[];
if (size == null) {
size_ = new Array($x.rank).fill(-1);
} else if (typeof size === 'number') {
size_ = [size, ...new Array($x.rank - 1).fill(-1)];
} else if (size.length < $x.rank) {
size_ = size.concat(new Array($x.rank - size.length).fill(-1));
} else {
size_ = size;
}
size_ = size_.map((d, i) => {
if (d >= 0) {
return d;
} else {
util.assert(
d === -1,
() => `Negative size values should be exactly -1 but got ` +
`${d} for the slice() size at index ${i}.`);
return $x.shape[i] - begin_[i];
}
});
const [begin_, size_] = slice_util.parseSliceParams($x, begin, size);
slice_util.assertParamsValid($x, begin_, size_);
const inputShape = $x.shape;
const grad = (dy: T) => {
// Create an Nx2 padding where the first column represents how many
// zeros are prepended (at start) for each dimension, and the second
// column indicates how many zeros are appended (at end).

// The number of zeros to append is the shape of the input
// elementwise-subtracted by both the begin vector and sizes vector.
const paddings: Array<[number, number]> = [];
for (let i = 0; i < dy.rank; i++) {
paddings.push([begin_[i], inputShape[i] - begin_[i] - size_[i]]);
}
return {x: () => pad(dy, paddings)};
const forward: ForwardFunc<Tensor> = (backend, save) => {
save([$x]);
return backend.slice($x, begin_, size_);
};
const attrs = {begin: begin_, size: size_};

const inputs: SliceInputs = {x: $x};
const attrs: SliceAttrs = {begin, size};

return ENGINE.runKernelFunc(
backend => backend.slice($x, begin_, size_), {x: $x}, grad, 'Slice',
attrs);
forward, inputs as {} as NamedTensorMap, null /* grad */, Slice,
attrs as {} as NamedAttrMap) as T;
}

export const slice = op({slice_});
export const slice1d = op({slice1d_});
export const slice2d = op({slice2d_});
export const slice3d = op({slice3d_});
export const slice4d = op({slice4d_});
39 changes: 39 additions & 0 deletions tfjs-core/src/ops/slice1d.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import {Tensor1D} from '../tensor';
import {convertToTensor} from '../tensor_util_env';
import {TensorLike} from '../types';
import * as util from '../util';

import {op} from './operation';
import {slice} from './slice';

/**
* Extracts a 1D slice from 1D array starting at coordinates `begin` and is
* of length `size`. See `slice` for details.
*/
function slice1d_(
x: Tensor1D|TensorLike, begin: number, size: number): Tensor1D {
const $x = convertToTensor(x, 'x', 'slice1d');
util.assert(
$x.rank === 1,
() =>
`slice1d expects a rank-1 tensor, but got a rank-${$x.rank} tensor`);
return slice($x, [begin], [size]);
}
export const slice1d = op({slice1d_});
Loading