Skip to content

Conversation

@larryliu0820
Copy link
Contributor

Summary:

In torchtune's MultiHeadAttention we have this logic:

If y is not None, calculate the values of k and v from y and update the KVCache.

Otherwise (if y is None), retrieve the value of k and v from KVCache.

This logic is not able to be handled by export world. Here I'm proposing a rewrite:

If y does not have all values equal to nan (not a number), calculate the values of k and v from y and update the KVCache.

Otherwise (if all of the values of y are nan), retrieve the value of k and v from KVCache.

This rewrite allows the module to satisfy the requirement of torch.cond and avoid specialization:

  • The operands to torch.cond should have the same shape for the true branch and the false branch.

This means we will have to change this logic in torchtune:

        if encoder_input is not None:
            encoder_embed = self.encoder(**encoder_input)

        output = self.decoder(
            tokens=tokens,
            mask=mask,
            encoder_input=encoder_embed,
            encoder_mask=encoder_mask,
            input_pos=input_pos,
        )

To be:

        if encoder_input is not None:
            encoder_embed = self.encoder(**encoder_input)
        else:
            encoder_embed = torch.full_like(encoder_input, torch.nan)
        output = self.decoder(
            tokens=tokens,
            mask=mask,
            encoder_input=encoder_embed,
            encoder_mask=encoder_mask,
            input_pos=input_pos,
        )

Test Plan: Rely on unit tests

Reviewers:

Subscribers:

Tasks:

Tags:

Summary

[PLEASE REMOVE] See CONTRIBUTING.md's Pull Requests for ExecuTorch PR guidelines.

[PLEASE REMOVE] If this PR closes an issue, please add a Fixes #<issue-id> line.

[PLEASE REMOVE] If this PR introduces a fix or feature that should be the upcoming release notes, please add a "Release notes: " label. For a list of available release notes labels, check out CONTRIBUTING.md's Pull Requests.

Test plan

[PLEASE REMOVE] How did you test this PR? Please write down any manual commands you used and note down tests that you have written if applicable.

@pytorch-bot
Copy link

pytorch-bot bot commented Nov 14, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/6869

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ No Failures

As of commit b7763e9 with merge base ad15852 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Nov 14, 2024
Copy link
Contributor

@jackzhxng jackzhxng left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems fine - so instead of passing in y is None we pass in a tensor full of zeroes

Summary:

In torchtune's MultiHeadAttention we have this logic:

If `y` is not None, calculate the values of `k` and `v` from y and
update the KVCache.

Otherwise (if `y` is None), retrieve the value of `k` and `v` from
KVCache.

This logic is not able to be handled by export world. Here I'm proposing
a rewrite:

If `y` does not have all values equal to nan (not a number), calculate
the values of `k` and `v` from `y` and update the KVCache.

Otherwise (if all of the values of `y` are nan), retrieve the value of
`k` and `v` from KVCache.

This rewrite allows the module to satisfy the requirement of
`torch.cond` and avoid specialization:
* The operands to `torch.cond` should have the same shape for the true
  branch and the false branch.

This means we will have to change this logic in torchtune:

```
        if encoder_input is not None:
            encoder_embed = self.encoder(**encoder_input)

        output = self.decoder(
            tokens=tokens,
            mask=mask,
            encoder_input=encoder_embed,
            encoder_mask=encoder_mask,
            input_pos=input_pos,
        )
```

To be:

```
        if encoder_input is not None:
            encoder_embed = self.encoder(**encoder_input)
        else:
            encoder_embed = torch.full_like(encoder_input, torch.nan)
        output = self.decoder(
            tokens=tokens,
            mask=mask,
            encoder_input=encoder_embed,
            encoder_mask=encoder_mask,
            input_pos=input_pos,
        )
```

Test Plan: Rely on unit tests

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
@larryliu0820 larryliu0820 merged commit 04b3d92 into main Nov 15, 2024
39 checks passed
@larryliu0820 larryliu0820 deleted the torch_cond_attention branch November 15, 2024 08:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants