Skip to content

Commit

Permalink
[Inductor] default block size for head_dim = 256 for flex attention (#…
Browse files Browse the repository at this point in the history
…125380)

## H100
### torch.bfloat16
No major change, as expected.
```
| Type    |   Speedup |   batch_size |   num_heads |   q_seq_len |   k_seq_len |   head_dim | score_mod   | dtype          |
|---------|-----------|--------------|-------------|-------------|-------------|------------|-------------|----------------|
| Average |     1.122 |              |             |             |             |            |             |                |
| Max     |     1.437 |            1 |          16 |         512 |         512 |        128 | head_bias   | torch.bfloat16 |
| Min     |     0.895 |            1 |          16 |        1024 |        1024 |         64 | head_bias   | torch.bfloat16 |
```
### torch.float32
Before: OOM when ```head_dim``` = 256
After:
```
| Type    |   Speedup |   batch_size |   num_heads |   q_seq_len |   k_seq_len |   head_dim | score_mod   | dtype         |
|---------|-----------|--------------|-------------|-------------|-------------|------------|-------------|---------------|
| Average |     2.231 |              |             |             |             |            |             |               |
| Max     |     3.760 |           16 |          16 |        4096 |        4096 |         64 | noop        | torch.float32 |
| Min     |     1.532 |            1 |          16 |         512 |         512 |        256 | causal_mask | torch.float32 |
```

## A100
### torch.bfloat16
Before:
```
| Type    |   Speedup |   batch_size |   num_heads |   q_seq_len |   k_seq_len |   head_dim | score_mod     | dtype          |
|---------|-----------|--------------|-------------|-------------|-------------|------------|---------------|----------------|
| Average |     0.587 |              |             |             |             |            |               |                |
| Max     |     0.960 |            1 |          16 |         512 |         512 |         64 | noop          | torch.bfloat16 |
| Min     |     0.017 |            8 |          16 |        4096 |        4096 |        256 | relative_bias | torch.bfloat16 |
```
After:
```
| Type    |   Speedup |   batch_size |   num_heads |   q_seq_len |   k_seq_len |   head_dim | score_mod   | dtype          |
|---------|-----------|--------------|-------------|-------------|-------------|------------|-------------|----------------|
| Average |     0.756 |              |             |             |             |            |             |                |
| Max     |     0.931 |            1 |          16 |         512 |         512 |         64 | noop        | torch.bfloat16 |
| Min     |     0.467 |           16 |          16 |        1024 |        1024 |        256 | noop        | torch.bfloat16 |
```

### torch.float32
Before: OOM when ```head_dim``` = 256
After:
```
| Type    |   Speedup |   batch_size |   num_heads |   q_seq_len |   k_seq_len |   head_dim | score_mod   | dtype         |
|---------|-----------|--------------|-------------|-------------|-------------|------------|-------------|---------------|
| Average |     2.386 |              |             |             |             |            |             |               |
| Max     |     7.584 |           16 |          16 |         512 |         512 |         64 | noop        | torch.float32 |
| Min     |     0.948 |            1 |          16 |         512 |         512 |        256 | causal_mask | torch.float32 |
```

Pull Request resolved: #125380
Approved by: https://github.com/drisspg
  • Loading branch information
yanboliang authored and pytorchmergebot committed May 2, 2024
1 parent 5c7b71d commit 3b5f6b1
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 20 deletions.
2 changes: 1 addition & 1 deletion benchmarks/transformer/score_mod.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def generate_experiment_configs() -> List[ExperimentConfig]:
batch_sizes = [1, 8, 16]
num_heads = [16]
q_kv_seq_lens = [(512, 512), (1024, 1024), (4096, 4096)]
head_dims = [64, 128]
head_dims = [64, 128, 256]
dtypes = [
torch.bfloat16,
]
Expand Down
51 changes: 32 additions & 19 deletions torch/_inductor/kernel/flex_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,31 +173,44 @@ def sdpa_grid(batch_size, num_heads, num_queries, d_model, meta):
)


_h100_default_config = {
(torch.float32, 64): (128, 32, 4, 3),
(torch.float32, 128): (32, 64, 4, 3),
(torch.float32, 256): (32, 32, 4, 3),
(torch.bfloat16, 64): (128, 64, 4, 3),
(torch.bfloat16, 128): (64, 32, 4, 3),
(torch.bfloat16, 256): (64, 32, 4, 3),
}

_a100_default_config = {
(torch.float32, 64): (128, 32, 4, 3),
(torch.float32, 128): (128, 32, 4, 3),
(torch.float32, 256): (64, 16, 4, 3),
(torch.bfloat16, 64): (128, 64, 4, 3),
(torch.bfloat16, 128): (128, 32, 4, 3),
(torch.bfloat16, 256): (32, 64, 4, 3),
}


def _get_default_config(query):
dtype = query.get_dtype()
head_dim = query.get_size()[-1]
default_config = None

if torch.cuda.get_device_capability() >= (9, 0): # H100
if query.get_dtype() == torch.float32:
if head_dim == 64:
default_config = (128, 32, 4, 3)
else:
default_config = (32, 64, 4, 3)
if head_dim <= 256 and torch.cuda.get_device_capability() >= (9, 0): # H100
if dtype == torch.float32:
default_config = (64, 64, 4, 3)
else:
if head_dim == 64:
default_config = (128, 64, 4, 3)
else:
default_config = (64, 32, 4, 3)
elif torch.cuda.get_device_capability() >= (8, 0): # A100
if query.get_dtype() == torch.float32:
default_config = (128, 32, 4, 3)
default_config = (128, 64, 4, 3)
default_config = _h100_default_config.get((dtype, head_dim), default_config)
elif head_dim <= 256 and torch.cuda.get_device_capability() >= (8, 0): # A100
if dtype == torch.float32:
default_config = (64, 64, 4, 3)
else:
if head_dim == 64:
default_config = (128, 64, 4, 3)
else:
default_config = (128, 32, 4, 3)
else:
if query.get_dtype() == torch.float32:
default_config = (128, 64, 4, 3)
default_config = _a100_default_config.get((dtype, head_dim), default_config)
else: # modest hardware or extremely large head_dim
if dtype == torch.float32:
default_config = (32, 16, 4, 3)
else:
default_config = (64, 32, 4, 3)
Expand Down

0 comments on commit 3b5f6b1

Please sign in to comment.