Skip to content

Commit

Permalink
inductor: enable weight prepack for LSTM
Browse files Browse the repository at this point in the history
ghstack-source-id: 3872e8235150d33260e06c0f8efe398766f6f392
Pull Request resolved: #103071
  • Loading branch information
chunyuan-w committed Jul 26, 2023
1 parent dfc9874 commit 3b207f7
Show file tree
Hide file tree
Showing 14 changed files with 747 additions and 62 deletions.
17 changes: 17 additions & 0 deletions aten/src/ATen/native/mkldnn/MKLDNNCommon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,23 @@ ideep::tensor itensor_view_from_dense(const Tensor& tensor) {
}
}

ideep::tensor itensor_view_from_dense(
const at::Tensor& tensor,
const ideep::tensor::desc& desc) {
TORCH_CHECK(
tensor.device().is_cpu(),
"itensor_view_from_dense expects CPU tensor input");
TORCH_CHECK(
tensor.layout() == at::Layout::Strided,
"itensor_view_from_dense expects dense tensor input");
TORCH_CHECK(
tensor.scalar_type() == at::ScalarType::Float ||
tensor.scalar_type() == at::ScalarType::BFloat16 ||
tensor.scalar_type() == at::ScalarType::Half,
"itensor_view_from_dense expects float, bfloat16 or half tensor input");
return {desc, tensor.data_ptr()};
}

// Helper function for getting an ideep tensor out of an aten Tensor.
// Note in case the aten Tensor is a dense tensor, the returned ideep
// tensor is just a view of the storage of the aten dense tensor, so
Expand Down
6 changes: 6 additions & 0 deletions aten/src/ATen/native/mkldnn/MKLDNNCommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@ TORCH_API ideep::tensor& itensor_from_mkldnn(const Tensor& mkldnn_tensor);
// ideep::tensor will share the underlying buffer
TORCH_API ideep::tensor itensor_view_from_dense(const Tensor& tensor);

// Construct an `ideep::tensor` "view" from dense tensor using given desc, note
// the ideep::tensor will share the underlying buffer
TORCH_API ideep::tensor itensor_view_from_dense(
const at::Tensor& tensor,
const ideep::tensor::desc& desc);

// Helper function for getting an ideep tensor out of an aten Tensor or MKL-DNN tensor.
TORCH_API ideep::tensor itensor_from_tensor(const Tensor& tensor);

Expand Down
149 changes: 149 additions & 0 deletions aten/src/ATen/native/mkldnn/MKLDNNConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
#include <ATen/native/mkldnn/Utils.h>
#include <ATen/native/utils/ParamUtils.h>
#include <torch/library.h>
#include <ATen/MatrixRef.h>
#include <tuple>

#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
Expand All @@ -16,6 +18,7 @@
#include <ATen/ops/mkldnn_reorder_conv2d_weight_native.h>
#include <ATen/ops/mkldnn_reorder_conv3d_weight_native.h>
#include <ATen/ops/to_mkldnn_native.h>
#include <ATen/ops/zeros.h>
#endif


Expand Down Expand Up @@ -339,6 +342,149 @@ static Tensor mkldnn_reorder_conv_transpose2d_weight(
self.options().device_opt());
}

static std::tuple<ideep::tensor, ideep::tensor> get_lstm_packed_weights(
const at::Tensor& weight_ih,
const at::Tensor& weight_hh,
const at::Tensor& weight2,
const at::Tensor& weight3,
int64_t layer_feature_size,
int64_t hidden_size,
bool has_biases,
int64_t num_layers,
bool bidirectional,
int64_t time_step,
int64_t batch_size,
bool reverse) {

ideep::tensor cached_weight_ih, cached_weight_hh;

int64_t num_gates = 4;
int64_t num_bias_gates = 4;
std::vector<int64_t> output_sizes = {time_step, batch_size, hidden_size};

auto dtype = get_mkldnn_dtype(weight_ih.scalar_type());
ideep::tensor::desc src_layer_desc({time_step, batch_size, layer_feature_size}, dtype, ideep::format_tag::tnc);
ideep::tensor::desc src_iter_desc({1, 1, batch_size, hidden_size}, dtype, ideep::format_tag::ldnc);
ideep::tensor::desc src_iter_c_desc({1, 1, batch_size, hidden_size}, dtype, ideep::format_tag::ldnc);
ideep::tensor::desc bias_desc({1, 1, num_bias_gates, hidden_size}, dtype, ideep::format_tag::ldgo);

ideep::tensor::desc dst_layer_desc({time_step, batch_size, hidden_size}, dtype, ideep::format_tag::tnc);
ideep::tensor::desc dst_iter_desc({1, 1, batch_size, hidden_size}, dtype, ideep::format_tag::ldnc);
ideep::tensor::desc dst_iter_c_desc({1, 1, batch_size, hidden_size}, dtype, ideep::format_tag::ldnc);

ideep::tensor src_layer(src_layer_desc);
ideep::tensor src_iter(src_iter_desc);
ideep::tensor src_iter_c(src_iter_c_desc);
ideep::tensor bias(bias_desc);

auto w1 = itensor_view_from_dense(
weight_ih,
{{1, 1, layer_feature_size, num_gates, hidden_size},
get_mkldnn_dtype(weight_ih.scalar_type()),
ideep::format_tag::ldgoi});

auto w2 = itensor_view_from_dense(
weight_hh,
{{1, 1, hidden_size, num_gates, hidden_size},
get_mkldnn_dtype(weight_hh.scalar_type()),
ideep::format_tag::ldgoi});

ideep::tensor::desc packed_desc_ih, packed_desc_hh;

std::tie(packed_desc_ih, packed_desc_hh) =
ideep::lstm_forward_inference::expected_weights_desc(
output_sizes,
src_layer,
src_iter,
src_iter_c,
w1,
w2,
bias,
reverse);

cached_weight_ih.init(packed_desc_ih);
cached_weight_hh.init(packed_desc_hh);

cached_weight_ih.feed_from(w1);
cached_weight_hh.feed_from(w2);

return std::make_tuple(cached_weight_ih, cached_weight_hh);
}

static bool should_use_plain_format(ideep::tensor w) {
#if defined(IDEEP_VERSION_MAJOR) && IDEEP_VERSION_MAJOR>=3
return w.get_desc().is_opaque() || w.get_desc().is_plain();
# else
return w.get_desc().is_rnn_packed() || w.get_desc().is_plain();
#endif
}

static std::vector<Tensor> mkldnn_reorder_mkldnn_rnn_layer_weight(
Tensor weight0,
Tensor weight1,
int64_t hidden_size,
bool reverse,
bool has_biases,
bool batch_first,
c10::OptionalArrayRef<int64_t> input_size) {

std::vector<int64_t> input_size_value;
int64_t time_step, batch_size;
if (input_size.has_value()) {
input_size_value = input_size.value().vec();
int64_t time_index = batch_first ? 1: 0;
int64_t batch_size_index = batch_first ? 0: 1;

time_step = input_size_value[time_index];
batch_size = input_size_value[batch_size_index];
} else {
// no value fed, provide one here
time_step = 5;
batch_size = 10;
}

std::vector<Tensor> result(2);

ideep::tensor w1_, w2_;
at::Tensor packed_w1, packed_w2;

int64_t feature_size = weight0.size(-1);

std::tie(w1_, w2_) = get_lstm_packed_weights(
weight0,
weight1,
at::zeros(
weight0.sizes(),
weight0.options()),
at::zeros(
weight1.sizes(),
weight1.options()),
feature_size,
hidden_size,
has_biases, // has_biases
1, // num_layers
false, // bidirectional
time_step,
batch_size,
reverse);

if (should_use_plain_format(w1_)) {
packed_w1 = weight0;
} else {
packed_w1 = new_with_itensor_mkldnn(std::move(w1_), optTypeMetaToScalarType(weight0.options().dtype_opt()), weight0.options().device_opt());
}

if (should_use_plain_format(w2_)) {
packed_w2 = weight1;
} else {
packed_w2 = new_with_itensor_mkldnn(std::move(w2_), optTypeMetaToScalarType(weight1.options().dtype_opt()), weight1.options().device_opt());
}

result[0] = packed_w1;
result[1] = packed_w2;
return result;
}

TORCH_LIBRARY_IMPL(mkldnn, CPU, m) {
m.impl(
TORCH_SELECTIVE_NAME("mkldnn::_reorder_convolution_transpose_weight"),
Expand All @@ -349,6 +495,9 @@ TORCH_LIBRARY_IMPL(mkldnn, CPU, m) {
m.impl(
TORCH_SELECTIVE_NAME("mkldnn::_reorder_convolution_weight"),
TORCH_FN(mkldnn_reorder_conv2d_weight));
m.impl(
TORCH_SELECTIVE_NAME("mkldnn::_reorder_mkldnn_rnn_layer_weight"),
TORCH_FN(mkldnn_reorder_mkldnn_rnn_layer_weight));
}

#else
Expand Down
89 changes: 38 additions & 51 deletions aten/src/ATen/native/mkldnn/RNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <c10/core/GradMode.h>
#include <c10/macros/Macros.h>
#include <c10/util/Exception.h>
#include <torch/library.h>

#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/NativeFunctions.h>
Expand Down Expand Up @@ -214,23 +215,6 @@ static Tensor _shuffle_bias(const Tensor& bias_ih, const Tensor& bias_hh, int64_
return bias_ih + bias_hh;
}

// Create mkldnn memory view from ATen tensor
static inline ideep::tensor get_mkldnn_tensor(
const Tensor& tensor, const ideep::tensor::desc& desc) {
TORCH_CHECK(
tensor.device().is_cpu(),
"get_mkldnn_tensor expects CPU tensor input");
TORCH_CHECK(
tensor.layout() == at::Layout::Strided,
"get_mkldnn_tensor expects dense tensor input");
TORCH_CHECK(
tensor.scalar_type() == at::ScalarType::Float ||
tensor.scalar_type() == at::ScalarType::BFloat16 ||
tensor.scalar_type() == at::ScalarType::Half,
"get_mkldnn_tensor expects float or bfloat16 tensor input");
return {desc, tensor.data_ptr()};
}

std::tuple<Tensor, Tensor, Tensor, Tensor> mkldnn_rnn_layer(const Tensor& input,
const Tensor& w0,
const Tensor& w1,
Expand Down Expand Up @@ -266,30 +250,31 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> mkldnn_rnn_layer(const Tensor& input,
auto weight_ih = _shuffle_weight(w0, rnn.mode);
auto weight_hh = _shuffle_weight(w1, rnn.mode);

// Packed weight will be mkldnn layout while bias won't be packed
auto bias = has_biases
? _shuffle_bias(w2, w3, rnn.mode)
: at::zeros({rnn.num_bias_gates * rnn.hidden_size}, weight_ih.options());
: at::zeros({rnn.num_bias_gates * rnn.hidden_size}, weight_ih.options().layout(at::Layout::Strided));

// per layer input size
int64_t input_size = input.size(2);
auto x = get_mkldnn_tensor(
ideep::tensor w1_, w2_;
auto x = itensor_view_from_dense(
input,
rnn.src_layer_desc(input_size, get_mkldnn_dtype(input)));
auto hx = get_mkldnn_tensor(
auto hx = itensor_view_from_dense(
hx_, rnn.src_iter_desc(get_mkldnn_dtype(hx_)));
auto cx = get_mkldnn_tensor(
auto cx = itensor_view_from_dense(
cx_, rnn.src_iter_c_desc(get_mkldnn_dtype(cx_)));
auto b = get_mkldnn_tensor(
auto b = itensor_view_from_dense(
bias, rnn.bias_desc(get_mkldnn_dtype(bias)));
auto y = get_mkldnn_tensor(
auto y = itensor_view_from_dense(
output, rnn.dst_layer_desc(get_mkldnn_dtype(output)));
auto hy = get_mkldnn_tensor(
auto hy = itensor_view_from_dense(
hy_, rnn.dst_iter_desc(get_mkldnn_dtype(hy_)));
auto cy = get_mkldnn_tensor(
auto cy = itensor_view_from_dense(
cy_, rnn.dst_iter_c_desc(get_mkldnn_dtype(cy_)));
auto w1_ = get_mkldnn_tensor(weight_ih, rnn.weights_layer_desc(input_size, get_mkldnn_dtype(weight_ih)));
auto w2_ = get_mkldnn_tensor(weight_hh, rnn.weights_iter_desc(get_mkldnn_dtype(weight_hh)));

w1_ = weight_ih.is_mkldnn() ? itensor_from_tensor(weight_ih) : itensor_view_from_dense(weight_ih, rnn.weights_layer_desc(input_size, get_mkldnn_dtype(weight_ih)));
w2_ = weight_hh.is_mkldnn() ? itensor_from_tensor(weight_hh) : itensor_view_from_dense(weight_hh, rnn.weights_iter_desc(get_mkldnn_dtype(weight_hh)));
if (at::GradMode::is_enabled()) {
Tensor workspace = Tensor();
auto pd = ideep::lstm_forward_training::prepare(
Expand Down Expand Up @@ -362,27 +347,27 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> mkldnn_rnn_la

// per layer input size
int64_t input_size = input.size(2);
auto x = get_mkldnn_tensor(
auto x = itensor_view_from_dense(
input,
rnn.src_layer_desc(input_size, get_mkldnn_dtype(input.scalar_type())));
auto hx = get_mkldnn_tensor(
auto hx = itensor_view_from_dense(
hx_, rnn.src_iter_desc(get_mkldnn_dtype(hx_.scalar_type())));
auto cx = get_mkldnn_tensor(
auto cx = itensor_view_from_dense(
cx_, rnn.src_iter_c_desc(get_mkldnn_dtype(cx_.scalar_type())));
auto w1 = get_mkldnn_tensor(
auto w1 = itensor_view_from_dense(
weight_ih,
rnn.weights_layer_desc(
input_size, get_mkldnn_dtype(weight_ih.scalar_type())));
auto w2 = get_mkldnn_tensor(
auto w2 = itensor_view_from_dense(
weight_hh,
rnn.weights_iter_desc(get_mkldnn_dtype(weight_hh.scalar_type())));
auto b = get_mkldnn_tensor(
auto b = itensor_view_from_dense(
bias, rnn.bias_desc(get_mkldnn_dtype(bias.scalar_type())));
auto y = get_mkldnn_tensor(
auto y = itensor_view_from_dense(
output, rnn.dst_layer_desc(get_mkldnn_dtype(output.scalar_type())));
auto hy = get_mkldnn_tensor(
auto hy = itensor_view_from_dense(
hy_, rnn.dst_iter_desc(get_mkldnn_dtype(hy_.scalar_type())));
auto cy = get_mkldnn_tensor(
auto cy = itensor_view_from_dense(
cy_, rnn.dst_iter_c_desc(get_mkldnn_dtype(cy_.scalar_type())));

// Create diff_* ATen tensor and corresponding ideep tensor as fp32
Expand All @@ -399,18 +384,18 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> mkldnn_rnn_la
auto diff_b_ =
at::empty(bias.sizes(), bias.options().dtype(at::ScalarType::Float));

auto diff_x = get_mkldnn_tensor(
auto diff_x = itensor_view_from_dense(
diff_x_, rnn.src_layer_desc(input_size, ideep::tensor::data_type::f32));
auto diff_hx = get_mkldnn_tensor(
auto diff_hx = itensor_view_from_dense(
diff_hx_, rnn.src_iter_desc(ideep::tensor::data_type::f32));
auto diff_cx = get_mkldnn_tensor(
auto diff_cx = itensor_view_from_dense(
diff_cx_, rnn.src_iter_c_desc(ideep::tensor::data_type::f32));
auto diff_w1 = get_mkldnn_tensor(
auto diff_w1 = itensor_view_from_dense(
diff_w1_,
rnn.weights_layer_desc(input_size, ideep::tensor::data_type::f32));
auto diff_w2 = get_mkldnn_tensor(
auto diff_w2 = itensor_view_from_dense(
diff_w2_, rnn.weights_iter_desc(ideep::tensor::data_type::f32));
auto diff_b = get_mkldnn_tensor(
auto diff_b = itensor_view_from_dense(
diff_b_, rnn.bias_desc(ideep::tensor::data_type::f32));

// Convert grad_y, grad_hy, grad_cy to fp32 in non-fp32 backward
Expand All @@ -428,18 +413,18 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> mkldnn_rnn_la
grad_cy.sizes(), grad_cy.options().dtype(at::ScalarType::Float));
grad_cy_.copy_(grad_cy);

diff_y = get_mkldnn_tensor(
diff_y = itensor_view_from_dense(
grad_y_, rnn.dst_layer_desc(get_mkldnn_dtype(grad_y_.scalar_type())));
diff_hy = get_mkldnn_tensor(
diff_hy = itensor_view_from_dense(
grad_hy_, rnn.dst_iter_desc(get_mkldnn_dtype(grad_hy_.scalar_type())));
diff_cy = get_mkldnn_tensor(
diff_cy = itensor_view_from_dense(
grad_cy_, rnn.dst_iter_desc(get_mkldnn_dtype(grad_cy_.scalar_type())));
} else {
diff_y = get_mkldnn_tensor(
diff_y = itensor_view_from_dense(
grad_output, rnn.dst_layer_desc(ideep::tensor::data_type::f32));
diff_hy = get_mkldnn_tensor(
diff_hy = itensor_view_from_dense(
grad_hy, rnn.dst_iter_desc(ideep::tensor::data_type::f32));
diff_cy = get_mkldnn_tensor(
diff_cy = itensor_view_from_dense(
grad_cy, rnn.dst_iter_desc(ideep::tensor::data_type::f32));
}

Expand Down Expand Up @@ -503,9 +488,10 @@ static std::tuple<Tensor, Tensor, Tensor> mkldnn_rnn(
auto layer_hx = hx[index];
auto layer_cx = cx[index];
auto reverse = (direction > 0);
// bias won't be packed
auto outputs = at::mkldnn_rnn_layer(layer_input, layer_weights[0], layer_weights[1],
has_biases ? layer_weights[2] : at::zeros(layer_weights[0].sizes(), layer_weights[0].options()),
has_biases ? layer_weights[3] : at::zeros(layer_weights[1].sizes(), layer_weights[1].options()), layer_hx,
has_biases ? layer_weights[2] : at::zeros(layer_weights[0].sizes(), layer_weights[0].options().layout(at::Layout::Strided)),
has_biases ? layer_weights[3] : at::zeros(layer_weights[1].sizes(), layer_weights[1].options().layout(at::Layout::Strided)), layer_hx,
layer_cx, reverse, batch_sizes, mode, hidden_size, num_layers, has_biases, bidirectional, batch_first, train);
layer_output[direction] = std::get<0>(outputs);
layer_hy[index] = std::get<1>(outputs);
Expand Down Expand Up @@ -579,6 +565,7 @@ void lstm_mkldnn(Tensor& output, Tensor& hy, Tensor& cy,

REGISTER_ALL_CPU_DISPATCH(lstm_mkldnn_stub, &lstm_mkldnn);


} // namespace at::native

#endif // AT_MKLDNN_EBABLED

0 comments on commit 3b207f7

Please sign in to comment.