Skip to content

Commit

Permalink
inductor: enable weight prepack for LSTM
Browse files Browse the repository at this point in the history
ghstack-source-id: 3545e5a28d23a8312a1781ee45844b6388ddc15e
Pull Request resolved: #103071
  • Loading branch information
chunyuan-w committed Jun 9, 2023
1 parent f79d2b4 commit bc79232
Show file tree
Hide file tree
Showing 14 changed files with 739 additions and 75 deletions.
17 changes: 17 additions & 0 deletions aten/src/ATen/native/mkldnn/MKLDNNCommon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,23 @@ ideep::tensor itensor_view_from_dense(const Tensor& tensor) {
}
}

ideep::tensor itensor_view_from_dense(
const at::Tensor& tensor,
const ideep::tensor::desc& desc) {
TORCH_CHECK(
tensor.device().is_cpu(),
"itensor_view_from_dense expects CPU tensor input");
TORCH_CHECK(
tensor.layout() == at::Layout::Strided,
"itensor_view_from_dense expects dense tensor input");
TORCH_CHECK(
tensor.scalar_type() == at::ScalarType::Float ||
tensor.scalar_type() == at::ScalarType::BFloat16 ||
tensor.scalar_type() == at::ScalarType::Half,
"itensor_view_from_dense expects float, bfloat16 or half tensor input");
return {desc, tensor.data_ptr()};
}

// Helper function for getting an ideep tensor out of an aten Tensor.
// Note in case the aten Tensor is a dense tensor, the returned ideep
// tensor is just a view of the storage of the aten dense tensor, so
Expand Down
6 changes: 6 additions & 0 deletions aten/src/ATen/native/mkldnn/MKLDNNCommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@ TORCH_API ideep::tensor& itensor_from_mkldnn(const Tensor& mkldnn_tensor);
// ideep::tensor will share the underlying buffer
TORCH_API ideep::tensor itensor_view_from_dense(const Tensor& tensor);

// Construct an `ideep::tensor` "view" from dense tensor using given desc, note
// the ideep::tensor will share the underlying buffer
TORCH_API ideep::tensor itensor_view_from_dense(
const at::Tensor& tensor,
const ideep::tensor::desc& desc);

// Helper function for getting an ideep tensor out of an aten Tensor or MKL-DNN tensor.
TORCH_API ideep::tensor itensor_from_tensor(const Tensor& tensor);

Expand Down
175 changes: 175 additions & 0 deletions aten/src/ATen/native/mkldnn/MKLDNNConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
#include <ATen/native/mkldnn/Utils.h>
#include <ATen/native/utils/ParamUtils.h>
#include <torch/library.h>
#include <ATen/MatrixRef.h>
#include <tuple>

#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
Expand All @@ -16,6 +18,7 @@
#include <ATen/ops/mkldnn_reorder_conv2d_weight_native.h>
#include <ATen/ops/mkldnn_reorder_conv3d_weight_native.h>
#include <ATen/ops/to_mkldnn_native.h>
#include <ATen/ops/zeros.h>
#endif


Expand Down Expand Up @@ -292,6 +295,172 @@ Tensor mkldnn_reorder_conv_transpose2d_weight(
self.options().device_opt());
}

std::tuple<ideep::tensor, ideep::tensor> get_lstm_packed_weights(
const at::Tensor& weight_ih,
const at::Tensor& weight_hh,
const at::Tensor& weight2,
const at::Tensor& weight3,
int64_t layer_feature_size,
int64_t hidden_size,
bool has_biases,
int64_t num_layers,
bool bidirectional,
int64_t time_step,
int64_t batch_size,
bool reverse) {

ideep::tensor cached_weight_ih, cached_weight_hh;

int64_t num_gates = 4;
int64_t num_bias_gates = 4;
std::vector<int64_t> output_sizes = {time_step, batch_size, hidden_size};

auto dtype = get_mkldnn_dtype(weight_ih.scalar_type());
ideep::tensor::desc src_layer_desc({time_step, batch_size, layer_feature_size}, dtype, ideep::format_tag::tnc);
ideep::tensor::desc src_iter_desc({1, 1, batch_size, hidden_size}, dtype, ideep::format_tag::ldnc);
ideep::tensor::desc src_iter_c_desc({1, 1, batch_size, hidden_size}, dtype, ideep::format_tag::ldnc);
ideep::tensor::desc bias_desc({1, 1, num_bias_gates, hidden_size}, dtype, ideep::format_tag::ldgo);

ideep::tensor::desc dst_layer_desc({time_step, batch_size, hidden_size}, dtype, ideep::format_tag::tnc);
ideep::tensor::desc dst_iter_desc({1, 1, batch_size, hidden_size}, dtype, ideep::format_tag::ldnc);
ideep::tensor::desc dst_iter_c_desc({1, 1, batch_size, hidden_size}, dtype, ideep::format_tag::ldnc);

ideep::tensor src_layer(src_layer_desc);
ideep::tensor src_iter(src_iter_desc);
ideep::tensor src_iter_c(src_iter_c_desc);
ideep::tensor bias(bias_desc);

auto w1 = itensor_view_from_dense(
weight_ih,
{{1, 1, layer_feature_size, num_gates, hidden_size},
get_mkldnn_dtype(weight_ih.scalar_type()),
ideep::format_tag::ldgoi});

auto w2 = itensor_view_from_dense(
weight_hh,
{{1, 1, hidden_size, num_gates, hidden_size},
get_mkldnn_dtype(weight_hh.scalar_type()),
ideep::format_tag::ldgoi});

ideep::tensor::desc packed_desc_ih, packed_desc_hh;

std::tie(packed_desc_ih, packed_desc_hh) =
ideep::lstm_forward_inference::expected_weights_desc(
output_sizes,
src_layer,
src_iter,
src_iter_c,
w1,
w2,
bias,
reverse);

cached_weight_ih.init(packed_desc_ih);
cached_weight_hh.init(packed_desc_hh);

cached_weight_ih.feed_from(w1);
cached_weight_hh.feed_from(w2);

return std::make_tuple(cached_weight_ih, cached_weight_hh);
}

std::vector<Tensor> mkldnn_reorder_lstm_weight(
TensorList weight,
int64_t input_feature_size,
int64_t hidden_size,
bool has_biases,
int64_t num_layers,
bool bidirectional,
bool batch_first,
c10::OptionalArrayRef<int64_t> input_size) {
std::vector<int64_t> input_size_value;
int64_t time_step, batch_size;
if (input_size.has_value()) {
input_size_value = input_size.value().vec();
int64_t time_index = batch_first ? 1: 0;
int64_t batch_size_index = batch_first ? 0: 1;

time_step = input_size_value[time_index];
batch_size = input_size_value[batch_size_index];
} else {
// no value fed, provide one here
time_step = 5;
batch_size = 10;
}

std::vector<Tensor> result(weight.size());

auto num_directions = bidirectional ? 2 : 1;
int64_t weight_stride0 = has_biases ? 4 : 2;

at::MatrixRef<at::Tensor> weights{
weight, static_cast<size_t>(weight_stride0)};
ideep::tensor w1_, w2_;
at::Tensor packed_w1, packed_w2;

for (int64_t layer = 0; layer < num_layers; layer++) {
for (int64_t direction = 0; direction < num_directions; direction++) {
// for layer == 0, feature_size = input_feature_size
// otherwise, feature_size = hidden_size
int64_t layer_feature_size = layer == 0? input_feature_size : num_directions * hidden_size;
auto index = layer * num_directions + direction;
auto layer_weights = weights[index];
TORCH_CHECK(layer_weights.size() == 2 || layer_weights.size() == 4);
auto reverse = (direction > 0);

std::tie(w1_, w2_) = get_lstm_packed_weights(
layer_weights[0],
layer_weights[1],
has_biases ? layer_weights[2] : at::zeros(
layer_weights[0].sizes(),
layer_weights[0].options()),
has_biases ? layer_weights[3]
: at::zeros(
layer_weights[1].sizes(),
layer_weights[1].options()),
layer_feature_size,
hidden_size,
has_biases,
num_layers,
bidirectional,
time_step,
batch_size,
reverse);

// TODO: use is_opaque() after updating ideep in pytorch
// Don't pack when the weight is of rnn_packed format
// When the weight is of rnn_packed format, if the seq_lens of
// the input changes, the format of weight also changes.
// oneDNN does not support reorder from rnn_packed back to public format.
// LSTM based on BRGEMM kernel (on AVX512 and newest ISAs) will use blocked
// format for weight of LSTM, which won't change when the input seq_lens
// changes.
// On AVX2, queried weight will be plain format
if (w1_.get_desc().is_rnn_packed() || w1_.get_desc().is_plain()) {
packed_w1 = layer_weights[0];
} else {
packed_w1 = new_with_itensor_mkldnn(std::move(w1_), optTypeMetaToScalarType(layer_weights[0].options().dtype_opt()), layer_weights[0].options().device_opt());
}

if (w2_.get_desc().is_rnn_packed() || w2_.get_desc().is_plain()) {
packed_w2 = layer_weights[1];
} else {
packed_w2 = new_with_itensor_mkldnn(std::move(w2_), optTypeMetaToScalarType(layer_weights[1].options().dtype_opt()), layer_weights[1].options().device_opt());
}

result[index * weight_stride0] = packed_w1;
result[index * weight_stride0+1] = packed_w2;

if (has_biases) {
result[index * weight_stride0+2] = layer_weights[2];
result[index * weight_stride0+3] = layer_weights[3];
}
}
}

return result;
}

TORCH_LIBRARY_IMPL(mkldnn, MkldnnCPU, m) {
m.impl(
TORCH_SELECTIVE_NAME("mkldnn::_reorder_convolution_transpose_weight"),
Expand All @@ -301,6 +470,12 @@ TORCH_LIBRARY_IMPL(mkldnn, MkldnnCPU, m) {
TORCH_FN(mkldnn_reorder_linear_weight));
}

TORCH_LIBRARY_IMPL(mkldnn, CPU, m) {
m.impl(
TORCH_SELECTIVE_NAME("mkldnn::_reorder_lstm_weight"),
TORCH_FN(mkldnn_reorder_lstm_weight));
}

#else

Tensor mkldnn_to_dense(const Tensor& mkldnn_tensor, c10::optional<ScalarType> dtype, c10::optional<bool> masked_grad) {
Expand Down

0 comments on commit bc79232

Please sign in to comment.