Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Intel MKL] Optimized implementation of GatherND using OpenMP #20237

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
15 changes: 15 additions & 0 deletions tensorflow/core/kernels/gather_nd_op_cpu_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,25 @@ struct GatherNdSlice<CPUDevice, T, Index, IXDIM> {
#endif
generator::GatherNdSliceGenerator<T, Index, IXDIM> gather_nd_generator(
slice_size, Tindices, Tparams, Tout, &error_loc);

#ifdef INTEL_MKL
// Eigen implementation below is not highly performant. gather_nd_generator
// does not seem to be called in parallel, leading to very poor performance.
// Additionally, since it uses scalar (Tscratch) to invoke 'generate', it
// needs to go through redundant operations like 'reshape', 'broadcast' and
// 'sum'. OpenMP loop below essentially does same thing as Eigen code, but
// is considerably more efficient.
#pragma omp parallel for
for (Eigen::DenseIndex i = 0; i < batch_size; i++) {
const Eigen::array<Eigen::DenseIndex, 1> loc = i;
gather_nd_generator(loc);
}
#else
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks! if you have time, can you also make a non-intel mkl modification below to use Shard(), like this example?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ebrevdo Sorry, but I don't think I have time for that change. Context switch to make this small change is going to be costly for me. Sorry again.

Tscratch.device(d) = Tscratch.reshape(reshape_dims)
.broadcast(broadcast_dims)
.generate(gather_nd_generator)
.sum();
#endif

// error_loc() returns -1 if there's no out-of-bounds index,
// otherwise it returns the location of an OOB index in Tindices.
Expand Down