-
Notifications
You must be signed in to change notification settings - Fork 21.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature request: Weighted average for EmbeddingBag #4068
Comments
Hi, any updates on this? Being able to provide weights (on top of the indices) would be really useful. |
This would be very helpful for my work with the Russian language. |
As some kind of motivation, I will just post a link to my post, where in many applications for Russian EmbeddingBags were superior to BPE =) |
API Bikeshedding: which of these two APIs would be better?
I'm leaning towards (2) because I haven't been able to find use cases for "weighted mean" (which can be emulated via weighted sum) and "weighted max". |
FWIW, you don't have to actually implemented weighted mean and weighted max if you implement (1); you can just make them raise errors. (This is not necessarily in favor of (1), but it's a comment on the reasoning.) |
Most likely these weights will be calculated using some sort of attention mechanism I wonder whether something like this can be implemented inside of this layer |
On the way to pytorch#4068. Adds a new per_sample_weights argument to nn.EmbeddingBag's forward pass and embedding_bag. This is only supported for mode='sum' and is intepreted as scaling the output of the embedding before applying the reduction. i.e., indices: 0, 3, 7 ; 1, 2 per_sample_weights: 0.1, 0.2, 0.4 ; 0.7, -0.8 offsets: 0, 3 weights (embeddings): e_0, e_1, e_2, ..., e_7 return 2 vectors: 0.1 * e_0 + 0.2 * e_3 + 0.4 * e_7 0.7 * e_1 - 0.8 * e_2 Future: - CPU backward, - CUDA forward, - CUDA backward, - CPU differentiable per_sample_weights - CUDA differentiable per_sample_weights Test Plan: - New tests
On the way to #4068. Adds a new per_sample_weights argument to nn.EmbeddingBag's forward pass and embedding_bag. This is only supported for mode='sum' and is intepreted as scaling the output of the embedding before applying the reduction. i.e., indices: 0, 3, 7 ; 1, 2 per_sample_weights: 0.1, 0.2, 0.4 ; 0.7, -0.8 offsets: 0, 3 weights (embeddings): e_0, e_1, e_2, ..., e_7 return 2 vectors: 0.1 * e_0 + 0.2 * e_3 + 0.4 * e_7 0.7 * e_1 - 0.8 * e_2 Future: - CPU backward, - CUDA forward, - CUDA backward, - CPU differentiable per_sample_weights - CUDA differentiable per_sample_weights Test Plan: - New tests
On the way to #4068. Adds a new per_sample_weights argument to nn.EmbeddingBag's forward pass and embedding_bag. This is only supported for mode='sum' and is intepreted as scaling the output of the embedding before applying the reduction. i.e., indices: 0, 3, 7 ; 1, 2 per_sample_weights: 0.1, 0.2, 0.4 ; 0.7, -0.8 offsets: 0, 3 weights (embeddings): e_0, e_1, e_2, ..., e_7 return 2 vectors: 0.1 * e_0 + 0.2 * e_3 + 0.4 * e_7 0.7 * e_1 - 0.8 * e_2 Future: - CPU backward, - CUDA forward, - CUDA backward, - CPU differentiable per_sample_weights - CUDA differentiable per_sample_weights Test Plan: - New tests gh-metadata: pytorch pytorch 18735 gh/zou3519/26/head
EmbeddingBag CPU forward with per_sample_weights. On the way to #4068. Adds a new per_sample_weights argument to nn.EmbeddingBag's forward pass and embedding_bag. This is only supported for mode='sum' and is intepreted as scaling the output of the embedding before applying the reduction. i.e., indices: 0, 3, 7 ; 1, 2 per_sample_weights: 0.1, 0.2, 0.4 ; 0.7, -0.8 offsets: 0, 3 weights (embeddings): e_0, e_1, e_2, ..., e_7 return 2 vectors: 0.1 * e_0 + 0.2 * e_3 + 0.4 * e_7 0.7 * e_1 - 0.8 * e_2 Future: - CPU backward, - CUDA forward, - CUDA backward, - CPU differentiable per_sample_weights - CUDA differentiable per_sample_weights Test Plan: - New tests gh-metadata: pytorch pytorch 18735 gh/zou3519/26/head
EmbeddingBag CPU forward with per_sample_weights. On the way to #4068. Adds a new per_sample_weights argument to nn.EmbeddingBag's forward pass and embedding_bag. This is only supported for mode='sum' and is intepreted as scaling the output of the embedding before applying the reduction. i.e., ``` indices: 0, 3, 7 ; 1, 2 per_sample_weights: 0.1, 0.2, 0.4 ; 0.7, -0.8 offsets: 0, 3 weights (embeddings): e_0, e_1, e_2, ..., e_7 ``` return 2 vectors: ``` 0.1 * e_0 + 0.2 * e_3 + 0.4 * e_7 0.7 * e_1 - 0.8 * e_2 ``` Future: - CPU backward, - CUDA forward, - CUDA backward, - CPU differentiable per_sample_weights - CUDA differentiable per_sample_weights Test Plan: - New tests gh-metadata: pytorch pytorch 18735 gh/zou3519/26/head
EmbeddingBag CPU forward with per_sample_weights. On the way to #4068. Adds a new per_sample_weights argument to nn.EmbeddingBag's forward pass and embedding_bag. This is only supported for mode='sum' and is intepreted as scaling the output of the embedding before applying the reduction. i.e., ``` indices: 0, 3, 7 ; 1, 2 per_sample_weights: 0.1, 0.2, 0.4 ; 0.7, -0.8 offsets: 0, 3 weights (embeddings): e_0, e_1, e_2, ..., e_7 ``` return 2 vectors: ``` 0.1 * e_0 + 0.2 * e_3 + 0.4 * e_7 0.7 * e_1 - 0.8 * e_2 ``` Future: - CPU backward, - CUDA forward, - CUDA backward, - CPU differentiable per_sample_weights - CUDA differentiable per_sample_weights Test Plan: - New tests gh-metadata: pytorch pytorch 18735 gh/zou3519/26/head
EmbeddingBag CPU forward with per_sample_weights. On the way to #4068. Adds a new per_sample_weights argument to nn.EmbeddingBag's forward pass and embedding_bag. This is only supported for mode='sum' and is intepreted as scaling the output of the embedding before applying the reduction. i.e., ``` indices: 0, 3, 7 ; 1, 2 per_sample_weights: 0.1, 0.2, 0.4 ; 0.7, -0.8 offsets: 0, 3 weights (embeddings): e_0, e_1, e_2, ..., e_7 ``` return 2 vectors: ``` 0.1 * e_0 + 0.2 * e_3 + 0.4 * e_7 0.7 * e_1 - 0.8 * e_2 ``` Future: - CPU backward, - CUDA forward, - CUDA backward, - CPU differentiable per_sample_weights - CUDA differentiable per_sample_weights Test Plan: - New tests gh-metadata: pytorch pytorch 18735 gh/zou3519/26/head
EmbeddingBag CPU forward with per_sample_weights. On the way to #4068. Adds a new per_sample_weights argument to nn.EmbeddingBag's forward pass and embedding_bag. This is only supported for mode='sum' and is intepreted as scaling the output of the embedding before applying the reduction. i.e., ``` indices: 0, 3, 7 ; 1, 2 per_sample_weights: 0.1, 0.2, 0.4 ; 0.7, -0.8 offsets: 0, 3 weights (embeddings): e_0, e_1, e_2, ..., e_7 ``` return 2 vectors: ``` 0.1 * e_0 + 0.2 * e_3 + 0.4 * e_7 0.7 * e_1 - 0.8 * e_2 ``` Future: - CPU backward, - CUDA forward, - CUDA backward, - CPU differentiable per_sample_weights - CUDA differentiable per_sample_weights Test Plan: - New tests gh-metadata: pytorch pytorch 18735 gh/zou3519/26/head
EmbeddingBag CPU forward with per_sample_weights. On the way to #4068. Adds a new per_sample_weights argument to nn.EmbeddingBag's forward pass and embedding_bag. This is only supported for mode='sum' and is intepreted as scaling the output of the embedding before applying the reduction. i.e., ``` indices: 0, 3, 7 ; 1, 2 per_sample_weights: 0.1, 0.2, 0.4 ; 0.7, -0.8 offsets: 0, 3 weights (embeddings): e_0, e_1, e_2, ..., e_7 ``` return 2 vectors: ``` 0.1 * e_0 + 0.2 * e_3 + 0.4 * e_7 0.7 * e_1 - 0.8 * e_2 ``` Future: - CPU backward, - CUDA forward, - CUDA backward, - CPU differentiable per_sample_weights - CUDA differentiable per_sample_weights Test Plan: - New tests gh-metadata: pytorch pytorch 18735 gh/zou3519/26/head
Added the feature in #18957. |
Many thanks! |
Is there currently a plan to implement per_sample_weights on CUDA for max aggregation? |
Right now 'torch.nn.EmbeddingBag' supports only 'sum' and 'mean'. What do you think about providing an option for weights to compute 'weighted average'? This would be more memory efficient than using current alternative.
For instance something like 'sp_weights' in 'tf.nn.embedding_lookup_sparse' [1].
References:
[1] https://www.tensorflow.org/api_docs/python/tf/nn/embedding_lookup_sparse
The text was updated successfully, but these errors were encountered: