From 52b27196e5cb5e2b78e907a25ed9c0997d3c8cac Mon Sep 17 00:00:00 2001 From: Matthew Soulanille Date: Thu, 2 Feb 2023 16:13:52 -0800 Subject: [PATCH 1/3] [WebGPU] Support forwarding FloorDiv to CPU --- tfjs-backend-webgpu/src/kernel_utils/shared.ts | 2 ++ tfjs-backend-webgpu/src/kernels/FloorDiv.ts | 5 ++++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/tfjs-backend-webgpu/src/kernel_utils/shared.ts b/tfjs-backend-webgpu/src/kernel_utils/shared.ts index db11d7e64da..2d7cbf1aa07 100644 --- a/tfjs-backend-webgpu/src/kernel_utils/shared.ts +++ b/tfjs-backend-webgpu/src/kernel_utils/shared.ts @@ -35,6 +35,7 @@ const { expImpl: expImplCPU, expm1Impl: expm1ImplCPU, floorImpl: floorImplCPU, + floorDivImpl: floorDivImplCPU, gatherNdImpl: gatherNdImplCPU, gatherV2Impl: gatherV2ImplCPU, greaterEqualImpl: greaterEqualImplCPU, @@ -72,6 +73,7 @@ export { expImplCPU, expm1ImplCPU, floorImplCPU, + floorDivImplCPU, gatherNdImplCPU, gatherV2ImplCPU, greaterEqualImplCPU, diff --git a/tfjs-backend-webgpu/src/kernels/FloorDiv.ts b/tfjs-backend-webgpu/src/kernels/FloorDiv.ts index a25fce5daaa..f4fed3f92e0 100644 --- a/tfjs-backend-webgpu/src/kernels/FloorDiv.ts +++ b/tfjs-backend-webgpu/src/kernels/FloorDiv.ts @@ -16,12 +16,15 @@ */ import {FloorDiv, KernelConfig} from '@tensorflow/tfjs-core'; +import '@tensorflow/tfjs-backend-cpu'; import {BinaryOpType} from '../binary_op_util'; import {binaryKernelFunc} from '../kernel_utils/kernel_funcs_utils'; +import {floorDivImplCPU} from '../kernel_utils/shared'; export const floorDiv = - binaryKernelFunc({opType: BinaryOpType.INT_DIV, dtype: 'int32'}); + binaryKernelFunc({opType: BinaryOpType.INT_DIV, + cpuKernelImpl: floorDivImplCPU, dtype: 'int32'}); export const floorDivConfig: KernelConfig = { kernelName: FloorDiv, From 2bb13189e7c916587f65f9871a65e5aff2f1bd5e Mon Sep 17 00:00:00 2001 From: Matthew Soulanille Date: Thu, 2 Feb 2023 16:15:26 -0800 Subject: [PATCH 2/3] Export floordiv as shared in CPU --- tfjs-backend-cpu/src/shared.ts | 1 + 1 file changed, 1 insertion(+) diff --git a/tfjs-backend-cpu/src/shared.ts b/tfjs-backend-cpu/src/shared.ts index 91cad3291c5..515376b1f8b 100644 --- a/tfjs-backend-cpu/src/shared.ts +++ b/tfjs-backend-cpu/src/shared.ts @@ -26,6 +26,7 @@ export {equalImpl} from './kernels/Equal'; export {expImpl} from './kernels/Exp'; export {expm1Impl} from './kernels/Expm1'; export {floorImpl} from './kernels/Floor'; +export {floorDivImpl} from './kernels/FloorDiv'; export {gatherNdImpl} from './kernels/GatherNd_Impl'; export {gatherV2Impl} from './kernels/GatherV2_impl'; export {greaterImpl} from './kernels/Greater'; From fdae048716ac260a1f6a3215d0643f092f7adcbb Mon Sep 17 00:00:00 2001 From: Matthew Soulanille Date: Thu, 2 Feb 2023 16:39:41 -0800 Subject: [PATCH 3/3] Remove tfjs-backend-cpu import from floordiv.ts --- tfjs-backend-webgpu/src/kernels/FloorDiv.ts | 1 - 1 file changed, 1 deletion(-) diff --git a/tfjs-backend-webgpu/src/kernels/FloorDiv.ts b/tfjs-backend-webgpu/src/kernels/FloorDiv.ts index f4fed3f92e0..bd03d728c8f 100644 --- a/tfjs-backend-webgpu/src/kernels/FloorDiv.ts +++ b/tfjs-backend-webgpu/src/kernels/FloorDiv.ts @@ -16,7 +16,6 @@ */ import {FloorDiv, KernelConfig} from '@tensorflow/tfjs-core'; -import '@tensorflow/tfjs-backend-cpu'; import {BinaryOpType} from '../binary_op_util'; import {binaryKernelFunc} from '../kernel_utils/kernel_funcs_utils';