-
Notifications
You must be signed in to change notification settings - Fork 24.9k
Description
🚀 Feature
Restructure the function multi_head_attention_forward
in nn.functional into several functions to improve the ability to experiment. In particular, decompose the function so that the following are available:
- The input embedding functions.
- The computation of attention weights.
- The output embedding function.
This will allow users to try different embeddings or attention mechanisms without having to recode the rest.
Motivation
Addresses the issue of decomposing the function as mentioned in #32590. It also moves forward on including more support for attention mechanisms.
Pitch
Currently, the mutli_head_attention_forward
function encapsulates the projection of the query
, key
, and value
, computing attention for these projections, and computing the output projection after applying attention. Furthermore, the input embedding utilizes several code paths that are different embeddings. By decomposing the function into several parts, we can make it more readable and open to experimentation.
The following plan is based on the above:
- Functions for computing the input embeddings
q
,k
, andv
. There are currently four code paths used for doing this, and three unique embeddings are used. Each embedding should be an individual function so that it's clearer what method is being used. The embeddings used are labeled 'self-attention' (wherequery = key = value
), 'encoder-decoder attention' (wherekey = value
) and one that is unlabeled but is probably just calledattention
. The last embedding has two code paths depending on whetherin_proj_weight
is used or separate weights are used forquery
,key
andvalue
. (See L3669-L3748.) - A function for applying attention to get a new query. Some models rely on computing attention in different ways, and separating this out would allow us to use those more freely. This should optionally return the attention weights. Specifically, this is the Scaled Dot-Product Attention. (See L3750-L3824.)
- A function for computing the output projection of the query. There is currently only one function needed for doing this. (See L3826-L3836.)
Alternatives
Some of the restructurings I have suggested could be skipped in favor of introducing fewer functions. In particular, only one of the input embeddings needs to be provided. The rest could be left to the end-user.
cc @albanD @mruberry @jbschlosser @walterddr @mikaylagawarecki