Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add scatter_mm and bsr_scatter_mm operations. #110396

Closed
wants to merge 8 commits into from

Conversation

pearu
Copy link
Collaborator

@pearu pearu commented Oct 2, 2023

This PR introduces scatter_mm operation (compute mm of arbitrary pairs of tensors given in batches of tensors) that is used to implement bsr_scatter_mm that is equivalent to bsr_dense_mm (the mm operation on bsr and strided tensors). The implementation is provided both in Triton (when tensor dimensions are multiples of 16) and in PyTorch (otherwise).

The figures below illustrate the performance differences of bsr_scatter_mm and bsr_dense_mm (GPU: NVIDIA GeForce RTX 2060 SUPER). The first figure represents the performance equilibrium point in BSR tensor sparsity at which value bsr_scatter_mm or bsr_dense_mm have the same performance characteristics as torch.matmul. The second figure represents speedups from using bsr_scatter_mm at its performance equilibrium points with respect to bsr_dense_mm.

The same figures for GPU card NVIDIA A100-SXM4-80GB:

In sum:

  • bsr_scatter_mm is about 2x faster than bsr_dense_mm for small block sizes of 16 and 32 and large tensors [GPU: NVIDIA GeForce RTX 2060 SUPER].
  • bsr_scatter_mm is up to 2x faster than bsr_dense_mm for small block sizes of 16 and large tensors [GPU: NVIDIA A100-SXM4-80GB].
  • bsr_dense_mm is up to 20 % faster than bsr_scatter_mm for block sizes of 64 or larger [GPU: NVIDIA GeForce RTX 2060 SUPER].
  • However, bsr_dense_mm fails with OutOfResources exception for block sizes of 256 or larger whereas bsr_scatter_mm succeeds.

Stack from ghstack (oldest at bottom):

cc @alexsamardzic @nikitaved @cpuhrsch @amjames @bhosmer

@pytorch-bot pytorch-bot bot added the release notes: sparse release notes category label Oct 2, 2023
@pytorch-bot
Copy link

pytorch-bot bot commented Oct 2, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/110396

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 2e9e52b with merge base 57c7aa1 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

pearu added a commit that referenced this pull request Oct 2, 2023
ghstack-source-id: d050ac42749d119c50ff7ca68dc67a67a9d72113
Pull Request resolved: #110396
@pearu pearu marked this pull request as draft October 2, 2023 16:09
@pearu pearu added this to In progress in Sparse tensors via automation Oct 9, 2023
This PR introduces `scatter_mm` operation (compute `mm` of arbitrary pairs of tensors given in batches of tensors) that is used to implement `bsr_scatter_mm` that is equivalent to `bsr_dense_mm` (the `mm` operation on bsr and strided tensors). The implementation is provided both in Triton (when tensor dimensions are multiples of 16) and in PyTorch (otherwise).

The figures below illustrate the performance differences of `bsr_scatter_mm` and `bsr_dense_mm` (GPU: `NVIDIA GeForce RTX 2060 SUPER`). The first figure represents the performance equilibrium point in BSR tensor sparsity at which value `bsr_scatter_mm` or `bsr_dense_mm` have the same performance characteristics as `torch.matmul`. The second figure represents speedups from using `bsr_scatter_mm` at its performance equilibrium points with respect to `bsr_dense_mm`.

![strided_vs_bsr_mm](https://github.com/pytorch/pytorch/assets/402156/526d182e-937f-4812-a6c4-904f52d6d5ab)

![strided_vs_bsr_mm_speedup](https://github.com/pytorch/pytorch/assets/402156/ccb606ab-1f3f-4133-887c-b56285f4f168)

In sum:
- `bsr_scatter_mm` is about 2x faster than `bsr_dense_mm` for small block sizes of 16 and 32 and large tensors.
- `bsr_dense_mm` is up to 20 % faster than `bsr_scatter_mm` for block sizes of 64 or larger.
- However, `bsr_dense_mm` fails with `OutOfResources` exception for block sizes of 256 or larger whereas `bsr_scatter_mm` succeeds.




[ghstack-poisoned]
pearu added a commit that referenced this pull request Oct 11, 2023
ghstack-source-id: 86665e5a524fa8235e26d5cd50403a000adf6372
Pull Request resolved: #110396
This PR introduces `scatter_mm` operation (compute `mm` of arbitrary pairs of tensors given in batches of tensors) that is used to implement `bsr_scatter_mm` that is equivalent to `bsr_dense_mm` (the `mm` operation on bsr and strided tensors). The implementation is provided both in Triton (when tensor dimensions are multiples of 16) and in PyTorch (otherwise).

The figures below illustrate the performance differences of `bsr_scatter_mm` and `bsr_dense_mm` (GPU: `NVIDIA GeForce RTX 2060 SUPER`). The first figure represents the performance equilibrium point in BSR tensor sparsity at which value `bsr_scatter_mm` or `bsr_dense_mm` have the same performance characteristics as `torch.matmul`. The second figure represents speedups from using `bsr_scatter_mm` at its performance equilibrium points with respect to `bsr_dense_mm`.

![strided_vs_bsr_mm](https://github.com/pytorch/pytorch/assets/402156/526d182e-937f-4812-a6c4-904f52d6d5ab)

![strided_vs_bsr_mm_speedup](https://github.com/pytorch/pytorch/assets/402156/ccb606ab-1f3f-4133-887c-b56285f4f168)

In sum:
- `bsr_scatter_mm` is about 2x faster than `bsr_dense_mm` for small block sizes of 16 and 32 and large tensors.
- `bsr_dense_mm` is up to 20 % faster than `bsr_scatter_mm` for block sizes of 64 or larger.
- However, `bsr_dense_mm` fails with `OutOfResources` exception for block sizes of 256 or larger whereas `bsr_scatter_mm` succeeds.




[ghstack-poisoned]
pearu added a commit that referenced this pull request Oct 12, 2023
ghstack-source-id: fad7ff8a0b8aa73463b93607c12a7f064e3d3f00
Pull Request resolved: #110396
This PR introduces `scatter_mm` operation (compute `mm` of arbitrary pairs of tensors given in batches of tensors) that is used to implement `bsr_scatter_mm` that is equivalent to `bsr_dense_mm` (the `mm` operation on bsr and strided tensors). The implementation is provided both in Triton (when tensor dimensions are multiples of 16) and in PyTorch (otherwise).

The figures below illustrate the performance differences of `bsr_scatter_mm` and `bsr_dense_mm` (GPU: `NVIDIA GeForce RTX 2060 SUPER`). The first figure represents the performance equilibrium point in BSR tensor sparsity at which value `bsr_scatter_mm` or `bsr_dense_mm` have the same performance characteristics as `torch.matmul`. The second figure represents speedups from using `bsr_scatter_mm` at its performance equilibrium points with respect to `bsr_dense_mm`.

![strided_vs_bsr_mm](https://github.com/pytorch/pytorch/assets/402156/526d182e-937f-4812-a6c4-904f52d6d5ab)

![strided_vs_bsr_mm_speedup](https://github.com/pytorch/pytorch/assets/402156/ccb606ab-1f3f-4133-887c-b56285f4f168)

In sum:
- `bsr_scatter_mm` is about 2x faster than `bsr_dense_mm` for small block sizes of 16 and 32 and large tensors.
- `bsr_dense_mm` is up to 20 % faster than `bsr_scatter_mm` for block sizes of 64 or larger.
- However, `bsr_dense_mm` fails with `OutOfResources` exception for block sizes of 256 or larger whereas `bsr_scatter_mm` succeeds.




[ghstack-poisoned]
pearu added a commit that referenced this pull request Oct 13, 2023
ghstack-source-id: e0ed1bb421d1cce16c2efa50f8e83f2104d9c323
Pull Request resolved: #110396
This PR introduces `scatter_mm` operation (compute `mm` of arbitrary pairs of tensors given in batches of tensors) that is used to implement `bsr_scatter_mm` that is equivalent to `bsr_dense_mm` (the `mm` operation on bsr and strided tensors). The implementation is provided both in Triton (when tensor dimensions are multiples of 16) and in PyTorch (otherwise).

The figures below illustrate the performance differences of `bsr_scatter_mm` and `bsr_dense_mm` (GPU: `NVIDIA GeForce RTX 2060 SUPER`). The first figure represents the performance equilibrium point in BSR tensor sparsity at which value `bsr_scatter_mm` or `bsr_dense_mm` have the same performance characteristics as `torch.matmul`. The second figure represents speedups from using `bsr_scatter_mm` at its performance equilibrium points with respect to `bsr_dense_mm`.

![strided_vs_bsr_mm](https://github.com/pytorch/pytorch/assets/402156/526d182e-937f-4812-a6c4-904f52d6d5ab)

![strided_vs_bsr_mm_speedup](https://github.com/pytorch/pytorch/assets/402156/ccb606ab-1f3f-4133-887c-b56285f4f168)

In sum:
- `bsr_scatter_mm` is about 2x faster than `bsr_dense_mm` for small block sizes of 16 and 32 and large tensors.
- `bsr_dense_mm` is up to 20 % faster than `bsr_scatter_mm` for block sizes of 64 or larger.
- However, `bsr_dense_mm` fails with `OutOfResources` exception for block sizes of 256 or larger whereas `bsr_scatter_mm` succeeds.




[ghstack-poisoned]
pearu added a commit that referenced this pull request Oct 17, 2023
ghstack-source-id: 5f48f4fc1b0770e26a4b1a44bc15e771f1d5c675
Pull Request resolved: #110396
This PR introduces `scatter_mm` operation (compute `mm` of arbitrary pairs of tensors given in batches of tensors) that is used to implement `bsr_scatter_mm` that is equivalent to `bsr_dense_mm` (the `mm` operation on bsr and strided tensors). The implementation is provided both in Triton (when tensor dimensions are multiples of 16) and in PyTorch (otherwise).

The figures below illustrate the performance differences of `bsr_scatter_mm` and `bsr_dense_mm` (GPU: `NVIDIA GeForce RTX 2060 SUPER`). The first figure represents the performance equilibrium point in BSR tensor sparsity at which value `bsr_scatter_mm` or `bsr_dense_mm` have the same performance characteristics as `torch.matmul`. The second figure represents speedups from using `bsr_scatter_mm` at its performance equilibrium points with respect to `bsr_dense_mm`.

<img src="https://github.com/pytorch/pytorch/assets/402156/526d182e-937f-4812-a6c4-904f52d6d5ab" width="48%"> <img src="https://github.com/pytorch/pytorch/assets/402156/ccb606ab-1f3f-4133-887c-b56285f4f168" width="48%">

The same figures for GPU card `NVIDIA A100-SXM4-80GB`:

<img src="https://github.com/pytorch/pytorch/assets/402156/25466f1d-df34-4d1c-a975-afb478e4d9f0" width="48%"> <img src="https://github.com/pytorch/pytorch/assets/402156/6ada91f0-a20f-4f0d-8a48-1f4ccc60d08e" width="48%">

In sum:
- `bsr_scatter_mm` is about 2x faster than `bsr_dense_mm` for small block sizes of 16 and 32 and large tensors [GPU: `NVIDIA GeForce RTX 2060 SUPER`].
- `bsr_scatter_mm` is up to 2x faster than `bsr_dense_mm` for small block sizes of 16 and large tensors [GPU: `NVIDIA A100-SXM4-80GB`].
- `bsr_dense_mm` is up to 20 % faster than `bsr_scatter_mm` for block sizes of 64 or larger [GPU: `NVIDIA GeForce RTX 2060 SUPER`].
- However, `bsr_dense_mm` fails with `OutOfResources` exception for block sizes of 256 or larger whereas `bsr_scatter_mm` succeeds.




[ghstack-poisoned]
@pearu pearu marked this pull request as ready for review October 18, 2023 17:02
@pearu pearu requested a review from amjames October 18, 2023 17:02
@pearu pearu added module: sparse Related to torch.sparse topic: new features topic category topic: performance topic category labels Oct 18, 2023
This PR introduces `scatter_mm` operation (compute `mm` of arbitrary pairs of tensors given in batches of tensors) that is used to implement `bsr_scatter_mm` that is equivalent to `bsr_dense_mm` (the `mm` operation on bsr and strided tensors). The implementation is provided both in Triton (when tensor dimensions are multiples of 16) and in PyTorch (otherwise).

The figures below illustrate the performance differences of `bsr_scatter_mm` and `bsr_dense_mm` (GPU: `NVIDIA GeForce RTX 2060 SUPER`). The first figure represents the performance equilibrium point in BSR tensor sparsity at which value `bsr_scatter_mm` or `bsr_dense_mm` have the same performance characteristics as `torch.matmul`. The second figure represents speedups from using `bsr_scatter_mm` at its performance equilibrium points with respect to `bsr_dense_mm`.

<img src="https://github.com/pytorch/pytorch/assets/402156/526d182e-937f-4812-a6c4-904f52d6d5ab" width="48%"> <img src="https://github.com/pytorch/pytorch/assets/402156/ccb606ab-1f3f-4133-887c-b56285f4f168" width="48%">

The same figures for GPU card `NVIDIA A100-SXM4-80GB`:

<img src="https://github.com/pytorch/pytorch/assets/402156/25466f1d-df34-4d1c-a975-afb478e4d9f0" width="48%"> <img src="https://github.com/pytorch/pytorch/assets/402156/6ada91f0-a20f-4f0d-8a48-1f4ccc60d08e" width="48%">

In sum:
- `bsr_scatter_mm` is about 2x faster than `bsr_dense_mm` for small block sizes of 16 and 32 and large tensors [GPU: `NVIDIA GeForce RTX 2060 SUPER`].
- `bsr_scatter_mm` is up to 2x faster than `bsr_dense_mm` for small block sizes of 16 and large tensors [GPU: `NVIDIA A100-SXM4-80GB`].
- `bsr_dense_mm` is up to 20 % faster than `bsr_scatter_mm` for block sizes of 64 or larger [GPU: `NVIDIA GeForce RTX 2060 SUPER`].
- However, `bsr_dense_mm` fails with `OutOfResources` exception for block sizes of 256 or larger whereas `bsr_scatter_mm` succeeds.




cc alexsamardzic nikitaved cpuhrsch amjames bhosmer

[ghstack-poisoned]
This PR introduces `scatter_mm` operation (compute `mm` of arbitrary pairs of tensors given in batches of tensors) that is used to implement `bsr_scatter_mm` that is equivalent to `bsr_dense_mm` (the `mm` operation on bsr and strided tensors). The implementation is provided both in Triton (when tensor dimensions are multiples of 16) and in PyTorch (otherwise).

The figures below illustrate the performance differences of `bsr_scatter_mm` and `bsr_dense_mm` (GPU: `NVIDIA GeForce RTX 2060 SUPER`). The first figure represents the performance equilibrium point in BSR tensor sparsity at which value `bsr_scatter_mm` or `bsr_dense_mm` have the same performance characteristics as `torch.matmul`. The second figure represents speedups from using `bsr_scatter_mm` at its performance equilibrium points with respect to `bsr_dense_mm`.

<img src="https://github.com/pytorch/pytorch/assets/402156/526d182e-937f-4812-a6c4-904f52d6d5ab" width="48%"> <img src="https://github.com/pytorch/pytorch/assets/402156/ccb606ab-1f3f-4133-887c-b56285f4f168" width="48%">

The same figures for GPU card `NVIDIA A100-SXM4-80GB`:

<img src="https://github.com/pytorch/pytorch/assets/402156/25466f1d-df34-4d1c-a975-afb478e4d9f0" width="48%"> <img src="https://github.com/pytorch/pytorch/assets/402156/6ada91f0-a20f-4f0d-8a48-1f4ccc60d08e" width="48%">

In sum:
- `bsr_scatter_mm` is about 2x faster than `bsr_dense_mm` for small block sizes of 16 and 32 and large tensors [GPU: `NVIDIA GeForce RTX 2060 SUPER`].
- `bsr_scatter_mm` is up to 2x faster than `bsr_dense_mm` for small block sizes of 16 and large tensors [GPU: `NVIDIA A100-SXM4-80GB`].
- `bsr_dense_mm` is up to 20 % faster than `bsr_scatter_mm` for block sizes of 64 or larger [GPU: `NVIDIA GeForce RTX 2060 SUPER`].
- However, `bsr_dense_mm` fails with `OutOfResources` exception for block sizes of 256 or larger whereas `bsr_scatter_mm` succeeds.




cc alexsamardzic nikitaved cpuhrsch amjames bhosmer

[ghstack-poisoned]
Sparse tensors automation moved this from In progress to Reviewer approved Oct 23, 2023
@cpuhrsch
Copy link
Contributor

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 23, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

# below.
device_name = torch.cuda.get_device_name()
is_A100 = 'A100' in device_name
if (M, K, N) == (256,) * 3:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess you essentially want to store a decision tree here?

Sparse tensors automation moved this from Reviewer approved to Done Oct 23, 2023
pytorchmergebot pushed a commit that referenced this pull request Oct 23, 2023
pytorchmergebot pushed a commit that referenced this pull request Oct 23, 2023
pytorchmergebot pushed a commit that referenced this pull request Oct 23, 2023
As in the title.

The figures below illustrate the performance differences of bsr_dense_mm with optimized parameters and bsr_dense_mm with default parameters (GPU: NVIDIA A100-SXM4-80GB). The first figure represents the performance equilibrium point in BSR tensor sparsity at which value bsr_dense_mm have the same performance characteristics as torch.matmul. The second figure represents speedups from using optimized meta parameters in bsr_dense_mm at its performance equilibrium points with respect to bsr_dense_mm with default meta parameters.

In sum, this PR speeds up `bsr_dense_mm` about 50 % depending on the bsr tensor shape and blocksize and lowers the performance equilibrium points of BSR tensor sparsity and strided tensor for matmul operations.

<img src="https://github.com/pytorch/pytorch/assets/402156/6fe9d35f-dd21-4aa0-bb01-6ee257254453" width="48%"> <img src="https://github.com/pytorch/pytorch/assets/402156/506921c6-3770-4209-ad3d-498d2ae4989d" width="48%">

Pull Request resolved: #111760
Approved by: https://github.com/cpuhrsch
ghstack dependencies: #110396, #111470, #111489
pytorchmergebot pushed a commit that referenced this pull request Oct 23, 2023
andreigh pushed a commit to andreigh/pytorch that referenced this pull request Oct 26, 2023
This PR introduces `scatter_mm` operation (compute `mm` of arbitrary pairs of tensors given in batches of tensors) that is used to implement `bsr_scatter_mm` that is equivalent to `bsr_dense_mm` (the `mm` operation on bsr and strided tensors). The implementation is provided both in Triton (when tensor dimensions are multiples of 16) and in PyTorch (otherwise).

The figures below illustrate the performance differences of `bsr_scatter_mm` and `bsr_dense_mm` (GPU: `NVIDIA GeForce RTX 2060 SUPER`). The first figure represents the performance equilibrium point in BSR tensor sparsity at which value `bsr_scatter_mm` or `bsr_dense_mm` have the same performance characteristics as `torch.matmul`. The second figure represents speedups from using `bsr_scatter_mm` at its performance equilibrium points with respect to `bsr_dense_mm`.

<img src="https://github.com/pytorch/pytorch/assets/402156/526d182e-937f-4812-a6c4-904f52d6d5ab" width="48%"> <img src="https://github.com/pytorch/pytorch/assets/402156/ccb606ab-1f3f-4133-887c-b56285f4f168" width="48%">

The same figures for GPU card `NVIDIA A100-SXM4-80GB`:

<img src="https://github.com/pytorch/pytorch/assets/402156/25466f1d-df34-4d1c-a975-afb478e4d9f0" width="48%"> <img src="https://github.com/pytorch/pytorch/assets/402156/6ada91f0-a20f-4f0d-8a48-1f4ccc60d08e" width="48%">

In sum:
- `bsr_scatter_mm` is about 2x faster than `bsr_dense_mm` for small block sizes of 16 and 32 and large tensors [GPU: `NVIDIA GeForce RTX 2060 SUPER`].
- `bsr_scatter_mm` is up to 2x faster than `bsr_dense_mm` for small block sizes of 16 and large tensors [GPU: `NVIDIA A100-SXM4-80GB`].
- `bsr_dense_mm` is up to 20 % faster than `bsr_scatter_mm` for block sizes of 64 or larger [GPU: `NVIDIA GeForce RTX 2060 SUPER`].
- However, `bsr_dense_mm` fails with `OutOfResources` exception for block sizes of 256 or larger whereas `bsr_scatter_mm` succeeds.

Pull Request resolved: pytorch#110396
Approved by: https://github.com/cpuhrsch
andreigh pushed a commit to andreigh/pytorch that referenced this pull request Oct 26, 2023
andreigh pushed a commit to andreigh/pytorch that referenced this pull request Oct 26, 2023
andreigh pushed a commit to andreigh/pytorch that referenced this pull request Oct 26, 2023
…1760)

As in the title.

The figures below illustrate the performance differences of bsr_dense_mm with optimized parameters and bsr_dense_mm with default parameters (GPU: NVIDIA A100-SXM4-80GB). The first figure represents the performance equilibrium point in BSR tensor sparsity at which value bsr_dense_mm have the same performance characteristics as torch.matmul. The second figure represents speedups from using optimized meta parameters in bsr_dense_mm at its performance equilibrium points with respect to bsr_dense_mm with default meta parameters.

In sum, this PR speeds up `bsr_dense_mm` about 50 % depending on the bsr tensor shape and blocksize and lowers the performance equilibrium points of BSR tensor sparsity and strided tensor for matmul operations.

<img src="https://github.com/pytorch/pytorch/assets/402156/6fe9d35f-dd21-4aa0-bb01-6ee257254453" width="48%"> <img src="https://github.com/pytorch/pytorch/assets/402156/506921c6-3770-4209-ad3d-498d2ae4989d" width="48%">

Pull Request resolved: pytorch#111760
Approved by: https://github.com/cpuhrsch
ghstack dependencies: pytorch#110396, pytorch#111470, pytorch#111489
andreigh pushed a commit to andreigh/pytorch that referenced this pull request Oct 26, 2023
@facebook-github-bot facebook-github-bot deleted the gh/pearu/120/head branch October 27, 2023 14:25
xuhancn pushed a commit to xuhancn/pytorch that referenced this pull request Nov 7, 2023
This PR introduces `scatter_mm` operation (compute `mm` of arbitrary pairs of tensors given in batches of tensors) that is used to implement `bsr_scatter_mm` that is equivalent to `bsr_dense_mm` (the `mm` operation on bsr and strided tensors). The implementation is provided both in Triton (when tensor dimensions are multiples of 16) and in PyTorch (otherwise).

The figures below illustrate the performance differences of `bsr_scatter_mm` and `bsr_dense_mm` (GPU: `NVIDIA GeForce RTX 2060 SUPER`). The first figure represents the performance equilibrium point in BSR tensor sparsity at which value `bsr_scatter_mm` or `bsr_dense_mm` have the same performance characteristics as `torch.matmul`. The second figure represents speedups from using `bsr_scatter_mm` at its performance equilibrium points with respect to `bsr_dense_mm`.

<img src="https://github.com/pytorch/pytorch/assets/402156/526d182e-937f-4812-a6c4-904f52d6d5ab" width="48%"> <img src="https://github.com/pytorch/pytorch/assets/402156/ccb606ab-1f3f-4133-887c-b56285f4f168" width="48%">

The same figures for GPU card `NVIDIA A100-SXM4-80GB`:

<img src="https://github.com/pytorch/pytorch/assets/402156/25466f1d-df34-4d1c-a975-afb478e4d9f0" width="48%"> <img src="https://github.com/pytorch/pytorch/assets/402156/6ada91f0-a20f-4f0d-8a48-1f4ccc60d08e" width="48%">

In sum:
- `bsr_scatter_mm` is about 2x faster than `bsr_dense_mm` for small block sizes of 16 and 32 and large tensors [GPU: `NVIDIA GeForce RTX 2060 SUPER`].
- `bsr_scatter_mm` is up to 2x faster than `bsr_dense_mm` for small block sizes of 16 and large tensors [GPU: `NVIDIA A100-SXM4-80GB`].
- `bsr_dense_mm` is up to 20 % faster than `bsr_scatter_mm` for block sizes of 64 or larger [GPU: `NVIDIA GeForce RTX 2060 SUPER`].
- However, `bsr_dense_mm` fails with `OutOfResources` exception for block sizes of 256 or larger whereas `bsr_scatter_mm` succeeds.

Pull Request resolved: pytorch#110396
Approved by: https://github.com/cpuhrsch
xuhancn pushed a commit to xuhancn/pytorch that referenced this pull request Nov 7, 2023
xuhancn pushed a commit to xuhancn/pytorch that referenced this pull request Nov 7, 2023
xuhancn pushed a commit to xuhancn/pytorch that referenced this pull request Nov 7, 2023
…1760)

As in the title.

The figures below illustrate the performance differences of bsr_dense_mm with optimized parameters and bsr_dense_mm with default parameters (GPU: NVIDIA A100-SXM4-80GB). The first figure represents the performance equilibrium point in BSR tensor sparsity at which value bsr_dense_mm have the same performance characteristics as torch.matmul. The second figure represents speedups from using optimized meta parameters in bsr_dense_mm at its performance equilibrium points with respect to bsr_dense_mm with default meta parameters.

In sum, this PR speeds up `bsr_dense_mm` about 50 % depending on the bsr tensor shape and blocksize and lowers the performance equilibrium points of BSR tensor sparsity and strided tensor for matmul operations.

<img src="https://github.com/pytorch/pytorch/assets/402156/6fe9d35f-dd21-4aa0-bb01-6ee257254453" width="48%"> <img src="https://github.com/pytorch/pytorch/assets/402156/506921c6-3770-4209-ad3d-498d2ae4989d" width="48%">

Pull Request resolved: pytorch#111760
Approved by: https://github.com/cpuhrsch
ghstack dependencies: pytorch#110396, pytorch#111470, pytorch#111489
xuhancn pushed a commit to xuhancn/pytorch that referenced this pull request Nov 7, 2023
Skylion007 pushed a commit to Skylion007/pytorch that referenced this pull request Nov 14, 2023
This PR introduces `scatter_mm` operation (compute `mm` of arbitrary pairs of tensors given in batches of tensors) that is used to implement `bsr_scatter_mm` that is equivalent to `bsr_dense_mm` (the `mm` operation on bsr and strided tensors). The implementation is provided both in Triton (when tensor dimensions are multiples of 16) and in PyTorch (otherwise).

The figures below illustrate the performance differences of `bsr_scatter_mm` and `bsr_dense_mm` (GPU: `NVIDIA GeForce RTX 2060 SUPER`). The first figure represents the performance equilibrium point in BSR tensor sparsity at which value `bsr_scatter_mm` or `bsr_dense_mm` have the same performance characteristics as `torch.matmul`. The second figure represents speedups from using `bsr_scatter_mm` at its performance equilibrium points with respect to `bsr_dense_mm`.

<img src="https://github.com/pytorch/pytorch/assets/402156/526d182e-937f-4812-a6c4-904f52d6d5ab" width="48%"> <img src="https://github.com/pytorch/pytorch/assets/402156/ccb606ab-1f3f-4133-887c-b56285f4f168" width="48%">

The same figures for GPU card `NVIDIA A100-SXM4-80GB`:

<img src="https://github.com/pytorch/pytorch/assets/402156/25466f1d-df34-4d1c-a975-afb478e4d9f0" width="48%"> <img src="https://github.com/pytorch/pytorch/assets/402156/6ada91f0-a20f-4f0d-8a48-1f4ccc60d08e" width="48%">

In sum:
- `bsr_scatter_mm` is about 2x faster than `bsr_dense_mm` for small block sizes of 16 and 32 and large tensors [GPU: `NVIDIA GeForce RTX 2060 SUPER`].
- `bsr_scatter_mm` is up to 2x faster than `bsr_dense_mm` for small block sizes of 16 and large tensors [GPU: `NVIDIA A100-SXM4-80GB`].
- `bsr_dense_mm` is up to 20 % faster than `bsr_scatter_mm` for block sizes of 64 or larger [GPU: `NVIDIA GeForce RTX 2060 SUPER`].
- However, `bsr_dense_mm` fails with `OutOfResources` exception for block sizes of 256 or larger whereas `bsr_scatter_mm` succeeds.

Pull Request resolved: pytorch#110396
Approved by: https://github.com/cpuhrsch
Skylion007 pushed a commit to Skylion007/pytorch that referenced this pull request Nov 14, 2023
Skylion007 pushed a commit to Skylion007/pytorch that referenced this pull request Nov 14, 2023
Skylion007 pushed a commit to Skylion007/pytorch that referenced this pull request Nov 14, 2023
…1760)

As in the title.

The figures below illustrate the performance differences of bsr_dense_mm with optimized parameters and bsr_dense_mm with default parameters (GPU: NVIDIA A100-SXM4-80GB). The first figure represents the performance equilibrium point in BSR tensor sparsity at which value bsr_dense_mm have the same performance characteristics as torch.matmul. The second figure represents speedups from using optimized meta parameters in bsr_dense_mm at its performance equilibrium points with respect to bsr_dense_mm with default meta parameters.

In sum, this PR speeds up `bsr_dense_mm` about 50 % depending on the bsr tensor shape and blocksize and lowers the performance equilibrium points of BSR tensor sparsity and strided tensor for matmul operations.

<img src="https://github.com/pytorch/pytorch/assets/402156/6fe9d35f-dd21-4aa0-bb01-6ee257254453" width="48%"> <img src="https://github.com/pytorch/pytorch/assets/402156/506921c6-3770-4209-ad3d-498d2ae4989d" width="48%">

Pull Request resolved: pytorch#111760
Approved by: https://github.com/cpuhrsch
ghstack dependencies: pytorch#110396, pytorch#111470, pytorch#111489
Skylion007 pushed a commit to Skylion007/pytorch that referenced this pull request Nov 14, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged module: sparse Related to torch.sparse open source release notes: sparse release notes category topic: new features topic category topic: performance topic category
Projects
Development

Successfully merging this pull request may close these issues.

None yet

4 participants