From 62cbec88dccf355cbb495101984142279e67f989 Mon Sep 17 00:00:00 2001 From: Yannick Assogba Date: Tue, 9 Jun 2020 22:47:15 -0400 Subject: [PATCH 1/9] modularize resize nearest neighbor --- .../gradients/ResizeNearestNeighbor_grad.ts | 43 +++++++++ tfjs-core/src/kernel_names.ts | 11 +++ tfjs-core/src/ops/image_ops.ts | 60 ------------ tfjs-core/src/ops/ops.ts | 13 ++- tfjs-core/src/ops/resize_nearest_neighbor.ts | 91 +++++++++++++++++++ .../chained_ops/register_all_chained_ops.ts | 1 + .../register_all_chained_ops_test.ts | 1 + .../chained_ops/resize_nearest_neighbor.ts | 32 +++++++ tfjs-core/src/register_all_gradients.ts | 2 + tfjs-core/src/tensor.ts | 9 -- 10 files changed, 193 insertions(+), 70 deletions(-) create mode 100644 tfjs-core/src/gradients/ResizeNearestNeighbor_grad.ts create mode 100644 tfjs-core/src/ops/resize_nearest_neighbor.ts create mode 100644 tfjs-core/src/public/chained_ops/resize_nearest_neighbor.ts diff --git a/tfjs-core/src/gradients/ResizeNearestNeighbor_grad.ts b/tfjs-core/src/gradients/ResizeNearestNeighbor_grad.ts new file mode 100644 index 00000000000..bb642dadbfd --- /dev/null +++ b/tfjs-core/src/gradients/ResizeNearestNeighbor_grad.ts @@ -0,0 +1,43 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ +import {ENGINE, ForwardFunc} from '../engine'; +import {ResizeNearestNeighbor, ResizeNearestNeighborAttrs, ResizeNearestNeighborGrad, ResizeNearestNeighborGradInputs} from '../kernel_names'; +import {GradConfig} from '../kernel_registry'; +import {NamedAttrMap} from '../kernel_registry'; +import {Tensor, Tensor4D} from '../tensor'; +import {NamedTensorMap} from '../tensor_types'; + +export const resizeNearestNeighborGradConfig: GradConfig = { + kernelName: ResizeNearestNeighbor, + inputsToSave: ['images'], + gradFunc: (dy: Tensor4D, saved: Tensor[], attrs: NamedAttrMap) => { + const [images] = saved; + + const backPropKernelFunc: ForwardFunc = (backend) => { + const {alignCorners} = attrs as {} as ResizeNearestNeighborAttrs; + return backend.resizeNearestNeighborBackprop( + dy, images as Tensor4D, alignCorners); + }; + + const inputs: ResizeNearestNeighborGradInputs = {images}; + const imagesDer = () => ENGINE.runKernelFunc( + backPropKernelFunc, inputs as {} as NamedTensorMap, null /* gradient */, + ResizeNearestNeighborGrad, attrs); + + return {images: imagesDer}; + } +}; diff --git a/tfjs-core/src/kernel_names.ts b/tfjs-core/src/kernel_names.ts index 87d3d771d35..85a39ff20d8 100644 --- a/tfjs-core/src/kernel_names.ts +++ b/tfjs-core/src/kernel_names.ts @@ -370,6 +370,17 @@ export type RealInputs = Pick; export const Relu = 'Relu'; export type ReluInputs = Pick; +export const ResizeNearestNeighbor = 'ResizeNearestNeighbor'; +export type ResizeNearestNeighborInputs = + Pick; +export interface ResizeNearestNeighborAttrs { + alignCorners: boolean; +} + +export const ResizeNearestNeighborGrad = 'ResizeNearestNeighbor'; +export type ResizeNearestNeighborGradInputs = + Pick; + export const SelectV2 = 'SelectV2'; export type SelectV2Inputs = Pick; diff --git a/tfjs-core/src/ops/image_ops.ts b/tfjs-core/src/ops/image_ops.ts index ba034e33526..cdd0f3c653b 100644 --- a/tfjs-core/src/ops/image_ops.ts +++ b/tfjs-core/src/ops/image_ops.ts @@ -85,66 +85,7 @@ function resizeBilinear_( return res as T; } -/** - * NearestNeighbor resize a batch of 3D images to a new shape. - * - * @param images The images, of rank 4 or rank 3, of shape - * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is assumed. - * @param size The new shape `[newHeight, newWidth]` to resize the - * images to. Each channel is resized individually. - * @param alignCorners Defaults to False. If true, rescale - * input by `(new_height - 1) / (height - 1)`, which exactly aligns the 4 - * corners of images and resized images. If false, rescale by - * `new_height / height`. Treat similarly the width dimension. - */ -/** @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'} */ -function resizeNearestNeighbor_( - images: T|TensorLike, size: [number, number], alignCorners = false): T { - const $images = convertToTensor(images, 'images', 'resizeNearestNeighbor'); - util.assert( - $images.rank === 3 || $images.rank === 4, - () => `Error in resizeNearestNeighbor: x must be rank 3 or 4, but got ` + - `rank ${$images.rank}.`); - util.assert( - size.length === 2, - () => - `Error in resizeNearestNeighbor: new shape must 2D, but got shape ` + - `${size}.`); - util.assert( - $images.dtype === 'float32' || $images.dtype === 'int32', - () => '`images` must have `int32` or `float32` as dtype'); - let batchImages = $images as Tensor4D; - let reshapedTo4D = false; - if ($images.rank === 3) { - reshapedTo4D = true; - batchImages = - $images.as4D(1, $images.shape[0], $images.shape[1], $images.shape[2]); - } - const [newHeight, newWidth] = size; - - const forward: ForwardFunc = (backend, save) => { - save([batchImages]); - return backend.resizeNearestNeighbor( - batchImages, newHeight, newWidth, alignCorners); - }; - - const backward = (dy: Tensor4D, saved: Tensor[]) => { - return { - batchImages: () => ENGINE.runKernelFunc( - backend => backend.resizeNearestNeighborBackprop( - dy, saved[0] as Tensor4D, alignCorners), - {}) - }; - }; - - const res = ENGINE.runKernelFunc(forward, {batchImages}, backward); - - if (reshapedTo4D) { - return res.as3D(res.shape[1], res.shape[2], res.shape[3]) as T; - } - return res as T; -} /** * Performs non maximum suppression of bounding boxes based on @@ -351,7 +292,6 @@ function cropAndResize_( } export const resizeBilinear = op({resizeBilinear_}); -export const resizeNearestNeighbor = op({resizeNearestNeighbor_}); export const nonMaxSuppressionAsync = nonMaxSuppressionAsync_; export const nonMaxSuppressionWithScore = op({nonMaxSuppressionWithScore_}); export const nonMaxSuppressionWithScoreAsync = nonMaxSuppressionWithScoreAsync_; diff --git a/tfjs-core/src/ops/ops.ts b/tfjs-core/src/ops/ops.ts index d8671048ed8..6586b2bc8a5 100644 --- a/tfjs-core/src/ops/ops.ts +++ b/tfjs-core/src/ops/ops.ts @@ -123,9 +123,20 @@ export {op} from './operation'; // Second level exports. import * as losses from './loss_ops'; import * as linalg from './linalg_ops'; -import * as image from './image_ops'; import * as spectral from './spectral_ops'; import * as fused from './fused_ops'; import * as signal from './signal_ops'; +import {resizeBilinear, cropAndResize, nonMaxSuppression, nonMaxSuppressionAsync, nonMaxSuppressionWithScore, nonMaxSuppressionWithScoreAsync} from './image_ops'; +import {resizeNearestNeighbor} from './resize_nearest_neighbor'; +const image = { + resizeNearestNeighbor, + resizeBilinear, + cropAndResize, + nonMaxSuppression, + nonMaxSuppressionAsync, + nonMaxSuppressionWithScore, + nonMaxSuppressionWithScoreAsync +}; + export {image, linalg, losses, spectral, fused, signal}; diff --git a/tfjs-core/src/ops/resize_nearest_neighbor.ts b/tfjs-core/src/ops/resize_nearest_neighbor.ts new file mode 100644 index 00000000000..0dcacf91818 --- /dev/null +++ b/tfjs-core/src/ops/resize_nearest_neighbor.ts @@ -0,0 +1,91 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {ENGINE, ForwardFunc} from '../engine'; +import {ResizeNearestNeighbor, ResizeNearestNeighborAttrs, ResizeNearestNeighborInputs} from '../kernel_names'; +import {NamedAttrMap} from '../kernel_registry'; +import {Tensor3D, Tensor4D} from '../tensor'; +import {NamedTensorMap} from '../tensor_types'; +import {convertToTensor} from '../tensor_util_env'; +import {TensorLike} from '../types'; +import * as util from '../util'; + +import {op} from './operation'; + + + +/** + * NearestNeighbor resize a batch of 3D images to a new shape. + * + * @param images The images, of rank 4 or rank 3, of shape + * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is assumed. + * @param size The new shape `[newHeight, newWidth]` to resize the + * images to. Each channel is resized individually. + * @param alignCorners Defaults to False. If true, rescale + * input by `(new_height - 1) / (height - 1)`, which exactly aligns the 4 + * corners of images and resized images. If false, rescale by + * `new_height / height`. Treat similarly the width dimension. + */ +/** @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'} */ +function resizeNearestNeighbor_( + images: T|TensorLike, size: [number, number], alignCorners = false): T { + const $images = convertToTensor(images, 'images', 'resizeNearestNeighbor'); + const $size = convertToTensor(size, 'size', 'resizeNearestNeighbor'); + + util.assert( + $images.rank === 3 || $images.rank === 4, + () => `Error in resizeNearestNeighbor: x must be rank 3 or 4, but got ` + + `rank ${$images.rank}.`); + util.assert( + size.length === 2, + () => + `Error in resizeNearestNeighbor: new shape must 2D, but got shape ` + + `${size}.`); + util.assert( + $images.dtype === 'float32' || $images.dtype === 'int32', + () => '`images` must have `int32` or `float32` as dtype'); + + let batchImages = $images as Tensor4D; + let reshapedTo4D = false; + if ($images.rank === 3) { + reshapedTo4D = true; + batchImages = + $images.as4D(1, $images.shape[0], $images.shape[1], $images.shape[2]); + } + const [newHeight, newWidth] = size; + + const inputs: + ResizeNearestNeighborInputs = {images: batchImages, size: $size}; + const attrs: ResizeNearestNeighborAttrs = {alignCorners}; + + const forward: ForwardFunc = (backend, save) => { + save([batchImages]); + return backend.resizeNearestNeighbor( + batchImages, newHeight, newWidth, alignCorners); + }; + + const res = ENGINE.runKernelFunc( + forward, inputs as {} as NamedTensorMap, null /* gradient */, + ResizeNearestNeighbor, attrs as {} as NamedAttrMap); + + if (reshapedTo4D) { + return res.as3D(res.shape[1], res.shape[2], res.shape[3]) as T; + } + return res as T; +} + +export const resizeNearestNeighbor = op({resizeNearestNeighbor_}); diff --git a/tfjs-core/src/public/chained_ops/register_all_chained_ops.ts b/tfjs-core/src/public/chained_ops/register_all_chained_ops.ts index d762c49def9..fc9863a453f 100644 --- a/tfjs-core/src/public/chained_ops/register_all_chained_ops.ts +++ b/tfjs-core/src/public/chained_ops/register_all_chained_ops.ts @@ -51,6 +51,7 @@ import './pad'; import './pool'; import './pow'; import './relu'; +import './resize_nearest_neighbor'; import './separable_conv2d'; import './split'; import './squared_difference'; diff --git a/tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts b/tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts index 41cb1556084..2f613b47646 100644 --- a/tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts +++ b/tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts @@ -61,6 +61,7 @@ const CHAINED_OPS = [ 'pool', 'pow', 'relu', + 'resizeNearestNeighbor', 'separableConv2d', 'spaceToBatchND', 'split', diff --git a/tfjs-core/src/public/chained_ops/resize_nearest_neighbor.ts b/tfjs-core/src/public/chained_ops/resize_nearest_neighbor.ts new file mode 100644 index 00000000000..7488bd50929 --- /dev/null +++ b/tfjs-core/src/public/chained_ops/resize_nearest_neighbor.ts @@ -0,0 +1,32 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ +import {resizeNearestNeighbor} from '../../ops/resize_nearest_neighbor'; +import {Tensor, Tensor3D, Tensor4D} from '../../tensor'; +import {Rank} from '../../types'; + +declare module '../../tensor' { + interface Tensor { + resizeNearestNeighbor( + newShape2D: [number, number], alignCorners: boolean): T; + } +} + +Tensor.prototype.resizeNearestNeighbor = function( + this: T, newShape2D: [number, number], alignCorners: boolean): T { + this.throwIfDisposed(); + return resizeNearestNeighbor(this, newShape2D, alignCorners); +}; diff --git a/tfjs-core/src/register_all_gradients.ts b/tfjs-core/src/register_all_gradients.ts index e9367fda611..05d4d06df82 100644 --- a/tfjs-core/src/register_all_gradients.ts +++ b/tfjs-core/src/register_all_gradients.ts @@ -45,6 +45,7 @@ import {oneHotGradConfig} from './gradients/OneHot_grad'; import {padV2GradConfig} from './gradients/PadV2_grad'; import {powGradConfig} from './gradients/Pow_grad'; import {reluGradConfig} from './gradients/Relu_grad'; +import {resizeNearestNeighborGradConfig} from './gradients/ResizeNearestNeighbor_grad' import {spaceToBatchNDGradConfig} from './gradients/SpaceToBatchND_grad'; import {splitVGradConfig} from './gradients/SplitV_grad'; import {squareGradConfig} from './gradients/Square_grad'; @@ -93,6 +94,7 @@ const gradConfigs: GradConfig[] = [ padV2GradConfig, powGradConfig, reluGradConfig, + resizeNearestNeighborGradConfig, spaceToBatchNDGradConfig, splitVGradConfig, squareGradConfig, diff --git a/tfjs-core/src/tensor.ts b/tfjs-core/src/tensor.ts index 64da828b9fb..f63fed59486 100644 --- a/tfjs-core/src/tensor.ts +++ b/tfjs-core/src/tensor.ts @@ -268,8 +268,6 @@ export interface OpHandler { image: { resizeBilinear( images: T, size: [number, number], alignCorners: boolean): T; - resizeNearestNeighbor( - images: T, size: [number, number], alignCorners: boolean): T; }; unsortedSegmentSum( x: T, segmentIds: Tensor1D|TensorLike1D, numSegments: number): T; @@ -1066,13 +1064,6 @@ export class Tensor { return opHandler.image.resizeBilinear(this, newShape2D, alignCorners); } - resizeNearestNeighbor( - this: T, newShape2D: [number, number], alignCorners = false): T { - (this as Tensor).throwIfDisposed(); - return opHandler.image.resizeNearestNeighbor( - this, newShape2D, alignCorners); - } - // Pooling. variable(trainable = true, name?: string, dtype?: DataType): Variable { this.throwIfDisposed(); From 11d2dc05cec6c1bd8b803f950e7e369688770d93 Mon Sep 17 00:00:00 2001 From: Yannick Assogba Date: Tue, 9 Jun 2020 22:48:45 -0400 Subject: [PATCH 2/9] typo --- tfjs-core/src/kernel_names.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tfjs-core/src/kernel_names.ts b/tfjs-core/src/kernel_names.ts index 85a39ff20d8..fc99f717143 100644 --- a/tfjs-core/src/kernel_names.ts +++ b/tfjs-core/src/kernel_names.ts @@ -377,7 +377,7 @@ export interface ResizeNearestNeighborAttrs { alignCorners: boolean; } -export const ResizeNearestNeighborGrad = 'ResizeNearestNeighbor'; +export const ResizeNearestNeighborGrad = 'ResizeNearestNeighborGrad'; export type ResizeNearestNeighborGradInputs = Pick; From 4eda313d7a9e4bab6e0049845cc30e6a09ec9ae2 Mon Sep 17 00:00:00 2001 From: Yannick Assogba Date: Tue, 9 Jun 2020 23:12:38 -0400 Subject: [PATCH 3/9] modularize resizeBilinear --- .../src/gradients/ResizeBilinear_grad.ts | 43 ++++++++++ tfjs-core/src/kernel_names.ts | 9 ++ tfjs-core/src/ops/image_ops.ts | 63 +------------- tfjs-core/src/ops/ops.ts | 3 +- tfjs-core/src/ops/resize_bilinear.ts | 83 +++++++++++++++++++ tfjs-core/src/ops/resize_bilinear_test.ts | 11 +-- tfjs-core/src/ops/resize_nearest_neighbor.ts | 2 - .../chained_ops/register_all_chained_ops.ts | 1 + .../register_all_chained_ops_test.ts | 1 + .../src/public/chained_ops/resize_bilinear.ts | 32 +++++++ tfjs-core/src/register_all_gradients.ts | 4 +- tfjs-core/src/tensor.ts | 12 --- 12 files changed, 181 insertions(+), 83 deletions(-) create mode 100644 tfjs-core/src/gradients/ResizeBilinear_grad.ts create mode 100644 tfjs-core/src/ops/resize_bilinear.ts create mode 100644 tfjs-core/src/public/chained_ops/resize_bilinear.ts diff --git a/tfjs-core/src/gradients/ResizeBilinear_grad.ts b/tfjs-core/src/gradients/ResizeBilinear_grad.ts new file mode 100644 index 00000000000..6fe41ffb11e --- /dev/null +++ b/tfjs-core/src/gradients/ResizeBilinear_grad.ts @@ -0,0 +1,43 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ +import {ENGINE, ForwardFunc} from '../engine'; +import {ResizeBilinear, ResizeBilinearAttrs, ResizeBilinearGrad, ResizeBilinearGradInputs} from '../kernel_names'; +import {GradConfig} from '../kernel_registry'; +import {NamedAttrMap} from '../kernel_registry'; +import {Tensor, Tensor4D} from '../tensor'; +import {NamedTensorMap} from '../tensor_types'; + +export const resizeBilinearGradConfig: GradConfig = { + kernelName: ResizeBilinear, + inputsToSave: ['images'], + gradFunc: (dy: Tensor4D, saved: Tensor[], attrs: NamedAttrMap) => { + const [images] = saved; + + const backPropKernelFunc: ForwardFunc = (backend) => { + const {alignCorners} = attrs as {} as ResizeBilinearAttrs; + return backend.resizeBilinearBackprop( + dy, images as Tensor4D, alignCorners); + }; + + const inputs: ResizeBilinearGradInputs = {images}; + const imagesDer = () => ENGINE.runKernelFunc( + backPropKernelFunc, inputs as {} as NamedTensorMap, null /* gradient */, + ResizeBilinearGrad, attrs); + + return {images: imagesDer}; + } +}; diff --git a/tfjs-core/src/kernel_names.ts b/tfjs-core/src/kernel_names.ts index fc99f717143..2075d3cfbef 100644 --- a/tfjs-core/src/kernel_names.ts +++ b/tfjs-core/src/kernel_names.ts @@ -381,6 +381,15 @@ export const ResizeNearestNeighborGrad = 'ResizeNearestNeighborGrad'; export type ResizeNearestNeighborGradInputs = Pick; +export const ResizeBilinear = 'ResizeBilinear'; +export type ResizeBilinearInputs = Pick; +export interface ResizeBilinearAttrs { + alignCorners: boolean; +} + +export const ResizeBilinearGrad = 'ResizeBilinearGrad'; +export type ResizeBilinearGradInputs = Pick; + export const SelectV2 = 'SelectV2'; export type SelectV2Inputs = Pick; diff --git a/tfjs-core/src/ops/image_ops.ts b/tfjs-core/src/ops/image_ops.ts index cdd0f3c653b..f22220898a0 100644 --- a/tfjs-core/src/ops/image_ops.ts +++ b/tfjs-core/src/ops/image_ops.ts @@ -17,7 +17,7 @@ import {nonMaxSuppressionV3, nonMaxSuppressionV5} from '../backends/non_max_suppression_impl'; import {ENGINE, ForwardFunc} from '../engine'; -import {Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D} from '../tensor'; +import {Tensor, Tensor1D, Tensor2D, Tensor4D} from '../tensor'; import {NamedTensorMap} from '../tensor_types'; import {convertToTensor} from '../tensor_util_env'; import {TensorLike} from '../types'; @@ -27,66 +27,6 @@ import {nonMaxSuppSanityCheck} from './nonmax_util'; import {op} from './operation'; export {nonMaxSuppression} from './non_max_suppression'; -/** - * Bilinear resize a batch of 3D images to a new shape. - * - * @param images The images, of rank 4 or rank 3, of shape - * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is assumed. - * @param size The new shape `[newHeight, newWidth]` to resize the - * images to. Each channel is resized individually. - * @param alignCorners Defaults to False. If true, rescale - * input by `(new_height - 1) / (height - 1)`, which exactly aligns the 4 - * corners of images and resized images. If false, rescale by - * `new_height / height`. Treat similarly the width dimension. - */ -/** @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'} */ -function resizeBilinear_( - images: T|TensorLike, size: [number, number], alignCorners = false): T { - const $images = convertToTensor(images, 'images', 'resizeBilinear'); - util.assert( - $images.rank === 3 || $images.rank === 4, - () => `Error in resizeBilinear: x must be rank 3 or 4, but got ` + - `rank ${$images.rank}.`); - util.assert( - size.length === 2, - () => `Error in resizeBilinear: new shape must 2D, but got shape ` + - `${size}.`); - - let batchImages = $images as Tensor4D; - let reshapedTo4D = false; - if ($images.rank === 3) { - reshapedTo4D = true; - batchImages = - $images.as4D(1, $images.shape[0], $images.shape[1], $images.shape[2]); - } - - const [newHeight, newWidth] = size; - const forward: ForwardFunc = (backend, save) => { - save([batchImages]); - return backend.resizeBilinear( - batchImages, newHeight, newWidth, alignCorners); - }; - - const backward = (dy: Tensor4D, saved: Tensor[]) => { - return { - x: () => ENGINE.runKernelFunc( - backend => backend.resizeBilinearBackprop( - dy, saved[0] as Tensor4D, alignCorners), - {}) - }; - }; - - const res = ENGINE.runKernelFunc( - forward, {x: batchImages}, backward, 'ResizeBilinear', - {alignCorners, newHeight, newWidth}); - if (reshapedTo4D) { - return res.as3D(res.shape[1], res.shape[2], res.shape[3]) as T; - } - return res as T; -} - - - /** * Performs non maximum suppression of bounding boxes based on * iou (intersection over union). @@ -291,7 +231,6 @@ function cropAndResize_( return res; } -export const resizeBilinear = op({resizeBilinear_}); export const nonMaxSuppressionAsync = nonMaxSuppressionAsync_; export const nonMaxSuppressionWithScore = op({nonMaxSuppressionWithScore_}); export const nonMaxSuppressionWithScoreAsync = nonMaxSuppressionWithScoreAsync_; diff --git a/tfjs-core/src/ops/ops.ts b/tfjs-core/src/ops/ops.ts index 6586b2bc8a5..705d95d023e 100644 --- a/tfjs-core/src/ops/ops.ts +++ b/tfjs-core/src/ops/ops.ts @@ -127,7 +127,8 @@ import * as spectral from './spectral_ops'; import * as fused from './fused_ops'; import * as signal from './signal_ops'; -import {resizeBilinear, cropAndResize, nonMaxSuppression, nonMaxSuppressionAsync, nonMaxSuppressionWithScore, nonMaxSuppressionWithScoreAsync} from './image_ops'; +import {cropAndResize, nonMaxSuppression, nonMaxSuppressionAsync, nonMaxSuppressionWithScore, nonMaxSuppressionWithScoreAsync} from './image_ops'; +import {resizeBilinear} from './resize_bilinear'; import {resizeNearestNeighbor} from './resize_nearest_neighbor'; const image = { resizeNearestNeighbor, diff --git a/tfjs-core/src/ops/resize_bilinear.ts b/tfjs-core/src/ops/resize_bilinear.ts new file mode 100644 index 00000000000..d5553b8b6d1 --- /dev/null +++ b/tfjs-core/src/ops/resize_bilinear.ts @@ -0,0 +1,83 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {ENGINE, ForwardFunc} from '../engine'; +import {ResizeBilinear, ResizeBilinearAttrs, ResizeBilinearInputs} from '../kernel_names'; +import {NamedAttrMap} from '../kernel_registry'; +import {Tensor3D, Tensor4D} from '../tensor'; +import {NamedTensorMap} from '../tensor_types'; +import {convertToTensor} from '../tensor_util_env'; +import {TensorLike} from '../types'; +import * as util from '../util'; + +import {op} from './operation'; + +/** + * Bilinear resize a batch of 3D images to a new shape. + * + * @param images The images, of rank 4 or rank 3, of shape + * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is assumed. + * @param size The new shape `[newHeight, newWidth]` to resize the + * images to. Each channel is resized individually. + * @param alignCorners Defaults to False. If true, rescale + * input by `(new_height - 1) / (height - 1)`, which exactly aligns the 4 + * corners of images and resized images. If false, rescale by + * `new_height / height`. Treat similarly the width dimension. + */ +/** @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'} */ +function resizeBilinear_( + images: T|TensorLike, size: [number, number], alignCorners = false): T { + const $images = convertToTensor(images, 'images', 'resizeBilinear'); + const $size = convertToTensor(size, 'size', 'resizeBilinear'); + util.assert( + $images.rank === 3 || $images.rank === 4, + () => `Error in resizeBilinear: x must be rank 3 or 4, but got ` + + `rank ${$images.rank}.`); + util.assert( + size.length === 2, + () => `Error in resizeBilinear: new shape must 2D, but got shape ` + + `${size}.`); + + let batchImages = $images as Tensor4D; + let reshapedTo4D = false; + if ($images.rank === 3) { + reshapedTo4D = true; + batchImages = + $images.as4D(1, $images.shape[0], $images.shape[1], $images.shape[2]); + } + + const [newHeight, newWidth] = size; + const forward: ForwardFunc = (backend, save) => { + save([batchImages]); + return backend.resizeBilinear( + batchImages, newHeight, newWidth, alignCorners); + }; + + const inputs: ResizeBilinearInputs = {images: batchImages, size: $size}; + const attrs: ResizeBilinearAttrs = {alignCorners}; + + const res = ENGINE.runKernelFunc( + forward, inputs as {} as NamedTensorMap, null /* gradient */, + ResizeBilinear, attrs as {} as NamedAttrMap); + + if (reshapedTo4D) { + return res.as3D(res.shape[1], res.shape[2], res.shape[3]) as T; + } + return res as T; +} + +export const resizeBilinear = op({resizeBilinear_}); diff --git a/tfjs-core/src/ops/resize_bilinear_test.ts b/tfjs-core/src/ops/resize_bilinear_test.ts index bd6f378d5a0..9d804b0aa7c 100644 --- a/tfjs-core/src/ops/resize_bilinear_test.ts +++ b/tfjs-core/src/ops/resize_bilinear_test.ts @@ -57,20 +57,21 @@ describeWithFlags('resizeBilinear', ALL_ENVS, () => { const output = input.resizeBilinear([4, 3], false); expectArraysClose(await output.data(), [ - 1.5632453, 2.13817763, 1.44398415, 1.07632685, 0.59306782, -0.36970866, - 1.59388208, 1.98745549, 1.2917161, 1.54812956, 1.30613375, 1.15276587, + 1.5632453, 2.13817763, 1.44398415, 1.07632685, 0.59306782, -0.36970866, + 1.59388208, 1.98745549, 1.2917161, 1.54812956, 1.30613375, 1.15276587, 1.62451875, 1.83673334, 1.13944793, 2.01993227, 2.01919961, 2.67524052, - 1.62451875, 1.83673334, 1.13944793, 2.01993227, 2.01919961, 2.67524052]); + 1.62451875, 1.83673334, 1.13944793, 2.01993227, 2.01919961, 2.67524052 + ]); }); it('works for ints', async () => { const input = tf.tensor3d([1, 2, 3, 4, 5], [1, 5, 1], 'int32'); - const output = input.resizeBilinear([1, 10]); + const output = input.resizeBilinear([1, 10], false); expect(output.shape).toEqual([1, 10, 1]); expect(output.dtype).toBe('float32'); expectArraysClose( - await output.data(), [1, 1.5, 2, 2.5, 3, 3.5, 4, 4.5, 5, 5]); + await output.data(), [1, 1.5, 2, 2.5, 3, 3.5, 4, 4.5, 5, 5]); }); it('matches tensorflow w/ random numbers alignCorners=false', async () => { diff --git a/tfjs-core/src/ops/resize_nearest_neighbor.ts b/tfjs-core/src/ops/resize_nearest_neighbor.ts index 0dcacf91818..4c0d0c372c6 100644 --- a/tfjs-core/src/ops/resize_nearest_neighbor.ts +++ b/tfjs-core/src/ops/resize_nearest_neighbor.ts @@ -26,8 +26,6 @@ import * as util from '../util'; import {op} from './operation'; - - /** * NearestNeighbor resize a batch of 3D images to a new shape. * diff --git a/tfjs-core/src/public/chained_ops/register_all_chained_ops.ts b/tfjs-core/src/public/chained_ops/register_all_chained_ops.ts index fc9863a453f..59d01d5b094 100644 --- a/tfjs-core/src/public/chained_ops/register_all_chained_ops.ts +++ b/tfjs-core/src/public/chained_ops/register_all_chained_ops.ts @@ -51,6 +51,7 @@ import './pad'; import './pool'; import './pow'; import './relu'; +import './resize_bilinear'; import './resize_nearest_neighbor'; import './separable_conv2d'; import './split'; diff --git a/tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts b/tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts index 2f613b47646..9e106a1400e 100644 --- a/tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts +++ b/tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts @@ -61,6 +61,7 @@ const CHAINED_OPS = [ 'pool', 'pow', 'relu', + 'resizeBilinear', 'resizeNearestNeighbor', 'separableConv2d', 'spaceToBatchND', diff --git a/tfjs-core/src/public/chained_ops/resize_bilinear.ts b/tfjs-core/src/public/chained_ops/resize_bilinear.ts new file mode 100644 index 00000000000..672be2b60d8 --- /dev/null +++ b/tfjs-core/src/public/chained_ops/resize_bilinear.ts @@ -0,0 +1,32 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ +import {resizeBilinear} from '../../ops/resize_bilinear'; +import {Tensor, Tensor3D, Tensor4D} from '../../tensor'; +import {Rank} from '../../types'; + +declare module '../../tensor' { + interface Tensor { + resizeBilinear( + newShape2D: [number, number], alignCorners: boolean): T; + } +} + +Tensor.prototype.resizeBilinear = function( + this: T, newShape2D: [number, number], alignCorners: boolean): T { + this.throwIfDisposed(); + return resizeBilinear(this, newShape2D, alignCorners); +}; diff --git a/tfjs-core/src/register_all_gradients.ts b/tfjs-core/src/register_all_gradients.ts index 05d4d06df82..4005f2e0626 100644 --- a/tfjs-core/src/register_all_gradients.ts +++ b/tfjs-core/src/register_all_gradients.ts @@ -45,7 +45,8 @@ import {oneHotGradConfig} from './gradients/OneHot_grad'; import {padV2GradConfig} from './gradients/PadV2_grad'; import {powGradConfig} from './gradients/Pow_grad'; import {reluGradConfig} from './gradients/Relu_grad'; -import {resizeNearestNeighborGradConfig} from './gradients/ResizeNearestNeighbor_grad' +import {resizeBilinearGradConfig} from './gradients/ResizeBilinear_grad'; +import {resizeNearestNeighborGradConfig} from './gradients/ResizeNearestNeighbor_grad'; import {spaceToBatchNDGradConfig} from './gradients/SpaceToBatchND_grad'; import {splitVGradConfig} from './gradients/SplitV_grad'; import {squareGradConfig} from './gradients/Square_grad'; @@ -94,6 +95,7 @@ const gradConfigs: GradConfig[] = [ padV2GradConfig, powGradConfig, reluGradConfig, + resizeBilinearGradConfig, resizeNearestNeighborGradConfig, spaceToBatchNDGradConfig, splitVGradConfig, diff --git a/tfjs-core/src/tensor.ts b/tfjs-core/src/tensor.ts index f63fed59486..07059fb4af6 100644 --- a/tfjs-core/src/tensor.ts +++ b/tfjs-core/src/tensor.ts @@ -265,10 +265,6 @@ export interface OpHandler { prelu(x: T, alpha: T|TensorLike): T; softmax(logits: T, dim: number): T; logSoftmax(logits: T, axis: number): T; - image: { - resizeBilinear( - images: T, size: [number, number], alignCorners: boolean): T; - }; unsortedSegmentSum( x: T, segmentIds: Tensor1D|TensorLike1D, numSegments: number): T; topk(x: T, k: number, sorted: boolean): @@ -1056,14 +1052,6 @@ export class Tensor { this.throwIfDisposed(); return opHandler.logSoftmax(this, axis); } - - // Image ops. - resizeBilinear( - this: T, newShape2D: [number, number], alignCorners = false): T { - (this as Tensor).throwIfDisposed(); - return opHandler.image.resizeBilinear(this, newShape2D, alignCorners); - } - // Pooling. variable(trainable = true, name?: string, dtype?: DataType): Variable { this.throwIfDisposed(); From ea0e716870cd78ff2627a302135c8e6115ed3c09 Mon Sep 17 00:00:00 2001 From: Yannick Assogba Date: Wed, 10 Jun 2020 15:32:22 -0400 Subject: [PATCH 4/9] make size an attribute --- tfjs-core/src/kernel_names.ts | 7 ++++--- tfjs-core/src/ops/resize_bilinear.ts | 6 +++--- tfjs-core/src/ops/resize_nearest_neighbor.ts | 6 ++---- 3 files changed, 9 insertions(+), 10 deletions(-) diff --git a/tfjs-core/src/kernel_names.ts b/tfjs-core/src/kernel_names.ts index 2075d3cfbef..1d245ffa085 100644 --- a/tfjs-core/src/kernel_names.ts +++ b/tfjs-core/src/kernel_names.ts @@ -371,10 +371,10 @@ export const Relu = 'Relu'; export type ReluInputs = Pick; export const ResizeNearestNeighbor = 'ResizeNearestNeighbor'; -export type ResizeNearestNeighborInputs = - Pick; +export type ResizeNearestNeighborInputs = Pick; export interface ResizeNearestNeighborAttrs { alignCorners: boolean; + size: [number, number]; } export const ResizeNearestNeighborGrad = 'ResizeNearestNeighborGrad'; @@ -382,9 +382,10 @@ export type ResizeNearestNeighborGradInputs = Pick; export const ResizeBilinear = 'ResizeBilinear'; -export type ResizeBilinearInputs = Pick; +export type ResizeBilinearInputs = Pick; export interface ResizeBilinearAttrs { alignCorners: boolean; + size: [number, number]; } export const ResizeBilinearGrad = 'ResizeBilinearGrad'; diff --git a/tfjs-core/src/ops/resize_bilinear.ts b/tfjs-core/src/ops/resize_bilinear.ts index d5553b8b6d1..cecb1e9c242 100644 --- a/tfjs-core/src/ops/resize_bilinear.ts +++ b/tfjs-core/src/ops/resize_bilinear.ts @@ -42,7 +42,7 @@ import {op} from './operation'; function resizeBilinear_( images: T|TensorLike, size: [number, number], alignCorners = false): T { const $images = convertToTensor(images, 'images', 'resizeBilinear'); - const $size = convertToTensor(size, 'size', 'resizeBilinear'); + util.assert( $images.rank === 3 || $images.rank === 4, () => `Error in resizeBilinear: x must be rank 3 or 4, but got ` + @@ -67,8 +67,8 @@ function resizeBilinear_( batchImages, newHeight, newWidth, alignCorners); }; - const inputs: ResizeBilinearInputs = {images: batchImages, size: $size}; - const attrs: ResizeBilinearAttrs = {alignCorners}; + const inputs: ResizeBilinearInputs = {images: batchImages}; + const attrs: ResizeBilinearAttrs = {alignCorners, size}; const res = ENGINE.runKernelFunc( forward, inputs as {} as NamedTensorMap, null /* gradient */, diff --git a/tfjs-core/src/ops/resize_nearest_neighbor.ts b/tfjs-core/src/ops/resize_nearest_neighbor.ts index 4c0d0c372c6..eb0ef5260b6 100644 --- a/tfjs-core/src/ops/resize_nearest_neighbor.ts +++ b/tfjs-core/src/ops/resize_nearest_neighbor.ts @@ -42,7 +42,6 @@ import {op} from './operation'; function resizeNearestNeighbor_( images: T|TensorLike, size: [number, number], alignCorners = false): T { const $images = convertToTensor(images, 'images', 'resizeNearestNeighbor'); - const $size = convertToTensor(size, 'size', 'resizeNearestNeighbor'); util.assert( $images.rank === 3 || $images.rank === 4, @@ -66,9 +65,8 @@ function resizeNearestNeighbor_( } const [newHeight, newWidth] = size; - const inputs: - ResizeNearestNeighborInputs = {images: batchImages, size: $size}; - const attrs: ResizeNearestNeighborAttrs = {alignCorners}; + const inputs: ResizeNearestNeighborInputs = {images: batchImages}; + const attrs: ResizeNearestNeighborAttrs = {alignCorners, size}; const forward: ForwardFunc = (backend, save) => { save([batchImages]); From ad360daefd85f213be8cc2087694b1fe68b567bd Mon Sep 17 00:00:00 2001 From: Yannick Assogba Date: Wed, 10 Jun 2020 17:15:17 -0400 Subject: [PATCH 5/9] make wasm resizeBilinear use core interface --- .../src/kernels/ResizeBilinear.ts | 47 ++++++++----------- 1 file changed, 19 insertions(+), 28 deletions(-) diff --git a/tfjs-backend-wasm/src/kernels/ResizeBilinear.ts b/tfjs-backend-wasm/src/kernels/ResizeBilinear.ts index e2057874c36..13fa62e345a 100644 --- a/tfjs-backend-wasm/src/kernels/ResizeBilinear.ts +++ b/tfjs-backend-wasm/src/kernels/ResizeBilinear.ts @@ -15,22 +15,12 @@ * ============================================================================= */ -import {NamedAttrMap, NamedTensorInfoMap, registerKernel, TensorInfo, util} from '@tensorflow/tfjs-core'; +import {NamedAttrMap, NamedTensorInfoMap, registerKernel, ResizeBilinear, ResizeBilinearAttrs, ResizeBilinearInputs, TensorInfo, util} from '@tensorflow/tfjs-core'; import {BackendWasm} from '../backend_wasm'; import {cast} from './Cast'; -interface ResizeBilinearInputs extends NamedTensorInfoMap { - x: TensorInfo; -} - -interface ResizeBilinearAttrs extends NamedAttrMap { - newWidth: number; - newHeight: number; - alignCorners: boolean; -} - let wasmResizeBilinear: ( xId: number, batch: number, oldHeight: number, oldWidth: number, numChannels: number, newHeight: number, newWidth: number, @@ -50,45 +40,46 @@ function setup(backend: BackendWasm): void { ]); } -function resizeBilinear(args: { - backend: BackendWasm, - inputs: ResizeBilinearInputs, - attrs: ResizeBilinearAttrs -}): TensorInfo { - const {backend, inputs, attrs} = args; - const {x} = inputs; - const {alignCorners, newHeight, newWidth} = attrs; +function resizeBilinear( + params: {backend: {}, inputs: NamedTensorInfoMap, attrs: NamedAttrMap}): + TensorInfo { + const {backend, inputs, attrs} = params; + const wasmBackend = backend as BackendWasm; + const {images} = inputs as {} as ResizeBilinearInputs; + const {alignCorners, size} = attrs as {} as ResizeBilinearAttrs; + const [newHeight, newWidth] = size; - const [batch, oldHeight, oldWidth, numChannels] = x.shape; + const [batch, oldHeight, oldWidth, numChannels] = images.shape; const outShape = [batch, newHeight, newWidth, numChannels]; - let xData = backend.dataIdMap.get(x.dataId); + let xData = wasmBackend.dataIdMap.get(images.dataId); let castedData; if (xData.dtype !== 'float32') { - castedData = cast({backend, inputs: {x}, attrs: {dtype: 'float32'}}); - xData = backend.dataIdMap.get(castedData.dataId); + castedData = cast( + {backend: wasmBackend, inputs: {x: images}, attrs: {dtype: 'float32'}}); + xData = wasmBackend.dataIdMap.get(castedData.dataId); } const xId = xData.id; - const out = backend.makeOutput(outShape, 'float32'); - if (util.sizeFromShape(x.shape) === 0) { + const out = wasmBackend.makeOutput(outShape, 'float32'); + if (util.sizeFromShape(images.shape) === 0) { return out; } - const outId = backend.dataIdMap.get(out.dataId).id; + const outId = wasmBackend.dataIdMap.get(out.dataId).id; wasmResizeBilinear( xId, batch, oldHeight, oldWidth, numChannels, newHeight, newWidth, alignCorners ? 1 : 0, outId); if (castedData != null) { - backend.disposeData(castedData.dataId); + wasmBackend.disposeData(castedData.dataId); } return out; } registerKernel({ - kernelName: 'ResizeBilinear', + kernelName: ResizeBilinear, backendName: 'wasm', setupFunc: setup, kernelFunc: resizeBilinear From dd61a3c8f37b9e3f15b826a1865c4e73c3b830c2 Mon Sep 17 00:00:00 2001 From: Yannick Assogba Date: Wed, 10 Jun 2020 17:43:02 -0400 Subject: [PATCH 6/9] fix default param --- tfjs-core/src/public/chained_ops/resize_bilinear.ts | 5 ++++- tfjs-core/src/public/chained_ops/resize_nearest_neighbor.ts | 5 ++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/tfjs-core/src/public/chained_ops/resize_bilinear.ts b/tfjs-core/src/public/chained_ops/resize_bilinear.ts index 672be2b60d8..7b1b08adc39 100644 --- a/tfjs-core/src/public/chained_ops/resize_bilinear.ts +++ b/tfjs-core/src/public/chained_ops/resize_bilinear.ts @@ -26,7 +26,10 @@ declare module '../../tensor' { } Tensor.prototype.resizeBilinear = function( - this: T, newShape2D: [number, number], alignCorners: boolean): T { + this: T, newShape2D: [number, number], + //@ts-ignore even with the default assignment tsc thinks alignCorners has + // type 'any' + alignCorners = false): T { this.throwIfDisposed(); return resizeBilinear(this, newShape2D, alignCorners); }; diff --git a/tfjs-core/src/public/chained_ops/resize_nearest_neighbor.ts b/tfjs-core/src/public/chained_ops/resize_nearest_neighbor.ts index 7488bd50929..bd61d58fdf1 100644 --- a/tfjs-core/src/public/chained_ops/resize_nearest_neighbor.ts +++ b/tfjs-core/src/public/chained_ops/resize_nearest_neighbor.ts @@ -26,7 +26,10 @@ declare module '../../tensor' { } Tensor.prototype.resizeNearestNeighbor = function( - this: T, newShape2D: [number, number], alignCorners: boolean): T { + this: T, newShape2D: [number, number], + //@ts-ignore even with the default assignment tsc thinks alignCorners has + // type 'any' + alignCorners = false): T { this.throwIfDisposed(); return resizeNearestNeighbor(this, newShape2D, alignCorners); }; From ada6f1f97023f207875e90b17999df76fda95439 Mon Sep 17 00:00:00 2001 From: Yannick Assogba Date: Wed, 10 Jun 2020 18:15:20 -0400 Subject: [PATCH 7/9] save --- tfjs-core/src/public/chained_ops/resize_bilinear.ts | 2 +- tfjs-core/src/public/chained_ops/resize_nearest_neighbor.ts | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tfjs-core/src/public/chained_ops/resize_bilinear.ts b/tfjs-core/src/public/chained_ops/resize_bilinear.ts index 7b1b08adc39..03f826ba01e 100644 --- a/tfjs-core/src/public/chained_ops/resize_bilinear.ts +++ b/tfjs-core/src/public/chained_ops/resize_bilinear.ts @@ -21,7 +21,7 @@ import {Rank} from '../../types'; declare module '../../tensor' { interface Tensor { resizeBilinear( - newShape2D: [number, number], alignCorners: boolean): T; + newShape2D: [number, number], alignCorners?: boolean): T; } } diff --git a/tfjs-core/src/public/chained_ops/resize_nearest_neighbor.ts b/tfjs-core/src/public/chained_ops/resize_nearest_neighbor.ts index bd61d58fdf1..f12f0176ad4 100644 --- a/tfjs-core/src/public/chained_ops/resize_nearest_neighbor.ts +++ b/tfjs-core/src/public/chained_ops/resize_nearest_neighbor.ts @@ -21,7 +21,7 @@ import {Rank} from '../../types'; declare module '../../tensor' { interface Tensor { resizeNearestNeighbor( - newShape2D: [number, number], alignCorners: boolean): T; + newShape2D: [number, number], alignCorners?: boolean): T; } } From a9ae3a255dde675d285be9448bda5d74a03d8060 Mon Sep 17 00:00:00 2001 From: Yannick Assogba Date: Wed, 10 Jun 2020 18:16:17 -0400 Subject: [PATCH 8/9] save --- tfjs-core/src/public/chained_ops/resize_bilinear.ts | 5 +---- tfjs-core/src/public/chained_ops/resize_nearest_neighbor.ts | 5 +---- 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/tfjs-core/src/public/chained_ops/resize_bilinear.ts b/tfjs-core/src/public/chained_ops/resize_bilinear.ts index 03f826ba01e..e70c15f8a61 100644 --- a/tfjs-core/src/public/chained_ops/resize_bilinear.ts +++ b/tfjs-core/src/public/chained_ops/resize_bilinear.ts @@ -26,10 +26,7 @@ declare module '../../tensor' { } Tensor.prototype.resizeBilinear = function( - this: T, newShape2D: [number, number], - //@ts-ignore even with the default assignment tsc thinks alignCorners has - // type 'any' - alignCorners = false): T { + this: T, newShape2D: [number, number], alignCorners?: boolean): T { this.throwIfDisposed(); return resizeBilinear(this, newShape2D, alignCorners); }; diff --git a/tfjs-core/src/public/chained_ops/resize_nearest_neighbor.ts b/tfjs-core/src/public/chained_ops/resize_nearest_neighbor.ts index f12f0176ad4..77a6fc791f9 100644 --- a/tfjs-core/src/public/chained_ops/resize_nearest_neighbor.ts +++ b/tfjs-core/src/public/chained_ops/resize_nearest_neighbor.ts @@ -26,10 +26,7 @@ declare module '../../tensor' { } Tensor.prototype.resizeNearestNeighbor = function( - this: T, newShape2D: [number, number], - //@ts-ignore even with the default assignment tsc thinks alignCorners has - // type 'any' - alignCorners = false): T { + this: T, newShape2D: [number, number], alignCorners?: boolean): T { this.throwIfDisposed(); return resizeNearestNeighbor(this, newShape2D, alignCorners); }; From 4c0f20268e563d4ad03f5d6777888f226b977458 Mon Sep 17 00:00:00 2001 From: Yannick Assogba Date: Fri, 12 Jun 2020 11:51:37 -0400 Subject: [PATCH 9/9] simplify --- .../src/kernels/ResizeBilinear.ts | 28 ++++++++++--------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/tfjs-backend-wasm/src/kernels/ResizeBilinear.ts b/tfjs-backend-wasm/src/kernels/ResizeBilinear.ts index 13fa62e345a..99ed1d608b5 100644 --- a/tfjs-backend-wasm/src/kernels/ResizeBilinear.ts +++ b/tfjs-backend-wasm/src/kernels/ResizeBilinear.ts @@ -40,39 +40,41 @@ function setup(backend: BackendWasm): void { ]); } -function resizeBilinear( - params: {backend: {}, inputs: NamedTensorInfoMap, attrs: NamedAttrMap}): - TensorInfo { - const {backend, inputs, attrs} = params; - const wasmBackend = backend as BackendWasm; - const {images} = inputs as {} as ResizeBilinearInputs; +function resizeBilinear(args: { + backend: BackendWasm, + inputs: NamedTensorInfoMap, + attrs: NamedAttrMap +}): TensorInfo { + const {backend, inputs, attrs} = args; + + const {images} = inputs as ResizeBilinearInputs; const {alignCorners, size} = attrs as {} as ResizeBilinearAttrs; const [newHeight, newWidth] = size; const [batch, oldHeight, oldWidth, numChannels] = images.shape; const outShape = [batch, newHeight, newWidth, numChannels]; - let xData = wasmBackend.dataIdMap.get(images.dataId); + let xData = backend.dataIdMap.get(images.dataId); let castedData; if (xData.dtype !== 'float32') { - castedData = cast( - {backend: wasmBackend, inputs: {x: images}, attrs: {dtype: 'float32'}}); - xData = wasmBackend.dataIdMap.get(castedData.dataId); + castedData = + cast({backend, inputs: {x: images}, attrs: {dtype: 'float32'}}); + xData = backend.dataIdMap.get(castedData.dataId); } const xId = xData.id; - const out = wasmBackend.makeOutput(outShape, 'float32'); + const out = backend.makeOutput(outShape, 'float32'); if (util.sizeFromShape(images.shape) === 0) { return out; } - const outId = wasmBackend.dataIdMap.get(out.dataId).id; + const outId = backend.dataIdMap.get(out.dataId).id; wasmResizeBilinear( xId, batch, oldHeight, oldWidth, numChannels, newHeight, newWidth, alignCorners ? 1 : 0, outId); if (castedData != null) { - wasmBackend.disposeData(castedData.dataId); + backend.disposeData(castedData.dataId); } return out;