Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
16d628b
setup
annxingyuan Dec 20, 2019
ab89a56
things compile
annxingyuan Dec 20, 2019
b9ee891
compiling
annxingyuan Dec 20, 2019
212d7f9
compiles
annxingyuan Dec 20, 2019
5049370
compiles
annxingyuan Dec 20, 2019
8c00717
basic
annxingyuan Dec 20, 2019
f11541c
almost
annxingyuan Dec 20, 2019
03c67b7
lint
annxingyuan Dec 20, 2019
93d27c7
Merge branch 'master' into wasm_scatter
annxingyuan Dec 20, 2019
926f275
testfix
annxingyuan Dec 20, 2019
b1729fc
initialization
annxingyuan Dec 20, 2019
fe1b327
pass
annxingyuan Dec 20, 2019
2013746
use pointer
annxingyuan Dec 20, 2019
a25a047
clean
annxingyuan Dec 20, 2019
b32ba77
clean
annxingyuan Dec 20, 2019
a637383
clean
annxingyuan Dec 20, 2019
98994cd
ws
annxingyuan Dec 20, 2019
a609fca
clean
annxingyuan Dec 20, 2019
2e4e565
clean
annxingyuan Dec 20, 2019
e2936e4
clean
annxingyuan Dec 20, 2019
df4edd0
compiles
annxingyuan Dec 22, 2019
d3eded1
compiles
annxingyuan Dec 22, 2019
0f596e8
tests pass
annxingyuan Dec 22, 2019
f1f7e6d
remove imports
annxingyuan Dec 22, 2019
280c1bd
lock
annxingyuan Dec 22, 2019
70f89e1
install
annxingyuan Dec 23, 2019
5c2bb2e
remove impl
annxingyuan Dec 23, 2019
653fc35
clean
annxingyuan Dec 23, 2019
84d52e2
use const
annxingyuan Dec 23, 2019
bfa29a7
fix
annxingyuan Dec 23, 2019
ca1befb
export
annxingyuan Dec 23, 2019
0334d89
rename
annxingyuan Dec 23, 2019
7ab11f5
Merge branch 'master' into wasm_scatter
annxingyuan Dec 23, 2019
2d7950e
revert yarn.lock
annxingyuan Dec 23, 2019
51a33c7
update
annxingyuan Dec 23, 2019
5a0d163
case
annxingyuan Dec 24, 2019
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tfjs-backend-wasm/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,15 @@
"path": false
},
"peerDependencies": {
"@tensorflow/tfjs-core": "1.5.1"
"@tensorflow/tfjs-core": "link:../tfjs-core"
},
"dependencies": {
"@types/emscripten": "~0.0.34"
},
"devDependencies": {
"@bazel/bazel": "^0.28.0",
"@bazel/buildifier": "0.29.0",
"@tensorflow/tfjs-core": "1.5.1",
"@tensorflow/tfjs-core": "link:../tfjs-core",
"@types/jasmine": "~2.8.6",
"clang-format": "~1.2.4",
"jasmine": "~3.1.0",
Expand Down
10 changes: 10 additions & 0 deletions tfjs-backend-wasm/src/cc/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ tfjs_cc_library(
":Relu",
":Relu6",
":ResizeBilinear",
":ScatterNd",
":Sigmoid",
":Sub",
":Tile",
Expand Down Expand Up @@ -586,6 +587,15 @@ tfjs_unit_test(
],
)

tfjs_cc_library(
name = "ScatterNd",
srcs = ["kernels/ScatterNd.cc"],
deps = [
":backend",
":util",
],
)

tfjs_cc_library(
name = "Sigmoid",
srcs = ["kernels/Sigmoid.cc"],
Expand Down
99 changes: 99 additions & 0 deletions tfjs-backend-wasm/src/cc/kernels/ScatterNd.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
/* Copyright 2019 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* ===========================================================================*/

#ifdef __EMSCRIPTEN__
#include <emscripten.h>
#endif

#include <vector>

#include "src/cc/backend.h"
#include "src/cc/util.h"

namespace {
template <typename T>
void scatter(const int* indices_ptr, const T* updates_ptr,
const size_t slice_rank, const size_t num_updates,
const size_t slice_size, const std::vector<size_t>& strides_ptr,
const size_t output_size, const size_t dtype_size,
T* out_buf_ptr) {
// Initialize output to 0.
memset(out_buf_ptr, 0, output_size * dtype_size);

for (size_t i = 0; i < num_updates; ++i) {
size_t flattened_index = 0;
for (size_t j = 0; j < slice_rank; ++j) {
flattened_index += *indices_ptr * strides_ptr[j];

indices_ptr++;
}

out_buf_ptr += flattened_index * slice_size;

for (size_t k = 0; k < slice_size; ++k) {
*out_buf_ptr += *updates_ptr;

out_buf_ptr++;
updates_ptr++;
}

out_buf_ptr -= (flattened_index * slice_size + slice_size);
}
}

} // namespace

namespace tfjs {
namespace wasm {
extern "C" {
#ifdef __EMSCRIPTEN__
EMSCRIPTEN_KEEPALIVE
#endif

void ScatterNd(const size_t indices_id, const size_t updates_id,
const DType dtype, const size_t slice_rank,
const size_t num_updates, const size_t slice_size,
const size_t* strides_ptr, const size_t output_size,
const size_t out_id) {
auto& indices_info = backend::get_tensor_info(indices_id);
auto& updates_info = backend::get_tensor_info(updates_id);
const std::vector<size_t>& strides =
std::vector<size_t>(strides_ptr, strides_ptr + slice_rank);
const int* indices_buf = indices_info.i32();
auto& out_info = backend::get_tensor_info_out(out_id);

switch (dtype) {
case DType::float32:
scatter<float>(indices_buf, updates_info.f32(), slice_rank, num_updates,
slice_size, strides, output_size, sizeof(float),
out_info.f32_write());
break;
case DType::int32:
scatter<int32_t>(indices_buf, updates_info.i32(), slice_rank, num_updates,
slice_size, strides, output_size, sizeof(int32_t),
out_info.i32_write());
break;
case DType::boolean:
scatter<bool>(indices_buf, updates_info.b(), slice_rank, num_updates,
slice_size, strides, output_size, sizeof(bool),
out_info.b_write());
break;
default:
util::warn("Scatter for tensor id %d failed. Unknown dtype %d",
indices_id, dtype);
}
}
} // extern "C"
} // namespace wasm
} // namespace tfjs
19 changes: 1 addition & 18 deletions tfjs-backend-wasm/src/kernels/ClipByValue.ts
Original file line number Diff line number Diff line change
@@ -1,23 +1,6 @@
/**
* @license
* Copyright 2019 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

/**
* @license
* Copyright 2019 Google Inc. All Rights Reserved.
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
Expand Down
88 changes: 88 additions & 0 deletions tfjs-backend-wasm/src/kernels/ScatterNd.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
/**
* @license
* Copyright 2019 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import {NamedAttrMap, NamedTensorInfoMap, registerKernel, scatter_util, TensorInfo, util} from '@tensorflow/tfjs-core';

import {BackendWasm} from '../backend_wasm';
import {CppDType} from './types';

interface ScatterNdInputs extends NamedTensorInfoMap {
indices: TensorInfo;
updates: TensorInfo;
}

interface ScatterNdAttrs extends NamedAttrMap {
shape: number[];
}

let wasmScatterNd: (
indicesId: number, updatesId: number, dtype: CppDType, sliceRank: number,
numUpdates: number, sliceSize: number, strides: Uint8Array,
outputSize: number, outId: number) => void;

function setup(backend: BackendWasm): void {
wasmScatterNd = backend.wasm.cwrap('ScatterNd', null /*void*/, [
'number', // indicesId
'number', // updatesId
'number', // dtype
'number', // sliceRank
'number', // numUpdates
'number', // sliceSize
'array', // strides
'number', // outputSize
'number' // outId
]);
}

function scatterNd(
args:
{backend: BackendWasm, inputs: ScatterNdInputs, attrs: ScatterNdAttrs}):
TensorInfo {
const {backend, inputs, attrs} = args;
const {indices, updates} = inputs;
const {shape} = attrs;

const out = backend.makeOutput(shape, updates.dtype);
if (util.sizeFromShape(shape) === 0) {
return out;
}

const {sliceRank, numUpdates, sliceSize, strides, outputSize} =
scatter_util.calculateShapes(updates, indices, shape);

const indicesData = backend.dataIdMap.get(indices.dataId);
const indicesId = indicesData.id;

const updatesData = backend.dataIdMap.get(updates.dataId);
const updatesId = updatesData.id;

const stridesBytes = new Uint8Array(new Int32Array(strides).buffer);

const outId = backend.dataIdMap.get(out.dataId).id;
wasmScatterNd(
indicesId, updatesId, CppDType[updates.dtype], sliceRank, numUpdates,
sliceSize, stridesBytes, outputSize, outId);

return out;
}

registerKernel({
kernelName: 'ScatterNd',
backendName: 'wasm',
setupFunc: setup,
kernelFunc: scatterNd
});
1 change: 1 addition & 0 deletions tfjs-backend-wasm/src/kernels/all_kernels.ts
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ import './Relu6';
import './Reshape';
import './ResizeBilinear';
import './Rsqrt';
import './ScatterNd';
import './Sigmoid';
import './Sin';
import './Slice';
Expand Down
1 change: 1 addition & 0 deletions tfjs-backend-wasm/src/setup_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ const TEST_FILTERS: TestFilter[] = [
'gradient' // Not yet implemented.
]
},
{include: 'scatterND '},
{
include: 'abs ',
excludes: [
Expand Down
14 changes: 3 additions & 11 deletions tfjs-backend-wasm/yarn.lock
Original file line number Diff line number Diff line change
Expand Up @@ -73,17 +73,9 @@
resolved "https://registry.yarnpkg.com/@bazel/hide-bazel-files/-/hide-bazel-files-0.38.3.tgz#e98231d3d360d51860d9c1a7c3345b40dab4cf81"
integrity sha512-o+dNkfDm3qxWQ8h/04cWuTcjR7qnjZi3pQGv4aklVb16oPWx2jF8BzbkwvWuIkdbOl9VnqYP0vaHzwQVJRRcIA==

"@tensorflow/tfjs-core@1.5.1":
version "1.5.1"
resolved "https://registry.yarnpkg.com/@tensorflow/tfjs-core/-/tfjs-core-1.5.1.tgz#490209617f744fef660e8f81fe8b858e95b0d10b"
integrity sha512-N4fsi8mLsRwRs8UJN2cARB1rYFxyVXkLyZ4wOusiR976BwwZbCwQrTTSIPzPqYT3rwiexEUzm7sM6ZaDl5dpXA==
dependencies:
"@types/offscreencanvas" "~2019.3.0"
"@types/seedrandom" "2.4.27"
"@types/webgl-ext" "0.0.30"
"@types/webgl2" "0.0.4"
node-fetch "~2.1.2"
seedrandom "2.4.3"
"@tensorflow/tfjs-core@link:../tfjs-core":
version "0.0.0"
uid ""

"@types/emscripten@~0.0.34":
version "0.0.34"
Expand Down
4 changes: 3 additions & 1 deletion tfjs-core/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ import * as backend_util from './backends/backend_util';
import * as io from './io/io';
import * as math from './math';
import * as browser from './ops/browser';
import * as scatter_util from './ops/scatter_nd_util';
import * as slice_util from './ops/slice_util';
import * as serialization from './serialization';
import {setOpHandler} from './tensor';
Expand Down Expand Up @@ -99,7 +100,8 @@ export {
backend_util,
webgl,
tensor_util,
slice_util
slice_util,
scatter_util
};

// Backend specific.
Expand Down
3 changes: 2 additions & 1 deletion tfjs-core/src/ops/scatter_nd.ts
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ function scatterND_<R extends Rank>(

return ENGINE.runKernelFunc(
backend => backend.scatterND($indices, $updates, shape),
{$indices, $updates});
{indices: $indices, updates: $updates}, null /* backward */, 'ScatterNd',
{shape});
}

export const scatterND = op({scatterND_});
9 changes: 6 additions & 3 deletions tfjs-core/src/ops/scatter_nd_util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
* limitations under the License.
* =============================================================================
*/
import {TensorInfo} from '../kernel_registry';
import {Tensor} from '../tensor';
import {computeStrides, sizeFromShape} from '../util';

Expand Down Expand Up @@ -123,9 +124,11 @@ export function validateInput(
* @returns ScatterShapeInfo
*/
export function calculateShapes(
updates: Tensor, indices: Tensor, shape: number[]): ScatterShapeInfo {
updates: TensorInfo, indices: TensorInfo,
shape: number[]): ScatterShapeInfo {
// Calculate the number of dimensions in indices
const sliceRank = (indices.rank > 1) ? indices.shape[indices.rank - 1] : 1;
const indicesRank = indices.shape.length;
const sliceRank = (indicesRank > 1) ? indices.shape[indicesRank - 1] : 1;

// Calculate the number of elements that make up each slice of our updated
// tensor. This allows us to work with flattened tensors and copy over whole
Expand All @@ -138,7 +141,7 @@ export function calculateShapes(
}

const safeSliceDim = (sliceRank < 1) ? 1 : sliceRank;
const numUpdates = indices.size / safeSliceDim;
const numUpdates = sizeFromShape(indices.shape) / safeSliceDim;

const strides = [...computeStrides(shape.slice(0, sliceRank)), 1];
const outputSize = sizeFromShape(shape);
Expand Down