diff --git a/tfjs-backend-cpu/src/backend_cpu.ts b/tfjs-backend-cpu/src/backend_cpu.ts index 9d69e3d2d3c..afd7d33b08b 100644 --- a/tfjs-backend-cpu/src/backend_cpu.ts +++ b/tfjs-backend-cpu/src/backend_cpu.ts @@ -18,7 +18,7 @@ import * as tf from '@tensorflow/tfjs-core'; import {engine, env} from '@tensorflow/tfjs-core'; import {backend_util, buffer, slice_util, util} from '@tensorflow/tfjs-core'; -import {BackendTimingInfo, DataStorage, DataType, DataValues, KernelBackend, NumericDataType, Rank, Scalar, ShapeMap, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D, TensorBuffer, TypedArray, upcastType} from '@tensorflow/tfjs-core'; +import {BackendTimingInfo, DataStorage, DataType, DataValues, KernelBackend, max, NumericDataType, Rank, Scalar, ShapeMap, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D, TensorBuffer, TypedArray, upcastType} from '@tensorflow/tfjs-core'; import {kernel_impls} from '@tensorflow/tfjs-core'; const nonMaxSuppressionV3 = kernel_impls.nonMaxSuppressionV3; @@ -365,7 +365,9 @@ export class MathBackendCPU extends KernelBackend { softmax(logits: T, dim: number): T { const axes = util.parseAxisParam([dim], logits.shape); - const maxLogit = this.max(logits, axes); + // TODO(annxingyuan): Call maxImpl rather than op as part of softmax kernel + // modularization. + const maxLogit = max(logits, axes); const expandedShape = backend_util.expandShapeToKeepDim(maxLogit.shape, axes); const a = this.subtract(logits, maxLogit.reshape(expandedShape)); @@ -807,31 +809,6 @@ export class MathBackendCPU extends KernelBackend { }); } - max(x: Tensor, axes: number[]): Tensor { - assertNotComplex(x, 'max'); - - backend_util.assertAxesAreInnerMostDims('max', axes, x.rank); - const [outShape, reduceShape] = - backend_util.computeOutAndReduceShapes(x.shape, axes); - const result = tf.zeros(outShape, x.dtype); - const reduceSize = util.sizeFromShape(reduceShape); - const vals = this.readSync(result.dataId) as TypedArray; - - const aVals = this.readSync(x.dataId) as TypedArray; - for (let i = 0; i < vals.length; ++i) { - const offset = i * reduceSize; - let max = aVals[offset]; - for (let j = 0; j < reduceSize; ++j) { - const value = aVals[offset + j]; - if (value > max) { - max = value; - } - } - vals[i] = max; - } - return result; - } - maximum(a: Tensor, b: Tensor): Tensor { assertNotComplex([a, b], 'maximum'); diff --git a/tfjs-backend-cpu/src/cpu_util.ts b/tfjs-backend-cpu/src/cpu_util.ts index 630e9ba127d..28ed23e0ff1 100644 --- a/tfjs-backend-cpu/src/cpu_util.ts +++ b/tfjs-backend-cpu/src/cpu_util.ts @@ -26,7 +26,8 @@ export function assertNotComplex( if (t != null) { util.assert( t.dtype !== 'complex64', - () => `${opName} does not support complex64 tensors.`); + () => `${ + opName} does not support complex64 tensors in the CPU backend.`); } }); } diff --git a/tfjs-backend-cpu/src/kernels/Max.ts b/tfjs-backend-cpu/src/kernels/Max.ts new file mode 100644 index 00000000000..bff04ee9115 --- /dev/null +++ b/tfjs-backend-cpu/src/kernels/Max.ts @@ -0,0 +1,65 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Max, MaxAttrs, MaxInputs} from '@tensorflow/tfjs-core'; +import {backend_util, KernelConfig} from '@tensorflow/tfjs-core'; +import {TypedArray, util} from '@tensorflow/tfjs-core'; + +import {MathBackendCPU} from '../backend_cpu'; +import {assertNotComplex} from '../cpu_util'; + +import {maxImpl} from './Max_impl'; +import {transposeImpl} from './Transpose_impl'; + +export const maxConfig: KernelConfig = { + kernelName: Max, + backendName: 'cpu', + kernelFunc: ({inputs, attrs, backend}) => { + const {x} = inputs as MaxInputs; + const {reductionIndices} = attrs as {} as MaxAttrs; + const cpuBackend = backend as MathBackendCPU; + let xShape = x.shape; + const xRank = xShape.length; + + const origAxes = util.parseAxisParam(reductionIndices, xShape); + let axes = origAxes; + const permutedAxes = backend_util.getAxesPermutation(axes, xRank); + let xVals = cpuBackend.data.get(x.dataId).values as TypedArray; + if (permutedAxes != null) { + const newShape: number[] = new Array(xRank); + for (let i = 0; i < newShape.length; i++) { + newShape[i] = xShape[permutedAxes[i]]; + } + + xVals = transposeImpl(xVals, xShape, x.dtype, permutedAxes, newShape); + axes = backend_util.getInnerMostAxes(axes.length, xRank); + + xShape = newShape; + } + + assertNotComplex(x, 'max'); + backend_util.assertAxesAreInnerMostDims('max', axes, xRank); + const [maxOutShape, reduceShape] = + backend_util.computeOutAndReduceShapes(xShape, axes); + + const reduceSize = util.sizeFromShape(reduceShape); + + const result = maxImpl(xVals, reduceSize, maxOutShape, x.dtype); + const dataId = cpuBackend.write(result, maxOutShape, x.dtype); + return {dataId, shape: maxOutShape, dtype: x.dtype}; + } +}; diff --git a/tfjs-backend-cpu/src/kernels/Max_impl.ts b/tfjs-backend-cpu/src/kernels/Max_impl.ts new file mode 100644 index 00000000000..c326ac3f9df --- /dev/null +++ b/tfjs-backend-cpu/src/kernels/Max_impl.ts @@ -0,0 +1,38 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {DataType, NumericDataType, TypedArray, util} from '@tensorflow/tfjs-core'; + +export function maxImpl( + aVals: TypedArray, reduceSize: number, outShape: number[], + dtype: DataType): TypedArray { + const vals = util.getTypedArrayFromDType( + dtype as NumericDataType, util.sizeFromShape(outShape)); + + for (let i = 0; i < vals.length; ++i) { + const offset = i * reduceSize; + let max = aVals[offset]; + for (let j = 0; j < reduceSize; ++j) { + const value = aVals[offset + j]; + if (value > max) { + max = value; + } + } + vals[i] = max; + } + return vals; +} diff --git a/tfjs-backend-cpu/src/kernels/Transpose_impl.ts b/tfjs-backend-cpu/src/kernels/Transpose_impl.ts index 3a6d290b777..ffa2afa888a 100644 --- a/tfjs-backend-cpu/src/kernels/Transpose_impl.ts +++ b/tfjs-backend-cpu/src/kernels/Transpose_impl.ts @@ -21,8 +21,8 @@ import {util} from '@tensorflow/tfjs-core'; export function transposeImpl( xVals: TypedArray, xShape: number[], dtype: DataType, perm: number[], newShape: number[]): TypedArray { - const xSize = util.sizeFromShape(xShape); const xRank = xShape.length; + const xSize = util.sizeFromShape(xShape); const xStrides = util.computeStrides(xShape); const newStrides = util.computeStrides(newShape); diff --git a/tfjs-backend-cpu/src/register_all_kernels.ts b/tfjs-backend-cpu/src/register_all_kernels.ts index 8618a514683..a1a6a66af82 100644 --- a/tfjs-backend-cpu/src/register_all_kernels.ts +++ b/tfjs-backend-cpu/src/register_all_kernels.ts @@ -20,6 +20,7 @@ import {KernelConfig, registerKernel} from '@tensorflow/tfjs-core'; import {divConfig} from './kernels/Div'; +import {maxConfig} from './kernels/Max'; import {maxPoolWithArgmaxConfig} from './kernels/MaxPoolWithArgmax'; import {nonMaxSuppressionV5Config} from './kernels/NonMaxSuppressionV5'; import {squareConfig} from './kernels/Square'; @@ -29,7 +30,7 @@ import {transposeConfig} from './kernels/Transpose'; // List all kernel configs here const kernelConfigs: KernelConfig[] = [ nonMaxSuppressionV5Config, squareConfig, squaredDifferenceConfig, divConfig, - transposeConfig, maxPoolWithArgmaxConfig + transposeConfig, maxPoolWithArgmaxConfig, maxConfig ]; for (const kernelConfig of kernelConfigs) { diff --git a/tfjs-backend-cpu/src/shared.ts b/tfjs-backend-cpu/src/shared.ts index 07572130c1f..09e7d1d5abe 100644 --- a/tfjs-backend-cpu/src/shared.ts +++ b/tfjs-backend-cpu/src/shared.ts @@ -16,4 +16,5 @@ */ // Shared kernel impls for use in other backends. +export {maxImpl} from './kernels/Max_impl'; export {transposeImpl} from './kernels/Transpose_impl'; diff --git a/tfjs-backend-wasm/src/kernels/Max.ts b/tfjs-backend-wasm/src/kernels/Max.ts index ab688b76c80..d9a31a4a5c0 100644 --- a/tfjs-backend-wasm/src/kernels/Max.ts +++ b/tfjs-backend-wasm/src/kernels/Max.ts @@ -15,18 +15,11 @@ * ============================================================================= */ -import {backend_util, NamedAttrMap, NamedTensorInfoMap, registerKernel, TensorInfo, util} from '@tensorflow/tfjs-core'; +import {backend_util, registerKernel, TensorInfo, util} from '@tensorflow/tfjs-core'; +import {Max, MaxAttrs, MaxInputs} from '@tensorflow/tfjs-core'; import {BackendWasm} from '../backend_wasm'; -interface MaxInputs extends NamedTensorInfoMap { - x: TensorInfo; -} - -interface MaxAttrs extends NamedAttrMap { - axes: number[]; -} - let wasmMax: (xId: number, reduceSize: number, outId: number) => void; function setup(backend: BackendWasm): void { @@ -34,16 +27,17 @@ function setup(backend: BackendWasm): void { backend.wasm.cwrap('Max', null /*void*/, ['number, number, number']); } -function max(args: {backend: BackendWasm, inputs: MaxInputs, attrs: MaxAttrs}): - TensorInfo { +function max(args: {backend: BackendWasm, inputs: {}, attrs: {}}): TensorInfo { const {backend, inputs, attrs} = args; - const {axes} = attrs; - const {x} = inputs; + const {reductionIndices} = attrs as MaxAttrs; + const {x} = inputs as MaxInputs; const xId = backend.dataIdMap.get(x.dataId).id; - backend_util.assertAxesAreInnerMostDims('max', axes, x.shape.length); + const origAxes = util.parseAxisParam(reductionIndices, x.shape); + + backend_util.assertAxesAreInnerMostDims('max', origAxes, x.shape.length); const [outShape, reduceShape] = - backend_util.computeOutAndReduceShapes(x.shape, axes); + backend_util.computeOutAndReduceShapes(x.shape, origAxes); const reduceSize = util.sizeFromShape(reduceShape); const out = backend.makeOutput(outShape, x.dtype); @@ -54,12 +48,9 @@ function max(args: {backend: BackendWasm, inputs: MaxInputs, attrs: MaxAttrs}): const outId = backend.dataIdMap.get(out.dataId).id; wasmMax(xId, reduceSize, outId); + return out; } -registerKernel({ - kernelName: 'Max', - backendName: 'wasm', - setupFunc: setup, - kernelFunc: max -}); +registerKernel( + {kernelName: Max, backendName: 'wasm', setupFunc: setup, kernelFunc: max}); diff --git a/tfjs-backend-wasm/src/kernels/Reshape.ts b/tfjs-backend-wasm/src/kernels/Reshape.ts index f1009c6ab38..36c6fab9c54 100644 --- a/tfjs-backend-wasm/src/kernels/Reshape.ts +++ b/tfjs-backend-wasm/src/kernels/Reshape.ts @@ -28,7 +28,7 @@ interface ReshapeAttrs extends NamedAttrMap { shape: number[]; } -function reshape( +export function reshape( args: {inputs: ReshapeInputs, attrs: ReshapeAttrs, backend: BackendWasm}) { const {inputs: {x}, attrs: {shape}} = args; return {dataId: x.dataId, shape, dtype: x.dtype}; diff --git a/tfjs-backend-webgl/src/backend_webgl.ts b/tfjs-backend-webgl/src/backend_webgl.ts index 2d7fe2bc3c7..a35a304ca0c 100644 --- a/tfjs-backend-webgl/src/backend_webgl.ts +++ b/tfjs-backend-webgl/src/backend_webgl.ts @@ -19,7 +19,7 @@ import './flags_webgl'; import * as tf from '@tensorflow/tfjs-core'; -import {complex, DataId, div, engine, env, imag, MemoryInfo, range, real, RecursiveArray, scalar, softmax, tensor, tidy, TimingInfo, transpose} from '@tensorflow/tfjs-core'; +import {complex, DataId, div, engine, env, imag, max, MemoryInfo, range, real, RecursiveArray, scalar, softmax, tensor, tidy, TimingInfo, transpose} from '@tensorflow/tfjs-core'; import {backend_util, buffer, kernel_impls, slice_util, util} from '@tensorflow/tfjs-core'; import {DataStorage, DataType, KernelBackend, NumericDataType, Rank, Scalar, ShapeMap, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D, TensorInfo, TypedArray, upcastType} from '@tensorflow/tfjs-core'; @@ -1304,19 +1304,6 @@ export class MathBackendWebGL extends KernelBackend { return this.compileAndRun(program, [a, b]); } - max(x: Tensor, axes: number[]): Tensor { - if (this.shouldExecuteOnCPU([x])) { - return this.cpuBackend.max(x, axes); - } - - backend_util.assertAxesAreInnerMostDims('max', axes, x.rank); - const [outShape, reduceShape] = - backend_util.computeOutAndReduceShapes(x.shape, axes); - const inSize = util.sizeFromShape(reduceShape); - const a2D = x.as2D(-1, inSize); - return this.reduce(a2D, 'max', a2D.dtype).reshape(outShape); - } - maximum(a: Tensor, b: Tensor): Tensor { if (this.shouldExecuteOnCPU([a, b])) { return this.cpuBackend.maximum(a, b); @@ -1553,7 +1540,9 @@ export class MathBackendWebGL extends KernelBackend { softmax(logits: T, dim: number): T { const axes = util.parseAxisParam([dim], logits.shape); - const maxLogit = this.max(logits, axes); + // TODO(annxingyuan): Call maxImpl rather than op as part of softmax kernel + // modularization. + const maxLogit = max(logits, axes); const expandedShape = backend_util.expandShapeToKeepDim(maxLogit.shape, axes); const a = this.subtract(logits, maxLogit.reshape(expandedShape)); diff --git a/tfjs-backend-webgl/src/kernel_utils/reduce.ts b/tfjs-backend-webgl/src/kernel_utils/reduce.ts new file mode 100644 index 00000000000..edf6df8db73 --- /dev/null +++ b/tfjs-backend-webgl/src/kernel_utils/reduce.ts @@ -0,0 +1,39 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {backend_util, DataType, TensorInfo} from '@tensorflow/tfjs-core'; + +import {MathBackendWebGL} from '../backend_webgl'; +import {ReduceProgram} from '../reduce_gpu'; + +type ReduceTypes = 'all'|'any'|'max'|'min'|'sum'|'prod'; + +export function reduce( + x: TensorInfo, dtype: DataType, reductionType: ReduceTypes, + backend: MathBackendWebGL): TensorInfo { + const [batchSize, inSize] = x.shape; + const windowSize = backend_util.computeOptimalWindowSize(inSize); + const reduceInfo = {windowSize, inSize, batchSize}; + const program = new ReduceProgram(reduceInfo, reductionType); + const output = backend.runWebGLProgram(program, [x], dtype); + + if (output.shape[1] === 1) { + return output; + } + + return reduce(output, dtype, reductionType, backend); +} diff --git a/tfjs-backend-webgl/src/kernel_utils/reshape.ts b/tfjs-backend-webgl/src/kernel_utils/reshape.ts new file mode 100644 index 00000000000..44417d68063 --- /dev/null +++ b/tfjs-backend-webgl/src/kernel_utils/reshape.ts @@ -0,0 +1,58 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {TensorInfo} from '@tensorflow/tfjs-core'; + +import {MathBackendWebGL} from '../backend_webgl'; +import {ReshapePackedProgram} from '../reshape_packed_gpu'; +import {getBatchDim, getRowsCols, isReshapeFree} from '../webgl_util'; + +function packedReshape( + input: TensorInfo, afterShape: number[], + backend: MathBackendWebGL): TensorInfo { + const input3DShape = + [getBatchDim(input.shape), + ...getRowsCols(input.shape)] as [number, number, number]; + const input3D: TensorInfo = { + dtype: input.dtype, + shape: input3DShape, + dataId: input.dataId + }; + const afterShapeAs3D = + [getBatchDim(afterShape), + ...getRowsCols(afterShape)] as [number, number, number]; + + const program = new ReshapePackedProgram(afterShapeAs3D, input3DShape); + const preventEagerUnpackingOfOutput = true; + const output = backend.runWebGLProgram( + program, [input3D], input.dtype, null /* customSetup */, + preventEagerUnpackingOfOutput); + return {dataId: output.dataId, shape: afterShape, dtype: output.dtype}; +} + +export function reshape( + x: TensorInfo, afterShape: number[], + backend: MathBackendWebGL): TensorInfo { + const xTexData = backend.texData.get(x.dataId); + if (xTexData.isPacked && !isReshapeFree(x.shape, afterShape) && + !(xTexData.texture !== null && + isReshapeFree(xTexData.shape, afterShape))) { + return packedReshape(x, afterShape, backend); + } + + return {dataId: x.dataId, shape: afterShape, dtype: x.dtype}; +} diff --git a/tfjs-backend-webgl/src/kernel_utils/shared.ts b/tfjs-backend-webgl/src/kernel_utils/shared.ts new file mode 100644 index 00000000000..772b439a432 --- /dev/null +++ b/tfjs-backend-webgl/src/kernel_utils/shared.ts @@ -0,0 +1,22 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {shared} from '@tensorflow/tfjs-backend-cpu'; + +const {maxImpl: maxImplCPU, transposeImpl: transposeImplCPU} = shared; + +export {maxImplCPU, transposeImplCPU}; diff --git a/tfjs-backend-webgl/src/kernels/Max.ts b/tfjs-backend-webgl/src/kernels/Max.ts new file mode 100644 index 00000000000..54b193c2ebc --- /dev/null +++ b/tfjs-backend-webgl/src/kernels/Max.ts @@ -0,0 +1,91 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Max, MaxAttrs, MaxInputs} from '@tensorflow/tfjs-core'; +import {backend_util, KernelConfig, TypedArray, util} from '@tensorflow/tfjs-core'; + +import {MathBackendWebGL} from '../backend_webgl'; +import {maxImplCPU} from '../kernel_utils/shared'; + +import {maxImpl} from './Max_impl'; +import {transposeImpl, transposeImplCPU} from './Transpose_impl'; + +export const maxConfig: KernelConfig = { + kernelName: Max, + backendName: 'webgl', + kernelFunc: ({inputs, attrs, backend}) => { + const {x} = inputs as MaxInputs; + const {reductionIndices} = attrs as {} as MaxAttrs; + const webglBackend = backend as MathBackendWebGL; + + const xRank = x.shape.length; + + const origAxes = util.parseAxisParam(reductionIndices, x.shape); + let axes = origAxes; + const permutedAxes = backend_util.getAxesPermutation(axes, xRank); + const maxInputIsTransposed = permutedAxes != null; + const shouldExecuteOnCPU = webglBackend.shouldExecuteOnCPU([x]); + + let maxInput = x; + if (maxInputIsTransposed) { + if (shouldExecuteOnCPU) { + const xTexData = webglBackend.texData.get(maxInput.dataId); + const values = xTexData.values as TypedArray; + + const newShape: number[] = new Array(xRank); + for (let i = 0; i < newShape.length; i++) { + newShape[i] = x.shape[permutedAxes[i]]; + } + const maxInputValues = + transposeImplCPU(values, x.shape, x.dtype, permutedAxes, newShape); + + maxInput = webglBackend.makeTensorInfo(newShape, x.dtype); + const maxInputData = webglBackend.texData.get(maxInput.dataId); + maxInputData.values = maxInputValues; + } else { + maxInput = transposeImpl(x, permutedAxes, webglBackend); + } + + axes = backend_util.getInnerMostAxes(axes.length, xRank); + } + + backend_util.assertAxesAreInnerMostDims('max', axes, xRank); + const [maxOutShape, reduceShape] = + backend_util.computeOutAndReduceShapes(maxInput.shape, axes); + + let out; + if (shouldExecuteOnCPU) { + const xTexData = webglBackend.texData.get(maxInput.dataId); + const values = xTexData.values as TypedArray; + + const outValues = maxImplCPU( + values, util.sizeFromShape(reduceShape), maxOutShape, x.dtype); + + out = webglBackend.makeTensorInfo(maxOutShape, x.dtype); + const outData = webglBackend.texData.get(out.dataId); + outData.values = outValues; + } else { + out = maxImpl(maxInput, reduceShape, maxOutShape, webglBackend); + } + + if (maxInputIsTransposed) { + webglBackend.disposeData(maxInput.dataId); + } + + return out; + } +}; diff --git a/tfjs-backend-webgl/src/kernels/Max_impl.ts b/tfjs-backend-webgl/src/kernels/Max_impl.ts new file mode 100644 index 00000000000..04a01f1bd75 --- /dev/null +++ b/tfjs-backend-webgl/src/kernels/Max_impl.ts @@ -0,0 +1,39 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {TensorInfo, util} from '@tensorflow/tfjs-core'; + +import {MathBackendWebGL} from '../backend_webgl'; +import {reduce} from '../kernel_utils/reduce'; +import {reshape} from '../kernel_utils/reshape'; + +export function maxImpl( + x: TensorInfo, reduceShape: number[], outShape: number[], + backend: MathBackendWebGL): TensorInfo { + const inSize = util.sizeFromShape(reduceShape); + const xSize = util.sizeFromShape(x.shape); + const batchSize = xSize / inSize; + const reshapedInput = reshape(x, [batchSize, inSize], backend); + const reduced = reduce(reshapedInput, x.dtype, 'max', backend); + + if (reshapedInput.dataId !== x.dataId) { + // dispose the output of the packed reshape. + backend.disposeData(reshapedInput.dataId); + } + + return reshape(reduced, outShape, backend); +} diff --git a/tfjs-backend-webgl/src/kernels/Transpose_impl.ts b/tfjs-backend-webgl/src/kernels/Transpose_impl.ts index f7ee6eba1ac..1b0680bc7e2 100644 --- a/tfjs-backend-webgl/src/kernels/Transpose_impl.ts +++ b/tfjs-backend-webgl/src/kernels/Transpose_impl.ts @@ -15,10 +15,10 @@ * ============================================================================= */ -import {shared} from '@tensorflow/tfjs-backend-cpu'; import {env, TensorInfo} from '@tensorflow/tfjs-core'; import {MathBackendWebGL} from '../backend_webgl'; +import {transposeImplCPU} from '../kernel_utils/shared'; import {TransposeProgram} from '../transpose_gpu'; import {TransposePackedProgram} from '../transpose_packed_gpu'; @@ -30,4 +30,4 @@ export function transposeImpl( return backend.runWebGLProgram(program, [x], x.dtype); } -export const transposeImplCPU = shared.transposeImpl; +export {transposeImplCPU}; diff --git a/tfjs-backend-webgl/src/register_all_kernels.ts b/tfjs-backend-webgl/src/register_all_kernels.ts index e7dfa515cea..d8ad7679d1c 100644 --- a/tfjs-backend-webgl/src/register_all_kernels.ts +++ b/tfjs-backend-webgl/src/register_all_kernels.ts @@ -18,6 +18,7 @@ import {KernelConfig, registerKernel} from '@tensorflow/tfjs-core'; import {divConfig} from './kernels/Div'; import {fromPixelsConfig} from './kernels/FromPixels'; +import {maxConfig} from './kernels/Max'; import {maxPoolWithArgmaxConfig} from './kernels/MaxPoolWithArgmax'; import {nonMaxSuppressionV5Config} from './kernels/NonMaxSuppressionV5'; import {squareConfig} from './kernels/Square'; @@ -26,8 +27,9 @@ import {transposeConfig} from './kernels/Transpose'; // List all kernel configs here const kernelConfigs: KernelConfig[] = [ - fromPixelsConfig, divConfig, nonMaxSuppressionV5Config, squareConfig, - squaredDifferenceConfig, transposeConfig, maxPoolWithArgmaxConfig + maxConfig, fromPixelsConfig, divConfig, nonMaxSuppressionV5Config, + squareConfig, squaredDifferenceConfig, transposeConfig, + maxPoolWithArgmaxConfig ]; for (const kernelConfig of kernelConfigs) { diff --git a/tfjs-backend-webgpu/src/backend_webgpu.ts b/tfjs-backend-webgpu/src/backend_webgpu.ts index e09724b1501..aa3d954e939 100644 --- a/tfjs-backend-webgpu/src/backend_webgpu.ts +++ b/tfjs-backend-webgpu/src/backend_webgpu.ts @@ -797,8 +797,8 @@ export class WebGPUBackend extends KernelBackend { const dimensions = [ convInfo.filterHeight, convInfo.filterWidth, ...pad, - convInfo.strideHeight, convInfo.strideWidth, - convInfo.dilationHeight, convInfo.dilationWidth + convInfo.strideHeight, convInfo.strideWidth, convInfo.dilationHeight, + convInfo.dilationWidth ]; const inputs: Tensor[] = [input, filter]; diff --git a/tfjs-core/src/gradients/Max_grad.ts b/tfjs-core/src/gradients/Max_grad.ts new file mode 100644 index 00000000000..0a126ea223c --- /dev/null +++ b/tfjs-core/src/gradients/Max_grad.ts @@ -0,0 +1,47 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Max, MaxAttrs} from '../kernel_names'; +import {GradConfig, NamedAttrMap} from '../kernel_registry'; +import * as axis_util from '../ops/axis_util'; +import {gradForMinAndMax} from '../ops/reduction_ops_util'; +import {transpose} from '../ops/transpose'; +import {Tensor} from '../tensor'; +import * as util from '../util'; + +export const maxGradConfig: GradConfig = { + kernelName: Max, + inputsToSave: ['x'], + outputsToSave: [true], + gradFunc: (dy: Tensor, saved: Tensor[], attrs: NamedAttrMap) => { + const maxAttrs: MaxAttrs = attrs as {} as MaxAttrs; + const {reductionIndices} = maxAttrs; + const [x, y] = saved; + const origAxes = util.parseAxisParam(reductionIndices, x.shape); + const permutedAxes = axis_util.getAxesPermutation(origAxes, x.rank); + const maxGrad = gradForMinAndMax(dy, y, x, origAxes, permutedAxes); + return { + x: () => { + let out = maxGrad['x'](); + if (permutedAxes != null) { + out = transpose(out); + } + return out; + } + }; + } +}; diff --git a/tfjs-core/src/kernel_names.ts b/tfjs-core/src/kernel_names.ts index cc800ee76b7..aaafc8813b6 100644 --- a/tfjs-core/src/kernel_names.ts +++ b/tfjs-core/src/kernel_names.ts @@ -164,6 +164,13 @@ export interface NonMaxSuppressionV5Attrs { softNmsSigma: number; } +export const Max = 'Max'; +export type MaxInputs = Pick; +export interface MaxAttrs { + reductionIndices: number|number[]; + keepDims: boolean; +} + export const OneHot = 'OneHot'; export type OneHotInputs = Pick; export interface OneHotAttrs { diff --git a/tfjs-core/src/ops/max.ts b/tfjs-core/src/ops/max.ts new file mode 100644 index 00000000000..8a26f957ef0 --- /dev/null +++ b/tfjs-core/src/ops/max.ts @@ -0,0 +1,98 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {KernelBackend} from '../backends/backend'; +import {ENGINE, ForwardFunc} from '../engine'; +import {Max, MaxAttrs, MaxInputs} from '../kernel_names'; +import {NamedAttrMap} from '../kernel_registry'; +import {Tensor} from '../tensor'; +import {GradSaveFunc, NamedTensorMap} from '../tensor_types'; +import {convertToTensor} from '../tensor_util_env'; +import {TensorLike} from '../types'; +import * as util from '../util'; +import {reshape} from './array_ops'; +import * as axis_util from './axis_util'; +import {op} from './operation'; +import {transpose} from './transpose'; + +/** + * Computes the maximum of elements across dimensions of a `tf.Tensor`. + * + * Reduces the input along the dimensions given in `axes`. Unless `keepDims` + * is true, the rank of the `tf.Tensor` is reduced by 1 for each entry in + * `axes`. If `keepDims` is true, the reduced dimensions are retained with + * length 1. If `axes` has no entries, all dimensions are reduced, and an + * `tf.Tensor` with a single element is returned. + * + * ```js + * const x = tf.tensor1d([1, 2, 3]); + * + * x.max().print(); // or tf.max(x) + * ``` + * + * ```js + * const x = tf.tensor2d([1, 2, 3, 4], [2, 2]); + * + * const axis = 1; + * x.max(axis).print(); // or tf.max(x, axis) + * ``` + * + * @param x The input tensor. + * @param axis The dimension(s) to reduce. By default it reduces + * all dimensions. + * @param keepDims If true, retains reduced dimensions with size 1. + */ +/** @doc {heading: 'Operations', subheading: 'Reduction'} */ +function max_( + x: Tensor|TensorLike, axis: number|number[] = null, keepDims = false): T { + const $x = convertToTensor(x, 'x', 'max'); + const forward: ForwardFunc = + (backend: KernelBackend, save: GradSaveFunc) => { + const origAxes = util.parseAxisParam(axis, $x.shape); + let axes = origAxes; + const permutedAxes = axis_util.getAxesPermutation(axes, $x.rank); + let maxInput = $x; + if (permutedAxes != null) { + maxInput = transpose($x, permutedAxes); + axes = axis_util.getInnerMostAxes(axes.length, maxInput.rank); + } + + const y = backend.max(maxInput, axes); + save([$x, y]); + + if (permutedAxes != null) { + backend.disposeData(maxInput.dataId); + } + + return y; + }; + const inputs: MaxInputs = {x: $x}; + const attrs: MaxAttrs = {reductionIndices: axis, keepDims}; + + const res = ENGINE.runKernelFunc( + forward, inputs as {} as NamedTensorMap, null /* gradient */, + Max, attrs as {} as NamedAttrMap) as T; + if (keepDims) { + return reshape( + res, + axis_util.expandShapeToKeepDim( + res.shape, util.parseAxisParam(axis, $x.shape))) as T; + } + return res; +} + +export const max = op({max_}); diff --git a/tfjs-core/src/ops/ops.ts b/tfjs-core/src/ops/ops.ts index 878623699d5..f280ca94677 100644 --- a/tfjs-core/src/ops/ops.ts +++ b/tfjs-core/src/ops/ops.ts @@ -43,6 +43,7 @@ export {greater} from './greater'; export {greaterEqual} from './greater_equal'; export {less} from './less'; export {lessEqual} from './less_equal'; +export {max} from './max'; export {multinomial} from './multinomial'; export {notEqual} from './not_equal'; export {oneHot} from './one_hot'; diff --git a/tfjs-core/src/ops/reduction_ops.ts b/tfjs-core/src/ops/reduction_ops.ts index 37d8f0dfa8d..3f871d8a111 100644 --- a/tfjs-core/src/ops/reduction_ops.ts +++ b/tfjs-core/src/ops/reduction_ops.ts @@ -21,8 +21,10 @@ import {Tensor} from '../tensor'; import {convertToTensor} from '../tensor_util_env'; import {TensorLike} from '../types'; import * as util from '../util'; + import * as axis_util from './axis_util'; import {op} from './operation'; +import {gradForMinAndMax} from './reduction_ops_util'; import {ones, scalar, zerosLike} from './tensor_ops'; /** @@ -267,25 +269,6 @@ function mean_( return customOp($x) as T; } -/** - * Gradient helper function for the min and max operations. - */ -function gradForMinAndMax( - dy: T, y: T, xOrig: Tensor, origAxes: number[], permutedAxes: number[]) { - if (y.rank < xOrig.rank) { - y = y.reshape(axis_util.expandShapeToKeepDim(y.shape, origAxes)) as T; - } - if (dy.rank < xOrig.rank) { - dy = dy.reshape(axis_util.expandShapeToKeepDim(dy.shape, origAxes)) as T; - } - return { - x: () => { - const dx = dy.mul(xOrig.equal(y).cast(dy.dtype)); - return permutedAxes == null ? dx : dx.transpose(permutedAxes); - } - }; -} - /** * Computes the minimum value from the input. * @@ -344,64 +327,6 @@ function min_( return res; } -/** - * Computes the maximum of elements across dimensions of a `tf.Tensor`. - * - * Reduces the input along the dimensions given in `axes`. Unless `keepDims` - * is true, the rank of the `tf.Tensor` is reduced by 1 for each entry in - * `axes`. If `keepDims` is true, the reduced dimensions are retained with - * length 1. If `axes` has no entries, all dimensions are reduced, and an - * `tf.Tensor` with a single element is returned. - * - * ```js - * const x = tf.tensor1d([1, 2, 3]); - * - * x.max().print(); // or tf.max(x) - * ``` - * - * ```js - * const x = tf.tensor2d([1, 2, 3, 4], [2, 2]); - * - * const axis = 1; - * x.max(axis).print(); // or tf.max(x, axis) - * ``` - * - * @param x The input tensor. - * @param axis The dimension(s) to reduce. By default it reduces - * all dimensions. - * @param keepDims If true, retains reduced dimensions with size 1. - */ -/** @doc {heading: 'Operations', subheading: 'Reduction'} */ -function max_( - x: Tensor|TensorLike, axis: number|number[] = null, keepDims = false): T { - let $x = convertToTensor(x, 'x', 'max'); - const xOrig = $x; - - const origAxes = util.parseAxisParam(axis, $x.shape); - let axes = origAxes; - const permutedAxes = axis_util.getAxesPermutation(axes, $x.rank); - if (permutedAxes != null) { - $x = $x.transpose(permutedAxes); - axes = axis_util.getInnerMostAxes(axes.length, $x.rank); - } - - const grad = (dy: T, saved: Tensor[]) => - gradForMinAndMax(dy, saved[1], saved[0], origAxes, permutedAxes); - - const inputsToSave = [$x]; - const outputsToSave: boolean[] = [true]; - let res = ENGINE.runKernelFunc((backend, save) => { - const y = backend.max($x, axes); - save([xOrig, y]); - return y; - }, {x: $x}, grad, 'Max', {axes}, inputsToSave, outputsToSave); - if (keepDims) { - const newShape = axis_util.expandShapeToKeepDim(res.shape, origAxes); - res = res.reshape(newShape) as T; - } - return res as T; -} - /** * Returns the indices of the minimum values along an `axis`. * @@ -625,7 +550,6 @@ export const any = op({any_}); export const argMax = op({argMax_}); export const argMin = op({argMin_}); export const logSumExp = op({logSumExp_}); -export const max = op({max_}); export const mean = op({mean_}); export const min = op({min_}); export const moments = op({moments_}); diff --git a/tfjs-core/src/ops/reduction_ops_util.ts b/tfjs-core/src/ops/reduction_ops_util.ts new file mode 100644 index 00000000000..375ae223f89 --- /dev/null +++ b/tfjs-core/src/ops/reduction_ops_util.ts @@ -0,0 +1,38 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Tensor} from '../tensor'; +import * as axis_util from './axis_util'; + +/** + * Gradient helper function for the min and max operations. + */ +export function gradForMinAndMax( + dy: T, y: T, xOrig: Tensor, origAxes: number[], permutedAxes: number[]) { + if (y.rank < xOrig.rank) { + y = y.reshape(axis_util.expandShapeToKeepDim(y.shape, origAxes)) as T; + } + if (dy.rank < xOrig.rank) { + dy = dy.reshape(axis_util.expandShapeToKeepDim(dy.shape, origAxes)) as T; + } + return { + x: () => { + const dx = dy.mul(xOrig.equal(y).cast(dy.dtype)); + return permutedAxes == null ? dx : dx.transpose(permutedAxes); + } + }; +} diff --git a/tfjs-core/src/public/chained_ops/max.ts b/tfjs-core/src/public/chained_ops/max.ts new file mode 100644 index 00000000000..4b0cf030413 --- /dev/null +++ b/tfjs-core/src/public/chained_ops/max.ts @@ -0,0 +1,32 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {max} from '../../ops/max'; +import {Tensor} from '../../tensor'; +import {Rank} from '../../types'; + +declare module '../../tensor' { + interface Tensor { + max(axis?: number|number[], keepDims?: boolean): T; + } +} + +Tensor.prototype.max = function( + axis?: number|number[], keepDims?: boolean): T { + this.throwIfDisposed(); + return max(this, axis, keepDims); +}; diff --git a/tfjs-core/src/public/chained_ops/register_all_chained_ops.ts b/tfjs-core/src/public/chained_ops/register_all_chained_ops.ts index f4e56813132..dca44625289 100644 --- a/tfjs-core/src/public/chained_ops/register_all_chained_ops.ts +++ b/tfjs-core/src/public/chained_ops/register_all_chained_ops.ts @@ -17,6 +17,7 @@ import './add'; import './batchnorm'; import './broadcast_to'; +import './max'; import './concat'; import './conv1d'; import './conv2d'; diff --git a/tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts b/tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts index 0cff805baec..7188841fea0 100644 --- a/tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts +++ b/tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts @@ -20,7 +20,7 @@ import {ALL_ENVS, describeWithFlags} from '../../jasmine_util'; // Testing for presence of chained op in this file will allow us to more easily // customize when we want this test to run. Currently it will run be default -// (And kerma will always load the chain augmentor files). But this gives us +// (And karma will always load the chain augmentor files). But this gives us // flexibility to change in future. const CHAINED_OPS = [ @@ -43,6 +43,7 @@ const CHAINED_OPS = [ 'notEqual', 'oneHot', 'pad', + 'max', 'separableConv2d', 'split', 'square', diff --git a/tfjs-core/src/register_all_gradients.ts b/tfjs-core/src/register_all_gradients.ts index 0bbe37a8673..a425f50f391 100644 --- a/tfjs-core/src/register_all_gradients.ts +++ b/tfjs-core/src/register_all_gradients.ts @@ -26,6 +26,7 @@ import {divGradConfig} from './gradients/Div_grad'; import {fusedBatchNormGradConfig} from './gradients/FusedBatchNorm_grad'; import {greaterEqualGradConfig} from './gradients/GreaterEqual_grad'; import {identityGradConfig} from './gradients/Identity_grad'; +import {maxGradConfig} from './gradients/Max_grad'; import {oneHotGradConfig} from './gradients/OneHot_grad'; import {padV2GradConfig} from './gradients/PadV2_grad'; import {splitVGradConfig} from './gradients/SplitV_grad'; @@ -39,25 +40,16 @@ import {registerGradient} from './kernel_registry'; // Export all kernel configs here so that the package can auto register them const gradConfigs: GradConfig[] = [ - addGradConfig, - addNGradConfig, - broadcastToGradConfig, - concatGradConfig, - conv2DGradConfig, - conv2DBackpropInputGradConfig, - conv3DGradConfig, - depthwiseConv2dNativeGradConfig, - divGradConfig, - fusedBatchNormGradConfig, - greaterEqualGradConfig, - identityGradConfig, - oneHotGradConfig, - padV2GradConfig, - splitVGradConfig, - squareGradConfig, - squaredDifferenceGradConfig, - tileGradConfig, - transposeGradConfig, + addGradConfig, addNGradConfig, + broadcastToGradConfig, concatGradConfig, + conv2DGradConfig, conv2DBackpropInputGradConfig, + conv3DGradConfig, depthwiseConv2dNativeGradConfig, + divGradConfig, fusedBatchNormGradConfig, + greaterEqualGradConfig, identityGradConfig, + oneHotGradConfig, padV2GradConfig, + splitVGradConfig, maxGradConfig, + squareGradConfig, squaredDifferenceGradConfig, + tileGradConfig, transposeGradConfig, subGradConfig ]; diff --git a/tfjs-core/src/tensor.ts b/tfjs-core/src/tensor.ts index ebabaf42e63..095df29d2ab 100644 --- a/tfjs-core/src/tensor.ts +++ b/tfjs-core/src/tensor.ts @@ -202,7 +202,6 @@ export interface OpHandler { mean(x: Tensor, axis: number|number[], keepDims: boolean): T; min(x: Tensor, axis: number|number[], keepDims: boolean): T; - max(x: Tensor, axis: number|number[], keepDims: boolean): T; argMin(x: Tensor, axis: number): T; argMax(x: Tensor, axis: number): T; addStrict(a: T, b: T|TensorLike): T; @@ -810,10 +809,6 @@ export class Tensor { this.throwIfDisposed(); return opHandler.min(this, axis, keepDims); } - max(axis: number|number[] = null, keepDims = false): T { - this.throwIfDisposed(); - return opHandler.max(this, axis, keepDims); - } argMin(axis: number = null): T { this.throwIfDisposed(); return opHandler.argMin(this, axis); diff --git a/tfjs-data/yarn.lock b/tfjs-data/yarn.lock index 28988ae2a4f..9c05c22f06a 100644 --- a/tfjs-data/yarn.lock +++ b/tfjs-data/yarn.lock @@ -3262,6 +3262,11 @@ minimist@^1.1.0, minimist@^1.1.3, minimist@^1.2.0: resolved "https://registry.yarnpkg.com/minimist/-/minimist-1.2.0.tgz#a35008b20f41383eec1fb914f4cd5df79a264284" integrity sha1-o1AIsg9BOD7sH7kU9M1d95omQoQ= +minimist@^1.2.5: + version "1.2.5" + resolved "https://registry.yarnpkg.com/minimist/-/minimist-1.2.5.tgz#67d66014b66a6a8aaa0c083c5fd58df4e4e97602" + integrity sha512-FM9nNUYrRBAELZQT3xeZQ7fmMOBg6nWNmJKTcgsJeaLstP/UODVpGsr5OhXhhXg6f+qtJ8uiZ+PUxkDWcgIXLw== + minimist@~0.0.1: version "0.0.10" resolved "https://registry.yarnpkg.com/minimist/-/minimist-0.0.10.tgz#de3f98543dbf96082be48ad1a0c7cda836301dcf" diff --git a/tfjs-layers/yarn.lock b/tfjs-layers/yarn.lock index 7cb0a9202cf..17fe70a34f2 100644 --- a/tfjs-layers/yarn.lock +++ b/tfjs-layers/yarn.lock @@ -3037,6 +3037,11 @@ minimist@^1.1.0, minimist@^1.1.3: resolved "https://registry.yarnpkg.com/minimist/-/minimist-1.2.0.tgz#a35008b20f41383eec1fb914f4cd5df79a264284" integrity sha1-o1AIsg9BOD7sH7kU9M1d95omQoQ= +minimist@^1.2.5: + version "1.2.5" + resolved "https://registry.yarnpkg.com/minimist/-/minimist-1.2.5.tgz#67d66014b66a6a8aaa0c083c5fd58df4e4e97602" + integrity sha512-FM9nNUYrRBAELZQT3xeZQ7fmMOBg6nWNmJKTcgsJeaLstP/UODVpGsr5OhXhhXg6f+qtJ8uiZ+PUxkDWcgIXLw== + minimist@~0.0.1: version "0.0.10" resolved "https://registry.yarnpkg.com/minimist/-/minimist-0.0.10.tgz#de3f98543dbf96082be48ad1a0c7cda836301dcf"