Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Inductor] default block size for head_dim = 256 for flex attention (#…
…125380) ## H100 ### torch.bfloat16 No major change, as expected. ``` | Type | Speedup | batch_size | num_heads | q_seq_len | k_seq_len | head_dim | score_mod | dtype | |---------|-----------|--------------|-------------|-------------|-------------|------------|-------------|----------------| | Average | 1.122 | | | | | | | | | Max | 1.437 | 1 | 16 | 512 | 512 | 128 | head_bias | torch.bfloat16 | | Min | 0.895 | 1 | 16 | 1024 | 1024 | 64 | head_bias | torch.bfloat16 | ``` ### torch.float32 Before: OOM when ```head_dim``` = 256 After: ``` | Type | Speedup | batch_size | num_heads | q_seq_len | k_seq_len | head_dim | score_mod | dtype | |---------|-----------|--------------|-------------|-------------|-------------|------------|-------------|---------------| | Average | 2.231 | | | | | | | | | Max | 3.760 | 16 | 16 | 4096 | 4096 | 64 | noop | torch.float32 | | Min | 1.532 | 1 | 16 | 512 | 512 | 256 | causal_mask | torch.float32 | ``` ## A100 ### torch.bfloat16 Before: ``` | Type | Speedup | batch_size | num_heads | q_seq_len | k_seq_len | head_dim | score_mod | dtype | |---------|-----------|--------------|-------------|-------------|-------------|------------|---------------|----------------| | Average | 0.587 | | | | | | | | | Max | 0.960 | 1 | 16 | 512 | 512 | 64 | noop | torch.bfloat16 | | Min | 0.017 | 8 | 16 | 4096 | 4096 | 256 | relative_bias | torch.bfloat16 | ``` After: ``` | Type | Speedup | batch_size | num_heads | q_seq_len | k_seq_len | head_dim | score_mod | dtype | |---------|-----------|--------------|-------------|-------------|-------------|------------|-------------|----------------| | Average | 0.756 | | | | | | | | | Max | 0.931 | 1 | 16 | 512 | 512 | 64 | noop | torch.bfloat16 | | Min | 0.467 | 16 | 16 | 1024 | 1024 | 256 | noop | torch.bfloat16 | ``` ### torch.float32 Before: OOM when ```head_dim``` = 256 After: ``` | Type | Speedup | batch_size | num_heads | q_seq_len | k_seq_len | head_dim | score_mod | dtype | |---------|-----------|--------------|-------------|-------------|-------------|------------|-------------|---------------| | Average | 2.386 | | | | | | | | | Max | 7.584 | 16 | 16 | 512 | 512 | 64 | noop | torch.float32 | | Min | 0.948 | 1 | 16 | 512 | 512 | 256 | causal_mask | torch.float32 | ``` Pull Request resolved: #125380 Approved by: https://github.com/drisspg
- Loading branch information