-
Notifications
You must be signed in to change notification settings - Fork 2k
[wasm] Add ScatterND kernel. #2600
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
dsmilkov
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! Just a few of comments regarding structuring c++
Reviewed 12 of 12 files at r1.
Reviewable status: 0 of 1 approvals obtained (waiting on @annxingyuan, @dsmilkov, and @nsthorat)
tfjs-backend-wasm/yarn.lock, line 70 at r1 (raw file):
"@bazel/buildifier-linux_x64" "0.29.0" "@bazel/buildifier-win32_x64" "0.29.0"
the lock file seems to have changed a lot, but don't see any changes to package.json? did you explicitly remove the lock file before, or was the lock @master out of sync?
tfjs-backend-wasm/src/cc/scatter_impl.h, line 25 at r1 (raw file):
template <typename T> void scatter(const int* indices_ptr, const T* updates_ptr, size_t slice_rank,
no need for this header file (see comment below)
tfjs-backend-wasm/src/cc/scatter_impl.cc, line 28 at r1 (raw file):
template <typename T> void scatter(const int* indices_ptr, const T* updates_ptr, size_t slice_rank,
no need for a separate scatter_impl.cc file (move all this code inside kernels/ScatterND.cc) unless two or more kernels are using scatter as internal detail (sharing implementation), which is not the case as of this PR.
tfjs-backend-wasm/src/cc/kernels/ScatterND.h, line 24 at r1 (raw file):
extern "C" { void ScatterND(size_t indices_id, size_t updates_id, const DType dtype,
no need for this header file unless you are calling it from elsewhere in C++ (e.g. C++ unit tests, or other kernels)
tfjs-backend-wasm/src/cc/kernels/ScatterND.cc, line 53 at r1 (raw file):
tfjs::wasm::scatter<int32_t>( indices_buf, updates_info.i32(), slice_rank, num_updates, slice_size, strides, output_size, sizeof(int32), out_info.i32_write());
should this be size(int32_t)?
tfjs-backend-wasm/src/kernels/ScatterND.ts, line 66 at r1 (raw file):
const {sliceRank, numUpdates, sliceSize, strides, outputSize} = scatter_nd_util.calculateShapes( updates as Tensor, indices as Tensor, shape);
let's change the signature of calculateShapes to take TensorInfo instead of Tensor (which is sufficient), to avoid depending on Tensor in this file.
tfjs-core/src/index.ts, line 50 at r1 (raw file):
import * as math from './math'; import * as browser from './ops/browser'; import * as scatter_nd_util from './ops/scatter_nd_util';
since we are making this a public API, let's shorten it a bit: scatter_util (drop the nd)
nsthorat
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewable status:
complete! 1 of 1 approvals obtained (waiting on @annxingyuan, @dsmilkov, and @nsthorat)
tfjs-backend-wasm/src/cc/scatter_impl.cc, line 28 at r1 (raw file):
Previously, dsmilkov (Daniel Smilkov) wrote…
no need for a separate scatter_impl.cc file (move all this code inside kernels/ScatterND.cc) unless two or more kernels are using scatter as internal detail (sharing implementation), which is not the case as of this PR.
once you do that you can also remove the template instantiation below
tfjs-backend-wasm/src/cc/kernels/ScatterND.cc, line 34 at r1 (raw file):
#endif void ScatterND(size_t indices_id, size_t updates_id, const DType dtype,
make as many things const here as possible
tfjs-backend-wasm/src/kernels/ScatterND.ts, line 85 at r1 (raw file):
registerKernel({ kernelName: 'ScatterND',
the "D" should be lower cased (the tensorflow op name is "ScatterNd")
annxingyuan
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewable status:
complete! 1 of 1 approvals obtained (waiting on @master)
tfjs-backend-wasm/yarn.lock, line 70 at r1 (raw file):
Previously, dsmilkov (Daniel Smilkov) wrote…
the lock file seems to have changed a lot, but don't see any changes to package.json? did you explicitly remove the lock file before, or was the lock @master out of sync?
Yes I explicitly removed node_modules and the lockfile and re-yarned at some point because I was having compile issues and I thought maybe something was out of sync.
tfjs-backend-wasm/src/cc/scatter_impl.h, line 25 at r1 (raw file):
Previously, dsmilkov (Daniel Smilkov) wrote…
no need for this header file (see comment below)
Done
tfjs-backend-wasm/src/cc/scatter_impl.cc, line 28 at r1 (raw file):
Previously, nsthorat (Nikhil Thorat) wrote…
once you do that you can also remove the template instantiation below
Done
tfjs-backend-wasm/src/cc/kernels/ScatterND.h, line 24 at r1 (raw file):
Previously, dsmilkov (Daniel Smilkov) wrote…
no need for this header file unless you are calling it from elsewhere in C++ (e.g. C++ unit tests, or other kernels)
Done
tfjs-backend-wasm/src/cc/kernels/ScatterND.cc, line 34 at r1 (raw file):
Previously, nsthorat (Nikhil Thorat) wrote…
make as many things const here as possible
Done
tfjs-backend-wasm/src/cc/kernels/ScatterND.cc, line 53 at r1 (raw file):
Previously, dsmilkov (Daniel Smilkov) wrote…
should this be size(int32_t)?
Done
tfjs-backend-wasm/src/kernels/ScatterND.ts, line 66 at r1 (raw file):
Previously, dsmilkov (Daniel Smilkov) wrote…
let's change the signature of calculateShapes to take TensorInfo instead of Tensor (which is sufficient), to avoid depending on Tensor in this file.
Done
tfjs-backend-wasm/src/kernels/ScatterND.ts, line 85 at r1 (raw file):
Previously, nsthorat (Nikhil Thorat) wrote…
the "D" should be lower cased (the tensorflow op name is "ScatterNd")
Done
tfjs-core/src/index.ts, line 50 at r1 (raw file):
Previously, dsmilkov (Daniel Smilkov) wrote…
since we are making this a public API, let's shorten it a bit: scatter_util (drop the nd)
Done
dsmilkov
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Excellent work
Reviewed 10 of 11 files at r2, 2 of 2 files at r3.
Reviewable status:complete! 2 of 1 approvals obtained (waiting on @annxingyuan)
tfjs-backend-wasm/yarn.lock, line 70 at r1 (raw file):
Previously, annxingyuan (Ann Yuan) wrote…
Yes I explicitly removed node_modules and the lockfile and re-yarned at some point because I was having compile issues and I thought maybe something was out of sync.
chatted offline. need to run yarn build which will update package.json and yarn to depend to core@master
tfjs-backend-wasm/src/cc/kernels/ScatterND.cc, line 1 at r3 (raw file):
/* Copyright 2019 Google Inc. All Rights Reserved.
rename this to lowercase d as well
Changes
To see the logs from the Cloud Build CI, please join either our discussion or announcement mailing list.
This change is