-
Notifications
You must be signed in to change notification settings - Fork 21.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Update on "inductor: enable weight prepack for LSTM"
- Enabled LSTM weight prepack in inductor. - Added a mkldnn decomposition for lstm which won't change for different `seq_lens`. With the previous decomposition, for dynamic shapes use case where `seq_lens` changes, the graph will be different. - Extended several inductor utility functions to support `List(Tensor`) as input. Previously those functions only supported `Tensor` input. **Update 2023-07-26:** - #103851 has moved CPU weight packing to be after AOTAutograd. Fixed the support in this PR to follow the same way (mainly in 3b207f7#diff-6dffed1ade0ba3e887f9a4eafa3bfcec267ab2365b8adcb91bd391f49b3fd2e3). LSTM is decomposed in `aten.mkldnn_rnn_layer` by layer and by direction. The weight prepack is done at the `mkldnn_rnn_layer` level. - Add a fix in rnn `__get_state__` function in case we need to recompile an `LSTM` module. When compiling the module, the weights tensors which are the `named_parameters` of the module are converted to `functional_tensor` here: https://github.com/pytorch/pytorch/blob/76fb72e24a5a4a47ad1f50c5c94d5c0b7e703531/torch/nn/utils/stateless.py#L125-L128 The forward function of LSTM will be called: https://github.com/pytorch/pytorch/blob/76fb72e24a5a4a47ad1f50c5c94d5c0b7e703531/torch/_functorch/aot_autograd.py#L3379-L3381 In the forward function, the `_flat_weights` are updated to be the same as the weights, thus becoming `functional_tensor`: https://github.com/pytorch/pytorch/blob/76fb72e24a5a4a47ad1f50c5c94d5c0b7e703531/torch/nn/modules/rnn.py#L775-L778 The weights tensors are converted back to the original tensors (which are not `functional_tensor` anymore) before exiting the `_reparametrize_module` context here: https://github.com/pytorch/pytorch/blob/76fb72e24a5a4a47ad1f50c5c94d5c0b7e703531/torch/nn/utils/stateless.py#L130-L142 But since `_flat_weights` is not in the `named_parameters` of the module, it's still `functional_tensor` ([link of the parameters that will be converted to functional and reverted back](https://github.com/pytorch/pytorch/blob/76fb72e24a5a4a47ad1f50c5c94d5c0b7e703531/torch/_functorch/aot_autograd.py#L3695-L3698)). At this moment, if we need to recompile the model, `deepcopy` will be called: https://github.com/pytorch/pytorch/blob/76fb72e24a5a4a47ad1f50c5c94d5c0b7e703531/torch/_dynamo/utils.py#L915-L917 And it will report `UnImplemented` since we have `functional_tensor` (`_flat_weights`) and will trigger graph break which is not what we expect: https://github.com/pytorch/pytorch/blob/76fb72e24a5a4a47ad1f50c5c94d5c0b7e703531/torch/_subclasses/meta_utils.py#L514 Added a fix in the `__get_state__` to update the `_flat_weights` if ever weights have changed to fix this issue. The fix is covered in the `test_lstm_packed` UT. cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 voznesenskym penguinwu EikanWang Guobing-Chen zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov [ghstack-poisoned]
- Loading branch information
Showing
56 changed files
with
1,290 additions
and
1,374 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Validating CODEOWNERS rules …
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
#version 450 core | ||
#define PRECISION $precision | ||
#define FORMAT $format | ||
|
||
layout(std430) buffer; | ||
|
||
/* Qualifiers: layout - storage - precision - memory */ | ||
layout(set = 0, binding = 0, FORMAT) uniform PRECISION restrict writeonly image3D uOutput; | ||
layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput; | ||
layout(set = 0, binding = 2) uniform PRECISION restrict Block { | ||
// dim_info.x: dim to sum | ||
// dim_info.y: size of dim (in the input) | ||
uvec2 dim_info; | ||
int channel; | ||
} | ||
uBlock; | ||
|
||
/* | ||
* Local Work Group Size | ||
*/ | ||
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; | ||
|
||
/* | ||
* Returns a new tensor with values summed along dimension dim | ||
* Dimension dim is squeezed | ||
* For each pos: | ||
* - Iterate over the out_texel and the summed dimension | ||
* - For H,W; rearrange pos.x, pos.y | ||
* - For C,H,W; | ||
* When CHW are summed, batch moves into channel | ||
* The src N is determined by pos.z * 4 + out_index | ||
*/ | ||
|
||
void main() { | ||
const ivec3 pos = ivec3(gl_GlobalInvocationID); | ||
|
||
int flattened_channels = int(ceil(uBlock.channel / 4.0)); | ||
vec4 out_texel = vec4(0, 0, 0, 0); | ||
|
||
// Batch | ||
if (uBlock.dim_info.x == 0) { | ||
for (int batch = 0; batch < uBlock.dim_info.y; batch++) { | ||
// src_n = batch | ||
// src_c = pos.z | ||
int src_z = batch * flattened_channels + pos.z; | ||
vec4 v = texelFetch(uInput, ivec3(pos.x, pos.y, src_z), 0); | ||
out_texel += v; | ||
} | ||
imageStore(uOutput, pos, out_texel); | ||
} | ||
|
||
// Channel | ||
else if (uBlock.dim_info.x == 1) { | ||
for (int out_index = 0; out_index < 4; out_index++) { | ||
for (int channel = 0; channel < uBlock.dim_info.y; channel++) { | ||
// src_n = pos.z * 4 + out_index | ||
// src_c = channel | ||
int src_z = | ||
(pos.z * 4 + out_index) * flattened_channels + int(channel / 4); | ||
vec4 v = texelFetch(uInput, ivec3(pos.x, pos.y, src_z), 0); | ||
out_texel[out_index] += v[channel % 4]; | ||
} | ||
} | ||
imageStore(uOutput, pos, out_texel); | ||
} | ||
|
||
// Height, Width | ||
else { | ||
for (int out_index = 0; out_index < 4; out_index++) { | ||
// src_n = pos.z * 4 + out_index | ||
// src_c = pos.y | ||
int src_z = (pos.z * 4 + out_index) * flattened_channels + pos.y / 4; | ||
for (int hw = 0; hw < uBlock.dim_info.y; hw++) { | ||
vec4 v = (uBlock.dim_info.x == 2) | ||
? texelFetch(uInput, ivec3(pos.x, hw, src_z), 0) // Height | ||
: texelFetch(uInput, ivec3(hw, pos.x, src_z), 0); // Width | ||
out_texel[out_index] += v[pos.y % 4]; | ||
} | ||
} | ||
imageStore(uOutput, pos, out_texel); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
#include <ATen/native/vulkan/ops/Common.h> | ||
#include <ATen/native/vulkan/ops/Utils.h> | ||
#include <torch/library.h> | ||
|
||
namespace at { | ||
namespace native { | ||
namespace vulkan { | ||
namespace ops { | ||
namespace { | ||
|
||
using namespace api::utils; | ||
|
||
Tensor sum_dim( | ||
const at::Tensor& self, | ||
int64_t dim, | ||
bool keepdim, | ||
const optional<ScalarType> dtype) { | ||
TORCH_CHECK( | ||
self.dim() >= 2 || self.dim() <= 4, | ||
"Vulkan sum.dim_IntList supports 2d, 3d, 4d tensors as input!"); | ||
TORCH_CHECK( | ||
dim >= -self.dim() - 1 && dim <= self.dim(), | ||
"Vulkan sum.dim_IntList dimension out of range expected to be in range of [", | ||
-self.dim() - 1, | ||
",", | ||
self.dim(), | ||
"], but got ", | ||
dim); | ||
|
||
// Get the global Vulkan context | ||
api::Context* const context = api::context(); | ||
|
||
// Cast the input Tensor to a vTensor | ||
const Tensor input = self.is_vulkan() ? self : self.vulkan(); | ||
const vTensor& v_input = convert(input); | ||
|
||
// Normalize dim into range [0, self.dim()] | ||
dim = utils::normalize(dim, self.dim()); | ||
|
||
// Create the output texture | ||
std::vector<int64_t> output_size = self.sizes().vec(); | ||
uint32_t dim_size = output_size[dim]; | ||
output_size.erase(output_size.begin() + dim); | ||
|
||
ScalarType type = self.scalar_type(); | ||
if (dtype.has_value()) { | ||
type = dtype.value(); | ||
} | ||
|
||
vTensor v_output{ | ||
context, | ||
output_size, | ||
type, | ||
}; | ||
|
||
// Required to determine how to insert memory barriers in the command buffer | ||
api::PipelineBarrier pipeline_barrier{}; | ||
|
||
// Shift dim into 4d range | ||
if (self.dim() < 4) { | ||
dim += (4 - self.dim()); | ||
} | ||
|
||
// Create the params buffer | ||
const struct Block final { | ||
uvec2 dim_info; | ||
int32_t channel; | ||
} block{ | ||
{static_cast<uint32_t>(dim), dim_size}, | ||
static_cast<int32_t>(get_dim<Dim4D::Channel>(v_input)), | ||
}; | ||
|
||
api::UniformParamsBuffer params(context, block); | ||
|
||
context->submit_compute_job( | ||
// shader descriptor | ||
VK_KERNEL(sum_dim), | ||
// pipeline barrier | ||
pipeline_barrier, | ||
// global work group size | ||
v_output.extents(), | ||
// local work group size | ||
adaptive_work_group_size(v_output.extents()), | ||
// fence handle | ||
VK_NULL_HANDLE, | ||
// shader arguments | ||
v_output.image( | ||
pipeline_barrier, | ||
api::PipelineStage::COMPUTE, | ||
api::MemoryAccessType::WRITE), | ||
v_input.image(pipeline_barrier, api::PipelineStage::COMPUTE), | ||
// params buffer | ||
params.buffer()); | ||
return convert(v_output); | ||
} | ||
|
||
Tensor sum_dim_IntList( | ||
const at::Tensor& self, | ||
const OptionalIntArrayRef opt_dim, | ||
bool keepdim, | ||
const optional<ScalarType> dtype) { | ||
TORCH_CHECK( | ||
opt_dim.has_value(), | ||
"Vulkan sum.dim_IntList without a dim arg is not implemented"); | ||
TORCH_CHECK( | ||
keepdim == false, | ||
"Vulkan sum.dim_IntList with keepdim=true is not implemented"); | ||
|
||
std::set<int64_t> dims_set; | ||
if (opt_dim.has_value()) { | ||
auto dims = opt_dim.value(); | ||
for (const auto& d : dims) { | ||
dims_set.insert(d); | ||
} | ||
Tensor result = self; | ||
for (auto it = dims_set.rbegin(); it != dims_set.rend(); ++it) { | ||
result = sum_dim(result, *it, keepdim, dtype); | ||
} | ||
return result; | ||
} | ||
return self; | ||
} | ||
|
||
#ifdef USE_VULKAN_API | ||
|
||
TORCH_LIBRARY_IMPL(aten, Vulkan, m) { | ||
m.impl( | ||
TORCH_SELECTIVE_NAME("aten::sum.dim_IntList"), TORCH_FN(sum_dim_IntList)); | ||
} | ||
|
||
#endif /* USE_VULKAN_API */ | ||
|
||
} // namespace | ||
} // namespace ops | ||
} // namespace vulkan | ||
} // namespace native | ||
} // namespace at |
Oops, something went wrong.