Skip to content

Commit

Permalink
inductor: enable weight prepack for LSTM
Browse files Browse the repository at this point in the history
ghstack-source-id: 7369d20ccac4c998176e6b65a5881f4864d39d8a
Pull Request resolved: #103071
  • Loading branch information
chunyuan-w committed Jun 7, 2023
1 parent f79d2b4 commit 3599a1a
Show file tree
Hide file tree
Showing 14 changed files with 714 additions and 32 deletions.
20 changes: 20 additions & 0 deletions aten/src/ATen/native/mkldnn/MKLDNNCommon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,26 @@ ideep::tensor itensor_view_from_dense(const Tensor& tensor) {
}
}

// TODO: same as get_mkldnn_tensor in RNN.cpp
ideep::tensor itensor_view_from_dense(
const at::Tensor& tensor,
const ideep::tensor::desc& desc) {
TORCH_CHECK(
tensor.device().is_cpu(),
"itensor_view_from_dense expects CPU tensor input");
TORCH_CHECK(
tensor.layout() == at::Layout::Strided,
"itensor_view_from_dense expects dense tensor input");
TORCH_CHECK(
tensor.scalar_type() == at::ScalarType::Float ||
tensor.scalar_type() == at::ScalarType::BFloat16 ||
tensor.scalar_type() == at::ScalarType::Half ||
tensor.scalar_type() == at::ScalarType::QInt8 ||
tensor.scalar_type() == at::ScalarType::QUInt8,
"itensor_view_from_dense expects float, bfloat16 or int8 tensor input");
return {desc, tensor.data_ptr()};
}

// Helper function for getting an ideep tensor out of an aten Tensor.
// Note in case the aten Tensor is a dense tensor, the returned ideep
// tensor is just a view of the storage of the aten dense tensor, so
Expand Down
6 changes: 6 additions & 0 deletions aten/src/ATen/native/mkldnn/MKLDNNCommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@ TORCH_API ideep::tensor& itensor_from_mkldnn(const Tensor& mkldnn_tensor);
// ideep::tensor will share the underlying buffer
TORCH_API ideep::tensor itensor_view_from_dense(const Tensor& tensor);

// Construct an `ideep::tensor` "view" from dense tensor using given desc, note
// the ideep::tensor will share the underlying buffer
TORCH_API ideep::tensor itensor_view_from_dense(
const at::Tensor& tensor,
const ideep::tensor::desc& desc);

// Helper function for getting an ideep tensor out of an aten Tensor or MKL-DNN tensor.
TORCH_API ideep::tensor itensor_from_tensor(const Tensor& tensor);

Expand Down
175 changes: 175 additions & 0 deletions aten/src/ATen/native/mkldnn/MKLDNNConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
#include <ATen/native/mkldnn/Utils.h>
#include <ATen/native/utils/ParamUtils.h>
#include <torch/library.h>
#include <ATen/MatrixRef.h>
#include <tuple>

#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
Expand All @@ -16,6 +18,7 @@
#include <ATen/ops/mkldnn_reorder_conv2d_weight_native.h>
#include <ATen/ops/mkldnn_reorder_conv3d_weight_native.h>
#include <ATen/ops/to_mkldnn_native.h>
#include <ATen/ops/zeros.h>
#endif


Expand Down Expand Up @@ -292,6 +295,172 @@ Tensor mkldnn_reorder_conv_transpose2d_weight(
self.options().device_opt());
}

std::tuple<ideep::tensor, ideep::tensor> get_lstm_packed_weights(
const at::Tensor& weight_ih,
const at::Tensor& weight_hh,
const at::Tensor& weight2,
const at::Tensor& weight3,
int64_t layer_feature_size,
int64_t hidden_size,
bool has_biases,
int64_t num_layers,
bool bidirectional,
int64_t time_step,
int64_t batch_size,
bool reverse) {

ideep::tensor cached_weight_ih, cached_weight_hh;

int64_t num_gates = 4;
int64_t num_bias_gates = 4;
std::vector<int64_t> output_sizes = {time_step, batch_size, hidden_size};

auto dtype = get_mkldnn_dtype(weight_ih.scalar_type());
ideep::tensor::desc src_layer_desc({time_step, batch_size, layer_feature_size}, dtype, ideep::format_tag::tnc);
ideep::tensor::desc src_iter_desc({1, 1, batch_size, hidden_size}, dtype, ideep::format_tag::ldnc);
ideep::tensor::desc src_iter_c_desc({1, 1, batch_size, hidden_size}, dtype, ideep::format_tag::ldnc);
ideep::tensor::desc bias_desc({1, 1, num_bias_gates, hidden_size}, dtype, ideep::format_tag::ldgo);

ideep::tensor::desc dst_layer_desc({time_step, batch_size, hidden_size}, dtype, ideep::format_tag::tnc);
ideep::tensor::desc dst_iter_desc({1, 1, batch_size, hidden_size}, dtype, ideep::format_tag::ldnc);
ideep::tensor::desc dst_iter_c_desc({1, 1, batch_size, hidden_size}, dtype, ideep::format_tag::ldnc);

ideep::tensor src_layer(src_layer_desc);
ideep::tensor src_iter(src_iter_desc);
ideep::tensor src_iter_c(src_iter_c_desc);
ideep::tensor bias(bias_desc);

auto w1 = itensor_view_from_dense(
weight_ih,
{{1, 1, layer_feature_size, num_gates, hidden_size},
get_mkldnn_dtype(weight_ih.scalar_type()),
ideep::format_tag::ldgoi});

auto w2 = itensor_view_from_dense(
weight_hh,
{{1, 1, hidden_size, num_gates, hidden_size},
get_mkldnn_dtype(weight_hh.scalar_type()),
ideep::format_tag::ldgoi});

ideep::tensor::desc packed_desc_ih, packed_desc_hh;

std::tie(packed_desc_ih, packed_desc_hh) =
ideep::lstm_forward_inference::expected_weights_desc(
output_sizes,
src_layer,
src_iter,
src_iter_c,
w1,
w2,
bias,
reverse);

cached_weight_ih.init(packed_desc_ih);
cached_weight_hh.init(packed_desc_hh);

cached_weight_ih.feed_from(w1);
cached_weight_hh.feed_from(w2);

return std::make_tuple(cached_weight_ih, cached_weight_hh);
}

std::vector<Tensor> mkldnn_reorder_lstm_weight(
TensorList weight,
int64_t input_feature_size,
int64_t hidden_size,
bool has_biases,
int64_t num_layers,
bool bidirectional,
bool batch_first,
c10::OptionalArrayRef<int64_t> input_size) {
std::vector<int64_t> input_size_value;
int64_t time_step, batch_size;
if (input_size.has_value()) {
input_size_value = input_size.value().vec();
int64_t time_index = batch_first ? 1: 0;
int64_t batch_size_index = batch_first ? 0: 1;

time_step = input_size_value[time_index];
batch_size = input_size_value[batch_size_index];
} else {
// no value fed, provide one here
time_step = 5;
batch_size = 10;
}

std::vector<Tensor> result(weight.size());

auto num_directions = bidirectional ? 2 : 1;
int64_t weight_stride0 = has_biases ? 4 : 2;

at::MatrixRef<at::Tensor> weights{
weight, static_cast<size_t>(weight_stride0)};
ideep::tensor w1_, w2_;
at::Tensor packed_w1, packed_w2;

for (int64_t layer = 0; layer < num_layers; layer++) {
for (int64_t direction = 0; direction < num_directions; direction++) {
// for layer == 0, feature_size = input_feature_size
// otherwise, feature_size = hidden_size
int64_t layer_feature_size = layer == 0? input_feature_size : num_directions * hidden_size;
auto index = layer * num_directions + direction;
auto layer_weights = weights[index];
TORCH_CHECK(layer_weights.size() == 2 || layer_weights.size() == 4);
auto reverse = (direction > 0);

std::tie(w1_, w2_) = get_lstm_packed_weights(
layer_weights[0],
layer_weights[1],
has_biases ? layer_weights[2] : at::zeros(
layer_weights[0].sizes(),
layer_weights[0].options()),
has_biases ? layer_weights[3]
: at::zeros(
layer_weights[1].sizes(),
layer_weights[1].options()),
layer_feature_size,
hidden_size,
has_biases,
num_layers,
bidirectional,
time_step,
batch_size,
reverse);

// TODO: use is_opaque() after updating ideep in pytorch
// Don't pack when the weight is of rnn_packed format
// When the weight is of rnn_packed format, if the seq_lens of
// the input changes, the format of weight also changes.
// oneDNN does not support reorder from rnn_packed back to public format.
// LSTM based on BRGEMM kernel (on AVX512 and newest ISAs) will use blocked
// format for weight of LSTM, which won't change when the input seq_lens
// changes.
// On AVX2, queried weight will be plain format
if (w1_.get_desc().is_rnn_packed() || w1_.get_desc().is_plain()) {
packed_w1 = layer_weights[0];
} else {
packed_w1 = new_with_itensor_mkldnn(std::move(w1_), optTypeMetaToScalarType(layer_weights[0].options().dtype_opt()), layer_weights[0].options().device_opt());
}

if (w2_.get_desc().is_rnn_packed() || w2_.get_desc().is_plain()) {
packed_w2 = layer_weights[1];
} else {
packed_w2 = new_with_itensor_mkldnn(std::move(w2_), optTypeMetaToScalarType(layer_weights[1].options().dtype_opt()), layer_weights[1].options().device_opt());
}

result[index * weight_stride0] = packed_w1;
result[index * weight_stride0+1] = packed_w2;

if (has_biases) {
result[index * weight_stride0+2] = layer_weights[2];
result[index * weight_stride0+3] = layer_weights[3];
}
}
}

return result;
}

TORCH_LIBRARY_IMPL(mkldnn, MkldnnCPU, m) {
m.impl(
TORCH_SELECTIVE_NAME("mkldnn::_reorder_convolution_transpose_weight"),
Expand All @@ -301,6 +470,12 @@ TORCH_LIBRARY_IMPL(mkldnn, MkldnnCPU, m) {
TORCH_FN(mkldnn_reorder_linear_weight));
}

TORCH_LIBRARY_IMPL(mkldnn, CPU, m) {
m.impl(
TORCH_SELECTIVE_NAME("mkldnn::_reorder_lstm_weight"),
TORCH_FN(mkldnn_reorder_lstm_weight));
}

#else

Tensor mkldnn_to_dense(const Tensor& mkldnn_tensor, c10::optional<ScalarType> dtype, c10::optional<bool> masked_grad) {
Expand Down
36 changes: 28 additions & 8 deletions aten/src/ATen/native/mkldnn/RNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <c10/core/GradMode.h>
#include <c10/macros/Macros.h>
#include <c10/util/Exception.h>
#include <torch/library.h>

#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/NativeFunctions.h>
Expand Down Expand Up @@ -266,12 +267,14 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> mkldnn_rnn_layer(const Tensor& input,
auto weight_ih = _shuffle_weight(w0, rnn.mode);
auto weight_hh = _shuffle_weight(w1, rnn.mode);

// TODO: bias is not packed
auto bias = has_biases
? _shuffle_bias(w2, w3, rnn.mode)
: at::zeros({rnn.num_bias_gates * rnn.hidden_size}, weight_ih.options());
: at::zeros({rnn.num_bias_gates * rnn.hidden_size}, weight_ih.options().layout(at::Layout::Strided));

// per layer input size
int64_t input_size = input.size(2);
ideep::tensor w1_, w2_;
auto x = get_mkldnn_tensor(
input,
rnn.src_layer_desc(input_size, get_mkldnn_dtype(input)));
Expand All @@ -287,9 +290,8 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> mkldnn_rnn_layer(const Tensor& input,
hy_, rnn.dst_iter_desc(get_mkldnn_dtype(hy_)));
auto cy = get_mkldnn_tensor(
cy_, rnn.dst_iter_c_desc(get_mkldnn_dtype(cy_)));
auto w1_ = get_mkldnn_tensor(weight_ih, rnn.weights_layer_desc(input_size, get_mkldnn_dtype(weight_ih)));
auto w2_ = get_mkldnn_tensor(weight_hh, rnn.weights_iter_desc(get_mkldnn_dtype(weight_hh)));

w1_ = weight_ih.is_mkldnn() ? itensor_from_tensor(weight_ih) : get_mkldnn_tensor(weight_ih, rnn.weights_layer_desc(input_size, get_mkldnn_dtype(weight_ih)));
w2_ = weight_hh.is_mkldnn() ? itensor_from_tensor(weight_hh) : get_mkldnn_tensor(weight_hh, rnn.weights_iter_desc(get_mkldnn_dtype(weight_hh)));
if (at::GradMode::is_enabled()) {
Tensor workspace = Tensor();
auto pd = ideep::lstm_forward_training::prepare(
Expand Down Expand Up @@ -504,8 +506,8 @@ std::tuple<Tensor, Tensor, Tensor> mkldnn_rnn(
auto layer_cx = cx[index];
auto reverse = (direction > 0);
auto outputs = at::mkldnn_rnn_layer(layer_input, layer_weights[0], layer_weights[1],
has_biases ? layer_weights[2] : at::zeros(layer_weights[0].sizes(), layer_weights[0].options()),
has_biases ? layer_weights[3] : at::zeros(layer_weights[1].sizes(), layer_weights[1].options()), layer_hx,
has_biases ? layer_weights[2] : at::zeros(layer_weights[0].sizes(), layer_weights[0].options().layout(at::Layout::Strided)),
has_biases ? layer_weights[3] : at::zeros(layer_weights[1].sizes(), layer_weights[1].options().layout(at::Layout::Strided)), layer_hx,
layer_cx, reverse, batch_sizes, mode, hidden_size, num_layers, has_biases, bidirectional, batch_first, train);
layer_output[direction] = std::get<0>(outputs);
layer_hy[index] = std::get<1>(outputs);
Expand Down Expand Up @@ -556,12 +558,10 @@ std::pair<Tensor, hidden_type> mkldnn_impl(
Tensor hx, cx;
std::tie(hx, cx) = unpack_hidden(hidden);
int64_t hidden_size = hx.size(2);

auto mkldnn_output = mkldnn_rnn(
input, params, has_biases ? 4 : 2,
hx, cx, static_cast<int>(mode), hidden_size, num_layers, has_biases, batch_first, dropout_p,
train, bidirectional, /*batch_sizes*/{});

return {std::get<0>(mkldnn_output),
pack_hidden<hidden_type>(std::get<1>(mkldnn_output), std::get<2>(mkldnn_output))};
}
Expand All @@ -580,6 +580,26 @@ void lstm_mkldnn(Tensor& output, Tensor& hy, Tensor& cy,

REGISTER_ALL_CPU_DISPATCH(lstm_mkldnn_stub, &lstm_mkldnn);


std::tuple<Tensor, Tensor, Tensor> lstm_mkldnn_inductor(const Tensor& input, TensorList hx, TensorList params, bool has_biases,
int64_t num_layers, double dropout_p, bool train, bool bidirectional, bool batch_first) {
Tensor output, hy, cy;
lstm_mkldnn(output, hy, cy, input, hx, params, has_biases, num_layers, dropout_p, train, bidirectional, batch_first);
return std::make_tuple(output, hy, cy);
}

TORCH_LIBRARY_IMPL(mkldnn, CPU, m) {
m.impl(
TORCH_SELECTIVE_NAME("mkldnn::_lstm"),
TORCH_FN(lstm_mkldnn_inductor));
}

TORCH_LIBRARY_IMPL(mkldnn, MkldnnCPU, m) {
m.impl(
TORCH_SELECTIVE_NAME("mkldnn::_lstm"),
TORCH_FN(lstm_mkldnn_inductor));
}

} // namespace at::native

#endif // AT_MKLDNN_EBABLED
4 changes: 4 additions & 0 deletions aten/src/ATen/native/mkldnn/RegisterMkldnnOpContextClass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,14 @@ TORCH_LIBRARY(mkldnn, m) {
"mkldnn::_convolution_pointwise_.binary(Tensor(a!) other, Tensor X, Tensor W, Tensor? B, int[] padding, int[] stride, int[] dilation, int groups, str binary_attr, Scalar? alpha, str? unary_attr, Scalar?[] unary_scalars, str? unary_algorithm) -> Tensor(a!) Y"));
m.def(TORCH_SELECTIVE_SCHEMA(
"mkldnn::_convolution_transpose_pointwise(Tensor X, Tensor W, Tensor? B, int[] padding, int[] output_padding, int[] stride, int[] dilation, int groups, str attr, Scalar?[] scalars, str? algorithm) -> Tensor Y"));
m.def(TORCH_SELECTIVE_SCHEMA(
"mkldnn::_lstm(Tensor input, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor Y, Tensor hy, Tensor cy)"));
m.def(TORCH_SELECTIVE_SCHEMA(
"mkldnn::_reorder_convolution_transpose_weight(Tensor self, int[2] padding=0, int[2] output_padding=0, int[2] stride=1, int[2] dilation=1, int groups=1, int[]? input_size=None) -> Tensor Y"));
m.def(TORCH_SELECTIVE_SCHEMA(
"mkldnn::_reorder_linear_weight(Tensor self, int? batch_size=None) -> Tensor Y"));
m.def(TORCH_SELECTIVE_SCHEMA(
"mkldnn::_reorder_lstm_weight(Tensor[] params, int input_feature_size, int hidden_size, bool has_biases, int num_layers, bool bidirectional, bool batch_first, int[]? input_size=None) -> Tensor[] Y"));
m.def("_is_mkldnn_bf16_supported", &is_mkldnn_bf16_supported);
}

Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3823,6 +3823,7 @@
- func: mkldnn_rnn_layer(Tensor input, Tensor weight0, Tensor weight1, Tensor weight2, Tensor weight3, Tensor hx_, Tensor cx_, bool reverse, int[] batch_sizes, int mode, int hidden_size, int num_layers, bool has_biases, bool bidirectional, bool batch_first, bool train) -> (Tensor, Tensor, Tensor, Tensor)
dispatch:
CPU: mkldnn_rnn_layer
MkldnnCPU: mkldnn_rnn_layer
autogen: mkldnn_rnn_layer.out

- func: mkldnn_rnn_layer_backward(Tensor input, Tensor weight1, Tensor weight2, Tensor weight3, Tensor weight4, Tensor hx_, Tensor cx_tmp, Tensor output, Tensor hy_, Tensor cy_, Tensor? grad_output, Tensor? grad_hy, Tensor? grad_cy, bool reverse, int mode, int hidden_size, int num_layers, bool has_biases, bool train, bool bidirectional, int[] batch_sizes, bool batch_first, Tensor workspace) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)
Expand Down

0 comments on commit 3599a1a

Please sign in to comment.