Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 13 additions & 24 deletions tfjs-backend-wasm/src/kernels/CropAndResize.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,11 @@
* =============================================================================
*/

import {NamedAttrMap, NamedTensorInfoMap, registerKernel, TensorInfo} from '@tensorflow/tfjs-core';
import {CropAndResize, CropAndResizeAttrs, CropAndResizeInputs, NamedAttrMap, NamedTensorInfoMap, registerKernel, TensorInfo} from '@tensorflow/tfjs-core';

import {BackendWasm} from '../backend_wasm';
import {cast} from './Cast';

interface CropAndResizeInputs extends NamedTensorInfoMap {
images: TensorInfo;
boxes: TensorInfo;
boxInd: TensorInfo;
}

interface CropAndResizeAttrs extends NamedAttrMap {
method: keyof InterpolationMethod;
extrapolationValue: number;
cropSize: [number, number];
}
import {cast} from './Cast';

// Must match enum in CropAndResize.cc
enum InterpolationMethod {
Expand Down Expand Up @@ -60,23 +49,23 @@ function setup(backend: BackendWasm): void {

function cropAndResize(args: {
backend: BackendWasm,
inputs: CropAndResizeInputs,
attrs: CropAndResizeAttrs
inputs: NamedTensorInfoMap,
attrs: NamedAttrMap
}): TensorInfo {
const {backend, inputs, attrs} = args;
const {method, extrapolationValue, cropSize} = attrs;
const {images, boxes, boxInd} = inputs;
const {method, extrapolationValue, cropSize} =
attrs as {} as CropAndResizeAttrs;
const {image, boxes, boxInd} = inputs as CropAndResizeInputs;

const numBoxes = boxes.shape[0];

const [cropHeight, cropWidth] = cropSize as [number, number];
const outShape = [numBoxes, cropHeight, cropWidth, images.shape[3]];
const outShape = [numBoxes, cropHeight, cropWidth, image.shape[3]];

let imagesData = backend.dataIdMap.get(images.dataId);
let imagesData = backend.dataIdMap.get(image.dataId);
let castedData;
if (images.dtype !== 'float32') {
castedData =
cast({backend, inputs: {x: images}, attrs: {dtype: 'float32'}});
if (image.dtype !== 'float32') {
castedData = cast({backend, inputs: {x: image}, attrs: {dtype: 'float32'}});
imagesData = backend.dataIdMap.get(castedData.dataId);
}

Expand All @@ -87,7 +76,7 @@ function cropAndResize(args: {
const out = backend.makeOutput(outShape, 'float32');
const outId = backend.dataIdMap.get(out.dataId).id;

const imagesShapeBytes = new Uint8Array(new Int32Array(images.shape).buffer);
const imagesShapeBytes = new Uint8Array(new Int32Array(image.shape).buffer);

wasmCropAndResize(
imagesId, boxesId, boxIndId, numBoxes, imagesShapeBytes, cropHeight,
Expand All @@ -103,7 +92,7 @@ function cropAndResize(args: {
}

registerKernel({
kernelName: 'CropAndResize',
kernelName: CropAndResize,
backendName: 'wasm',
setupFunc: setup,
kernelFunc: cropAndResize
Expand Down
9 changes: 9 additions & 0 deletions tfjs-core/src/kernel_names.ts
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,15 @@ export interface CumsumAttrs {
reverse: boolean;
}

export const CropAndResize = 'CropAndResize';
export type CropAndResizeInputs =
Pick<NamedTensorInfoMap, 'image'|'boxes'|'boxInd'>;
export interface CropAndResizeAttrs {
cropSize: [number, number];
method: 'bilinear'|'nearest';
extrapolationValue: number;
}

export const DepthToSpace = 'DepthToSpace';
export type DepthToSpaceInputs = Pick<NamedTensorInfoMap, 'x'>;
export interface DepthToSpaceAttrs {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/**
* @license
* Copyright 2018 Google Inc. All Rights Reserved.
* Copyright 2020 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
Expand All @@ -16,7 +16,10 @@
*/

import {ENGINE, ForwardFunc} from '../engine';
import {CropAndResize, CropAndResizeAttrs, CropAndResizeInputs} from '../kernel_names';
import {NamedAttrMap} from '../kernel_registry';
import {Tensor1D, Tensor2D, Tensor4D} from '../tensor';
import {NamedTensorMap} from '../tensor_types';
import {convertToTensor} from '../tensor_util_env';
import {TensorLike} from '../types';
import * as util from '../util';
Expand Down Expand Up @@ -84,13 +87,15 @@ function cropAndResize_(
method === 'bilinear' || method === 'nearest',
() => `method must be bilinear or nearest, but was ${method}`);

const forward: ForwardFunc<Tensor4D> = (backend, save) =>
backend.cropAndResize(
$image, $boxes, $boxInd, cropSize, method, extrapolationValue);
const forward: ForwardFunc<Tensor4D> = (backend) => backend.cropAndResize(
$image, $boxes, $boxInd, cropSize, method, extrapolationValue);

const inputs:
CropAndResizeInputs = {image: $image, boxes: $boxes, boxInd: $boxInd};
const attrs: CropAndResizeAttrs = {method, extrapolationValue, cropSize};
const res = ENGINE.runKernelFunc(
forward, {images: $image, boxes: $boxes, boxInd: $boxInd}, null /* der */,
'CropAndResize', {method, extrapolationValue, cropSize});
forward, inputs as {} as NamedTensorMap, null /* grad */, CropAndResize,
attrs as {} as NamedAttrMap);
return res;
}

Expand Down
Loading