Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add prototype for optimized dense2jagged kernel for 1 jagged dim #1221

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

chirayuG-nvidia
Copy link

Sharing optimized prototype kernel for 2D dense_to_jagged operations using vectorized operations.
CC: @mjanderson09

@netlify
Copy link

netlify bot commented Jul 26, 2022

Deploy Preview for eclectic-stroopwafel-199537 canceled.

Name Link
🔨 Latest commit 23e692e
🔍 Latest deploy log https://app.netlify.com/sites/eclectic-stroopwafel-199537/deploys/62df4e7ae3678f000928a1f6

@facebook-github-bot
Copy link
Contributor

@mjanderson09 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

Copy link
Contributor

@jspark1105 jspark1105 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is amazing! Thanks for the PR. I added a few comments here for visibility but these can be certainly addressed on our side.

}
// warp sync ?
index_t ix = idx - smem_dense_offsets[res];
index_t stride = (res * jagged_dim_stride); // res is row in dense-0th dim
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Strides are manually used vs. tensor accessor. Does it make a difference in perf?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It shouldn't make any difference in perf as long as accessor is also adding same amount of instructions. I used explicit calculations as it makes the code self explanatory.

Comment on lines +187 to +188
for (index_t idx = (threadIdx.x + blockIdx.x * blockDim.x); idx < total_jagged_rows;
idx += (blockDim.x * gridDim.x)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we can use CUDA_KERNEL_LOOP(idx, total_jagged_rows) which is defined in pytorch/aten/src/ATen/cuda/detail/KernelUtils.h

typedef typename VecType32<scalar_t>::TType vec32; // no-op for double

// 128 bit alignment check
if (((e_dim * sizeof(scalar_t)) << 3 ) % 128 == 0) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of applying vectorization sort of all-or-nothing, can we do best effort like the first loop applied up to columns rounded down to 128-bit, the second loop applied to up to columns rounded down to 64-bit, and so on? In this way when e_dim is say 576-bit we can use 128-bit vectorization up to 512-bit and then 64-bit vectorization. Or this will regress perf?

Copy link
Author

@chirayuG-nvidia chirayuG-nvidia Jul 27, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Performance will depend on the overhead of math instructions that will get added to do the alignment check by each thread of the warp every time it executes. I think you can find both types of problem size where doing the incremental vector-bit decrement will help performance and other (perfectly aligned cases) where it will reduce it.
IMO ideal solution is to allocate contiguous tensors with appropriate pitch in each dimension that aligns the vectorized load/store. Not sure if pytorch does that already.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So you're only checking alignment with respect to e_dim, not alignment of the actual pointers, right? If that's the case, can we just do the following? Or do we have to check alignment of "values" and "dense" pointers.

    VecType128::TType * values128 = (VecType128::TType *)&(values[0]);
    VecType128::TType * dense128 = (VecType128::TType *)&(dense[0]);
    VecType64::TType * values64 = (VecType64::TType *)&(values[0]);
    VecType64::TType * dense64 = (VecType64::TType *)&(dense[0]);
    VecType32::TType * values32 = (VecType32::TType *)&(values[0]);
    VecType32::TType * dense32 = (VecType32::TType *)&(dense[0]);
    for(int real_row = values_row ; real_row < nnz ; real_row += blockDim.y * gridDim.y) {
        int dense_row = rows[real_row];
        int dense_col = cols[real_row];
        for(int tid = threadIdx.x ; tid < E/8 ; tid += blockDim.x) {
            values128[tid + real_row * (E/8)] = dense128[tid + dense_col * (E/8) + dense_row * L * (E/8)];
        }
        for(int tid = threadIdx.x + (E/8)*8 ; tid < E/4 ; tid += blockDim.x) {
            values64[tid + real_row * (E/4)] = dense64[tid + dense_col * (E/4) + dense_row * L * (E/4)];
        }
        for(int tid = threadIdx.x + (E/4)*4 ; tid < E/2 ; tid += blockDim.x) {
            values32[tid + real_row * (E/2)] = dense32[tid + dense_col * (E/2) + dense_row * L * (E/2)];
        }
        for(int tid = threadIdx.x + (E/2)*2 ; tid < E ; tid += blockDim.x) {
            values[tid + real_row * (E)] = dense[tid + dense_col * (E) + dense_row * L * (E)];
        }
    }

Copy link
Author

@chirayuG-nvidia chirayuG-nvidia Jul 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am doing a alignment check with e_dim * sizeof(scalar_t) because :

  1. It will ensure jagged_out tensor is aligned
  2. Because of broadcast and the only effective stride in dense tensor being same as e_dim, so my check ensures overall things will align for all the consequent pointer location that kernel will try to read from.

The alternate that you suggested will run into pointer alignment error. For eg: say output jagged is 130 bit wide (65 half's), so very first row will work fine with your logic, however the 2nd row will try to write first 8 half's with address that is not aligned for 128-bit operation.

for(int tid = threadIdx.x ; tid < E/8 ; tid += blockDim.x) {

Note that in my code I use a warp to service two rows (block.x = 16) , I use tid = threadIdx.x % blockDim.x

mjanderson09 added a commit to mjanderson09/FBGEMM that referenced this pull request Aug 4, 2022
Summary:
- Integrating prototype from pytorch#1221
- Added elementwise functionality
- Restricted use to half precision, aligned case
- Added tests for optimized case
- Added elementwise ops to benchmark

Differential Revision: D38365864

fbshipit-source-id: 0eb3b0cf6ab94123ba32e9ba3a827f7deda3fefa
mjanderson09 added a commit to mjanderson09/FBGEMM that referenced this pull request Aug 4, 2022
Summary:
Pull Request resolved: pytorch#1236

- Integrating prototype from pytorch#1221
- Added elementwise functionality
- Restricted use to half precision, aligned case
- Added tests for optimized case
- Added elementwise ops to benchmark

Differential Revision: D38365864

fbshipit-source-id: 50ee72c240c1f8fba53d77ac66418600b2e1d678
mjanderson09 added a commit to mjanderson09/FBGEMM that referenced this pull request Aug 4, 2022
Summary:
Pull Request resolved: pytorch#1236

- Integrating prototype from pytorch#1221
- Added elementwise functionality
- Restricted use to half precision, aligned case
- Added tests for optimized case
- Added elementwise ops to benchmark

Differential Revision: D38365864

fbshipit-source-id: d040bb98c57fb40bf5b65d527c3ff6cbb67b4f79
mjanderson09 added a commit to mjanderson09/FBGEMM that referenced this pull request Aug 4, 2022
Summary:
Pull Request resolved: pytorch#1236

- Integrating prototype from pytorch#1221
- Added elementwise functionality
- Restricted use to half precision, aligned case
- Added tests for optimized case
- Added elementwise ops to benchmark

Differential Revision: D38365864

fbshipit-source-id: ae5081660770a0ad01a7410f82d471eb3549d7f3
mjanderson09 added a commit to mjanderson09/FBGEMM that referenced this pull request Aug 5, 2022
Summary:
Pull Request resolved: pytorch#1236

- Integrating prototype from pytorch#1221
- Added elementwise functionality
- Restricted use to half precision, aligned case
- Added tests for optimized case
- Added elementwise ops to benchmark

Reviewed By: jianyuh

Differential Revision: D38365864

fbshipit-source-id: 8b23ab08a26669da752199afdf6ab6a0f02f0049
mjanderson09 added a commit to mjanderson09/FBGEMM that referenced this pull request Aug 9, 2022
Summary:
Pull Request resolved: pytorch#1236

- Integrating prototype from pytorch#1221
- Added elementwise functionality
- Restricted use to half precision, aligned case
- Added tests for optimized case
- Added elementwise ops to benchmark

Reviewed By: jianyuh

Differential Revision: D38365864

fbshipit-source-id: 32f42b125aa6a5009e83c3ec9e81c807ca3e3dd1
mjanderson09 added a commit to mjanderson09/FBGEMM that referenced this pull request Aug 10, 2022
Summary:
Pull Request resolved: pytorch#1236

- Integrating prototype from pytorch#1221
- Added elementwise functionality
- Restricted use to half precision, aligned case
- Added tests for optimized case
- Added elementwise ops to benchmark

Reviewed By: jianyuh

Differential Revision: D38365864

fbshipit-source-id: e8e7e757d23fff67c49943134c14f1b1837fdb36
mjanderson09 added a commit to mjanderson09/FBGEMM that referenced this pull request Aug 11, 2022
Summary:
Pull Request resolved: pytorch#1236

- Integrating prototype from pytorch#1221
- Added elementwise functionality
- Restricted use to half precision, aligned case
- Added tests for optimized case
- Added elementwise ops to benchmark

Reviewed By: jianyuh

Differential Revision: D38365864

fbshipit-source-id: 08143f7e8d450ecc1dd036496b89b7b665051d14
mjanderson09 added a commit to mjanderson09/FBGEMM that referenced this pull request Aug 12, 2022
Summary:
Pull Request resolved: pytorch#1236

- Integrating prototype from pytorch#1221
- Added elementwise functionality
- Restricted use to half precision, aligned case
- Added tests for optimized case
- Added elementwise ops to benchmark

Reviewed By: jianyuh

Differential Revision: D38365864

fbshipit-source-id: 262dbd9b9814cd2165c3446dd9f2f69a6ae62ca8
facebook-github-bot pushed a commit that referenced this pull request Aug 12, 2022
Summary:
Pull Request resolved: #1236

- Integrating prototype from #1221
- Added elementwise functionality
- Restricted use to half precision, aligned case
- Added tests for optimized case
- Added elementwise ops to benchmark

Reviewed By: jianyuh

Differential Revision: D38365864

fbshipit-source-id: c164213b587cd667fb9b2fcded8601a0bcdb12b0
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants