Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions tfjs-backend-cpu/src/kernels/Max.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ export const maxConfig: KernelConfig = {
backendName: 'cpu',
kernelFunc: ({inputs, attrs, backend}) => {
const {x} = inputs as MaxInputs;
const {reductionIndices} = attrs as {} as MaxAttrs;
const {reductionIndices, keepDims} = attrs as {} as MaxAttrs;
const cpuBackend = backend as MathBackendCPU;
let xShape = x.shape;
const xRank = xShape.length;
Expand Down Expand Up @@ -60,6 +60,14 @@ export const maxConfig: KernelConfig = {

const result = maxImpl(xVals, reduceSize, maxOutShape, x.dtype);
const dataId = cpuBackend.write(result, maxOutShape, x.dtype);
return {dataId, shape: maxOutShape, dtype: x.dtype};

let outShape = maxOutShape;
if (keepDims) {
// reshape
const newShape = backend_util.expandShapeToKeepDim(maxOutShape, origAxes);
outShape = newShape;
}

return {dataId, shape: outShape, dtype: x.dtype};
}
};
47 changes: 32 additions & 15 deletions tfjs-backend-wasm/src/kernels/ArgMax.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,13 @@
* =============================================================================
*/

import {KernelFunc, registerKernel, TensorInfo, util} from '@tensorflow/tfjs-core';
import {ArgMax, ArgMaxAttrs, ArgMaxInputs, KernelFunc, registerKernel, util} from '@tensorflow/tfjs-core';

import {BackendWasm} from '../backend_wasm';

import {permuteAxesAndTranspose} from './kernel_utils';
import {CppDType} from './types';

interface ArgMaxInputs {
x: TensorInfo;
}

interface ArgMaxAttrs {
axis: number;
}

let wasmFunc: (
xId: number, dtype: number, outerSize: number, innerSize: number,
outId: number) => void;
Expand All @@ -45,19 +38,43 @@ function setup(backend: BackendWasm) {

function argmax(
args: {inputs: ArgMaxInputs, backend: BackendWasm, attrs: ArgMaxAttrs}) {
const {inputs: {x}, backend, attrs: {axis}} = args;
const outShape = x.shape.slice(0, -1);
const out = backend.makeOutput(outShape, 'int32');
const {backend, inputs, attrs} = args;
const {axis} = attrs as {} as ArgMaxAttrs;
const {x} = inputs as {} as ArgMaxInputs;
const xId = backend.dataIdMap.get(x.dataId).id;
let inputId = xId;
let input = x;

const {transposed, axes, inputWasTransposed} =
permuteAxesAndTranspose(x, axis, backend);

if (inputWasTransposed) {
const transposedId = backend.dataIdMap.get(transposed.dataId).id;
if (transposedId !== xId) {
// transpose was not a no-op. We will need to dispose of this
// once we are done.
input = transposed;
inputId = transposedId;
}
}

const outShape = input.shape.slice(0, -1);
const out = backend.makeOutput(outShape, 'int32');
const outId = backend.dataIdMap.get(out.dataId).id;
const outerSize = util.sizeFromShape(out.shape);
const innerSize = x.shape[axis];
wasmFunc(xId, CppDType[x.dtype], outerSize, innerSize, outId);
const innerSize = input.shape[axes[0]];
wasmFunc(inputId, CppDType[input.dtype], outerSize, innerSize, outId);

if (inputWasTransposed) {
// dispose of the transposed tensor.
backend.disposeData(transposed.dataId);
}

return out;
}

registerKernel({
kernelName: 'ArgMax',
kernelName: ArgMax,
backendName: 'wasm',
kernelFunc: argmax as {} as KernelFunc,
setupFunc: setup
Expand Down
79 changes: 40 additions & 39 deletions tfjs-backend-wasm/src/kernels/Max.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@
* =============================================================================
*/

import {backend_util, registerKernel, TensorInfo, util} from '@tensorflow/tfjs-core';
import {backend_util, KernelFunc, registerKernel, TensorInfo, util} from '@tensorflow/tfjs-core';
import {Max, MaxAttrs, MaxInputs} from '@tensorflow/tfjs-core';

import {BackendWasm} from '../backend_wasm';

import {transpose} from './Transpose';
import {permuteAxesAndTranspose} from './kernel_utils';

let wasmMax: (xId: number, reduceSize: number, outId: number) => void;

Expand All @@ -29,56 +29,57 @@ function setup(backend: BackendWasm): void {
backend.wasm.cwrap('Max', null /*void*/, ['number, number, number']);
}

function max(args: {backend: BackendWasm, inputs: {}, attrs: {}}): TensorInfo {
function max(args: {backend: BackendWasm, inputs: MaxInputs, attrs: MaxAttrs}):
TensorInfo {
const {backend, inputs, attrs} = args;
const {reductionIndices} = attrs as MaxAttrs;
const {x} = inputs as MaxInputs;
const {reductionIndices: axis, keepDims} = attrs;
const {x} = inputs;
const xId = backend.dataIdMap.get(x.dataId).id;

let xShape = x.shape;
const xRank = x.shape.length;
const xVals = backend.typedArrayFromHeap(x);

const origAxes = util.parseAxisParam(reductionIndices, xShape);
let axes = origAxes;
const permutedAxes = backend_util.getAxesPermutation(axes, xRank);
const maxInputIsTransposed = permutedAxes != null;
if (maxInputIsTransposed) {
const newShape: number[] = new Array(xRank);
for (let i = 0; i < newShape.length; i++) {
newShape[i] = xShape[permutedAxes[i]];
}

axes = backend_util.getInnerMostAxes(axes.length, xRank);

const xTransposed =
transpose({inputs: {x}, attrs: {perm: permutedAxes}, backend});

if (backend.dataIdMap.get(xTransposed.dataId).id !== xId) {
// If perm is not no-op.
const xTransposedVals = backend.typedArrayFromHeap(xTransposed);
xVals.set(xTransposedVals, 0);
backend.disposeData(xTransposed.dataId);
let inputId = xId;
let input = x;

const {transposed, axes, originalAxes, inputWasTransposed} =
permuteAxesAndTranspose(x, axis, backend);

if (inputWasTransposed) {
const transposedId = backend.dataIdMap.get(transposed.dataId).id;
if (transposedId !== xId) {
// transpose was not a no-op. We will need to dispose of this
// once we are done.
input = transposed;
inputId = transposedId;
}
xShape = newShape;
}

backend_util.assertAxesAreInnerMostDims('max', axes, xRank);
const inputRank = input.shape.length;
backend_util.assertAxesAreInnerMostDims('max', axes, inputRank);
const [outShape, reduceShape] =
backend_util.computeOutAndReduceShapes(xShape, axes);
backend_util.computeOutAndReduceShapes(input.shape, axes);
const reduceSize = util.sizeFromShape(reduceShape);

const out = backend.makeOutput(outShape, x.dtype);
if (util.sizeFromShape(xShape) === 0) {
return out;
if (util.sizeFromShape(input.shape) !== 0) {
const outId = backend.dataIdMap.get(out.dataId).id;
wasmMax(inputId, reduceSize, outId);
}

const outId = backend.dataIdMap.get(out.dataId).id;
if (inputWasTransposed) {
// dispose of the transposed tensor.
backend.disposeData(transposed.dataId);
}

wasmMax(xId, reduceSize, outId);
if (keepDims) {
// reshape
const newShape = backend_util.expandShapeToKeepDim(out.shape, originalAxes);
out.shape = newShape;
}

return out;
}

registerKernel(
{kernelName: Max, backendName: 'wasm', setupFunc: setup, kernelFunc: max});
registerKernel({
kernelName: Max,
backendName: 'wasm',
setupFunc: setup,
kernelFunc: max as {} as KernelFunc
});
56 changes: 38 additions & 18 deletions tfjs-backend-wasm/src/kernels/Min.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,11 @@
* =============================================================================
*/

import {backend_util, NamedAttrMap, NamedTensorInfoMap, registerKernel, TensorInfo, util} from '@tensorflow/tfjs-core';
import {backend_util, KernelFunc, Min, MinAttrs, MinInputs, registerKernel, TensorInfo, util} from '@tensorflow/tfjs-core';

import {BackendWasm} from '../backend_wasm';

interface MinInputs extends NamedTensorInfoMap {
x: TensorInfo;
}

interface MinAttrs extends NamedAttrMap {
axes: number[];
}
import {permuteAxesAndTranspose} from './kernel_utils';

let wasmMin: (xId: number, reduceSize: number, outId: number) => void;

Expand All @@ -37,29 +31,55 @@ function setup(backend: BackendWasm): void {
function min(args: {backend: BackendWasm, inputs: MinInputs, attrs: MinAttrs}):
TensorInfo {
const {backend, inputs, attrs} = args;
const {axes} = attrs;
const {axis, keepDims} = attrs;
const {x} = inputs;
const xId = backend.dataIdMap.get(x.dataId).id;
let inputId = xId;
let input = x;

const {transposed, axes, originalAxes, inputWasTransposed} =
permuteAxesAndTranspose(x, axis, backend);

if (inputWasTransposed) {
const transposedId = backend.dataIdMap.get(transposed.dataId).id;
if (transposedId !== xId) {
// transpose was not a no-op. We will need to dispose of this
// once we are done.
input = transposed;
inputId = transposedId;
}
}

const inputRank = input.shape.length;

backend_util.assertAxesAreInnerMostDims('min', axes, x.shape.length);
backend_util.assertAxesAreInnerMostDims('min', axes, inputRank);
const [outShape, reduceShape] =
backend_util.computeOutAndReduceShapes(x.shape, axes);
backend_util.computeOutAndReduceShapes(input.shape, axes);
const reduceSize = util.sizeFromShape(reduceShape);

const out = backend.makeOutput(outShape, x.dtype);
if (util.sizeFromShape(x.shape) === 0) {
return out;
const out = backend.makeOutput(outShape, input.dtype);
if (util.sizeFromShape(input.shape) !== 0) {
const outId = backend.dataIdMap.get(out.dataId).id;
wasmMin(inputId, reduceSize, outId);
}

const outId = backend.dataIdMap.get(out.dataId).id;
if (inputWasTransposed) {
// dispose of the transposed tensor.
backend.disposeData(transposed.dataId);
}

if (keepDims) {
// reshape
const newShape = backend_util.expandShapeToKeepDim(out.shape, originalAxes);
out.shape = newShape;
}

wasmMin(xId, reduceSize, outId);
return out;
}

registerKernel({
kernelName: 'Min',
kernelName: Min,
backendName: 'wasm',
setupFunc: setup,
kernelFunc: min
kernelFunc: min as {} as KernelFunc
});
58 changes: 40 additions & 18 deletions tfjs-backend-wasm/src/kernels/Sum.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,11 @@
* =============================================================================
*/

import {backend_util, NamedAttrMap, NamedTensorInfoMap, registerKernel, TensorInfo, util} from '@tensorflow/tfjs-core';
import {backend_util, KernelFunc, registerKernel, Sum, SumAttrs, SumInputs, TensorInfo, util} from '@tensorflow/tfjs-core';

import {BackendWasm} from '../backend_wasm';

interface SumInputs extends NamedTensorInfoMap {
x: TensorInfo;
}

interface SumAttrs extends NamedAttrMap {
axes: number[];
}
import {permuteAxesAndTranspose} from './kernel_utils';

let wasmSum: (xId: number, reduceSize: number, outId: number) => void;

Expand All @@ -37,29 +31,57 @@ function setup(backend: BackendWasm): void {
function sum(args: {backend: BackendWasm, inputs: SumInputs, attrs: SumAttrs}):
TensorInfo {
const {backend, inputs, attrs} = args;
const {axes} = attrs;
const {axis, keepDims} = attrs;
const {x} = inputs;
const xId = backend.dataIdMap.get(x.dataId).id;
let inputId = xId;
let input = x;

backend_util.assertAxesAreInnerMostDims('sum', axes, x.shape.length);
const {transposed, axes, originalAxes, inputWasTransposed} =
permuteAxesAndTranspose(x, axis, backend);

let reductionAxes = axes;
if (inputWasTransposed) {
const transposedId = backend.dataIdMap.get(transposed.dataId).id;
if (transposedId !== xId) {
// transpose was not a no-op. We will need to dispose of this
// once we are done.
input = transposed;
inputId = transposedId;
reductionAxes = backend_util.getInnerMostAxes(
reductionAxes.length, input.shape.length);
}
}

backend_util.assertAxesAreInnerMostDims(
'sum', reductionAxes, input.shape.length);
const [outShape, reduceShape] =
backend_util.computeOutAndReduceShapes(x.shape, axes);
backend_util.computeOutAndReduceShapes(input.shape, reductionAxes);
const reduceSize = util.sizeFromShape(reduceShape);

const out = backend.makeOutput(outShape, x.dtype);
if (util.sizeFromShape(x.shape) === 0) {
return out;
const out = backend.makeOutput(outShape, input.dtype);
if (util.sizeFromShape(input.shape) !== 0) {
const outId = backend.dataIdMap.get(out.dataId).id;
wasmSum(inputId, reduceSize, outId);
}

if (inputWasTransposed) {
// dispose of the transposed tensor.
backend.disposeData(transposed.dataId);
}

const outId = backend.dataIdMap.get(out.dataId).id;
if (keepDims) {
// reshape
const newShape = backend_util.expandShapeToKeepDim(out.shape, originalAxes);
out.shape = newShape;
}

wasmSum(xId, reduceSize, outId);
return out;
}

registerKernel({
kernelName: 'Sum',
kernelName: Sum,
backendName: 'wasm',
setupFunc: setup,
kernelFunc: sum
kernelFunc: sum as {} as KernelFunc
});
Loading