Skip to content

Commit

Permalink
Update on "inductor: enable weight prepack for LSTM"
Browse files Browse the repository at this point in the history
- Enabled LSTM weight prepack in inductor.
- Added a mkldnn decomposition for lstm which won't change for different `seq_lens`. With the previous decomposition, for dynamic shapes use case where `seq_lens` changes, the graph will be different.
- Extended several inductor utility functions to support `List(Tensor`) as input. Previously those functions only supported `Tensor` input.

**Update 2023-07-26:**
- #103851 has moved CPU weight packing to be after AOTAutograd. Fixed the support in this PR to follow the same way (mainly in 3b207f7#diff-6dffed1ade0ba3e887f9a4eafa3bfcec267ab2365b8adcb91bd391f49b3fd2e3).
LSTM is decomposed in `aten.mkldnn_rnn_layer` by layer and by direction. The weight prepack is done at the `mkldnn_rnn_layer` level.
- Add a fix in rnn `__get_state__` function in case we need to recompile an `LSTM` module.
When compiling the module, the weights tensors which are the `named_parameters` of the module are converted to `functional_tensor` here:
https://github.com/pytorch/pytorch/blob/76fb72e24a5a4a47ad1f50c5c94d5c0b7e703531/torch/nn/utils/stateless.py#L125-L128
The forward function of LSTM will be called:
https://github.com/pytorch/pytorch/blob/76fb72e24a5a4a47ad1f50c5c94d5c0b7e703531/torch/_functorch/aot_autograd.py#L3379-L3381
In the forward function, the `_flat_weights` are updated to be the same as the weights, thus becoming `functional_tensor`:
https://github.com/pytorch/pytorch/blob/76fb72e24a5a4a47ad1f50c5c94d5c0b7e703531/torch/nn/modules/rnn.py#L775-L778
The weights tensors are converted back to the original tensors (which are not `functional_tensor` anymore) before exiting the `_reparametrize_module` context here:
https://github.com/pytorch/pytorch/blob/76fb72e24a5a4a47ad1f50c5c94d5c0b7e703531/torch/nn/utils/stateless.py#L130-L142
But since `_flat_weights` is not in the `named_parameters` of the module, it's still `functional_tensor` ([link of the parameters that will be converted to functional and reverted back](https://github.com/pytorch/pytorch/blob/76fb72e24a5a4a47ad1f50c5c94d5c0b7e703531/torch/_functorch/aot_autograd.py#L3695-L3698)).
At this moment, if we need to recompile the model, `deepcopy` will be called:
https://github.com/pytorch/pytorch/blob/76fb72e24a5a4a47ad1f50c5c94d5c0b7e703531/torch/_dynamo/utils.py#L915-L917
And it will report `UnImplemented` since we have `functional_tensor` (`_flat_weights`) and will trigger graph break which is not what we expect:
https://github.com/pytorch/pytorch/blob/76fb72e24a5a4a47ad1f50c5c94d5c0b7e703531/torch/_subclasses/meta_utils.py#L514
Added a fix in the `__get_state__`  to update the `_flat_weights` if ever weights have changed to fix this issue. The fix is covered in the `test_lstm_packed` UT.



cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 voznesenskym penguinwu EikanWang Guobing-Chen zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
  • Loading branch information
chunyuan-w committed Jul 28, 2023
2 parents e703b44 + 688f52a commit 572a694
Show file tree
Hide file tree
Showing 56 changed files with 1,290 additions and 1,374 deletions.
25 changes: 17 additions & 8 deletions .github/workflows/lint-bc.yml
Original file line number Diff line number Diff line change
@@ -1,15 +1,24 @@
name: BC Lint

on:
# Copied from check-labels.yml to get around needing approval for first time contributors
# See https://docs.github.com/en/actions/managing-workflow-runs/approving-workflow-runs-from-public-forks
# Only allow pull_request_target when merging to main, not some historical branch.
#
# Make sure to don't introduce explicit checking out and installing/running
# untrusted user code into this workflow!
pull_request_target:
types: [opened, synchronize, reopened, labeled, unlabeled]
branches: [main]
paths-ignore: [.github/workflows/lint-bc.yml]

# To allow testing PRs that change workflows.
# May be triggered together with pull_request_target, it's OK.
pull_request:
types:
- opened
- synchronize
- reopened
- labeled
- unlabeled
branches-ignore:
- nightly
types: [opened, synchronize, reopened, labeled, unlabeled]
paths: [.github/workflows/lint-bc.yml]
branches-ignore: [nightly]

workflow_dispatch:

jobs:
Expand Down
12 changes: 6 additions & 6 deletions CODEOWNERS
Validating CODEOWNERS rules …
Original file line number Diff line number Diff line change
Expand Up @@ -45,16 +45,16 @@ nn/qat/ @jerryzh168
# Distributed package
# This list is mostly if you'd like to be tagged as reviewer, feel free to add
# or remove yourself from it.
/torch/csrc/distributed/ @mrshenli @zhaojuanmao @rohan-varma @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @kiukchung @d4l3k
/torch/distributed/ @mrshenli @zhaojuanmao @rohan-varma @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @kiukchung @d4l3k
/torch/distributed/_composable @mrshenli @zhaojuanmao @rohan-varma @H-Huang @awgu @kwen2501 @yhcharles @fegin @fduwjj @kiukchung @d4l3k
/torch/nn/parallel/ @mrshenli @zhaojuanmao @rohan-varma @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @kiukchung @d4l3k
/torch/csrc/distributed/ @mrshenli @zhaojuanmao @rohan-varma @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @kiukchung @d4l3k @penguinwu
/torch/distributed/ @mrshenli @zhaojuanmao @rohan-varma @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @kiukchung @d4l3k @penguinwu
/torch/distributed/_composable @mrshenli @zhaojuanmao @rohan-varma @H-Huang @awgu @kwen2501 @yhcharles @fegin @fduwjj @kiukchung @d4l3k @penguinwu
/torch/nn/parallel/ @mrshenli @zhaojuanmao @rohan-varma @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @kiukchung @d4l3k @penguinwu

# Distributed tests
# This list is mostly if you'd like to be tagged as reviewer, feel free to add
# or remove yourself from it.
/test/distributed @mrshenli @zhaojuanmao @rohan-varma @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj
/torch/testing/_internal/distributed @mrshenli @zhaojuanmao @rohan-varma @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj
/test/distributed @mrshenli @zhaojuanmao @rohan-varma @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @penguinwu
/torch/testing/_internal/distributed @mrshenli @zhaojuanmao @rohan-varma @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @penguinwu

# ONNX Export
/torch/csrc/jit/passes/onnx.h @bowenbao @abock @thiagocrepaldi
Expand Down
18 changes: 11 additions & 7 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1208,7 +1208,7 @@
variants: function, method
dispatch:
CompositeExplicitAutograd: logical_xor
tags: pointwise
tags: [core, pointwise]

- func: logical_xor_(Tensor(a!) self, Tensor other) -> Tensor(a!)
device_check: NoCheck # TensorIterator
Expand Down Expand Up @@ -1919,6 +1919,7 @@
structured_delegate: cumsum.out
device_check: NoCheck # TensorIterator
variants: function, method
tags: core

- func: cumsum_(Tensor(a!) self, int dim, *, ScalarType? dtype=None) -> Tensor(a!)
structured_delegate: cumsum.out
Expand Down Expand Up @@ -2194,6 +2195,7 @@
CompositeExplicitAutograd: embedding_symint
NestedTensorCPU, NestedTensorCUDA: NestedTensor_embedding
autogen: embedding.out
tags: core

- func: embedding_backward(Tensor grad, Tensor indices, SymInt num_weights, SymInt padding_idx, bool scale_grad_by_freq, bool sparse) -> Tensor
dispatch:
Expand Down Expand Up @@ -2944,7 +2946,7 @@
variants: function, method
dispatch:
QuantizedCPU: quantized_index
tags: dynamic_output_shape
tags: [core, dynamic_output_shape]
# NB: This function is special-cased in tools/autograd/gen_variable_type.py
# NB: The following functions are declared in aten/src/ATen/templates/TensorBody.h and defined in aten/src/ATen/TensorIndexing.cpp:
# - Tensor Tensor::index(ArrayRef<TensorIndex> indices)
Expand Down Expand Up @@ -3005,6 +3007,7 @@
variants: function, method
dispatch:
CompositeExplicitAutograd: index_put
tags: core

- func: _unsafe_index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor
device_check: NoCheck # delegate to _index_put_impl_ after clone, which leverages TensorIterator
Expand Down Expand Up @@ -7867,7 +7870,7 @@
variants: method, function
dispatch:
CompositeExplicitAutograd: bitwise_and
tags: pointwise
tags: [core, pointwise]

- func: bitwise_and.Scalar_Tensor(Scalar self, Tensor other) -> Tensor
device_check: NoCheck # TensorIterator
Expand Down Expand Up @@ -7930,7 +7933,7 @@
- func: bitwise_or.Scalar(Tensor self, Scalar other) -> Tensor
device_check: NoCheck # TensorIterator
variants: method, function
tags: pointwise
tags: [core, pointwise]

- func: bitwise_or.Scalar_Tensor(Scalar self, Tensor other) -> Tensor
device_check: NoCheck # TensorIterator
Expand Down Expand Up @@ -7993,7 +7996,7 @@
- func: bitwise_xor.Scalar(Tensor self, Scalar other) -> Tensor
device_check: NoCheck # TensorIterator
variants: method, function
tags: pointwise
tags: [core, pointwise]

- func: bitwise_xor.Scalar_Tensor(Scalar self, Tensor other) -> Tensor
device_check: NoCheck # TensorIterator
Expand Down Expand Up @@ -9326,7 +9329,7 @@
variants: method, function
dispatch:
CompositeExplicitAutograd: fmod
tags: pointwise
tags: [core, pointwise]

- func: fmod_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
device_check: NoCheck # TensorIterator
Expand Down Expand Up @@ -9434,7 +9437,7 @@
variants: method, function
dispatch:
CompositeExplicitAutograd: remainder
tags: pointwise
tags: [core, pointwise]

- func: remainder_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
variants: method
Expand Down Expand Up @@ -9683,6 +9686,7 @@
variants: method, function
dispatch:
SparseCPU, SparseCUDA: any_sparse
tags: core

- func: any.all_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck
Expand Down
82 changes: 82 additions & 0 deletions aten/src/ATen/native/vulkan/glsl/sum_dim.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
#version 450 core
#define PRECISION $precision
#define FORMAT $format

layout(std430) buffer;

/* Qualifiers: layout - storage - precision - memory */
layout(set = 0, binding = 0, FORMAT) uniform PRECISION restrict writeonly image3D uOutput;
layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput;
layout(set = 0, binding = 2) uniform PRECISION restrict Block {
// dim_info.x: dim to sum
// dim_info.y: size of dim (in the input)
uvec2 dim_info;
int channel;
}
uBlock;

/*
* Local Work Group Size
*/
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

/*
* Returns a new tensor with values summed along dimension dim
* Dimension dim is squeezed
* For each pos:
* - Iterate over the out_texel and the summed dimension
* - For H,W; rearrange pos.x, pos.y
* - For C,H,W;
* When CHW are summed, batch moves into channel
* The src N is determined by pos.z * 4 + out_index
*/

void main() {
const ivec3 pos = ivec3(gl_GlobalInvocationID);

int flattened_channels = int(ceil(uBlock.channel / 4.0));
vec4 out_texel = vec4(0, 0, 0, 0);

// Batch
if (uBlock.dim_info.x == 0) {
for (int batch = 0; batch < uBlock.dim_info.y; batch++) {
// src_n = batch
// src_c = pos.z
int src_z = batch * flattened_channels + pos.z;
vec4 v = texelFetch(uInput, ivec3(pos.x, pos.y, src_z), 0);
out_texel += v;
}
imageStore(uOutput, pos, out_texel);
}

// Channel
else if (uBlock.dim_info.x == 1) {
for (int out_index = 0; out_index < 4; out_index++) {
for (int channel = 0; channel < uBlock.dim_info.y; channel++) {
// src_n = pos.z * 4 + out_index
// src_c = channel
int src_z =
(pos.z * 4 + out_index) * flattened_channels + int(channel / 4);
vec4 v = texelFetch(uInput, ivec3(pos.x, pos.y, src_z), 0);
out_texel[out_index] += v[channel % 4];
}
}
imageStore(uOutput, pos, out_texel);
}

// Height, Width
else {
for (int out_index = 0; out_index < 4; out_index++) {
// src_n = pos.z * 4 + out_index
// src_c = pos.y
int src_z = (pos.z * 4 + out_index) * flattened_channels + pos.y / 4;
for (int hw = 0; hw < uBlock.dim_info.y; hw++) {
vec4 v = (uBlock.dim_info.x == 2)
? texelFetch(uInput, ivec3(pos.x, hw, src_z), 0) // Height
: texelFetch(uInput, ivec3(hw, pos.x, src_z), 0); // Width
out_texel[out_index] += v[pos.y % 4];
}
}
imageStore(uOutput, pos, out_texel);
}
}
137 changes: 137 additions & 0 deletions aten/src/ATen/native/vulkan/ops/Sum.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
#include <ATen/native/vulkan/ops/Common.h>
#include <ATen/native/vulkan/ops/Utils.h>
#include <torch/library.h>

namespace at {
namespace native {
namespace vulkan {
namespace ops {
namespace {

using namespace api::utils;

Tensor sum_dim(
const at::Tensor& self,
int64_t dim,
bool keepdim,
const optional<ScalarType> dtype) {
TORCH_CHECK(
self.dim() >= 2 || self.dim() <= 4,
"Vulkan sum.dim_IntList supports 2d, 3d, 4d tensors as input!");
TORCH_CHECK(
dim >= -self.dim() - 1 && dim <= self.dim(),
"Vulkan sum.dim_IntList dimension out of range expected to be in range of [",
-self.dim() - 1,
",",
self.dim(),
"], but got ",
dim);

// Get the global Vulkan context
api::Context* const context = api::context();

// Cast the input Tensor to a vTensor
const Tensor input = self.is_vulkan() ? self : self.vulkan();
const vTensor& v_input = convert(input);

// Normalize dim into range [0, self.dim()]
dim = utils::normalize(dim, self.dim());

// Create the output texture
std::vector<int64_t> output_size = self.sizes().vec();
uint32_t dim_size = output_size[dim];
output_size.erase(output_size.begin() + dim);

ScalarType type = self.scalar_type();
if (dtype.has_value()) {
type = dtype.value();
}

vTensor v_output{
context,
output_size,
type,
};

// Required to determine how to insert memory barriers in the command buffer
api::PipelineBarrier pipeline_barrier{};

// Shift dim into 4d range
if (self.dim() < 4) {
dim += (4 - self.dim());
}

// Create the params buffer
const struct Block final {
uvec2 dim_info;
int32_t channel;
} block{
{static_cast<uint32_t>(dim), dim_size},
static_cast<int32_t>(get_dim<Dim4D::Channel>(v_input)),
};

api::UniformParamsBuffer params(context, block);

context->submit_compute_job(
// shader descriptor
VK_KERNEL(sum_dim),
// pipeline barrier
pipeline_barrier,
// global work group size
v_output.extents(),
// local work group size
adaptive_work_group_size(v_output.extents()),
// fence handle
VK_NULL_HANDLE,
// shader arguments
v_output.image(
pipeline_barrier,
api::PipelineStage::COMPUTE,
api::MemoryAccessType::WRITE),
v_input.image(pipeline_barrier, api::PipelineStage::COMPUTE),
// params buffer
params.buffer());
return convert(v_output);
}

Tensor sum_dim_IntList(
const at::Tensor& self,
const OptionalIntArrayRef opt_dim,
bool keepdim,
const optional<ScalarType> dtype) {
TORCH_CHECK(
opt_dim.has_value(),
"Vulkan sum.dim_IntList without a dim arg is not implemented");
TORCH_CHECK(
keepdim == false,
"Vulkan sum.dim_IntList with keepdim=true is not implemented");

std::set<int64_t> dims_set;
if (opt_dim.has_value()) {
auto dims = opt_dim.value();
for (const auto& d : dims) {
dims_set.insert(d);
}
Tensor result = self;
for (auto it = dims_set.rbegin(); it != dims_set.rend(); ++it) {
result = sum_dim(result, *it, keepdim, dtype);
}
return result;
}
return self;
}

#ifdef USE_VULKAN_API

TORCH_LIBRARY_IMPL(aten, Vulkan, m) {
m.impl(
TORCH_SELECTIVE_NAME("aten::sum.dim_IntList"), TORCH_FN(sum_dim_IntList));
}

#endif /* USE_VULKAN_API */

} // namespace
} // namespace ops
} // namespace vulkan
} // namespace native
} // namespace at

0 comments on commit 572a694

Please sign in to comment.