Skip to content

Conversation

@vyom1611
Copy link
Contributor

@vyom1611 vyom1611 commented Jan 6, 2025

  • Replaces non-jittable code with a new, JAX-compatible approach for computing embedding bags.

  • Handles offsets=None by performing a straightforward reduction (sum, mean, or max) across each row of the embedded tensor, aligning with PyTorch’s behavior for multi-dimensional indices.

  • Computes and returns offset2bag, bag_size, and max_indices when offsets is given, matching each mode (sum, mean, or max).

  • Converts offset2bag and bag_size to JAX arrays before returning, maintaining consistent data types.

  • Ensures the function’s return signature fully matches PyTorch expectations

qihqi
qihqi previously approved these changes Jan 6, 2025
@qihqi qihqi dismissed their stale review January 6, 2025 23:36

i'll run ci first

@qihqi qihqi self-requested a review January 6, 2025 23:36
@qihqi qihqi merged commit 5278e7a into pytorch:master Jan 7, 2025
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants