From 5bd75468225e91b687c945a10a43667b9a490ceb Mon Sep 17 00:00:00 2001 From: NALLEIN Date: Tue, 24 Dec 2019 15:33:56 +0800 Subject: [PATCH 1/7] Add reduce and made some changes as requested --- tfjs-backend-webgpu/src/backend_webgpu.ts | 52 +++++++- .../src/kernels/reduce_webgpu.ts | 118 ++++++++++++++++++ .../src/kernels/unary_op_webgpu.ts | 1 + tfjs-backend-webgpu/src/setup_test.ts | 39 ++++++ 4 files changed, 209 insertions(+), 1 deletion(-) create mode 100644 tfjs-backend-webgpu/src/kernels/reduce_webgpu.ts diff --git a/tfjs-backend-webgpu/src/backend_webgpu.ts b/tfjs-backend-webgpu/src/backend_webgpu.ts index eff9036ea1f..8b124a67813 100644 --- a/tfjs-backend-webgpu/src/backend_webgpu.ts +++ b/tfjs-backend-webgpu/src/backend_webgpu.ts @@ -20,11 +20,16 @@ import './flags_webgpu'; import {backend_util, DataStorage, DataType, engine, env, findBackend, KernelBackend, Rank, RecursiveArray, ShapeMap, slice_util, Tensor, Tensor2D, Tensor3D, Tensor4D, TimingInfo, util} from '@tensorflow/tfjs-core'; -import {Glslang} from '@webgpu/glslang/dist/web-devel/glslang.onefile'; // TODO(xing.xu): use FusedConv2DConfig from backend_util: // https://github.com/tensorflow/tfjs/issues/2471 // tslint:disable-next-line: no-imports-from-dist import {FusedConv2DConfig} from '@tensorflow/tfjs-core/dist/ops/fused_util'; +// TODO TODO to import computeOptimalWindowSize and other reduce_util function +// from backend_util, rather than from 'dist' +import {assertAxesAreInnerMostDims, computeOutAndReduceShapes} from '@tensorflow/tfjs-core/src/ops/axis_util'; +import {computeOptimalWindowSize} from '@tensorflow/tfjs-core/src/ops/reduce_util'; +import {sumOutType} from '@tensorflow/tfjs-core/src/types' +import {Glslang} from '@webgpu/glslang/dist/web-devel/glslang.onefile'; import {BufferManager} from './buffer_manager'; import {ArgMinMaxProgram} from './kernels/argminmax_webgpu'; @@ -41,6 +46,7 @@ import {MatMulPackedProgram} from './kernels/matmul_packed_webgpu'; import {MatMulProgram} from './kernels/matmul_webgpu'; import {MaxPoolProgram} from './kernels/maxpool_webgpu'; import {PadProgram} from './kernels/pad_webgpu'; +import {ReduceProgram} from './kernels/reduce_webgpu'; import {ResizeBilinearProgram} from './kernels/resize_bilinear_webgpu'; import {SelectProgram} from './kernels/select_webgpu'; import {SliceProgram} from './kernels/slice_webgpu'; @@ -825,6 +831,42 @@ export class WebGPUBackend extends KernelBackend { return this.argMinMaxReduce(x, axis, 'max'); } + private reduce(x: Tensor2D, reduceType: 'max'|'min'|'sum', dtype: DataType): + Tensor2D { + const batchSize = x.shape[0]; + const inSize = x.shape[1]; + const windowSize = computeOptimalWindowSize(inSize); + const reduceInfo = {windowSize, inSize, batchSize}; + const program = new ReduceProgram(reduceInfo, reduceType); + const output = this.makeOutputArray(program.outputShape, dtype); + return this.compileAndRun(program, [x], output); + } + + max(x: Tensor, axes: number[]): Tensor { + assertAxesAreInnerMostDims('max', axes, x.rank); + const [outShape, reduceShape] = computeOutAndReduceShapes(x.shape, axes); + const reduceSize = util.sizeFromShape(reduceShape); + const a2D = x.as2D(-1, reduceSize); + return this.reduce(a2D, 'max', a2D.dtype).reshape(outShape); + } + + min(x: Tensor, axes: number[]): Tensor { + assertAxesAreInnerMostDims('min', axes, x.rank); + const [outShape, reduceShape] = computeOutAndReduceShapes(x.shape, axes); + const reduceSize = util.sizeFromShape(reduceShape); + const a2D = x.as2D(-1, reduceSize); + return this.reduce(a2D, 'min', a2D.dtype).reshape(outShape); + } + + sum(x: Tensor, axes: number[]): Tensor { + assertAxesAreInnerMostDims('sum', axes, x.rank); + const [outShape, reduceShape] = computeOutAndReduceShapes(x.shape, axes); + const reduceSize = util.sizeFromShape(reduceShape); + const a2D = x.as2D(-1, reduceSize); + const outputDType = sumOutType(x.dtype); + return this.reduce(a2D, 'sum', outputDType).reshape(outShape); + } + clip(x: T, min: number, max: number): T { const program = new ClipProgram(x.shape, min, max); return this.compileAndRun(program, [x]); @@ -916,6 +958,14 @@ export class WebGPUBackend extends KernelBackend { return this.compileAndRun(program, [x]); } + abs(x: T): T { + if (this.shouldExecuteOnCPU([x])) { + return this.cpuBackend.abs(x); + } + const program = new UnaryOpProgram(x.shape, unary_op.ABS); + return this.compileAndRun(program, [x]); + } + prelu(x: T, alpha: T): T { const program = new BinaryOpProgram(binary_op.PRELU, x.shape, alpha.shape); return this.compileAndRun(program, [x, alpha]); diff --git a/tfjs-backend-webgpu/src/kernels/reduce_webgpu.ts b/tfjs-backend-webgpu/src/kernels/reduce_webgpu.ts new file mode 100644 index 00000000000..99fcfb05033 --- /dev/null +++ b/tfjs-backend-webgpu/src/kernels/reduce_webgpu.ts @@ -0,0 +1,118 @@ +/** + * @license + * Copyright 2019 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {backend_util, util} from '@tensorflow/tfjs-core'; +import {ReduceInfo} from '@tensorflow/tfjs-core/src/ops/reduce_util'; +import {getCoordsDataType} from '../shader_preprocessor'; +import {computeDispatch} from '../webgpu_util'; +import {WebGPUProgram} from './webgpu_program'; + +export class ReduceProgram implements WebGPUProgram { + outputShape: number[]; + shaderKey: string; + userCode: string; + dispatchLayout: {x: number[], y: number[]}; + dispatch: [number, number, number]; + workGroupSize: [number, number, number]; + variableNames = ['x']; + + constructor(reduceInfo: ReduceInfo, reduceType: 'max'|'min'|'sum') { + const inputShape = [reduceInfo.batchSize, reduceInfo.inSize]; + const [outputShape, reduceShape] = + backend_util.computeOutAndReduceShapes(inputShape, [1]); + this.outputShape = outputShape.length === 0 ? [1] : outputShape; + const reduceSize = util.sizeFromShape(reduceShape); + + const reductionFactor = 2; + const xMaxThreads = 1024; + const xThreads = + Math.min(Math.ceil(reduceSize / reductionFactor), xMaxThreads); + + this.workGroupSize = [xThreads, 1, 1]; + this.dispatchLayout = {x: [], y: this.outputShape.map((d, i) => i)}; + this.dispatch = computeDispatch(this.dispatchLayout, this.outputShape); + const reduceInSharedMemory = xThreads > 1; + + const minmaxOp = ` + if (candidate ${reduceType === 'min' ? '<' : '>'} bestValue + && !isnan(candidate)) + { bestValue = candidate; } + `; + const sumOp = ' bestValue += candidate; '; + const op = + (reduceType === 'min' || reduceType === 'max') ? minmaxOp : sumOp; + + const sharedMemorySnippet = ` + shared float xBestValues[WorkGroupSize]; + `; + const sharedMemoryReduceSnippet = ` + xBestValues[gl_LocalInvocationID.x] = bestValue; + ${reduceType === 'sum' ? 'bestValue=0;' : ' '} + int currentSize = WorkGroupSize; + while (currentSize > 1) { + barrier(); + for (int w = 0; w < ${reductionFactor}; ++w) { + int i = int(gl_LocalInvocationID.x) * ${reductionFactor} + w; + if (i < currentSize) { + float candidate = xBestValues[i]; + ${op} + } + } + xBestValues[gl_LocalInvocationID.x] = bestValue; + currentSize = DIV_CEIL(currentSize, ${reductionFactor}); + ${reduceType === 'sum' ? 'if(currentSize > 1) bestValue=0;' : ''} + } + if (gl_LocalInvocationID.x == 0) { + setOutput(flatOutputIndex, bestValue); + } + `; + + const outputCoordsType = getCoordsDataType(this.outputShape.length); + + this.userCode = ` + #define DIV_CEIL(x, y) (((x) - 1) / (y) + 1) + const int WorkGroupSize = int(gl_WorkGroupSize.x); + ${reduceInSharedMemory ? sharedMemorySnippet : ''} + int getOffset() { + const ${outputCoordsType} outputCoords = getOutputCoords(); + int offset = ${ + this.outputShape.length === 1 ? 'outputCoords' : + 'outputCoords[0]'} * xShape[1]; + return offset; + } + void main() { + const int offset= getOffset(); + ${ + reduceType === 'sum' ? 'float bestValue = 0;' : + 'float bestValue = x[offset];'} + const int Length = ${inputShape.length === 1 ? 'xShape' : 'xShape[1]'}; + const int WorkPerThread = DIV_CEIL(Length, WorkGroupSize); + for (int w = 0; w < WorkPerThread; ++w) { + int i = int(gl_GlobalInvocationID.x) * WorkPerThread + w; + if (i < Length) { + float candidate = x[offset + i]; + ${(reduceType === 'max' || reduceType === 'min') ? minmaxOp : sumOp} + } + } + const int flatOutputIndex = int(gl_GlobalInvocationID.y); + ${ + reduceInSharedMemory ? sharedMemoryReduceSnippet : + 'setOutput(flatOutputIndex, bestValue);'} + } + `; + } +} diff --git a/tfjs-backend-webgpu/src/kernels/unary_op_webgpu.ts b/tfjs-backend-webgpu/src/kernels/unary_op_webgpu.ts index 90c0fe3d3a4..a3c2cd7f6d3 100644 --- a/tfjs-backend-webgpu/src/kernels/unary_op_webgpu.ts +++ b/tfjs-backend-webgpu/src/kernels/unary_op_webgpu.ts @@ -27,6 +27,7 @@ export const LINEAR = `return x;`; export const ELU = `return (x >= 0.0) ? x : (exp(x) - 1.0);`; export const SIGMOID = `return 1.0 / (1.0 + exp(-1.0 * a));`; +export const ABS = `return abs(a);`; export class UnaryOpProgram implements WebGPUProgram { outputShape: number[]; diff --git a/tfjs-backend-webgpu/src/setup_test.ts b/tfjs-backend-webgpu/src/setup_test.ts index e47fd960f7f..556d74bfe9a 100644 --- a/tfjs-backend-webgpu/src/setup_test.ts +++ b/tfjs-backend-webgpu/src/setup_test.ts @@ -283,6 +283,45 @@ const TEST_FILTERS: TestFilter[] = [ 'grad', // 'depthwiseConv2DDerFilter' not yet implemented, slice not yet // implemented ] + }, + { + include: 'Reduction: max', + excludes: [ + '5D', // Rank 5 is not yet implemented. + '6D', // Rank 5 is not yet implemented. + 'accepts tensor with bool', // Actual != Expected. + 'gradient', // zerosLike not yet implemented. + ] + }, + { + include: 'Reduction: min', + excludes: [ + '5D', // Rank 5 is not yet implemented. + '6D', // Rank 5 is not yet implemented. + 'accepts tensor with bool', // Actual != Expected. + 'gradient', // zerosLike not yet implemented. + ] + }, + { + include: 'Reduction: sum', + excludes: [ + 'dtype bool', // not support dtype bool yet. + '5D', // Rank 5 is not yet implemented. + '6D', // Rank 5 is not yet implemented. + 'accepts tensor with bool', // Actual != Expected. + 'gradient', // zerosLike not yet implemented. + ] + }, + { + include: 'abs', + excludes: [ + 'complex', // No complex support yet. + '5D', // Rank 5 is not yet implemented. + '6D', // Rank 5 is not yet implemented. + 'accepts tensor with bool', // Actual != Expected. + 'gradient', // zerosLike not yet implemented. + 'absoluteDifference', // absoluteDifference not yet implemented + ] } ]; From d2b1191eccae21a5fc6a21699c3659e485820dee Mon Sep 17 00:00:00 2001 From: NALLEIN Date: Tue, 24 Dec 2019 15:58:40 +0800 Subject: [PATCH 2/7] Made some changes as requested --- tfjs-backend-webgpu/src/backend_webgpu.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tfjs-backend-webgpu/src/backend_webgpu.ts b/tfjs-backend-webgpu/src/backend_webgpu.ts index 8b124a67813..34f237ab840 100644 --- a/tfjs-backend-webgpu/src/backend_webgpu.ts +++ b/tfjs-backend-webgpu/src/backend_webgpu.ts @@ -28,7 +28,7 @@ import {FusedConv2DConfig} from '@tensorflow/tfjs-core/dist/ops/fused_util'; // from backend_util, rather than from 'dist' import {assertAxesAreInnerMostDims, computeOutAndReduceShapes} from '@tensorflow/tfjs-core/src/ops/axis_util'; import {computeOptimalWindowSize} from '@tensorflow/tfjs-core/src/ops/reduce_util'; -import {sumOutType} from '@tensorflow/tfjs-core/src/types' +import {sumOutType} from '@tensorflow/tfjs-core/src/types'; import {Glslang} from '@webgpu/glslang/dist/web-devel/glslang.onefile'; import {BufferManager} from './buffer_manager'; From a6271070d3f12f416a7077a2122bcc90a0bb2d27 Mon Sep 17 00:00:00 2001 From: NALLEIN Date: Fri, 27 Dec 2019 10:10:37 +0800 Subject: [PATCH 3/7] Made some changes to minimize imports from non-public API --- tfjs-backend-webgpu/src/backend_webgpu.ts | 18 +++++++++--------- .../src/kernels/reduce_webgpu.ts | 1 + tfjs-core/src/backends/backend_util.ts | 3 ++- tfjs-core/src/index.ts | 3 ++- 4 files changed, 14 insertions(+), 11 deletions(-) diff --git a/tfjs-backend-webgpu/src/backend_webgpu.ts b/tfjs-backend-webgpu/src/backend_webgpu.ts index 34f237ab840..6d59f7531ad 100644 --- a/tfjs-backend-webgpu/src/backend_webgpu.ts +++ b/tfjs-backend-webgpu/src/backend_webgpu.ts @@ -24,9 +24,6 @@ import {backend_util, DataStorage, DataType, engine, env, findBackend, KernelBac // https://github.com/tensorflow/tfjs/issues/2471 // tslint:disable-next-line: no-imports-from-dist import {FusedConv2DConfig} from '@tensorflow/tfjs-core/dist/ops/fused_util'; -// TODO TODO to import computeOptimalWindowSize and other reduce_util function -// from backend_util, rather than from 'dist' -import {assertAxesAreInnerMostDims, computeOutAndReduceShapes} from '@tensorflow/tfjs-core/src/ops/axis_util'; import {computeOptimalWindowSize} from '@tensorflow/tfjs-core/src/ops/reduce_util'; import {sumOutType} from '@tensorflow/tfjs-core/src/types'; import {Glslang} from '@webgpu/glslang/dist/web-devel/glslang.onefile'; @@ -843,24 +840,27 @@ export class WebGPUBackend extends KernelBackend { } max(x: Tensor, axes: number[]): Tensor { - assertAxesAreInnerMostDims('max', axes, x.rank); - const [outShape, reduceShape] = computeOutAndReduceShapes(x.shape, axes); + backend_util.assertAxesAreInnerMostDims('max', axes, x.rank); + const [outShape, reduceShape] = + backend_util.computeOutAndReduceShapes(x.shape, axes); const reduceSize = util.sizeFromShape(reduceShape); const a2D = x.as2D(-1, reduceSize); return this.reduce(a2D, 'max', a2D.dtype).reshape(outShape); } min(x: Tensor, axes: number[]): Tensor { - assertAxesAreInnerMostDims('min', axes, x.rank); - const [outShape, reduceShape] = computeOutAndReduceShapes(x.shape, axes); + backend_util.assertAxesAreInnerMostDims('min', axes, x.rank); + const [outShape, reduceShape] = + backend_util.computeOutAndReduceShapes(x.shape, axes); const reduceSize = util.sizeFromShape(reduceShape); const a2D = x.as2D(-1, reduceSize); return this.reduce(a2D, 'min', a2D.dtype).reshape(outShape); } sum(x: Tensor, axes: number[]): Tensor { - assertAxesAreInnerMostDims('sum', axes, x.rank); - const [outShape, reduceShape] = computeOutAndReduceShapes(x.shape, axes); + backend_util.assertAxesAreInnerMostDims('sum', axes, x.rank); + const [outShape, reduceShape] = + backend_util.computeOutAndReduceShapes(x.shape, axes); const reduceSize = util.sizeFromShape(reduceShape); const a2D = x.as2D(-1, reduceSize); const outputDType = sumOutType(x.dtype); diff --git a/tfjs-backend-webgpu/src/kernels/reduce_webgpu.ts b/tfjs-backend-webgpu/src/kernels/reduce_webgpu.ts index 99fcfb05033..cdd9354bdd6 100644 --- a/tfjs-backend-webgpu/src/kernels/reduce_webgpu.ts +++ b/tfjs-backend-webgpu/src/kernels/reduce_webgpu.ts @@ -16,6 +16,7 @@ */ import {backend_util, util} from '@tensorflow/tfjs-core'; +// TODO : use backend_util.reduce_util with the next release of tfjs-core. import {ReduceInfo} from '@tensorflow/tfjs-core/src/ops/reduce_util'; import {getCoordsDataType} from '../shader_preprocessor'; import {computeDispatch} from '../webgpu_util'; diff --git a/tfjs-core/src/backends/backend_util.ts b/tfjs-core/src/backends/backend_util.ts index 7bf8f9eed64..b894e853b36 100644 --- a/tfjs-core/src/backends/backend_util.ts +++ b/tfjs-core/src/backends/backend_util.ts @@ -30,7 +30,8 @@ export * from '../ops/broadcast_util'; export * from '../ops/concat_util'; export * from '../ops/conv_util'; export {Activation, FusedConv2DConfig} from '../ops/fused_util'; - +// TODO : use backend_util.reduce_util with the next release of tfjs-core. +export * from '../ops/reduce_util'; export {BackendValues, TypedArray, upcastType, PixelData} from '../types'; export {MemoryInfo, TimingInfo} from '../engine'; diff --git a/tfjs-core/src/index.ts b/tfjs-core/src/index.ts index ffb2638caa7..ea4de90efc8 100644 --- a/tfjs-core/src/index.ts +++ b/tfjs-core/src/index.ts @@ -68,7 +68,8 @@ export {RMSPropOptimizer} from './optimizers/rmsprop_optimizer'; export {SGDOptimizer} from './optimizers/sgd_optimizer'; export {Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D, TensorBuffer, Variable} from './tensor'; export {GradSaveFunc, NamedTensorMap, TensorContainer, TensorContainerArray, TensorContainerObject} from './tensor_types'; -export {DataType, DataTypeMap, DataValues, Rank, RecursiveArray, ShapeMap, TensorLike} from './types'; +// TODO : import sumOutType directly from '@tensorflow/tfjs-core'. +export {DataType, DataTypeMap, DataValues, Rank, RecursiveArray, ShapeMap, sumOutType, TensorLike} from './types'; export * from './ops/ops'; export {LSTMCellFunc} from './ops/lstm'; From 08a35074186a8d4fefaebd00c65b0a3689024078 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Mon, 30 Dec 2019 15:54:13 -0500 Subject: [PATCH 4/7] Update backend_webgpu.ts --- tfjs-backend-webgpu/src/backend_webgpu.ts | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tfjs-backend-webgpu/src/backend_webgpu.ts b/tfjs-backend-webgpu/src/backend_webgpu.ts index 6d59f7531ad..a403b11ac07 100644 --- a/tfjs-backend-webgpu/src/backend_webgpu.ts +++ b/tfjs-backend-webgpu/src/backend_webgpu.ts @@ -24,7 +24,9 @@ import {backend_util, DataStorage, DataType, engine, env, findBackend, KernelBac // https://github.com/tensorflow/tfjs/issues/2471 // tslint:disable-next-line: no-imports-from-dist import {FusedConv2DConfig} from '@tensorflow/tfjs-core/dist/ops/fused_util'; +// TODO: Import reduce_util from backend_util with next release of core. import {computeOptimalWindowSize} from '@tensorflow/tfjs-core/src/ops/reduce_util'; +// TODO: import sumOutType directly from '@tensorflow/tfjs-core' with next release of core. import {sumOutType} from '@tensorflow/tfjs-core/src/types'; import {Glslang} from '@webgpu/glslang/dist/web-devel/glslang.onefile'; From f5ac013ae686a200ea7799ce23debaf3c7c21b99 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Mon, 30 Dec 2019 15:54:46 -0500 Subject: [PATCH 5/7] Update backend_util.ts --- tfjs-core/src/backends/backend_util.ts | 1 - 1 file changed, 1 deletion(-) diff --git a/tfjs-core/src/backends/backend_util.ts b/tfjs-core/src/backends/backend_util.ts index b894e853b36..f5a4839b469 100644 --- a/tfjs-core/src/backends/backend_util.ts +++ b/tfjs-core/src/backends/backend_util.ts @@ -30,7 +30,6 @@ export * from '../ops/broadcast_util'; export * from '../ops/concat_util'; export * from '../ops/conv_util'; export {Activation, FusedConv2DConfig} from '../ops/fused_util'; -// TODO : use backend_util.reduce_util with the next release of tfjs-core. export * from '../ops/reduce_util'; export {BackendValues, TypedArray, upcastType, PixelData} from '../types'; export {MemoryInfo, TimingInfo} from '../engine'; From b3047d3a4c50dbf094ed1a796309b6dcceb9f5e3 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Mon, 30 Dec 2019 15:55:06 -0500 Subject: [PATCH 6/7] Update index.ts --- tfjs-core/src/index.ts | 1 - 1 file changed, 1 deletion(-) diff --git a/tfjs-core/src/index.ts b/tfjs-core/src/index.ts index ea4de90efc8..bd1ea023f2c 100644 --- a/tfjs-core/src/index.ts +++ b/tfjs-core/src/index.ts @@ -68,7 +68,6 @@ export {RMSPropOptimizer} from './optimizers/rmsprop_optimizer'; export {SGDOptimizer} from './optimizers/sgd_optimizer'; export {Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D, TensorBuffer, Variable} from './tensor'; export {GradSaveFunc, NamedTensorMap, TensorContainer, TensorContainerArray, TensorContainerObject} from './tensor_types'; -// TODO : import sumOutType directly from '@tensorflow/tfjs-core'. export {DataType, DataTypeMap, DataValues, Rank, RecursiveArray, ShapeMap, sumOutType, TensorLike} from './types'; export * from './ops/ops'; From a2b1e270f361a5d67d98fff5c8dc999c3c20af5e Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Mon, 30 Dec 2019 16:00:44 -0500 Subject: [PATCH 7/7] Update backend_webgpu.ts --- tfjs-backend-webgpu/src/backend_webgpu.ts | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tfjs-backend-webgpu/src/backend_webgpu.ts b/tfjs-backend-webgpu/src/backend_webgpu.ts index a403b11ac07..36caa78814b 100644 --- a/tfjs-backend-webgpu/src/backend_webgpu.ts +++ b/tfjs-backend-webgpu/src/backend_webgpu.ts @@ -26,7 +26,8 @@ import {backend_util, DataStorage, DataType, engine, env, findBackend, KernelBac import {FusedConv2DConfig} from '@tensorflow/tfjs-core/dist/ops/fused_util'; // TODO: Import reduce_util from backend_util with next release of core. import {computeOptimalWindowSize} from '@tensorflow/tfjs-core/src/ops/reduce_util'; -// TODO: import sumOutType directly from '@tensorflow/tfjs-core' with next release of core. +// TODO: import sumOutType directly from '@tensorflow/tfjs-core' with next +// release of core. import {sumOutType} from '@tensorflow/tfjs-core/src/types'; import {Glslang} from '@webgpu/glslang/dist/web-devel/glslang.onefile';