New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Require less alignment for attn bias #114173
Conversation
db1acc9
to
4e572a3
Compare
I am having the most annoying time wiith this, I am revisiting my old repro: import torch
import copy
# torch.use_deterministic_algorithms(True)
num_heads = 16
head_dim = 128
torch.set_printoptions(threshold=1000000, sci_mode=True)
def _attn_sdpa(query, key, value, attention_mask=None, contiguify=False, enable_mem_efficient=False):
query_shape = query.shape
batch_size = query_shape[0]
kv_seq_len = key.shape[-2]
query_length = query_shape[1]
# NOTE: Maybe there is better than this?
query = query.view(batch_size, query_length, num_heads, head_dim).transpose(1, 2)
# Without these unsqueeze, SDPA complains as the query and key/value have a different number of dimensions.
key = key.unsqueeze(1)
value = value.unsqueeze(1)
key = key.expand(-1, num_heads, -1, -1)
value = value.expand(-1, num_heads, -1, -1)
if contiguify:
key = key.contiguous()
value = value.contiguous()
if enable_mem_efficient:
enable_math = False
else:
enable_math = True
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=enable_math, enable_mem_efficient=enable_mem_efficient):
sdpa_result = torch.nn.functional.scaled_dot_product_attention(
query,
key,
value,
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=False,
)
return sdpa_result
device = "cuda"
use_random=False
if use_random:
print("using random tensors")
seq_len_q =1
seq_len_kv = 16
query_sdpa = torch.rand(1,seq_len_q,2048, device=device)
key_sdpa = torch.rand(1,seq_len_kv,128, device=device)
value_sdpa = torch.rand(1,seq_len_kv,128, device=device)
attention_mask_sdpa=torch.ones(1, 1, seq_len_q, seq_len_kv, device=device, dtype=torch.bool)
else:
path = "/home/drisspg/meta/scripts/sdpa/"
query_sdpa = torch.load(path+"query_sdpa.pt").to(device)
key_sdpa = torch.load(path+"key_sdpa.pt").to(device)
value_sdpa = torch.load(path+"value_sdpa.pt").to(device)
attention_mask_sdpa = torch.load(path+"attention_mask_sdpa.pt").to(device)
print("query_sdpa", query_sdpa.shape)
print("key_sdpa", key_sdpa.shape)
print("value_sdpa", value_sdpa.shape)
print("attention_mask_sdpa", attention_mask_sdpa.shape)
print("attention_mask_sdpa", attention_mask_sdpa)
res_non_contig_cuda_math = _attn_sdpa(query_sdpa, key_sdpa, value_sdpa, attention_mask_sdpa, contiguify=False)
res_contig_cuda_math = _attn_sdpa(query_sdpa, key_sdpa, value_sdpa, attention_mask_sdpa, contiguify=True)
res_non_contig_cuda_memeff = _attn_sdpa(query_sdpa, key_sdpa, value_sdpa, attention_mask_sdpa, contiguify=False, enable_mem_efficient=True)
res_contig_cuda_memeff = _attn_sdpa(query_sdpa, key_sdpa, value_sdpa, attention_mask_sdpa,contiguify=True, enable_mem_efficient=True)
res_non_contig_cuda_memeff_no_mask = _attn_sdpa(query_sdpa, key_sdpa, value_sdpa, contiguify=False, enable_mem_efficient=True)
res_contig_cuda_memeff_no_mask = _attn_sdpa(query_sdpa, key_sdpa, value_sdpa, contiguify=True, enable_mem_efficient=True)
def print_diff(text, tensor1, tensor2):
print(f"{text}: mean abs-diff", (tensor1 - tensor2).abs().mean())
print(f"{text}: mean rel-diff", ((tensor1 - tensor2).abs() / (tensor1.abs() + 1e-12)).mean())
print("Math compare".center(80, "*"))
print_diff("cuda non-contig/contig math", res_non_contig_cuda_math, res_contig_cuda_math)
print_diff("cuda non-contig/contig memeff", res_non_contig_cuda_memeff, res_contig_cuda_memeff)
print("Mem-eff mask compare".center(80, "*"))
print("With Mask Contiguity | Without Mask Contiguity")
print_diff("non-contig | contig", res_non_contig_cuda_memeff, res_contig_cuda_memeff_no_mask)
print_diff("contig | contig", res_contig_cuda_memeff, res_contig_cuda_memeff_no_mask)
print_diff("non-contig | non-contig", res_non_contig_cuda_memeff, res_non_contig_cuda_memeff_no_mask)
print_diff("contig | non-contig", res_contig_cuda_memeff, res_non_contig_cuda_memeff_no_mask)
print("*".center(80, "*"))
print("Allclose CUDA math non-contig/contig:", torch.allclose(res_non_contig_cuda_math, res_contig_cuda_math))
print("Allclose CUDA memeff non-contig/contig:", torch.allclose(res_non_contig_cuda_memeff, res_contig_cuda_memeff)) If I comment out the math results then I do don't get any numerical deviation but with everything run I get: **********************************Math compare**********************************
cuda non-contig/contig math: mean abs-diff tensor(0., device='cuda:0')
cuda non-contig/contig math: mean rel-diff tensor(0., device='cuda:0')
cuda non-contig/contig memeff: mean abs-diff tensor(1.4653e-03, device='cuda:0')
cuda non-contig/contig memeff: mean rel-diff tensor(6.0943e-01, device='cuda:0')
******************************Mem-eff mask compare******************************
With Mask Contiguity | Without Mask Contiguity
non-contig | contig: mean abs-diff tensor(1.4653e-03, device='cuda:0')
non-contig | contig: mean rel-diff tensor(6.0943e-01, device='cuda:0')
contig | contig: mean abs-diff tensor(0., device='cuda:0')
contig | contig: mean rel-diff tensor(0., device='cuda:0')
non-contig | non-contig: mean abs-diff tensor(1.4653e-03, device='cuda:0')
non-contig | non-contig: mean rel-diff tensor(6.0943e-01, device='cuda:0')
contig | non-contig: mean abs-diff tensor(0., device='cuda:0')
contig | non-contig: mean rel-diff tensor(0., device='cuda:0')
********************************************************************************
Allclose CUDA math non-contig/contig: True
Allclose CUDA memeff non-contig/contig: False Ran through compute sanitizer and no errors were reported ... |
4e572a3
to
725f6fa
Compare
When I swap the first block of two lines and call mem-eff before calling math there is no numerical differences.. res_non_contig_cuda_math = _attn_sdpa(query_sdpa, key_sdpa, value_sdpa, attention_mask_sdpa, contiguify=False)
res_contig_cuda_math = _attn_sdpa(query_sdpa, key_sdpa, value_sdpa, attention_mask_sdpa, contiguify=True)
res_non_contig_cuda_memeff = _attn_sdpa(query_sdpa, key_sdpa, value_sdpa, attention_mask_sdpa, contiguify=False, enable_mem_efficient=True)
res_contig_cuda_memeff = _attn_sdpa(query_sdpa, key_sdpa, value_sdpa, attention_mask_sdpa,contiguify=True, enable_mem_efficient=True) |
Another fun data point, doesn't occur when dtype is fp16. |
bfe6962
to
14277ed
Compare
14277ed
to
55b0317
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! This looks good to me and should save a lot of memory for people using mask with broadcasting :)
55b0317
to
07e150c
Compare
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
8a01979
to
300081d
Compare
torch/_inductor/lowering.py
Outdated
# This is needed for the backward case: | ||
# Hacky but not sure of a better way | ||
if isinstance(arg.data, ir.StorageBox) and arg.data.is_input_buffer(): | ||
if "expand" in arg.data.get_name(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a better way to match on this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would you clarify what exactly you are matching on or the case you're handling?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
asking a few questions
torch/_inductor/lowering.py
Outdated
# This is needed for the backward case: | ||
# Hacky but not sure of a better way | ||
if isinstance(arg.data, ir.StorageBox) and arg.data.is_input_buffer(): | ||
if "expand" in arg.data.get_name(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would you clarify what exactly you are matching on or the case you're handling?
917c605
to
300081d
Compare
4d99965
to
c861b7f
Compare
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
# Summary Improved Fix for Attention Mask Alignment Issue (pytorch#112577) This PR addresses Issue pytorch#112577 by refining the previously implemented fix, which was found to be incorrect and causes un-needed memory regressions. The update simplifies the approach to handling the alignment of the attention mask for mem eff attention. ## Changes Alignment Check and Padding: Initially, the alignment of the attention mask is checked. If misalignment is detected, padding is applied, followed by slicing. During this process, a warning is raised to alert users. Should this be warn_once? We only call expand, once on the aligned mask. Reference https://github.com/facebookresearch/xformers/blob/main/xformers/ops/fmha/cutlass.py#L115 @albanD, @mruberry, @jbschlosser, @walterddr, and @mikaylagawarecki. Pull Request resolved: pytorch#114173 Approved by: https://github.com/danthe3rd
Improved Fix for Attention Mask Alignment Issue (pytorch#112577) This PR addresses Issue pytorch#112577 by refining the previously implemented fix, which was found to be incorrect and causes un-needed memory regressions. The update simplifies the approach to handling the alignment of the attention mask for mem eff attention. Alignment Check and Padding: Initially, the alignment of the attention mask is checked. If misalignment is detected, padding is applied, followed by slicing. During this process, a warning is raised to alert users. Should this be warn_once? We only call expand, once on the aligned mask. Reference https://github.com/facebookresearch/xformers/blob/main/xformers/ops/fmha/cutlass.py#L115 @albanD, @mruberry, @jbschlosser, @walterddr, and @mikaylagawarecki. Pull Request resolved: pytorch#114173 Approved by: https://github.com/danthe3rd
Improved Fix for Attention Mask Alignment Issue (pytorch#112577) This PR addresses Issue pytorch#112577 by refining the previously implemented fix, which was found to be incorrect and causes un-needed memory regressions. The update simplifies the approach to handling the alignment of the attention mask for mem eff attention. Alignment Check and Padding: Initially, the alignment of the attention mask is checked. If misalignment is detected, padding is applied, followed by slicing. During this process, a warning is raised to alert users. Should this be warn_once? We only call expand, once on the aligned mask. Reference https://github.com/facebookresearch/xformers/blob/main/xformers/ops/fmha/cutlass.py#L115 @albanD, @mruberry, @jbschlosser, @walterddr, and @mikaylagawarecki. Pull Request resolved: pytorch#114173 Approved by: https://github.com/danthe3rd
Improved Fix for Attention Mask Alignment Issue (pytorch#112577) This PR addresses Issue pytorch#112577 by refining the previously implemented fix, which was found to be incorrect and causes un-needed memory regressions. The update simplifies the approach to handling the alignment of the attention mask for mem eff attention. Alignment Check and Padding: Initially, the alignment of the attention mask is checked. If misalignment is detected, padding is applied, followed by slicing. During this process, a warning is raised to alert users. Should this be warn_once? We only call expand, once on the aligned mask. Reference https://github.com/facebookresearch/xformers/blob/main/xformers/ops/fmha/cutlass.py#L115 @albanD, @mruberry, @jbschlosser, @walterddr, and @mikaylagawarecki. Pull Request resolved: pytorch#114173 Approved by: https://github.com/danthe3rd
Improved Fix for Attention Mask Alignment Issue (pytorch#112577) This PR addresses Issue pytorch#112577 by refining the previously implemented fix, which was found to be incorrect and causes un-needed memory regressions. The update simplifies the approach to handling the alignment of the attention mask for mem eff attention. Alignment Check and Padding: Initially, the alignment of the attention mask is checked. If misalignment is detected, padding is applied, followed by slicing. During this process, a warning is raised to alert users. Should this be warn_once? We only call expand, once on the aligned mask. Reference https://github.com/facebookresearch/xformers/blob/main/xformers/ops/fmha/cutlass.py#L115 @albanD, @mruberry, @jbschlosser, @walterddr, and @mikaylagawarecki. Pull Request resolved: pytorch#114173 Approved by: https://github.com/danthe3rd
Improved Fix for Attention Mask Alignment Issue (#112577) This PR addresses Issue #112577 by refining the previously implemented fix, which was found to be incorrect and causes un-needed memory regressions. The update simplifies the approach to handling the alignment of the attention mask for mem eff attention. Alignment Check and Padding: Initially, the alignment of the attention mask is checked. If misalignment is detected, padding is applied, followed by slicing. During this process, a warning is raised to alert users. Should this be warn_once? We only call expand, once on the aligned mask. Reference https://github.com/facebookresearch/xformers/blob/main/xformers/ops/fmha/cutlass.py#L115 @albanD, @mruberry, @jbschlosser, @walterddr, and @mikaylagawarecki. Pull Request resolved: #114173 Approved by: https://github.com/danthe3rd
Summary
Improved Fix for Attention Mask Alignment Issue (#112577)
This PR addresses Issue #112577 by refining the previously implemented fix, which was found to be incorrect and causes un-needed memory regressions. The update simplifies the approach to handling the alignment of the attention mask for mem eff attention.
Changes
Alignment Check and Padding: Initially, the alignment of the attention mask is checked. If misalignment is detected, padding is applied, followed by slicing. During this process, a warning is raised to alert users.
Should this be warn_once?
We only call expand, once on the aligned mask.
Reference
https://github.com/facebookresearch/xformers/blob/main/xformers/ops/fmha/cutlass.py#L115
@albanD, @mruberry, @jbschlosser, @walterddr, and @mikaylagawarecki.
cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @albanD @mruberry @jbschlosser @walterddr @mikaylagawarecki @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @voznesenskym @EikanWang @Guobing-Chen @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler @kiukchung @d4l3k @LucasLLC