Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tfjs-backend-wasm/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
"build": "rimraf dist/ && tsc && ./scripts/build-wasm.sh && cp wasm-out/*.wasm dist/",
"build-ci": "./scripts/build-ci.sh",
"build-npm": "./scripts/build-npm.sh",
"clean": "rimraf dist/ && bazel clean --expunge",
"cpplint": "./scripts/cpplint.js",
"lint": "tslint -p . -t verbose && yarn cpplint",
"test": "./scripts/build-wasm.sh && karma start",
Expand Down
19 changes: 19 additions & 0 deletions tfjs-backend-wasm/src/cc/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ tfjs_cc_library(
":Abs",
":Add",
":BatchMatMul",
":MaxPool",
":ClipByValue",
":CropAndResize",
":Conv2D",
Expand All @@ -97,6 +98,16 @@ tfjs_cc_library(
]
)

tfjs_cc_library(
name = "MaxPool",
srcs = ["kernels/MaxPool.cc"],
hdrs = ["kernels/MaxPool.h"],
deps = [
":backend",
":util"
]
)

tfjs_cc_library(
name = "FusedBatchNorm",
srcs = ["kernels/FusedBatchNorm.cc"],
Expand Down Expand Up @@ -327,6 +338,14 @@ tfjs_unit_test(
]
)

tfjs_unit_test(
name = "MaxPool_test",
srcs = ["kernels/MaxPool_test.cc"],
deps = [
":MaxPool"
]
)

tfjs_unit_test(
name = "ClipByValue_test",
srcs = ["kernels/ClipByValue_test.cc"],
Expand Down
109 changes: 109 additions & 0 deletions tfjs-backend-wasm/src/cc/kernels/MaxPool.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
/* Copyright 2019 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* ===========================================================================*/

#ifdef __EMSCRIPTEN__
#include <emscripten.h>
#endif

#include <xnnpack.h>
#include <array>
#include <cmath>
#include <limits>
#include <map>
#include <unordered_map>

#include "src/cc/backend.h"
#include "src/cc/kernels/MaxPool.h"
#include "src/cc/util.h"

namespace {
typedef std::array<int, 14> OperatorCacheKey;

std::map<OperatorCacheKey, xnn_operator_t> operator_cache;
} // namespace

namespace tfjs {
namespace wasm {
extern "C" {

#ifdef __EMSCRIPTEN__
EMSCRIPTEN_KEEPALIVE
#endif
void MaxPool(const int x_id, const int batch_size, const int input_height,
const int input_width, const int filter_height,
const int filter_width, int pad_top, int pad_right, int pad_bottom,
int pad_left, const int dilation_height, const int dilation_width,
const int stride_height, const int stride_width,
const int input_channels, const int output_channels,
const int out_id) {
auto& x_info = backend::get_tensor_info(x_id);
auto& out_info = backend::get_tensor_info(out_id);

const float* x_buf = reinterpret_cast<float*>(x_info.memory_offset);
float* out_buf = reinterpret_cast<float*>(out_info.memory_offset);

xnn_operator_t max_pool_op = nullptr;

const int flags = 0;
const int channels = input_channels;

OperatorCacheKey cache_key = {pad_top, pad_right, pad_bottom,
pad_left, filter_height, filter_width,
stride_height, stride_width, dilation_height,
dilation_width, channels, input_channels,
output_channels, flags};

auto operator_cache_idx = operator_cache.find(cache_key);

if (operator_cache_idx == operator_cache.end()) {
float output_min = -std::numeric_limits<float>::infinity();
float output_max = std::numeric_limits<float>::infinity();

xnn_status status = xnn_create_max_pooling2d_nhwc_f32(
pad_top, pad_right, pad_bottom, pad_left, filter_height, filter_width,
stride_height, stride_width, dilation_height, dilation_width, channels,
input_channels /* input_pixel_stride */,
output_channels /* output_pixel_stride */, output_min, output_max,
flags, &max_pool_op);

if (status != xnn_status_success) {
util::warn(
"XNN status for xnn_create_max_pooling2d_nhwc_f32 is not "
"successful. ",
"Got status %d. Use -c dbg to see XNN logs.", status);
}

operator_cache.emplace(cache_key, max_pool_op);

tfjs::backend::xnn_operator_count++;
} else {
max_pool_op = operator_cache_idx->second;
}

xnn_status status = xnn_setup_max_pooling2d_nhwc_f32(
max_pool_op, batch_size, input_height, input_width, x_buf, out_buf,
nullptr /* thread pool */);
if (status != xnn_status_success) {
util::warn(
"XNN status for xnn_setup_max_pooling2d_nhwc_f32 is not successful. "
"Got status %d. Use -c dbg to see XNN logs.",
status);
return;
}

xnn_run_operator(max_pool_op, nullptr /* thread pool */);
}
} // extern "C"
} // namespace wasm
} // namespace tfjs
34 changes: 34 additions & 0 deletions tfjs-backend-wasm/src/cc/kernels/MaxPool.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/* Copyright 2019 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* ===========================================================================*/

#ifndef KERNELS_MAXPOOL_H_
#define KERNELS_MAXPOOL_H_

namespace tfjs {

namespace wasm {
extern "C" {
void MaxPool(const int x_id, const int batch_size, const int input_height,
const int input_width, const int filter_height,
const int filter_width, int pad_top, int pad_right, int pad_bottom,
int pad_left, const int dilation_height, const int dilation_width,
const int stride_height, const int stride_width,
const int input_channels, const int output_channels,
const int out_id);
}

} // namespace wasm
} // namespace tfjs

#endif // KERNELS_MAXPOOL_H_
82 changes: 82 additions & 0 deletions tfjs-backend-wasm/src/cc/kernels/MaxPool_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@

/* Copyright 2019 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* ===========================================================================*/

#include <gtest/gtest.h>

#include "src/cc/backend.h"
#include "src/cc/kernels/MaxPool.h"

TEST(MAXPOOL, xnn_operator_lifetime) {
tfjs::wasm::init();

ASSERT_EQ(0, tfjs::backend::num_tensors());

const int x0_id = 0;
const int x1_id = 1;
const int size = 9;
float x_values[size] = {1, 2, 3, 4, 5, 6, 7, 8, 9};

const int out_id = 2;
const int out_size = 9;
float out_values[out_size] = {};

tfjs::wasm::register_tensor(x0_id, size, x_values);
tfjs::wasm::register_tensor(x1_id, size, x_values);
tfjs::wasm::register_tensor(out_id, out_size, out_values);

ASSERT_EQ(3, tfjs::backend::num_tensors());
ASSERT_EQ(0, tfjs::backend::xnn_operator_count);

// One xnn_operator should be created for first call to maxPool.
const int batch_size = 1;
const int input_height = 3;
const int input_width = 3;
const int filter_height = 2;
const int filter_width = 2;
const int pad_top = 0;
const int pad_right = 1;
const int pad_bottom = 1;
const int pad_left = 0;
const int dilation_height = 1;
const int dilation_width = 1;
const int stride_height = 1;
const int stride_width = 1;
const int input_channels = 1;
const int output_channels = 1;
tfjs::wasm::MaxPool(
x0_id, batch_size, input_height, input_width, filter_height, filter_width,
pad_top, pad_right, pad_bottom, pad_left, dilation_height, dilation_width,
stride_height, stride_width, input_channels, output_channels, out_id);
ASSERT_EQ(1, tfjs::backend::xnn_operator_count);

// No new xnn_operators should be created for the second call to maxPool with
// the same arguments.
tfjs::wasm::MaxPool(
x0_id, batch_size, input_height, input_width, filter_height, filter_width,
pad_top, pad_right, pad_bottom, pad_left, dilation_height, dilation_width,
stride_height, stride_width, input_channels, output_channels, out_id);
ASSERT_EQ(1, tfjs::backend::xnn_operator_count);

// One new xnn_operator should be created for the next call to maxPool with
// 'valid' padding.
tfjs::wasm::MaxPool(x0_id, batch_size, input_height, input_width,
filter_height, filter_width, pad_top, 0 /* pad_right */,
0 /* pad_bottom */, pad_left, dilation_height,
dilation_width, stride_height, stride_width,
input_channels, output_channels, out_id);
ASSERT_EQ(2, tfjs::backend::xnn_operator_count);

tfjs::wasm::dispose();
}
101 changes: 101 additions & 0 deletions tfjs-backend-wasm/src/kernels/MaxPool.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
/**
* @license
* Copyright 2019 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import {backend_util, KernelFunc, NamedTensorInfoMap, registerKernel, TensorInfo} from '@tensorflow/tfjs-core';

import {BackendWasm} from '../backend_wasm';

interface MaxPoolInputs extends NamedTensorInfoMap {
x: TensorInfo;
filter: TensorInfo;
}

let wasmMaxPool: (
xId: number, batchSize: number, inputHeight: number, inputWidth: number,
filterHeight: number, filterWidth: number, padTop: number, padRight: number,
padBottom: number, padLeft: number, dilationHeight: number,
dilationWidth: number, strideHeight: number, strideWidth: number,
inputChannels: number, outputChannels: number, outId: number) => void;

function setup(backend: BackendWasm) {
wasmMaxPool = backend.wasm.cwrap('MaxPool', null /* void */, [
'number', // xId
'number', // batchSize
'number', // inputHeight
'number', // inputWidth
'number', // filterHeight
'number', // filterWidth
'number', // padTop
'number', // padRight
'number', // padBottom
'number', // padLeft
'number', // dilationHeight
'number', // dilationWidth
'number', // strideHeight
'number', // strideWidth
'number', // inputChannels
'number', // outputChannels
'number', // outId
]);
}

function maxPool(args: {
inputs: MaxPoolInputs,
backend: BackendWasm,
attrs: backend_util.Conv2DInfo
}) {
const {inputs, attrs, backend} = args;
const convInfo = attrs;

const {x} = inputs;
const xId = backend.dataIdMap.get(x.dataId).id;

const filterHeight = convInfo.filterHeight;
const filterWidth = convInfo.filterWidth;
const padTop = convInfo.padInfo.top;
const padRight = convInfo.padInfo.right;
const padBottom = convInfo.padInfo.bottom;
const padLeft = convInfo.padInfo.left;
const dilationHeight = convInfo.dilationHeight;
const dilationWidth = convInfo.dilationWidth;
const strideHeight = convInfo.strideHeight;
const strideWidth = convInfo.strideWidth;
const inputChannels = convInfo.inChannels;
const outputChannels = convInfo.outChannels;

if (convInfo.dataFormat !== 'channelsLast') {
throw new Error(
`wasm backend does not support dataFormat:'` +
`${convInfo.dataFormat}'. Please use 'channelsLast'.`);
}

const out = backend.makeOutput(convInfo.outShape, 'float32');
const outId = backend.dataIdMap.get(out.dataId).id;

wasmMaxPool(
xId, x.shape[0], x.shape[1], x.shape[2], filterHeight, filterWidth,
padTop, padRight, padBottom, padLeft, dilationHeight, dilationWidth,
strideHeight, strideWidth, inputChannels, outputChannels, outId);
return out;
}

registerKernel({
kernelName: 'MaxPool',
backendName: 'wasm',
setupFunc: setup,
kernelFunc: maxPool as {} as KernelFunc
});
1 change: 1 addition & 0 deletions tfjs-backend-wasm/src/kernels/all_kernels.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import './Div';
import './FusedBatchNorm';
import './FusedConv2D';
import './Max';
import './MaxPool';
import './Min';
import './Mul';
import './PadV2';
Expand Down
Loading