diff --git a/tfjs-backend-cpu/src/backend_cpu.ts b/tfjs-backend-cpu/src/backend_cpu.ts index 65d996083e5..8bf6be1794a 100644 --- a/tfjs-backend-cpu/src/backend_cpu.ts +++ b/tfjs-backend-cpu/src/backend_cpu.ts @@ -1152,7 +1152,7 @@ export class MathBackendCPU extends KernelBackend { const v = values[i]; resultValues[i] = v > max ? max : (v < min ? min : v); } - return this.makeOutput(resultValues, x.shape, 'float32'); + return this.makeOutput(resultValues, x.shape, x.dtype); } abs(x: T): T { diff --git a/tfjs-backend-wasm/src/kernels/ClipByValue.ts b/tfjs-backend-wasm/src/kernels/ClipByValue.ts index fe848574334..f3026ec10d0 100644 --- a/tfjs-backend-wasm/src/kernels/ClipByValue.ts +++ b/tfjs-backend-wasm/src/kernels/ClipByValue.ts @@ -39,7 +39,7 @@ function clip(args: { const {x} = inputs; const {clipValueMin, clipValueMax} = attrs; const xId = backend.dataIdMap.get(x.dataId).id; - const out = backend.makeOutput(x.shape, 'float32'); + const out = backend.makeOutput(x.shape, x.dtype); const outId = backend.dataIdMap.get(out.dataId).id; wasmClip(xId, clipValueMin, clipValueMax, outId); return out; diff --git a/tfjs-core/src/ops/clip_by_value_test.ts b/tfjs-core/src/ops/clip_by_value_test.ts index 0c64c7d2bbd..642199ac2dd 100644 --- a/tfjs-core/src/ops/clip_by_value_test.ts +++ b/tfjs-core/src/ops/clip_by_value_test.ts @@ -135,4 +135,13 @@ describeWithFlags('clipByValue', ALL_ENVS, () => { expect(() => tf.clipByValue('q', 0, 1)) .toThrowError(/Argument 'x' passed to 'clipByValue' must be numeric/); }); + + it('clip int32 tensor', async () => { + const min = -1; + const max = 50; + const tensor = tf.tensor([2, 3, 4], [3], 'int32'); + const result = tf.clipByValue(tensor, min, max); + expectArraysClose(await result.data(), [2, 3, 4]); + expect(result.dtype).toEqual('int32'); + }); }); diff --git a/tfjs-node/src/nodejs_kernel_backend.ts b/tfjs-node/src/nodejs_kernel_backend.ts index ba0f64331ff..466f68e30b8 100644 --- a/tfjs-node/src/nodejs_kernel_backend.ts +++ b/tfjs-node/src/nodejs_kernel_backend.ts @@ -722,8 +722,8 @@ export class NodeJSKernelBackend extends KernelBackend { } clip(x: T, min: number, max: number): T { - const xMin = this.minimum(x, scalar(max)); - return this.maximum(xMin, scalar(min)) as T; + const xMin = this.minimum(x, scalar(max, x.dtype)); + return this.maximum(xMin, scalar(min, x.dtype)) as T; } abs(x: T): T {