diff --git a/tfjs-backend-wasm/src/kernels/Cast.ts b/tfjs-backend-wasm/src/kernels/Cast.ts index 1592bb9f42a..ae0adb2205c 100644 --- a/tfjs-backend-wasm/src/kernels/Cast.ts +++ b/tfjs-backend-wasm/src/kernels/Cast.ts @@ -28,8 +28,9 @@ interface CastAttrs extends NamedAttrMap { dtype: DataType; } -function cast( - args: {inputs: CastInputs, attrs: CastAttrs, backend: BackendWasm}) { +export function cast( + args: {inputs: CastInputs, attrs: CastAttrs, backend: BackendWasm}): + TensorInfo { const {inputs: {x}, attrs: {dtype}, backend} = args; const out = backend.makeOutput(x.shape, dtype); const inVals = backend.typedArrayFromHeap(x); diff --git a/tfjs-backend-wasm/src/kernels/CropAndResize.ts b/tfjs-backend-wasm/src/kernels/CropAndResize.ts index b6cb223b3d2..a46f141fd2b 100644 --- a/tfjs-backend-wasm/src/kernels/CropAndResize.ts +++ b/tfjs-backend-wasm/src/kernels/CropAndResize.ts @@ -18,6 +18,7 @@ import {NamedAttrMap, NamedTensorInfoMap, registerKernel, TensorInfo} from '@tensorflow/tfjs-core'; import {BackendWasm} from '../backend_wasm'; +import {cast} from './Cast'; interface CropAndResizeInputs extends NamedTensorInfoMap { images: TensorInfo; @@ -71,11 +72,19 @@ function cropAndResize(args: { const [cropHeight, cropWidth] = cropSize as [number, number]; const outShape = [numBoxes, cropHeight, cropWidth, images.shape[3]]; - const imagesId = backend.dataIdMap.get(images.dataId).id; + let imagesData = backend.dataIdMap.get(images.dataId); + let castedData; + if (images.dtype !== 'float32') { + castedData = + cast({backend, inputs: {x: images}, attrs: {dtype: 'float32'}}); + imagesData = backend.dataIdMap.get(castedData.dataId); + } + + const imagesId = imagesData.id; const boxesId = backend.dataIdMap.get(boxes.dataId).id; const boxIndId = backend.dataIdMap.get(boxInd.dataId).id; - const out = backend.makeOutput(outShape, images.dtype); + const out = backend.makeOutput(outShape, 'float32'); const outId = backend.dataIdMap.get(out.dataId).id; const imagesShapeBytes = new Uint8Array(new Int32Array(images.shape).buffer); @@ -85,6 +94,11 @@ function cropAndResize(args: { cropWidth, InterpolationMethod[method as {} as keyof typeof InterpolationMethod], extrapolationValue, outId); + + if (castedData != null) { + backend.disposeData(castedData.dataId); + } + return out; } diff --git a/tfjs-backend-wasm/src/kernels/ResizeBilinear.ts b/tfjs-backend-wasm/src/kernels/ResizeBilinear.ts index 912927518c2..e2057874c36 100644 --- a/tfjs-backend-wasm/src/kernels/ResizeBilinear.ts +++ b/tfjs-backend-wasm/src/kernels/ResizeBilinear.ts @@ -19,6 +19,8 @@ import {NamedAttrMap, NamedTensorInfoMap, registerKernel, TensorInfo, util} from import {BackendWasm} from '../backend_wasm'; +import {cast} from './Cast'; + interface ResizeBilinearInputs extends NamedTensorInfoMap { x: TensorInfo; } @@ -48,7 +50,7 @@ function setup(backend: BackendWasm): void { ]); } -function cropAndResize(args: { +function resizeBilinear(args: { backend: BackendWasm, inputs: ResizeBilinearInputs, attrs: ResizeBilinearAttrs @@ -60,9 +62,15 @@ function cropAndResize(args: { const [batch, oldHeight, oldWidth, numChannels] = x.shape; const outShape = [batch, newHeight, newWidth, numChannels]; - const xId = backend.dataIdMap.get(x.dataId).id; + let xData = backend.dataIdMap.get(x.dataId); + let castedData; + if (xData.dtype !== 'float32') { + castedData = cast({backend, inputs: {x}, attrs: {dtype: 'float32'}}); + xData = backend.dataIdMap.get(castedData.dataId); + } + const xId = xData.id; - const out = backend.makeOutput(outShape, x.dtype); + const out = backend.makeOutput(outShape, 'float32'); if (util.sizeFromShape(x.shape) === 0) { return out; } @@ -71,6 +79,11 @@ function cropAndResize(args: { wasmResizeBilinear( xId, batch, oldHeight, oldWidth, numChannels, newHeight, newWidth, alignCorners ? 1 : 0, outId); + + if (castedData != null) { + backend.disposeData(castedData.dataId); + } + return out; } @@ -78,5 +91,5 @@ registerKernel({ kernelName: 'ResizeBilinear', backendName: 'wasm', setupFunc: setup, - kernelFunc: cropAndResize + kernelFunc: resizeBilinear }); diff --git a/tfjs-core/src/backends/cpu/backend_cpu.ts b/tfjs-core/src/backends/cpu/backend_cpu.ts index 80bc18f2bb7..acaa4f1b0a4 100644 --- a/tfjs-core/src/backends/cpu/backend_cpu.ts +++ b/tfjs-core/src/backends/cpu/backend_cpu.ts @@ -3561,8 +3561,8 @@ export class MathBackendCPU extends KernelBackend { const numBoxes = boxes.shape[0]; const [cropHeight, cropWidth] = cropSize; - const output = ops.buffer( - [numBoxes, cropHeight, cropWidth, numChannels], images.dtype); + const output = + ops.buffer([numBoxes, cropHeight, cropWidth, numChannels], 'float32'); const boxVals = this.readSync(boxes.dataId) as TypedArray; const boxIndVals = this.readSync(boxIndex.dataId) as TypedArray; diff --git a/tfjs-core/src/backends/webgl/backend_webgl.ts b/tfjs-core/src/backends/webgl/backend_webgl.ts index 083499f6db6..5e9e6c1db9a 100644 --- a/tfjs-core/src/backends/webgl/backend_webgl.ts +++ b/tfjs-core/src/backends/webgl/backend_webgl.ts @@ -2195,7 +2195,7 @@ export class MathBackendWebGL extends KernelBackend { new ResizeBilinearPackedProgram( x.shape, newHeight, newWidth, alignCorners) : new ResizeBilinearProgram(x.shape, newHeight, newWidth, alignCorners); - return this.compileAndRun(program, [x]); + return this.compileAndRun(program, [x], 'float32'); } resizeBilinearBackprop(dy: Tensor4D, x: Tensor4D, alignCorners: boolean): @@ -2260,7 +2260,7 @@ export class MathBackendWebGL extends KernelBackend { extrapolationValue: number): Tensor4D { const program = new CropAndResizeProgram( image.shape, boxes.shape, cropSize, method, extrapolationValue); - return this.compileAndRun(program, [image, boxes, boxIndex]); + return this.compileAndRun(program, [image, boxes, boxIndex], 'float32'); } depthToSpace(x: Tensor4D, blockSize: number, dataFormat: 'NHWC'|'NCHW'): diff --git a/tfjs-core/src/ops/image_ops.ts b/tfjs-core/src/ops/image_ops.ts index 29a64b93adf..336c98e7c43 100644 --- a/tfjs-core/src/ops/image_ops.ts +++ b/tfjs-core/src/ops/image_ops.ts @@ -271,7 +271,7 @@ function cropAndResize_( method?: 'bilinear'|'nearest', extrapolationValue?: number, ): Tensor4D { - const $image = convertToTensor(image, 'image', 'cropAndResize', 'float32'); + const $image = convertToTensor(image, 'image', 'cropAndResize'); const $boxes = convertToTensor(boxes, 'boxes', 'cropAndResize', 'float32'); const $boxInd = convertToTensor(boxInd, 'boxInd', 'cropAndResize', 'int32'); method = method || 'bilinear'; diff --git a/tfjs-core/src/ops/image_ops_test.ts b/tfjs-core/src/ops/image_ops_test.ts index 9a4be6dfbf5..bbac09fef01 100644 --- a/tfjs-core/src/ops/image_ops_test.ts +++ b/tfjs-core/src/ops/image_ops_test.ts @@ -220,74 +220,97 @@ describeWithFlags('cropAndResize', ALL_ENVS, () => { it('1x1-bilinear', async () => { const image: tf.Tensor4D = tf.tensor4d([1, 2, 3, 4], [1, 2, 2, 1]); const boxes: tf.Tensor2D = tf.tensor2d([0, 0, 1, 1], [1, 4]); - const boxInd: tf.Tensor1D = tf.tensor1d([0], 'int32'); + const output = tf.image.cropAndResize(image, boxes, boxInd, [1, 1], 'bilinear', 0); + expect(output.shape).toEqual([1, 1, 1, 1]); + expect(output.dtype).toBe('float32'); expectArraysClose(await output.data(), [2.5]); }); it('1x1-nearest', async () => { const image: tf.Tensor4D = tf.tensor4d([1, 2, 3, 4], [1, 2, 2, 1]); const boxes: tf.Tensor2D = tf.tensor2d([0, 0, 1, 1], [1, 4]); const boxInd: tf.Tensor1D = tf.tensor1d([0], 'int32'); + const output = tf.image.cropAndResize(image, boxes, boxInd, [1, 1], 'nearest', 0); + expect(output.shape).toEqual([1, 1, 1, 1]); + expect(output.dtype).toBe('float32'); expectArraysClose(await output.data(), [4.0]); }); it('1x1Flipped-bilinear', async () => { const image: tf.Tensor4D = tf.tensor4d([1, 2, 3, 4], [1, 2, 2, 1]); const boxes: tf.Tensor2D = tf.tensor2d([1, 1, 0, 0], [1, 4]); const boxInd: tf.Tensor1D = tf.tensor1d([0], 'int32'); + const output = tf.image.cropAndResize(image, boxes, boxInd, [1, 1], 'bilinear', 0); + expect(output.shape).toEqual([1, 1, 1, 1]); + expect(output.dtype).toBe('float32'); expectArraysClose(await output.data(), [2.5]); }); it('1x1Flipped-nearest', async () => { const image: tf.Tensor4D = tf.tensor4d([1, 2, 3, 4], [1, 2, 2, 1]); const boxes: tf.Tensor2D = tf.tensor2d([1, 1, 0, 0], [1, 4]); const boxInd: tf.Tensor1D = tf.tensor1d([0], 'int32'); + const output = tf.image.cropAndResize(image, boxes, boxInd, [1, 1], 'nearest', 0); + expect(output.shape).toEqual([1, 1, 1, 1]); + expect(output.dtype).toBe('float32'); expectArraysClose(await output.data(), [4.0]); }); it('3x3-bilinear', async () => { const image: tf.Tensor4D = tf.tensor4d([1, 2, 3, 4], [1, 2, 2, 1]); const boxes: tf.Tensor2D = tf.tensor2d([0, 0, 1, 1], [1, 4]); const boxInd: tf.Tensor1D = tf.tensor1d([0], 'int32'); + const output = tf.image.cropAndResize(image, boxes, boxInd, [3, 3], 'bilinear', 0); + expect(output.shape).toEqual([1, 3, 3, 1]); + expect(output.dtype).toBe('float32'); expectArraysClose(await output.data(), [1, 1.5, 2, 2, 2.5, 3, 3, 3.5, 4]); }); it('3x3-nearest', async () => { const image: tf.Tensor4D = tf.tensor4d([1, 2, 3, 4], [1, 2, 2, 1]); const boxes: tf.Tensor2D = tf.tensor2d([0, 0, 1, 1], [1, 4]); const boxInd: tf.Tensor1D = tf.tensor1d([0], 'int32'); + const output = tf.image.cropAndResize(image, boxes, boxInd, [3, 3], 'nearest', 0); + expect(output.shape).toEqual([1, 3, 3, 1]); + expect(output.dtype).toBe('float32'); expectArraysClose(await output.data(), [1, 2, 2, 3, 4, 4, 3, 4, 4]); }); it('3x3Flipped-bilinear', async () => { const image: tf.Tensor4D = tf.tensor4d([1, 2, 3, 4], [1, 2, 2, 1]); const boxes: tf.Tensor2D = tf.tensor2d([1, 1, 0, 0], [1, 4]); const boxInd: tf.Tensor1D = tf.tensor1d([0], 'int32'); + const output = tf.image.cropAndResize(image, boxes, boxInd, [3, 3], 'bilinear', 0); + expect(output.shape).toEqual([1, 3, 3, 1]); + expect(output.dtype).toBe('float32'); expectArraysClose(await output.data(), [4, 3.5, 3, 3, 2.5, 2, 2, 1.5, 1]); }); it('3x3Flipped-nearest', async () => { const image: tf.Tensor4D = tf.tensor4d([1, 2, 3, 4], [1, 2, 2, 1]); const boxes: tf.Tensor2D = tf.tensor2d([1, 1, 0, 0], [1, 4]); const boxInd: tf.Tensor1D = tf.tensor1d([0], 'int32'); + const output = tf.image.cropAndResize(image, boxes, boxInd, [3, 3], 'nearest', 0); + expect(output.shape).toEqual([1, 3, 3, 1]); + expect(output.dtype).toBe('float32'); expectArraysClose(await output.data(), [4, 4, 3, 4, 4, 3, 2, 2, 1]); }); it('3x3to2x2-bilinear', async () => { @@ -296,9 +319,12 @@ describeWithFlags('cropAndResize', ALL_ENVS, () => { const boxes: tf.Tensor2D = tf.tensor2d([0, 0, 1, 1, 0, 0, 0.5, 0.5], [2, 4]); const boxInd: tf.Tensor1D = tf.tensor1d([0, 0], 'int32'); + const output = tf.image.cropAndResize(image, boxes, boxInd, [2, 2], 'bilinear', 0); + expect(output.shape).toEqual([2, 2, 2, 1]); + expect(output.dtype).toBe('float32'); expectArraysClose(await output.data(), [1, 3, 7, 9, 1, 2, 4, 5]); }); it('3x3to2x2-nearest', async () => { @@ -307,9 +333,12 @@ describeWithFlags('cropAndResize', ALL_ENVS, () => { const boxes: tf.Tensor2D = tf.tensor2d([0, 0, 1, 1, 0, 0, 0.5, 0.5], [2, 4]); const boxInd: tf.Tensor1D = tf.tensor1d([0, 0], 'int32'); + const output = tf.image.cropAndResize(image, boxes, boxInd, [2, 2], 'nearest', 0); + expect(output.shape).toEqual([2, 2, 2, 1]); + expect(output.dtype).toBe('float32'); expectArraysClose(await output.data(), [1, 3, 7, 9, 1, 2, 4, 5]); }); it('3x3to2x2Flipped-bilinear', async () => { @@ -318,9 +347,12 @@ describeWithFlags('cropAndResize', ALL_ENVS, () => { const boxes: tf.Tensor2D = tf.tensor2d([1, 1, 0, 0, 0.5, 0.5, 0, 0], [2, 4]); const boxInd: tf.Tensor1D = tf.tensor1d([0, 0], 'int32'); + const output = tf.image.cropAndResize(image, boxes, boxInd, [2, 2], 'bilinear', 0); + expect(output.shape).toEqual([2, 2, 2, 1]); + expect(output.dtype).toBe('float32'); expectArraysClose(await output.data(), [9, 7, 3, 1, 5, 4, 2, 1]); }); it('3x3to2x2Flipped-nearest', async () => { @@ -329,18 +361,24 @@ describeWithFlags('cropAndResize', ALL_ENVS, () => { const boxes: tf.Tensor2D = tf.tensor2d([1, 1, 0, 0, 0.5, 0.5, 0, 0], [2, 4]); const boxInd: tf.Tensor1D = tf.tensor1d([0, 0], 'int32'); + const output = tf.image.cropAndResize(image, boxes, boxInd, [2, 2], 'nearest', 0); + expect(output.shape).toEqual([2, 2, 2, 1]); + expect(output.dtype).toBe('float32'); expectArraysClose(await output.data(), [9, 7, 3, 1, 5, 4, 2, 1]); }); it('3x3-BoxisRectangular', async () => { const image: tf.Tensor4D = tf.tensor4d([1, 2, 3, 4], [1, 2, 2, 1]); const boxes: tf.Tensor2D = tf.tensor2d([0, 0, 1, 1.5], [1, 4]); const boxInd: tf.Tensor1D = tf.tensor1d([0], 'int32'); + const output = tf.image.cropAndResize(image, boxes, boxInd, [3, 3], 'bilinear', 0); + expect(output.shape).toEqual([1, 3, 3, 1]); + expect(output.dtype).toBe('float32'); expectArraysClose( await output.data(), [1, 1.75, 0, 2, 2.75, 0, 3, 3.75, 0]); }); @@ -348,9 +386,12 @@ describeWithFlags('cropAndResize', ALL_ENVS, () => { const image: tf.Tensor4D = tf.tensor4d([1, 2, 3, 4], [1, 2, 2, 1]); const boxes: tf.Tensor2D = tf.tensor2d([0, 0, 1, 1.5], [1, 4]); const boxInd: tf.Tensor1D = tf.tensor1d([0], 'int32'); + const output = tf.image.cropAndResize(image, boxes, boxInd, [3, 3], 'nearest', 0); + expect(output.shape).toEqual([1, 3, 3, 1]); + expect(output.dtype).toBe('float32'); expectArraysClose(await output.data(), [1, 2, 0, 3, 4, 0, 3, 4, 0]); }); it('2x2to3x3-Extrapolated', async () => { @@ -358,9 +399,12 @@ describeWithFlags('cropAndResize', ALL_ENVS, () => { const image: tf.Tensor4D = tf.tensor4d([1, 2, 3, 4], [1, 2, 2, 1]); const boxes: tf.Tensor2D = tf.tensor2d([-1, -1, 1, 1], [1, 4]); const boxInd: tf.Tensor1D = tf.tensor1d([0], 'int32'); + const output = tf.image.cropAndResize(image, boxes, boxInd, [3, 3], 'bilinear', val); + expect(output.shape).toEqual([1, 3, 3, 1]); + expect(output.dtype).toBe('float32'); expectArraysClose( await output.data(), [val, val, val, val, 1, 2, val, 3, 4]); }); @@ -369,9 +413,12 @@ describeWithFlags('cropAndResize', ALL_ENVS, () => { const image: tf.Tensor4D = tf.tensor4d([1, 2, 3, 4], [1, 2, 2, 1]); const boxes: tf.Tensor2D = tf.tensor2d([-1, -1, 1, 1], [1, 4]); const boxInd: tf.Tensor1D = tf.tensor1d([0], 'int32'); + const output = tf.image.cropAndResize(image, boxes, boxInd, [3, 3], 'bilinear', val); + expect(output.shape).toEqual([1, 3, 3, 1]); + expect(output.dtype).toBe('float32'); expectArraysClose( await output.data(), [val, val, val, val, 1, 2, val, 3, 4]); }); @@ -380,9 +427,12 @@ describeWithFlags('cropAndResize', ALL_ENVS, () => { const image: tf.Tensor4D = tf.tensor4d([1, 2, 3, 4], [1, 2, 2, 1]); const boxes: tf.Tensor2D = tf.tensor2d([], [0, 4]); const boxInd: tf.Tensor1D = tf.tensor1d([], 'int32'); + const output = tf.image.cropAndResize(image, boxes, boxInd, [3, 3], 'bilinear', val); + expect(output.shape).toEqual([0, 3, 3, 1]); + expect(output.dtype).toBe('float32'); expectArraysClose(await output.data(), []); }); it('MultipleBoxes-DifferentBoxes', async () => { @@ -391,9 +441,12 @@ describeWithFlags('cropAndResize', ALL_ENVS, () => { const boxes: tf.Tensor2D = tf.tensor2d([0, 0, 1, 1.5, 0, 0, 1.5, 1], [2, 4]); const boxInd: tf.Tensor1D = tf.tensor1d([0, 1], 'int32'); + const output = tf.image.cropAndResize(image, boxes, boxInd, [3, 3], 'bilinear', 0); + expect(output.shape).toEqual([2, 3, 3, 1]); + expect(output.dtype).toBe('float32'); expectArraysClose( await output.data(), [1, 1.75, 0, 2, 2.75, 0, 3, 3.75, 0, 5, 5.5, 6, 6.5, 7, 7.5, 0, 0, 0]); @@ -403,11 +456,30 @@ describeWithFlags('cropAndResize', ALL_ENVS, () => { tf.tensor4d([1, 2, 3, 4, 5, 6, 7, 8], [2, 2, 2, 1]); const boxes: tf.Tensor2D = tf.tensor2d([0, 0, 1, 1.5, 0, 0, 2, 1], [2, 4]); const boxInd: tf.Tensor1D = tf.tensor1d([0, 1], 'int32'); + const output = tf.image.cropAndResize(image, boxes, boxInd, [3, 3], 'nearest', 0); + expect(output.shape).toEqual([2, 3, 3, 1]); + expect(output.dtype).toBe('float32'); expectArraysClose( await output.data(), [1, 2, 0, 3, 4, 0, 3, 4, 0, 5, 6, 6, 7, 8, 8, 0, 0, 0]); }); + it('int32 image returns float output', async () => { + const image: tf.Tensor4D = + tf.tensor4d([1, 2, 3, 4, 5, 6, 7, 8], [2, 2, 2, 1], 'int32'); + const boxes: tf.Tensor2D = + tf.tensor2d([0, 0, 1, 1.5, 0, 0, 1.5, 1], [2, 4]); + const boxInd: tf.Tensor1D = tf.tensor1d([0, 1], 'int32'); + + const output = + tf.image.cropAndResize(image, boxes, boxInd, [3, 3], 'bilinear', 0); + + expect(output.shape).toEqual([2, 3, 3, 1]); + expect(output.dtype).toBe('float32'); + expectArraysClose( + await output.data(), + [1, 1.75, 0, 2, 2.75, 0, 3, 3.75, 0, 5, 5.5, 6, 6.5, 7, 7.5, 0, 0, 0]); + }); }); diff --git a/tfjs-core/src/ops/resize_bilinear_test.ts b/tfjs-core/src/ops/resize_bilinear_test.ts index e1a78b69441..565c7929696 100644 --- a/tfjs-core/src/ops/resize_bilinear_test.ts +++ b/tfjs-core/src/ops/resize_bilinear_test.ts @@ -52,6 +52,16 @@ describeWithFlags('resizeBilinear', ALL_ENVS, () => { 1.62451875, 1.83673334, 1.13944793, 2.01993227, 2.01919961, 2.67524052]); }); + it('works for ints', async () => { + const input = tf.tensor3d([1, 2, 3, 4, 5], [1, 5, 1], 'int32'); + const output = input.resizeBilinear([1, 10]); + + expect(output.shape).toEqual([1, 10, 1]); + expect(output.dtype).toBe('float32'); + expectArraysClose( + await output.data(), [1, 1.5, 2, 2.5, 3, 3.5, 4, 4.5, 5, 5]); + }); + it('matches tensorflow w/ random numbers alignCorners=false', async () => { const input = tf.tensor3d( [