From 5dd14dbf3e957f745bb0071a209a0afbe125cbb6 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Tue, 24 Mar 2020 08:03:01 -0400 Subject: [PATCH 01/72] initial --- tfjs-core/src/backends/cpu/backend_cpu.ts | 30 +------ tfjs-core/src/backends/cpu/kernels/Max.ts | 67 +++++++++++++++ .../src/backends/cpu/register_all_kernels.ts | 6 +- tfjs-core/src/gradients/Max_grad.ts | 36 ++++++++ tfjs-core/src/kernel_names.ts | 6 ++ tfjs-core/src/ops/max.ts | 86 +++++++++++++++++++ tfjs-core/src/ops/ops.ts | 1 + tfjs-core/src/ops/reduction_ops.ts | 61 +------------ tfjs-core/src/public/chained_ops/max.ts | 31 +++++++ .../chained_ops/register_all_chained_ops.ts | 1 + .../register_all_chained_ops_test.ts | 1 + tfjs-core/src/register_all_gradients.ts | 2 + tfjs-core/src/tensor.ts | 5 -- 13 files changed, 239 insertions(+), 94 deletions(-) create mode 100644 tfjs-core/src/backends/cpu/kernels/Max.ts create mode 100644 tfjs-core/src/gradients/Max_grad.ts create mode 100644 tfjs-core/src/ops/max.ts create mode 100644 tfjs-core/src/public/chained_ops/max.ts diff --git a/tfjs-core/src/backends/cpu/backend_cpu.ts b/tfjs-core/src/backends/cpu/backend_cpu.ts index f45855309a8..6c663425249 100644 --- a/tfjs-core/src/backends/cpu/backend_cpu.ts +++ b/tfjs-core/src/backends/cpu/backend_cpu.ts @@ -30,6 +30,7 @@ import {Conv2DInfo, Conv3DInfo} from '../../ops/conv_util'; import * as erf_util from '../../ops/erf_util'; import {Activation, FusedBatchMatMulConfig, FusedConv2DConfig} from '../../ops/fused_util'; import * as gather_nd_util from '../../ops/gather_nd_util'; +import {max} from '../../ops/max'; import * as ops from '../../ops/ops'; import {buffer, scalar, tensor, tensor4d} from '../../ops/ops'; import * as scatter_nd_util from '../../ops/scatter_nd_util'; @@ -379,7 +380,9 @@ export class MathBackendCPU extends KernelBackend { softmax(logits: T, dim: number): T { const axes = util.parseAxisParam([dim], logits.shape); - const maxLogit = this.max(logits, axes); + // TODO(annxingyuan): Call maxImpl rather than op as part of softmax kernel + // modularization. + const maxLogit = max(logits, axes); const expandedShape = axis_util.expandShapeToKeepDim(maxLogit.shape, axes); const a = this.subtract(logits, maxLogit.reshape(expandedShape)); const b = this.exp(a); @@ -826,31 +829,6 @@ export class MathBackendCPU extends KernelBackend { }); } - max(x: Tensor, axes: number[]): Tensor { - assertNotComplex(x, 'max'); - - axis_util.assertAxesAreInnerMostDims('max', axes, x.rank); - const [outShape, reduceShape] = - axis_util.computeOutAndReduceShapes(x.shape, axes); - const result = ops.zeros(outShape, x.dtype); - const reduceSize = util.sizeFromShape(reduceShape); - const vals = this.readSync(result.dataId) as TypedArray; - - const aVals = this.readSync(x.dataId) as TypedArray; - for (let i = 0; i < vals.length; ++i) { - const offset = i * reduceSize; - let max = aVals[offset]; - for (let j = 0; j < reduceSize; ++j) { - const value = aVals[offset + j]; - if (value > max) { - max = value; - } - } - vals[i] = max; - } - return result; - } - maximum(a: Tensor, b: Tensor): Tensor { assertNotComplex([a, b], 'maximum'); diff --git a/tfjs-core/src/backends/cpu/kernels/Max.ts b/tfjs-core/src/backends/cpu/kernels/Max.ts new file mode 100644 index 00000000000..1966486e820 --- /dev/null +++ b/tfjs-core/src/backends/cpu/kernels/Max.ts @@ -0,0 +1,67 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Max, MaxAttrs, MaxInputs} from '../../../kernel_names'; +import {KernelConfig} from '../../../kernel_registry'; +import * as axis_util from '../../../ops/axis_util'; +import {DataType, NumericDataType, TypedArray} from '../../../types'; +import * as util from '../../../util'; +import {MathBackendCPU} from '../backend_cpu'; +import {assertNotComplex} from '../cpu_util'; + +export function maxImpl( + aVals: TypedArray, reduceSize: number, outShape: number[], + dtype: DataType): TypedArray { + const vals = util.getTypedArrayFromDType( + dtype as NumericDataType, util.sizeFromShape(outShape)); + + for (let i = 0; i < vals.length; ++i) { + const offset = i * reduceSize; + let max = aVals[offset]; + for (let j = 0; j < reduceSize; ++j) { + const value = aVals[offset + j]; + if (value > max) { + max = value; + } + } + vals[i] = max; + } + return vals; +} + +export const maxConfig: KernelConfig = { + kernelName: Max, + backendName: 'cpu', + kernelFunc: ({inputs, attrs, backend}) => { + const {x} = inputs as MaxInputs; + const {axes} = attrs as {} as MaxAttrs; + const cpuBackend = backend as MathBackendCPU; + + assertNotComplex(x, 'max'); + axis_util.assertAxesAreInnerMostDims('max', axes, x.shape.length); + const [outShape, reduceShape] = + axis_util.computeOutAndReduceShapes(x.shape, axes); + + const reduceSize = util.sizeFromShape(reduceShape); + + const aVals = cpuBackend.data.get(x.dataId).values as TypedArray; + const result = maxImpl(aVals, reduceSize, outShape, x.dtype); + + const dataId = cpuBackend.write(result, outShape, x.dtype); + return {dataId, shape: outShape, dtype: x.dtype}; + } +}; diff --git a/tfjs-core/src/backends/cpu/register_all_kernels.ts b/tfjs-core/src/backends/cpu/register_all_kernels.ts index 9516f9af311..bae3d571d29 100644 --- a/tfjs-core/src/backends/cpu/register_all_kernels.ts +++ b/tfjs-core/src/backends/cpu/register_all_kernels.ts @@ -19,15 +19,15 @@ // the contents of this file and import only the kernels that are needed. import {KernelConfig, registerKernel} from '../../kernel_registry'; +import {maxConfig} from './kernels/Max'; import {nonMaxSuppressionV5Config} from './kernels/NonMaxSuppressionV5'; import {squareConfig} from './kernels/Square'; import {squaredDifferenceConfig} from './kernels/SquaredDifference'; // List all kernel configs here const kernelConfigs: KernelConfig[] = [ - nonMaxSuppressionV5Config, - squareConfig, - squaredDifferenceConfig, + nonMaxSuppressionV5Config, squareConfig, squaredDifferenceConfig, maxConfig, + nonMaxSuppressionV5Config, squareConfig, squaredDifferenceConfig ]; for (const kernelConfig of kernelConfigs) { diff --git a/tfjs-core/src/gradients/Max_grad.ts b/tfjs-core/src/gradients/Max_grad.ts new file mode 100644 index 00000000000..9c004e56a96 --- /dev/null +++ b/tfjs-core/src/gradients/Max_grad.ts @@ -0,0 +1,36 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Max, MaxAttrs} from '../kernel_names'; +import {GradConfig, NamedAttrMap} from '../kernel_registry'; +import * as axis_util from '../ops/axis_util'; +import {gradForMinAndMax} from '../ops/reduction_ops'; +import {Tensor} from '../tensor'; +import * as util from '../util'; + +export const maxGradConfig: GradConfig = { + kernelName: Max, + inputsToSave: ['x'], + gradFunc: (dy: Tensor, saved: Tensor[], attrs: NamedAttrMap) => { + const maxAttrs: MaxAttrs = attrs as {} as MaxAttrs; + const {axes} = maxAttrs; + const [x, y] = saved; + const origAxes = util.parseAxisParam(axes, x.shape); + const permutedAxes = axis_util.getAxesPermutation(axes, x.rank); + return gradForMinAndMax(dy, y, x, origAxes, permutedAxes); + } +}; diff --git a/tfjs-core/src/kernel_names.ts b/tfjs-core/src/kernel_names.ts index 201d73ac05f..77db2c278d5 100644 --- a/tfjs-core/src/kernel_names.ts +++ b/tfjs-core/src/kernel_names.ts @@ -37,6 +37,12 @@ export interface NonMaxSuppressionV5Attrs { softNmsSigma: number; } +export const Max = 'Max'; +export type MaxInputs = Pick; +export interface MaxAttrs { + axes: number[]; +} + export const BroadcastTo = 'BroadcastTo'; export type BroadcastToInputs = Pick; export interface BroadCastToAttrs { diff --git a/tfjs-core/src/ops/max.ts b/tfjs-core/src/ops/max.ts new file mode 100644 index 00000000000..21fec7a1d97 --- /dev/null +++ b/tfjs-core/src/ops/max.ts @@ -0,0 +1,86 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {ENGINE} from '../engine'; +import {Tensor} from '../tensor'; +import {convertToTensor} from '../tensor_util_env'; +import {TensorLike} from '../types'; +import * as util from '../util'; + +import * as axis_util from './axis_util'; +import {op} from './operation'; +import {gradForMinAndMax} from './reduction_ops'; + +/** + * Computes the maximum of elements across dimensions of a `tf.Tensor`. + * + * Reduces the input along the dimensions given in `axes`. Unless `keepDims` + * is true, the rank of the `tf.Tensor` is reduced by 1 for each entry in + * `axes`. If `keepDims` is true, the reduced dimensions are retained with + * length 1. If `axes` has no entries, all dimensions are reduced, and an + * `tf.Tensor` with a single element is returned. + * + * ```js + * const x = tf.tensor1d([1, 2, 3]); + * + * x.max().print(); // or tf.max(x) + * ``` + * + * ```js + * const x = tf.tensor2d([1, 2, 3, 4], [2, 2]); + * + * const axis = 1; + * x.max(axis).print(); // or tf.max(x, axis) + * ``` + * + * @param x The input tensor. + * @param axis The dimension(s) to reduce. By default it reduces + * all dimensions. + * @param keepDims If true, retains reduced dimensions with size 1. + */ +/** @doc {heading: 'Operations', subheading: 'Reduction'} */ +function max_( + x: Tensor|TensorLike, axis: number|number[] = null, keepDims = false): T { + let $x = convertToTensor(x, 'x', 'max'); + const xOrig = $x; + + const origAxes = util.parseAxisParam(axis, $x.shape); + let axes = origAxes; + const permutedAxes = axis_util.getAxesPermutation(axes, $x.rank); + if (permutedAxes != null) { + $x = $x.transpose(permutedAxes); + axes = axis_util.getInnerMostAxes(axes.length, $x.rank); + } + + const grad = (dy: T, saved: Tensor[]) => + gradForMinAndMax(dy, saved[1], saved[0], origAxes, permutedAxes); + + const inputsToSave = [$x]; + const outputsToSave: boolean[] = [true]; + let res = ENGINE.runKernelFunc((backend, save) => { + const y = backend.max($x, axes); + save([xOrig, y]); + return y; + }, {x: $x}, grad, 'Max', {axes}, inputsToSave, outputsToSave); + if (keepDims) { + const newShape = axis_util.expandShapeToKeepDim(res.shape, origAxes); + res = res.reshape(newShape) as T; + } + return res as T; +} + +export const max = op({max_}); diff --git a/tfjs-core/src/ops/ops.ts b/tfjs-core/src/ops/ops.ts index bbb169d6293..c58557ceea3 100644 --- a/tfjs-core/src/ops/ops.ts +++ b/tfjs-core/src/ops/ops.ts @@ -18,6 +18,7 @@ // Modularized ops. export {broadcastTo} from './broadcast_to'; export {clone} from './clone'; +export {max} from './max'; export {multinomial} from './multinomial'; export {rand} from './rand'; export {randomGamma} from './random_gamma'; diff --git a/tfjs-core/src/ops/reduction_ops.ts b/tfjs-core/src/ops/reduction_ops.ts index 37d8f0dfa8d..68fb6be6a45 100644 --- a/tfjs-core/src/ops/reduction_ops.ts +++ b/tfjs-core/src/ops/reduction_ops.ts @@ -270,7 +270,7 @@ function mean_( /** * Gradient helper function for the min and max operations. */ -function gradForMinAndMax( +export function gradForMinAndMax( dy: T, y: T, xOrig: Tensor, origAxes: number[], permutedAxes: number[]) { if (y.rank < xOrig.rank) { y = y.reshape(axis_util.expandShapeToKeepDim(y.shape, origAxes)) as T; @@ -344,64 +344,6 @@ function min_( return res; } -/** - * Computes the maximum of elements across dimensions of a `tf.Tensor`. - * - * Reduces the input along the dimensions given in `axes`. Unless `keepDims` - * is true, the rank of the `tf.Tensor` is reduced by 1 for each entry in - * `axes`. If `keepDims` is true, the reduced dimensions are retained with - * length 1. If `axes` has no entries, all dimensions are reduced, and an - * `tf.Tensor` with a single element is returned. - * - * ```js - * const x = tf.tensor1d([1, 2, 3]); - * - * x.max().print(); // or tf.max(x) - * ``` - * - * ```js - * const x = tf.tensor2d([1, 2, 3, 4], [2, 2]); - * - * const axis = 1; - * x.max(axis).print(); // or tf.max(x, axis) - * ``` - * - * @param x The input tensor. - * @param axis The dimension(s) to reduce. By default it reduces - * all dimensions. - * @param keepDims If true, retains reduced dimensions with size 1. - */ -/** @doc {heading: 'Operations', subheading: 'Reduction'} */ -function max_( - x: Tensor|TensorLike, axis: number|number[] = null, keepDims = false): T { - let $x = convertToTensor(x, 'x', 'max'); - const xOrig = $x; - - const origAxes = util.parseAxisParam(axis, $x.shape); - let axes = origAxes; - const permutedAxes = axis_util.getAxesPermutation(axes, $x.rank); - if (permutedAxes != null) { - $x = $x.transpose(permutedAxes); - axes = axis_util.getInnerMostAxes(axes.length, $x.rank); - } - - const grad = (dy: T, saved: Tensor[]) => - gradForMinAndMax(dy, saved[1], saved[0], origAxes, permutedAxes); - - const inputsToSave = [$x]; - const outputsToSave: boolean[] = [true]; - let res = ENGINE.runKernelFunc((backend, save) => { - const y = backend.max($x, axes); - save([xOrig, y]); - return y; - }, {x: $x}, grad, 'Max', {axes}, inputsToSave, outputsToSave); - if (keepDims) { - const newShape = axis_util.expandShapeToKeepDim(res.shape, origAxes); - res = res.reshape(newShape) as T; - } - return res as T; -} - /** * Returns the indices of the minimum values along an `axis`. * @@ -625,7 +567,6 @@ export const any = op({any_}); export const argMax = op({argMax_}); export const argMin = op({argMin_}); export const logSumExp = op({logSumExp_}); -export const max = op({max_}); export const mean = op({mean_}); export const min = op({min_}); export const moments = op({moments_}); diff --git a/tfjs-core/src/public/chained_ops/max.ts b/tfjs-core/src/public/chained_ops/max.ts new file mode 100644 index 00000000000..c49ffdd7eab --- /dev/null +++ b/tfjs-core/src/public/chained_ops/max.ts @@ -0,0 +1,31 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {max} from '../../ops/max'; +import {Tensor} from '../../tensor'; +import {Rank} from '../../types'; + +declare module '../../tensor' { + interface Tensor { + max(axis?: number|number[], keepDims?: boolean): T; + } +} + +Tensor.prototype.max = function( + axis?: number|number[], keepDims?: boolean): T { + return max(this, axis, keepDims); +}; diff --git a/tfjs-core/src/public/chained_ops/register_all_chained_ops.ts b/tfjs-core/src/public/chained_ops/register_all_chained_ops.ts index 3a79a8dd6d9..aa580ebfa18 100644 --- a/tfjs-core/src/public/chained_ops/register_all_chained_ops.ts +++ b/tfjs-core/src/public/chained_ops/register_all_chained_ops.ts @@ -17,3 +17,4 @@ import './squared_difference'; import './broadcast_to'; +import './max'; diff --git a/tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts b/tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts index 0ce556c9aa0..fc2445d2e6b 100644 --- a/tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts +++ b/tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts @@ -26,6 +26,7 @@ import {ALL_ENVS, describeWithFlags} from '../../jasmine_util'; const CHAINED_OPS = [ 'square', 'broadcastTo', + 'max', ]; describeWithFlags('chained ops', ALL_ENVS, () => { diff --git a/tfjs-core/src/register_all_gradients.ts b/tfjs-core/src/register_all_gradients.ts index f7a934358a1..23bf408187d 100644 --- a/tfjs-core/src/register_all_gradients.ts +++ b/tfjs-core/src/register_all_gradients.ts @@ -16,6 +16,7 @@ */ import {broadcastToGradConfig} from './gradients/BroadcastTo_grad'; import {identityGradConfig} from './gradients/Identity_grad'; +import {maxGradConfig} from './gradients/Max_grad'; import {squareGradConfig} from './gradients/Square_grad'; import {squaredDifferenceGradConfig} from './gradients/SquaredDifference_grad'; import {GradConfig} from './kernel_registry'; @@ -25,6 +26,7 @@ import {registerGradient} from './kernel_registry'; const gradConfigs: GradConfig[] = [ squareGradConfig, squaredDifferenceGradConfig, + maxGradConfig, broadcastToGradConfig, identityGradConfig, ]; diff --git a/tfjs-core/src/tensor.ts b/tfjs-core/src/tensor.ts index cee26c4cde3..cb8a5e0f50e 100644 --- a/tfjs-core/src/tensor.ts +++ b/tfjs-core/src/tensor.ts @@ -217,7 +217,6 @@ export interface OpHandler { mean(x: Tensor, axis: number|number[], keepDims: boolean): T; min(x: Tensor, axis: number|number[], keepDims: boolean): T; - max(x: Tensor, axis: number|number[], keepDims: boolean): T; argMin(x: Tensor, axis: number): T; argMax(x: Tensor, axis: number): T; add(a: Tensor, b: Tensor|TensorLike): T; @@ -898,10 +897,6 @@ export class Tensor { this.throwIfDisposed(); return opHandler.min(this, axis, keepDims); } - max(axis: number|number[] = null, keepDims = false): T { - this.throwIfDisposed(); - return opHandler.max(this, axis, keepDims); - } argMin(axis: number = null): T { this.throwIfDisposed(); return opHandler.argMin(this, axis); From 08260bbae804ea0e13c531d3415469b5ebae3f54 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Tue, 24 Mar 2020 15:53:51 -0400 Subject: [PATCH 02/72] merge --- tfjs-core/src/backends/cpu/register_all_kernels.ts | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tfjs-core/src/backends/cpu/register_all_kernels.ts b/tfjs-core/src/backends/cpu/register_all_kernels.ts index bae3d571d29..fc36171f39d 100644 --- a/tfjs-core/src/backends/cpu/register_all_kernels.ts +++ b/tfjs-core/src/backends/cpu/register_all_kernels.ts @@ -26,8 +26,7 @@ import {squaredDifferenceConfig} from './kernels/SquaredDifference'; // List all kernel configs here const kernelConfigs: KernelConfig[] = [ - nonMaxSuppressionV5Config, squareConfig, squaredDifferenceConfig, maxConfig, - nonMaxSuppressionV5Config, squareConfig, squaredDifferenceConfig + nonMaxSuppressionV5Config, squareConfig, squaredDifferenceConfig, maxConfig ]; for (const kernelConfig of kernelConfigs) { From bb9543db8528d175976264e0b7dd6b8407c335d4 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Tue, 24 Mar 2020 19:59:34 -0400 Subject: [PATCH 03/72] save xorig --- tfjs-core/src/ops/max.ts | 2 +- tfjs-core/src/register_all_gradients.ts | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tfjs-core/src/ops/max.ts b/tfjs-core/src/ops/max.ts index 21fec7a1d97..c774f20f7fd 100644 --- a/tfjs-core/src/ops/max.ts +++ b/tfjs-core/src/ops/max.ts @@ -69,7 +69,7 @@ function max_( const grad = (dy: T, saved: Tensor[]) => gradForMinAndMax(dy, saved[1], saved[0], origAxes, permutedAxes); - const inputsToSave = [$x]; + const inputsToSave = [xOrig]; const outputsToSave: boolean[] = [true]; let res = ENGINE.runKernelFunc((backend, save) => { const y = backend.max($x, axes); diff --git a/tfjs-core/src/register_all_gradients.ts b/tfjs-core/src/register_all_gradients.ts index 23bf408187d..d667540709b 100644 --- a/tfjs-core/src/register_all_gradients.ts +++ b/tfjs-core/src/register_all_gradients.ts @@ -16,7 +16,7 @@ */ import {broadcastToGradConfig} from './gradients/BroadcastTo_grad'; import {identityGradConfig} from './gradients/Identity_grad'; -import {maxGradConfig} from './gradients/Max_grad'; +// import {maxGradConfig} from './gradients/Max_grad'; import {squareGradConfig} from './gradients/Square_grad'; import {squaredDifferenceGradConfig} from './gradients/SquaredDifference_grad'; import {GradConfig} from './kernel_registry'; @@ -26,7 +26,7 @@ import {registerGradient} from './kernel_registry'; const gradConfigs: GradConfig[] = [ squareGradConfig, squaredDifferenceGradConfig, - maxGradConfig, + // maxGradConfig, broadcastToGradConfig, identityGradConfig, ]; From fe39bb586df719f0d3a1c60419b009d4690ae7ec Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Tue, 24 Mar 2020 20:10:59 -0400 Subject: [PATCH 04/72] save output --- tfjs-core/src/gradients/Max_grad.ts | 1 + tfjs-core/src/register_all_gradients.ts | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/tfjs-core/src/gradients/Max_grad.ts b/tfjs-core/src/gradients/Max_grad.ts index 9c004e56a96..3c200c79839 100644 --- a/tfjs-core/src/gradients/Max_grad.ts +++ b/tfjs-core/src/gradients/Max_grad.ts @@ -25,6 +25,7 @@ import * as util from '../util'; export const maxGradConfig: GradConfig = { kernelName: Max, inputsToSave: ['x'], + outputsToSave: [true], gradFunc: (dy: Tensor, saved: Tensor[], attrs: NamedAttrMap) => { const maxAttrs: MaxAttrs = attrs as {} as MaxAttrs; const {axes} = maxAttrs; diff --git a/tfjs-core/src/register_all_gradients.ts b/tfjs-core/src/register_all_gradients.ts index d667540709b..23bf408187d 100644 --- a/tfjs-core/src/register_all_gradients.ts +++ b/tfjs-core/src/register_all_gradients.ts @@ -16,7 +16,7 @@ */ import {broadcastToGradConfig} from './gradients/BroadcastTo_grad'; import {identityGradConfig} from './gradients/Identity_grad'; -// import {maxGradConfig} from './gradients/Max_grad'; +import {maxGradConfig} from './gradients/Max_grad'; import {squareGradConfig} from './gradients/Square_grad'; import {squaredDifferenceGradConfig} from './gradients/SquaredDifference_grad'; import {GradConfig} from './kernel_registry'; @@ -26,7 +26,7 @@ import {registerGradient} from './kernel_registry'; const gradConfigs: GradConfig[] = [ squareGradConfig, squaredDifferenceGradConfig, - // maxGradConfig, + maxGradConfig, broadcastToGradConfig, identityGradConfig, ]; From e68606443c635fd38997a9b35f09dc492adb4135 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Tue, 24 Mar 2020 20:22:25 -0400 Subject: [PATCH 05/72] max --- tfjs-core/src/ops/max.ts | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/tfjs-core/src/ops/max.ts b/tfjs-core/src/ops/max.ts index c774f20f7fd..4fe347d316d 100644 --- a/tfjs-core/src/ops/max.ts +++ b/tfjs-core/src/ops/max.ts @@ -23,7 +23,6 @@ import * as util from '../util'; import * as axis_util from './axis_util'; import {op} from './operation'; -import {gradForMinAndMax} from './reduction_ops'; /** * Computes the maximum of elements across dimensions of a `tf.Tensor`. @@ -56,7 +55,6 @@ import {gradForMinAndMax} from './reduction_ops'; function max_( x: Tensor|TensorLike, axis: number|number[] = null, keepDims = false): T { let $x = convertToTensor(x, 'x', 'max'); - const xOrig = $x; const origAxes = util.parseAxisParam(axis, $x.shape); let axes = origAxes; @@ -66,16 +64,11 @@ function max_( axes = axis_util.getInnerMostAxes(axes.length, $x.rank); } - const grad = (dy: T, saved: Tensor[]) => - gradForMinAndMax(dy, saved[1], saved[0], origAxes, permutedAxes); - - const inputsToSave = [xOrig]; - const outputsToSave: boolean[] = [true]; let res = ENGINE.runKernelFunc((backend, save) => { const y = backend.max($x, axes); - save([xOrig, y]); + save([$x, y]); return y; - }, {x: $x}, grad, 'Max', {axes}, inputsToSave, outputsToSave); + }, {x: $x}, null /* gradient */, 'Max', {axes}); if (keepDims) { const newShape = axis_util.expandShapeToKeepDim(res.shape, origAxes); res = res.reshape(newShape) as T; From 70caa0d2d8ed4c095ac9017bf67c0de73d25bd68 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Wed, 25 Mar 2020 07:27:05 -0400 Subject: [PATCH 06/72] separate out --- tfjs-core/src/backends/cpu/kernels/Max.ts | 23 +---------- .../src/backends/cpu/kernels/Max_impl.ts | 39 +++++++++++++++++++ 2 files changed, 41 insertions(+), 21 deletions(-) create mode 100644 tfjs-core/src/backends/cpu/kernels/Max_impl.ts diff --git a/tfjs-core/src/backends/cpu/kernels/Max.ts b/tfjs-core/src/backends/cpu/kernels/Max.ts index 1966486e820..2dd8ed90cc9 100644 --- a/tfjs-core/src/backends/cpu/kernels/Max.ts +++ b/tfjs-core/src/backends/cpu/kernels/Max.ts @@ -18,30 +18,11 @@ import {Max, MaxAttrs, MaxInputs} from '../../../kernel_names'; import {KernelConfig} from '../../../kernel_registry'; import * as axis_util from '../../../ops/axis_util'; -import {DataType, NumericDataType, TypedArray} from '../../../types'; +import {TypedArray} from '../../../types'; import * as util from '../../../util'; import {MathBackendCPU} from '../backend_cpu'; import {assertNotComplex} from '../cpu_util'; - -export function maxImpl( - aVals: TypedArray, reduceSize: number, outShape: number[], - dtype: DataType): TypedArray { - const vals = util.getTypedArrayFromDType( - dtype as NumericDataType, util.sizeFromShape(outShape)); - - for (let i = 0; i < vals.length; ++i) { - const offset = i * reduceSize; - let max = aVals[offset]; - for (let j = 0; j < reduceSize; ++j) { - const value = aVals[offset + j]; - if (value > max) { - max = value; - } - } - vals[i] = max; - } - return vals; -} +import {maxImpl} from './Max_impl'; export const maxConfig: KernelConfig = { kernelName: Max, diff --git a/tfjs-core/src/backends/cpu/kernels/Max_impl.ts b/tfjs-core/src/backends/cpu/kernels/Max_impl.ts new file mode 100644 index 00000000000..5cdee7353ce --- /dev/null +++ b/tfjs-core/src/backends/cpu/kernels/Max_impl.ts @@ -0,0 +1,39 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {DataType, NumericDataType, TypedArray} from '../../../types'; +import * as util from '../../../util'; + +export function maxImpl( + aVals: TypedArray, reduceSize: number, outShape: number[], + dtype: DataType): TypedArray { + const vals = util.getTypedArrayFromDType( + dtype as NumericDataType, util.sizeFromShape(outShape)); + + for (let i = 0; i < vals.length; ++i) { + const offset = i * reduceSize; + let max = aVals[offset]; + for (let j = 0; j < reduceSize; ++j) { + const value = aVals[offset + j]; + if (value > max) { + max = value; + } + } + vals[i] = max; + } + return vals; +} From a97c5fb511201391d3f731f78f92e424a3165300 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Wed, 25 Mar 2020 09:02:21 -0400 Subject: [PATCH 07/72] remove max --- tfjs-core/src/backends/webgl/backend_webgl.ts | 18 ++---- .../src/backends/webgl/kernel_utils/reduce.ts | 39 +++++++++++++ .../backends/webgl/kernel_utils/reshape.ts | 58 +++++++++++++++++++ tfjs-core/src/backends/webgl/kernels/Max.ts | 58 +++++++++++++++++++ tfjs-core/src/backends/webgl/reduce_gpu.ts | 6 +- .../backends/webgl/register_all_kernels.ts | 2 + tfjs-core/src/ops/reduce_util.ts | 2 + 7 files changed, 165 insertions(+), 18 deletions(-) create mode 100644 tfjs-core/src/backends/webgl/kernel_utils/reduce.ts create mode 100644 tfjs-core/src/backends/webgl/kernel_utils/reshape.ts create mode 100644 tfjs-core/src/backends/webgl/kernels/Max.ts diff --git a/tfjs-core/src/backends/webgl/backend_webgl.ts b/tfjs-core/src/backends/webgl/backend_webgl.ts index f513265c7aa..ddb887659f4 100644 --- a/tfjs-core/src/backends/webgl/backend_webgl.ts +++ b/tfjs-core/src/backends/webgl/backend_webgl.ts @@ -32,6 +32,7 @@ import {computeOutShape} from '../../ops/concat_util'; import {Conv2DInfo, Conv3DInfo} from '../../ops/conv_util'; import {Activation, FusedBatchMatMulConfig, FusedConv2DConfig} from '../../ops/fused_util'; import * as gather_nd_util from '../../ops/gather_nd_util'; +import {max} from '../../ops/max'; import * as reduce_util from '../../ops/reduce_util'; import * as scatter_nd_util from '../../ops/scatter_nd_util'; import * as segment_util from '../../ops/segment_util'; @@ -1330,19 +1331,6 @@ export class MathBackendWebGL extends KernelBackend { return this.compileAndRun(program, [a, b]); } - max(x: Tensor, axes: number[]): Tensor { - if (this.shouldExecuteOnCPU([x])) { - return this.cpuBackend.max(x, axes); - } - - axis_util.assertAxesAreInnerMostDims('max', axes, x.rank); - const [outShape, reduceShape] = - axis_util.computeOutAndReduceShapes(x.shape, axes); - const inSize = util.sizeFromShape(reduceShape); - const a2D = x.as2D(-1, inSize); - return this.reduce(a2D, 'max', a2D.dtype).reshape(outShape); - } - maximum(a: Tensor, b: Tensor): Tensor { if (this.shouldExecuteOnCPU([a, b])) { return this.cpuBackend.maximum(a, b); @@ -1591,7 +1579,9 @@ export class MathBackendWebGL extends KernelBackend { softmax(logits: T, dim: number): T { const axes = util.parseAxisParam([dim], logits.shape); - const maxLogit = this.max(logits, axes); + // TODO(annxingyuan): Call maxImpl rather than op as part of softmax kernel + // modularization. + const maxLogit = max(logits, axes); const expandedShape = axis_util.expandShapeToKeepDim(maxLogit.shape, axes); const a = this.subtract(logits, maxLogit.reshape(expandedShape)); const b = this.exp(a); diff --git a/tfjs-core/src/backends/webgl/kernel_utils/reduce.ts b/tfjs-core/src/backends/webgl/kernel_utils/reduce.ts new file mode 100644 index 00000000000..84ef675403d --- /dev/null +++ b/tfjs-core/src/backends/webgl/kernel_utils/reduce.ts @@ -0,0 +1,39 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {TensorInfo} from '../../../kernel_registry'; +import {computeOptimalWindowSize, ReduceTypes} from '../../../ops/reduce_util'; +import {DataType} from '../../../types'; + +import {MathBackendWebGL} from '../backend_webgl'; +import {ReduceProgram} from '../reduce_gpu'; + +export function reduce( + x: TensorInfo, reduceShape: number[], dtype: DataType, + reductionType: ReduceTypes, backend: MathBackendWebGL): TensorInfo { + const [batchSize, inSize] = x.shape; + const windowSize = computeOptimalWindowSize(inSize); + const reduceInfo = {windowSize, inSize, batchSize}; + const program = new ReduceProgram(reduceInfo, reductionType); + const output = backend.runWebGLProgram(program, [x], dtype); + + if (output.shape[1] === 1) { + return output; + } + + return reduce(output, reduceShape, dtype, reductionType, backend); +} diff --git a/tfjs-core/src/backends/webgl/kernel_utils/reshape.ts b/tfjs-core/src/backends/webgl/kernel_utils/reshape.ts new file mode 100644 index 00000000000..922cad7c12e --- /dev/null +++ b/tfjs-core/src/backends/webgl/kernel_utils/reshape.ts @@ -0,0 +1,58 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {TensorInfo} from '../../../kernel_registry'; +import {webgl_util} from '../../../webgl'; +import * as backend_util from '../../backend_util'; +import {MathBackendWebGL} from '../backend_webgl'; +import {ReshapePackedProgram} from '../reshape_packed_gpu'; + +function packedReshape( + input: TensorInfo, afterShape: number[], + backend: MathBackendWebGL): TensorInfo { + const input3DShape = [ + webgl_util.getBatchDim(input.shape), ...webgl_util.getRowsCols(input.shape) + ] as [number, number, number]; + const input3D: TensorInfo = { + dtype: input.dtype, + shape: input3DShape, + dataId: input.dataId + }; + const afterShapeAs3D = [ + webgl_util.getBatchDim(afterShape), ...webgl_util.getRowsCols(afterShape) + ] as [number, number, number]; + + const program = new ReshapePackedProgram(afterShapeAs3D, input3DShape); + const preventEagerUnpackingOfOutput = true; + const output = backend.runWebGLProgram( + program, [input3D], input.dtype, null /* customSetup */, + preventEagerUnpackingOfOutput); + return {dataId: output.dataId, shape: afterShape, dtype: output.dtype}; +} + +export function reshape( + x: TensorInfo, afterShape: number[], + backend: MathBackendWebGL): TensorInfo { + const xTexData = backend.texData.get(x.dataId); + if (xTexData.isPacked && !webgl_util.isReshapeFree(x.shape, afterShape) && + !(xTexData.texture !== null && + webgl_util.isReshapeFree(xTexData.shape, afterShape))) { + return packedReshape(x, afterShape, backend); + } + + return backend_util.reshapeTensor(x as any, afterShape); +} diff --git a/tfjs-core/src/backends/webgl/kernels/Max.ts b/tfjs-core/src/backends/webgl/kernels/Max.ts new file mode 100644 index 00000000000..d94dbab870f --- /dev/null +++ b/tfjs-core/src/backends/webgl/kernels/Max.ts @@ -0,0 +1,58 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Max, MaxAttrs, MaxInputs} from '../../../kernel_names'; +import {KernelConfig} from '../../../kernel_registry'; +import {TensorInfo} from '../../../kernel_registry'; +import * as axis_util from '../../../ops/axis_util'; +import {sizeFromShape} from '../../../util'; +import {MathBackendWebGL} from '../backend_webgl'; +import {reduce} from '../kernel_utils/reduce'; +import {reshape} from '../kernel_utils/reshape'; + +export const maxImpl = + (x: TensorInfo, reduceShape: number[], outShape: number[], + backend: MathBackendWebGL): TensorInfo => { + const inSize = sizeFromShape(reduceShape); + const xSize = sizeFromShape(x.shape); + const batchSize = xSize / inSize; + + return reshape( + reduce( + reshape(x, [batchSize, inSize], backend), reduceShape, x.dtype, + 'max', backend), + outShape, backend); + }; + +export const maxConfig: KernelConfig = { + kernelName: Max, + backendName: 'webgl', + kernelFunc: ({inputs, attrs, backend}) => { + const {x} = inputs as MaxInputs; + const {axes} = attrs as {} as MaxAttrs; + const webglBackend = backend as MathBackendWebGL; + + axis_util.assertAxesAreInnerMostDims('max', axes, x.shape.length); + + const [outShape, reduceShape] = + axis_util.computeOutAndReduceShapes(x.shape, axes); + + const out = maxImpl(x, reduceShape, outShape, webglBackend); + + return {dataId: out.dataId, shape: outShape, dtype: x.dtype}; + } +}; diff --git a/tfjs-core/src/backends/webgl/reduce_gpu.ts b/tfjs-core/src/backends/webgl/reduce_gpu.ts index 273c7028837..428354964ec 100644 --- a/tfjs-core/src/backends/webgl/reduce_gpu.ts +++ b/tfjs-core/src/backends/webgl/reduce_gpu.ts @@ -15,7 +15,7 @@ * ============================================================================= */ -import {ReduceInfo} from '../../ops/reduce_util'; +import {ReduceInfo, ReduceTypes} from '../../ops/reduce_util'; import {GPGPUProgram} from './gpgpu_math'; export class ReduceProgram implements GPGPUProgram { @@ -23,9 +23,7 @@ export class ReduceProgram implements GPGPUProgram { outputShape: number[]; userCode: string; - constructor( - reduceInfo: ReduceInfo, - reduceType: 'all'|'any'|'max'|'min'|'sum'|'prod') { + constructor(reduceInfo: ReduceInfo, reduceType: ReduceTypes) { const windowSize = reduceInfo.windowSize; const batchSize = reduceInfo.batchSize; const inSize = reduceInfo.inSize; diff --git a/tfjs-core/src/backends/webgl/register_all_kernels.ts b/tfjs-core/src/backends/webgl/register_all_kernels.ts index f7913ac6b21..17240e61be0 100644 --- a/tfjs-core/src/backends/webgl/register_all_kernels.ts +++ b/tfjs-core/src/backends/webgl/register_all_kernels.ts @@ -17,6 +17,7 @@ import {KernelConfig, registerKernel} from '../../kernel_registry'; import {fromPixelsConfig} from './kernels/FromPixels'; +import {maxConfig} from './kernels/Max'; import {nonMaxSuppressionV5Config} from './kernels/NonMaxSuppressionV5'; import {squareConfig} from './kernels/Square'; import {squaredDifferenceConfig} from './kernels/SquaredDifference'; @@ -24,6 +25,7 @@ import {squaredDifferenceConfig} from './kernels/SquaredDifference'; // List all kernel configs here const kernelConfigs: KernelConfig[] = [ fromPixelsConfig, + maxConfig, nonMaxSuppressionV5Config, squareConfig, squaredDifferenceConfig, diff --git a/tfjs-core/src/ops/reduce_util.ts b/tfjs-core/src/ops/reduce_util.ts index cb9f27d158a..876c784dc35 100644 --- a/tfjs-core/src/ops/reduce_util.ts +++ b/tfjs-core/src/ops/reduce_util.ts @@ -29,6 +29,8 @@ export interface ReduceInfo { inSize: number; } +export type ReduceTypes = 'all'|'any'|'max'|'min'|'sum'|'prod'; + export function computeOptimalWindowSize(inSize: number): number { if (inSize <= PARALLELIZE_THRESHOLD) { return inSize; From 4e9ae96af816003488c4b9804f22dd34966e4a43 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Wed, 25 Mar 2020 09:46:01 -0400 Subject: [PATCH 08/72] cpu forward --- tfjs-core/src/backends/webgl/backend_webgl.ts | 5 +++-- tfjs-core/src/backends/webgl/kernels/Max.ts | 17 ++++++++++++++++- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/tfjs-core/src/backends/webgl/backend_webgl.ts b/tfjs-core/src/backends/webgl/backend_webgl.ts index ddb887659f4..46505fe0dc7 100644 --- a/tfjs-core/src/backends/webgl/backend_webgl.ts +++ b/tfjs-core/src/backends/webgl/backend_webgl.ts @@ -653,8 +653,9 @@ export class MathBackendWebGL extends KernelBackend { TODO(https://github.com/tensorflow/tfjs/issues/872): Develop a more sustainable strategy for optimizing backend execution of ops. */ - private shouldExecuteOnCPU( - inputs: Tensor[], sizeThreshold = CPU_HANDOFF_SIZE_THRESHOLD): boolean { + shouldExecuteOnCPU( + inputs: Tensor[]|TensorInfo[], + sizeThreshold = CPU_HANDOFF_SIZE_THRESHOLD): boolean { return this.getCPUBackend() != null && inputs.every( input => this.texData.get(input.dataId).texture == null && diff --git a/tfjs-core/src/backends/webgl/kernels/Max.ts b/tfjs-core/src/backends/webgl/kernels/Max.ts index d94dbab870f..77948b82315 100644 --- a/tfjs-core/src/backends/webgl/kernels/Max.ts +++ b/tfjs-core/src/backends/webgl/kernels/Max.ts @@ -15,10 +15,13 @@ * ============================================================================= */ +import {maxImpl as cpuMax} from '../../../backends/cpu/kernels/Max_impl'; import {Max, MaxAttrs, MaxInputs} from '../../../kernel_names'; import {KernelConfig} from '../../../kernel_registry'; import {TensorInfo} from '../../../kernel_registry'; import * as axis_util from '../../../ops/axis_util'; +import {TypedArray} from '../../../types'; +import * as util from '../../../util'; import {sizeFromShape} from '../../../util'; import {MathBackendWebGL} from '../backend_webgl'; import {reduce} from '../kernel_utils/reduce'; @@ -51,7 +54,19 @@ export const maxConfig: KernelConfig = { const [outShape, reduceShape] = axis_util.computeOutAndReduceShapes(x.shape, axes); - const out = maxImpl(x, reduceShape, outShape, webglBackend); + let out; + if (webglBackend.shouldExecuteOnCPU([x])) { + const xTexData = webglBackend.texData.get(x.dataId); + const values = xTexData.values as TypedArray; + const outValues = + cpuMax(values, util.sizeFromShape(reduceShape), outShape, x.dtype); + + out = webglBackend.makeTensorInfo(outShape, x.dtype); + const outData = webglBackend.texData.get(out.dataId); + outData.values = outValues; + } else { + out = maxImpl(x, reduceShape, outShape, webglBackend); + } return {dataId: out.dataId, shape: outShape, dtype: x.dtype}; } From 2173ab2d90e507174f99e8c338fbbd9fe60ae152 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Wed, 25 Mar 2020 09:49:40 -0400 Subject: [PATCH 09/72] create tensor info --- tfjs-core/src/backends/webgl/kernel_utils/reshape.ts | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tfjs-core/src/backends/webgl/kernel_utils/reshape.ts b/tfjs-core/src/backends/webgl/kernel_utils/reshape.ts index 922cad7c12e..a19fee541df 100644 --- a/tfjs-core/src/backends/webgl/kernel_utils/reshape.ts +++ b/tfjs-core/src/backends/webgl/kernel_utils/reshape.ts @@ -17,7 +17,6 @@ import {TensorInfo} from '../../../kernel_registry'; import {webgl_util} from '../../../webgl'; -import * as backend_util from '../../backend_util'; import {MathBackendWebGL} from '../backend_webgl'; import {ReshapePackedProgram} from '../reshape_packed_gpu'; @@ -54,5 +53,5 @@ export function reshape( return packedReshape(x, afterShape, backend); } - return backend_util.reshapeTensor(x as any, afterShape); + return {dataId: x.dataId, shape: afterShape, dtype: x.dtype}; } From face1b2d4568b817e1268549c0b2b8e47cfe7422 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Wed, 25 Mar 2020 09:53:25 -0400 Subject: [PATCH 10/72] simplify --- tfjs-core/src/backends/webgl/kernel_utils/reduce.ts | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tfjs-core/src/backends/webgl/kernel_utils/reduce.ts b/tfjs-core/src/backends/webgl/kernel_utils/reduce.ts index 84ef675403d..949f96d7ea1 100644 --- a/tfjs-core/src/backends/webgl/kernel_utils/reduce.ts +++ b/tfjs-core/src/backends/webgl/kernel_utils/reduce.ts @@ -23,8 +23,8 @@ import {MathBackendWebGL} from '../backend_webgl'; import {ReduceProgram} from '../reduce_gpu'; export function reduce( - x: TensorInfo, reduceShape: number[], dtype: DataType, - reductionType: ReduceTypes, backend: MathBackendWebGL): TensorInfo { + x: TensorInfo, dtype: DataType, reductionType: ReduceTypes, + backend: MathBackendWebGL): TensorInfo { const [batchSize, inSize] = x.shape; const windowSize = computeOptimalWindowSize(inSize); const reduceInfo = {windowSize, inSize, batchSize}; @@ -35,5 +35,5 @@ export function reduce( return output; } - return reduce(output, reduceShape, dtype, reductionType, backend); + return reduce(output, dtype, reductionType, backend); } From 5ff78b481f3cccf1f2351daa55763d5de9bb58c8 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Wed, 25 Mar 2020 10:05:03 -0400 Subject: [PATCH 11/72] split out impl --- tfjs-core/src/backends/webgl/kernels/Max.ts | 19 +--------- .../src/backends/webgl/kernels/Max_impl.ts | 36 +++++++++++++++++++ 2 files changed, 37 insertions(+), 18 deletions(-) create mode 100644 tfjs-core/src/backends/webgl/kernels/Max_impl.ts diff --git a/tfjs-core/src/backends/webgl/kernels/Max.ts b/tfjs-core/src/backends/webgl/kernels/Max.ts index 77948b82315..f124e79cddf 100644 --- a/tfjs-core/src/backends/webgl/kernels/Max.ts +++ b/tfjs-core/src/backends/webgl/kernels/Max.ts @@ -18,28 +18,11 @@ import {maxImpl as cpuMax} from '../../../backends/cpu/kernels/Max_impl'; import {Max, MaxAttrs, MaxInputs} from '../../../kernel_names'; import {KernelConfig} from '../../../kernel_registry'; -import {TensorInfo} from '../../../kernel_registry'; import * as axis_util from '../../../ops/axis_util'; import {TypedArray} from '../../../types'; import * as util from '../../../util'; -import {sizeFromShape} from '../../../util'; import {MathBackendWebGL} from '../backend_webgl'; -import {reduce} from '../kernel_utils/reduce'; -import {reshape} from '../kernel_utils/reshape'; - -export const maxImpl = - (x: TensorInfo, reduceShape: number[], outShape: number[], - backend: MathBackendWebGL): TensorInfo => { - const inSize = sizeFromShape(reduceShape); - const xSize = sizeFromShape(x.shape); - const batchSize = xSize / inSize; - - return reshape( - reduce( - reshape(x, [batchSize, inSize], backend), reduceShape, x.dtype, - 'max', backend), - outShape, backend); - }; +import {maxImpl} from './Max_impl'; export const maxConfig: KernelConfig = { kernelName: Max, diff --git a/tfjs-core/src/backends/webgl/kernels/Max_impl.ts b/tfjs-core/src/backends/webgl/kernels/Max_impl.ts new file mode 100644 index 00000000000..f2841686321 --- /dev/null +++ b/tfjs-core/src/backends/webgl/kernels/Max_impl.ts @@ -0,0 +1,36 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {TensorInfo} from '../../../kernel_registry'; +import {sizeFromShape} from '../../../util'; +import {MathBackendWebGL} from '../backend_webgl'; +import {reduce} from '../kernel_utils/reduce'; +import {reshape} from '../kernel_utils/reshape'; + +export const maxImpl = + (x: TensorInfo, reduceShape: number[], outShape: number[], + backend: MathBackendWebGL): TensorInfo => { + const inSize = sizeFromShape(reduceShape); + const xSize = sizeFromShape(x.shape); + const batchSize = xSize / inSize; + + return reshape( + reduce( + reshape(x, [batchSize, inSize], backend), x.dtype, 'max', + backend), + outShape, backend); + }; From 3a34f18e3101e09e767124c38013c8784d6e3810 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Wed, 25 Mar 2020 10:10:24 -0400 Subject: [PATCH 12/72] unchain --- tfjs-core/src/backends/cpu/register_all_kernels.ts | 2 +- tfjs-core/src/ops/max.ts | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/tfjs-core/src/backends/cpu/register_all_kernels.ts b/tfjs-core/src/backends/cpu/register_all_kernels.ts index fc36171f39d..2f6d9653cd5 100644 --- a/tfjs-core/src/backends/cpu/register_all_kernels.ts +++ b/tfjs-core/src/backends/cpu/register_all_kernels.ts @@ -26,7 +26,7 @@ import {squaredDifferenceConfig} from './kernels/SquaredDifference'; // List all kernel configs here const kernelConfigs: KernelConfig[] = [ - nonMaxSuppressionV5Config, squareConfig, squaredDifferenceConfig, maxConfig + maxConfig, nonMaxSuppressionV5Config, squareConfig, squaredDifferenceConfig ]; for (const kernelConfig of kernelConfigs) { diff --git a/tfjs-core/src/ops/max.ts b/tfjs-core/src/ops/max.ts index 4fe347d316d..2c883fffde3 100644 --- a/tfjs-core/src/ops/max.ts +++ b/tfjs-core/src/ops/max.ts @@ -23,6 +23,7 @@ import * as util from '../util'; import * as axis_util from './axis_util'; import {op} from './operation'; +import {transpose} from './transpose'; /** * Computes the maximum of elements across dimensions of a `tf.Tensor`. @@ -60,7 +61,7 @@ function max_( let axes = origAxes; const permutedAxes = axis_util.getAxesPermutation(axes, $x.rank); if (permutedAxes != null) { - $x = $x.transpose(permutedAxes); + $x = transpose($x, permutedAxes); axes = axis_util.getInnerMostAxes(axes.length, $x.rank); } From c16de99288373c0c5d16ad3da6759acfed5605d9 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Wed, 25 Mar 2020 10:46:08 -0400 Subject: [PATCH 13/72] return out directly --- tfjs-core/src/backends/webgl/kernels/Max.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tfjs-core/src/backends/webgl/kernels/Max.ts b/tfjs-core/src/backends/webgl/kernels/Max.ts index f124e79cddf..115470b5121 100644 --- a/tfjs-core/src/backends/webgl/kernels/Max.ts +++ b/tfjs-core/src/backends/webgl/kernels/Max.ts @@ -51,6 +51,6 @@ export const maxConfig: KernelConfig = { out = maxImpl(x, reduceShape, outShape, webglBackend); } - return {dataId: out.dataId, shape: outShape, dtype: x.dtype}; + return out; } }; From 7abe3958c0d22172ca465404fe96f234c12f5819 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Fri, 27 Mar 2020 09:22:30 -0400 Subject: [PATCH 14/72] rename --- tfjs-backend-wasm/src/kernels/Max.ts | 9 +++++---- tfjs-core/src/backends/cpu/kernels/Max.ts | 7 ++++--- tfjs-core/src/backends/webgl/kernels/Max.ts | 7 ++++--- tfjs-core/src/gradients/Max_grad.ts | 6 +++--- tfjs-core/src/kernel_names.ts | 2 +- 5 files changed, 17 insertions(+), 14 deletions(-) diff --git a/tfjs-backend-wasm/src/kernels/Max.ts b/tfjs-backend-wasm/src/kernels/Max.ts index ab688b76c80..bd300605557 100644 --- a/tfjs-backend-wasm/src/kernels/Max.ts +++ b/tfjs-backend-wasm/src/kernels/Max.ts @@ -24,7 +24,7 @@ interface MaxInputs extends NamedTensorInfoMap { } interface MaxAttrs extends NamedAttrMap { - axes: number[]; + reductionIndices: number[]; } let wasmMax: (xId: number, reduceSize: number, outId: number) => void; @@ -37,13 +37,14 @@ function setup(backend: BackendWasm): void { function max(args: {backend: BackendWasm, inputs: MaxInputs, attrs: MaxAttrs}): TensorInfo { const {backend, inputs, attrs} = args; - const {axes} = attrs; + const {reductionIndices} = attrs; const {x} = inputs; const xId = backend.dataIdMap.get(x.dataId).id; - backend_util.assertAxesAreInnerMostDims('max', axes, x.shape.length); + backend_util.assertAxesAreInnerMostDims( + 'max', reductionIndices, x.shape.length); const [outShape, reduceShape] = - backend_util.computeOutAndReduceShapes(x.shape, axes); + backend_util.computeOutAndReduceShapes(x.shape, reductionIndices); const reduceSize = util.sizeFromShape(reduceShape); const out = backend.makeOutput(outShape, x.dtype); diff --git a/tfjs-core/src/backends/cpu/kernels/Max.ts b/tfjs-core/src/backends/cpu/kernels/Max.ts index 2dd8ed90cc9..db2aa962e2e 100644 --- a/tfjs-core/src/backends/cpu/kernels/Max.ts +++ b/tfjs-core/src/backends/cpu/kernels/Max.ts @@ -29,13 +29,14 @@ export const maxConfig: KernelConfig = { backendName: 'cpu', kernelFunc: ({inputs, attrs, backend}) => { const {x} = inputs as MaxInputs; - const {axes} = attrs as {} as MaxAttrs; + const {reductionIndices} = attrs as {} as MaxAttrs; const cpuBackend = backend as MathBackendCPU; assertNotComplex(x, 'max'); - axis_util.assertAxesAreInnerMostDims('max', axes, x.shape.length); + axis_util.assertAxesAreInnerMostDims( + 'max', reductionIndices, x.shape.length); const [outShape, reduceShape] = - axis_util.computeOutAndReduceShapes(x.shape, axes); + axis_util.computeOutAndReduceShapes(x.shape, reductionIndices); const reduceSize = util.sizeFromShape(reduceShape); diff --git a/tfjs-core/src/backends/webgl/kernels/Max.ts b/tfjs-core/src/backends/webgl/kernels/Max.ts index 115470b5121..ad162c30d46 100644 --- a/tfjs-core/src/backends/webgl/kernels/Max.ts +++ b/tfjs-core/src/backends/webgl/kernels/Max.ts @@ -29,13 +29,14 @@ export const maxConfig: KernelConfig = { backendName: 'webgl', kernelFunc: ({inputs, attrs, backend}) => { const {x} = inputs as MaxInputs; - const {axes} = attrs as {} as MaxAttrs; + const {reductionIndices} = attrs as {} as MaxAttrs; const webglBackend = backend as MathBackendWebGL; - axis_util.assertAxesAreInnerMostDims('max', axes, x.shape.length); + axis_util.assertAxesAreInnerMostDims( + 'max', reductionIndices, x.shape.length); const [outShape, reduceShape] = - axis_util.computeOutAndReduceShapes(x.shape, axes); + axis_util.computeOutAndReduceShapes(x.shape, reductionIndices); let out; if (webglBackend.shouldExecuteOnCPU([x])) { diff --git a/tfjs-core/src/gradients/Max_grad.ts b/tfjs-core/src/gradients/Max_grad.ts index 3c200c79839..d8f682f84aa 100644 --- a/tfjs-core/src/gradients/Max_grad.ts +++ b/tfjs-core/src/gradients/Max_grad.ts @@ -28,10 +28,10 @@ export const maxGradConfig: GradConfig = { outputsToSave: [true], gradFunc: (dy: Tensor, saved: Tensor[], attrs: NamedAttrMap) => { const maxAttrs: MaxAttrs = attrs as {} as MaxAttrs; - const {axes} = maxAttrs; + const {reductionIndices} = maxAttrs; const [x, y] = saved; - const origAxes = util.parseAxisParam(axes, x.shape); - const permutedAxes = axis_util.getAxesPermutation(axes, x.rank); + const origAxes = util.parseAxisParam(reductionIndices, x.shape); + const permutedAxes = axis_util.getAxesPermutation(reductionIndices, x.rank); return gradForMinAndMax(dy, y, x, origAxes, permutedAxes); } }; diff --git a/tfjs-core/src/kernel_names.ts b/tfjs-core/src/kernel_names.ts index 686f716f94d..de753e343ad 100644 --- a/tfjs-core/src/kernel_names.ts +++ b/tfjs-core/src/kernel_names.ts @@ -51,7 +51,7 @@ export interface NonMaxSuppressionV5Attrs { export const Max = 'Max'; export type MaxInputs = Pick; export interface MaxAttrs { - axes: number[]; + reductionIndices: number[]; } export const BroadcastTo = 'BroadcastTo'; From 8ead8d7b3c1cf0e19418c62415f156d31a6629e8 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Fri, 27 Mar 2020 09:55:55 -0400 Subject: [PATCH 15/72] kernelize --- tfjs-backend-wasm/src/index_test.ts | 11 +++++++-- tfjs-backend-wasm/src/kernels/Max.ts | 12 +++++++--- tfjs-core/src/backends/cpu/kernels/Max.ts | 13 ++++++++--- tfjs-core/src/backends/webgl/kernels/Max.ts | 13 ++++++++--- tfjs-core/src/ops/max.ts | 26 ++++++++++++++------- 5 files changed, 55 insertions(+), 20 deletions(-) diff --git a/tfjs-backend-wasm/src/index_test.ts b/tfjs-backend-wasm/src/index_test.ts index 370615f2cf3..4c724ab3035 100644 --- a/tfjs-backend-wasm/src/index_test.ts +++ b/tfjs-backend-wasm/src/index_test.ts @@ -58,8 +58,8 @@ describeWithFlags('wasm init', BROWSER_ENVS, () => { }, 100); // Silences backend registration warnings. - spyOn(console, 'warn'); - spyOn(console, 'log'); + // spyOn(console, 'warn'); + // spyOn(console, 'log'); }); afterEach(() => { @@ -92,4 +92,11 @@ describeWithFlags('wasm init', BROWSER_ENVS, () => { expect(() => setWasmPath('too/late')) .toThrowError(/The WASM backend was already initialized. Make sure/); }); + + fit('max', async () => { + const a = tf.tensor1d([3, -1, 0, 100, -7, 2]); + const r = tf.max(a); + const data = await r.data(); + console.log(data); + }); }); diff --git a/tfjs-backend-wasm/src/kernels/Max.ts b/tfjs-backend-wasm/src/kernels/Max.ts index bd300605557..32486359125 100644 --- a/tfjs-backend-wasm/src/kernels/Max.ts +++ b/tfjs-backend-wasm/src/kernels/Max.ts @@ -41,10 +41,16 @@ function max(args: {backend: BackendWasm, inputs: MaxInputs, attrs: MaxAttrs}): const {x} = inputs; const xId = backend.dataIdMap.get(x.dataId).id; - backend_util.assertAxesAreInnerMostDims( - 'max', reductionIndices, x.shape.length); + const origAxes = util.parseAxisParam(reductionIndices, x.shape); + let axes = origAxes; + const permutedAxes = backend_util.getAxesPermutation(axes, x.shape.length); + if (permutedAxes != null) { + console.log('TRANSPOSE'); + } + + backend_util.assertAxesAreInnerMostDims('max', axes, x.shape.length); const [outShape, reduceShape] = - backend_util.computeOutAndReduceShapes(x.shape, reductionIndices); + backend_util.computeOutAndReduceShapes(x.shape, axes); const reduceSize = util.sizeFromShape(reduceShape); const out = backend.makeOutput(outShape, x.dtype); diff --git a/tfjs-core/src/backends/cpu/kernels/Max.ts b/tfjs-core/src/backends/cpu/kernels/Max.ts index db2aa962e2e..8b656dd8d68 100644 --- a/tfjs-core/src/backends/cpu/kernels/Max.ts +++ b/tfjs-core/src/backends/cpu/kernels/Max.ts @@ -31,12 +31,19 @@ export const maxConfig: KernelConfig = { const {x} = inputs as MaxInputs; const {reductionIndices} = attrs as {} as MaxAttrs; const cpuBackend = backend as MathBackendCPU; + console.log('max cpu kernel func', x, reductionIndices); + + const origAxes = util.parseAxisParam(reductionIndices, x.shape); + let axes = origAxes; + const permutedAxes = axis_util.getAxesPermutation(axes, x.shape.length); + if (permutedAxes != null) { + console.log('TRANSPOSE'); + } assertNotComplex(x, 'max'); - axis_util.assertAxesAreInnerMostDims( - 'max', reductionIndices, x.shape.length); + axis_util.assertAxesAreInnerMostDims('max', axes, x.shape.length); const [outShape, reduceShape] = - axis_util.computeOutAndReduceShapes(x.shape, reductionIndices); + axis_util.computeOutAndReduceShapes(x.shape, axes); const reduceSize = util.sizeFromShape(reduceShape); diff --git a/tfjs-core/src/backends/webgl/kernels/Max.ts b/tfjs-core/src/backends/webgl/kernels/Max.ts index ad162c30d46..66522edf091 100644 --- a/tfjs-core/src/backends/webgl/kernels/Max.ts +++ b/tfjs-core/src/backends/webgl/kernels/Max.ts @@ -31,15 +31,22 @@ export const maxConfig: KernelConfig = { const {x} = inputs as MaxInputs; const {reductionIndices} = attrs as {} as MaxAttrs; const webglBackend = backend as MathBackendWebGL; + console.log('max webgl kernel func', x, reductionIndices); - axis_util.assertAxesAreInnerMostDims( - 'max', reductionIndices, x.shape.length); + const origAxes = util.parseAxisParam(reductionIndices, x.shape); + let axes = origAxes; + const permutedAxes = axis_util.getAxesPermutation(axes, x.shape.length); + if (permutedAxes != null) { + console.log('TRANSPOSE'); + } + axis_util.assertAxesAreInnerMostDims('max', axes, x.shape.length); const [outShape, reduceShape] = - axis_util.computeOutAndReduceShapes(x.shape, reductionIndices); + axis_util.computeOutAndReduceShapes(x.shape, axes); let out; if (webglBackend.shouldExecuteOnCPU([x])) { + console.log('running on the cpu instead'); const xTexData = webglBackend.texData.get(x.dataId); const values = xTexData.values as TypedArray; const outValues = diff --git a/tfjs-core/src/ops/max.ts b/tfjs-core/src/ops/max.ts index 2c883fffde3..afe04308e04 100644 --- a/tfjs-core/src/ops/max.ts +++ b/tfjs-core/src/ops/max.ts @@ -15,8 +15,10 @@ * ============================================================================= */ +import {KernelBackend} from '../backends/backend'; import {ENGINE} from '../engine'; import {Tensor} from '../tensor'; +import {GradSaveFunc} from '../tensor_types'; import {convertToTensor} from '../tensor_util_env'; import {TensorLike} from '../types'; import * as util from '../util'; @@ -55,21 +57,27 @@ import {transpose} from './transpose'; /** @doc {heading: 'Operations', subheading: 'Reduction'} */ function max_( x: Tensor|TensorLike, axis: number|number[] = null, keepDims = false): T { + console.log('max op'); + console.log(x); + console.log(axis); let $x = convertToTensor(x, 'x', 'max'); - const origAxes = util.parseAxisParam(axis, $x.shape); - let axes = origAxes; - const permutedAxes = axis_util.getAxesPermutation(axes, $x.rank); - if (permutedAxes != null) { - $x = transpose($x, permutedAxes); - axes = axis_util.getInnerMostAxes(axes.length, $x.rank); - } - let res = ENGINE.runKernelFunc((backend, save) => { + const forward = (backend: KernelBackend, save: GradSaveFunc) => { + let axes = origAxes; + const permutedAxes = axis_util.getAxesPermutation(axes, $x.rank); + if (permutedAxes != null) { + $x = transpose($x, permutedAxes); + axes = axis_util.getInnerMostAxes(axes.length, $x.rank); + } + const y = backend.max($x, axes); save([$x, y]); return y; - }, {x: $x}, null /* gradient */, 'Max', {axes}); + }; + + let res = ENGINE.runKernelFunc( + forward, {x: $x}, null /* gradient */, 'Max', {reductionIndices: axis}); if (keepDims) { const newShape = axis_util.expandShapeToKeepDim(res.shape, origAxes); res = res.reshape(newShape) as T; From cc7685b78328ca6d1355c0354f989d9c367803f8 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Fri, 27 Mar 2020 10:03:11 -0400 Subject: [PATCH 16/72] wip --- tfjs-core/src/backends/cpu/kernels/Max.ts | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/tfjs-core/src/backends/cpu/kernels/Max.ts b/tfjs-core/src/backends/cpu/kernels/Max.ts index 8b656dd8d68..66006e68750 100644 --- a/tfjs-core/src/backends/cpu/kernels/Max.ts +++ b/tfjs-core/src/backends/cpu/kernels/Max.ts @@ -23,6 +23,7 @@ import * as util from '../../../util'; import {MathBackendCPU} from '../backend_cpu'; import {assertNotComplex} from '../cpu_util'; import {maxImpl} from './Max_impl'; +import {transposeImpl} from './Transpose_impl'; export const maxConfig: KernelConfig = { kernelName: Max, @@ -31,24 +32,28 @@ export const maxConfig: KernelConfig = { const {x} = inputs as MaxInputs; const {reductionIndices} = attrs as {} as MaxAttrs; const cpuBackend = backend as MathBackendCPU; + const xRank = x.shape.length; console.log('max cpu kernel func', x, reductionIndices); const origAxes = util.parseAxisParam(reductionIndices, x.shape); let axes = origAxes; - const permutedAxes = axis_util.getAxesPermutation(axes, x.shape.length); + const permutedAxes = axis_util.getAxesPermutation(axes, xRank); if (permutedAxes != null) { console.log('TRANSPOSE'); + const vals = cpuBackend.data.get(x.dataId).values as TypedArray; + const xTVals = transposeImpl(vals, x.shape, x.dtype, permutedAxes); + axes = axis_util.getInnerMostAxes(axes.length, xRank); } assertNotComplex(x, 'max'); - axis_util.assertAxesAreInnerMostDims('max', axes, x.shape.length); + axis_util.assertAxesAreInnerMostDims('max', axes, xRank); const [outShape, reduceShape] = axis_util.computeOutAndReduceShapes(x.shape, axes); const reduceSize = util.sizeFromShape(reduceShape); - const aVals = cpuBackend.data.get(x.dataId).values as TypedArray; - const result = maxImpl(aVals, reduceSize, outShape, x.dtype); + const xVals = cpuBackend.data.get(x.dataId).values as TypedArray; + const result = maxImpl(xVals, reduceSize, outShape, x.dtype); const dataId = cpuBackend.write(result, outShape, x.dtype); return {dataId, shape: outShape, dtype: x.dtype}; From f9c0df3d2fba18f804925ce1060b02a0e5250855 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Wed, 22 Apr 2020 08:50:43 -0400 Subject: [PATCH 17/72] fix --- tfjs-backend-cpu/src/kernels/Max.ts | 23 +++++++++---------- tfjs-backend-cpu/src/kernels/Max_impl.ts | 3 +-- tfjs-backend-cpu/src/kernels/Transpose.ts | 2 +- .../src/kernels/Transpose_impl.ts | 11 ++++++--- tfjs-backend-webgpu/src/backend_webgpu.ts | 4 ++-- tfjs-core/yarn.lock | 5 ---- 6 files changed, 23 insertions(+), 25 deletions(-) diff --git a/tfjs-backend-cpu/src/kernels/Max.ts b/tfjs-backend-cpu/src/kernels/Max.ts index 66006e68750..ff436eb4f40 100644 --- a/tfjs-backend-cpu/src/kernels/Max.ts +++ b/tfjs-backend-cpu/src/kernels/Max.ts @@ -15,13 +15,13 @@ * ============================================================================= */ -import {Max, MaxAttrs, MaxInputs} from '../../../kernel_names'; -import {KernelConfig} from '../../../kernel_registry'; -import * as axis_util from '../../../ops/axis_util'; -import {TypedArray} from '../../../types'; -import * as util from '../../../util'; +import {Max, MaxAttrs, MaxInputs} from '@tensorflow/tfjs-core'; +import {backend_util, KernelConfig} from '@tensorflow/tfjs-core'; +import {TypedArray, util} from '@tensorflow/tfjs-core'; + import {MathBackendCPU} from '../backend_cpu'; import {assertNotComplex} from '../cpu_util'; + import {maxImpl} from './Max_impl'; import {transposeImpl} from './Transpose_impl'; @@ -37,22 +37,21 @@ export const maxConfig: KernelConfig = { const origAxes = util.parseAxisParam(reductionIndices, x.shape); let axes = origAxes; - const permutedAxes = axis_util.getAxesPermutation(axes, xRank); + const permutedAxes = backend_util.getAxesPermutation(axes, xRank); + let xVals = cpuBackend.data.get(x.dataId).values as TypedArray; if (permutedAxes != null) { console.log('TRANSPOSE'); - const vals = cpuBackend.data.get(x.dataId).values as TypedArray; - const xTVals = transposeImpl(vals, x.shape, x.dtype, permutedAxes); - axes = axis_util.getInnerMostAxes(axes.length, xRank); + xVals = transposeImpl(xVals, x.shape, x.dtype, permutedAxes); + axes = backend_util.getInnerMostAxes(axes.length, xRank); } assertNotComplex(x, 'max'); - axis_util.assertAxesAreInnerMostDims('max', axes, xRank); + backend_util.assertAxesAreInnerMostDims('max', axes, xRank); const [outShape, reduceShape] = - axis_util.computeOutAndReduceShapes(x.shape, axes); + backend_util.computeOutAndReduceShapes(x.shape, axes); const reduceSize = util.sizeFromShape(reduceShape); - const xVals = cpuBackend.data.get(x.dataId).values as TypedArray; const result = maxImpl(xVals, reduceSize, outShape, x.dtype); const dataId = cpuBackend.write(result, outShape, x.dtype); diff --git a/tfjs-backend-cpu/src/kernels/Max_impl.ts b/tfjs-backend-cpu/src/kernels/Max_impl.ts index 5cdee7353ce..c326ac3f9df 100644 --- a/tfjs-backend-cpu/src/kernels/Max_impl.ts +++ b/tfjs-backend-cpu/src/kernels/Max_impl.ts @@ -15,8 +15,7 @@ * ============================================================================= */ -import {DataType, NumericDataType, TypedArray} from '../../../types'; -import * as util from '../../../util'; +import {DataType, NumericDataType, TypedArray, util} from '@tensorflow/tfjs-core'; export function maxImpl( aVals: TypedArray, reduceSize: number, outShape: number[], diff --git a/tfjs-backend-cpu/src/kernels/Transpose.ts b/tfjs-backend-cpu/src/kernels/Transpose.ts index 05ae13f03cb..c1c68be78f2 100644 --- a/tfjs-backend-cpu/src/kernels/Transpose.ts +++ b/tfjs-backend-cpu/src/kernels/Transpose.ts @@ -41,7 +41,7 @@ export const transposeConfig: KernelConfig = { } const values = cpuBackend.data.get(x.dataId).values as TypedArray; - const result = transposeImpl(values, x.shape, x.dtype, perm, newShape); + const result = transposeImpl(values, x.shape, x.dtype, perm); const dataId = cpuBackend.write(result, newShape, x.dtype); return {dataId, shape: newShape, dtype: x.dtype}; diff --git a/tfjs-backend-cpu/src/kernels/Transpose_impl.ts b/tfjs-backend-cpu/src/kernels/Transpose_impl.ts index 3a6d290b777..b7d78decb03 100644 --- a/tfjs-backend-cpu/src/kernels/Transpose_impl.ts +++ b/tfjs-backend-cpu/src/kernels/Transpose_impl.ts @@ -19,10 +19,15 @@ import {DataType, NumericDataType, TypedArray} from '@tensorflow/tfjs-core'; import {util} from '@tensorflow/tfjs-core'; export function transposeImpl( - xVals: TypedArray, xShape: number[], dtype: DataType, perm: number[], - newShape: number[]): TypedArray { - const xSize = util.sizeFromShape(xShape); + xVals: TypedArray, xShape: number[], dtype: DataType, + perm: number[]): TypedArray { const xRank = xShape.length; + const newShape: number[] = new Array(xRank); + for (let i = 0; i < newShape.length; i++) { + newShape[i] = xShape[perm[i]]; + } + + const xSize = util.sizeFromShape(xShape); const xStrides = util.computeStrides(xShape); const newStrides = util.computeStrides(newShape); diff --git a/tfjs-backend-webgpu/src/backend_webgpu.ts b/tfjs-backend-webgpu/src/backend_webgpu.ts index e09724b1501..aa3d954e939 100644 --- a/tfjs-backend-webgpu/src/backend_webgpu.ts +++ b/tfjs-backend-webgpu/src/backend_webgpu.ts @@ -797,8 +797,8 @@ export class WebGPUBackend extends KernelBackend { const dimensions = [ convInfo.filterHeight, convInfo.filterWidth, ...pad, - convInfo.strideHeight, convInfo.strideWidth, - convInfo.dilationHeight, convInfo.dilationWidth + convInfo.strideHeight, convInfo.strideWidth, convInfo.dilationHeight, + convInfo.dilationWidth ]; const inputs: Tensor[] = [input, filter]; diff --git a/tfjs-core/yarn.lock b/tfjs-core/yarn.lock index 90dfb1e96f2..bbebade706d 100644 --- a/tfjs-core/yarn.lock +++ b/tfjs-core/yarn.lock @@ -208,11 +208,6 @@ acorn@^6.0.5: resolved "https://registry.yarnpkg.com/acorn/-/acorn-6.4.0.tgz#b659d2ffbafa24baf5db1cdbb2c94a983ecd2784" integrity sha512-gac8OEcQ2Li1dxIEWGZzsp2BitJxwkwcOm0zHAJLcPJaVvm58FRnk6RkuLRpU1EujipU2ZFODv2P9DLMfnV8mw== -acorn@^7.1.1: - version "7.1.1" - resolved "https://registry.yarnpkg.com/acorn/-/acorn-7.1.1.tgz#e35668de0b402f359de515c5482a1ab9f89a69bf" - integrity sha512-add7dgA5ppRPxCFJoAGfMDi7PIBXq1RtGo7BhbLaxwrXPOmw8gq48Y9ozT01hUKy9byMjlR20EJhu5zlkErEkg== - after@0.8.2: version "0.8.2" resolved "https://registry.yarnpkg.com/after/-/after-0.8.2.tgz#fedb394f9f0e02aa9768e702bda23b505fae7e1f" From f93656eb2df3acc58a8eff17ab0d6e6c3496ed57 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Wed, 22 Apr 2020 09:00:53 -0400 Subject: [PATCH 18/72] webgl build --- tfjs-backend-webgl/src/kernel_utils/reduce.ts | 8 ++--- .../src/kernel_utils/reshape.ts | 21 +++++++------- tfjs-backend-webgl/src/kernels/Max.ts | 22 +++++++------- tfjs-backend-webgl/src/kernels/Max_impl.ts | 29 ++++++++++++++++--- 4 files changed, 49 insertions(+), 31 deletions(-) diff --git a/tfjs-backend-webgl/src/kernel_utils/reduce.ts b/tfjs-backend-webgl/src/kernel_utils/reduce.ts index 949f96d7ea1..cfc5c1e5151 100644 --- a/tfjs-backend-webgl/src/kernel_utils/reduce.ts +++ b/tfjs-backend-webgl/src/kernel_utils/reduce.ts @@ -15,18 +15,16 @@ * ============================================================================= */ -import {TensorInfo} from '../../../kernel_registry'; -import {computeOptimalWindowSize, ReduceTypes} from '../../../ops/reduce_util'; -import {DataType} from '../../../types'; +import {backend_util, DataType, TensorInfo} from '@tensorflow/tfjs-core'; import {MathBackendWebGL} from '../backend_webgl'; import {ReduceProgram} from '../reduce_gpu'; export function reduce( - x: TensorInfo, dtype: DataType, reductionType: ReduceTypes, + x: TensorInfo, dtype: DataType, reductionType: backend_util.ReduceTypes, backend: MathBackendWebGL): TensorInfo { const [batchSize, inSize] = x.shape; - const windowSize = computeOptimalWindowSize(inSize); + const windowSize = backend_util.computeOptimalWindowSize(inSize); const reduceInfo = {windowSize, inSize, batchSize}; const program = new ReduceProgram(reduceInfo, reductionType); const output = backend.runWebGLProgram(program, [x], dtype); diff --git a/tfjs-backend-webgl/src/kernel_utils/reshape.ts b/tfjs-backend-webgl/src/kernel_utils/reshape.ts index a19fee541df..44417d68063 100644 --- a/tfjs-backend-webgl/src/kernel_utils/reshape.ts +++ b/tfjs-backend-webgl/src/kernel_utils/reshape.ts @@ -15,25 +15,26 @@ * ============================================================================= */ -import {TensorInfo} from '../../../kernel_registry'; -import {webgl_util} from '../../../webgl'; +import {TensorInfo} from '@tensorflow/tfjs-core'; + import {MathBackendWebGL} from '../backend_webgl'; import {ReshapePackedProgram} from '../reshape_packed_gpu'; +import {getBatchDim, getRowsCols, isReshapeFree} from '../webgl_util'; function packedReshape( input: TensorInfo, afterShape: number[], backend: MathBackendWebGL): TensorInfo { - const input3DShape = [ - webgl_util.getBatchDim(input.shape), ...webgl_util.getRowsCols(input.shape) - ] as [number, number, number]; + const input3DShape = + [getBatchDim(input.shape), + ...getRowsCols(input.shape)] as [number, number, number]; const input3D: TensorInfo = { dtype: input.dtype, shape: input3DShape, dataId: input.dataId }; - const afterShapeAs3D = [ - webgl_util.getBatchDim(afterShape), ...webgl_util.getRowsCols(afterShape) - ] as [number, number, number]; + const afterShapeAs3D = + [getBatchDim(afterShape), + ...getRowsCols(afterShape)] as [number, number, number]; const program = new ReshapePackedProgram(afterShapeAs3D, input3DShape); const preventEagerUnpackingOfOutput = true; @@ -47,9 +48,9 @@ export function reshape( x: TensorInfo, afterShape: number[], backend: MathBackendWebGL): TensorInfo { const xTexData = backend.texData.get(x.dataId); - if (xTexData.isPacked && !webgl_util.isReshapeFree(x.shape, afterShape) && + if (xTexData.isPacked && !isReshapeFree(x.shape, afterShape) && !(xTexData.texture !== null && - webgl_util.isReshapeFree(xTexData.shape, afterShape))) { + isReshapeFree(xTexData.shape, afterShape))) { return packedReshape(x, afterShape, backend); } diff --git a/tfjs-backend-webgl/src/kernels/Max.ts b/tfjs-backend-webgl/src/kernels/Max.ts index 66522edf091..3a6b3cde86f 100644 --- a/tfjs-backend-webgl/src/kernels/Max.ts +++ b/tfjs-backend-webgl/src/kernels/Max.ts @@ -15,14 +15,12 @@ * ============================================================================= */ -import {maxImpl as cpuMax} from '../../../backends/cpu/kernels/Max_impl'; -import {Max, MaxAttrs, MaxInputs} from '../../../kernel_names'; -import {KernelConfig} from '../../../kernel_registry'; -import * as axis_util from '../../../ops/axis_util'; -import {TypedArray} from '../../../types'; -import * as util from '../../../util'; +import {Max, MaxAttrs, MaxInputs} from '@tensorflow/tfjs-core'; +import {backend_util, KernelConfig, TypedArray, util} from '@tensorflow/tfjs-core'; + import {MathBackendWebGL} from '../backend_webgl'; -import {maxImpl} from './Max_impl'; + +import {maxImpl, maxImplCPU} from './Max_impl'; export const maxConfig: KernelConfig = { kernelName: Max, @@ -35,22 +33,22 @@ export const maxConfig: KernelConfig = { const origAxes = util.parseAxisParam(reductionIndices, x.shape); let axes = origAxes; - const permutedAxes = axis_util.getAxesPermutation(axes, x.shape.length); + const permutedAxes = backend_util.getAxesPermutation(axes, x.shape.length); if (permutedAxes != null) { console.log('TRANSPOSE'); } - axis_util.assertAxesAreInnerMostDims('max', axes, x.shape.length); + backend_util.assertAxesAreInnerMostDims('max', axes, x.shape.length); const [outShape, reduceShape] = - axis_util.computeOutAndReduceShapes(x.shape, axes); + backend_util.computeOutAndReduceShapes(x.shape, axes); let out; if (webglBackend.shouldExecuteOnCPU([x])) { console.log('running on the cpu instead'); const xTexData = webglBackend.texData.get(x.dataId); const values = xTexData.values as TypedArray; - const outValues = - cpuMax(values, util.sizeFromShape(reduceShape), outShape, x.dtype); + const outValues = maxImplCPU( + values, util.sizeFromShape(reduceShape), outShape, x.dtype); out = webglBackend.makeTensorInfo(outShape, x.dtype); const outData = webglBackend.texData.get(out.dataId); diff --git a/tfjs-backend-webgl/src/kernels/Max_impl.ts b/tfjs-backend-webgl/src/kernels/Max_impl.ts index f2841686321..257bddc8223 100644 --- a/tfjs-backend-webgl/src/kernels/Max_impl.ts +++ b/tfjs-backend-webgl/src/kernels/Max_impl.ts @@ -15,8 +15,8 @@ * ============================================================================= */ -import {TensorInfo} from '../../../kernel_registry'; -import {sizeFromShape} from '../../../util'; +import {DataType, NumericDataType, TensorInfo, TypedArray, util} from '@tensorflow/tfjs-core'; + import {MathBackendWebGL} from '../backend_webgl'; import {reduce} from '../kernel_utils/reduce'; import {reshape} from '../kernel_utils/reshape'; @@ -24,8 +24,8 @@ import {reshape} from '../kernel_utils/reshape'; export const maxImpl = (x: TensorInfo, reduceShape: number[], outShape: number[], backend: MathBackendWebGL): TensorInfo => { - const inSize = sizeFromShape(reduceShape); - const xSize = sizeFromShape(x.shape); + const inSize = util.sizeFromShape(reduceShape); + const xSize = util.sizeFromShape(x.shape); const batchSize = xSize / inSize; return reshape( @@ -34,3 +34,24 @@ export const maxImpl = backend), outShape, backend); }; + +// todo(@annxingyuan) import this from cpu backend. +export function maxImplCPU( + aVals: TypedArray, reduceSize: number, outShape: number[], + dtype: DataType): TypedArray { + const vals = util.getTypedArrayFromDType( + dtype as NumericDataType, util.sizeFromShape(outShape)); + + for (let i = 0; i < vals.length; ++i) { + const offset = i * reduceSize; + let max = aVals[offset]; + for (let j = 0; j < reduceSize; ++j) { + const value = aVals[offset + j]; + if (value > max) { + max = value; + } + } + vals[i] = max; + } + return vals; +} From 7b14fafb01997149ef43b7e79f08795506e0f077 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Wed, 22 Apr 2020 09:03:12 -0400 Subject: [PATCH 19/72] fix --- tfjs-backend-cpu/yarn.lock | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tfjs-backend-cpu/yarn.lock b/tfjs-backend-cpu/yarn.lock index ffeeeada886..754e4d08dbc 100644 --- a/tfjs-backend-cpu/yarn.lock +++ b/tfjs-backend-cpu/yarn.lock @@ -139,11 +139,6 @@ acorn@^6.0.5: resolved "https://registry.yarnpkg.com/acorn/-/acorn-6.4.1.tgz#531e58ba3f51b9dacb9a6646ca4debf5b14ca474" integrity sha512-ZVA9k326Nwrj3Cj9jlh3wGFutC2ZornPNARZwsNYqQYgN0EsV2d53w5RN/co65Ohn4sUAUtb1rSUAOD6XN9idA== -acorn@^7.1.1: - version "7.1.1" - resolved "https://registry.yarnpkg.com/acorn/-/acorn-7.1.1.tgz#e35668de0b402f359de515c5482a1ab9f89a69bf" - integrity sha512-add7dgA5ppRPxCFJoAGfMDi7PIBXq1RtGo7BhbLaxwrXPOmw8gq48Y9ozT01hUKy9byMjlR20EJhu5zlkErEkg== - after@0.8.2: version "0.8.2" resolved "https://registry.yarnpkg.com/after/-/after-0.8.2.tgz#fedb394f9f0e02aa9768e702bda23b505fae7e1f" From 9020fedaf96959c848bea0f9ebdbd248e2d99321 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Wed, 22 Apr 2020 09:58:54 -0400 Subject: [PATCH 20/72] max --- tfjs-backend-webgl/src/kernels/Max.ts | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/tfjs-backend-webgl/src/kernels/Max.ts b/tfjs-backend-webgl/src/kernels/Max.ts index 3a6b3cde86f..e371e4a03b1 100644 --- a/tfjs-backend-webgl/src/kernels/Max.ts +++ b/tfjs-backend-webgl/src/kernels/Max.ts @@ -21,24 +21,28 @@ import {backend_util, KernelConfig, TypedArray, util} from '@tensorflow/tfjs-cor import {MathBackendWebGL} from '../backend_webgl'; import {maxImpl, maxImplCPU} from './Max_impl'; +import {transposeImpl} from './Transpose_impl'; export const maxConfig: KernelConfig = { kernelName: Max, backendName: 'webgl', kernelFunc: ({inputs, attrs, backend}) => { - const {x} = inputs as MaxInputs; + let {x} = inputs as MaxInputs; const {reductionIndices} = attrs as {} as MaxAttrs; const webglBackend = backend as MathBackendWebGL; console.log('max webgl kernel func', x, reductionIndices); + const xRank = x.shape.length; + const origAxes = util.parseAxisParam(reductionIndices, x.shape); let axes = origAxes; - const permutedAxes = backend_util.getAxesPermutation(axes, x.shape.length); + const permutedAxes = backend_util.getAxesPermutation(axes, xRank); if (permutedAxes != null) { - console.log('TRANSPOSE'); + x = transposeImpl(x, permutedAxes, webglBackend); + axes = backend_util.getInnerMostAxes(axes.length, xRank); } - backend_util.assertAxesAreInnerMostDims('max', axes, x.shape.length); + backend_util.assertAxesAreInnerMostDims('max', axes, xRank); const [outShape, reduceShape] = backend_util.computeOutAndReduceShapes(x.shape, axes); From 17a46a20d8ab42777fc96c289bf077db671783d0 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Wed, 22 Apr 2020 10:37:12 -0400 Subject: [PATCH 21/72] add more logs --- tfjs-backend-cpu/src/kernels/Max.ts | 4 ++++ tfjs-core/src/ops/reduction_ops_test.ts | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/tfjs-backend-cpu/src/kernels/Max.ts b/tfjs-backend-cpu/src/kernels/Max.ts index ff436eb4f40..eb91d2f0b94 100644 --- a/tfjs-backend-cpu/src/kernels/Max.ts +++ b/tfjs-backend-cpu/src/kernels/Max.ts @@ -41,8 +41,11 @@ export const maxConfig: KernelConfig = { let xVals = cpuBackend.data.get(x.dataId).values as TypedArray; if (permutedAxes != null) { console.log('TRANSPOSE'); + console.log(permutedAxes); xVals = transposeImpl(xVals, x.shape, x.dtype, permutedAxes); axes = backend_util.getInnerMostAxes(axes.length, xRank); + console.log(axes); + console.log(xVals); } assertNotComplex(x, 'max'); @@ -53,6 +56,7 @@ export const maxConfig: KernelConfig = { const reduceSize = util.sizeFromShape(reduceShape); const result = maxImpl(xVals, reduceSize, outShape, x.dtype); + console.log(result); const dataId = cpuBackend.write(result, outShape, x.dtype); return {dataId, shape: outShape, dtype: x.dtype}; diff --git a/tfjs-core/src/ops/reduction_ops_test.ts b/tfjs-core/src/ops/reduction_ops_test.ts index a26623ff45b..90c5524bb7a 100644 --- a/tfjs-core/src/ops/reduction_ops_test.ts +++ b/tfjs-core/src/ops/reduction_ops_test.ts @@ -1563,7 +1563,7 @@ describeWithFlags('norm', ALL_ENVS, () => { expectArraysClose(await norm.data(), [4]); }); - it('axis=0,1 keepDims in 3D array norm', async () => { + fit('axis=0,1 keepDims in 3D array norm', async () => { const a = tf.tensor3d([1, 2, 3, 0, 0, 1], [3, 2, 1]); const norm = tf.norm(a, Infinity, [0, 1], true); From 0e3ea93a20c90105006cc70d6e2c358cafca5ddf Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Wed, 22 Apr 2020 10:49:28 -0400 Subject: [PATCH 22/72] add logs --- tfjs-backend-cpu/src/kernels/Max.ts | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tfjs-backend-cpu/src/kernels/Max.ts b/tfjs-backend-cpu/src/kernels/Max.ts index eb91d2f0b94..4153720557c 100644 --- a/tfjs-backend-cpu/src/kernels/Max.ts +++ b/tfjs-backend-cpu/src/kernels/Max.ts @@ -56,6 +56,8 @@ export const maxConfig: KernelConfig = { const reduceSize = util.sizeFromShape(reduceShape); const result = maxImpl(xVals, reduceSize, outShape, x.dtype); + console.log('RESULT'); + console.log(reduceSize, outShape); console.log(result); const dataId = cpuBackend.write(result, outShape, x.dtype); From e0c834b75c8ef69cbb5dbaea34ad898b36e38a91 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Wed, 22 Apr 2020 10:55:32 -0400 Subject: [PATCH 23/72] logs --- tfjs-backend-cpu/src/kernels/Max.ts | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tfjs-backend-cpu/src/kernels/Max.ts b/tfjs-backend-cpu/src/kernels/Max.ts index 4153720557c..38c55a8d858 100644 --- a/tfjs-backend-cpu/src/kernels/Max.ts +++ b/tfjs-backend-cpu/src/kernels/Max.ts @@ -39,14 +39,13 @@ export const maxConfig: KernelConfig = { let axes = origAxes; const permutedAxes = backend_util.getAxesPermutation(axes, xRank); let xVals = cpuBackend.data.get(x.dataId).values as TypedArray; + console.log('permuted axes', permutedAxes); if (permutedAxes != null) { - console.log('TRANSPOSE'); - console.log(permutedAxes); xVals = transposeImpl(xVals, x.shape, x.dtype, permutedAxes); axes = backend_util.getInnerMostAxes(axes.length, xRank); - console.log(axes); - console.log(xVals); } + console.log('axes', axes); + console.log('x', xVals); assertNotComplex(x, 'max'); backend_util.assertAxesAreInnerMostDims('max', axes, xRank); @@ -57,6 +56,7 @@ export const maxConfig: KernelConfig = { const result = maxImpl(xVals, reduceSize, outShape, x.dtype); console.log('RESULT'); + console.log(reduceShape); console.log(reduceSize, outShape); console.log(result); From 3a3622ca895d10e20903a7c17a751649679bfb4c Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Wed, 22 Apr 2020 10:58:25 -0400 Subject: [PATCH 24/72] logs --- tfjs-backend-cpu/src/kernels/Max.ts | 1 + 1 file changed, 1 insertion(+) diff --git a/tfjs-backend-cpu/src/kernels/Max.ts b/tfjs-backend-cpu/src/kernels/Max.ts index 38c55a8d858..4149d6d62ec 100644 --- a/tfjs-backend-cpu/src/kernels/Max.ts +++ b/tfjs-backend-cpu/src/kernels/Max.ts @@ -56,6 +56,7 @@ export const maxConfig: KernelConfig = { const result = maxImpl(xVals, reduceSize, outShape, x.dtype); console.log('RESULT'); + console.log('x shape', x.shape, axes); console.log(reduceShape); console.log(reduceSize, outShape); console.log(result); From 10b8cb99f65fbb3b96ca9bf4fac3fcb071a0706d Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Wed, 22 Apr 2020 11:03:20 -0400 Subject: [PATCH 25/72] rem logs --- tfjs-backend-cpu/src/kernels/Max.ts | 25 ++++++++++++------------- tfjs-core/src/ops/reduction_ops_test.ts | 2 +- 2 files changed, 13 insertions(+), 14 deletions(-) diff --git a/tfjs-backend-cpu/src/kernels/Max.ts b/tfjs-backend-cpu/src/kernels/Max.ts index 4149d6d62ec..cd41de1f100 100644 --- a/tfjs-backend-cpu/src/kernels/Max.ts +++ b/tfjs-backend-cpu/src/kernels/Max.ts @@ -32,34 +32,33 @@ export const maxConfig: KernelConfig = { const {x} = inputs as MaxInputs; const {reductionIndices} = attrs as {} as MaxAttrs; const cpuBackend = backend as MathBackendCPU; - const xRank = x.shape.length; - console.log('max cpu kernel func', x, reductionIndices); + let xShape = x.shape; + const xRank = xShape.length; - const origAxes = util.parseAxisParam(reductionIndices, x.shape); + const origAxes = util.parseAxisParam(reductionIndices, xShape); let axes = origAxes; const permutedAxes = backend_util.getAxesPermutation(axes, xRank); let xVals = cpuBackend.data.get(x.dataId).values as TypedArray; - console.log('permuted axes', permutedAxes); if (permutedAxes != null) { - xVals = transposeImpl(xVals, x.shape, x.dtype, permutedAxes); + xVals = transposeImpl(xVals, xShape, x.dtype, permutedAxes); axes = backend_util.getInnerMostAxes(axes.length, xRank); + + const newShape: number[] = new Array(xRank); + for (let i = 0; i < newShape.length; i++) { + newShape[i] = xShape[permutedAxes[i]]; + } + + xShape = newShape; } - console.log('axes', axes); - console.log('x', xVals); assertNotComplex(x, 'max'); backend_util.assertAxesAreInnerMostDims('max', axes, xRank); const [outShape, reduceShape] = - backend_util.computeOutAndReduceShapes(x.shape, axes); + backend_util.computeOutAndReduceShapes(xShape, axes); const reduceSize = util.sizeFromShape(reduceShape); const result = maxImpl(xVals, reduceSize, outShape, x.dtype); - console.log('RESULT'); - console.log('x shape', x.shape, axes); - console.log(reduceShape); - console.log(reduceSize, outShape); - console.log(result); const dataId = cpuBackend.write(result, outShape, x.dtype); return {dataId, shape: outShape, dtype: x.dtype}; diff --git a/tfjs-core/src/ops/reduction_ops_test.ts b/tfjs-core/src/ops/reduction_ops_test.ts index 90c5524bb7a..a26623ff45b 100644 --- a/tfjs-core/src/ops/reduction_ops_test.ts +++ b/tfjs-core/src/ops/reduction_ops_test.ts @@ -1563,7 +1563,7 @@ describeWithFlags('norm', ALL_ENVS, () => { expectArraysClose(await norm.data(), [4]); }); - fit('axis=0,1 keepDims in 3D array norm', async () => { + it('axis=0,1 keepDims in 3D array norm', async () => { const a = tf.tensor3d([1, 2, 3, 0, 0, 1], [3, 2, 1]); const norm = tf.norm(a, Infinity, [0, 1], true); From e38df556315cf2fe46ca66831a966cb17d7a9653 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Wed, 22 Apr 2020 11:13:35 -0400 Subject: [PATCH 26/72] logs --- tfjs-core/src/gradients/Max_grad.ts | 8 ++++++++ tfjs-core/src/ops/reduction_ops.ts | 2 ++ tfjs-core/src/ops/reduction_ops_test.ts | 2 +- 3 files changed, 11 insertions(+), 1 deletion(-) diff --git a/tfjs-core/src/gradients/Max_grad.ts b/tfjs-core/src/gradients/Max_grad.ts index d8f682f84aa..6a121258e4f 100644 --- a/tfjs-core/src/gradients/Max_grad.ts +++ b/tfjs-core/src/gradients/Max_grad.ts @@ -27,11 +27,19 @@ export const maxGradConfig: GradConfig = { inputsToSave: ['x'], outputsToSave: [true], gradFunc: (dy: Tensor, saved: Tensor[], attrs: NamedAttrMap) => { + console.log('GRAD FUNC'); const maxAttrs: MaxAttrs = attrs as {} as MaxAttrs; const {reductionIndices} = maxAttrs; const [x, y] = saved; + console.log('max attrs'); + console.log(maxAttrs); + console.log('reduction indices'); + console.log(reductionIndices); + console.log('x shape', x.shape); const origAxes = util.parseAxisParam(reductionIndices, x.shape); + console.log('orig axes', origAxes); const permutedAxes = axis_util.getAxesPermutation(reductionIndices, x.rank); + console.log('permuted axes', permutedAxes); return gradForMinAndMax(dy, y, x, origAxes, permutedAxes); } }; diff --git a/tfjs-core/src/ops/reduction_ops.ts b/tfjs-core/src/ops/reduction_ops.ts index 68fb6be6a45..168d32166b9 100644 --- a/tfjs-core/src/ops/reduction_ops.ts +++ b/tfjs-core/src/ops/reduction_ops.ts @@ -272,6 +272,8 @@ function mean_( */ export function gradForMinAndMax( dy: T, y: T, xOrig: Tensor, origAxes: number[], permutedAxes: number[]) { + console.log('orig axes', origAxes); + console.log(permutedAxes); if (y.rank < xOrig.rank) { y = y.reshape(axis_util.expandShapeToKeepDim(y.shape, origAxes)) as T; } diff --git a/tfjs-core/src/ops/reduction_ops_test.ts b/tfjs-core/src/ops/reduction_ops_test.ts index a26623ff45b..208cb4624b8 100644 --- a/tfjs-core/src/ops/reduction_ops_test.ts +++ b/tfjs-core/src/ops/reduction_ops_test.ts @@ -320,7 +320,7 @@ describeWithFlags('max', ALL_ENVS, () => { expectArraysClose(await gradients.data(), [-1]); }); - it('max gradient: 1D, ties', async () => { + fit('max gradient: 1D, ties', async () => { const x = tf.tensor1d([1, 3, 7, 7]); const dy = tf.scalar(-1); const gradients = tf.grad(v => tf.max(v))(x, dy); From 04de4ce73265a4c6becc2beb36f09cebaa83bc9e Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Wed, 22 Apr 2020 11:17:52 -0400 Subject: [PATCH 27/72] remove hacks --- tfjs-core/src/gradients/Max_grad.ts | 10 +--------- tfjs-core/src/ops/reduction_ops_test.ts | 2 +- 2 files changed, 2 insertions(+), 10 deletions(-) diff --git a/tfjs-core/src/gradients/Max_grad.ts b/tfjs-core/src/gradients/Max_grad.ts index 6a121258e4f..c80e26ee187 100644 --- a/tfjs-core/src/gradients/Max_grad.ts +++ b/tfjs-core/src/gradients/Max_grad.ts @@ -27,19 +27,11 @@ export const maxGradConfig: GradConfig = { inputsToSave: ['x'], outputsToSave: [true], gradFunc: (dy: Tensor, saved: Tensor[], attrs: NamedAttrMap) => { - console.log('GRAD FUNC'); const maxAttrs: MaxAttrs = attrs as {} as MaxAttrs; const {reductionIndices} = maxAttrs; const [x, y] = saved; - console.log('max attrs'); - console.log(maxAttrs); - console.log('reduction indices'); - console.log(reductionIndices); - console.log('x shape', x.shape); const origAxes = util.parseAxisParam(reductionIndices, x.shape); - console.log('orig axes', origAxes); - const permutedAxes = axis_util.getAxesPermutation(reductionIndices, x.rank); - console.log('permuted axes', permutedAxes); + const permutedAxes = axis_util.getAxesPermutation(origAxes, x.rank); return gradForMinAndMax(dy, y, x, origAxes, permutedAxes); } }; diff --git a/tfjs-core/src/ops/reduction_ops_test.ts b/tfjs-core/src/ops/reduction_ops_test.ts index 208cb4624b8..a26623ff45b 100644 --- a/tfjs-core/src/ops/reduction_ops_test.ts +++ b/tfjs-core/src/ops/reduction_ops_test.ts @@ -320,7 +320,7 @@ describeWithFlags('max', ALL_ENVS, () => { expectArraysClose(await gradients.data(), [-1]); }); - fit('max gradient: 1D, ties', async () => { + it('max gradient: 1D, ties', async () => { const x = tf.tensor1d([1, 3, 7, 7]); const dy = tf.scalar(-1); const gradients = tf.grad(v => tf.max(v))(x, dy); From 0e2cb52347130c5f490a7a0a0554d34d9426229a Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Thu, 23 Apr 2020 09:47:38 -0400 Subject: [PATCH 28/72] add logs --- tfjs-backend-cpu/src/kernels/Max.ts | 1 + tfjs-core/src/engine.ts | 5 +++++ tfjs-core/src/gradients/Max_grad.ts | 1 + tfjs-core/src/gradients/Transpose_grad.ts | 10 +++++++++- tfjs-core/src/ops/reduction_ops.ts | 6 +++++- tfjs-core/src/ops/reduction_ops_test.ts | 2 +- tfjs-core/src/tape.ts | 8 ++++++++ 7 files changed, 30 insertions(+), 3 deletions(-) diff --git a/tfjs-backend-cpu/src/kernels/Max.ts b/tfjs-backend-cpu/src/kernels/Max.ts index cd41de1f100..206e8c3bd75 100644 --- a/tfjs-backend-cpu/src/kernels/Max.ts +++ b/tfjs-backend-cpu/src/kernels/Max.ts @@ -40,6 +40,7 @@ export const maxConfig: KernelConfig = { const permutedAxes = backend_util.getAxesPermutation(axes, xRank); let xVals = cpuBackend.data.get(x.dataId).values as TypedArray; if (permutedAxes != null) { + console.log('TARNSPOSING'); xVals = transposeImpl(xVals, xShape, x.dtype, permutedAxes); axes = backend_util.getInnerMostAxes(axes.length, xRank); diff --git a/tfjs-core/src/engine.ts b/tfjs-core/src/engine.ts index 4ee48fa5a10..ac2d55179e1 100644 --- a/tfjs-core/src/engine.ts +++ b/tfjs-core/src/engine.ts @@ -895,6 +895,11 @@ export class Engine implements TensorTracker, DataMover { return gradientsFunc(dys.length > 1 ? dys : dys[0], saved, attrs); }; } + console.log('PUSHING TO TAPE'); + for (const input in inputs) { + console.log(input); + console.log(inputs[input].shape); + } this.state.activeTape.push(tapeNode); } diff --git a/tfjs-core/src/gradients/Max_grad.ts b/tfjs-core/src/gradients/Max_grad.ts index c80e26ee187..ce859982c74 100644 --- a/tfjs-core/src/gradients/Max_grad.ts +++ b/tfjs-core/src/gradients/Max_grad.ts @@ -27,6 +27,7 @@ export const maxGradConfig: GradConfig = { inputsToSave: ['x'], outputsToSave: [true], gradFunc: (dy: Tensor, saved: Tensor[], attrs: NamedAttrMap) => { + console.log('MAX GRAD FUNC'); const maxAttrs: MaxAttrs = attrs as {} as MaxAttrs; const {reductionIndices} = maxAttrs; const [x, y] = saved; diff --git a/tfjs-core/src/gradients/Transpose_grad.ts b/tfjs-core/src/gradients/Transpose_grad.ts index 4d48500d7de..4c03e1cf6b7 100644 --- a/tfjs-core/src/gradients/Transpose_grad.ts +++ b/tfjs-core/src/gradients/Transpose_grad.ts @@ -27,6 +27,14 @@ export const transposeGradConfig: GradConfig = { const transposeAttrs: TransposeAttrs = attrs as {} as TransposeAttrs; const {perm} = transposeAttrs; const undoPerm = axis_util.getUndoAxesPermutation(perm); - return {x: () => transpose(dy, undoPerm)}; + return { + x: () => { + console.log('IN TRANSPOSE GRADIENT'); + console.log(dy.shape); + const out = transpose(dy, undoPerm); + console.log(out.shape); + return out; + } + }; } }; diff --git a/tfjs-core/src/ops/reduction_ops.ts b/tfjs-core/src/ops/reduction_ops.ts index 168d32166b9..673b3f4d978 100644 --- a/tfjs-core/src/ops/reduction_ops.ts +++ b/tfjs-core/src/ops/reduction_ops.ts @@ -272,6 +272,7 @@ function mean_( */ export function gradForMinAndMax( dy: T, y: T, xOrig: Tensor, origAxes: number[], permutedAxes: number[]) { + console.log('GRAD FRO MIN AND MAX'); console.log('orig axes', origAxes); console.log(permutedAxes); if (y.rank < xOrig.rank) { @@ -282,8 +283,11 @@ export function gradForMinAndMax( } return { x: () => { + console.log('INVOKING'); const dx = dy.mul(xOrig.equal(y).cast(dy.dtype)); - return permutedAxes == null ? dx : dx.transpose(permutedAxes); + const out = permutedAxes == null ? dx : dx.transpose(permutedAxes); + console.log('out', out.shape); + return out; } }; } diff --git a/tfjs-core/src/ops/reduction_ops_test.ts b/tfjs-core/src/ops/reduction_ops_test.ts index a26623ff45b..e69f164e0f6 100644 --- a/tfjs-core/src/ops/reduction_ops_test.ts +++ b/tfjs-core/src/ops/reduction_ops_test.ts @@ -345,7 +345,7 @@ describeWithFlags('max', ALL_ENVS, () => { expect(gradients.shape).toEqual([2, 3]); }); - it('max gradient: 2D, axes=0, keepDims=false', async () => { + fit('max gradient: 2D, axes=0, keepDims=false', async () => { const x = tf.tensor2d([[0, 20, 10], [-10, -30, 20]]); const dy = tf.tensor1d([-1, -1, -1]); const axis = 0; diff --git a/tfjs-core/src/tape.ts b/tfjs-core/src/tape.ts index 8f78106f150..6f5fd152113 100644 --- a/tfjs-core/src/tape.ts +++ b/tfjs-core/src/tape.ts @@ -43,6 +43,7 @@ export type NamedGradientMap = { */ export function getFilteredNodesXToY( tape: TapeNode[], xs: Tensor[], y: Tensor): TapeNode[] { + console.log('GETTING NODES FROM X TO Y'); // Forward pass to compute all the nodes and Tensors that are transitively a // function of x. const tensorsFromX: {[tensorId: number]: boolean} = {}; @@ -131,9 +132,11 @@ export function getFilteredNodesXToY( export function backpropagateGradients( tensorAccumulatedGradientMap: {[tensorId: number]: Tensor}, filteredTape: TapeNode[], tidy: (f: Function) => Tensor) { + console.log('BACKPROPPING'); // Walk the tape backward and keep a map of Tensor to its gradient. for (let i = filteredTape.length - 1; i >= 0; i--) { const node = filteredTape[i]; + console.log('tape node', node.kernelName); const dys: Tensor[] = []; node.outputs.forEach(o => { @@ -157,6 +160,7 @@ export function backpropagateGradients( const inputGradients = node.gradient(dys); for (const inputName in node.inputs) { + console.log('looping over inputs in node inputs', inputName); if (!(inputName in inputGradients)) { throw new Error( `Cannot backprop through input ${inputName}. ` + @@ -165,6 +169,8 @@ export function backpropagateGradients( // Call the gradient function. const dx = tidy(() => inputGradients[inputName]()); + console.log('just called the gradient function'); + if (dx.dtype !== 'float32') { throw new Error( `Error in gradient for op ${ @@ -172,6 +178,7 @@ export function backpropagateGradients( `${inputName} must have 'float32' dtype, but has '${dx.dtype}'`); } const x = node.inputs[inputName]; + console.log(dx.shape, x.shape); if (!util.arraysEqual(dx.shape, x.shape)) { throw new Error( `Error in gradient for op ${ @@ -179,6 +186,7 @@ export function backpropagateGradients( `'${inputName}' has shape '${dx.shape}', which does not match ` + `the shape of the input '${x.shape}'`); } + console.log('got past errors'); if (tensorAccumulatedGradientMap[x.id] == null) { tensorAccumulatedGradientMap[x.id] = dx; From f068e52f1acb6b92f5def0c99001edd04c8f36c7 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Thu, 23 Apr 2020 10:08:44 -0400 Subject: [PATCH 29/72] add logs --- tfjs-core/src/engine.ts | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/tfjs-core/src/engine.ts b/tfjs-core/src/engine.ts index ac2d55179e1..13f00eba51c 100644 --- a/tfjs-core/src/engine.ts +++ b/tfjs-core/src/engine.ts @@ -466,6 +466,7 @@ export class Engine implements TensorTracker, DataMover { const inputs = {x}; const grad = (dy: Tensor) => ({x: () => dy.toFloat()}); const saved: Tensor[] = []; + console.log('ADDING TO TAPE IN CLONE'); this.addTapeNode(this.state.activeScope.name, inputs, [y], grad, saved, {}); return y; } @@ -538,6 +539,11 @@ export class Engine implements TensorTracker, DataMover { backwardsFunc?: (dy: T, saved: Tensor[]) => {[P in keyof I]: () => I[P]}, kernelName?: string, attrs?: NamedAttrMap, inputsToSave?: Tensor[], outputsToSave?: boolean[]): T { + console.log('RUN KERNEL FUNC', kernelName); + for (const input in inputs) { + console.log(input); + console.log(inputs[input].shape); + } let outputs: Tensor[]; let saved: Tensor[] = []; const isTapeOn = this.isTapeOn(); @@ -624,6 +630,7 @@ export class Engine implements TensorTracker, DataMover { }); if (isTapeOn) { + console.log('ADDING TO TAPE NODE WITHIN RUNKERNELFUNC'); this.addTapeNode( kernelName, inputs, outputs, backwardsFunc, saved, attrs); } @@ -871,6 +878,11 @@ export class Engine implements TensorTracker, DataMover { private addTapeNode( kernelName: string, inputs: NamedTensorMap, outputs: Tensor[], gradientsFunc: GradFunc, saved: Tensor[], attrs: NamedAttrMap): void { + console.log('PUSHING TO TAPE', kernelName); + for (const input in inputs) { + console.log(input); + console.log(inputs[input].shape); + } const tapeNode: TapeNode = {id: this.state.nextTapeNodeId++, kernelName, inputs, outputs, saved}; @@ -895,11 +907,6 @@ export class Engine implements TensorTracker, DataMover { return gradientsFunc(dys.length > 1 ? dys : dys[0], saved, attrs); }; } - console.log('PUSHING TO TAPE'); - for (const input in inputs) { - console.log(input); - console.log(inputs[input].shape); - } this.state.activeTape.push(tapeNode); } From 42625724eb23c92ac1d1c58c8b77ee3521a4b24a Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Thu, 23 Apr 2020 10:45:51 -0400 Subject: [PATCH 30/72] add transpose to grad --- tfjs-core/src/gradients/Max_grad.ts | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/tfjs-core/src/gradients/Max_grad.ts b/tfjs-core/src/gradients/Max_grad.ts index ce859982c74..9e399e17dc6 100644 --- a/tfjs-core/src/gradients/Max_grad.ts +++ b/tfjs-core/src/gradients/Max_grad.ts @@ -19,6 +19,7 @@ import {Max, MaxAttrs} from '../kernel_names'; import {GradConfig, NamedAttrMap} from '../kernel_registry'; import * as axis_util from '../ops/axis_util'; import {gradForMinAndMax} from '../ops/reduction_ops'; +import {transpose} from '../ops/transpose'; import {Tensor} from '../tensor'; import * as util from '../util'; @@ -33,6 +34,12 @@ export const maxGradConfig: GradConfig = { const [x, y] = saved; const origAxes = util.parseAxisParam(reductionIndices, x.shape); const permutedAxes = axis_util.getAxesPermutation(origAxes, x.rank); - return gradForMinAndMax(dy, y, x, origAxes, permutedAxes); + const maxGrad = gradForMinAndMax(dy, y, x, origAxes, permutedAxes); + return { + x: () => { + const out = maxGrad['x'](); + return transpose(out); + } + }; } }; From c8642630fc9f9eb3db3e2515919fa6a152c867f1 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Thu, 23 Apr 2020 10:46:55 -0400 Subject: [PATCH 31/72] add condition --- tfjs-core/src/gradients/Max_grad.ts | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tfjs-core/src/gradients/Max_grad.ts b/tfjs-core/src/gradients/Max_grad.ts index 9e399e17dc6..bbc7382e0d8 100644 --- a/tfjs-core/src/gradients/Max_grad.ts +++ b/tfjs-core/src/gradients/Max_grad.ts @@ -37,8 +37,11 @@ export const maxGradConfig: GradConfig = { const maxGrad = gradForMinAndMax(dy, y, x, origAxes, permutedAxes); return { x: () => { - const out = maxGrad['x'](); - return transpose(out); + let out = maxGrad['x'](); + if (permutedAxes != null) { + out = transpose(out); + } + return out; } }; } From 2f941f6fea3cb9c3dbafb616330488f280c49e26 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Thu, 23 Apr 2020 10:51:17 -0400 Subject: [PATCH 32/72] passes --- tfjs-core/src/engine.ts | 12 ------------ tfjs-core/src/ops/reduction_ops_test.ts | 2 +- tfjs-core/src/tape.ts | 4 ---- 3 files changed, 1 insertion(+), 17 deletions(-) diff --git a/tfjs-core/src/engine.ts b/tfjs-core/src/engine.ts index 13f00eba51c..4ee48fa5a10 100644 --- a/tfjs-core/src/engine.ts +++ b/tfjs-core/src/engine.ts @@ -466,7 +466,6 @@ export class Engine implements TensorTracker, DataMover { const inputs = {x}; const grad = (dy: Tensor) => ({x: () => dy.toFloat()}); const saved: Tensor[] = []; - console.log('ADDING TO TAPE IN CLONE'); this.addTapeNode(this.state.activeScope.name, inputs, [y], grad, saved, {}); return y; } @@ -539,11 +538,6 @@ export class Engine implements TensorTracker, DataMover { backwardsFunc?: (dy: T, saved: Tensor[]) => {[P in keyof I]: () => I[P]}, kernelName?: string, attrs?: NamedAttrMap, inputsToSave?: Tensor[], outputsToSave?: boolean[]): T { - console.log('RUN KERNEL FUNC', kernelName); - for (const input in inputs) { - console.log(input); - console.log(inputs[input].shape); - } let outputs: Tensor[]; let saved: Tensor[] = []; const isTapeOn = this.isTapeOn(); @@ -630,7 +624,6 @@ export class Engine implements TensorTracker, DataMover { }); if (isTapeOn) { - console.log('ADDING TO TAPE NODE WITHIN RUNKERNELFUNC'); this.addTapeNode( kernelName, inputs, outputs, backwardsFunc, saved, attrs); } @@ -878,11 +871,6 @@ export class Engine implements TensorTracker, DataMover { private addTapeNode( kernelName: string, inputs: NamedTensorMap, outputs: Tensor[], gradientsFunc: GradFunc, saved: Tensor[], attrs: NamedAttrMap): void { - console.log('PUSHING TO TAPE', kernelName); - for (const input in inputs) { - console.log(input); - console.log(inputs[input].shape); - } const tapeNode: TapeNode = {id: this.state.nextTapeNodeId++, kernelName, inputs, outputs, saved}; diff --git a/tfjs-core/src/ops/reduction_ops_test.ts b/tfjs-core/src/ops/reduction_ops_test.ts index e69f164e0f6..a26623ff45b 100644 --- a/tfjs-core/src/ops/reduction_ops_test.ts +++ b/tfjs-core/src/ops/reduction_ops_test.ts @@ -345,7 +345,7 @@ describeWithFlags('max', ALL_ENVS, () => { expect(gradients.shape).toEqual([2, 3]); }); - fit('max gradient: 2D, axes=0, keepDims=false', async () => { + it('max gradient: 2D, axes=0, keepDims=false', async () => { const x = tf.tensor2d([[0, 20, 10], [-10, -30, 20]]); const dy = tf.tensor1d([-1, -1, -1]); const axis = 0; diff --git a/tfjs-core/src/tape.ts b/tfjs-core/src/tape.ts index 6f5fd152113..fe09850d3e7 100644 --- a/tfjs-core/src/tape.ts +++ b/tfjs-core/src/tape.ts @@ -160,7 +160,6 @@ export function backpropagateGradients( const inputGradients = node.gradient(dys); for (const inputName in node.inputs) { - console.log('looping over inputs in node inputs', inputName); if (!(inputName in inputGradients)) { throw new Error( `Cannot backprop through input ${inputName}. ` + @@ -169,7 +168,6 @@ export function backpropagateGradients( // Call the gradient function. const dx = tidy(() => inputGradients[inputName]()); - console.log('just called the gradient function'); if (dx.dtype !== 'float32') { throw new Error( @@ -178,7 +176,6 @@ export function backpropagateGradients( `${inputName} must have 'float32' dtype, but has '${dx.dtype}'`); } const x = node.inputs[inputName]; - console.log(dx.shape, x.shape); if (!util.arraysEqual(dx.shape, x.shape)) { throw new Error( `Error in gradient for op ${ @@ -186,7 +183,6 @@ export function backpropagateGradients( `'${inputName}' has shape '${dx.shape}', which does not match ` + `the shape of the input '${x.shape}'`); } - console.log('got past errors'); if (tensorAccumulatedGradientMap[x.id] == null) { tensorAccumulatedGradientMap[x.id] = dx; From 36cd5ecf72030932cea52cf6ab938feab3d5f987 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Thu, 23 Apr 2020 10:56:27 -0400 Subject: [PATCH 33/72] reove logs --- tfjs-backend-cpu/src/kernels/Max.ts | 1 - tfjs-backend-webgl/yarn.lock | 5 ----- tfjs-core/src/ops/reduction_ops_test.ts | 2 +- tfjs-core/src/tape.ts | 3 --- 4 files changed, 1 insertion(+), 10 deletions(-) diff --git a/tfjs-backend-cpu/src/kernels/Max.ts b/tfjs-backend-cpu/src/kernels/Max.ts index 206e8c3bd75..cd41de1f100 100644 --- a/tfjs-backend-cpu/src/kernels/Max.ts +++ b/tfjs-backend-cpu/src/kernels/Max.ts @@ -40,7 +40,6 @@ export const maxConfig: KernelConfig = { const permutedAxes = backend_util.getAxesPermutation(axes, xRank); let xVals = cpuBackend.data.get(x.dataId).values as TypedArray; if (permutedAxes != null) { - console.log('TARNSPOSING'); xVals = transposeImpl(xVals, xShape, x.dtype, permutedAxes); axes = backend_util.getInnerMostAxes(axes.length, xRank); diff --git a/tfjs-backend-webgl/yarn.lock b/tfjs-backend-webgl/yarn.lock index ffeeeada886..754e4d08dbc 100644 --- a/tfjs-backend-webgl/yarn.lock +++ b/tfjs-backend-webgl/yarn.lock @@ -139,11 +139,6 @@ acorn@^6.0.5: resolved "https://registry.yarnpkg.com/acorn/-/acorn-6.4.1.tgz#531e58ba3f51b9dacb9a6646ca4debf5b14ca474" integrity sha512-ZVA9k326Nwrj3Cj9jlh3wGFutC2ZornPNARZwsNYqQYgN0EsV2d53w5RN/co65Ohn4sUAUtb1rSUAOD6XN9idA== -acorn@^7.1.1: - version "7.1.1" - resolved "https://registry.yarnpkg.com/acorn/-/acorn-7.1.1.tgz#e35668de0b402f359de515c5482a1ab9f89a69bf" - integrity sha512-add7dgA5ppRPxCFJoAGfMDi7PIBXq1RtGo7BhbLaxwrXPOmw8gq48Y9ozT01hUKy9byMjlR20EJhu5zlkErEkg== - after@0.8.2: version "0.8.2" resolved "https://registry.yarnpkg.com/after/-/after-0.8.2.tgz#fedb394f9f0e02aa9768e702bda23b505fae7e1f" diff --git a/tfjs-core/src/ops/reduction_ops_test.ts b/tfjs-core/src/ops/reduction_ops_test.ts index a26623ff45b..87b76065f30 100644 --- a/tfjs-core/src/ops/reduction_ops_test.ts +++ b/tfjs-core/src/ops/reduction_ops_test.ts @@ -1482,7 +1482,7 @@ describeWithFlags('norm', ALL_ENVS, () => { expectArraysEqual(await norm.data(), NaN); }); - it('axis=null in 2D array norm', async () => { + fit('axis=null in 2D array norm', async () => { const a = tf.tensor2d([1, 2, 3, 0, 0, 1], [3, 2]); const norm = tf.norm(a, Infinity); diff --git a/tfjs-core/src/tape.ts b/tfjs-core/src/tape.ts index fe09850d3e7..7dd34a48479 100644 --- a/tfjs-core/src/tape.ts +++ b/tfjs-core/src/tape.ts @@ -43,7 +43,6 @@ export type NamedGradientMap = { */ export function getFilteredNodesXToY( tape: TapeNode[], xs: Tensor[], y: Tensor): TapeNode[] { - console.log('GETTING NODES FROM X TO Y'); // Forward pass to compute all the nodes and Tensors that are transitively a // function of x. const tensorsFromX: {[tensorId: number]: boolean} = {}; @@ -132,11 +131,9 @@ export function getFilteredNodesXToY( export function backpropagateGradients( tensorAccumulatedGradientMap: {[tensorId: number]: Tensor}, filteredTape: TapeNode[], tidy: (f: Function) => Tensor) { - console.log('BACKPROPPING'); // Walk the tape backward and keep a map of Tensor to its gradient. for (let i = filteredTape.length - 1; i >= 0; i--) { const node = filteredTape[i]; - console.log('tape node', node.kernelName); const dys: Tensor[] = []; node.outputs.forEach(o => { From 87b5dbc2f1db89fa391b8f7589408dfdf3a2b320 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Thu, 23 Apr 2020 12:14:43 -0400 Subject: [PATCH 34/72] add failing tests --- tfjs-backend-webgl/src/kernels/Max.ts | 1 + tfjs-core/src/ops/max.ts | 1 + tfjs-core/src/ops/reduction_ops_test.ts | 8 ++++---- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/tfjs-backend-webgl/src/kernels/Max.ts b/tfjs-backend-webgl/src/kernels/Max.ts index e371e4a03b1..6dfa292a403 100644 --- a/tfjs-backend-webgl/src/kernels/Max.ts +++ b/tfjs-backend-webgl/src/kernels/Max.ts @@ -38,6 +38,7 @@ export const maxConfig: KernelConfig = { let axes = origAxes; const permutedAxes = backend_util.getAxesPermutation(axes, xRank); if (permutedAxes != null) { + console.log('TRANSPOSE IN WEBGL MAX'); x = transposeImpl(x, permutedAxes, webglBackend); axes = backend_util.getInnerMostAxes(axes.length, xRank); } diff --git a/tfjs-core/src/ops/max.ts b/tfjs-core/src/ops/max.ts index afe04308e04..1909de35ec5 100644 --- a/tfjs-core/src/ops/max.ts +++ b/tfjs-core/src/ops/max.ts @@ -64,6 +64,7 @@ function max_( const origAxes = util.parseAxisParam(axis, $x.shape); const forward = (backend: KernelBackend, save: GradSaveFunc) => { + console.log('running forward func'); let axes = origAxes; const permutedAxes = axis_util.getAxesPermutation(axes, $x.rank); if (permutedAxes != null) { diff --git a/tfjs-core/src/ops/reduction_ops_test.ts b/tfjs-core/src/ops/reduction_ops_test.ts index 87b76065f30..2c742778454 100644 --- a/tfjs-core/src/ops/reduction_ops_test.ts +++ b/tfjs-core/src/ops/reduction_ops_test.ts @@ -1482,7 +1482,7 @@ describeWithFlags('norm', ALL_ENVS, () => { expectArraysEqual(await norm.data(), NaN); }); - fit('axis=null in 2D array norm', async () => { + it('axis=null in 2D array norm', async () => { const a = tf.tensor2d([1, 2, 3, 0, 0, 1], [3, 2]); const norm = tf.norm(a, Infinity); @@ -1500,7 +1500,7 @@ describeWithFlags('norm', ALL_ENVS, () => { expectArraysClose(await norm.data(), [3]); }); - it('axis=0 in 2D array norm', async () => { + fit('axis=0 in 2D array norm', async () => { const a = tf.tensor2d([1, 2, 3, 0, 0, 1], [3, 2]); const norm = tf.norm(a, Infinity, [0]); @@ -1599,7 +1599,7 @@ describeWithFlags('norm', ALL_ENVS, () => { expectArraysClose(await norm.data(), [3]); }); - it('axis=0,1 in 4D array norm', async () => { + fit('axis=0,1 in 4D array norm', async () => { const a = tf.tensor4d( [ 1, 2, 3, 0, 0, 1, 1, 2, 3, 0, 0, 1, @@ -1613,7 +1613,7 @@ describeWithFlags('norm', ALL_ENVS, () => { expectArraysClose(await norm.data(), [4, 3, 4, 3]); }); - it('axis=0,1 in 4D array norm', async () => { + fit('axis=0,1 in 4D array norm', async () => { const a = tf.tensor4d( [ 1, 2, 3, 0, 0, 1, 1, 2, 3, 0, 0, 1, From 018bb6e28d562c0c613de818ffbbd39e394d6092 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Mon, 27 Apr 2020 13:09:11 -0400 Subject: [PATCH 35/72] dispose --- tfjs-backend-webgl/src/kernels/Max.ts | 22 +++++++++++++--------- tfjs-core/src/ops/reduction_ops_test.ts | 4 ++-- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/tfjs-backend-webgl/src/kernels/Max.ts b/tfjs-backend-webgl/src/kernels/Max.ts index 6dfa292a403..c826dcccc84 100644 --- a/tfjs-backend-webgl/src/kernels/Max.ts +++ b/tfjs-backend-webgl/src/kernels/Max.ts @@ -27,30 +27,30 @@ export const maxConfig: KernelConfig = { kernelName: Max, backendName: 'webgl', kernelFunc: ({inputs, attrs, backend}) => { - let {x} = inputs as MaxInputs; + const {x} = inputs as MaxInputs; const {reductionIndices} = attrs as {} as MaxAttrs; const webglBackend = backend as MathBackendWebGL; - console.log('max webgl kernel func', x, reductionIndices); const xRank = x.shape.length; const origAxes = util.parseAxisParam(reductionIndices, x.shape); let axes = origAxes; const permutedAxes = backend_util.getAxesPermutation(axes, xRank); - if (permutedAxes != null) { - console.log('TRANSPOSE IN WEBGL MAX'); - x = transposeImpl(x, permutedAxes, webglBackend); + const maxInputIsTransposed = permutedAxes != null; + + let maxInput = x; + if (maxInputIsTransposed) { + maxInput = transposeImpl(x, permutedAxes, webglBackend); axes = backend_util.getInnerMostAxes(axes.length, xRank); } backend_util.assertAxesAreInnerMostDims('max', axes, xRank); const [outShape, reduceShape] = - backend_util.computeOutAndReduceShapes(x.shape, axes); + backend_util.computeOutAndReduceShapes(maxInput.shape, axes); let out; if (webglBackend.shouldExecuteOnCPU([x])) { - console.log('running on the cpu instead'); - const xTexData = webglBackend.texData.get(x.dataId); + const xTexData = webglBackend.texData.get(maxInput.dataId); const values = xTexData.values as TypedArray; const outValues = maxImplCPU( values, util.sizeFromShape(reduceShape), outShape, x.dtype); @@ -59,7 +59,11 @@ export const maxConfig: KernelConfig = { const outData = webglBackend.texData.get(out.dataId); outData.values = outValues; } else { - out = maxImpl(x, reduceShape, outShape, webglBackend); + out = maxImpl(maxInput, reduceShape, outShape, webglBackend); + } + + if (maxInputIsTransposed) { + webglBackend.disposeData(maxInput.dataId); } return out; diff --git a/tfjs-core/src/ops/reduction_ops_test.ts b/tfjs-core/src/ops/reduction_ops_test.ts index 2c742778454..cdff0c42699 100644 --- a/tfjs-core/src/ops/reduction_ops_test.ts +++ b/tfjs-core/src/ops/reduction_ops_test.ts @@ -1599,7 +1599,7 @@ describeWithFlags('norm', ALL_ENVS, () => { expectArraysClose(await norm.data(), [3]); }); - fit('axis=0,1 in 4D array norm', async () => { + it('axis=0,1 in 4D array norm', async () => { const a = tf.tensor4d( [ 1, 2, 3, 0, 0, 1, 1, 2, 3, 0, 0, 1, @@ -1613,7 +1613,7 @@ describeWithFlags('norm', ALL_ENVS, () => { expectArraysClose(await norm.data(), [4, 3, 4, 3]); }); - fit('axis=0,1 in 4D array norm', async () => { + it('axis=0,1 in 4D array norm', async () => { const a = tf.tensor4d( [ 1, 2, 3, 0, 0, 1, 1, 2, 3, 0, 0, 1, From 02278761c9db1a9fc5439f8d9416a38cd245fcea Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Mon, 27 Apr 2020 13:13:16 -0400 Subject: [PATCH 36/72] remove fit --- tfjs-core/src/ops/reduction_ops_test.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tfjs-core/src/ops/reduction_ops_test.ts b/tfjs-core/src/ops/reduction_ops_test.ts index cdff0c42699..a26623ff45b 100644 --- a/tfjs-core/src/ops/reduction_ops_test.ts +++ b/tfjs-core/src/ops/reduction_ops_test.ts @@ -1500,7 +1500,7 @@ describeWithFlags('norm', ALL_ENVS, () => { expectArraysClose(await norm.data(), [3]); }); - fit('axis=0 in 2D array norm', async () => { + it('axis=0 in 2D array norm', async () => { const a = tf.tensor2d([1, 2, 3, 0, 0, 1], [3, 2]); const norm = tf.norm(a, Infinity, [0]); From cef4a01548b2609ce1d0c2d078f13b2f24802b1b Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Mon, 27 Apr 2020 13:30:15 -0400 Subject: [PATCH 37/72] fix wasm --- tfjs-backend-wasm/src/index_test.ts | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/tfjs-backend-wasm/src/index_test.ts b/tfjs-backend-wasm/src/index_test.ts index 47a24e5bd05..fb1404f8416 100644 --- a/tfjs-backend-wasm/src/index_test.ts +++ b/tfjs-backend-wasm/src/index_test.ts @@ -58,8 +58,8 @@ describeWithFlags('wasm init', BROWSER_ENVS, () => { }, 100); // Silences backend registration warnings. - // spyOn(console, 'warn'); - // spyOn(console, 'log'); + spyOn(console, 'warn'); + spyOn(console, 'log'); }); afterEach(() => { @@ -121,11 +121,4 @@ describeWithFlags('wasm init', BROWSER_ENVS, () => { expect(() => setWasmPath('too/late')) .toThrowError(/The WASM backend was already initialized. Make sure/); }); - - fit('max', async () => { - const a = tf.tensor1d([3, -1, 0, 100, -7, 2]); - const r = tf.max(a); - const data = await r.data(); - console.log(data); - }); }); From 7ce546ca425cc6a113f5f1952dc4b797f2eaaa1f Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Tue, 28 Apr 2020 13:34:40 -0400 Subject: [PATCH 38/72] undo --- tfjs-backend-wasm/src/kernels/Max.ts | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/tfjs-backend-wasm/src/kernels/Max.ts b/tfjs-backend-wasm/src/kernels/Max.ts index 32486359125..ab688b76c80 100644 --- a/tfjs-backend-wasm/src/kernels/Max.ts +++ b/tfjs-backend-wasm/src/kernels/Max.ts @@ -24,7 +24,7 @@ interface MaxInputs extends NamedTensorInfoMap { } interface MaxAttrs extends NamedAttrMap { - reductionIndices: number[]; + axes: number[]; } let wasmMax: (xId: number, reduceSize: number, outId: number) => void; @@ -37,17 +37,10 @@ function setup(backend: BackendWasm): void { function max(args: {backend: BackendWasm, inputs: MaxInputs, attrs: MaxAttrs}): TensorInfo { const {backend, inputs, attrs} = args; - const {reductionIndices} = attrs; + const {axes} = attrs; const {x} = inputs; const xId = backend.dataIdMap.get(x.dataId).id; - const origAxes = util.parseAxisParam(reductionIndices, x.shape); - let axes = origAxes; - const permutedAxes = backend_util.getAxesPermutation(axes, x.shape.length); - if (permutedAxes != null) { - console.log('TRANSPOSE'); - } - backend_util.assertAxesAreInnerMostDims('max', axes, x.shape.length); const [outShape, reduceShape] = backend_util.computeOutAndReduceShapes(x.shape, axes); From 5f687acaaf7105f567fdd7744e59e5da3c994bf0 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Tue, 28 Apr 2020 13:58:07 -0400 Subject: [PATCH 39/72] change api --- tfjs-backend-webgl/src/kernels/Max.ts | 9 +++++---- tfjs-backend-webgl/src/kernels/Transpose.ts | 2 +- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/tfjs-backend-webgl/src/kernels/Max.ts b/tfjs-backend-webgl/src/kernels/Max.ts index c826dcccc84..2ea53e46e00 100644 --- a/tfjs-backend-webgl/src/kernels/Max.ts +++ b/tfjs-backend-webgl/src/kernels/Max.ts @@ -15,13 +15,14 @@ * ============================================================================= */ +// tslint:disable-next-line: no-imports-from-dist import {Max, MaxAttrs, MaxInputs} from '@tensorflow/tfjs-core'; import {backend_util, KernelConfig, TypedArray, util} from '@tensorflow/tfjs-core'; import {MathBackendWebGL} from '../backend_webgl'; -import {maxImpl, maxImplCPU} from './Max_impl'; -import {transposeImpl} from './Transpose_impl'; +import {maxImpl} from './Max_impl'; +import {transposeImpl, transposeImplCPU} from './Transpose_impl'; export const maxConfig: KernelConfig = { kernelName: Max, @@ -52,8 +53,8 @@ export const maxConfig: KernelConfig = { if (webglBackend.shouldExecuteOnCPU([x])) { const xTexData = webglBackend.texData.get(maxInput.dataId); const values = xTexData.values as TypedArray; - const outValues = maxImplCPU( - values, util.sizeFromShape(reduceShape), outShape, x.dtype); + const outValues = + transposeImplCPU(values, x.shape, x.dtype, permutedAxes); out = webglBackend.makeTensorInfo(outShape, x.dtype); const outData = webglBackend.texData.get(out.dataId); diff --git a/tfjs-backend-webgl/src/kernels/Transpose.ts b/tfjs-backend-webgl/src/kernels/Transpose.ts index 7493ad4467c..cb5e31079c3 100644 --- a/tfjs-backend-webgl/src/kernels/Transpose.ts +++ b/tfjs-backend-webgl/src/kernels/Transpose.ts @@ -41,7 +41,7 @@ export const transposeConfig: KernelConfig = { if (webglBackend.shouldExecuteOnCPU([x])) { const xTexData = webglBackend.texData.get(x.dataId); const values = xTexData.values as TypedArray; - const outValues = cpuTranspose(values, x.shape, x.dtype, perm, newShape); + const outValues = cpuTranspose(values, x.shape, x.dtype, perm); out = webglBackend.makeTensorInfo(newShape, x.dtype); const outData = webglBackend.texData.get(out.dataId); From 3f086ea92fd75f3f1837cc3b0d349b1e0692519d Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Tue, 28 Apr 2020 13:59:31 -0400 Subject: [PATCH 40/72] fix --- tfjs-backend-webgl/src/kernels/Max_impl.ts | 23 +--------------------- 1 file changed, 1 insertion(+), 22 deletions(-) diff --git a/tfjs-backend-webgl/src/kernels/Max_impl.ts b/tfjs-backend-webgl/src/kernels/Max_impl.ts index 257bddc8223..0bfc864a81a 100644 --- a/tfjs-backend-webgl/src/kernels/Max_impl.ts +++ b/tfjs-backend-webgl/src/kernels/Max_impl.ts @@ -15,7 +15,7 @@ * ============================================================================= */ -import {DataType, NumericDataType, TensorInfo, TypedArray, util} from '@tensorflow/tfjs-core'; +import {TensorInfo, util} from '@tensorflow/tfjs-core'; import {MathBackendWebGL} from '../backend_webgl'; import {reduce} from '../kernel_utils/reduce'; @@ -34,24 +34,3 @@ export const maxImpl = backend), outShape, backend); }; - -// todo(@annxingyuan) import this from cpu backend. -export function maxImplCPU( - aVals: TypedArray, reduceSize: number, outShape: number[], - dtype: DataType): TypedArray { - const vals = util.getTypedArrayFromDType( - dtype as NumericDataType, util.sizeFromShape(outShape)); - - for (let i = 0; i < vals.length; ++i) { - const offset = i * reduceSize; - let max = aVals[offset]; - for (let j = 0; j < reduceSize; ++j) { - const value = aVals[offset + j]; - if (value > max) { - max = value; - } - } - vals[i] = max; - } - return vals; -} From c3f23946d621180d354f74d24354525b259fa6b2 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Tue, 28 Apr 2020 17:31:17 -0400 Subject: [PATCH 41/72] rm dist --- tfjs-backend-webgl/src/kernels/Max.ts | 1 - 1 file changed, 1 deletion(-) diff --git a/tfjs-backend-webgl/src/kernels/Max.ts b/tfjs-backend-webgl/src/kernels/Max.ts index 2ea53e46e00..a1901c669ec 100644 --- a/tfjs-backend-webgl/src/kernels/Max.ts +++ b/tfjs-backend-webgl/src/kernels/Max.ts @@ -15,7 +15,6 @@ * ============================================================================= */ -// tslint:disable-next-line: no-imports-from-dist import {Max, MaxAttrs, MaxInputs} from '@tensorflow/tfjs-core'; import {backend_util, KernelConfig, TypedArray, util} from '@tensorflow/tfjs-core'; From 7d85575ffecd8de58a7b8016d917cb79ff231784 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Wed, 29 Apr 2020 09:36:43 -0400 Subject: [PATCH 42/72] export shared --- tfjs-backend-cpu/src/index.ts | 3 ++- tfjs-backend-webgl/src/kernels/Transpose_impl.ts | 5 +++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/tfjs-backend-cpu/src/index.ts b/tfjs-backend-cpu/src/index.ts index 2d87de514fc..f85fd6e58cb 100644 --- a/tfjs-backend-cpu/src/index.ts +++ b/tfjs-backend-cpu/src/index.ts @@ -19,5 +19,6 @@ import {registerBackend} from '@tensorflow/tfjs-core'; import {MathBackendCPU} from './backend_cpu'; registerBackend('cpu', () => new MathBackendCPU(), 1 /* priority */); import './register_all_kernels'; +import * as shared from './shared'; -export {MathBackendCPU}; +export {MathBackendCPU, shared}; diff --git a/tfjs-backend-webgl/src/kernels/Transpose_impl.ts b/tfjs-backend-webgl/src/kernels/Transpose_impl.ts index 7f0284e4e54..f5edeb3cab3 100644 --- a/tfjs-backend-webgl/src/kernels/Transpose_impl.ts +++ b/tfjs-backend-webgl/src/kernels/Transpose_impl.ts @@ -15,14 +15,15 @@ * ============================================================================= */ -// tslint:disable-next-line: no-imports-from-dist -import {transposeImpl as transposeImplCPU} from '@tensorflow/tfjs-backend-cpu/dist/shared'; +import {shared} from '@tensorflow/tfjs-backend-cpu'; import {env, TensorInfo} from '@tensorflow/tfjs-core'; import {MathBackendWebGL} from '../backend_webgl'; import {TransposeProgram} from '../transpose_gpu'; import {TransposePackedProgram} from '../transpose_packed_gpu'; +const transposeImplCPU = shared.transposeImpl; + export function transposeImpl( x: TensorInfo, perm: number[], backend: MathBackendWebGL): TensorInfo { const program = env().getBool('WEBGL_PACK_ARRAY_OPERATIONS') ? From 3008a319188ef9c6d7fef8c2fddd4d2d1275e294 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Wed, 29 Apr 2020 10:10:36 -0400 Subject: [PATCH 43/72] cup forward --- tfjs-backend-cpu/src/kernels/Max.ts | 6 ++-- tfjs-backend-cpu/src/kernels/Transpose.ts | 2 +- .../src/kernels/Transpose_impl.ts | 9 ++---- tfjs-backend-cpu/src/shared.ts | 1 + tfjs-backend-webgl/src/kernels/Max.ts | 28 ++++++++++++++++--- tfjs-backend-webgl/src/kernels/Transpose.ts | 2 +- tfjs-core/src/ops/max.ts | 3 -- 7 files changed, 32 insertions(+), 19 deletions(-) diff --git a/tfjs-backend-cpu/src/kernels/Max.ts b/tfjs-backend-cpu/src/kernels/Max.ts index cd41de1f100..f43e385b1ba 100644 --- a/tfjs-backend-cpu/src/kernels/Max.ts +++ b/tfjs-backend-cpu/src/kernels/Max.ts @@ -40,14 +40,14 @@ export const maxConfig: KernelConfig = { const permutedAxes = backend_util.getAxesPermutation(axes, xRank); let xVals = cpuBackend.data.get(x.dataId).values as TypedArray; if (permutedAxes != null) { - xVals = transposeImpl(xVals, xShape, x.dtype, permutedAxes); - axes = backend_util.getInnerMostAxes(axes.length, xRank); - const newShape: number[] = new Array(xRank); for (let i = 0; i < newShape.length; i++) { newShape[i] = xShape[permutedAxes[i]]; } + xVals = transposeImpl(xVals, xShape, x.dtype, permutedAxes, newShape); + axes = backend_util.getInnerMostAxes(axes.length, xRank); + xShape = newShape; } diff --git a/tfjs-backend-cpu/src/kernels/Transpose.ts b/tfjs-backend-cpu/src/kernels/Transpose.ts index c1c68be78f2..05ae13f03cb 100644 --- a/tfjs-backend-cpu/src/kernels/Transpose.ts +++ b/tfjs-backend-cpu/src/kernels/Transpose.ts @@ -41,7 +41,7 @@ export const transposeConfig: KernelConfig = { } const values = cpuBackend.data.get(x.dataId).values as TypedArray; - const result = transposeImpl(values, x.shape, x.dtype, perm); + const result = transposeImpl(values, x.shape, x.dtype, perm, newShape); const dataId = cpuBackend.write(result, newShape, x.dtype); return {dataId, shape: newShape, dtype: x.dtype}; diff --git a/tfjs-backend-cpu/src/kernels/Transpose_impl.ts b/tfjs-backend-cpu/src/kernels/Transpose_impl.ts index b7d78decb03..ffa2afa888a 100644 --- a/tfjs-backend-cpu/src/kernels/Transpose_impl.ts +++ b/tfjs-backend-cpu/src/kernels/Transpose_impl.ts @@ -19,14 +19,9 @@ import {DataType, NumericDataType, TypedArray} from '@tensorflow/tfjs-core'; import {util} from '@tensorflow/tfjs-core'; export function transposeImpl( - xVals: TypedArray, xShape: number[], dtype: DataType, - perm: number[]): TypedArray { + xVals: TypedArray, xShape: number[], dtype: DataType, perm: number[], + newShape: number[]): TypedArray { const xRank = xShape.length; - const newShape: number[] = new Array(xRank); - for (let i = 0; i < newShape.length; i++) { - newShape[i] = xShape[perm[i]]; - } - const xSize = util.sizeFromShape(xShape); const xStrides = util.computeStrides(xShape); const newStrides = util.computeStrides(newShape); diff --git a/tfjs-backend-cpu/src/shared.ts b/tfjs-backend-cpu/src/shared.ts index 07572130c1f..09e7d1d5abe 100644 --- a/tfjs-backend-cpu/src/shared.ts +++ b/tfjs-backend-cpu/src/shared.ts @@ -16,4 +16,5 @@ */ // Shared kernel impls for use in other backends. +export {maxImpl} from './kernels/Max_impl'; export {transposeImpl} from './kernels/Transpose_impl'; diff --git a/tfjs-backend-webgl/src/kernels/Max.ts b/tfjs-backend-webgl/src/kernels/Max.ts index a1901c669ec..dadc88784b3 100644 --- a/tfjs-backend-webgl/src/kernels/Max.ts +++ b/tfjs-backend-webgl/src/kernels/Max.ts @@ -15,6 +15,7 @@ * ============================================================================= */ +import {shared as cpuSharedImpls} from '@tensorflow/tfjs-backend-cpu'; import {Max, MaxAttrs, MaxInputs} from '@tensorflow/tfjs-core'; import {backend_util, KernelConfig, TypedArray, util} from '@tensorflow/tfjs-core'; @@ -37,10 +38,28 @@ export const maxConfig: KernelConfig = { let axes = origAxes; const permutedAxes = backend_util.getAxesPermutation(axes, xRank); const maxInputIsTransposed = permutedAxes != null; + const shouldExecuteOnCPU = webglBackend.shouldExecuteOnCPU([x]); let maxInput = x; if (maxInputIsTransposed) { - maxInput = transposeImpl(x, permutedAxes, webglBackend); + if (shouldExecuteOnCPU) { + const xTexData = webglBackend.texData.get(maxInput.dataId); + const values = xTexData.values as TypedArray; + + const newShape: number[] = new Array(xRank); + for (let i = 0; i < newShape.length; i++) { + newShape[i] = x.shape[permutedAxes[i]]; + } + const maxInputValues = + transposeImplCPU(values, x.shape, x.dtype, permutedAxes, newShape); + + maxInput = webglBackend.makeTensorInfo(newShape, x.dtype); + const maxInputData = webglBackend.texData.get(maxInput.dataId); + maxInputData.values = maxInputValues; + } else { + maxInput = transposeImpl(x, permutedAxes, webglBackend); + } + axes = backend_util.getInnerMostAxes(axes.length, xRank); } @@ -49,11 +68,12 @@ export const maxConfig: KernelConfig = { backend_util.computeOutAndReduceShapes(maxInput.shape, axes); let out; - if (webglBackend.shouldExecuteOnCPU([x])) { + if (shouldExecuteOnCPU) { const xTexData = webglBackend.texData.get(maxInput.dataId); const values = xTexData.values as TypedArray; - const outValues = - transposeImplCPU(values, x.shape, x.dtype, permutedAxes); + + const outValues = cpuSharedImpls.maxImpl( + values, util.sizeFromShape(reduceShape), outShape, x.dtype); out = webglBackend.makeTensorInfo(outShape, x.dtype); const outData = webglBackend.texData.get(out.dataId); diff --git a/tfjs-backend-webgl/src/kernels/Transpose.ts b/tfjs-backend-webgl/src/kernels/Transpose.ts index cb5e31079c3..7493ad4467c 100644 --- a/tfjs-backend-webgl/src/kernels/Transpose.ts +++ b/tfjs-backend-webgl/src/kernels/Transpose.ts @@ -41,7 +41,7 @@ export const transposeConfig: KernelConfig = { if (webglBackend.shouldExecuteOnCPU([x])) { const xTexData = webglBackend.texData.get(x.dataId); const values = xTexData.values as TypedArray; - const outValues = cpuTranspose(values, x.shape, x.dtype, perm); + const outValues = cpuTranspose(values, x.shape, x.dtype, perm, newShape); out = webglBackend.makeTensorInfo(newShape, x.dtype); const outData = webglBackend.texData.get(out.dataId); diff --git a/tfjs-core/src/ops/max.ts b/tfjs-core/src/ops/max.ts index 1909de35ec5..38154ef949e 100644 --- a/tfjs-core/src/ops/max.ts +++ b/tfjs-core/src/ops/max.ts @@ -57,9 +57,6 @@ import {transpose} from './transpose'; /** @doc {heading: 'Operations', subheading: 'Reduction'} */ function max_( x: Tensor|TensorLike, axis: number|number[] = null, keepDims = false): T { - console.log('max op'); - console.log(x); - console.log(axis); let $x = convertToTensor(x, 'x', 'max'); const origAxes = util.parseAxisParam(axis, $x.shape); From 7106795fa8dfaa61fa762c13da7e3a4b638b6281 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Wed, 29 Apr 2020 10:14:58 -0400 Subject: [PATCH 44/72] import max --- tfjs-backend-cpu/src/backend_cpu.ts | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tfjs-backend-cpu/src/backend_cpu.ts b/tfjs-backend-cpu/src/backend_cpu.ts index fc819ecb2cd..afd7d33b08b 100644 --- a/tfjs-backend-cpu/src/backend_cpu.ts +++ b/tfjs-backend-cpu/src/backend_cpu.ts @@ -18,7 +18,7 @@ import * as tf from '@tensorflow/tfjs-core'; import {engine, env} from '@tensorflow/tfjs-core'; import {backend_util, buffer, slice_util, util} from '@tensorflow/tfjs-core'; -import {BackendTimingInfo, DataStorage, DataType, DataValues, KernelBackend, NumericDataType, Rank, Scalar, ShapeMap, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D, TensorBuffer, TypedArray, upcastType} from '@tensorflow/tfjs-core'; +import {BackendTimingInfo, DataStorage, DataType, DataValues, KernelBackend, max, NumericDataType, Rank, Scalar, ShapeMap, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D, TensorBuffer, TypedArray, upcastType} from '@tensorflow/tfjs-core'; import {kernel_impls} from '@tensorflow/tfjs-core'; const nonMaxSuppressionV3 = kernel_impls.nonMaxSuppressionV3; @@ -367,7 +367,7 @@ export class MathBackendCPU extends KernelBackend { const axes = util.parseAxisParam([dim], logits.shape); // TODO(annxingyuan): Call maxImpl rather than op as part of softmax kernel // modularization. - const maxLogit = tf.max(logits, axes); + const maxLogit = max(logits, axes); const expandedShape = backend_util.expandShapeToKeepDim(maxLogit.shape, axes); const a = this.subtract(logits, maxLogit.reshape(expandedShape)); From 06ff0a201a8d654d04eaff8ad8b980d7fea86579 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Wed, 29 Apr 2020 10:19:37 -0400 Subject: [PATCH 45/72] remove logs --- tfjs-core/src/gradients/Max_grad.ts | 1 - tfjs-core/src/gradients/Transpose_grad.ts | 10 +--------- tfjs-core/src/ops/reduction_ops.ts | 8 +------- tfjs-core/src/tape.ts | 1 - 4 files changed, 2 insertions(+), 18 deletions(-) diff --git a/tfjs-core/src/gradients/Max_grad.ts b/tfjs-core/src/gradients/Max_grad.ts index bbc7382e0d8..189892beefc 100644 --- a/tfjs-core/src/gradients/Max_grad.ts +++ b/tfjs-core/src/gradients/Max_grad.ts @@ -28,7 +28,6 @@ export const maxGradConfig: GradConfig = { inputsToSave: ['x'], outputsToSave: [true], gradFunc: (dy: Tensor, saved: Tensor[], attrs: NamedAttrMap) => { - console.log('MAX GRAD FUNC'); const maxAttrs: MaxAttrs = attrs as {} as MaxAttrs; const {reductionIndices} = maxAttrs; const [x, y] = saved; diff --git a/tfjs-core/src/gradients/Transpose_grad.ts b/tfjs-core/src/gradients/Transpose_grad.ts index 4c03e1cf6b7..4d48500d7de 100644 --- a/tfjs-core/src/gradients/Transpose_grad.ts +++ b/tfjs-core/src/gradients/Transpose_grad.ts @@ -27,14 +27,6 @@ export const transposeGradConfig: GradConfig = { const transposeAttrs: TransposeAttrs = attrs as {} as TransposeAttrs; const {perm} = transposeAttrs; const undoPerm = axis_util.getUndoAxesPermutation(perm); - return { - x: () => { - console.log('IN TRANSPOSE GRADIENT'); - console.log(dy.shape); - const out = transpose(dy, undoPerm); - console.log(out.shape); - return out; - } - }; + return {x: () => transpose(dy, undoPerm)}; } }; diff --git a/tfjs-core/src/ops/reduction_ops.ts b/tfjs-core/src/ops/reduction_ops.ts index 673b3f4d978..68fb6be6a45 100644 --- a/tfjs-core/src/ops/reduction_ops.ts +++ b/tfjs-core/src/ops/reduction_ops.ts @@ -272,9 +272,6 @@ function mean_( */ export function gradForMinAndMax( dy: T, y: T, xOrig: Tensor, origAxes: number[], permutedAxes: number[]) { - console.log('GRAD FRO MIN AND MAX'); - console.log('orig axes', origAxes); - console.log(permutedAxes); if (y.rank < xOrig.rank) { y = y.reshape(axis_util.expandShapeToKeepDim(y.shape, origAxes)) as T; } @@ -283,11 +280,8 @@ export function gradForMinAndMax( } return { x: () => { - console.log('INVOKING'); const dx = dy.mul(xOrig.equal(y).cast(dy.dtype)); - const out = permutedAxes == null ? dx : dx.transpose(permutedAxes); - console.log('out', out.shape); - return out; + return permutedAxes == null ? dx : dx.transpose(permutedAxes); } }; } diff --git a/tfjs-core/src/tape.ts b/tfjs-core/src/tape.ts index 7dd34a48479..8f78106f150 100644 --- a/tfjs-core/src/tape.ts +++ b/tfjs-core/src/tape.ts @@ -165,7 +165,6 @@ export function backpropagateGradients( // Call the gradient function. const dx = tidy(() => inputGradients[inputName]()); - if (dx.dtype !== 'float32') { throw new Error( `Error in gradient for op ${ From 08abe3f098db653f054349767fb4ccc1f2e9de57 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Wed, 29 Apr 2020 10:20:17 -0400 Subject: [PATCH 46/72] remove log --- tfjs-core/src/ops/max.ts | 1 - 1 file changed, 1 deletion(-) diff --git a/tfjs-core/src/ops/max.ts b/tfjs-core/src/ops/max.ts index 38154ef949e..fa66e9d3d52 100644 --- a/tfjs-core/src/ops/max.ts +++ b/tfjs-core/src/ops/max.ts @@ -61,7 +61,6 @@ function max_( const origAxes = util.parseAxisParam(axis, $x.shape); const forward = (backend: KernelBackend, save: GradSaveFunc) => { - console.log('running forward func'); let axes = origAxes; const permutedAxes = axis_util.getAxesPermutation(axes, $x.rank); if (permutedAxes != null) { From 714742c97e1f9317e8da900a7db2d9516c2976bc Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Wed, 29 Apr 2020 10:32:18 -0400 Subject: [PATCH 47/72] extract --- tfjs-core/src/gradients/Max_grad.ts | 2 +- tfjs-core/src/ops/reduction_ops.ts | 22 ++------------ tfjs-core/src/ops/reduction_ops_util.ts | 38 +++++++++++++++++++++++++ 3 files changed, 42 insertions(+), 20 deletions(-) create mode 100644 tfjs-core/src/ops/reduction_ops_util.ts diff --git a/tfjs-core/src/gradients/Max_grad.ts b/tfjs-core/src/gradients/Max_grad.ts index 189892beefc..0a126ea223c 100644 --- a/tfjs-core/src/gradients/Max_grad.ts +++ b/tfjs-core/src/gradients/Max_grad.ts @@ -18,7 +18,7 @@ import {Max, MaxAttrs} from '../kernel_names'; import {GradConfig, NamedAttrMap} from '../kernel_registry'; import * as axis_util from '../ops/axis_util'; -import {gradForMinAndMax} from '../ops/reduction_ops'; +import {gradForMinAndMax} from '../ops/reduction_ops_util'; import {transpose} from '../ops/transpose'; import {Tensor} from '../tensor'; import * as util from '../util'; diff --git a/tfjs-core/src/ops/reduction_ops.ts b/tfjs-core/src/ops/reduction_ops.ts index 68fb6be6a45..bc97ed6e3bd 100644 --- a/tfjs-core/src/ops/reduction_ops.ts +++ b/tfjs-core/src/ops/reduction_ops.ts @@ -21,10 +21,13 @@ import {Tensor} from '../tensor'; import {convertToTensor} from '../tensor_util_env'; import {TensorLike} from '../types'; import * as util from '../util'; + import * as axis_util from './axis_util'; import {op} from './operation'; +import {gradForMinAndMax} from './reduction_ops_util'; import {ones, scalar, zerosLike} from './tensor_ops'; + /** * Computes the log(sum(exp(elements across the reduction dimensions)). * @@ -267,25 +270,6 @@ function mean_( return customOp($x) as T; } -/** - * Gradient helper function for the min and max operations. - */ -export function gradForMinAndMax( - dy: T, y: T, xOrig: Tensor, origAxes: number[], permutedAxes: number[]) { - if (y.rank < xOrig.rank) { - y = y.reshape(axis_util.expandShapeToKeepDim(y.shape, origAxes)) as T; - } - if (dy.rank < xOrig.rank) { - dy = dy.reshape(axis_util.expandShapeToKeepDim(dy.shape, origAxes)) as T; - } - return { - x: () => { - const dx = dy.mul(xOrig.equal(y).cast(dy.dtype)); - return permutedAxes == null ? dx : dx.transpose(permutedAxes); - } - }; -} - /** * Computes the minimum value from the input. * diff --git a/tfjs-core/src/ops/reduction_ops_util.ts b/tfjs-core/src/ops/reduction_ops_util.ts new file mode 100644 index 00000000000..375ae223f89 --- /dev/null +++ b/tfjs-core/src/ops/reduction_ops_util.ts @@ -0,0 +1,38 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Tensor} from '../tensor'; +import * as axis_util from './axis_util'; + +/** + * Gradient helper function for the min and max operations. + */ +export function gradForMinAndMax( + dy: T, y: T, xOrig: Tensor, origAxes: number[], permutedAxes: number[]) { + if (y.rank < xOrig.rank) { + y = y.reshape(axis_util.expandShapeToKeepDim(y.shape, origAxes)) as T; + } + if (dy.rank < xOrig.rank) { + dy = dy.reshape(axis_util.expandShapeToKeepDim(dy.shape, origAxes)) as T; + } + return { + x: () => { + const dx = dy.mul(xOrig.equal(y).cast(dy.dtype)); + return permutedAxes == null ? dx : dx.transpose(permutedAxes); + } + }; +} From 1837abe75a8c6598e6bd872bada6c4652546d206 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Wed, 29 Apr 2020 10:43:09 -0400 Subject: [PATCH 48/72] lint --- tfjs-core/src/ops/reduction_ops.ts | 1 - 1 file changed, 1 deletion(-) diff --git a/tfjs-core/src/ops/reduction_ops.ts b/tfjs-core/src/ops/reduction_ops.ts index bc97ed6e3bd..3f871d8a111 100644 --- a/tfjs-core/src/ops/reduction_ops.ts +++ b/tfjs-core/src/ops/reduction_ops.ts @@ -27,7 +27,6 @@ import {op} from './operation'; import {gradForMinAndMax} from './reduction_ops_util'; import {ones, scalar, zerosLike} from './tensor_ops'; - /** * Computes the log(sum(exp(elements across the reduction dimensions)). * From efef5c8a14c350320ac04d984ba549259a83c380 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Mon, 4 May 2020 10:14:07 -0400 Subject: [PATCH 49/72] pr comments --- tfjs-core/src/ops/max.ts | 31 +++++++++++++++++-------------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/tfjs-core/src/ops/max.ts b/tfjs-core/src/ops/max.ts index fa66e9d3d52..7ff45b6dc6f 100644 --- a/tfjs-core/src/ops/max.ts +++ b/tfjs-core/src/ops/max.ts @@ -16,13 +16,15 @@ */ import {KernelBackend} from '../backends/backend'; -import {ENGINE} from '../engine'; +import {ENGINE, ForwardFunc} from '../engine'; +import {Max} from '../kernel_names'; import {Tensor} from '../tensor'; import {GradSaveFunc} from '../tensor_types'; import {convertToTensor} from '../tensor_util_env'; import {TensorLike} from '../types'; import * as util from '../util'; +import {reshape} from './array_ops'; import * as axis_util from './axis_util'; import {op} from './operation'; import {transpose} from './transpose'; @@ -60,24 +62,25 @@ function max_( let $x = convertToTensor(x, 'x', 'max'); const origAxes = util.parseAxisParam(axis, $x.shape); - const forward = (backend: KernelBackend, save: GradSaveFunc) => { - let axes = origAxes; - const permutedAxes = axis_util.getAxesPermutation(axes, $x.rank); - if (permutedAxes != null) { - $x = transpose($x, permutedAxes); - axes = axis_util.getInnerMostAxes(axes.length, $x.rank); - } + const forward: ForwardFunc = + (backend: KernelBackend, save: GradSaveFunc) => { + let axes = origAxes; + const permutedAxes = axis_util.getAxesPermutation(axes, $x.rank); + if (permutedAxes != null) { + $x = transpose($x, permutedAxes); + axes = axis_util.getInnerMostAxes(axes.length, $x.rank); + } - const y = backend.max($x, axes); - save([$x, y]); - return y; - }; + const y = backend.max($x, axes); + save([$x, y]); + return y; + }; let res = ENGINE.runKernelFunc( - forward, {x: $x}, null /* gradient */, 'Max', {reductionIndices: axis}); + forward, {x: $x}, null /* gradient */, Max, {reductionIndices: axis}); if (keepDims) { const newShape = axis_util.expandShapeToKeepDim(res.shape, origAxes); - res = res.reshape(newShape) as T; + res = reshape(res, newShape) as T; } return res as T; } From 15e21dbebe07d3eb769d042eece61d32578f0fe2 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Mon, 4 May 2020 10:17:12 -0400 Subject: [PATCH 50/72] pr comments --- tfjs-core/src/kernel_names.ts | 2 +- tfjs-core/src/ops/max.ts | 11 ++++++++--- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/tfjs-core/src/kernel_names.ts b/tfjs-core/src/kernel_names.ts index eb57d158db2..907e24c92c2 100644 --- a/tfjs-core/src/kernel_names.ts +++ b/tfjs-core/src/kernel_names.ts @@ -83,7 +83,7 @@ export interface NonMaxSuppressionV5Attrs { export const Max = 'Max'; export type MaxInputs = Pick; export interface MaxAttrs { - reductionIndices: number[]; + reductionIndices: number|number[]; } export const OneHot = 'OneHot'; diff --git a/tfjs-core/src/ops/max.ts b/tfjs-core/src/ops/max.ts index 7ff45b6dc6f..2259f1ec9bd 100644 --- a/tfjs-core/src/ops/max.ts +++ b/tfjs-core/src/ops/max.ts @@ -17,9 +17,10 @@ import {KernelBackend} from '../backends/backend'; import {ENGINE, ForwardFunc} from '../engine'; -import {Max} from '../kernel_names'; +import {Max, MaxAttrs, MaxInputs} from '../kernel_names'; +import {NamedAttrMap} from '../kernel_registry'; import {Tensor} from '../tensor'; -import {GradSaveFunc} from '../tensor_types'; +import {GradSaveFunc, NamedTensorMap} from '../tensor_types'; import {convertToTensor} from '../tensor_util_env'; import {TensorLike} from '../types'; import * as util from '../util'; @@ -29,6 +30,7 @@ import * as axis_util from './axis_util'; import {op} from './operation'; import {transpose} from './transpose'; + /** * Computes the maximum of elements across dimensions of a `tf.Tensor`. * @@ -75,9 +77,12 @@ function max_( save([$x, y]); return y; }; + const inputs: MaxInputs = {x: $x}; + const attrs: MaxAttrs = {reductionIndices: axis}; let res = ENGINE.runKernelFunc( - forward, {x: $x}, null /* gradient */, Max, {reductionIndices: axis}); + forward, inputs as {} as NamedTensorMap, null /* gradient */, Max, + attrs as {} as NamedAttrMap); if (keepDims) { const newShape = axis_util.expandShapeToKeepDim(res.shape, origAxes); res = reshape(res, newShape) as T; From 82fa0939938a000e50b22883e11d7c35940f858a Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Mon, 4 May 2020 10:17:48 -0400 Subject: [PATCH 51/72] pr comments --- tfjs-core/src/public/chained_ops/max.ts | 1 + 1 file changed, 1 insertion(+) diff --git a/tfjs-core/src/public/chained_ops/max.ts b/tfjs-core/src/public/chained_ops/max.ts index c49ffdd7eab..4b0cf030413 100644 --- a/tfjs-core/src/public/chained_ops/max.ts +++ b/tfjs-core/src/public/chained_ops/max.ts @@ -27,5 +27,6 @@ declare module '../../tensor' { Tensor.prototype.max = function( axis?: number|number[], keepDims?: boolean): T { + this.throwIfDisposed(); return max(this, axis, keepDims); }; From a0d5ffd87527a3f7103c5cce5fe57d1334375339 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Mon, 4 May 2020 10:31:40 -0400 Subject: [PATCH 52/72] reorganize --- tfjs-backend-webgl/src/kernel_utils/shared.ts | 22 +++++++++++++++++++ tfjs-backend-webgl/src/kernels/Max.ts | 4 ++-- tfjs-backend-webgl/src/kernels/Max_impl.ts | 22 +++++++++---------- .../src/kernels/Transpose_impl.ts | 4 +--- 4 files changed, 35 insertions(+), 17 deletions(-) create mode 100644 tfjs-backend-webgl/src/kernel_utils/shared.ts diff --git a/tfjs-backend-webgl/src/kernel_utils/shared.ts b/tfjs-backend-webgl/src/kernel_utils/shared.ts new file mode 100644 index 00000000000..772b439a432 --- /dev/null +++ b/tfjs-backend-webgl/src/kernel_utils/shared.ts @@ -0,0 +1,22 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {shared} from '@tensorflow/tfjs-backend-cpu'; + +const {maxImpl: maxImplCPU, transposeImpl: transposeImplCPU} = shared; + +export {maxImplCPU, transposeImplCPU}; diff --git a/tfjs-backend-webgl/src/kernels/Max.ts b/tfjs-backend-webgl/src/kernels/Max.ts index dadc88784b3..1c74ad85105 100644 --- a/tfjs-backend-webgl/src/kernels/Max.ts +++ b/tfjs-backend-webgl/src/kernels/Max.ts @@ -15,11 +15,11 @@ * ============================================================================= */ -import {shared as cpuSharedImpls} from '@tensorflow/tfjs-backend-cpu'; import {Max, MaxAttrs, MaxInputs} from '@tensorflow/tfjs-core'; import {backend_util, KernelConfig, TypedArray, util} from '@tensorflow/tfjs-core'; import {MathBackendWebGL} from '../backend_webgl'; +import {maxImplCPU} from '../kernel_utils/shared'; import {maxImpl} from './Max_impl'; import {transposeImpl, transposeImplCPU} from './Transpose_impl'; @@ -72,7 +72,7 @@ export const maxConfig: KernelConfig = { const xTexData = webglBackend.texData.get(maxInput.dataId); const values = xTexData.values as TypedArray; - const outValues = cpuSharedImpls.maxImpl( + const outValues = maxImplCPU( values, util.sizeFromShape(reduceShape), outShape, x.dtype); out = webglBackend.makeTensorInfo(outShape, x.dtype); diff --git a/tfjs-backend-webgl/src/kernels/Max_impl.ts b/tfjs-backend-webgl/src/kernels/Max_impl.ts index 0bfc864a81a..11ea80bcb87 100644 --- a/tfjs-backend-webgl/src/kernels/Max_impl.ts +++ b/tfjs-backend-webgl/src/kernels/Max_impl.ts @@ -21,16 +21,14 @@ import {MathBackendWebGL} from '../backend_webgl'; import {reduce} from '../kernel_utils/reduce'; import {reshape} from '../kernel_utils/reshape'; -export const maxImpl = - (x: TensorInfo, reduceShape: number[], outShape: number[], - backend: MathBackendWebGL): TensorInfo => { - const inSize = util.sizeFromShape(reduceShape); - const xSize = util.sizeFromShape(x.shape); - const batchSize = xSize / inSize; +export function maxImpl( + x: TensorInfo, reduceShape: number[], outShape: number[], + backend: MathBackendWebGL): TensorInfo { + const inSize = util.sizeFromShape(reduceShape); + const xSize = util.sizeFromShape(x.shape); + const batchSize = xSize / inSize; - return reshape( - reduce( - reshape(x, [batchSize, inSize], backend), x.dtype, 'max', - backend), - outShape, backend); - }; + return reshape( + reduce(reshape(x, [batchSize, inSize], backend), x.dtype, 'max', backend), + outShape, backend); +}; diff --git a/tfjs-backend-webgl/src/kernels/Transpose_impl.ts b/tfjs-backend-webgl/src/kernels/Transpose_impl.ts index f5edeb3cab3..1b0680bc7e2 100644 --- a/tfjs-backend-webgl/src/kernels/Transpose_impl.ts +++ b/tfjs-backend-webgl/src/kernels/Transpose_impl.ts @@ -15,15 +15,13 @@ * ============================================================================= */ -import {shared} from '@tensorflow/tfjs-backend-cpu'; import {env, TensorInfo} from '@tensorflow/tfjs-core'; import {MathBackendWebGL} from '../backend_webgl'; +import {transposeImplCPU} from '../kernel_utils/shared'; import {TransposeProgram} from '../transpose_gpu'; import {TransposePackedProgram} from '../transpose_packed_gpu'; -const transposeImplCPU = shared.transposeImpl; - export function transposeImpl( x: TensorInfo, perm: number[], backend: MathBackendWebGL): TensorInfo { const program = env().getBool('WEBGL_PACK_ARRAY_OPERATIONS') ? From a4ac9ff06f5b10e66ebfaa452c55559aa5824cc8 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Mon, 4 May 2020 10:42:04 -0400 Subject: [PATCH 53/72] add keepdims --- tfjs-backend-cpu/src/kernels/Max.ts | 10 +++++++--- tfjs-backend-webgl/src/kernels/Max.ts | 16 +++++++++++----- tfjs-core/src/kernel_names.ts | 1 + tfjs-core/src/ops/max.ts | 13 ++++--------- 4 files changed, 23 insertions(+), 17 deletions(-) diff --git a/tfjs-backend-cpu/src/kernels/Max.ts b/tfjs-backend-cpu/src/kernels/Max.ts index f43e385b1ba..62961fc1f02 100644 --- a/tfjs-backend-cpu/src/kernels/Max.ts +++ b/tfjs-backend-cpu/src/kernels/Max.ts @@ -30,7 +30,7 @@ export const maxConfig: KernelConfig = { backendName: 'cpu', kernelFunc: ({inputs, attrs, backend}) => { const {x} = inputs as MaxInputs; - const {reductionIndices} = attrs as {} as MaxAttrs; + const {reductionIndices, keepDims} = attrs as {} as MaxAttrs; const cpuBackend = backend as MathBackendCPU; let xShape = x.shape; const xRank = xShape.length; @@ -53,12 +53,16 @@ export const maxConfig: KernelConfig = { assertNotComplex(x, 'max'); backend_util.assertAxesAreInnerMostDims('max', axes, xRank); - const [outShape, reduceShape] = + const [maxOutShape, reduceShape] = backend_util.computeOutAndReduceShapes(xShape, axes); const reduceSize = util.sizeFromShape(reduceShape); - const result = maxImpl(xVals, reduceSize, outShape, x.dtype); + const result = maxImpl(xVals, reduceSize, maxOutShape, x.dtype); + let outShape = maxOutShape; + if (keepDims) { + outShape = backend_util.expandShapeToKeepDim(maxOutShape, origAxes); + } const dataId = cpuBackend.write(result, outShape, x.dtype); return {dataId, shape: outShape, dtype: x.dtype}; diff --git a/tfjs-backend-webgl/src/kernels/Max.ts b/tfjs-backend-webgl/src/kernels/Max.ts index 1c74ad85105..d7acd96162c 100644 --- a/tfjs-backend-webgl/src/kernels/Max.ts +++ b/tfjs-backend-webgl/src/kernels/Max.ts @@ -19,6 +19,7 @@ import {Max, MaxAttrs, MaxInputs} from '@tensorflow/tfjs-core'; import {backend_util, KernelConfig, TypedArray, util} from '@tensorflow/tfjs-core'; import {MathBackendWebGL} from '../backend_webgl'; +import {reshape} from '../kernel_utils/reshape'; import {maxImplCPU} from '../kernel_utils/shared'; import {maxImpl} from './Max_impl'; @@ -29,7 +30,7 @@ export const maxConfig: KernelConfig = { backendName: 'webgl', kernelFunc: ({inputs, attrs, backend}) => { const {x} = inputs as MaxInputs; - const {reductionIndices} = attrs as {} as MaxAttrs; + const {reductionIndices, keepDims} = attrs as {} as MaxAttrs; const webglBackend = backend as MathBackendWebGL; const xRank = x.shape.length; @@ -64,7 +65,7 @@ export const maxConfig: KernelConfig = { } backend_util.assertAxesAreInnerMostDims('max', axes, xRank); - const [outShape, reduceShape] = + const [maxOutShape, reduceShape] = backend_util.computeOutAndReduceShapes(maxInput.shape, axes); let out; @@ -73,13 +74,18 @@ export const maxConfig: KernelConfig = { const values = xTexData.values as TypedArray; const outValues = maxImplCPU( - values, util.sizeFromShape(reduceShape), outShape, x.dtype); + values, util.sizeFromShape(reduceShape), maxOutShape, x.dtype); - out = webglBackend.makeTensorInfo(outShape, x.dtype); + out = webglBackend.makeTensorInfo(maxOutShape, x.dtype); const outData = webglBackend.texData.get(out.dataId); outData.values = outValues; } else { - out = maxImpl(maxInput, reduceShape, outShape, webglBackend); + out = maxImpl(maxInput, reduceShape, maxOutShape, webglBackend); + } + + if (keepDims) { + const outShape = backend_util.expandShapeToKeepDim(maxOutShape, origAxes); + out = reshape(out, outShape, webglBackend); } if (maxInputIsTransposed) { diff --git a/tfjs-core/src/kernel_names.ts b/tfjs-core/src/kernel_names.ts index 907e24c92c2..655741d2cb3 100644 --- a/tfjs-core/src/kernel_names.ts +++ b/tfjs-core/src/kernel_names.ts @@ -84,6 +84,7 @@ export const Max = 'Max'; export type MaxInputs = Pick; export interface MaxAttrs { reductionIndices: number|number[]; + keepDims: boolean; } export const OneHot = 'OneHot'; diff --git a/tfjs-core/src/ops/max.ts b/tfjs-core/src/ops/max.ts index 2259f1ec9bd..e05eafb54ed 100644 --- a/tfjs-core/src/ops/max.ts +++ b/tfjs-core/src/ops/max.ts @@ -78,16 +78,11 @@ function max_( return y; }; const inputs: MaxInputs = {x: $x}; - const attrs: MaxAttrs = {reductionIndices: axis}; + const attrs: MaxAttrs = {reductionIndices: axis, keepDims}; - let res = ENGINE.runKernelFunc( - forward, inputs as {} as NamedTensorMap, null /* gradient */, Max, - attrs as {} as NamedAttrMap); - if (keepDims) { - const newShape = axis_util.expandShapeToKeepDim(res.shape, origAxes); - res = reshape(res, newShape) as T; - } - return res as T; + return ENGINE.runKernelFunc( + forward, inputs as {} as NamedTensorMap, null /* gradient */, Max, + attrs as {} as NamedAttrMap) as T; } export const max = op({max_}); From c088772bcb83a5864faac4d9562cc4eaf401695a Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Mon, 4 May 2020 10:48:27 -0400 Subject: [PATCH 54/72] remove reshape --- tfjs-core/src/ops/max.ts | 2 -- 1 file changed, 2 deletions(-) diff --git a/tfjs-core/src/ops/max.ts b/tfjs-core/src/ops/max.ts index e05eafb54ed..4952bdbb6b3 100644 --- a/tfjs-core/src/ops/max.ts +++ b/tfjs-core/src/ops/max.ts @@ -25,12 +25,10 @@ import {convertToTensor} from '../tensor_util_env'; import {TensorLike} from '../types'; import * as util from '../util'; -import {reshape} from './array_ops'; import * as axis_util from './axis_util'; import {op} from './operation'; import {transpose} from './transpose'; - /** * Computes the maximum of elements across dimensions of a `tf.Tensor`. * From b03ef21ee45a9210fd09f0fe8abb8d68e6c81460 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Mon, 4 May 2020 11:25:29 -0400 Subject: [PATCH 55/72] add max --- tfjs-backend-wasm/src/kernels/Max.ts | 40 ++++++++++++------------ tfjs-backend-wasm/src/kernels/Reshape.ts | 2 +- tfjs-core/src/ops/max.ts | 6 +++- 3 files changed, 26 insertions(+), 22 deletions(-) diff --git a/tfjs-backend-wasm/src/kernels/Max.ts b/tfjs-backend-wasm/src/kernels/Max.ts index ab688b76c80..ca8b5f6c4b8 100644 --- a/tfjs-backend-wasm/src/kernels/Max.ts +++ b/tfjs-backend-wasm/src/kernels/Max.ts @@ -15,17 +15,12 @@ * ============================================================================= */ -import {backend_util, NamedAttrMap, NamedTensorInfoMap, registerKernel, TensorInfo, util} from '@tensorflow/tfjs-core'; +import {backend_util, registerKernel, TensorInfo, util} from '@tensorflow/tfjs-core'; +import {Max, MaxAttrs, MaxInputs} from '@tensorflow/tfjs-core'; import {BackendWasm} from '../backend_wasm'; -interface MaxInputs extends NamedTensorInfoMap { - x: TensorInfo; -} - -interface MaxAttrs extends NamedAttrMap { - axes: number[]; -} +import {reshape} from './Reshape'; let wasmMax: (xId: number, reduceSize: number, outId: number) => void; @@ -34,16 +29,17 @@ function setup(backend: BackendWasm): void { backend.wasm.cwrap('Max', null /*void*/, ['number, number, number']); } -function max(args: {backend: BackendWasm, inputs: MaxInputs, attrs: MaxAttrs}): - TensorInfo { +function max(args: {backend: BackendWasm, inputs: {}, attrs: {}}): TensorInfo { const {backend, inputs, attrs} = args; - const {axes} = attrs; - const {x} = inputs; + const {reductionIndices, keepDims} = attrs as MaxAttrs; + const {x} = inputs as MaxInputs; const xId = backend.dataIdMap.get(x.dataId).id; - backend_util.assertAxesAreInnerMostDims('max', axes, x.shape.length); + const origAxes = util.parseAxisParam(reductionIndices, x.shape); + + backend_util.assertAxesAreInnerMostDims('max', origAxes, x.shape.length); const [outShape, reduceShape] = - backend_util.computeOutAndReduceShapes(x.shape, axes); + backend_util.computeOutAndReduceShapes(x.shape, origAxes); const reduceSize = util.sizeFromShape(reduceShape); const out = backend.makeOutput(outShape, x.dtype); @@ -54,12 +50,16 @@ function max(args: {backend: BackendWasm, inputs: MaxInputs, attrs: MaxAttrs}): const outId = backend.dataIdMap.get(out.dataId).id; wasmMax(xId, reduceSize, outId); + + if (keepDims) { + return reshape({ + inputs: {x: out}, + attrs: {shape: backend_util.expandShapeToKeepDim(out.shape, origAxes)}, + backend + }); + } return out; } -registerKernel({ - kernelName: 'Max', - backendName: 'wasm', - setupFunc: setup, - kernelFunc: max -}); +registerKernel( + {kernelName: Max, backendName: 'wasm', setupFunc: setup, kernelFunc: max}); diff --git a/tfjs-backend-wasm/src/kernels/Reshape.ts b/tfjs-backend-wasm/src/kernels/Reshape.ts index f1009c6ab38..36c6fab9c54 100644 --- a/tfjs-backend-wasm/src/kernels/Reshape.ts +++ b/tfjs-backend-wasm/src/kernels/Reshape.ts @@ -28,7 +28,7 @@ interface ReshapeAttrs extends NamedAttrMap { shape: number[]; } -function reshape( +export function reshape( args: {inputs: ReshapeInputs, attrs: ReshapeAttrs, backend: BackendWasm}) { const {inputs: {x}, attrs: {shape}} = args; return {dataId: x.dataId, shape, dtype: x.dtype}; diff --git a/tfjs-core/src/ops/max.ts b/tfjs-core/src/ops/max.ts index 4952bdbb6b3..f100bf465d9 100644 --- a/tfjs-core/src/ops/max.ts +++ b/tfjs-core/src/ops/max.ts @@ -24,7 +24,7 @@ import {GradSaveFunc, NamedTensorMap} from '../tensor_types'; import {convertToTensor} from '../tensor_util_env'; import {TensorLike} from '../types'; import * as util from '../util'; - +import {reshape} from './array_ops'; import * as axis_util from './axis_util'; import {op} from './operation'; import {transpose} from './transpose'; @@ -73,6 +73,10 @@ function max_( const y = backend.max($x, axes); save([$x, y]); + + if (keepDims) { + return reshape(y, axis_util.expandShapeToKeepDim(y.shape, origAxes)); + } return y; }; const inputs: MaxInputs = {x: $x}; From d7da66f59a7603da9c41b0e952429d2bc9f20f75 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Mon, 4 May 2020 11:27:57 -0400 Subject: [PATCH 56/72] reduce --- tfjs-backend-webgl/src/kernel_utils/reduce.ts | 4 +++- tfjs-core/src/ops/reduce_util.ts | 2 -- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tfjs-backend-webgl/src/kernel_utils/reduce.ts b/tfjs-backend-webgl/src/kernel_utils/reduce.ts index cfc5c1e5151..edf6df8db73 100644 --- a/tfjs-backend-webgl/src/kernel_utils/reduce.ts +++ b/tfjs-backend-webgl/src/kernel_utils/reduce.ts @@ -20,8 +20,10 @@ import {backend_util, DataType, TensorInfo} from '@tensorflow/tfjs-core'; import {MathBackendWebGL} from '../backend_webgl'; import {ReduceProgram} from '../reduce_gpu'; +type ReduceTypes = 'all'|'any'|'max'|'min'|'sum'|'prod'; + export function reduce( - x: TensorInfo, dtype: DataType, reductionType: backend_util.ReduceTypes, + x: TensorInfo, dtype: DataType, reductionType: ReduceTypes, backend: MathBackendWebGL): TensorInfo { const [batchSize, inSize] = x.shape; const windowSize = backend_util.computeOptimalWindowSize(inSize); diff --git a/tfjs-core/src/ops/reduce_util.ts b/tfjs-core/src/ops/reduce_util.ts index 876c784dc35..cb9f27d158a 100644 --- a/tfjs-core/src/ops/reduce_util.ts +++ b/tfjs-core/src/ops/reduce_util.ts @@ -29,8 +29,6 @@ export interface ReduceInfo { inSize: number; } -export type ReduceTypes = 'all'|'any'|'max'|'min'|'sum'|'prod'; - export function computeOptimalWindowSize(inSize: number): number { if (inSize <= PARALLELIZE_THRESHOLD) { return inSize; From a9e62170a829cbf9a1a6007088c88fdaeffaf53f Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Mon, 4 May 2020 11:28:29 -0400 Subject: [PATCH 57/72] typo --- .../src/public/chained_ops/register_all_chained_ops_test.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts b/tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts index 839a79dca90..568858c556d 100644 --- a/tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts +++ b/tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts @@ -20,7 +20,7 @@ import {ALL_ENVS, describeWithFlags} from '../../jasmine_util'; // Testing for presence of chained op in this file will allow us to more easily // customize when we want this test to run. Currently it will run be default -// (And kerma will always load the chain augmentor files). But this gives us +// (And karma will always load the chain augmentor files). But this gives us // flexibility to change in future. const CHAINED_OPS = [ From 5bc64b9b10517fd92cf611f8fc23e03a5e34dea4 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Mon, 4 May 2020 14:17:26 -0400 Subject: [PATCH 58/72] lint --- tfjs-backend-webgl/src/kernels/Max_impl.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tfjs-backend-webgl/src/kernels/Max_impl.ts b/tfjs-backend-webgl/src/kernels/Max_impl.ts index 11ea80bcb87..7a323b49d02 100644 --- a/tfjs-backend-webgl/src/kernels/Max_impl.ts +++ b/tfjs-backend-webgl/src/kernels/Max_impl.ts @@ -31,4 +31,4 @@ export function maxImpl( return reshape( reduce(reshape(x, [batchSize, inSize], backend), x.dtype, 'max', backend), outShape, backend); -}; +} From 65e04a1cff5b40ac543d2d67a46639ac1aad7c54 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Tue, 5 May 2020 12:55:00 -0400 Subject: [PATCH 59/72] add msg --- tfjs-backend-cpu/src/cpu_util.ts | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tfjs-backend-cpu/src/cpu_util.ts b/tfjs-backend-cpu/src/cpu_util.ts index 630e9ba127d..28ed23e0ff1 100644 --- a/tfjs-backend-cpu/src/cpu_util.ts +++ b/tfjs-backend-cpu/src/cpu_util.ts @@ -26,7 +26,8 @@ export function assertNotComplex( if (t != null) { util.assert( t.dtype !== 'complex64', - () => `${opName} does not support complex64 tensors.`); + () => `${ + opName} does not support complex64 tensors in the CPU backend.`); } }); } From d8dd8d5062bcf15eaba9fbbb6f7556535417d95c Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Tue, 5 May 2020 12:59:49 -0400 Subject: [PATCH 60/72] move --- tfjs-core/src/ops/max.ts | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tfjs-core/src/ops/max.ts b/tfjs-core/src/ops/max.ts index f100bf465d9..b39b78c787d 100644 --- a/tfjs-core/src/ops/max.ts +++ b/tfjs-core/src/ops/max.ts @@ -60,10 +60,9 @@ import {transpose} from './transpose'; function max_( x: Tensor|TensorLike, axis: number|number[] = null, keepDims = false): T { let $x = convertToTensor(x, 'x', 'max'); - const origAxes = util.parseAxisParam(axis, $x.shape); - const forward: ForwardFunc = (backend: KernelBackend, save: GradSaveFunc) => { + const origAxes = util.parseAxisParam(axis, $x.shape); let axes = origAxes; const permutedAxes = axis_util.getAxesPermutation(axes, $x.rank); if (permutedAxes != null) { From acd1f66339e8506dd7029b58d120c4cae076e901 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Wed, 6 May 2020 11:04:23 -0400 Subject: [PATCH 61/72] properly dispose --- tfjs-backend-webgl/src/kernels/Max_impl.ts | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/tfjs-backend-webgl/src/kernels/Max_impl.ts b/tfjs-backend-webgl/src/kernels/Max_impl.ts index 7a323b49d02..04a01f1bd75 100644 --- a/tfjs-backend-webgl/src/kernels/Max_impl.ts +++ b/tfjs-backend-webgl/src/kernels/Max_impl.ts @@ -27,8 +27,13 @@ export function maxImpl( const inSize = util.sizeFromShape(reduceShape); const xSize = util.sizeFromShape(x.shape); const batchSize = xSize / inSize; + const reshapedInput = reshape(x, [batchSize, inSize], backend); + const reduced = reduce(reshapedInput, x.dtype, 'max', backend); - return reshape( - reduce(reshape(x, [batchSize, inSize], backend), x.dtype, 'max', backend), - outShape, backend); + if (reshapedInput.dataId !== x.dataId) { + // dispose the output of the packed reshape. + backend.disposeData(reshapedInput.dataId); + } + + return reshape(reduced, outShape, backend); } From aa50854caab7a0dcfa93c9568479e437491696c0 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Thu, 7 May 2020 09:31:58 -0400 Subject: [PATCH 62/72] dispose intermediate --- tfjs-core/src/ops/max.ts | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/tfjs-core/src/ops/max.ts b/tfjs-core/src/ops/max.ts index b39b78c787d..68fed51697e 100644 --- a/tfjs-core/src/ops/max.ts +++ b/tfjs-core/src/ops/max.ts @@ -59,20 +59,25 @@ import {transpose} from './transpose'; /** @doc {heading: 'Operations', subheading: 'Reduction'} */ function max_( x: Tensor|TensorLike, axis: number|number[] = null, keepDims = false): T { - let $x = convertToTensor(x, 'x', 'max'); + const $x = convertToTensor(x, 'x', 'max'); const forward: ForwardFunc = (backend: KernelBackend, save: GradSaveFunc) => { const origAxes = util.parseAxisParam(axis, $x.shape); let axes = origAxes; const permutedAxes = axis_util.getAxesPermutation(axes, $x.rank); + let maxInput = $x; if (permutedAxes != null) { - $x = transpose($x, permutedAxes); - axes = axis_util.getInnerMostAxes(axes.length, $x.rank); + maxInput = transpose($x, permutedAxes); + axes = axis_util.getInnerMostAxes(axes.length, maxInput.rank); } - const y = backend.max($x, axes); + const y = backend.max(maxInput, axes); save([$x, y]); + if (permutedAxes != null) { + backend.disposeData(maxInput.dataId); + } + if (keepDims) { return reshape(y, axis_util.expandShapeToKeepDim(y.shape, origAxes)); } From c9a1976f0897d418038ec358f181a480e9b8a5b3 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Thu, 7 May 2020 10:12:38 -0400 Subject: [PATCH 63/72] add logs --- tfjs-core/src/ops/max.ts | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/tfjs-core/src/ops/max.ts b/tfjs-core/src/ops/max.ts index 68fed51697e..ad2fb29e191 100644 --- a/tfjs-core/src/ops/max.ts +++ b/tfjs-core/src/ops/max.ts @@ -62,26 +62,45 @@ function max_( const $x = convertToTensor(x, 'x', 'max'); const forward: ForwardFunc = (backend: KernelBackend, save: GradSaveFunc) => { + let numDataIds = backend.numDataIds(); + console.log('--------------- MAX KERNEL'); + console.log('start num data ids:', numDataIds); const origAxes = util.parseAxisParam(axis, $x.shape); let axes = origAxes; const permutedAxes = axis_util.getAxesPermutation(axes, $x.rank); let maxInput = $x; if (permutedAxes != null) { + console.log('TRANSPOSE'); maxInput = transpose($x, permutedAxes); axes = axis_util.getInnerMostAxes(axes.length, maxInput.rank); } + numDataIds = backend.numDataIds(); + console.log('num data ids after transpose', numDataIds); + const y = backend.max(maxInput, axes); + numDataIds = backend.numDataIds(); + console.log('num data ids after max', numDataIds); save([$x, y]); if (permutedAxes != null) { backend.disposeData(maxInput.dataId); } + numDataIds = backend.numDataIds(); + console.log('num data ids after dispose', numDataIds); + + let output = y; + if (keepDims) { - return reshape(y, axis_util.expandShapeToKeepDim(y.shape, origAxes)); + console.log('RESHAPE'); + output = + reshape(y, axis_util.expandShapeToKeepDim(y.shape, origAxes)); } - return y; + + numDataIds = backend.numDataIds(); + console.log('num data ids after reshape', numDataIds); + return output; }; const inputs: MaxInputs = {x: $x}; const attrs: MaxAttrs = {reductionIndices: axis, keepDims}; From d43dadb33eb73f21bb5ccffd888c41ee436c9c2e Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Thu, 7 May 2020 13:26:23 -0400 Subject: [PATCH 64/72] run one test --- tfjs-core/src/ops/reduction_ops_test.ts | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tfjs-core/src/ops/reduction_ops_test.ts b/tfjs-core/src/ops/reduction_ops_test.ts index a26623ff45b..6c610d88358 100644 --- a/tfjs-core/src/ops/reduction_ops_test.ts +++ b/tfjs-core/src/ops/reduction_ops_test.ts @@ -364,7 +364,9 @@ describeWithFlags('max', ALL_ENVS, () => { expect(gradients.shape).toEqual([2, 3]); }); - it('max gradient: 2D, axes=0, keepDims=true', async () => { + // tslint:disable-next-line: ban + fit('max gradient: 2D, axes=0, keepDims=true', async () => { + console.log('RUNNING PROBLEM TEST'); const x = tf.tensor2d([[0, 20, 10], [-10, -30, 20]]); const dy = tf.tensor2d([[-1, -1, -1]]); const axis = 0; @@ -372,6 +374,7 @@ describeWithFlags('max', ALL_ENVS, () => { const gradients = tf.grad(v => tf.max(v, axis, keepDims))(x, dy); expectArraysClose(await gradients.data(), [-1, -1, 0, 0, 0, -1]); expect(gradients.shape).toEqual([2, 3]); + console.log('END RUNNING PROBLEM TEST'); }); it('max gradient: 3D, axes=[1, 2], keepDims=false', async () => { From 3522c95f0f090633c3be4aaf047d5a190245a60f Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Thu, 7 May 2020 13:54:31 -0400 Subject: [PATCH 65/72] hbn --- tfjs-node/src/nodejs_kernel_backend.ts | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tfjs-node/src/nodejs_kernel_backend.ts b/tfjs-node/src/nodejs_kernel_backend.ts index 55ffcc97b24..3a8251ad9f2 100644 --- a/tfjs-node/src/nodejs_kernel_backend.ts +++ b/tfjs-node/src/nodejs_kernel_backend.ts @@ -616,8 +616,10 @@ export class NodeJSKernelBackend extends KernelBackend { max(x: Tensor, axes: number[]): Tensor { const axesTensor = tensor1d(axes, 'int32'); - return this.executeSingleOutput( + const out = this.executeSingleOutput( 'Max', this.createReductionOpAttrs(x), [x, axesTensor]); + this.disposeData(axesTensor.dataId); + return out; } maximum(a: Tensor, b: Tensor): Tensor { From a4f0ae6ec57988d71313a67a926d851c3b5cdc6e Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Fri, 8 May 2020 07:19:34 -0400 Subject: [PATCH 66/72] temp disable --- tslint.json | 1 - 1 file changed, 1 deletion(-) diff --git a/tslint.json b/tslint.json index e580e8b1a9a..b97cfb9ef1e 100644 --- a/tslint.json +++ b/tslint.json @@ -5,7 +5,6 @@ "array-type": [true, "array-simple"], "arrow-return-shorthand": true, "ban": [true, - ["fit"], ["fdescribe"], ["xit"], ["xdescribe"], From 10570150fcebeb5f17a03a894bd951582f2b84f5 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Fri, 8 May 2020 07:55:39 -0400 Subject: [PATCH 67/72] revive --- tslint.json | 1 + 1 file changed, 1 insertion(+) diff --git a/tslint.json b/tslint.json index b97cfb9ef1e..e580e8b1a9a 100644 --- a/tslint.json +++ b/tslint.json @@ -5,6 +5,7 @@ "array-type": [true, "array-simple"], "arrow-return-shorthand": true, "ban": [true, + ["fit"], ["fdescribe"], ["xit"], ["xdescribe"], From c22e41ce090aae583828f79d2eac9c9fd2d0e00b Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Fri, 8 May 2020 09:18:49 -0400 Subject: [PATCH 68/72] yarn --- tfjs-data/yarn.lock | 5 +++++ tfjs-layers/yarn.lock | 5 +++++ 2 files changed, 10 insertions(+) diff --git a/tfjs-data/yarn.lock b/tfjs-data/yarn.lock index 28988ae2a4f..9c05c22f06a 100644 --- a/tfjs-data/yarn.lock +++ b/tfjs-data/yarn.lock @@ -3262,6 +3262,11 @@ minimist@^1.1.0, minimist@^1.1.3, minimist@^1.2.0: resolved "https://registry.yarnpkg.com/minimist/-/minimist-1.2.0.tgz#a35008b20f41383eec1fb914f4cd5df79a264284" integrity sha1-o1AIsg9BOD7sH7kU9M1d95omQoQ= +minimist@^1.2.5: + version "1.2.5" + resolved "https://registry.yarnpkg.com/minimist/-/minimist-1.2.5.tgz#67d66014b66a6a8aaa0c083c5fd58df4e4e97602" + integrity sha512-FM9nNUYrRBAELZQT3xeZQ7fmMOBg6nWNmJKTcgsJeaLstP/UODVpGsr5OhXhhXg6f+qtJ8uiZ+PUxkDWcgIXLw== + minimist@~0.0.1: version "0.0.10" resolved "https://registry.yarnpkg.com/minimist/-/minimist-0.0.10.tgz#de3f98543dbf96082be48ad1a0c7cda836301dcf" diff --git a/tfjs-layers/yarn.lock b/tfjs-layers/yarn.lock index 7cb0a9202cf..17fe70a34f2 100644 --- a/tfjs-layers/yarn.lock +++ b/tfjs-layers/yarn.lock @@ -3037,6 +3037,11 @@ minimist@^1.1.0, minimist@^1.1.3: resolved "https://registry.yarnpkg.com/minimist/-/minimist-1.2.0.tgz#a35008b20f41383eec1fb914f4cd5df79a264284" integrity sha1-o1AIsg9BOD7sH7kU9M1d95omQoQ= +minimist@^1.2.5: + version "1.2.5" + resolved "https://registry.yarnpkg.com/minimist/-/minimist-1.2.5.tgz#67d66014b66a6a8aaa0c083c5fd58df4e4e97602" + integrity sha512-FM9nNUYrRBAELZQT3xeZQ7fmMOBg6nWNmJKTcgsJeaLstP/UODVpGsr5OhXhhXg6f+qtJ8uiZ+PUxkDWcgIXLw== + minimist@~0.0.1: version "0.0.10" resolved "https://registry.yarnpkg.com/minimist/-/minimist-0.0.10.tgz#de3f98543dbf96082be48ad1a0c7cda836301dcf" From 7671c286c90487e8858c8723ceb265bcb4f7f1b4 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Fri, 8 May 2020 09:48:31 -0400 Subject: [PATCH 69/72] revive --- tfjs-backend-cpu/src/kernels/Max.ts | 11 +++-------- tfjs-backend-webgl/src/kernels/Max.ts | 8 +------- tfjs-core/src/ops/max.ts | 25 +++++++++++-------------- tfjs-core/src/ops/reduction_ops_test.ts | 3 +-- tfjs-node/src/nodejs_kernel_backend.ts | 1 - 5 files changed, 16 insertions(+), 32 deletions(-) diff --git a/tfjs-backend-cpu/src/kernels/Max.ts b/tfjs-backend-cpu/src/kernels/Max.ts index 62961fc1f02..bff04ee9115 100644 --- a/tfjs-backend-cpu/src/kernels/Max.ts +++ b/tfjs-backend-cpu/src/kernels/Max.ts @@ -30,7 +30,7 @@ export const maxConfig: KernelConfig = { backendName: 'cpu', kernelFunc: ({inputs, attrs, backend}) => { const {x} = inputs as MaxInputs; - const {reductionIndices, keepDims} = attrs as {} as MaxAttrs; + const {reductionIndices} = attrs as {} as MaxAttrs; const cpuBackend = backend as MathBackendCPU; let xShape = x.shape; const xRank = xShape.length; @@ -59,12 +59,7 @@ export const maxConfig: KernelConfig = { const reduceSize = util.sizeFromShape(reduceShape); const result = maxImpl(xVals, reduceSize, maxOutShape, x.dtype); - let outShape = maxOutShape; - if (keepDims) { - outShape = backend_util.expandShapeToKeepDim(maxOutShape, origAxes); - } - - const dataId = cpuBackend.write(result, outShape, x.dtype); - return {dataId, shape: outShape, dtype: x.dtype}; + const dataId = cpuBackend.write(result, maxOutShape, x.dtype); + return {dataId, shape: maxOutShape, dtype: x.dtype}; } }; diff --git a/tfjs-backend-webgl/src/kernels/Max.ts b/tfjs-backend-webgl/src/kernels/Max.ts index d7acd96162c..54b193c2ebc 100644 --- a/tfjs-backend-webgl/src/kernels/Max.ts +++ b/tfjs-backend-webgl/src/kernels/Max.ts @@ -19,7 +19,6 @@ import {Max, MaxAttrs, MaxInputs} from '@tensorflow/tfjs-core'; import {backend_util, KernelConfig, TypedArray, util} from '@tensorflow/tfjs-core'; import {MathBackendWebGL} from '../backend_webgl'; -import {reshape} from '../kernel_utils/reshape'; import {maxImplCPU} from '../kernel_utils/shared'; import {maxImpl} from './Max_impl'; @@ -30,7 +29,7 @@ export const maxConfig: KernelConfig = { backendName: 'webgl', kernelFunc: ({inputs, attrs, backend}) => { const {x} = inputs as MaxInputs; - const {reductionIndices, keepDims} = attrs as {} as MaxAttrs; + const {reductionIndices} = attrs as {} as MaxAttrs; const webglBackend = backend as MathBackendWebGL; const xRank = x.shape.length; @@ -83,11 +82,6 @@ export const maxConfig: KernelConfig = { out = maxImpl(maxInput, reduceShape, maxOutShape, webglBackend); } - if (keepDims) { - const outShape = backend_util.expandShapeToKeepDim(maxOutShape, origAxes); - out = reshape(out, outShape, webglBackend); - } - if (maxInputIsTransposed) { webglBackend.disposeData(maxInput.dataId); } diff --git a/tfjs-core/src/ops/max.ts b/tfjs-core/src/ops/max.ts index ad2fb29e191..dddcefa2ab4 100644 --- a/tfjs-core/src/ops/max.ts +++ b/tfjs-core/src/ops/max.ts @@ -90,24 +90,21 @@ function max_( numDataIds = backend.numDataIds(); console.log('num data ids after dispose', numDataIds); - let output = y; - - if (keepDims) { - console.log('RESHAPE'); - output = - reshape(y, axis_util.expandShapeToKeepDim(y.shape, origAxes)); - } - - numDataIds = backend.numDataIds(); - console.log('num data ids after reshape', numDataIds); - return output; + return y; }; const inputs: MaxInputs = {x: $x}; const attrs: MaxAttrs = {reductionIndices: axis, keepDims}; - return ENGINE.runKernelFunc( - forward, inputs as {} as NamedTensorMap, null /* gradient */, Max, - attrs as {} as NamedAttrMap) as T; + const res = ENGINE.runKernelFunc( + forward, inputs as {} as NamedTensorMap, null /* gradient */, + Max, attrs as {} as NamedAttrMap) as T; + if (keepDims) { + return reshape( + res, + axis_util.expandShapeToKeepDim( + res.shape, util.parseAxisParam(axis, $x.shape))) as T; + } + return res; } export const max = op({max_}); diff --git a/tfjs-core/src/ops/reduction_ops_test.ts b/tfjs-core/src/ops/reduction_ops_test.ts index 6c610d88358..2de3314af3d 100644 --- a/tfjs-core/src/ops/reduction_ops_test.ts +++ b/tfjs-core/src/ops/reduction_ops_test.ts @@ -364,8 +364,7 @@ describeWithFlags('max', ALL_ENVS, () => { expect(gradients.shape).toEqual([2, 3]); }); - // tslint:disable-next-line: ban - fit('max gradient: 2D, axes=0, keepDims=true', async () => { + it('max gradient: 2D, axes=0, keepDims=true', async () => { console.log('RUNNING PROBLEM TEST'); const x = tf.tensor2d([[0, 20, 10], [-10, -30, 20]]); const dy = tf.tensor2d([[-1, -1, -1]]); diff --git a/tfjs-node/src/nodejs_kernel_backend.ts b/tfjs-node/src/nodejs_kernel_backend.ts index 3a8251ad9f2..80d5ac23c8b 100644 --- a/tfjs-node/src/nodejs_kernel_backend.ts +++ b/tfjs-node/src/nodejs_kernel_backend.ts @@ -618,7 +618,6 @@ export class NodeJSKernelBackend extends KernelBackend { const axesTensor = tensor1d(axes, 'int32'); const out = this.executeSingleOutput( 'Max', this.createReductionOpAttrs(x), [x, axesTensor]); - this.disposeData(axesTensor.dataId); return out; } From 448d0712521706618e7b0136dfbf1f9a5d8572a4 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Fri, 8 May 2020 09:50:05 -0400 Subject: [PATCH 70/72] remove logs --- tfjs-core/src/ops/max.ts | 12 ------------ tfjs-core/src/ops/reduction_ops_test.ts | 2 -- 2 files changed, 14 deletions(-) diff --git a/tfjs-core/src/ops/max.ts b/tfjs-core/src/ops/max.ts index dddcefa2ab4..8a26f957ef0 100644 --- a/tfjs-core/src/ops/max.ts +++ b/tfjs-core/src/ops/max.ts @@ -62,34 +62,22 @@ function max_( const $x = convertToTensor(x, 'x', 'max'); const forward: ForwardFunc = (backend: KernelBackend, save: GradSaveFunc) => { - let numDataIds = backend.numDataIds(); - console.log('--------------- MAX KERNEL'); - console.log('start num data ids:', numDataIds); const origAxes = util.parseAxisParam(axis, $x.shape); let axes = origAxes; const permutedAxes = axis_util.getAxesPermutation(axes, $x.rank); let maxInput = $x; if (permutedAxes != null) { - console.log('TRANSPOSE'); maxInput = transpose($x, permutedAxes); axes = axis_util.getInnerMostAxes(axes.length, maxInput.rank); } - numDataIds = backend.numDataIds(); - console.log('num data ids after transpose', numDataIds); - const y = backend.max(maxInput, axes); - numDataIds = backend.numDataIds(); - console.log('num data ids after max', numDataIds); save([$x, y]); if (permutedAxes != null) { backend.disposeData(maxInput.dataId); } - numDataIds = backend.numDataIds(); - console.log('num data ids after dispose', numDataIds); - return y; }; const inputs: MaxInputs = {x: $x}; diff --git a/tfjs-core/src/ops/reduction_ops_test.ts b/tfjs-core/src/ops/reduction_ops_test.ts index 2de3314af3d..a26623ff45b 100644 --- a/tfjs-core/src/ops/reduction_ops_test.ts +++ b/tfjs-core/src/ops/reduction_ops_test.ts @@ -365,7 +365,6 @@ describeWithFlags('max', ALL_ENVS, () => { }); it('max gradient: 2D, axes=0, keepDims=true', async () => { - console.log('RUNNING PROBLEM TEST'); const x = tf.tensor2d([[0, 20, 10], [-10, -30, 20]]); const dy = tf.tensor2d([[-1, -1, -1]]); const axis = 0; @@ -373,7 +372,6 @@ describeWithFlags('max', ALL_ENVS, () => { const gradients = tf.grad(v => tf.max(v, axis, keepDims))(x, dy); expectArraysClose(await gradients.data(), [-1, -1, 0, 0, 0, -1]); expect(gradients.shape).toEqual([2, 3]); - console.log('END RUNNING PROBLEM TEST'); }); it('max gradient: 3D, axes=[1, 2], keepDims=false', async () => { From 814fe2039a86629d6a1cf56298e0eb350428bfe4 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Fri, 8 May 2020 09:52:06 -0400 Subject: [PATCH 71/72] rm --- tfjs-backend-wasm/src/kernels/Max.ts | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/tfjs-backend-wasm/src/kernels/Max.ts b/tfjs-backend-wasm/src/kernels/Max.ts index ca8b5f6c4b8..d9a31a4a5c0 100644 --- a/tfjs-backend-wasm/src/kernels/Max.ts +++ b/tfjs-backend-wasm/src/kernels/Max.ts @@ -20,8 +20,6 @@ import {Max, MaxAttrs, MaxInputs} from '@tensorflow/tfjs-core'; import {BackendWasm} from '../backend_wasm'; -import {reshape} from './Reshape'; - let wasmMax: (xId: number, reduceSize: number, outId: number) => void; function setup(backend: BackendWasm): void { @@ -31,7 +29,7 @@ function setup(backend: BackendWasm): void { function max(args: {backend: BackendWasm, inputs: {}, attrs: {}}): TensorInfo { const {backend, inputs, attrs} = args; - const {reductionIndices, keepDims} = attrs as MaxAttrs; + const {reductionIndices} = attrs as MaxAttrs; const {x} = inputs as MaxInputs; const xId = backend.dataIdMap.get(x.dataId).id; @@ -51,13 +49,6 @@ function max(args: {backend: BackendWasm, inputs: {}, attrs: {}}): TensorInfo { wasmMax(xId, reduceSize, outId); - if (keepDims) { - return reshape({ - inputs: {x: out}, - attrs: {shape: backend_util.expandShapeToKeepDim(out.shape, origAxes)}, - backend - }); - } return out; } From 4eed0b5d8b022d209e9ac44c354209041a69c12c Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Fri, 8 May 2020 09:55:13 -0400 Subject: [PATCH 72/72] clean --- tfjs-node/src/nodejs_kernel_backend.ts | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tfjs-node/src/nodejs_kernel_backend.ts b/tfjs-node/src/nodejs_kernel_backend.ts index 80d5ac23c8b..55ffcc97b24 100644 --- a/tfjs-node/src/nodejs_kernel_backend.ts +++ b/tfjs-node/src/nodejs_kernel_backend.ts @@ -616,9 +616,8 @@ export class NodeJSKernelBackend extends KernelBackend { max(x: Tensor, axes: number[]): Tensor { const axesTensor = tensor1d(axes, 'int32'); - const out = this.executeSingleOutput( + return this.executeSingleOutput( 'Max', this.createReductionOpAttrs(x), [x, axesTensor]); - return out; } maximum(a: Tensor, b: Tensor): Tensor {