Skip to content

Significant Backward Pass Scaling Bottleneck in Flex Attention: 12x Slower than Forward Pass at 32K Sequence Length #142817

@cora-codes

Description

@cora-codes

🐛 Describe the bug

A plot really say it all here:

Screenshot 2024-12-10 at 3 23 18 PM

You can find a script to repro below:

import functools
import torch
import torch.nn.attention.flex_attention
import time
import matplotlib.pyplot as plt
import numpy as np

torch.set_default_device("cuda")
torch.set_float32_matmul_precision("medium")


@torch.compile
def attn(query, key, value, global_causal_mask):
    n_batch, n_ctx, d_model = query.shape
    n_head = 16
    query = query.reshape(n_batch, n_ctx, n_head, -1).transpose(1, 2)
    key = key.reshape(n_batch, n_ctx, n_head, -1).transpose(1, 2)
    value = value.reshape(n_batch, n_ctx, n_head, -1).transpose(1, 2)
    attn = functools.partial(
        torch.nn.attention.flex_attention.flex_attention,
        block_mask=global_causal_mask,
        return_lse=True,
    )
    x, _ = attn(query, key, value)
    x = x.transpose(1, 2).contiguous().reshape(n_batch, n_ctx, d_model)
    return x


def measure_time(seq_len):
    n_local_band = 128
    query = torch.randn(1, seq_len, 512, requires_grad=True)
    key = torch.randn(1, seq_len, 512, requires_grad=True)
    value = torch.randn(1, seq_len, 512, requires_grad=True)

    def global_causal(b, h, q_idx, kv_idx):
        return (q_idx >= kv_idx) & (q_idx - kv_idx > n_local_band)

    global_causal_mask = torch.nn.attention.flex_attention.create_block_mask(
        global_causal, B=None, H=None, Q_LEN=seq_len, KV_LEN=seq_len
    )

    out = attn(query, key, value, global_causal_mask)
    out.sum().backward()
    torch.cuda.synchronize()

    torch.cuda.synchronize()
    start_time = time.perf_counter()
    out = attn(query, key, value, global_causal_mask)
    torch.cuda.synchronize()
    forward_time = time.perf_counter() - start_time

    torch.cuda.synchronize()
    start_time = time.perf_counter()
    out.sum().backward()
    torch.cuda.synchronize()
    backward_time = time.perf_counter() - start_time

    return forward_time * 1000, backward_time * 1000  # Convert to milliseconds

seq_lengths = [1024, 2048, 4096, 8192, 16384, 32768]
n_runs = 3

forward_times = []
backward_times = []

for seq_len in seq_lengths:
    print(f"Testing sequence length: {seq_len}")
    fwd_times = []
    bwd_times = []

    for _ in range(n_runs):
        f_time, b_time = measure_time(seq_len)
        fwd_times.append(f_time)
        bwd_times.append(b_time)

    forward_times.append(np.mean(fwd_times))
    backward_times.append(np.mean(bwd_times))

plt.figure(figsize=(10, 6))
plt.plot(seq_lengths, forward_times, 'b-o', label='Forward Pass')
plt.plot(seq_lengths, backward_times, 'r-o', label='Backward Pass')

plt.xlabel('Sequence Length')
plt.ylabel('Time (ms)')
plt.title('Attention Performance Scaling with Sequence Length')
plt.grid(True)
plt.legend()

for i, seq_len in enumerate(seq_lengths):
    plt.annotate(f'{forward_times[i]:.1f}ms',
                (seq_len, forward_times[i]),
                textcoords="offset points",
                xytext=(0,10),
                ha='center')
    plt.annotate(f'{backward_times[i]:.1f}ms',
                (seq_len, backward_times[i]),
                textcoords="offset points",
                xytext=(0,-15),
                ha='center')

plt.tight_layout()
plt.savefig('attention_scaling.png', dpi=300, bbox_inches='tight')
plt.show()

I will note that this might not recompile like it should for each sequence length (but if you are careful about this, you will still see the same scaling behavior), I was fighting the dreaded issue
.

Versions

'2.6.0a0+gitf86a175'

cc @msaroufim @chauhang @penguinwu @zou3519 @ydwu4 @bdhirsh @yf225 @Chillee @drisspg @yanboliang @BoyuanFeng

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: flex attentionmodule: higher order operatorstorch.cond and similarmodule: performanceIssues related to performance, either of kernel code or framework gluemodule: pt2-dispatcherPT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op,oncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions