-
Notifications
You must be signed in to change notification settings - Fork 2k
Wasm: Gather function produces wrong results when using batches. #8658
Copy link
Copy link
Open
Labels
Description
Please make sure that this is a bug. As per our
GitHub Policy,
we only address code/doc bugs, performance issues, feature requests and
build/installation issues on GitHub. tag:bug_template
System information
- Have I written custom code (as opposed to using a stock example script provided in TensorFlow.js): Yes
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Arch Linux
- Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device:
- TensorFlow.js installed from (npm or script link): NPM
- TensorFlow.js version (use command below): 4.22.0
- Browser version:
- Tensorflow.js Converter Version:
Describe the current behavior
In the gather operation, if we use batches we get incorrect results (and different that of tfjs cpu).
Describe the expected behavior
Should output the correct results (and similar to tfjs cpu)
Standalone code to reproduce the issue
import * as tf from '@tensorflow/tfjs';
import '@tensorflow/tfjs-backend-wasm';
async function test() {
await tf.setBackend('wasm');
await tf.ready();
let actions = [[1],
[2],
[3],
[4]];
actions = tf.tensor(actions);
let values = [[0.6180683 , -0.3866255, -0.1915082, 0.2409521 , -0.3771759],
[0.243487 , -0.1231469, -0.1254758, 0.1159831 , -0.2473405],
[-0.0443349, -0.1508565, -0.0818814, -0.0168378, -0.0160572],
[0.4880314 , -0.2639047, -0.1795055, 0.2236233 , -0.3836304]];
values = tf.tensor(values);
const result = tf.gather(
values,
actions.asType('int32'),
1,
1
);
result.print();
}
test();output:
Tensor
[[-0.3866255],
[0 ],
[-0.2639047],
[0 ]]tfjs cpu output (the correct and expected one):
============================
Hi, looks like you are running TensorFlow.js in Node.js. To speed things up dramatically, install our node backend, visit https://github.com/tensorflow/tfjs-node for more details.
============================
Tensor
[[-0.3866255],
[-0.1254758],
[-0.0168378],
[-0.3836304]]Other info / logs Include any logs or source code that would be helpful to
diagnose the problem. If including tracebacks, please include the full
traceback. Large logs and files should be attached.
Reactions are currently unavailable