Skip to content

Conversation

liangel-02
Copy link
Contributor

@liangel-02 liangel-02 commented Oct 2, 2025

Summary

Today, the only way to have variable sequence length support in PyTorch attention is through nested tensors here. We also want to add an explicit lower-level API that provides variable sequence length support without padding/masking in SDPA.

This PR builds out varlen_attn, the public API that users can call for the forward method, and _varlen_attn, the private API that calls into the Flash Attention/cuDNN backend.

Benchmarking

To benchmark, we compare runtime and TFLOPs against the current SDPA approach with padding.

Settings:

  • 1 H100 machine
  • batch_size=8, max_seq_len=2048, embed_dim=1024, num_heads=16
  • dtype torch.bfloat16
  • is_causal=False
  • for variable length, we set sequences to be random multiples of 64 up to max_seq_len
  • 100 runs
Variable Length API SDPA
Runtime 0.21750560760498047 ms 0.43171775817871094 ms
TFLOPs 231.812 320.840

The sparsity is 0.453 which we can see matches the speedup we get from Varlen (approx 50%). TFLOPs remains around the same, with SDPA slightly larger due to potential higher overhead and total flops scaling with sequence length.

Testing

Run python test/test_varlen_attention.py for unit tests where we verify basic functionality and confirm numerical match between varlen outputs vs SDPA.

Next steps

Next steps from this PR (higher in the stack) include registering the private API _varlen_attn as a custom op, implementing backward support, and enabling cuDNN with correct numerics.

Stack from ghstack (oldest at bottom):

(This stack builds on top of #162326)

[ghstack-poisoned]
Copy link

pytorch-bot bot commented Oct 2, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/164502

Note: Links to docs will display an error until the docs builds have been completed.

⏳ No Failures, 63 Pending

As of commit 59331ad with merge base ffe3cb2 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

liangel-02 added a commit that referenced this pull request Oct 2, 2025
ghstack-source-id: e22c006
Pull Request resolved: #164502
This was referenced Oct 2, 2025
**Summary** 

Today, the only way to have variable sequence length support in PyTorch attention is through nested tensors [here](https://docs.pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html#nestedtensor-and-dense-tensor-support). We also want to add an explicit lower-level API that provides variable sequence length support without padding/masking in SDPA. 

This PR builds out `varlen_attn`, the public API that users can call for the forward method, and `_varlen_attn`, the private API that calls into the Flash Attention/cuDNN backend.

**Benchmarking**

To benchmark, we compare runtime and TFLOPs against the current SDPA approach with padding.

Settings: 

- 1 H100 machine
- `batch_size=8`, `max_seq_len=2048`, `embed_dim=1024`, `num_heads=16`
- dtype `torch.bfloat16`
- `is_causal=False`
- for variable length, we set sequences to be random multiples of 64 up to `max_seq_len`
- 100 runs 

|        | Variable Length API | SDPA     |
|--------|--------------------|----------|
| Runtime | 0.21750560760498047 ms       | 0.43171775817871094 ms  |
| TFLOPs | 231.812         | 320.840  |

The sparsity is 0.453 which we can see matches the speedup we get from Varlen (approx 50%). TFLOPs remains around the same, with SDPA slightly larger due to potential higher overhead and total flops scaling with sequence length.

**Testing**

Run `python test/test_varlen_attention.py` for unit tests where we verify basic functionality and confirm numerical match between varlen outputs vs SDPA.

**Next steps** 

Next steps from this PR (higher in the stack) include registering the private API `_varlen_attn` as a custom op, implementing backward support, and enabling cuDNN with correct numerics. 


(This stack builds on top of #162326)


[ghstack-poisoned]
**Summary** 

Today, the only way to have variable sequence length support in PyTorch attention is through nested tensors [here](https://docs.pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html#nestedtensor-and-dense-tensor-support). We also want to add an explicit lower-level API that provides variable sequence length support without padding/masking in SDPA. 

This PR builds out `varlen_attn`, the public API that users can call for the forward method, and `_varlen_attn`, the private API that calls into the Flash Attention/cuDNN backend.

**Benchmarking**

To benchmark, we compare runtime and TFLOPs against the current SDPA approach with padding.

Settings: 

- 1 H100 machine
- `batch_size=8`, `max_seq_len=2048`, `embed_dim=1024`, `num_heads=16`
- dtype `torch.bfloat16`
- `is_causal=False`
- for variable length, we set sequences to be random multiples of 64 up to `max_seq_len`
- 100 runs 

|        | Variable Length API | SDPA     |
|--------|--------------------|----------|
| Runtime | 0.21750560760498047 ms       | 0.43171775817871094 ms  |
| TFLOPs | 231.812         | 320.840  |

The sparsity is 0.453 which we can see matches the speedup we get from Varlen (approx 50%). TFLOPs remains around the same, with SDPA slightly larger due to potential higher overhead and total flops scaling with sequence length.

**Testing**

Run `python test/test_varlen_attention.py` for unit tests where we verify basic functionality and confirm numerical match between varlen outputs vs SDPA.

**Next steps** 

Next steps from this PR (higher in the stack) include registering the private API `_varlen_attn` as a custom op, implementing backward support, and enabling cuDNN with correct numerics. 


(This stack builds on top of #162326)


[ghstack-poisoned]
@liangel-02 liangel-02 requested a review from drisspg October 9, 2025 15:18
liangel-02 added a commit that referenced this pull request Oct 9, 2025
ghstack-source-id: 46c9dcc
Pull Request resolved: #164502
liangel-02 added a commit that referenced this pull request Oct 9, 2025
ghstack-source-id: 46c9dcc
Pull Request resolved: #164502
@drisspg drisspg requested a review from v0i0 October 9, 2025 23:26
@drisspg drisspg added release notes: nn release notes category module: sdpa All things related to torch.nn.functional.scaled_dot_product_attentiion labels Oct 10, 2025
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
**Summary** 

Today, the only way to have variable sequence length support in PyTorch attention is through nested tensors [here](https://docs.pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html#nestedtensor-and-dense-tensor-support). We also want to add an explicit lower-level API that provides variable sequence length support without padding/masking in SDPA. 

This PR builds out `varlen_attn`, the public API that users can call for the forward method, and `_varlen_attn`, the private API that calls into the Flash Attention/cuDNN backend.

**Benchmarking**

To benchmark, we compare runtime and TFLOPs against the current SDPA approach with padding.

Settings: 

- 1 H100 machine
- `batch_size=8`, `max_seq_len=2048`, `embed_dim=1024`, `num_heads=16`
- dtype `torch.bfloat16`
- `is_causal=False`
- for variable length, we set sequences to be random multiples of 64 up to `max_seq_len`
- 100 runs 

|        | Variable Length API | SDPA     |
|--------|--------------------|----------|
| Runtime | 0.21750560760498047 ms       | 0.43171775817871094 ms  |
| TFLOPs | 231.812         | 320.840  |

The sparsity is 0.453 which we can see matches the speedup we get from Varlen (approx 50%). TFLOPs remains around the same, with SDPA slightly larger due to potential higher overhead and total flops scaling with sequence length.

**Testing**

Run `python test/test_varlen_attention.py` for unit tests where we verify basic functionality and confirm numerical match between varlen outputs vs SDPA.

**Next steps** 

Next steps from this PR (higher in the stack) include registering the private API `_varlen_attn` as a custom op, implementing backward support, and enabling cuDNN with correct numerics. 


(This stack builds on top of #162326)


[ghstack-poisoned]
liangel-02 added a commit that referenced this pull request Oct 14, 2025
ghstack-source-id: f1bd573
Pull Request resolved: #164502
**Summary** 

Today, the only way to have variable sequence length support in PyTorch attention is through nested tensors [here](https://docs.pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html#nestedtensor-and-dense-tensor-support). We also want to add an explicit lower-level API that provides variable sequence length support without padding/masking in SDPA. 

This PR builds out `varlen_attn`, the public API that users can call for the forward method, and `_varlen_attn`, the private API that calls into the Flash Attention/cuDNN backend.

**Benchmarking**

To benchmark, we compare runtime and TFLOPs against the current SDPA approach with padding.

Settings: 

- 1 H100 machine
- `batch_size=8`, `max_seq_len=2048`, `embed_dim=1024`, `num_heads=16`
- dtype `torch.bfloat16`
- `is_causal=False`
- for variable length, we set sequences to be random multiples of 64 up to `max_seq_len`
- 100 runs 

|        | Variable Length API | SDPA     |
|--------|--------------------|----------|
| Runtime | 0.21750560760498047 ms       | 0.43171775817871094 ms  |
| TFLOPs | 231.812         | 320.840  |

The sparsity is 0.453 which we can see matches the speedup we get from Varlen (approx 50%). TFLOPs remains around the same, with SDPA slightly larger due to potential higher overhead and total flops scaling with sequence length.

**Testing**

Run `python test/test_varlen_attention.py` for unit tests where we verify basic functionality and confirm numerical match between varlen outputs vs SDPA.

**Next steps** 

Next steps from this PR (higher in the stack) include registering the private API `_varlen_attn` as a custom op, implementing backward support, and enabling cuDNN with correct numerics. 


(This stack builds on top of #162326)


[ghstack-poisoned]
**Summary** 

Today, the only way to have variable sequence length support in PyTorch attention is through nested tensors [here](https://docs.pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html#nestedtensor-and-dense-tensor-support). We also want to add an explicit lower-level API that provides variable sequence length support without padding/masking in SDPA. 

This PR builds out `varlen_attn`, the public API that users can call for the forward method, and `_varlen_attn`, the private API that calls into the Flash Attention/cuDNN backend.

**Benchmarking**

To benchmark, we compare runtime and TFLOPs against the current SDPA approach with padding.

Settings: 

- 1 H100 machine
- `batch_size=8`, `max_seq_len=2048`, `embed_dim=1024`, `num_heads=16`
- dtype `torch.bfloat16`
- `is_causal=False`
- for variable length, we set sequences to be random multiples of 64 up to `max_seq_len`
- 100 runs 

|        | Variable Length API | SDPA     |
|--------|--------------------|----------|
| Runtime | 0.21750560760498047 ms       | 0.43171775817871094 ms  |
| TFLOPs | 231.812         | 320.840  |

The sparsity is 0.453 which we can see matches the speedup we get from Varlen (approx 50%). TFLOPs remains around the same, with SDPA slightly larger due to potential higher overhead and total flops scaling with sequence length.

**Testing**

Run `python test/test_varlen_attention.py` for unit tests where we verify basic functionality and confirm numerical match between varlen outputs vs SDPA.

**Next steps** 

Next steps from this PR (higher in the stack) include registering the private API `_varlen_attn` as a custom op, implementing backward support, and enabling cuDNN with correct numerics. 


(This stack builds on top of #162326)


[ghstack-poisoned]
@liangel-02 liangel-02 changed the base branch from gh/liangel/1/base to main October 14, 2025 19:29
@liangel-02
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 14, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

zhudada0120 pushed a commit to zhudada0120/pytorch that referenced this pull request Oct 15, 2025
**Summary**

Today, the only way to have variable sequence length support in PyTorch attention is through nested tensors [here](https://docs.pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html#nestedtensor-and-dense-tensor-support). We also want to add an explicit lower-level API that provides variable sequence length support without padding/masking in SDPA.

This PR builds out `varlen_attn`, the public API that users can call for the forward method, and `_varlen_attn`, the private API that calls into the Flash Attention/cuDNN backend.

**Benchmarking**

To benchmark, we compare runtime and TFLOPs against the current SDPA approach with padding.

Settings:

- 1 H100 machine
- `batch_size=8`, `max_seq_len=2048`, `embed_dim=1024`, `num_heads=16`
- dtype `torch.bfloat16`
- `is_causal=False`
- for variable length, we set sequences to be random multiples of 64 up to `max_seq_len`
- 100 runs

|        | Variable Length API | SDPA     |
|--------|--------------------|----------|
| Runtime | 0.21750560760498047 ms       | 0.43171775817871094 ms  |
| TFLOPs | 231.812         | 320.840  |

The sparsity is 0.453 which we can see matches the speedup we get from Varlen (approx 50%). TFLOPs remains around the same, with SDPA slightly larger due to potential higher overhead and total flops scaling with sequence length.

**Testing**

Run `python test/test_varlen_attention.py` for unit tests where we verify basic functionality and confirm numerical match between varlen outputs vs SDPA.

**Next steps**

Next steps from this PR (higher in the stack) include registering the private API `_varlen_attn` as a custom op, implementing backward support, and enabling cuDNN with correct numerics.

(This stack builds on top of pytorch#162326)

Pull Request resolved: pytorch#164502
Approved by: https://github.com/v0i0, https://github.com/drisspg
@huydhn
Copy link
Contributor

huydhn commented Oct 15, 2025

@pytorchbot revert -m 'Sorry for reverting your change, but the doctests failure is legit' -c weird

GH job link HUD commit link

Due to a macOS outage we had during the weekend, doctests was wrongly marked as flaky

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

@pytorchmergebot
Copy link
Collaborator

@liangel-02 your PR has been successfully reverted.

pytorchmergebot added a commit that referenced this pull request Oct 15, 2025
This reverts commit 3681312.

Reverted #164502 on behalf of https://github.com/huydhn due to Sorry for reverting your change, but the doctests failure is legit ([comment](#164502 (comment)))
@pytorchmergebot pytorchmergebot added Reverted ci-no-td Do not run TD on this PR labels Oct 15, 2025
@pytorch-auto-revert
Copy link

@pytorchbot revert -m "Reverted automatically by pytorch's autorevert, to avoid this behaviour add the tag autorevert: disable" -c autorevert

This PR is attributed to have caused regression in:

Please investigate and fix the issues.

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

@pytorchmergebot
Copy link
Collaborator

Reverting PR 164502 failed

Reason: list index out of range

Details for Dev Infra team Raised by workflow job

**Summary** 

Today, the only way to have variable sequence length support in PyTorch attention is through nested tensors [here](https://docs.pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html#nestedtensor-and-dense-tensor-support). We also want to add an explicit lower-level API that provides variable sequence length support without padding/masking in SDPA. 

This PR builds out `varlen_attn`, the public API that users can call for the forward method, and `_varlen_attn`, the private API that calls into the Flash Attention/cuDNN backend.

**Benchmarking**

To benchmark, we compare runtime and TFLOPs against the current SDPA approach with padding.

Settings: 

- 1 H100 machine
- `batch_size=8`, `max_seq_len=2048`, `embed_dim=1024`, `num_heads=16`
- dtype `torch.bfloat16`
- `is_causal=False`
- for variable length, we set sequences to be random multiples of 64 up to `max_seq_len`
- 100 runs 

|        | Variable Length API | SDPA     |
|--------|--------------------|----------|
| Runtime | 0.21750560760498047 ms       | 0.43171775817871094 ms  |
| TFLOPs | 231.812         | 320.840  |

The sparsity is 0.453 which we can see matches the speedup we get from Varlen (approx 50%). TFLOPs remains around the same, with SDPA slightly larger due to potential higher overhead and total flops scaling with sequence length.

**Testing**

Run `python test/test_varlen_attention.py` for unit tests where we verify basic functionality and confirm numerical match between varlen outputs vs SDPA.

**Next steps** 

Next steps from this PR (higher in the stack) include registering the private API `_varlen_attn` as a custom op, implementing backward support, and enabling cuDNN with correct numerics. 


(This stack builds on top of #162326)


[ghstack-poisoned]
**Summary** 

Today, the only way to have variable sequence length support in PyTorch attention is through nested tensors [here](https://docs.pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html#nestedtensor-and-dense-tensor-support). We also want to add an explicit lower-level API that provides variable sequence length support without padding/masking in SDPA. 

This PR builds out `varlen_attn`, the public API that users can call for the forward method, and `_varlen_attn`, the private API that calls into the Flash Attention/cuDNN backend.

**Benchmarking**

To benchmark, we compare runtime and TFLOPs against the current SDPA approach with padding.

Settings: 

- 1 H100 machine
- `batch_size=8`, `max_seq_len=2048`, `embed_dim=1024`, `num_heads=16`
- dtype `torch.bfloat16`
- `is_causal=False`
- for variable length, we set sequences to be random multiples of 64 up to `max_seq_len`
- 100 runs 

|        | Variable Length API | SDPA     |
|--------|--------------------|----------|
| Runtime | 0.21750560760498047 ms       | 0.43171775817871094 ms  |
| TFLOPs | 231.812         | 320.840  |

The sparsity is 0.453 which we can see matches the speedup we get from Varlen (approx 50%). TFLOPs remains around the same, with SDPA slightly larger due to potential higher overhead and total flops scaling with sequence length.

**Testing**

Run `python test/test_varlen_attention.py` for unit tests where we verify basic functionality and confirm numerical match between varlen outputs vs SDPA.

**Next steps** 

Next steps from this PR (higher in the stack) include registering the private API `_varlen_attn` as a custom op, implementing backward support, and enabling cuDNN with correct numerics. 


(This stack builds on top of #162326)


[ghstack-poisoned]
**Summary** 

Today, the only way to have variable sequence length support in PyTorch attention is through nested tensors [here](https://docs.pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html#nestedtensor-and-dense-tensor-support). We also want to add an explicit lower-level API that provides variable sequence length support without padding/masking in SDPA. 

This PR builds out `varlen_attn`, the public API that users can call for the forward method, and `_varlen_attn`, the private API that calls into the Flash Attention/cuDNN backend.

**Benchmarking**

To benchmark, we compare runtime and TFLOPs against the current SDPA approach with padding.

Settings: 

- 1 H100 machine
- `batch_size=8`, `max_seq_len=2048`, `embed_dim=1024`, `num_heads=16`
- dtype `torch.bfloat16`
- `is_causal=False`
- for variable length, we set sequences to be random multiples of 64 up to `max_seq_len`
- 100 runs 

|        | Variable Length API | SDPA     |
|--------|--------------------|----------|
| Runtime | 0.21750560760498047 ms       | 0.43171775817871094 ms  |
| TFLOPs | 231.812         | 320.840  |

The sparsity is 0.453 which we can see matches the speedup we get from Varlen (approx 50%). TFLOPs remains around the same, with SDPA slightly larger due to potential higher overhead and total flops scaling with sequence length.

**Testing**

Run `python test/test_varlen_attention.py` for unit tests where we verify basic functionality and confirm numerical match between varlen outputs vs SDPA.

**Next steps** 

Next steps from this PR (higher in the stack) include registering the private API `_varlen_attn` as a custom op, implementing backward support, and enabling cuDNN with correct numerics. 


(This stack builds on top of #162326)


[ghstack-poisoned]
**Summary** 

Today, the only way to have variable sequence length support in PyTorch attention is through nested tensors [here](https://docs.pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html#nestedtensor-and-dense-tensor-support). We also want to add an explicit lower-level API that provides variable sequence length support without padding/masking in SDPA. 

This PR builds out `varlen_attn`, the public API that users can call for the forward method, and `_varlen_attn`, the private API that calls into the Flash Attention/cuDNN backend.

**Benchmarking**

To benchmark, we compare runtime and TFLOPs against the current SDPA approach with padding.

Settings: 

- 1 H100 machine
- `batch_size=8`, `max_seq_len=2048`, `embed_dim=1024`, `num_heads=16`
- dtype `torch.bfloat16`
- `is_causal=False`
- for variable length, we set sequences to be random multiples of 64 up to `max_seq_len`
- 100 runs 

|        | Variable Length API | SDPA     |
|--------|--------------------|----------|
| Runtime | 0.21750560760498047 ms       | 0.43171775817871094 ms  |
| TFLOPs | 231.812         | 320.840  |

The sparsity is 0.453 which we can see matches the speedup we get from Varlen (approx 50%). TFLOPs remains around the same, with SDPA slightly larger due to potential higher overhead and total flops scaling with sequence length.

**Testing**

Run `python test/test_varlen_attention.py` for unit tests where we verify basic functionality and confirm numerical match between varlen outputs vs SDPA.

**Next steps** 

Next steps from this PR (higher in the stack) include registering the private API `_varlen_attn` as a custom op, implementing backward support, and enabling cuDNN with correct numerics. 


(This stack builds on top of #162326)


[ghstack-poisoned]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-no-td Do not run TD on this PR ciflow/trunk Trigger trunk jobs on your pull request Merged module: sdpa All things related to torch.nn.functional.scaled_dot_product_attentiion release notes: nn release notes category Reverted

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants