-
Notifications
You must be signed in to change notification settings - Fork 478
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add prototype for optimized dense2jagged kernel for 1 jagged dim #1221
base: main
Are you sure you want to change the base?
Conversation
✅ Deploy Preview for eclectic-stroopwafel-199537 canceled.
|
@mjanderson09 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is amazing! Thanks for the PR. I added a few comments here for visibility but these can be certainly addressed on our side.
} | ||
// warp sync ? | ||
index_t ix = idx - smem_dense_offsets[res]; | ||
index_t stride = (res * jagged_dim_stride); // res is row in dense-0th dim |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Strides are manually used vs. tensor accessor. Does it make a difference in perf?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It shouldn't make any difference in perf as long as accessor is also adding same amount of instructions. I used explicit calculations as it makes the code self explanatory.
for (index_t idx = (threadIdx.x + blockIdx.x * blockDim.x); idx < total_jagged_rows; | ||
idx += (blockDim.x * gridDim.x)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: we can use CUDA_KERNEL_LOOP(idx, total_jagged_rows) which is defined in pytorch/aten/src/ATen/cuda/detail/KernelUtils.h
typedef typename VecType32<scalar_t>::TType vec32; // no-op for double | ||
|
||
// 128 bit alignment check | ||
if (((e_dim * sizeof(scalar_t)) << 3 ) % 128 == 0) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of applying vectorization sort of all-or-nothing, can we do best effort like the first loop applied up to columns rounded down to 128-bit, the second loop applied to up to columns rounded down to 64-bit, and so on? In this way when e_dim is say 576-bit we can use 128-bit vectorization up to 512-bit and then 64-bit vectorization. Or this will regress perf?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Performance will depend on the overhead of math instructions that will get added to do the alignment check by each thread of the warp every time it executes. I think you can find both types of problem size where doing the incremental vector-bit decrement will help performance and other (perfectly aligned cases) where it will reduce it.
IMO ideal solution is to allocate contiguous tensors with appropriate pitch in each dimension that aligns the vectorized load/store. Not sure if pytorch does that already.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So you're only checking alignment with respect to e_dim, not alignment of the actual pointers, right? If that's the case, can we just do the following? Or do we have to check alignment of "values" and "dense" pointers.
VecType128::TType * values128 = (VecType128::TType *)&(values[0]);
VecType128::TType * dense128 = (VecType128::TType *)&(dense[0]);
VecType64::TType * values64 = (VecType64::TType *)&(values[0]);
VecType64::TType * dense64 = (VecType64::TType *)&(dense[0]);
VecType32::TType * values32 = (VecType32::TType *)&(values[0]);
VecType32::TType * dense32 = (VecType32::TType *)&(dense[0]);
for(int real_row = values_row ; real_row < nnz ; real_row += blockDim.y * gridDim.y) {
int dense_row = rows[real_row];
int dense_col = cols[real_row];
for(int tid = threadIdx.x ; tid < E/8 ; tid += blockDim.x) {
values128[tid + real_row * (E/8)] = dense128[tid + dense_col * (E/8) + dense_row * L * (E/8)];
}
for(int tid = threadIdx.x + (E/8)*8 ; tid < E/4 ; tid += blockDim.x) {
values64[tid + real_row * (E/4)] = dense64[tid + dense_col * (E/4) + dense_row * L * (E/4)];
}
for(int tid = threadIdx.x + (E/4)*4 ; tid < E/2 ; tid += blockDim.x) {
values32[tid + real_row * (E/2)] = dense32[tid + dense_col * (E/2) + dense_row * L * (E/2)];
}
for(int tid = threadIdx.x + (E/2)*2 ; tid < E ; tid += blockDim.x) {
values[tid + real_row * (E)] = dense[tid + dense_col * (E) + dense_row * L * (E)];
}
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am doing a alignment check with e_dim * sizeof(scalar_t)
because :
- It will ensure jagged_out tensor is aligned
- Because of broadcast and the only effective stride in dense tensor being same as e_dim, so my check ensures overall things will align for all the consequent pointer location that kernel will try to read from.
The alternate that you suggested will run into pointer alignment error. For eg: say output jagged is 130 bit wide (65 half's), so very first row will work fine with your logic, however the 2nd row will try to write first 8 half's with address that is not aligned for 128-bit operation.
for(int tid = threadIdx.x ; tid < E/8 ; tid += blockDim.x) {
Note that in my code I use a warp to service two rows (block.x = 16) , I use tid = threadIdx.x % blockDim.x
Summary: - Integrating prototype from pytorch#1221 - Added elementwise functionality - Restricted use to half precision, aligned case - Added tests for optimized case - Added elementwise ops to benchmark Differential Revision: D38365864 fbshipit-source-id: 0eb3b0cf6ab94123ba32e9ba3a827f7deda3fefa
Summary: Pull Request resolved: pytorch#1236 - Integrating prototype from pytorch#1221 - Added elementwise functionality - Restricted use to half precision, aligned case - Added tests for optimized case - Added elementwise ops to benchmark Differential Revision: D38365864 fbshipit-source-id: 50ee72c240c1f8fba53d77ac66418600b2e1d678
Summary: Pull Request resolved: pytorch#1236 - Integrating prototype from pytorch#1221 - Added elementwise functionality - Restricted use to half precision, aligned case - Added tests for optimized case - Added elementwise ops to benchmark Differential Revision: D38365864 fbshipit-source-id: d040bb98c57fb40bf5b65d527c3ff6cbb67b4f79
Summary: Pull Request resolved: pytorch#1236 - Integrating prototype from pytorch#1221 - Added elementwise functionality - Restricted use to half precision, aligned case - Added tests for optimized case - Added elementwise ops to benchmark Differential Revision: D38365864 fbshipit-source-id: ae5081660770a0ad01a7410f82d471eb3549d7f3
Summary: Pull Request resolved: pytorch#1236 - Integrating prototype from pytorch#1221 - Added elementwise functionality - Restricted use to half precision, aligned case - Added tests for optimized case - Added elementwise ops to benchmark Reviewed By: jianyuh Differential Revision: D38365864 fbshipit-source-id: 8b23ab08a26669da752199afdf6ab6a0f02f0049
Summary: Pull Request resolved: pytorch#1236 - Integrating prototype from pytorch#1221 - Added elementwise functionality - Restricted use to half precision, aligned case - Added tests for optimized case - Added elementwise ops to benchmark Reviewed By: jianyuh Differential Revision: D38365864 fbshipit-source-id: 32f42b125aa6a5009e83c3ec9e81c807ca3e3dd1
Summary: Pull Request resolved: pytorch#1236 - Integrating prototype from pytorch#1221 - Added elementwise functionality - Restricted use to half precision, aligned case - Added tests for optimized case - Added elementwise ops to benchmark Reviewed By: jianyuh Differential Revision: D38365864 fbshipit-source-id: e8e7e757d23fff67c49943134c14f1b1837fdb36
Summary: Pull Request resolved: pytorch#1236 - Integrating prototype from pytorch#1221 - Added elementwise functionality - Restricted use to half precision, aligned case - Added tests for optimized case - Added elementwise ops to benchmark Reviewed By: jianyuh Differential Revision: D38365864 fbshipit-source-id: 08143f7e8d450ecc1dd036496b89b7b665051d14
Summary: Pull Request resolved: pytorch#1236 - Integrating prototype from pytorch#1221 - Added elementwise functionality - Restricted use to half precision, aligned case - Added tests for optimized case - Added elementwise ops to benchmark Reviewed By: jianyuh Differential Revision: D38365864 fbshipit-source-id: 262dbd9b9814cd2165c3446dd9f2f69a6ae62ca8
Summary: Pull Request resolved: #1236 - Integrating prototype from #1221 - Added elementwise functionality - Restricted use to half precision, aligned case - Added tests for optimized case - Added elementwise ops to benchmark Reviewed By: jianyuh Differential Revision: D38365864 fbshipit-source-id: c164213b587cd667fb9b2fcded8601a0bcdb12b0
Sharing optimized prototype kernel for 2D dense_to_jagged operations using vectorized operations.
CC: @mjanderson09