-
Notifications
You must be signed in to change notification settings - Fork 2k
[WASM] Add Concat and Transpose kernels
#2303
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
nsthorat
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewed 19 of 19 files at r1.
Reviewable status:complete! 1 of 1 approvals obtained (waiting on @annxingyuan, @dsmilkov, and @nsthorat)
tfjs-backend-wasm/src/cc/kernels/BatchMatMul.cc, line 54 at r1 (raw file):
std::fill(out_buf, out_buf + batch_dim * size, 0); for (int b = 0; b < batch_dim; ++b) {
just curious does this matter?
tfjs-backend-wasm/src/kernels/all_kernels.ts, line 33 at r1 (raw file):
import './Sigmoid'; import './Slice'; import './Square';
oops! something to think about is how we can catch this. for now this is cool.
tfjs-backend-wasm/src/kernels/Concat.ts, line 57 at r1 (raw file):
kernelName: 'Concat', backendName: 'wasm', kernelFunc: concat as {} as KernelFunc,
why do you have to cast as {}?
tfjs-backend-wasm/src/kernels/Transpose.ts, line 58 at r1 (raw file):
dtype: inputs.x.dtype }; let noChange = true;
noShapeChange
dsmilkov
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewable status:
complete! 1 of 1 approvals obtained (waiting on @annxingyuan and @nsthorat)
tfjs-backend-wasm/src/cc/kernels/BatchMatMul.cc, line 54 at r1 (raw file):
Previously, nsthorat (Nikhil Thorat) wrote…
just curious does this matter?
yup, which is why the style guide says to use preincrement: https://google.github.io/styleguide/cppguide.html#Preincrement_and_Predecrement
tfjs-backend-wasm/src/kernels/all_kernels.ts, line 33 at r1 (raw file):
Previously, nsthorat (Nikhil Thorat) wrote…
oops! something to think about is how we can catch this. for now this is cool.
Yeah, something analogous to the test-code-snippets in core would help
tfjs-backend-wasm/src/kernels/Concat.ts, line 57 at r1 (raw file):
Previously, nsthorat (Nikhil Thorat) wrote…
why do you have to cast as {}?
because concat takes arbitrary number of tensors, so I typed the input as TensorInfo[] but the compiler wants NamedTensorInfoMap and one can't inherit from the other.
tfjs-backend-wasm/src/kernels/Transpose.ts, line 58 at r1 (raw file):
Previously, nsthorat (Nikhil Thorat) wrote…
noShapeChange
Done.
annxingyuan
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewable status:
complete! 2 of 1 approvals obtained (waiting on @annxingyuan, @dsmilkov, and @nsthorat)
tfjs-backend-wasm/src/cc/backend.h, line 49 at r2 (raw file):
// bucket. TensorInfo &get_tensor_info(int tensor_id);
Just curious - do things work without the & prefix?
tfjs-backend-wasm/src/cc/util.h, line 142 at r2 (raw file):
// Returns the strides of a tensor given its shape. Note that the strides // are of length R-1 where R is the rank of the tensor. std::vector<int> compute_strides(const std::vector<int> shape);
Just curious - how do we decide whether to inline a function as opposed to defining it in the cc file?
tfjs-backend-wasm/src/cc/kernels/Transpose.cc, line 193 at r2 (raw file):
out_data[new_i] = x_data[i]; } }
Do you happen to know how much slower the generic implementation is?
tfjs-backend-wasm/src/kernels/Concat.ts, line 39 at r2 (raw file):
return innerDim; }); const inVals = inputs.map(input => backend.typedArrayFromHeap(input));
Question for my own understanding - you're saying calling typedArrayFromHeap is essentially free?
dsmilkov
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewed 2 of 7 files at r2.
Reviewable status:complete! 2 of 1 approvals obtained (waiting on @annxingyuan and @nsthorat)
tfjs-backend-wasm/src/cc/backend.h, line 49 at r2 (raw file):
Previously, annxingyuan (Ann Yuan) wrote…
Just curious - do things work without the & prefix?
It does but without the &, it makes a copy every time it returns a result (can be much slower). With &, it returns a reference, which is just a pointer internally.
tfjs-backend-wasm/src/cc/util.h, line 142 at r2 (raw file):
Previously, annxingyuan (Ann Yuan) wrote…
Just curious - how do we decide whether to inline a function as opposed to defining it in the cc file?
there is no hard rule. The styleguide says when it's less than 10 lines long. https://google.github.io/styleguide/cppguide.html#Inline_Functions
The reason I inlined the previous methods, but not compute_strides is because the previous methods (loc_to_offset, offset_to_loc, offset) are called in the inner-most loop (hottest code) and the compiler has more opportunity to optimize that code when it is inlined (compile-time optimization is better than link-time optimization)
tfjs-backend-wasm/src/cc/kernels/Transpose.cc, line 193 at r2 (raw file):
Previously, annxingyuan (Ann Yuan) wrote…
Do you happen to know how much slower the generic implementation is?
Yes, by at factor of 10x
tfjs-backend-wasm/src/kernels/Concat.ts, line 39 at r2 (raw file):
Previously, annxingyuan (Ann Yuan) wrote…
Question for my own understanding - you're saying calling typedArrayFromHeap is essentially free?
Yes, It's just a view (buffer) at an offset in the heap. There is no memory copy, no allocation.
Add optimized kernels for
ConcatandTransposekernels.concatwe don't call into C++ since there a single for-loop with trivial memory access pattern and we just copy chucks of memory aroundtranspose, we do call into C++ since there are various nested for-loops with non-trivial memory access pattern, depending on the rank of the tensor.This change is