Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
5dd14db
initial
annxingyuan Mar 24, 2020
08260bb
merge
annxingyuan Mar 24, 2020
bb9543d
save xorig
annxingyuan Mar 24, 2020
fe39bb5
save output
annxingyuan Mar 25, 2020
e686064
max
annxingyuan Mar 25, 2020
70caa0d
separate out
annxingyuan Mar 25, 2020
a97c5fb
remove max
annxingyuan Mar 25, 2020
4e9ae96
cpu forward
annxingyuan Mar 25, 2020
2173ab2
create tensor info
annxingyuan Mar 25, 2020
face1b2
simplify
annxingyuan Mar 25, 2020
5ff78b4
split out impl
annxingyuan Mar 25, 2020
3a34f18
unchain
annxingyuan Mar 25, 2020
fcf1de2
merge
annxingyuan Mar 25, 2020
c16de99
return out directly
annxingyuan Mar 25, 2020
4a9f314
merge
annxingyuan Mar 27, 2020
7abe395
rename
annxingyuan Mar 27, 2020
8ead8d7
kernelize
annxingyuan Mar 27, 2020
cc7685b
wip
annxingyuan Mar 27, 2020
80b28e8
merge
annxingyuan Apr 22, 2020
f9c0df3
fix
annxingyuan Apr 22, 2020
f93656e
webgl build
annxingyuan Apr 22, 2020
7b14faf
fix
annxingyuan Apr 22, 2020
9020fed
max
annxingyuan Apr 22, 2020
17a46a2
add more logs
annxingyuan Apr 22, 2020
0e3ea93
add logs
annxingyuan Apr 22, 2020
e0c834b
logs
annxingyuan Apr 22, 2020
3a3622c
logs
annxingyuan Apr 22, 2020
10b8cb9
rem logs
annxingyuan Apr 22, 2020
e38df55
logs
annxingyuan Apr 22, 2020
04de4ce
remove hacks
annxingyuan Apr 22, 2020
61ff6ae
Merge branch 'master' into modular_max
annxingyuan Apr 23, 2020
0e2cb52
add logs
annxingyuan Apr 23, 2020
f068e52
add logs
annxingyuan Apr 23, 2020
4262572
add transpose to grad
annxingyuan Apr 23, 2020
c864263
add condition
annxingyuan Apr 23, 2020
2f941f6
passes
annxingyuan Apr 23, 2020
36cd5ec
reove logs
annxingyuan Apr 23, 2020
87b5dbc
add failing tests
annxingyuan Apr 23, 2020
9e805c9
Merge branch 'master' into modular_max
annxingyuan Apr 27, 2020
018bb6e
dispose
annxingyuan Apr 27, 2020
0227876
remove fit
annxingyuan Apr 27, 2020
cef4a01
fix wasm
annxingyuan Apr 27, 2020
7ce546c
undo
annxingyuan Apr 28, 2020
82019c1
fix test
annxingyuan Apr 28, 2020
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 3 additions & 26 deletions tfjs-backend-cpu/src/backend_cpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,9 @@ export class MathBackendCPU extends KernelBackend {

softmax<T extends Tensor>(logits: T, dim: number): T {
const axes = util.parseAxisParam([dim], logits.shape);
const maxLogit = this.max(logits, axes);
// TODO(annxingyuan): Call maxImpl rather than op as part of softmax kernel
// modularization.
const maxLogit = tf.max(logits, axes);
const expandedShape =
backend_util.expandShapeToKeepDim(maxLogit.shape, axes);
const a = this.subtract(logits, maxLogit.reshape(expandedShape));
Expand Down Expand Up @@ -807,31 +809,6 @@ export class MathBackendCPU extends KernelBackend {
});
}

max(x: Tensor, axes: number[]): Tensor {
assertNotComplex(x, 'max');

backend_util.assertAxesAreInnerMostDims('max', axes, x.rank);
const [outShape, reduceShape] =
backend_util.computeOutAndReduceShapes(x.shape, axes);
const result = tf.zeros(outShape, x.dtype);
const reduceSize = util.sizeFromShape(reduceShape);
const vals = this.readSync(result.dataId) as TypedArray;

const aVals = this.readSync(x.dataId) as TypedArray;
for (let i = 0; i < vals.length; ++i) {
const offset = i * reduceSize;
let max = aVals[offset];
for (let j = 0; j < reduceSize; ++j) {
const value = aVals[offset + j];
if (value > max) {
max = value;
}
}
vals[i] = max;
}
return result;
}

maximum(a: Tensor, b: Tensor): Tensor {
assertNotComplex([a, b], 'maximum');

Expand Down
66 changes: 66 additions & 0 deletions tfjs-backend-cpu/src/kernels/Max.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/**
* @license
* Copyright 2020 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import {Max, MaxAttrs, MaxInputs} from '@tensorflow/tfjs-core';
import {backend_util, KernelConfig} from '@tensorflow/tfjs-core';
import {TypedArray, util} from '@tensorflow/tfjs-core';

import {MathBackendCPU} from '../backend_cpu';
import {assertNotComplex} from '../cpu_util';

import {maxImpl} from './Max_impl';
import {transposeImpl} from './Transpose_impl';

export const maxConfig: KernelConfig = {
kernelName: Max,
backendName: 'cpu',
kernelFunc: ({inputs, attrs, backend}) => {
const {x} = inputs as MaxInputs;
const {reductionIndices} = attrs as {} as MaxAttrs;
const cpuBackend = backend as MathBackendCPU;
let xShape = x.shape;
const xRank = xShape.length;

const origAxes = util.parseAxisParam(reductionIndices, xShape);
let axes = origAxes;
const permutedAxes = backend_util.getAxesPermutation(axes, xRank);
let xVals = cpuBackend.data.get(x.dataId).values as TypedArray;
if (permutedAxes != null) {
xVals = transposeImpl(xVals, xShape, x.dtype, permutedAxes);
axes = backend_util.getInnerMostAxes(axes.length, xRank);

const newShape: number[] = new Array(xRank);
for (let i = 0; i < newShape.length; i++) {
newShape[i] = xShape[permutedAxes[i]];
}

xShape = newShape;
}

assertNotComplex(x, 'max');
backend_util.assertAxesAreInnerMostDims('max', axes, xRank);
const [outShape, reduceShape] =
backend_util.computeOutAndReduceShapes(xShape, axes);

const reduceSize = util.sizeFromShape(reduceShape);

const result = maxImpl(xVals, reduceSize, outShape, x.dtype);

const dataId = cpuBackend.write(result, outShape, x.dtype);
return {dataId, shape: outShape, dtype: x.dtype};
}
};
38 changes: 38 additions & 0 deletions tfjs-backend-cpu/src/kernels/Max_impl.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/**
* @license
* Copyright 2020 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import {DataType, NumericDataType, TypedArray, util} from '@tensorflow/tfjs-core';

export function maxImpl(
aVals: TypedArray, reduceSize: number, outShape: number[],
dtype: DataType): TypedArray {
const vals = util.getTypedArrayFromDType(
dtype as NumericDataType, util.sizeFromShape(outShape));

for (let i = 0; i < vals.length; ++i) {
const offset = i * reduceSize;
let max = aVals[offset];
for (let j = 0; j < reduceSize; ++j) {
const value = aVals[offset + j];
if (value > max) {
max = value;
}
}
vals[i] = max;
}
return vals;
}
2 changes: 1 addition & 1 deletion tfjs-backend-cpu/src/kernels/Transpose.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ export const transposeConfig: KernelConfig = {
}

const values = cpuBackend.data.get(x.dataId).values as TypedArray;
const result = transposeImpl(values, x.shape, x.dtype, perm, newShape);
const result = transposeImpl(values, x.shape, x.dtype, perm);

const dataId = cpuBackend.write(result, newShape, x.dtype);
return {dataId, shape: newShape, dtype: x.dtype};
Expand Down
11 changes: 8 additions & 3 deletions tfjs-backend-cpu/src/kernels/Transpose_impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,15 @@ import {DataType, NumericDataType, TypedArray} from '@tensorflow/tfjs-core';
import {util} from '@tensorflow/tfjs-core';

export function transposeImpl(
xVals: TypedArray, xShape: number[], dtype: DataType, perm: number[],
newShape: number[]): TypedArray {
const xSize = util.sizeFromShape(xShape);
xVals: TypedArray, xShape: number[], dtype: DataType,
perm: number[]): TypedArray {
const xRank = xShape.length;
const newShape: number[] = new Array(xRank);
for (let i = 0; i < newShape.length; i++) {
newShape[i] = xShape[perm[i]];
}

const xSize = util.sizeFromShape(xShape);
const xStrides = util.computeStrides(xShape);
const newStrides = util.computeStrides(newShape);

Expand Down
3 changes: 2 additions & 1 deletion tfjs-backend-cpu/src/register_all_kernels.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import {KernelConfig, registerKernel} from '@tensorflow/tfjs-core';

import {divConfig} from './kernels/Div';
import {maxConfig} from './kernels/Max';
import {maxPoolWithArgmaxConfig} from './kernels/MaxPoolWithArgmax';
import {nonMaxSuppressionV5Config} from './kernels/NonMaxSuppressionV5';
import {squareConfig} from './kernels/Square';
Expand All @@ -29,7 +30,7 @@ import {transposeConfig} from './kernels/Transpose';
// List all kernel configs here
const kernelConfigs: KernelConfig[] = [
nonMaxSuppressionV5Config, squareConfig, squaredDifferenceConfig, divConfig,
transposeConfig, maxPoolWithArgmaxConfig
transposeConfig, maxPoolWithArgmaxConfig, maxConfig
];

for (const kernelConfig of kernelConfigs) {
Expand Down
5 changes: 0 additions & 5 deletions tfjs-backend-cpu/yarn.lock
Original file line number Diff line number Diff line change
Expand Up @@ -139,11 +139,6 @@ acorn@^6.0.5:
resolved "https://registry.yarnpkg.com/acorn/-/acorn-6.4.1.tgz#531e58ba3f51b9dacb9a6646ca4debf5b14ca474"
integrity sha512-ZVA9k326Nwrj3Cj9jlh3wGFutC2ZornPNARZwsNYqQYgN0EsV2d53w5RN/co65Ohn4sUAUtb1rSUAOD6XN9idA==

acorn@^7.1.1:
version "7.1.1"
resolved "https://registry.yarnpkg.com/acorn/-/acorn-7.1.1.tgz#e35668de0b402f359de515c5482a1ab9f89a69bf"
integrity sha512-add7dgA5ppRPxCFJoAGfMDi7PIBXq1RtGo7BhbLaxwrXPOmw8gq48Y9ozT01hUKy9byMjlR20EJhu5zlkErEkg==

after@0.8.2:
version "0.8.2"
resolved "https://registry.yarnpkg.com/after/-/after-0.8.2.tgz#fedb394f9f0e02aa9768e702bda23b505fae7e1f"
Expand Down
19 changes: 4 additions & 15 deletions tfjs-backend-webgl/src/backend_webgl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import './flags_webgl';

import * as tf from '@tensorflow/tfjs-core';
import {complex, DataId, div, engine, env, imag, MemoryInfo, range, real, RecursiveArray, scalar, softmax, tensor, tidy, TimingInfo, transpose} from '@tensorflow/tfjs-core';
import {complex, DataId, div, engine, env, imag, max, MemoryInfo, range, real, RecursiveArray, scalar, softmax, tensor, tidy, TimingInfo, transpose} from '@tensorflow/tfjs-core';
import {backend_util, buffer, kernel_impls, slice_util, util} from '@tensorflow/tfjs-core';
import {DataStorage, DataType, KernelBackend, NumericDataType, Rank, Scalar, ShapeMap, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D, TensorInfo, TypedArray, upcastType} from '@tensorflow/tfjs-core';

Expand Down Expand Up @@ -1305,19 +1305,6 @@ export class MathBackendWebGL extends KernelBackend {
return this.compileAndRun(program, [a, b]);
}

max(x: Tensor, axes: number[]): Tensor {
if (this.shouldExecuteOnCPU([x])) {
return this.cpuBackend.max(x, axes);
}

backend_util.assertAxesAreInnerMostDims('max', axes, x.rank);
const [outShape, reduceShape] =
backend_util.computeOutAndReduceShapes(x.shape, axes);
const inSize = util.sizeFromShape(reduceShape);
const a2D = x.as2D(-1, inSize);
return this.reduce(a2D, 'max', a2D.dtype).reshape(outShape);
}

maximum(a: Tensor, b: Tensor): Tensor {
if (this.shouldExecuteOnCPU([a, b])) {
return this.cpuBackend.maximum(a, b);
Expand Down Expand Up @@ -1554,7 +1541,9 @@ export class MathBackendWebGL extends KernelBackend {

softmax<T extends Tensor>(logits: T, dim: number): T {
const axes = util.parseAxisParam([dim], logits.shape);
const maxLogit = this.max(logits, axes);
// TODO(annxingyuan): Call maxImpl rather than op as part of softmax kernel
// modularization.
const maxLogit = max(logits, axes);
const expandedShape =
backend_util.expandShapeToKeepDim(maxLogit.shape, axes);
const a = this.subtract(logits, maxLogit.reshape(expandedShape));
Expand Down
37 changes: 37 additions & 0 deletions tfjs-backend-webgl/src/kernel_utils/reduce.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import {backend_util, DataType, TensorInfo} from '@tensorflow/tfjs-core';

import {MathBackendWebGL} from '../backend_webgl';
import {ReduceProgram} from '../reduce_gpu';

export function reduce(
x: TensorInfo, dtype: DataType, reductionType: backend_util.ReduceTypes,
backend: MathBackendWebGL): TensorInfo {
const [batchSize, inSize] = x.shape;
const windowSize = backend_util.computeOptimalWindowSize(inSize);
const reduceInfo = {windowSize, inSize, batchSize};
const program = new ReduceProgram(reduceInfo, reductionType);
const output = backend.runWebGLProgram(program, [x], dtype);

if (output.shape[1] === 1) {
return output;
}

return reduce(output, dtype, reductionType, backend);
}
58 changes: 58 additions & 0 deletions tfjs-backend-webgl/src/kernel_utils/reshape.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import {TensorInfo} from '@tensorflow/tfjs-core';

import {MathBackendWebGL} from '../backend_webgl';
import {ReshapePackedProgram} from '../reshape_packed_gpu';
import {getBatchDim, getRowsCols, isReshapeFree} from '../webgl_util';

function packedReshape(
input: TensorInfo, afterShape: number[],
backend: MathBackendWebGL): TensorInfo {
const input3DShape =
[getBatchDim(input.shape),
...getRowsCols(input.shape)] as [number, number, number];
const input3D: TensorInfo = {
dtype: input.dtype,
shape: input3DShape,
dataId: input.dataId
};
const afterShapeAs3D =
[getBatchDim(afterShape),
...getRowsCols(afterShape)] as [number, number, number];

const program = new ReshapePackedProgram(afterShapeAs3D, input3DShape);
const preventEagerUnpackingOfOutput = true;
const output = backend.runWebGLProgram(
program, [input3D], input.dtype, null /* customSetup */,
preventEagerUnpackingOfOutput);
return {dataId: output.dataId, shape: afterShape, dtype: output.dtype};
}

export function reshape(
x: TensorInfo, afterShape: number[],
backend: MathBackendWebGL): TensorInfo {
const xTexData = backend.texData.get(x.dataId);
if (xTexData.isPacked && !isReshapeFree(x.shape, afterShape) &&
!(xTexData.texture !== null &&
isReshapeFree(xTexData.shape, afterShape))) {
return packedReshape(x, afterShape, backend);
}

return {dataId: x.dataId, shape: afterShape, dtype: x.dtype};
}
Loading