Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Require less alignment for attn bias #114173

Closed
wants to merge 5 commits into from

Conversation

drisspg
Copy link
Contributor

@drisspg drisspg commented Nov 20, 2023

Summary

Improved Fix for Attention Mask Alignment Issue (#112577)

This PR addresses Issue #112577 by refining the previously implemented fix, which was found to be incorrect and causes un-needed memory regressions. The update simplifies the approach to handling the alignment of the attention mask for mem eff attention.

Changes

Alignment Check and Padding: Initially, the alignment of the attention mask is checked. If misalignment is detected, padding is applied, followed by slicing. During this process, a warning is raised to alert users.

Should this be warn_once?

We only call expand, once on the aligned mask.

Reference
https://github.com/facebookresearch/xformers/blob/main/xformers/ops/fmha/cutlass.py#L115

@albanD, @mruberry, @jbschlosser, @walterddr, and @mikaylagawarecki.

cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @albanD @mruberry @jbschlosser @walterddr @mikaylagawarecki @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @voznesenskym @EikanWang @Guobing-Chen @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler @kiukchung @d4l3k @LucasLLC

Copy link

pytorch-bot bot commented Nov 20, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/114173

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (5 Unrelated Failures)

As of commit c861b7f with merge base 9d68cfe (image):

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@drisspg drisspg force-pushed the update_mask_preprocessing branch 2 times, most recently from db1acc9 to 4e572a3 Compare November 20, 2023 23:51
@drisspg
Copy link
Contributor Author

drisspg commented Nov 21, 2023

I am having the most annoying time wiith this, I am revisiting my old repro:

import torch
import copy
# torch.use_deterministic_algorithms(True)

num_heads = 16
head_dim = 128
torch.set_printoptions(threshold=1000000, sci_mode=True)

def _attn_sdpa(query, key, value, attention_mask=None, contiguify=False, enable_mem_efficient=False):
    query_shape = query.shape
    batch_size = query_shape[0]
    kv_seq_len = key.shape[-2]

    query_length = query_shape[1]

    # NOTE: Maybe there is better than this?
    query = query.view(batch_size, query_length, num_heads, head_dim).transpose(1, 2)

    # Without these unsqueeze, SDPA complains as the query and key/value have a different number of dimensions.
    key = key.unsqueeze(1)
    value = value.unsqueeze(1)

    key = key.expand(-1, num_heads, -1, -1)
    value = value.expand(-1, num_heads, -1, -1)

    if contiguify:
        key = key.contiguous()
        value = value.contiguous()

    if enable_mem_efficient:
        enable_math = False
    else:
        enable_math = True

    with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=enable_math, enable_mem_efficient=enable_mem_efficient):
        sdpa_result = torch.nn.functional.scaled_dot_product_attention(
            query,
            key,
            value,
            attn_mask=attention_mask,
            dropout_p=0.0,
            is_causal=False,
        )

    return sdpa_result

device = "cuda"

use_random=False
if use_random:
    print("using random tensors")
    seq_len_q =1
    seq_len_kv = 16
    query_sdpa = torch.rand(1,seq_len_q,2048, device=device)
    key_sdpa = torch.rand(1,seq_len_kv,128, device=device)
    value_sdpa = torch.rand(1,seq_len_kv,128, device=device)
    attention_mask_sdpa=torch.ones(1, 1, seq_len_q, seq_len_kv, device=device, dtype=torch.bool)
else:
    path = "/home/drisspg/meta/scripts/sdpa/"
    query_sdpa = torch.load(path+"query_sdpa.pt").to(device)
    key_sdpa = torch.load(path+"key_sdpa.pt").to(device)
    value_sdpa = torch.load(path+"value_sdpa.pt").to(device)
    attention_mask_sdpa = torch.load(path+"attention_mask_sdpa.pt").to(device)

print("query_sdpa", query_sdpa.shape)
print("key_sdpa", key_sdpa.shape)
print("value_sdpa", value_sdpa.shape)
print("attention_mask_sdpa", attention_mask_sdpa.shape)
print("attention_mask_sdpa", attention_mask_sdpa)


res_non_contig_cuda_math = _attn_sdpa(query_sdpa, key_sdpa, value_sdpa, attention_mask_sdpa, contiguify=False)
res_contig_cuda_math = _attn_sdpa(query_sdpa, key_sdpa, value_sdpa, attention_mask_sdpa, contiguify=True)

res_non_contig_cuda_memeff = _attn_sdpa(query_sdpa, key_sdpa, value_sdpa, attention_mask_sdpa, contiguify=False, enable_mem_efficient=True)
res_contig_cuda_memeff = _attn_sdpa(query_sdpa, key_sdpa, value_sdpa, attention_mask_sdpa,contiguify=True, enable_mem_efficient=True)

res_non_contig_cuda_memeff_no_mask = _attn_sdpa(query_sdpa, key_sdpa, value_sdpa, contiguify=False, enable_mem_efficient=True)
res_contig_cuda_memeff_no_mask = _attn_sdpa(query_sdpa, key_sdpa, value_sdpa, contiguify=True, enable_mem_efficient=True)

def print_diff(text, tensor1, tensor2):
    print(f"{text}: mean abs-diff", (tensor1 - tensor2).abs().mean())
    print(f"{text}: mean rel-diff", ((tensor1 - tensor2).abs() / (tensor1.abs() + 1e-12)).mean())

print("Math compare".center(80, "*"))
print_diff("cuda non-contig/contig math", res_non_contig_cuda_math, res_contig_cuda_math)
print_diff("cuda non-contig/contig memeff", res_non_contig_cuda_memeff, res_contig_cuda_memeff)

print("Mem-eff mask compare".center(80, "*"))
print("With Mask Contiguity | Without Mask Contiguity")
print_diff("non-contig  |  contig", res_non_contig_cuda_memeff, res_contig_cuda_memeff_no_mask)
print_diff("contig      |  contig", res_contig_cuda_memeff, res_contig_cuda_memeff_no_mask)
print_diff("non-contig  |  non-contig", res_non_contig_cuda_memeff, res_non_contig_cuda_memeff_no_mask)
print_diff("contig      |  non-contig", res_contig_cuda_memeff, res_non_contig_cuda_memeff_no_mask)
print("*".center(80, "*"))
print("Allclose CUDA math non-contig/contig:", torch.allclose(res_non_contig_cuda_math, res_contig_cuda_math))
print("Allclose CUDA memeff non-contig/contig:", torch.allclose(res_non_contig_cuda_memeff, res_contig_cuda_memeff))

If I comment out the math results then I do don't get any numerical deviation but with everything run I get:

**********************************Math compare**********************************
cuda non-contig/contig math: mean abs-diff tensor(0., device='cuda:0')
cuda non-contig/contig math: mean rel-diff tensor(0., device='cuda:0')
cuda non-contig/contig memeff: mean abs-diff tensor(1.4653e-03, device='cuda:0')
cuda non-contig/contig memeff: mean rel-diff tensor(6.0943e-01, device='cuda:0')
******************************Mem-eff mask compare******************************
With Mask Contiguity | Without Mask Contiguity
non-contig  |  contig: mean abs-diff tensor(1.4653e-03, device='cuda:0')
non-contig  |  contig: mean rel-diff tensor(6.0943e-01, device='cuda:0')
contig      |  contig: mean abs-diff tensor(0., device='cuda:0')
contig      |  contig: mean rel-diff tensor(0., device='cuda:0')
non-contig  |  non-contig: mean abs-diff tensor(1.4653e-03, device='cuda:0')
non-contig  |  non-contig: mean rel-diff tensor(6.0943e-01, device='cuda:0')
contig      |  non-contig: mean abs-diff tensor(0., device='cuda:0')
contig      |  non-contig: mean rel-diff tensor(0., device='cuda:0')
********************************************************************************
Allclose CUDA math non-contig/contig: True
Allclose CUDA memeff non-contig/contig: False

Ran through compute sanitizer and no errors were reported ...

@drisspg
Copy link
Contributor Author

drisspg commented Nov 21, 2023

When I swap the first block of two lines and call mem-eff before calling math there is no numerical differences..

res_non_contig_cuda_math = _attn_sdpa(query_sdpa, key_sdpa, value_sdpa, attention_mask_sdpa, contiguify=False)
res_contig_cuda_math = _attn_sdpa(query_sdpa, key_sdpa, value_sdpa, attention_mask_sdpa, contiguify=True)

res_non_contig_cuda_memeff = _attn_sdpa(query_sdpa, key_sdpa, value_sdpa, attention_mask_sdpa, contiguify=False, enable_mem_efficient=True)
res_contig_cuda_memeff = _attn_sdpa(query_sdpa, key_sdpa, value_sdpa, attention_mask_sdpa,contiguify=True, enable_mem_efficient=True)

@drisspg
Copy link
Contributor Author

drisspg commented Nov 21, 2023

Another fun data point, doesn't occur when dtype is fp16.

@drisspg drisspg force-pushed the update_mask_preprocessing branch 3 times, most recently from bfe6962 to 14277ed Compare November 21, 2023 17:44
@drisspg drisspg added module: performance Issues related to performance, either of kernel code or framework glue module: nn Related to torch.nn release notes: nn release notes category labels Nov 21, 2023
@drisspg drisspg requested a review from bdhirsh November 21, 2023 23:40
@drisspg drisspg added the ciflow/trunk Trigger trunk jobs on your pull request label Nov 22, 2023
Copy link
Contributor

@danthe3rd danthe3rd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! This looks good to me and should save a lot of memory for people using mask with broadcasting :)

aten/src/ATen/native/transformers/attention.cpp Outdated Show resolved Hide resolved
aten/src/ATen/native/transformers/attention.cpp Outdated Show resolved Hide resolved
@drisspg
Copy link
Contributor Author

drisspg commented Nov 22, 2023

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

# This is needed for the backward case:
# Hacky but not sure of a better way
if isinstance(arg.data, ir.StorageBox) and arg.data.is_input_buffer():
if "expand" in arg.data.get_name():
Copy link
Contributor Author

@drisspg drisspg Nov 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a better way to match on this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you clarify what exactly you are matching on or the case you're handling?

Copy link
Contributor

@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

asking a few questions

# This is needed for the backward case:
# Hacky but not sure of a better way
if isinstance(arg.data, ir.StorageBox) and arg.data.is_input_buffer():
if "expand" in arg.data.get_name():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you clarify what exactly you are matching on or the case you're handling?

torch/_inductor/lowering.py Show resolved Hide resolved
torch/_inductor/lowering.py Show resolved Hide resolved
@github-actions github-actions bot added module: cpu CPU specific problem (e.g., perf, algorithm) module: dynamo module: distributed labels Nov 27, 2023
@drisspg drisspg force-pushed the update_mask_preprocessing branch 2 times, most recently from 917c605 to 300081d Compare November 27, 2023 20:32
@drisspg
Copy link
Contributor Author

drisspg commented Nov 28, 2023

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

vfdev-5 pushed a commit to vfdev-5/pytorch that referenced this pull request Nov 29, 2023
# Summary
Improved Fix for Attention Mask Alignment Issue (pytorch#112577)

This PR addresses Issue pytorch#112577 by refining the previously implemented fix, which was found to be incorrect and causes un-needed memory regressions. The update simplifies the approach to handling the alignment of the attention mask for mem eff attention.

## Changes
Alignment Check and Padding: Initially, the alignment of the attention mask is checked. If misalignment is detected, padding is applied, followed by slicing. During this process, a warning is raised to alert users.

Should this be warn_once?

We only call expand, once on the aligned mask.

Reference
https://github.com/facebookresearch/xformers/blob/main/xformers/ops/fmha/cutlass.py#L115

@albanD, @mruberry, @jbschlosser, @walterddr, and @mikaylagawarecki.

Pull Request resolved: pytorch#114173
Approved by: https://github.com/danthe3rd
drisspg added a commit to drisspg/pytorch that referenced this pull request Nov 30, 2023
Improved Fix for Attention Mask Alignment Issue (pytorch#112577)

This PR addresses Issue pytorch#112577 by refining the previously implemented fix, which was found to be incorrect and causes un-needed memory regressions. The update simplifies the approach to handling the alignment of the attention mask for mem eff attention.

Alignment Check and Padding: Initially, the alignment of the attention mask is checked. If misalignment is detected, padding is applied, followed by slicing. During this process, a warning is raised to alert users.

Should this be warn_once?

We only call expand, once on the aligned mask.

Reference
https://github.com/facebookresearch/xformers/blob/main/xformers/ops/fmha/cutlass.py#L115

@albanD, @mruberry, @jbschlosser, @walterddr, and @mikaylagawarecki.

Pull Request resolved: pytorch#114173
Approved by: https://github.com/danthe3rd
drisspg added a commit to drisspg/pytorch that referenced this pull request Nov 30, 2023
Improved Fix for Attention Mask Alignment Issue (pytorch#112577)

This PR addresses Issue pytorch#112577 by refining the previously implemented fix, which was found to be incorrect and causes un-needed memory regressions. The update simplifies the approach to handling the alignment of the attention mask for mem eff attention.

Alignment Check and Padding: Initially, the alignment of the attention mask is checked. If misalignment is detected, padding is applied, followed by slicing. During this process, a warning is raised to alert users.

Should this be warn_once?

We only call expand, once on the aligned mask.

Reference
https://github.com/facebookresearch/xformers/blob/main/xformers/ops/fmha/cutlass.py#L115

@albanD, @mruberry, @jbschlosser, @walterddr, and @mikaylagawarecki.

Pull Request resolved: pytorch#114173
Approved by: https://github.com/danthe3rd
drisspg added a commit to drisspg/pytorch that referenced this pull request Nov 30, 2023
Improved Fix for Attention Mask Alignment Issue (pytorch#112577)

This PR addresses Issue pytorch#112577 by refining the previously implemented fix, which was found to be incorrect and causes un-needed memory regressions. The update simplifies the approach to handling the alignment of the attention mask for mem eff attention.

Alignment Check and Padding: Initially, the alignment of the attention mask is checked. If misalignment is detected, padding is applied, followed by slicing. During this process, a warning is raised to alert users.

Should this be warn_once?

We only call expand, once on the aligned mask.

Reference
https://github.com/facebookresearch/xformers/blob/main/xformers/ops/fmha/cutlass.py#L115

@albanD, @mruberry, @jbschlosser, @walterddr, and @mikaylagawarecki.

Pull Request resolved: pytorch#114173
Approved by: https://github.com/danthe3rd
drisspg added a commit to drisspg/pytorch that referenced this pull request Nov 30, 2023
Improved Fix for Attention Mask Alignment Issue (pytorch#112577)

This PR addresses Issue pytorch#112577 by refining the previously implemented fix, which was found to be incorrect and causes un-needed memory regressions. The update simplifies the approach to handling the alignment of the attention mask for mem eff attention.

Alignment Check and Padding: Initially, the alignment of the attention mask is checked. If misalignment is detected, padding is applied, followed by slicing. During this process, a warning is raised to alert users.

Should this be warn_once?

We only call expand, once on the aligned mask.

Reference
https://github.com/facebookresearch/xformers/blob/main/xformers/ops/fmha/cutlass.py#L115

@albanD, @mruberry, @jbschlosser, @walterddr, and @mikaylagawarecki.

Pull Request resolved: pytorch#114173
Approved by: https://github.com/danthe3rd
atalman pushed a commit that referenced this pull request Dec 5, 2023
Improved Fix for Attention Mask Alignment Issue (#112577)

This PR addresses Issue #112577 by refining the previously implemented fix, which was found to be incorrect and causes un-needed memory regressions. The update simplifies the approach to handling the alignment of the attention mask for mem eff attention.

Alignment Check and Padding: Initially, the alignment of the attention mask is checked. If misalignment is detected, padding is applied, followed by slicing. During this process, a warning is raised to alert users.

Should this be warn_once?

We only call expand, once on the aligned mask.

Reference
https://github.com/facebookresearch/xformers/blob/main/xformers/ops/fmha/cutlass.py#L115

@albanD, @mruberry, @jbschlosser, @walterddr, and @mikaylagawarecki.

Pull Request resolved: #114173
Approved by: https://github.com/danthe3rd
@albanD albanD added oncall: distributed Add this issue/PR to distributed oncall triage queue and removed module: distributed labels Dec 8, 2023
atalman added a commit that referenced this pull request Dec 12, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged module: cpu CPU specific problem (e.g., perf, algorithm) module: dynamo module: inductor module: nn Related to torch.nn module: performance Issues related to performance, either of kernel code or framework glue oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: nn release notes category Reverted
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants