diff --git a/tfjs-backend-cpu/src/shared.ts b/tfjs-backend-cpu/src/shared.ts index 91cad3291c5..515376b1f8b 100644 --- a/tfjs-backend-cpu/src/shared.ts +++ b/tfjs-backend-cpu/src/shared.ts @@ -26,6 +26,7 @@ export {equalImpl} from './kernels/Equal'; export {expImpl} from './kernels/Exp'; export {expm1Impl} from './kernels/Expm1'; export {floorImpl} from './kernels/Floor'; +export {floorDivImpl} from './kernels/FloorDiv'; export {gatherNdImpl} from './kernels/GatherNd_Impl'; export {gatherV2Impl} from './kernels/GatherV2_impl'; export {greaterImpl} from './kernels/Greater'; diff --git a/tfjs-backend-webgpu/src/kernel_utils/shared.ts b/tfjs-backend-webgpu/src/kernel_utils/shared.ts index db11d7e64da..2d7cbf1aa07 100644 --- a/tfjs-backend-webgpu/src/kernel_utils/shared.ts +++ b/tfjs-backend-webgpu/src/kernel_utils/shared.ts @@ -35,6 +35,7 @@ const { expImpl: expImplCPU, expm1Impl: expm1ImplCPU, floorImpl: floorImplCPU, + floorDivImpl: floorDivImplCPU, gatherNdImpl: gatherNdImplCPU, gatherV2Impl: gatherV2ImplCPU, greaterEqualImpl: greaterEqualImplCPU, @@ -72,6 +73,7 @@ export { expImplCPU, expm1ImplCPU, floorImplCPU, + floorDivImplCPU, gatherNdImplCPU, gatherV2ImplCPU, greaterEqualImplCPU, diff --git a/tfjs-backend-webgpu/src/kernels/FloorDiv.ts b/tfjs-backend-webgpu/src/kernels/FloorDiv.ts index a25fce5daaa..bd03d728c8f 100644 --- a/tfjs-backend-webgpu/src/kernels/FloorDiv.ts +++ b/tfjs-backend-webgpu/src/kernels/FloorDiv.ts @@ -19,9 +19,11 @@ import {FloorDiv, KernelConfig} from '@tensorflow/tfjs-core'; import {BinaryOpType} from '../binary_op_util'; import {binaryKernelFunc} from '../kernel_utils/kernel_funcs_utils'; +import {floorDivImplCPU} from '../kernel_utils/shared'; export const floorDiv = - binaryKernelFunc({opType: BinaryOpType.INT_DIV, dtype: 'int32'}); + binaryKernelFunc({opType: BinaryOpType.INT_DIV, + cpuKernelImpl: floorDivImplCPU, dtype: 'int32'}); export const floorDivConfig: KernelConfig = { kernelName: FloorDiv,