Skip to content

Commit

Permalink
[wasm] Add Maximum
Browse files Browse the repository at this point in the history
  • Loading branch information
sapphi-red committed Dec 11, 2019
1 parent 189ee59 commit 625db8c
Show file tree
Hide file tree
Showing 6 changed files with 95 additions and 2 deletions.
11 changes: 11 additions & 0 deletions tfjs-backend-wasm/src/cc/BUILD
Expand Up @@ -129,6 +129,7 @@ tfjs_cc_library(
":DepthwiseConv2dNative",
":FloorDiv",
":Minimum",
":Maximum",
":FusedConv2D",
":FusedDepthwiseConv2D",
":Div",
Expand Down Expand Up @@ -300,6 +301,16 @@ tfjs_cc_library(
],
)

tfjs_cc_library(
name = "Maximum",
srcs = ["kernels/Maximum.cc"],
deps = [
":backend",
":binary",
":util",
],
)

tfjs_cc_library(
name = "FusedConv2D",
srcs = ["kernels/FusedConv2D.cc"],
Expand Down
53 changes: 53 additions & 0 deletions tfjs-backend-wasm/src/cc/kernels/Maximum.cc
@@ -0,0 +1,53 @@
/* Copyright 2019 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* ===========================================================================*/

#ifdef __EMSCRIPTEN__
#include <emscripten.h>
#endif

#include <algorithm>

#include "src/cc/binary.h"
#include "src/cc/util.h"

namespace tfjs {
namespace wasm {
// We use C-style API to interface with Javascript.
extern "C" {

#ifdef __EMSCRIPTEN__
EMSCRIPTEN_KEEPALIVE
#endif
void Maximum(const int a_id, const size_t* a_shape_ptr, const int a_shape_len,
const int b_id, const size_t* b_shape_ptr, const int b_shape_len,
const DType dtype, const int out_id) {
switch (dtype) {
case DType::float32:
binary_f32(a_id, b_id, out_id,
[](float a, float b) { return std::max(a, b); });
break;
case DType::int32:
binary_i32(a_id, b_id, out_id,
[](int a, int b) { return std::max(a, b); });
break;
default:
util::warn(
"Maximum for tensor ids %d and %d failed. Unsupported dtype %d",
a_id, b_id, dtype);
}
}

} // extern "C"
} // namespace wasm
} // namespace tfjs
20 changes: 20 additions & 0 deletions tfjs-backend-wasm/src/kernels/Maximum.ts
@@ -0,0 +1,20 @@
/**
* @license
* Copyright 2019 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import { registerBinaryKernel } from './binary_kernel';
const supportsBroadcast = false;
registerBinaryKernel('Maximum', supportsBroadcast);
1 change: 1 addition & 0 deletions tfjs-backend-wasm/src/kernels/all_kernels.ts
Expand Up @@ -37,6 +37,7 @@ import './FusedBatchNorm';
import './FusedConv2D';
import './FusedDepthwiseConv2D';
import './Max';
import './Maximum';
import './MaxPool';
import './Min';
import './Minimum';
Expand Down
8 changes: 8 additions & 0 deletions tfjs-backend-wasm/src/setup_test.ts
Expand Up @@ -203,6 +203,14 @@ const TEST_FILTERS: TestFilter[] = [
// dims not supported yet.
]
},
{
include: 'maximum',
excludes: [
'gradient', // Not yet implemented.
'broadcasts 2x1 Tensor2D and 2x2 Tensor2D' // Broadcasting along inner
// dims not supported yet.
]
},
];

const customInclude = (testName: string) => {
Expand Down
4 changes: 2 additions & 2 deletions tfjs-core/src/ops/binary_ops.ts
Expand Up @@ -738,13 +738,13 @@ function maximum_<T extends Tensor>(
const [$a, $b] = saved;
const derA = () => dy.mul($a.greaterEqual($b).toFloat());
const derB = () => dy.mul($a.less($b).toFloat());
return {$a: derA, $b: derB};
return {a: derA, b: derB};
};
return ENGINE.runKernelFunc((backend, save) => {
const res = backend.maximum($a, $b);
save([$a, $b]);
return res;
}, {$a, $b}, der) as T;
}, {a: $a, b: $b}, der, 'Maximum') as T;
}

/**
Expand Down

0 comments on commit 625db8c

Please sign in to comment.