-
Notifications
You must be signed in to change notification settings - Fork 25k
[vulkan] Efficient gemm implementation #49609
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
💊 CI failures summary and remediationsAs of commit 19eb3aa (more details on the Dr. CI page): 💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions to the (internal) Dr. CI Users group. |
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
Differential Revision: [D26209677](https://our.internmc.facebook.com/intern/diff/D26209677) [ghstack-poisoned]
Differential Revision: [D26209677](https://our.internmc.facebook.com/intern/diff/D26209677) [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Stephen!
|
||
if (all(lessThan(pos, uBlock.size.xyz))) { | ||
const int base_x = 2*pos.x; | ||
const int base_y = 2*pos.y; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Think in terms of vectors. Not sure if it will perform better on modern scalar GPUs with a SIMT architecture (shouldn't be worse anyway) but should perform better on older VLIW.
By the way, swizzling in shaders is free.
const int2 base = 2 * pos.xy;
const ivec4 index = base + ivec4(0, 1 ,uBlock.orig_size.x, uBlock.orig_size.x+1); | ||
|
||
vec4 outvec = vec4(0,0,0,0); | ||
if (base_x < uBlock.orig_size.x && base_y < uBlock.orig_size.y) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This shader is not performance sensitive if it's just a one time transformation but still branches are expensive in shaders. Generally if you can rework the logic to avoid branches it is better.
const Shader::Descriptor& shader_descriptor, | ||
const Shader::WorkGroup& global_work_group, | ||
const Shader::WorkGroup& local_work_group_size, | ||
Arguments&&... arguments); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please delete the old version of this function that does not take local work group size explicitly, replacing it with this new version only. Then pass local_work_group_size (adapter->blah_blah() - don't remember the name) explicitly at all call sites. We are going to need that flexibility anyway for tweaking local work group size.
VK_IMAGE_PACK_NC4HW_3D = 0, | ||
VK_IMAGE_PACK_NC4HW_2D = 1, | ||
VK_IMAGE_PACK_H2W2 = 2, | ||
} VkImagePackFormat; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we still need this? Sorry it may become apparent as I scroll down.
vec4 texel1 = texelFetch(uM1, ivec3(k, pos.y, pos.z), 0); | ||
vec4 texel2 = texelFetch(uM2, ivec3(pos.x, k, pos.z), 0); | ||
sum = fma(texel1.xxzz, texel2.xyxy, sum); | ||
sum = fma(texel1.yyww, texel2.zwzw, sum); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this a by-product of our new packing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, the new packing makes use of the entire input texel.
}, | ||
v_src.options() | ||
}; | ||
const struct { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment regarding anonymous structs on GCC.
}; | ||
|
||
uint32_t orig_w = output_sizes[output_sizes.size() - 1]; | ||
uint32_t orig_h = output_sizes[output_sizes.size() - 2]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const. const everywhere please. I am a const zealot. :)
return v_src_unpacked; | ||
} | ||
|
||
vTensor unpack_image1x1(vTensor v_src, c10::SmallVector<int64_t, 4u> output_sizes, api::Context* context, api::Command::Buffer& command_buffer) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pass all objects greater than the size of the two machine words (2 x 64-bits on 64-bit, 2 x 32-bit for 32-bits) by [const] reference. I add a fudge factor of 2 since pointer chasing and dereferencing (which is effectively what references are - just syntactic sugar for pointers) has a cost so it's best avoided when the cost of passing by value is small.
vTensor pack_image2d_h2w2(vTensor v_src, api::Context* context, api::Command::Buffer& command_buffer); | ||
vTensor unpack_image2d_h2w2(vTensor v_src, c10::SmallVector<int64_t, 4u> output_sizes, api::Context* context, api::Command::Buffer& command_buffer); | ||
|
||
vTensor unpack_image1x1(vTensor v_src, c10::SmallVector<int64_t, 4u> output_sizes, api::Context* context, api::Command::Buffer& command_buffer); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If these functions are only used in one single implementation file, please remove the common header. Reason: Software engineering is the art (since it is not all science unfortunately) and science of change management, and the bedrock of managing changes is limiting scope. Limiting scope in general is the single most important tool software engineers have to get a handle on entropy.
|
||
const auto check = almostEqual(out_cpu, out_vulkan.cpu()); | ||
if (!check) { | ||
std::cout << "Expected:\n" << out_cpu << std::endl; | ||
std::cout << "Got:\n" << out_vulkan.cpu() << std::endl; | ||
showRtol(out_cpu, out_vulkan.cpu()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Change other places to this function as well.
Differential Revision: [D26209677](https://our.internmc.facebook.com/intern/diff/D26209677) [ghstack-poisoned]
Differential Revision: [D26209677](https://our.internmc.facebook.com/intern/diff/D26209677) [ghstack-poisoned]
Differential Revision: [D26209677](https://our.internmc.facebook.com/intern/diff/D26209677) [ghstack-poisoned]
Differential Revision: [D26209677](https://our.internmc.facebook.com/intern/diff/D26209677) [ghstack-poisoned]
Summary: Pull Request resolved: pytorch#49609 Test Plan: Imported from OSS Reviewed By: AshkanAliabadi Differential Revision: D26209677 Pulled By: SS-JIA fbshipit-source-id: 773a944559bf0deb3cf3e233d833220a12f9f2ab
Stack from ghstack:
Differential Revision: D26209677