Skip to content

Restructure multi_head_attention_forward #34573

@Enealor

Description

@Enealor

🚀 Feature

Restructure the function multi_head_attention_forward in nn.functional into several functions to improve the ability to experiment. In particular, decompose the function so that the following are available:

  • The input embedding functions.
  • The computation of attention weights.
  • The output embedding function.

This will allow users to try different embeddings or attention mechanisms without having to recode the rest.

Motivation

Addresses the issue of decomposing the function as mentioned in #32590. It also moves forward on including more support for attention mechanisms.

Pitch

Currently, the mutli_head_attention_forward function encapsulates the projection of the query, key, and value, computing attention for these projections, and computing the output projection after applying attention. Furthermore, the input embedding utilizes several code paths that are different embeddings. By decomposing the function into several parts, we can make it more readable and open to experimentation.

The following plan is based on the above:

  • Functions for computing the input embeddings q, k, and v. There are currently four code paths used for doing this, and three unique embeddings are used. Each embedding should be an individual function so that it's clearer what method is being used. The embeddings used are labeled 'self-attention' (where query = key = value), 'encoder-decoder attention' (where key = value) and one that is unlabeled but is probably just called attention. The last embedding has two code paths depending on whether in_proj_weight is used or separate weights are used for query, key and value. (See L3669-L3748.)
  • A function for applying attention to get a new query. Some models rely on computing attention in different ways, and separating this out would allow us to use those more freely. This should optionally return the attention weights. Specifically, this is the Scaled Dot-Product Attention. (See L3750-L3824.)
  • A function for computing the output projection of the query. There is currently only one function needed for doing this. (See L3826-L3836.)

Alternatives

Some of the restructurings I have suggested could be skipped in favor of introducing fewer functions. In particular, only one of the input embeddings needs to be provided. The rest could be left to the end-user.

cc @albanD @mruberry @jbschlosser @walterddr @mikaylagawarecki

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: nnRelated to torch.nntriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions