Skip to content

Commit

Permalink
added support for sigmoid activation function for fused ops (#4948)
Browse files Browse the repository at this point in the history
FEATURE
* added support for sigmoid activation function for fused ops

* added sigmoid support in wasm and allow matmul to use sigmoid activation

* fix bazel format

* addressed comments
  • Loading branch information
pyu10055 committed Apr 19, 2021
1 parent ac9d67d commit 8f96f93
Show file tree
Hide file tree
Showing 19 changed files with 249 additions and 62 deletions.
3 changes: 3 additions & 0 deletions tfjs-backend-cpu/src/utils/fused_utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import {leakyRelu} from '../kernels/LeakyRelu';
import {prelu} from '../kernels/Prelu';
import {relu} from '../kernels/Relu';
import {relu6} from '../kernels/Relu6';
import {sigmoid} from '../kernels/Sigmoid';

export function applyActivation(
backend: MathBackendCPU, x: TensorInfo, activation: backend_util.Activation,
Expand All @@ -40,6 +41,8 @@ export function applyActivation(
return prelu({inputs: {x, alpha: preluActivationWeights}, backend});
} else if (activation === 'leakyrelu') {
return leakyRelu({inputs: {x}, backend, attrs: {alpha: leakyreluAlpha}});
} else if (activation === 'sigmoid') {
return sigmoid({inputs: {x}, backend}) as TensorInfo;
}
throw new Error(
`Activation ${activation} has not been implemented for the CPU backend.`);
Expand Down
13 changes: 13 additions & 0 deletions tfjs-backend-wasm/src/cc/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ tfjs_cc_library(
":backend",
":leakyrelu_impl",
":prelu_impl",
":sigmoid_impl",
":transpose_impl",
":util",
],
Expand All @@ -137,6 +138,7 @@ tfjs_cc_library(
":backend",
":leakyrelu_impl",
":prelu_impl",
":sigmoid_impl",
":transpose_impl",
":util",
],
Expand Down Expand Up @@ -181,6 +183,16 @@ tfjs_cc_library(
],
)

tfjs_cc_library(
name = "sigmoid_impl",
srcs = ["sigmoid_impl.cc"],
hdrs = ["sigmoid_impl.h"],
deps = [
":backend",
":util",
],
)

tfjs_cc_library(
name = "transpose_impl",
srcs = ["transpose_impl.cc"],
Expand Down Expand Up @@ -935,6 +947,7 @@ tfjs_cc_library(
hdrs = ["kernels/Sigmoid.h"],
deps = [
":backend",
":sigmoid_impl",
":unary",
],
)
Expand Down
3 changes: 2 additions & 1 deletion tfjs-backend-wasm/src/cc/backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ enum FusableActivation {
RELU = 1,
RELU6 = 2,
PRELU = 3,
LEAKYRELU = 4
LEAKYRELU = 4,
SIGMOID = 5
};

// Holds the memory offset and the size of a tensor.
Expand Down
6 changes: 4 additions & 2 deletions tfjs-backend-wasm/src/cc/batch_mat_mul_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "tfjs-backend-wasm/src/cc/backend.h"
#include "tfjs-backend-wasm/src/cc/leakyrelu_impl.h"
#include "tfjs-backend-wasm/src/cc/prelu_impl.h"
#include "tfjs-backend-wasm/src/cc/sigmoid_impl.h"
#include "tfjs-backend-wasm/src/cc/util.h"

#include "tfjs-backend-wasm/src/cc/batch_mat_mul_impl.h"
Expand Down Expand Up @@ -304,9 +305,10 @@ void fused_batch_mat_mul(const size_t a_id, const size_t* a_shape_ptr,
float* out_buf = out_info.f32_write();
if (activation == FusableActivation::PRELU) {
prelu(out_buf, out_info.size, prelu_weights_id, out_id);
}
if (activation == FusableActivation::LEAKYRELU) {
} else if (activation == FusableActivation::LEAKYRELU) {
leakyrelu(out_buf, out_info.size, leakyrelu_alpha, out_id);
} else if (activation == FusableActivation::SIGMOID) {
sigmoid(out_buf, out_info.size, out_id);
}
}
} // namespace wasm
Expand Down
4 changes: 4 additions & 0 deletions tfjs-backend-wasm/src/cc/conv2d_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include "tfjs-backend-wasm/src/cc/backend.h"
#include "tfjs-backend-wasm/src/cc/leakyrelu_impl.h"
#include "tfjs-backend-wasm/src/cc/prelu_impl.h"
#include "tfjs-backend-wasm/src/cc/sigmoid_impl.h"
#include "tfjs-backend-wasm/src/cc/transpose_impl.h"
#include "tfjs-backend-wasm/src/cc/util.h"

Expand Down Expand Up @@ -277,6 +278,9 @@ void conv2d(const size_t x_id, const size_t batch_size,
if (activation == FusableActivation::LEAKYRELU) {
leakyrelu(out_buf, out_info.size, leakyrelu_alpha, out_id);
}
if (activation == FusableActivation::SIGMOID) {
sigmoid(out_buf, out_info.size, out_id);
}
}

} // namespace wasm
Expand Down
59 changes: 3 additions & 56 deletions tfjs-backend-wasm/src/cc/kernels/Sigmoid.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,11 @@
#include <map>
#include <tuple>

#include "tfjs-backend-wasm/src/cc/kernels/Sigmoid.h"

#include "tfjs-backend-wasm/src/cc/backend.h"
#include "tfjs-backend-wasm/src/cc/kernels/Sigmoid.h"
#include "tfjs-backend-wasm/src/cc/sigmoid_impl.h"
#include "tfjs-backend-wasm/src/cc/util.h"

namespace {
// We use std::tuple as the cache key as it implements the compare operator
// needed for std::map.
typedef std::tuple<size_t> OperatorCacheKey;

// The operator cache maps the weights id to the xnn_operator_t instantiated for
// this set of weights.
std::map<OperatorCacheKey, xnn_operator_t> operator_cache;

} // namespace

namespace tfjs {
namespace wasm {
// We use C-style API to interface with Javascript.
Expand All @@ -48,51 +37,9 @@ EMSCRIPTEN_KEEPALIVE
#endif
void Sigmoid(const size_t x_id, const size_t out_id) {
auto& x_info = backend::get_tensor_info(x_id);
auto& out_info = backend::get_tensor_info_out(out_id);

const float* x_buf = x_info.f32();
float* out_buf = out_info.f32_write();

xnn_operator_t sigmoid_op = nullptr;

const size_t channels = 1;
OperatorCacheKey cache_key = {channels};

auto operator_cache_idx = operator_cache.find(cache_key);
if (operator_cache_idx == operator_cache.end()) {
const size_t input_stride = channels;
const size_t output_stride = channels;
const uint32_t flags = 0;

xnn_status status = xnn_create_sigmoid_nc_f32(
channels, input_stride, output_stride, flags, &sigmoid_op);
if (status != xnn_status_success) {
tfjs::util::warn(
"XNN status for xnn_create_sigmoid_nc_f32 is not "
"successful. Got status %d. Use -c dbg to see XNN logs.",
status);
return;
}

operator_cache.insert({cache_key, sigmoid_op});

tfjs::backend::xnn_operator_count++;
} else {
sigmoid_op = operator_cache_idx->second;
}

const size_t batch = x_info.size;
xnn_status status = xnn_setup_sigmoid_nc_f32(
sigmoid_op, batch, x_buf, out_buf, tfjs::backend::threadpool);
if (status != xnn_status_success) {
tfjs::util::warn(
"XNN status for xnn_setup_resize_bilinear2d_nhwc_f32 is not "
"successful. Got status %d. Use -c dbg to see XNN logs.",
status);
return;
}

xnn_run_operator(sigmoid_op, tfjs::backend::threadpool);
tfjs::wasm::sigmoid(x_buf, x_info.size, out_id);
}

} // extern "C"
Expand Down
91 changes: 91 additions & 0 deletions tfjs-backend-wasm/src/cc/sigmoid_impl.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
/* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* ===========================================================================*/

#ifdef __EMSCRIPTEN__
#include <emscripten.h>
#endif

#include "tfjs-backend-wasm/src/cc/sigmoid_impl.h"

#include <xnnpack.h>
#include <cmath>
#include <cstddef>
#include <map>
#include <tuple>

#include "tfjs-backend-wasm/src/cc/backend.h"
#include "tfjs-backend-wasm/src/cc/util.h"

namespace {
// We use std::tuple as the cache key as it implements the compare operator
// needed for std::map.
typedef std::tuple<size_t> OperatorCacheKey;

// The operator cache maps the weights id to the xnn_operator_t instantiated for
// this set of weights.
std::map<OperatorCacheKey, xnn_operator_t> operator_cache;

} // namespace

namespace tfjs {
namespace wasm {

void sigmoid(const float* x_buf, const size_t x_size, const size_t out_id) {
auto& out_info = backend::get_tensor_info_out(out_id);
float* out_buf = out_info.f32_write();

xnn_operator_t sigmoid_op = nullptr;

const size_t channels = 1;
OperatorCacheKey cache_key = {channels};

auto operator_cache_idx = operator_cache.find(cache_key);
if (operator_cache_idx == operator_cache.end()) {
const size_t input_stride = channels;
const size_t output_stride = channels;
const uint32_t flags = 0;

xnn_status status = xnn_create_sigmoid_nc_f32(
channels, input_stride, output_stride, flags, &sigmoid_op);
if (status != xnn_status_success) {
tfjs::util::warn(
"XNN status for xnn_create_sigmoid_nc_f32 is not "
"successful. Got status %d. Use -c dbg to see XNN logs.",
status);
return;
}

operator_cache.insert({cache_key, sigmoid_op});

tfjs::backend::xnn_operator_count++;
} else {
sigmoid_op = operator_cache_idx->second;
}

const size_t batch = x_size;
xnn_status status = xnn_setup_sigmoid_nc_f32(
sigmoid_op, batch, x_buf, out_buf, tfjs::backend::threadpool);
if (status != xnn_status_success) {
tfjs::util::warn(
"XNN status for xnn_setup_sigmoid_nc_f32 is not "
"successful. Got status %d. Use -c dbg to see XNN logs.",
status);
return;
}

xnn_run_operator(sigmoid_op, tfjs::backend::threadpool);
}

} // namespace wasm
} // namespace tfjs
28 changes: 28 additions & 0 deletions tfjs-backend-wasm/src/cc/sigmoid_impl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* ===========================================================================*/

#ifndef SIGMOID_IMPL_H_
#define SIGMOID_IMPL_H_

#include <cstddef>

namespace tfjs {
namespace wasm {

void sigmoid(const float* x_buf, const size_t x_size, const size_t out_id);

} // namespace wasm
} // namespace tfjs

#endif // SIGMOID_IMPL_H_
3 changes: 2 additions & 1 deletion tfjs-backend-wasm/src/kernels/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,6 @@ export enum FusableActivation {
relu = 1,
relu6 = 2,
prelu = 3,
leakyrelu = 4
leakyrelu = 4,
sigmoid = 5
}
5 changes: 5 additions & 0 deletions tfjs-backend-webgl/src/kernel_utils/kernel_funcs_utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,11 @@ export function mapActivationToShaderProgram(
return LEAKYRELU_PACKED;
}
return LEAKYRELU;
} else if (activation === 'sigmoid') {
if (packed) {
return unary_packed_op.SIGMOID;
}
return unary_op.SIGMOID;
}
throw new Error(`Activation ${
activation} has not been implemented for the WebGL backend.`);
Expand Down
3 changes: 2 additions & 1 deletion tfjs-backend-webgl/src/unaryop_gpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ export function STEP(alpha = 0.0) {
}

export const ELU = `return (x >= 0.0) ? x : (exp(x) - 1.0);`;

export const RELU = CHECK_NAN_SNIPPET + `
return (x < 0.0) ? 0.0 : x;
`;
Expand All @@ -62,3 +61,5 @@ export const RELU6 = CHECK_NAN_SNIPPET + `
`;

export const CLONE = 'return x;';

export const SIGMOID = `return 1.0 / (1.0 + exp(-1.0 * x));`;
2 changes: 2 additions & 0 deletions tfjs-backend-webgl/src/unaryop_packed_gpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ export const RELU6 = `
return result;
`;

export const SIGMOID = `return 1.0 / (1.0 + exp(-1.0 * x));`;

export class UnaryOpPackedProgram implements GPGPUProgram {
variableNames = ['A'];
userCode: string;
Expand Down
2 changes: 2 additions & 0 deletions tfjs-backend-webgpu/src/backend_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -771,6 +771,8 @@ export class WebGPUBackend extends KernelBackend {
return unary_op.RELU6;
} else if (activation === 'prelu') {
return getBinaryOpString(BinaryOpType.PRELU, packed);
} else if (activation === 'sigmoid') {
return unary_op.SIGMOID;
}
throw new Error(`Activation ${
activation} has not been implemented for the WebGPU backend.`);
Expand Down
34 changes: 34 additions & 0 deletions tfjs-core/src/ops/fused/fused_conv2d_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -756,6 +756,40 @@ describeWithFlags('fused conv2d', ALL_ENVS, () => {
expectArraysClose(await result.data(), expected);
});

it('basic with sigmoid', async () => {
const inputDepth = 2;
const inShape: [number, number, number, number] = [2, 2, 2, inputDepth];
const outputDepth = 2;
const fSize = 1;
const pad = 0;
const stride = 1;

const x = tf.tensor4d(
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], inShape);
const alpha = 0.3;
const w = tf.tensor4d(
[-0.1, 0.1, -0.2, 0.05], [fSize, fSize, inputDepth, outputDepth]);

const result = tf.fused.conv2d({
x,
filter: w,
strides: stride,
pad,
dataFormat: 'NHWC',
dilations: [1, 1],
activation: 'sigmoid',
leakyreluAlpha: alpha
});
expect(result.shape).toEqual([2, 2, 2, 2]);
const expected = [
0.3775407, 0.549834, 0.24973989, 0.6224593, 0.15446526, 0.6899744,
0.09112296, 0.7502601, 0.0521535, 0.80218387, 0.02931219, 0.84553474,
0.0163025, 0.8807971, 0.0090133, 0.908877
];

expectArraysClose(await result.data(), expected);
});

it('basic with broadcasted bias and relu', async () => {
const inputDepth = 2;
const inShape: [number, number, number, number] = [2, 2, 2, inputDepth];
Expand Down
Loading

0 comments on commit 8f96f93

Please sign in to comment.