From 491c24a49dd39d15a8f5f62496b6cbcb96166495 Mon Sep 17 00:00:00 2001 From: Ping Yu <4018+pyu10055@users.noreply.github.com> Date: Tue, 18 Aug 2020 09:20:04 -0700 Subject: [PATCH 1/3] add clipByValue int32 test and fix the node impl --- tfjs-core/src/ops/clip_by_value_test.ts | 9 +++++++++ tfjs-node/src/nodejs_kernel_backend.ts | 4 ++-- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/tfjs-core/src/ops/clip_by_value_test.ts b/tfjs-core/src/ops/clip_by_value_test.ts index 0c64c7d2bbd..642199ac2dd 100644 --- a/tfjs-core/src/ops/clip_by_value_test.ts +++ b/tfjs-core/src/ops/clip_by_value_test.ts @@ -135,4 +135,13 @@ describeWithFlags('clipByValue', ALL_ENVS, () => { expect(() => tf.clipByValue('q', 0, 1)) .toThrowError(/Argument 'x' passed to 'clipByValue' must be numeric/); }); + + it('clip int32 tensor', async () => { + const min = -1; + const max = 50; + const tensor = tf.tensor([2, 3, 4], [3], 'int32'); + const result = tf.clipByValue(tensor, min, max); + expectArraysClose(await result.data(), [2, 3, 4]); + expect(result.dtype).toEqual('int32'); + }); }); diff --git a/tfjs-node/src/nodejs_kernel_backend.ts b/tfjs-node/src/nodejs_kernel_backend.ts index ba0f64331ff..466f68e30b8 100644 --- a/tfjs-node/src/nodejs_kernel_backend.ts +++ b/tfjs-node/src/nodejs_kernel_backend.ts @@ -722,8 +722,8 @@ export class NodeJSKernelBackend extends KernelBackend { } clip(x: T, min: number, max: number): T { - const xMin = this.minimum(x, scalar(max)); - return this.maximum(xMin, scalar(min)) as T; + const xMin = this.minimum(x, scalar(max, x.dtype)); + return this.maximum(xMin, scalar(min, x.dtype)) as T; } abs(x: T): T { From 58179f5251a6b0fa02953b848376e20109bcd6b3 Mon Sep 17 00:00:00 2001 From: Ping Yu <4018+pyu10055@users.noreply.github.com> Date: Tue, 18 Aug 2020 10:11:48 -0700 Subject: [PATCH 2/3] fix cpu backend failed test for clip --- tfjs-backend-cpu/src/backend_cpu.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tfjs-backend-cpu/src/backend_cpu.ts b/tfjs-backend-cpu/src/backend_cpu.ts index 65d996083e5..8bf6be1794a 100644 --- a/tfjs-backend-cpu/src/backend_cpu.ts +++ b/tfjs-backend-cpu/src/backend_cpu.ts @@ -1152,7 +1152,7 @@ export class MathBackendCPU extends KernelBackend { const v = values[i]; resultValues[i] = v > max ? max : (v < min ? min : v); } - return this.makeOutput(resultValues, x.shape, 'float32'); + return this.makeOutput(resultValues, x.shape, x.dtype); } abs(x: T): T { From cde2a8e873af7a3bb50ea949308e9da17e5b9f7c Mon Sep 17 00:00:00 2001 From: Ping Yu <4018+pyu10055@users.noreply.github.com> Date: Tue, 18 Aug 2020 10:35:38 -0700 Subject: [PATCH 3/3] fix wasm backend --- tfjs-backend-wasm/src/kernels/ClipByValue.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tfjs-backend-wasm/src/kernels/ClipByValue.ts b/tfjs-backend-wasm/src/kernels/ClipByValue.ts index fe848574334..f3026ec10d0 100644 --- a/tfjs-backend-wasm/src/kernels/ClipByValue.ts +++ b/tfjs-backend-wasm/src/kernels/ClipByValue.ts @@ -39,7 +39,7 @@ function clip(args: { const {x} = inputs; const {clipValueMin, clipValueMax} = attrs; const xId = backend.dataIdMap.get(x.dataId).id; - const out = backend.makeOutput(x.shape, 'float32'); + const out = backend.makeOutput(x.shape, x.dtype); const outId = backend.dataIdMap.get(out.dataId).id; wasmClip(xId, clipValueMin, clipValueMax, outId); return out;