diff --git a/tfjs-backend-cpu/src/backend_cpu.ts b/tfjs-backend-cpu/src/backend_cpu.ts index 9d69e3d2d3c..fc819ecb2cd 100644 --- a/tfjs-backend-cpu/src/backend_cpu.ts +++ b/tfjs-backend-cpu/src/backend_cpu.ts @@ -365,7 +365,9 @@ export class MathBackendCPU extends KernelBackend { softmax(logits: T, dim: number): T { const axes = util.parseAxisParam([dim], logits.shape); - const maxLogit = this.max(logits, axes); + // TODO(annxingyuan): Call maxImpl rather than op as part of softmax kernel + // modularization. + const maxLogit = tf.max(logits, axes); const expandedShape = backend_util.expandShapeToKeepDim(maxLogit.shape, axes); const a = this.subtract(logits, maxLogit.reshape(expandedShape)); @@ -807,31 +809,6 @@ export class MathBackendCPU extends KernelBackend { }); } - max(x: Tensor, axes: number[]): Tensor { - assertNotComplex(x, 'max'); - - backend_util.assertAxesAreInnerMostDims('max', axes, x.rank); - const [outShape, reduceShape] = - backend_util.computeOutAndReduceShapes(x.shape, axes); - const result = tf.zeros(outShape, x.dtype); - const reduceSize = util.sizeFromShape(reduceShape); - const vals = this.readSync(result.dataId) as TypedArray; - - const aVals = this.readSync(x.dataId) as TypedArray; - for (let i = 0; i < vals.length; ++i) { - const offset = i * reduceSize; - let max = aVals[offset]; - for (let j = 0; j < reduceSize; ++j) { - const value = aVals[offset + j]; - if (value > max) { - max = value; - } - } - vals[i] = max; - } - return result; - } - maximum(a: Tensor, b: Tensor): Tensor { assertNotComplex([a, b], 'maximum'); diff --git a/tfjs-backend-cpu/src/kernels/Max.ts b/tfjs-backend-cpu/src/kernels/Max.ts new file mode 100644 index 00000000000..cd41de1f100 --- /dev/null +++ b/tfjs-backend-cpu/src/kernels/Max.ts @@ -0,0 +1,66 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Max, MaxAttrs, MaxInputs} from '@tensorflow/tfjs-core'; +import {backend_util, KernelConfig} from '@tensorflow/tfjs-core'; +import {TypedArray, util} from '@tensorflow/tfjs-core'; + +import {MathBackendCPU} from '../backend_cpu'; +import {assertNotComplex} from '../cpu_util'; + +import {maxImpl} from './Max_impl'; +import {transposeImpl} from './Transpose_impl'; + +export const maxConfig: KernelConfig = { + kernelName: Max, + backendName: 'cpu', + kernelFunc: ({inputs, attrs, backend}) => { + const {x} = inputs as MaxInputs; + const {reductionIndices} = attrs as {} as MaxAttrs; + const cpuBackend = backend as MathBackendCPU; + let xShape = x.shape; + const xRank = xShape.length; + + const origAxes = util.parseAxisParam(reductionIndices, xShape); + let axes = origAxes; + const permutedAxes = backend_util.getAxesPermutation(axes, xRank); + let xVals = cpuBackend.data.get(x.dataId).values as TypedArray; + if (permutedAxes != null) { + xVals = transposeImpl(xVals, xShape, x.dtype, permutedAxes); + axes = backend_util.getInnerMostAxes(axes.length, xRank); + + const newShape: number[] = new Array(xRank); + for (let i = 0; i < newShape.length; i++) { + newShape[i] = xShape[permutedAxes[i]]; + } + + xShape = newShape; + } + + assertNotComplex(x, 'max'); + backend_util.assertAxesAreInnerMostDims('max', axes, xRank); + const [outShape, reduceShape] = + backend_util.computeOutAndReduceShapes(xShape, axes); + + const reduceSize = util.sizeFromShape(reduceShape); + + const result = maxImpl(xVals, reduceSize, outShape, x.dtype); + + const dataId = cpuBackend.write(result, outShape, x.dtype); + return {dataId, shape: outShape, dtype: x.dtype}; + } +}; diff --git a/tfjs-backend-cpu/src/kernels/Max_impl.ts b/tfjs-backend-cpu/src/kernels/Max_impl.ts new file mode 100644 index 00000000000..c326ac3f9df --- /dev/null +++ b/tfjs-backend-cpu/src/kernels/Max_impl.ts @@ -0,0 +1,38 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {DataType, NumericDataType, TypedArray, util} from '@tensorflow/tfjs-core'; + +export function maxImpl( + aVals: TypedArray, reduceSize: number, outShape: number[], + dtype: DataType): TypedArray { + const vals = util.getTypedArrayFromDType( + dtype as NumericDataType, util.sizeFromShape(outShape)); + + for (let i = 0; i < vals.length; ++i) { + const offset = i * reduceSize; + let max = aVals[offset]; + for (let j = 0; j < reduceSize; ++j) { + const value = aVals[offset + j]; + if (value > max) { + max = value; + } + } + vals[i] = max; + } + return vals; +} diff --git a/tfjs-backend-cpu/src/kernels/Transpose.ts b/tfjs-backend-cpu/src/kernels/Transpose.ts index 05ae13f03cb..c1c68be78f2 100644 --- a/tfjs-backend-cpu/src/kernels/Transpose.ts +++ b/tfjs-backend-cpu/src/kernels/Transpose.ts @@ -41,7 +41,7 @@ export const transposeConfig: KernelConfig = { } const values = cpuBackend.data.get(x.dataId).values as TypedArray; - const result = transposeImpl(values, x.shape, x.dtype, perm, newShape); + const result = transposeImpl(values, x.shape, x.dtype, perm); const dataId = cpuBackend.write(result, newShape, x.dtype); return {dataId, shape: newShape, dtype: x.dtype}; diff --git a/tfjs-backend-cpu/src/kernels/Transpose_impl.ts b/tfjs-backend-cpu/src/kernels/Transpose_impl.ts index 3a6d290b777..b7d78decb03 100644 --- a/tfjs-backend-cpu/src/kernels/Transpose_impl.ts +++ b/tfjs-backend-cpu/src/kernels/Transpose_impl.ts @@ -19,10 +19,15 @@ import {DataType, NumericDataType, TypedArray} from '@tensorflow/tfjs-core'; import {util} from '@tensorflow/tfjs-core'; export function transposeImpl( - xVals: TypedArray, xShape: number[], dtype: DataType, perm: number[], - newShape: number[]): TypedArray { - const xSize = util.sizeFromShape(xShape); + xVals: TypedArray, xShape: number[], dtype: DataType, + perm: number[]): TypedArray { const xRank = xShape.length; + const newShape: number[] = new Array(xRank); + for (let i = 0; i < newShape.length; i++) { + newShape[i] = xShape[perm[i]]; + } + + const xSize = util.sizeFromShape(xShape); const xStrides = util.computeStrides(xShape); const newStrides = util.computeStrides(newShape); diff --git a/tfjs-backend-cpu/src/register_all_kernels.ts b/tfjs-backend-cpu/src/register_all_kernels.ts index 8618a514683..a1a6a66af82 100644 --- a/tfjs-backend-cpu/src/register_all_kernels.ts +++ b/tfjs-backend-cpu/src/register_all_kernels.ts @@ -20,6 +20,7 @@ import {KernelConfig, registerKernel} from '@tensorflow/tfjs-core'; import {divConfig} from './kernels/Div'; +import {maxConfig} from './kernels/Max'; import {maxPoolWithArgmaxConfig} from './kernels/MaxPoolWithArgmax'; import {nonMaxSuppressionV5Config} from './kernels/NonMaxSuppressionV5'; import {squareConfig} from './kernels/Square'; @@ -29,7 +30,7 @@ import {transposeConfig} from './kernels/Transpose'; // List all kernel configs here const kernelConfigs: KernelConfig[] = [ nonMaxSuppressionV5Config, squareConfig, squaredDifferenceConfig, divConfig, - transposeConfig, maxPoolWithArgmaxConfig + transposeConfig, maxPoolWithArgmaxConfig, maxConfig ]; for (const kernelConfig of kernelConfigs) { diff --git a/tfjs-backend-cpu/yarn.lock b/tfjs-backend-cpu/yarn.lock index ffeeeada886..754e4d08dbc 100644 --- a/tfjs-backend-cpu/yarn.lock +++ b/tfjs-backend-cpu/yarn.lock @@ -139,11 +139,6 @@ acorn@^6.0.5: resolved "https://registry.yarnpkg.com/acorn/-/acorn-6.4.1.tgz#531e58ba3f51b9dacb9a6646ca4debf5b14ca474" integrity sha512-ZVA9k326Nwrj3Cj9jlh3wGFutC2ZornPNARZwsNYqQYgN0EsV2d53w5RN/co65Ohn4sUAUtb1rSUAOD6XN9idA== -acorn@^7.1.1: - version "7.1.1" - resolved "https://registry.yarnpkg.com/acorn/-/acorn-7.1.1.tgz#e35668de0b402f359de515c5482a1ab9f89a69bf" - integrity sha512-add7dgA5ppRPxCFJoAGfMDi7PIBXq1RtGo7BhbLaxwrXPOmw8gq48Y9ozT01hUKy9byMjlR20EJhu5zlkErEkg== - after@0.8.2: version "0.8.2" resolved "https://registry.yarnpkg.com/after/-/after-0.8.2.tgz#fedb394f9f0e02aa9768e702bda23b505fae7e1f" diff --git a/tfjs-backend-webgl/src/backend_webgl.ts b/tfjs-backend-webgl/src/backend_webgl.ts index bee73a7bd15..3785adedadd 100644 --- a/tfjs-backend-webgl/src/backend_webgl.ts +++ b/tfjs-backend-webgl/src/backend_webgl.ts @@ -19,7 +19,7 @@ import './flags_webgl'; import * as tf from '@tensorflow/tfjs-core'; -import {complex, DataId, div, engine, env, imag, MemoryInfo, range, real, RecursiveArray, scalar, softmax, tensor, tidy, TimingInfo, transpose} from '@tensorflow/tfjs-core'; +import {complex, DataId, div, engine, env, imag, max, MemoryInfo, range, real, RecursiveArray, scalar, softmax, tensor, tidy, TimingInfo, transpose} from '@tensorflow/tfjs-core'; import {backend_util, buffer, kernel_impls, slice_util, util} from '@tensorflow/tfjs-core'; import {DataStorage, DataType, KernelBackend, NumericDataType, Rank, Scalar, ShapeMap, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D, TensorInfo, TypedArray, upcastType} from '@tensorflow/tfjs-core'; @@ -1305,19 +1305,6 @@ export class MathBackendWebGL extends KernelBackend { return this.compileAndRun(program, [a, b]); } - max(x: Tensor, axes: number[]): Tensor { - if (this.shouldExecuteOnCPU([x])) { - return this.cpuBackend.max(x, axes); - } - - backend_util.assertAxesAreInnerMostDims('max', axes, x.rank); - const [outShape, reduceShape] = - backend_util.computeOutAndReduceShapes(x.shape, axes); - const inSize = util.sizeFromShape(reduceShape); - const a2D = x.as2D(-1, inSize); - return this.reduce(a2D, 'max', a2D.dtype).reshape(outShape); - } - maximum(a: Tensor, b: Tensor): Tensor { if (this.shouldExecuteOnCPU([a, b])) { return this.cpuBackend.maximum(a, b); @@ -1554,7 +1541,9 @@ export class MathBackendWebGL extends KernelBackend { softmax(logits: T, dim: number): T { const axes = util.parseAxisParam([dim], logits.shape); - const maxLogit = this.max(logits, axes); + // TODO(annxingyuan): Call maxImpl rather than op as part of softmax kernel + // modularization. + const maxLogit = max(logits, axes); const expandedShape = backend_util.expandShapeToKeepDim(maxLogit.shape, axes); const a = this.subtract(logits, maxLogit.reshape(expandedShape)); diff --git a/tfjs-backend-webgl/src/kernel_utils/reduce.ts b/tfjs-backend-webgl/src/kernel_utils/reduce.ts new file mode 100644 index 00000000000..cfc5c1e5151 --- /dev/null +++ b/tfjs-backend-webgl/src/kernel_utils/reduce.ts @@ -0,0 +1,37 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {backend_util, DataType, TensorInfo} from '@tensorflow/tfjs-core'; + +import {MathBackendWebGL} from '../backend_webgl'; +import {ReduceProgram} from '../reduce_gpu'; + +export function reduce( + x: TensorInfo, dtype: DataType, reductionType: backend_util.ReduceTypes, + backend: MathBackendWebGL): TensorInfo { + const [batchSize, inSize] = x.shape; + const windowSize = backend_util.computeOptimalWindowSize(inSize); + const reduceInfo = {windowSize, inSize, batchSize}; + const program = new ReduceProgram(reduceInfo, reductionType); + const output = backend.runWebGLProgram(program, [x], dtype); + + if (output.shape[1] === 1) { + return output; + } + + return reduce(output, dtype, reductionType, backend); +} diff --git a/tfjs-backend-webgl/src/kernel_utils/reshape.ts b/tfjs-backend-webgl/src/kernel_utils/reshape.ts new file mode 100644 index 00000000000..44417d68063 --- /dev/null +++ b/tfjs-backend-webgl/src/kernel_utils/reshape.ts @@ -0,0 +1,58 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {TensorInfo} from '@tensorflow/tfjs-core'; + +import {MathBackendWebGL} from '../backend_webgl'; +import {ReshapePackedProgram} from '../reshape_packed_gpu'; +import {getBatchDim, getRowsCols, isReshapeFree} from '../webgl_util'; + +function packedReshape( + input: TensorInfo, afterShape: number[], + backend: MathBackendWebGL): TensorInfo { + const input3DShape = + [getBatchDim(input.shape), + ...getRowsCols(input.shape)] as [number, number, number]; + const input3D: TensorInfo = { + dtype: input.dtype, + shape: input3DShape, + dataId: input.dataId + }; + const afterShapeAs3D = + [getBatchDim(afterShape), + ...getRowsCols(afterShape)] as [number, number, number]; + + const program = new ReshapePackedProgram(afterShapeAs3D, input3DShape); + const preventEagerUnpackingOfOutput = true; + const output = backend.runWebGLProgram( + program, [input3D], input.dtype, null /* customSetup */, + preventEagerUnpackingOfOutput); + return {dataId: output.dataId, shape: afterShape, dtype: output.dtype}; +} + +export function reshape( + x: TensorInfo, afterShape: number[], + backend: MathBackendWebGL): TensorInfo { + const xTexData = backend.texData.get(x.dataId); + if (xTexData.isPacked && !isReshapeFree(x.shape, afterShape) && + !(xTexData.texture !== null && + isReshapeFree(xTexData.shape, afterShape))) { + return packedReshape(x, afterShape, backend); + } + + return {dataId: x.dataId, shape: afterShape, dtype: x.dtype}; +} diff --git a/tfjs-backend-webgl/src/kernels/Max.ts b/tfjs-backend-webgl/src/kernels/Max.ts new file mode 100644 index 00000000000..c826dcccc84 --- /dev/null +++ b/tfjs-backend-webgl/src/kernels/Max.ts @@ -0,0 +1,71 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Max, MaxAttrs, MaxInputs} from '@tensorflow/tfjs-core'; +import {backend_util, KernelConfig, TypedArray, util} from '@tensorflow/tfjs-core'; + +import {MathBackendWebGL} from '../backend_webgl'; + +import {maxImpl, maxImplCPU} from './Max_impl'; +import {transposeImpl} from './Transpose_impl'; + +export const maxConfig: KernelConfig = { + kernelName: Max, + backendName: 'webgl', + kernelFunc: ({inputs, attrs, backend}) => { + const {x} = inputs as MaxInputs; + const {reductionIndices} = attrs as {} as MaxAttrs; + const webglBackend = backend as MathBackendWebGL; + + const xRank = x.shape.length; + + const origAxes = util.parseAxisParam(reductionIndices, x.shape); + let axes = origAxes; + const permutedAxes = backend_util.getAxesPermutation(axes, xRank); + const maxInputIsTransposed = permutedAxes != null; + + let maxInput = x; + if (maxInputIsTransposed) { + maxInput = transposeImpl(x, permutedAxes, webglBackend); + axes = backend_util.getInnerMostAxes(axes.length, xRank); + } + + backend_util.assertAxesAreInnerMostDims('max', axes, xRank); + const [outShape, reduceShape] = + backend_util.computeOutAndReduceShapes(maxInput.shape, axes); + + let out; + if (webglBackend.shouldExecuteOnCPU([x])) { + const xTexData = webglBackend.texData.get(maxInput.dataId); + const values = xTexData.values as TypedArray; + const outValues = maxImplCPU( + values, util.sizeFromShape(reduceShape), outShape, x.dtype); + + out = webglBackend.makeTensorInfo(outShape, x.dtype); + const outData = webglBackend.texData.get(out.dataId); + outData.values = outValues; + } else { + out = maxImpl(maxInput, reduceShape, outShape, webglBackend); + } + + if (maxInputIsTransposed) { + webglBackend.disposeData(maxInput.dataId); + } + + return out; + } +}; diff --git a/tfjs-backend-webgl/src/kernels/Max_impl.ts b/tfjs-backend-webgl/src/kernels/Max_impl.ts new file mode 100644 index 00000000000..257bddc8223 --- /dev/null +++ b/tfjs-backend-webgl/src/kernels/Max_impl.ts @@ -0,0 +1,57 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {DataType, NumericDataType, TensorInfo, TypedArray, util} from '@tensorflow/tfjs-core'; + +import {MathBackendWebGL} from '../backend_webgl'; +import {reduce} from '../kernel_utils/reduce'; +import {reshape} from '../kernel_utils/reshape'; + +export const maxImpl = + (x: TensorInfo, reduceShape: number[], outShape: number[], + backend: MathBackendWebGL): TensorInfo => { + const inSize = util.sizeFromShape(reduceShape); + const xSize = util.sizeFromShape(x.shape); + const batchSize = xSize / inSize; + + return reshape( + reduce( + reshape(x, [batchSize, inSize], backend), x.dtype, 'max', + backend), + outShape, backend); + }; + +// todo(@annxingyuan) import this from cpu backend. +export function maxImplCPU( + aVals: TypedArray, reduceSize: number, outShape: number[], + dtype: DataType): TypedArray { + const vals = util.getTypedArrayFromDType( + dtype as NumericDataType, util.sizeFromShape(outShape)); + + for (let i = 0; i < vals.length; ++i) { + const offset = i * reduceSize; + let max = aVals[offset]; + for (let j = 0; j < reduceSize; ++j) { + const value = aVals[offset + j]; + if (value > max) { + max = value; + } + } + vals[i] = max; + } + return vals; +} diff --git a/tfjs-backend-webgl/src/register_all_kernels.ts b/tfjs-backend-webgl/src/register_all_kernels.ts index e7dfa515cea..d8ad7679d1c 100644 --- a/tfjs-backend-webgl/src/register_all_kernels.ts +++ b/tfjs-backend-webgl/src/register_all_kernels.ts @@ -18,6 +18,7 @@ import {KernelConfig, registerKernel} from '@tensorflow/tfjs-core'; import {divConfig} from './kernels/Div'; import {fromPixelsConfig} from './kernels/FromPixels'; +import {maxConfig} from './kernels/Max'; import {maxPoolWithArgmaxConfig} from './kernels/MaxPoolWithArgmax'; import {nonMaxSuppressionV5Config} from './kernels/NonMaxSuppressionV5'; import {squareConfig} from './kernels/Square'; @@ -26,8 +27,9 @@ import {transposeConfig} from './kernels/Transpose'; // List all kernel configs here const kernelConfigs: KernelConfig[] = [ - fromPixelsConfig, divConfig, nonMaxSuppressionV5Config, squareConfig, - squaredDifferenceConfig, transposeConfig, maxPoolWithArgmaxConfig + maxConfig, fromPixelsConfig, divConfig, nonMaxSuppressionV5Config, + squareConfig, squaredDifferenceConfig, transposeConfig, + maxPoolWithArgmaxConfig ]; for (const kernelConfig of kernelConfigs) { diff --git a/tfjs-backend-webgl/yarn.lock b/tfjs-backend-webgl/yarn.lock index ffeeeada886..754e4d08dbc 100644 --- a/tfjs-backend-webgl/yarn.lock +++ b/tfjs-backend-webgl/yarn.lock @@ -139,11 +139,6 @@ acorn@^6.0.5: resolved "https://registry.yarnpkg.com/acorn/-/acorn-6.4.1.tgz#531e58ba3f51b9dacb9a6646ca4debf5b14ca474" integrity sha512-ZVA9k326Nwrj3Cj9jlh3wGFutC2ZornPNARZwsNYqQYgN0EsV2d53w5RN/co65Ohn4sUAUtb1rSUAOD6XN9idA== -acorn@^7.1.1: - version "7.1.1" - resolved "https://registry.yarnpkg.com/acorn/-/acorn-7.1.1.tgz#e35668de0b402f359de515c5482a1ab9f89a69bf" - integrity sha512-add7dgA5ppRPxCFJoAGfMDi7PIBXq1RtGo7BhbLaxwrXPOmw8gq48Y9ozT01hUKy9byMjlR20EJhu5zlkErEkg== - after@0.8.2: version "0.8.2" resolved "https://registry.yarnpkg.com/after/-/after-0.8.2.tgz#fedb394f9f0e02aa9768e702bda23b505fae7e1f" diff --git a/tfjs-backend-webgpu/src/backend_webgpu.ts b/tfjs-backend-webgpu/src/backend_webgpu.ts index e09724b1501..aa3d954e939 100644 --- a/tfjs-backend-webgpu/src/backend_webgpu.ts +++ b/tfjs-backend-webgpu/src/backend_webgpu.ts @@ -797,8 +797,8 @@ export class WebGPUBackend extends KernelBackend { const dimensions = [ convInfo.filterHeight, convInfo.filterWidth, ...pad, - convInfo.strideHeight, convInfo.strideWidth, - convInfo.dilationHeight, convInfo.dilationWidth + convInfo.strideHeight, convInfo.strideWidth, convInfo.dilationHeight, + convInfo.dilationWidth ]; const inputs: Tensor[] = [input, filter]; diff --git a/tfjs-core/src/gradients/Max_grad.ts b/tfjs-core/src/gradients/Max_grad.ts new file mode 100644 index 00000000000..bbc7382e0d8 --- /dev/null +++ b/tfjs-core/src/gradients/Max_grad.ts @@ -0,0 +1,48 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Max, MaxAttrs} from '../kernel_names'; +import {GradConfig, NamedAttrMap} from '../kernel_registry'; +import * as axis_util from '../ops/axis_util'; +import {gradForMinAndMax} from '../ops/reduction_ops'; +import {transpose} from '../ops/transpose'; +import {Tensor} from '../tensor'; +import * as util from '../util'; + +export const maxGradConfig: GradConfig = { + kernelName: Max, + inputsToSave: ['x'], + outputsToSave: [true], + gradFunc: (dy: Tensor, saved: Tensor[], attrs: NamedAttrMap) => { + console.log('MAX GRAD FUNC'); + const maxAttrs: MaxAttrs = attrs as {} as MaxAttrs; + const {reductionIndices} = maxAttrs; + const [x, y] = saved; + const origAxes = util.parseAxisParam(reductionIndices, x.shape); + const permutedAxes = axis_util.getAxesPermutation(origAxes, x.rank); + const maxGrad = gradForMinAndMax(dy, y, x, origAxes, permutedAxes); + return { + x: () => { + let out = maxGrad['x'](); + if (permutedAxes != null) { + out = transpose(out); + } + return out; + } + }; + } +}; diff --git a/tfjs-core/src/gradients/Transpose_grad.ts b/tfjs-core/src/gradients/Transpose_grad.ts index 4d48500d7de..4c03e1cf6b7 100644 --- a/tfjs-core/src/gradients/Transpose_grad.ts +++ b/tfjs-core/src/gradients/Transpose_grad.ts @@ -27,6 +27,14 @@ export const transposeGradConfig: GradConfig = { const transposeAttrs: TransposeAttrs = attrs as {} as TransposeAttrs; const {perm} = transposeAttrs; const undoPerm = axis_util.getUndoAxesPermutation(perm); - return {x: () => transpose(dy, undoPerm)}; + return { + x: () => { + console.log('IN TRANSPOSE GRADIENT'); + console.log(dy.shape); + const out = transpose(dy, undoPerm); + console.log(out.shape); + return out; + } + }; } }; diff --git a/tfjs-core/src/kernel_names.ts b/tfjs-core/src/kernel_names.ts index b54d4ca91d7..de8ff64a3f2 100644 --- a/tfjs-core/src/kernel_names.ts +++ b/tfjs-core/src/kernel_names.ts @@ -67,6 +67,12 @@ export interface NonMaxSuppressionV5Attrs { softNmsSigma: number; } +export const Max = 'Max'; +export type MaxInputs = Pick; +export interface MaxAttrs { + reductionIndices: number[]; +} + export const BroadcastTo = 'BroadcastTo'; export type BroadcastToInputs = Pick; export interface BroadCastToAttrs { diff --git a/tfjs-core/src/ops/conv2d_test.ts b/tfjs-core/src/ops/conv2d_test.ts index b74f97cd0e5..bec5ec5b077 100644 --- a/tfjs-core/src/ops/conv2d_test.ts +++ b/tfjs-core/src/ops/conv2d_test.ts @@ -360,25 +360,30 @@ describeWithFlags('conv2d', ALL_ENVS, () => { const pad = 'same'; const stride: [number, number] = [2, 2]; - const inputs = generateCaseInputs( - 1 * xSize * xSize * inputDepth, - fSize * fSize * inputDepth * outputDepth); - const x = tf.tensor4d(inputs.input, inputShape); - const w = - tf.tensor4d(inputs.filter, [fSize, fSize, inputDepth, outputDepth]); + // TODO(annxingyuan): Make this test work with large inputs using + // generateCaseInputs https://github.com/tensorflow/tfjs/issues/3143 + const inputData = []; + for (let i = 0; i < xSize * xSize * inputDepth; i++) { + inputData.push(i % 5); + } + + const wData = []; + for (let i = 0; i < fSize * fSize * inputDepth * outputDepth; i++) { + wData.push(i % 5); + } + + const x = tf.tensor4d(inputData, inputShape); + const w = tf.tensor4d(wData, [fSize, fSize, inputDepth, outputDepth]); const result = tf.conv2d(x, w, stride, pad); expect(result.shape).toEqual([1, 4, 4, 4]); expectArraysClose( await result.data(), new Float32Array([ - 57771, 58554, 59337, 60120, 66357, 67302, 68247, 69192, - 74943, 76050, 77157, 78264, 49071, 49890, 50709, 51528, - 126459, 128538, 130617, 132696, 135045, 137286, 139527, 141768, - 143631, 146034, 148437, 150840, 89679, 91362, 93045, 94728, - 195147, 198522, 201897, 205272, 203733, 207270, 210807, 214344, - 212319, 216018, 219717, 223416, 130287, 132834, 135381, 137928, - 105798, 108696, 111594, 114492, 109578, 112584, 115590, 118596, - 113358, 116472, 119586, 122700, 64502, 66632, 68762, 70892 + 104, 125, 126, 102, 133, 126, 104, 57, 137, 102, 57, 112, 64, + 40, 76, 92, 116, 53, 110, 142, 50, 104, 133, 137, 104, 125, + 126, 102, 83, 88, 78, 33, 133, 126, 104, 57, 137, 102, 57, + 112, 116, 53, 110, 142, 37, 76, 100, 99, 33, 68, 83, 88, + 70, 83, 76, 64, 92, 88, 64, 40, 51, 44, 27, 50 ])); }); diff --git a/tfjs-core/src/ops/max.ts b/tfjs-core/src/ops/max.ts new file mode 100644 index 00000000000..1909de35ec5 --- /dev/null +++ b/tfjs-core/src/ops/max.ts @@ -0,0 +1,89 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {KernelBackend} from '../backends/backend'; +import {ENGINE} from '../engine'; +import {Tensor} from '../tensor'; +import {GradSaveFunc} from '../tensor_types'; +import {convertToTensor} from '../tensor_util_env'; +import {TensorLike} from '../types'; +import * as util from '../util'; + +import * as axis_util from './axis_util'; +import {op} from './operation'; +import {transpose} from './transpose'; + +/** + * Computes the maximum of elements across dimensions of a `tf.Tensor`. + * + * Reduces the input along the dimensions given in `axes`. Unless `keepDims` + * is true, the rank of the `tf.Tensor` is reduced by 1 for each entry in + * `axes`. If `keepDims` is true, the reduced dimensions are retained with + * length 1. If `axes` has no entries, all dimensions are reduced, and an + * `tf.Tensor` with a single element is returned. + * + * ```js + * const x = tf.tensor1d([1, 2, 3]); + * + * x.max().print(); // or tf.max(x) + * ``` + * + * ```js + * const x = tf.tensor2d([1, 2, 3, 4], [2, 2]); + * + * const axis = 1; + * x.max(axis).print(); // or tf.max(x, axis) + * ``` + * + * @param x The input tensor. + * @param axis The dimension(s) to reduce. By default it reduces + * all dimensions. + * @param keepDims If true, retains reduced dimensions with size 1. + */ +/** @doc {heading: 'Operations', subheading: 'Reduction'} */ +function max_( + x: Tensor|TensorLike, axis: number|number[] = null, keepDims = false): T { + console.log('max op'); + console.log(x); + console.log(axis); + let $x = convertToTensor(x, 'x', 'max'); + const origAxes = util.parseAxisParam(axis, $x.shape); + + const forward = (backend: KernelBackend, save: GradSaveFunc) => { + console.log('running forward func'); + let axes = origAxes; + const permutedAxes = axis_util.getAxesPermutation(axes, $x.rank); + if (permutedAxes != null) { + $x = transpose($x, permutedAxes); + axes = axis_util.getInnerMostAxes(axes.length, $x.rank); + } + + const y = backend.max($x, axes); + save([$x, y]); + return y; + }; + + let res = ENGINE.runKernelFunc( + forward, {x: $x}, null /* gradient */, 'Max', {reductionIndices: axis}); + if (keepDims) { + const newShape = axis_util.expandShapeToKeepDim(res.shape, origAxes); + res = res.reshape(newShape) as T; + } + return res as T; +} + +export const max = op({max_}); diff --git a/tfjs-core/src/ops/ops.ts b/tfjs-core/src/ops/ops.ts index 00237b3052d..fc0b98a266e 100644 --- a/tfjs-core/src/ops/ops.ts +++ b/tfjs-core/src/ops/ops.ts @@ -27,6 +27,7 @@ export {clone} from './clone'; export {div} from './div'; export {divNoNan} from './div_no_nan'; export {eye} from './eye'; +export {max} from './max'; export {multinomial} from './multinomial'; export {notEqual} from './not_equal'; export {oneHot} from './one_hot'; diff --git a/tfjs-core/src/ops/reduce_util.ts b/tfjs-core/src/ops/reduce_util.ts index cb9f27d158a..876c784dc35 100644 --- a/tfjs-core/src/ops/reduce_util.ts +++ b/tfjs-core/src/ops/reduce_util.ts @@ -29,6 +29,8 @@ export interface ReduceInfo { inSize: number; } +export type ReduceTypes = 'all'|'any'|'max'|'min'|'sum'|'prod'; + export function computeOptimalWindowSize(inSize: number): number { if (inSize <= PARALLELIZE_THRESHOLD) { return inSize; diff --git a/tfjs-core/src/ops/reduction_ops.ts b/tfjs-core/src/ops/reduction_ops.ts index 37d8f0dfa8d..673b3f4d978 100644 --- a/tfjs-core/src/ops/reduction_ops.ts +++ b/tfjs-core/src/ops/reduction_ops.ts @@ -270,8 +270,11 @@ function mean_( /** * Gradient helper function for the min and max operations. */ -function gradForMinAndMax( +export function gradForMinAndMax( dy: T, y: T, xOrig: Tensor, origAxes: number[], permutedAxes: number[]) { + console.log('GRAD FRO MIN AND MAX'); + console.log('orig axes', origAxes); + console.log(permutedAxes); if (y.rank < xOrig.rank) { y = y.reshape(axis_util.expandShapeToKeepDim(y.shape, origAxes)) as T; } @@ -280,8 +283,11 @@ function gradForMinAndMax( } return { x: () => { + console.log('INVOKING'); const dx = dy.mul(xOrig.equal(y).cast(dy.dtype)); - return permutedAxes == null ? dx : dx.transpose(permutedAxes); + const out = permutedAxes == null ? dx : dx.transpose(permutedAxes); + console.log('out', out.shape); + return out; } }; } @@ -344,64 +350,6 @@ function min_( return res; } -/** - * Computes the maximum of elements across dimensions of a `tf.Tensor`. - * - * Reduces the input along the dimensions given in `axes`. Unless `keepDims` - * is true, the rank of the `tf.Tensor` is reduced by 1 for each entry in - * `axes`. If `keepDims` is true, the reduced dimensions are retained with - * length 1. If `axes` has no entries, all dimensions are reduced, and an - * `tf.Tensor` with a single element is returned. - * - * ```js - * const x = tf.tensor1d([1, 2, 3]); - * - * x.max().print(); // or tf.max(x) - * ``` - * - * ```js - * const x = tf.tensor2d([1, 2, 3, 4], [2, 2]); - * - * const axis = 1; - * x.max(axis).print(); // or tf.max(x, axis) - * ``` - * - * @param x The input tensor. - * @param axis The dimension(s) to reduce. By default it reduces - * all dimensions. - * @param keepDims If true, retains reduced dimensions with size 1. - */ -/** @doc {heading: 'Operations', subheading: 'Reduction'} */ -function max_( - x: Tensor|TensorLike, axis: number|number[] = null, keepDims = false): T { - let $x = convertToTensor(x, 'x', 'max'); - const xOrig = $x; - - const origAxes = util.parseAxisParam(axis, $x.shape); - let axes = origAxes; - const permutedAxes = axis_util.getAxesPermutation(axes, $x.rank); - if (permutedAxes != null) { - $x = $x.transpose(permutedAxes); - axes = axis_util.getInnerMostAxes(axes.length, $x.rank); - } - - const grad = (dy: T, saved: Tensor[]) => - gradForMinAndMax(dy, saved[1], saved[0], origAxes, permutedAxes); - - const inputsToSave = [$x]; - const outputsToSave: boolean[] = [true]; - let res = ENGINE.runKernelFunc((backend, save) => { - const y = backend.max($x, axes); - save([xOrig, y]); - return y; - }, {x: $x}, grad, 'Max', {axes}, inputsToSave, outputsToSave); - if (keepDims) { - const newShape = axis_util.expandShapeToKeepDim(res.shape, origAxes); - res = res.reshape(newShape) as T; - } - return res as T; -} - /** * Returns the indices of the minimum values along an `axis`. * @@ -625,7 +573,6 @@ export const any = op({any_}); export const argMax = op({argMax_}); export const argMin = op({argMin_}); export const logSumExp = op({logSumExp_}); -export const max = op({max_}); export const mean = op({mean_}); export const min = op({min_}); export const moments = op({moments_}); diff --git a/tfjs-core/src/public/chained_ops/max.ts b/tfjs-core/src/public/chained_ops/max.ts new file mode 100644 index 00000000000..c49ffdd7eab --- /dev/null +++ b/tfjs-core/src/public/chained_ops/max.ts @@ -0,0 +1,31 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {max} from '../../ops/max'; +import {Tensor} from '../../tensor'; +import {Rank} from '../../types'; + +declare module '../../tensor' { + interface Tensor { + max(axis?: number|number[], keepDims?: boolean): T; + } +} + +Tensor.prototype.max = function( + axis?: number|number[], keepDims?: boolean): T { + return max(this, axis, keepDims); +}; diff --git a/tfjs-core/src/public/chained_ops/register_all_chained_ops.ts b/tfjs-core/src/public/chained_ops/register_all_chained_ops.ts index 6381adc4449..ecdf860305e 100644 --- a/tfjs-core/src/public/chained_ops/register_all_chained_ops.ts +++ b/tfjs-core/src/public/chained_ops/register_all_chained_ops.ts @@ -17,6 +17,7 @@ import './add'; import './batchnorm'; import './broadcast_to'; +import './max'; import './div'; import './div_no_nan'; import './one_hot'; diff --git a/tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts b/tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts index e9175206785..1a0e1954d86 100644 --- a/tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts +++ b/tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts @@ -25,7 +25,7 @@ import {ALL_ENVS, describeWithFlags} from '../../jasmine_util'; const CHAINED_OPS = [ 'add', 'batchNorm', 'broadcastTo', 'div', 'divNoNan', 'oneHot', 'notEqual', - 'pad', 'square', 'sub', 'tile', 'transpose' + 'pad', 'square', 'sub', 'tile', 'transpose', 'max' ]; describeWithFlags('chained ops', ALL_ENVS, () => { diff --git a/tfjs-core/src/register_all_gradients.ts b/tfjs-core/src/register_all_gradients.ts index 9c440df575c..d18b9280777 100644 --- a/tfjs-core/src/register_all_gradients.ts +++ b/tfjs-core/src/register_all_gradients.ts @@ -20,6 +20,7 @@ import {broadcastToGradConfig} from './gradients/BroadcastTo_grad'; import {divGradConfig} from './gradients/Div_grad'; import {fusedBatchNormGradConfig} from './gradients/FusedBatchNorm_grad'; import {identityGradConfig} from './gradients/Identity_grad'; +import {maxGradConfig} from './gradients/Max_grad'; import {oneHotGradConfig} from './gradients/OneHot_grad'; import {padV2GradConfig} from './gradients/PadV2_grad'; import {squareGradConfig} from './gradients/Square_grad'; @@ -32,8 +33,8 @@ import {registerGradient} from './kernel_registry'; // Export all kernel configs here so that the package can auto register them const gradConfigs: GradConfig[] = [ - addGradConfig, addNGradConfig, broadcastToGradConfig, divGradConfig, - fusedBatchNormGradConfig, identityGradConfig, oneHotGradConfig, + maxGradConfig, addGradConfig, addNGradConfig, broadcastToGradConfig, + divGradConfig, fusedBatchNormGradConfig, identityGradConfig, oneHotGradConfig, padV2GradConfig, squareGradConfig, squaredDifferenceGradConfig, tileGradConfig, transposeGradConfig, subGradConfig ]; diff --git a/tfjs-core/src/tape.ts b/tfjs-core/src/tape.ts index 8f78106f150..7dd34a48479 100644 --- a/tfjs-core/src/tape.ts +++ b/tfjs-core/src/tape.ts @@ -165,6 +165,7 @@ export function backpropagateGradients( // Call the gradient function. const dx = tidy(() => inputGradients[inputName]()); + if (dx.dtype !== 'float32') { throw new Error( `Error in gradient for op ${ diff --git a/tfjs-core/src/tensor.ts b/tfjs-core/src/tensor.ts index 3759a3e8233..3561160964f 100644 --- a/tfjs-core/src/tensor.ts +++ b/tfjs-core/src/tensor.ts @@ -205,7 +205,6 @@ export interface OpHandler { mean(x: Tensor, axis: number|number[], keepDims: boolean): T; min(x: Tensor, axis: number|number[], keepDims: boolean): T; - max(x: Tensor, axis: number|number[], keepDims: boolean): T; argMin(x: Tensor, axis: number): T; argMax(x: Tensor, axis: number): T; addStrict(a: T, b: T|TensorLike): T; @@ -854,10 +853,6 @@ export class Tensor { this.throwIfDisposed(); return opHandler.min(this, axis, keepDims); } - max(axis: number|number[] = null, keepDims = false): T { - this.throwIfDisposed(); - return opHandler.max(this, axis, keepDims); - } argMin(axis: number = null): T { this.throwIfDisposed(); return opHandler.argMin(this, axis);