-
Notifications
You must be signed in to change notification settings - Fork 26k
Closed
Labels
module: flex attentionmodule: higher order operatorstorch.cond and similartorch.cond and similarmodule: performanceIssues related to performance, either of kernel code or framework glueIssues related to performance, either of kernel code or framework gluemodule: pt2-dispatcherPT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op,PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op,oncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🐛 Describe the bug
A plot really say it all here:
You can find a script to repro below:
import functools
import torch
import torch.nn.attention.flex_attention
import time
import matplotlib.pyplot as plt
import numpy as np
torch.set_default_device("cuda")
torch.set_float32_matmul_precision("medium")
@torch.compile
def attn(query, key, value, global_causal_mask):
n_batch, n_ctx, d_model = query.shape
n_head = 16
query = query.reshape(n_batch, n_ctx, n_head, -1).transpose(1, 2)
key = key.reshape(n_batch, n_ctx, n_head, -1).transpose(1, 2)
value = value.reshape(n_batch, n_ctx, n_head, -1).transpose(1, 2)
attn = functools.partial(
torch.nn.attention.flex_attention.flex_attention,
block_mask=global_causal_mask,
return_lse=True,
)
x, _ = attn(query, key, value)
x = x.transpose(1, 2).contiguous().reshape(n_batch, n_ctx, d_model)
return x
def measure_time(seq_len):
n_local_band = 128
query = torch.randn(1, seq_len, 512, requires_grad=True)
key = torch.randn(1, seq_len, 512, requires_grad=True)
value = torch.randn(1, seq_len, 512, requires_grad=True)
def global_causal(b, h, q_idx, kv_idx):
return (q_idx >= kv_idx) & (q_idx - kv_idx > n_local_band)
global_causal_mask = torch.nn.attention.flex_attention.create_block_mask(
global_causal, B=None, H=None, Q_LEN=seq_len, KV_LEN=seq_len
)
out = attn(query, key, value, global_causal_mask)
out.sum().backward()
torch.cuda.synchronize()
torch.cuda.synchronize()
start_time = time.perf_counter()
out = attn(query, key, value, global_causal_mask)
torch.cuda.synchronize()
forward_time = time.perf_counter() - start_time
torch.cuda.synchronize()
start_time = time.perf_counter()
out.sum().backward()
torch.cuda.synchronize()
backward_time = time.perf_counter() - start_time
return forward_time * 1000, backward_time * 1000 # Convert to milliseconds
seq_lengths = [1024, 2048, 4096, 8192, 16384, 32768]
n_runs = 3
forward_times = []
backward_times = []
for seq_len in seq_lengths:
print(f"Testing sequence length: {seq_len}")
fwd_times = []
bwd_times = []
for _ in range(n_runs):
f_time, b_time = measure_time(seq_len)
fwd_times.append(f_time)
bwd_times.append(b_time)
forward_times.append(np.mean(fwd_times))
backward_times.append(np.mean(bwd_times))
plt.figure(figsize=(10, 6))
plt.plot(seq_lengths, forward_times, 'b-o', label='Forward Pass')
plt.plot(seq_lengths, backward_times, 'r-o', label='Backward Pass')
plt.xlabel('Sequence Length')
plt.ylabel('Time (ms)')
plt.title('Attention Performance Scaling with Sequence Length')
plt.grid(True)
plt.legend()
for i, seq_len in enumerate(seq_lengths):
plt.annotate(f'{forward_times[i]:.1f}ms',
(seq_len, forward_times[i]),
textcoords="offset points",
xytext=(0,10),
ha='center')
plt.annotate(f'{backward_times[i]:.1f}ms',
(seq_len, backward_times[i]),
textcoords="offset points",
xytext=(0,-15),
ha='center')
plt.tight_layout()
plt.savefig('attention_scaling.png', dpi=300, bbox_inches='tight')
plt.show()I will note that this might not recompile like it should for each sequence length (but if you are careful about this, you will still see the same scaling behavior), I was fighting the dreaded issue
.
Versions
'2.6.0a0+gitf86a175'
cc @msaroufim @chauhang @penguinwu @zou3519 @ydwu4 @bdhirsh @yf225 @Chillee @drisspg @yanboliang @BoyuanFeng
Metadata
Metadata
Assignees
Labels
module: flex attentionmodule: higher order operatorstorch.cond and similartorch.cond and similarmodule: performanceIssues related to performance, either of kernel code or framework glueIssues related to performance, either of kernel code or framework gluemodule: pt2-dispatcherPT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op,PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op,oncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
