Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
85 commits
Select commit Hold shift + click to select a range
5dd14db
initial
annxingyuan Mar 24, 2020
08260bb
merge
annxingyuan Mar 24, 2020
bb9543d
save xorig
annxingyuan Mar 24, 2020
fe39bb5
save output
annxingyuan Mar 25, 2020
e686064
max
annxingyuan Mar 25, 2020
70caa0d
separate out
annxingyuan Mar 25, 2020
a97c5fb
remove max
annxingyuan Mar 25, 2020
4e9ae96
cpu forward
annxingyuan Mar 25, 2020
2173ab2
create tensor info
annxingyuan Mar 25, 2020
face1b2
simplify
annxingyuan Mar 25, 2020
5ff78b4
split out impl
annxingyuan Mar 25, 2020
3a34f18
unchain
annxingyuan Mar 25, 2020
fcf1de2
merge
annxingyuan Mar 25, 2020
c16de99
return out directly
annxingyuan Mar 25, 2020
4a9f314
merge
annxingyuan Mar 27, 2020
7abe395
rename
annxingyuan Mar 27, 2020
8ead8d7
kernelize
annxingyuan Mar 27, 2020
cc7685b
wip
annxingyuan Mar 27, 2020
80b28e8
merge
annxingyuan Apr 22, 2020
f9c0df3
fix
annxingyuan Apr 22, 2020
f93656e
webgl build
annxingyuan Apr 22, 2020
7b14faf
fix
annxingyuan Apr 22, 2020
9020fed
max
annxingyuan Apr 22, 2020
17a46a2
add more logs
annxingyuan Apr 22, 2020
0e3ea93
add logs
annxingyuan Apr 22, 2020
e0c834b
logs
annxingyuan Apr 22, 2020
3a3622c
logs
annxingyuan Apr 22, 2020
10b8cb9
rem logs
annxingyuan Apr 22, 2020
e38df55
logs
annxingyuan Apr 22, 2020
04de4ce
remove hacks
annxingyuan Apr 22, 2020
61ff6ae
Merge branch 'master' into modular_max
annxingyuan Apr 23, 2020
0e2cb52
add logs
annxingyuan Apr 23, 2020
f068e52
add logs
annxingyuan Apr 23, 2020
4262572
add transpose to grad
annxingyuan Apr 23, 2020
c864263
add condition
annxingyuan Apr 23, 2020
2f941f6
passes
annxingyuan Apr 23, 2020
36cd5ec
reove logs
annxingyuan Apr 23, 2020
87b5dbc
add failing tests
annxingyuan Apr 23, 2020
9e805c9
Merge branch 'master' into modular_max
annxingyuan Apr 27, 2020
018bb6e
dispose
annxingyuan Apr 27, 2020
0227876
remove fit
annxingyuan Apr 27, 2020
cef4a01
fix wasm
annxingyuan Apr 27, 2020
7ce546c
undo
annxingyuan Apr 28, 2020
c731742
merge
annxingyuan Apr 28, 2020
5f687ac
change api
annxingyuan Apr 28, 2020
3f086ea
fix
annxingyuan Apr 28, 2020
c3f2394
rm dist
annxingyuan Apr 28, 2020
7d85575
export shared
annxingyuan Apr 29, 2020
3008a31
cup forward
annxingyuan Apr 29, 2020
7106795
import max
annxingyuan Apr 29, 2020
06ff0a2
remove logs
annxingyuan Apr 29, 2020
08abe3f
remove log
annxingyuan Apr 29, 2020
714742c
extract
annxingyuan Apr 29, 2020
9d0f05b
merge
annxingyuan Apr 29, 2020
1837abe
lint
annxingyuan Apr 29, 2020
3badd01
merge
annxingyuan May 4, 2020
efef5c8
pr comments
annxingyuan May 4, 2020
15e21db
pr comments
annxingyuan May 4, 2020
82fa093
pr comments
annxingyuan May 4, 2020
a0d5ffd
reorganize
annxingyuan May 4, 2020
a4ac9ff
add keepdims
annxingyuan May 4, 2020
c088772
remove reshape
annxingyuan May 4, 2020
b03ef21
add max
annxingyuan May 4, 2020
d7da66f
reduce
annxingyuan May 4, 2020
a9e6217
typo
annxingyuan May 4, 2020
5bc64b9
lint
annxingyuan May 4, 2020
e167f52
merge
annxingyuan May 5, 2020
65e04a1
add msg
annxingyuan May 5, 2020
d8dd8d5
move
annxingyuan May 5, 2020
c105e65
Merge branch 'master' into modular_max
annxingyuan May 6, 2020
acd1f66
properly dispose
annxingyuan May 6, 2020
84436ec
merge
annxingyuan May 7, 2020
aa50854
dispose intermediate
annxingyuan May 7, 2020
c9a1976
add logs
annxingyuan May 7, 2020
d43dadb
run one test
annxingyuan May 7, 2020
3522c95
hbn
annxingyuan May 7, 2020
f2b8f9f
Merge branch 'master' into modular_max
annxingyuan May 8, 2020
a4f0ae6
temp disable
annxingyuan May 8, 2020
1057015
revive
annxingyuan May 8, 2020
c22e41c
yarn
annxingyuan May 8, 2020
7671c28
revive
annxingyuan May 8, 2020
448d071
remove logs
annxingyuan May 8, 2020
814fe20
rm
annxingyuan May 8, 2020
4eed0b5
clean
annxingyuan May 8, 2020
d2cc027
Merge branch 'master' into modular_max
annxingyuan May 8, 2020
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 4 additions & 27 deletions tfjs-backend-cpu/src/backend_cpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import * as tf from '@tensorflow/tfjs-core';
import {engine, env} from '@tensorflow/tfjs-core';
import {backend_util, buffer, slice_util, util} from '@tensorflow/tfjs-core';
import {BackendTimingInfo, DataStorage, DataType, DataValues, KernelBackend, NumericDataType, Rank, Scalar, ShapeMap, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D, TensorBuffer, TypedArray, upcastType} from '@tensorflow/tfjs-core';
import {BackendTimingInfo, DataStorage, DataType, DataValues, KernelBackend, max, NumericDataType, Rank, Scalar, ShapeMap, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D, TensorBuffer, TypedArray, upcastType} from '@tensorflow/tfjs-core';
import {kernel_impls} from '@tensorflow/tfjs-core';

const nonMaxSuppressionV3 = kernel_impls.nonMaxSuppressionV3;
Expand Down Expand Up @@ -365,7 +365,9 @@ export class MathBackendCPU extends KernelBackend {

softmax<T extends Tensor>(logits: T, dim: number): T {
const axes = util.parseAxisParam([dim], logits.shape);
const maxLogit = this.max(logits, axes);
// TODO(annxingyuan): Call maxImpl rather than op as part of softmax kernel
// modularization.
const maxLogit = max(logits, axes);
const expandedShape =
backend_util.expandShapeToKeepDim(maxLogit.shape, axes);
const a = this.subtract(logits, maxLogit.reshape(expandedShape));
Expand Down Expand Up @@ -807,31 +809,6 @@ export class MathBackendCPU extends KernelBackend {
});
}

max(x: Tensor, axes: number[]): Tensor {
assertNotComplex(x, 'max');

backend_util.assertAxesAreInnerMostDims('max', axes, x.rank);
const [outShape, reduceShape] =
backend_util.computeOutAndReduceShapes(x.shape, axes);
const result = tf.zeros(outShape, x.dtype);
const reduceSize = util.sizeFromShape(reduceShape);
const vals = this.readSync(result.dataId) as TypedArray;

const aVals = this.readSync(x.dataId) as TypedArray;
for (let i = 0; i < vals.length; ++i) {
const offset = i * reduceSize;
let max = aVals[offset];
for (let j = 0; j < reduceSize; ++j) {
const value = aVals[offset + j];
if (value > max) {
max = value;
}
}
vals[i] = max;
}
return result;
}

maximum(a: Tensor, b: Tensor): Tensor {
assertNotComplex([a, b], 'maximum');

Expand Down
3 changes: 2 additions & 1 deletion tfjs-backend-cpu/src/cpu_util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ export function assertNotComplex(
if (t != null) {
util.assert(
t.dtype !== 'complex64',
() => `${opName} does not support complex64 tensors.`);
() => `${
opName} does not support complex64 tensors in the CPU backend.`);
}
});
}
65 changes: 65 additions & 0 deletions tfjs-backend-cpu/src/kernels/Max.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/**
* @license
* Copyright 2020 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import {Max, MaxAttrs, MaxInputs} from '@tensorflow/tfjs-core';
import {backend_util, KernelConfig} from '@tensorflow/tfjs-core';
import {TypedArray, util} from '@tensorflow/tfjs-core';

import {MathBackendCPU} from '../backend_cpu';
import {assertNotComplex} from '../cpu_util';

import {maxImpl} from './Max_impl';
import {transposeImpl} from './Transpose_impl';

export const maxConfig: KernelConfig = {
kernelName: Max,
backendName: 'cpu',
kernelFunc: ({inputs, attrs, backend}) => {
const {x} = inputs as MaxInputs;
const {reductionIndices} = attrs as {} as MaxAttrs;
const cpuBackend = backend as MathBackendCPU;
let xShape = x.shape;
const xRank = xShape.length;

const origAxes = util.parseAxisParam(reductionIndices, xShape);
let axes = origAxes;
const permutedAxes = backend_util.getAxesPermutation(axes, xRank);
let xVals = cpuBackend.data.get(x.dataId).values as TypedArray;
if (permutedAxes != null) {
const newShape: number[] = new Array(xRank);
for (let i = 0; i < newShape.length; i++) {
newShape[i] = xShape[permutedAxes[i]];
}

xVals = transposeImpl(xVals, xShape, x.dtype, permutedAxes, newShape);
axes = backend_util.getInnerMostAxes(axes.length, xRank);

xShape = newShape;
}

assertNotComplex(x, 'max');
backend_util.assertAxesAreInnerMostDims('max', axes, xRank);
const [maxOutShape, reduceShape] =
backend_util.computeOutAndReduceShapes(xShape, axes);

const reduceSize = util.sizeFromShape(reduceShape);

const result = maxImpl(xVals, reduceSize, maxOutShape, x.dtype);
const dataId = cpuBackend.write(result, maxOutShape, x.dtype);
return {dataId, shape: maxOutShape, dtype: x.dtype};
}
};
38 changes: 38 additions & 0 deletions tfjs-backend-cpu/src/kernels/Max_impl.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/**
* @license
* Copyright 2020 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import {DataType, NumericDataType, TypedArray, util} from '@tensorflow/tfjs-core';

export function maxImpl(
aVals: TypedArray, reduceSize: number, outShape: number[],
dtype: DataType): TypedArray {
const vals = util.getTypedArrayFromDType(
dtype as NumericDataType, util.sizeFromShape(outShape));

for (let i = 0; i < vals.length; ++i) {
const offset = i * reduceSize;
let max = aVals[offset];
for (let j = 0; j < reduceSize; ++j) {
const value = aVals[offset + j];
if (value > max) {
max = value;
}
}
vals[i] = max;
}
return vals;
}
2 changes: 1 addition & 1 deletion tfjs-backend-cpu/src/kernels/Transpose_impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ import {util} from '@tensorflow/tfjs-core';
export function transposeImpl(
xVals: TypedArray, xShape: number[], dtype: DataType, perm: number[],
newShape: number[]): TypedArray {
const xSize = util.sizeFromShape(xShape);
const xRank = xShape.length;
const xSize = util.sizeFromShape(xShape);
const xStrides = util.computeStrides(xShape);
const newStrides = util.computeStrides(newShape);

Expand Down
3 changes: 2 additions & 1 deletion tfjs-backend-cpu/src/register_all_kernels.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import {KernelConfig, registerKernel} from '@tensorflow/tfjs-core';

import {divConfig} from './kernels/Div';
import {maxConfig} from './kernels/Max';
import {maxPoolWithArgmaxConfig} from './kernels/MaxPoolWithArgmax';
import {nonMaxSuppressionV5Config} from './kernels/NonMaxSuppressionV5';
import {squareConfig} from './kernels/Square';
Expand All @@ -29,7 +30,7 @@ import {transposeConfig} from './kernels/Transpose';
// List all kernel configs here
const kernelConfigs: KernelConfig[] = [
nonMaxSuppressionV5Config, squareConfig, squaredDifferenceConfig, divConfig,
transposeConfig, maxPoolWithArgmaxConfig
transposeConfig, maxPoolWithArgmaxConfig, maxConfig
];

for (const kernelConfig of kernelConfigs) {
Expand Down
1 change: 1 addition & 0 deletions tfjs-backend-cpu/src/shared.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,5 @@
*/

// Shared kernel impls for use in other backends.
export {maxImpl} from './kernels/Max_impl';
export {transposeImpl} from './kernels/Transpose_impl';
33 changes: 12 additions & 21 deletions tfjs-backend-wasm/src/kernels/Max.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,35 +15,29 @@
* =============================================================================
*/

import {backend_util, NamedAttrMap, NamedTensorInfoMap, registerKernel, TensorInfo, util} from '@tensorflow/tfjs-core';
import {backend_util, registerKernel, TensorInfo, util} from '@tensorflow/tfjs-core';
import {Max, MaxAttrs, MaxInputs} from '@tensorflow/tfjs-core';

import {BackendWasm} from '../backend_wasm';

interface MaxInputs extends NamedTensorInfoMap {
x: TensorInfo;
}

interface MaxAttrs extends NamedAttrMap {
axes: number[];
}

let wasmMax: (xId: number, reduceSize: number, outId: number) => void;

function setup(backend: BackendWasm): void {
wasmMax =
backend.wasm.cwrap('Max', null /*void*/, ['number, number, number']);
}

function max(args: {backend: BackendWasm, inputs: MaxInputs, attrs: MaxAttrs}):
TensorInfo {
function max(args: {backend: BackendWasm, inputs: {}, attrs: {}}): TensorInfo {
const {backend, inputs, attrs} = args;
const {axes} = attrs;
const {x} = inputs;
const {reductionIndices} = attrs as MaxAttrs;
const {x} = inputs as MaxInputs;
const xId = backend.dataIdMap.get(x.dataId).id;

backend_util.assertAxesAreInnerMostDims('max', axes, x.shape.length);
const origAxes = util.parseAxisParam(reductionIndices, x.shape);

backend_util.assertAxesAreInnerMostDims('max', origAxes, x.shape.length);
const [outShape, reduceShape] =
backend_util.computeOutAndReduceShapes(x.shape, axes);
backend_util.computeOutAndReduceShapes(x.shape, origAxes);
const reduceSize = util.sizeFromShape(reduceShape);

const out = backend.makeOutput(outShape, x.dtype);
Expand All @@ -54,12 +48,9 @@ function max(args: {backend: BackendWasm, inputs: MaxInputs, attrs: MaxAttrs}):
const outId = backend.dataIdMap.get(out.dataId).id;

wasmMax(xId, reduceSize, outId);

return out;
}

registerKernel({
kernelName: 'Max',
backendName: 'wasm',
setupFunc: setup,
kernelFunc: max
});
registerKernel(
{kernelName: Max, backendName: 'wasm', setupFunc: setup, kernelFunc: max});
2 changes: 1 addition & 1 deletion tfjs-backend-wasm/src/kernels/Reshape.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ interface ReshapeAttrs extends NamedAttrMap {
shape: number[];
}

function reshape(
export function reshape(
args: {inputs: ReshapeInputs, attrs: ReshapeAttrs, backend: BackendWasm}) {
const {inputs: {x}, attrs: {shape}} = args;
return {dataId: x.dataId, shape, dtype: x.dtype};
Expand Down
19 changes: 4 additions & 15 deletions tfjs-backend-webgl/src/backend_webgl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import './flags_webgl';

import * as tf from '@tensorflow/tfjs-core';
import {complex, DataId, div, engine, env, imag, MemoryInfo, range, real, RecursiveArray, scalar, softmax, tensor, tidy, TimingInfo, transpose} from '@tensorflow/tfjs-core';
import {complex, DataId, div, engine, env, imag, max, MemoryInfo, range, real, RecursiveArray, scalar, softmax, tensor, tidy, TimingInfo, transpose} from '@tensorflow/tfjs-core';
import {backend_util, buffer, kernel_impls, slice_util, util} from '@tensorflow/tfjs-core';
import {DataStorage, DataType, KernelBackend, NumericDataType, Rank, Scalar, ShapeMap, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D, TensorInfo, TypedArray, upcastType} from '@tensorflow/tfjs-core';

Expand Down Expand Up @@ -1304,19 +1304,6 @@ export class MathBackendWebGL extends KernelBackend {
return this.compileAndRun(program, [a, b]);
}

max(x: Tensor, axes: number[]): Tensor {
if (this.shouldExecuteOnCPU([x])) {
return this.cpuBackend.max(x, axes);
}

backend_util.assertAxesAreInnerMostDims('max', axes, x.rank);
const [outShape, reduceShape] =
backend_util.computeOutAndReduceShapes(x.shape, axes);
const inSize = util.sizeFromShape(reduceShape);
const a2D = x.as2D(-1, inSize);
return this.reduce(a2D, 'max', a2D.dtype).reshape(outShape);
}

maximum(a: Tensor, b: Tensor): Tensor {
if (this.shouldExecuteOnCPU([a, b])) {
return this.cpuBackend.maximum(a, b);
Expand Down Expand Up @@ -1553,7 +1540,9 @@ export class MathBackendWebGL extends KernelBackend {

softmax<T extends Tensor>(logits: T, dim: number): T {
const axes = util.parseAxisParam([dim], logits.shape);
const maxLogit = this.max(logits, axes);
// TODO(annxingyuan): Call maxImpl rather than op as part of softmax kernel
// modularization.
const maxLogit = max(logits, axes);
const expandedShape =
backend_util.expandShapeToKeepDim(maxLogit.shape, axes);
const a = this.subtract(logits, maxLogit.reshape(expandedShape));
Expand Down
39 changes: 39 additions & 0 deletions tfjs-backend-webgl/src/kernel_utils/reduce.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import {backend_util, DataType, TensorInfo} from '@tensorflow/tfjs-core';

import {MathBackendWebGL} from '../backend_webgl';
import {ReduceProgram} from '../reduce_gpu';

type ReduceTypes = 'all'|'any'|'max'|'min'|'sum'|'prod';

export function reduce(
x: TensorInfo, dtype: DataType, reductionType: ReduceTypes,
backend: MathBackendWebGL): TensorInfo {
const [batchSize, inSize] = x.shape;
const windowSize = backend_util.computeOptimalWindowSize(inSize);
const reduceInfo = {windowSize, inSize, batchSize};
const program = new ReduceProgram(reduceInfo, reductionType);
const output = backend.runWebGLProgram(program, [x], dtype);

if (output.shape[1] === 1) {
return output;
}

return reduce(output, dtype, reductionType, backend);
}
Loading