Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

inductor: enable weight prepack for LSTM #103071

Closed
wants to merge 35 commits into from
Closed
Show file tree
Hide file tree
Changes from 32 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
066e30c
inductor: enable weight prepack for LSTM
chunyuan-w Jun 6, 2023
87d72ad
Update on "inductor: enable weight prepack for LSTM"
chunyuan-w Jun 6, 2023
f24b183
Update on "inductor: enable weight prepack for LSTM"
chunyuan-w Jun 7, 2023
cef8c40
Update on "inductor: enable weight prepack for LSTM"
chunyuan-w Jun 7, 2023
92c0298
Update on "inductor: enable weight prepack for LSTM"
chunyuan-w Jun 7, 2023
fca0fbd
Update on "inductor: enable weight prepack for LSTM"
chunyuan-w Jun 7, 2023
7d5d6ad
Update on "inductor: enable weight prepack for LSTM"
chunyuan-w Jun 7, 2023
08d68e1
Update on "inductor: enable weight prepack for LSTM"
chunyuan-w Jun 9, 2023
e757681
Update on "inductor: enable weight prepack for LSTM"
chunyuan-w Jun 9, 2023
a2ec005
Update on "inductor: enable weight prepack for LSTM"
chunyuan-w Jun 9, 2023
afe37ad
Update on "inductor: enable weight prepack for LSTM"
chunyuan-w Jun 12, 2023
0283e8f
Update on "inductor: enable weight prepack for LSTM"
chunyuan-w Jun 12, 2023
4c504b5
Update on "inductor: enable weight prepack for LSTM"
chunyuan-w Jun 12, 2023
7f16ca4
Update on "inductor: enable weight prepack for LSTM"
chunyuan-w Jun 26, 2023
f808289
Update on "inductor: enable weight prepack for LSTM"
chunyuan-w Jun 26, 2023
f3a7c1f
Update on "inductor: enable weight prepack for LSTM"
chunyuan-w Jun 27, 2023
5bcaf40
Update on "inductor: enable weight prepack for LSTM"
chunyuan-w Jun 29, 2023
5176daf
Update on "inductor: enable weight prepack for LSTM"
chunyuan-w Jul 20, 2023
1070f49
Update on "inductor: enable weight prepack for LSTM"
chunyuan-w Jul 20, 2023
e29ee33
Update on "inductor: enable weight prepack for LSTM"
chunyuan-w Jul 20, 2023
63d145f
Update on "inductor: enable weight prepack for LSTM"
chunyuan-w Jul 26, 2023
529fe8b
Update on "inductor: enable weight prepack for LSTM"
chunyuan-w Jul 26, 2023
6545e19
Update on "inductor: enable weight prepack for LSTM"
chunyuan-w Jul 26, 2023
781a977
Update on "inductor: enable weight prepack for LSTM"
chunyuan-w Jul 27, 2023
8e0cb66
Update on "inductor: enable weight prepack for LSTM"
chunyuan-w Jul 27, 2023
1e72dc6
Update on "inductor: enable weight prepack for LSTM"
chunyuan-w Jul 27, 2023
bcd05fe
Update on "inductor: enable weight prepack for LSTM"
chunyuan-w Jul 27, 2023
b68ea92
Update on "inductor: enable weight prepack for LSTM"
chunyuan-w Jul 27, 2023
9d5012d
Update on "inductor: enable weight prepack for LSTM"
chunyuan-w Jul 27, 2023
c943558
Update on "inductor: enable weight prepack for LSTM"
chunyuan-w Jul 27, 2023
5e80392
Update on "inductor: enable weight prepack for LSTM"
chunyuan-w Jul 27, 2023
5f986f8
Update on "inductor: enable weight prepack for LSTM"
chunyuan-w Jul 27, 2023
8d5f080
Update on "inductor: enable weight prepack for LSTM"
chunyuan-w Jul 27, 2023
e703b44
Update on "inductor: enable weight prepack for LSTM"
chunyuan-w Jul 27, 2023
572a694
Update on "inductor: enable weight prepack for LSTM"
chunyuan-w Jul 28, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
17 changes: 17 additions & 0 deletions aten/src/ATen/native/mkldnn/MKLDNNCommon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,23 @@ ideep::tensor itensor_view_from_dense(const Tensor& tensor) {
}
}

ideep::tensor itensor_view_from_dense(
const at::Tensor& tensor,
const ideep::tensor::desc& desc) {
TORCH_CHECK(
tensor.device().is_cpu(),
"itensor_view_from_dense expects CPU tensor input");
TORCH_CHECK(
tensor.layout() == at::Layout::Strided,
"itensor_view_from_dense expects dense tensor input");
TORCH_CHECK(
tensor.scalar_type() == at::ScalarType::Float ||
tensor.scalar_type() == at::ScalarType::BFloat16 ||
tensor.scalar_type() == at::ScalarType::Half,
"itensor_view_from_dense expects float, bfloat16 or half tensor input");
return {desc, tensor.data_ptr()};
}

// Helper function for getting an ideep tensor out of an aten Tensor.
// Note in case the aten Tensor is a dense tensor, the returned ideep
// tensor is just a view of the storage of the aten dense tensor, so
Expand Down
6 changes: 6 additions & 0 deletions aten/src/ATen/native/mkldnn/MKLDNNCommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@ TORCH_API ideep::tensor& itensor_from_mkldnn(const Tensor& mkldnn_tensor);
// ideep::tensor will share the underlying buffer
TORCH_API ideep::tensor itensor_view_from_dense(const Tensor& tensor);

// Construct an `ideep::tensor` "view" from dense tensor using given desc, note
// the ideep::tensor will share the underlying buffer
TORCH_API ideep::tensor itensor_view_from_dense(
const at::Tensor& tensor,
const ideep::tensor::desc& desc);

// Helper function for getting an ideep tensor out of an aten Tensor or MKL-DNN tensor.
TORCH_API ideep::tensor itensor_from_tensor(const Tensor& tensor);

Expand Down
142 changes: 142 additions & 0 deletions aten/src/ATen/native/mkldnn/MKLDNNConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include <ATen/ops/mkldnn_reorder_conv2d_weight_native.h>
#include <ATen/ops/mkldnn_reorder_conv3d_weight_native.h>
#include <ATen/ops/to_mkldnn_native.h>
#include <ATen/ops/zeros.h>
#endif


Expand Down Expand Up @@ -339,6 +340,144 @@ static Tensor mkldnn_reorder_conv_transpose2d_weight(
self.options().device_opt());
}

static std::tuple<ideep::tensor, ideep::tensor> get_lstm_packed_weights(
const at::Tensor& weight_ih,
const at::Tensor& weight_hh,
const at::Tensor& weight2,
const at::Tensor& weight3,
int64_t layer_feature_size,
int64_t hidden_size,
bool has_biases,
int64_t num_layers,
bool bidirectional,
int64_t time_step,
int64_t batch_size,
bool reverse) {

ideep::tensor cached_weight_ih, cached_weight_hh;

int64_t num_gates = 4;
int64_t num_bias_gates = 4;
std::vector<int64_t> output_sizes = {time_step, batch_size, hidden_size};

auto dtype = get_mkldnn_dtype(weight_ih.scalar_type());
ideep::tensor::desc src_layer_desc({time_step, batch_size, layer_feature_size}, dtype, ideep::format_tag::tnc);
ideep::tensor::desc src_iter_desc({1, 1, batch_size, hidden_size}, dtype, ideep::format_tag::ldnc);
ideep::tensor::desc src_iter_c_desc({1, 1, batch_size, hidden_size}, dtype, ideep::format_tag::ldnc);
ideep::tensor::desc bias_desc({1, 1, num_bias_gates, hidden_size}, dtype, ideep::format_tag::ldgo);

ideep::tensor::desc dst_layer_desc({time_step, batch_size, hidden_size}, dtype, ideep::format_tag::tnc);
ideep::tensor::desc dst_iter_desc({1, 1, batch_size, hidden_size}, dtype, ideep::format_tag::ldnc);
ideep::tensor::desc dst_iter_c_desc({1, 1, batch_size, hidden_size}, dtype, ideep::format_tag::ldnc);

ideep::tensor src_layer(src_layer_desc);
ideep::tensor src_iter(src_iter_desc);
ideep::tensor src_iter_c(src_iter_c_desc);
ideep::tensor bias(bias_desc);

auto w1 = itensor_view_from_dense(
weight_ih,
{{1, 1, layer_feature_size, num_gates, hidden_size},
get_mkldnn_dtype(weight_ih.scalar_type()),
ideep::format_tag::ldgoi});

auto w2 = itensor_view_from_dense(
weight_hh,
{{1, 1, hidden_size, num_gates, hidden_size},
get_mkldnn_dtype(weight_hh.scalar_type()),
ideep::format_tag::ldgoi});

ideep::tensor::desc packed_desc_ih, packed_desc_hh;

std::tie(packed_desc_ih, packed_desc_hh) =
ideep::lstm_forward_inference::expected_weights_desc(
output_sizes,
src_layer,
src_iter,
src_iter_c,
w1,
w2,
bias,
reverse);

cached_weight_ih.init(packed_desc_ih);
cached_weight_hh.init(packed_desc_hh);

cached_weight_ih.feed_from(w1);
cached_weight_hh.feed_from(w2);

return std::make_tuple(cached_weight_ih, cached_weight_hh);
}

static bool should_use_plain_format(ideep::tensor w) {
#if defined(IDEEP_VERSION_MAJOR) && IDEEP_VERSION_MAJOR>=3
return w.get_desc().is_opaque() || w.get_desc().is_plain();
# else
return w.get_desc().is_rnn_packed() || w.get_desc().is_plain();
#endif
}

static std::vector<Tensor> mkldnn_reorder_mkldnn_rnn_layer_weight(
Tensor weight0,
Tensor weight1,
int64_t hidden_size,
bool reverse,
bool has_biases,
bool batch_first,
c10::OptionalArrayRef<int64_t> input_size) {

std::vector<int64_t> input_size_value;
int64_t time_step, batch_size;
if (input_size.has_value()) {
input_size_value = input_size.value().vec();
int64_t time_index = batch_first ? 1: 0;
int64_t batch_size_index = batch_first ? 0: 1;

time_step = input_size_value[time_index];
batch_size = input_size_value[batch_size_index];
} else {
// no value fed, provide one here
time_step = 5;
batch_size = 10;
}

ideep::tensor w1_, w2_;
at::Tensor packed_w1, packed_w2;

int64_t feature_size = weight0.size(-1);

std::tie(w1_, w2_) = get_lstm_packed_weights(
weight0,
weight1,
at::zeros(
weight0.sizes(),
weight0.options()),
at::zeros(
weight1.sizes(),
weight1.options()),
feature_size,
hidden_size,
has_biases, // has_biases
1, // num_layers
false, // bidirectional
time_step,
batch_size,
reverse);

if (should_use_plain_format(w1_)) {
packed_w1 = weight0;
} else {
packed_w1 = new_with_itensor_mkldnn(std::move(w1_), optTypeMetaToScalarType(weight0.options().dtype_opt()), weight0.options().device_opt());
}

if (should_use_plain_format(w2_)) {
packed_w2 = weight1;
} else {
packed_w2 = new_with_itensor_mkldnn(std::move(w2_), optTypeMetaToScalarType(weight1.options().dtype_opt()), weight1.options().device_opt());
}
return {packed_w1, packed_w2};
}

TORCH_LIBRARY_IMPL(mkldnn, CPU, m) {
m.impl(
TORCH_SELECTIVE_NAME("mkldnn::_reorder_convolution_transpose_weight"),
Expand All @@ -349,6 +488,9 @@ TORCH_LIBRARY_IMPL(mkldnn, CPU, m) {
m.impl(
TORCH_SELECTIVE_NAME("mkldnn::_reorder_convolution_weight"),
TORCH_FN(mkldnn_reorder_conv2d_weight));
m.impl(
TORCH_SELECTIVE_NAME("mkldnn::_reorder_mkldnn_rnn_layer_weight"),
TORCH_FN(mkldnn_reorder_mkldnn_rnn_layer_weight));
}

#else
Expand Down
88 changes: 37 additions & 51 deletions aten/src/ATen/native/mkldnn/RNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <c10/core/GradMode.h>
#include <c10/macros/Macros.h>
#include <c10/util/Exception.h>
#include <torch/library.h>

#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/NativeFunctions.h>
Expand Down Expand Up @@ -214,23 +215,6 @@ static Tensor _shuffle_bias(const Tensor& bias_ih, const Tensor& bias_hh, int64_
return bias_ih + bias_hh;
}

// Create mkldnn memory view from ATen tensor
static inline ideep::tensor get_mkldnn_tensor(
const Tensor& tensor, const ideep::tensor::desc& desc) {
TORCH_CHECK(
tensor.device().is_cpu(),
"get_mkldnn_tensor expects CPU tensor input");
TORCH_CHECK(
tensor.layout() == at::Layout::Strided,
"get_mkldnn_tensor expects dense tensor input");
TORCH_CHECK(
tensor.scalar_type() == at::ScalarType::Float ||
tensor.scalar_type() == at::ScalarType::BFloat16 ||
tensor.scalar_type() == at::ScalarType::Half,
"get_mkldnn_tensor expects float or bfloat16 tensor input");
return {desc, tensor.data_ptr()};
}

std::tuple<Tensor, Tensor, Tensor, Tensor> mkldnn_rnn_layer(const Tensor& input,
const Tensor& w0,
const Tensor& w1,
Expand Down Expand Up @@ -266,30 +250,31 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> mkldnn_rnn_layer(const Tensor& input,
auto weight_ih = _shuffle_weight(w0, rnn.mode);
auto weight_hh = _shuffle_weight(w1, rnn.mode);

// Packed weight will be mkldnn layout while bias won't be packed
auto bias = has_biases
? _shuffle_bias(w2, w3, rnn.mode)
: at::zeros({rnn.num_bias_gates * rnn.hidden_size}, weight_ih.options());
: at::zeros({rnn.num_bias_gates * rnn.hidden_size}, weight_ih.options().layout(at::Layout::Strided));

// per layer input size
int64_t input_size = input.size(2);
auto x = get_mkldnn_tensor(
ideep::tensor w1_, w2_;
auto x = itensor_view_from_dense(
input,
rnn.src_layer_desc(input_size, get_mkldnn_dtype(input)));
auto hx = get_mkldnn_tensor(
auto hx = itensor_view_from_dense(
hx_, rnn.src_iter_desc(get_mkldnn_dtype(hx_)));
auto cx = get_mkldnn_tensor(
auto cx = itensor_view_from_dense(
cx_, rnn.src_iter_c_desc(get_mkldnn_dtype(cx_)));
auto b = get_mkldnn_tensor(
auto b = itensor_view_from_dense(
bias, rnn.bias_desc(get_mkldnn_dtype(bias)));
auto y = get_mkldnn_tensor(
auto y = itensor_view_from_dense(
output, rnn.dst_layer_desc(get_mkldnn_dtype(output)));
auto hy = get_mkldnn_tensor(
auto hy = itensor_view_from_dense(
hy_, rnn.dst_iter_desc(get_mkldnn_dtype(hy_)));
auto cy = get_mkldnn_tensor(
auto cy = itensor_view_from_dense(
cy_, rnn.dst_iter_c_desc(get_mkldnn_dtype(cy_)));
auto w1_ = get_mkldnn_tensor(weight_ih, rnn.weights_layer_desc(input_size, get_mkldnn_dtype(weight_ih)));
auto w2_ = get_mkldnn_tensor(weight_hh, rnn.weights_iter_desc(get_mkldnn_dtype(weight_hh)));

w1_ = weight_ih.is_mkldnn() ? itensor_from_tensor(weight_ih) : itensor_view_from_dense(weight_ih, rnn.weights_layer_desc(input_size, get_mkldnn_dtype(weight_ih)));
w2_ = weight_hh.is_mkldnn() ? itensor_from_tensor(weight_hh) : itensor_view_from_dense(weight_hh, rnn.weights_iter_desc(get_mkldnn_dtype(weight_hh)));
if (at::GradMode::is_enabled()) {
Tensor workspace = Tensor();
auto pd = ideep::lstm_forward_training::prepare(
Expand Down Expand Up @@ -362,27 +347,27 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> mkldnn_rnn_la

// per layer input size
int64_t input_size = input.size(2);
auto x = get_mkldnn_tensor(
auto x = itensor_view_from_dense(
input,
rnn.src_layer_desc(input_size, get_mkldnn_dtype(input.scalar_type())));
auto hx = get_mkldnn_tensor(
auto hx = itensor_view_from_dense(
hx_, rnn.src_iter_desc(get_mkldnn_dtype(hx_.scalar_type())));
auto cx = get_mkldnn_tensor(
auto cx = itensor_view_from_dense(
cx_, rnn.src_iter_c_desc(get_mkldnn_dtype(cx_.scalar_type())));
auto w1 = get_mkldnn_tensor(
auto w1 = itensor_view_from_dense(
weight_ih,
rnn.weights_layer_desc(
input_size, get_mkldnn_dtype(weight_ih.scalar_type())));
auto w2 = get_mkldnn_tensor(
auto w2 = itensor_view_from_dense(
weight_hh,
rnn.weights_iter_desc(get_mkldnn_dtype(weight_hh.scalar_type())));
auto b = get_mkldnn_tensor(
auto b = itensor_view_from_dense(
bias, rnn.bias_desc(get_mkldnn_dtype(bias.scalar_type())));
auto y = get_mkldnn_tensor(
auto y = itensor_view_from_dense(
output, rnn.dst_layer_desc(get_mkldnn_dtype(output.scalar_type())));
auto hy = get_mkldnn_tensor(
auto hy = itensor_view_from_dense(
hy_, rnn.dst_iter_desc(get_mkldnn_dtype(hy_.scalar_type())));
auto cy = get_mkldnn_tensor(
auto cy = itensor_view_from_dense(
cy_, rnn.dst_iter_c_desc(get_mkldnn_dtype(cy_.scalar_type())));

// Create diff_* ATen tensor and corresponding ideep tensor as fp32
Expand All @@ -399,18 +384,18 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> mkldnn_rnn_la
auto diff_b_ =
at::empty(bias.sizes(), bias.options().dtype(at::ScalarType::Float));

auto diff_x = get_mkldnn_tensor(
auto diff_x = itensor_view_from_dense(
diff_x_, rnn.src_layer_desc(input_size, ideep::tensor::data_type::f32));
auto diff_hx = get_mkldnn_tensor(
auto diff_hx = itensor_view_from_dense(
diff_hx_, rnn.src_iter_desc(ideep::tensor::data_type::f32));
auto diff_cx = get_mkldnn_tensor(
auto diff_cx = itensor_view_from_dense(
diff_cx_, rnn.src_iter_c_desc(ideep::tensor::data_type::f32));
auto diff_w1 = get_mkldnn_tensor(
auto diff_w1 = itensor_view_from_dense(
diff_w1_,
rnn.weights_layer_desc(input_size, ideep::tensor::data_type::f32));
auto diff_w2 = get_mkldnn_tensor(
auto diff_w2 = itensor_view_from_dense(
diff_w2_, rnn.weights_iter_desc(ideep::tensor::data_type::f32));
auto diff_b = get_mkldnn_tensor(
auto diff_b = itensor_view_from_dense(
diff_b_, rnn.bias_desc(ideep::tensor::data_type::f32));

// Convert grad_y, grad_hy, grad_cy to fp32 in non-fp32 backward
Expand All @@ -428,18 +413,18 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> mkldnn_rnn_la
grad_cy.sizes(), grad_cy.options().dtype(at::ScalarType::Float));
grad_cy_.copy_(grad_cy);

diff_y = get_mkldnn_tensor(
diff_y = itensor_view_from_dense(
grad_y_, rnn.dst_layer_desc(get_mkldnn_dtype(grad_y_.scalar_type())));
diff_hy = get_mkldnn_tensor(
diff_hy = itensor_view_from_dense(
grad_hy_, rnn.dst_iter_desc(get_mkldnn_dtype(grad_hy_.scalar_type())));
diff_cy = get_mkldnn_tensor(
diff_cy = itensor_view_from_dense(
grad_cy_, rnn.dst_iter_desc(get_mkldnn_dtype(grad_cy_.scalar_type())));
} else {
diff_y = get_mkldnn_tensor(
diff_y = itensor_view_from_dense(
grad_output, rnn.dst_layer_desc(ideep::tensor::data_type::f32));
diff_hy = get_mkldnn_tensor(
diff_hy = itensor_view_from_dense(
grad_hy, rnn.dst_iter_desc(ideep::tensor::data_type::f32));
diff_cy = get_mkldnn_tensor(
diff_cy = itensor_view_from_dense(
grad_cy, rnn.dst_iter_desc(ideep::tensor::data_type::f32));
}

Expand Down Expand Up @@ -503,9 +488,10 @@ static std::tuple<Tensor, Tensor, Tensor> mkldnn_rnn(
auto layer_hx = hx[index];
auto layer_cx = cx[index];
auto reverse = (direction > 0);
// bias won't be packed
auto outputs = at::mkldnn_rnn_layer(layer_input, layer_weights[0], layer_weights[1],
has_biases ? layer_weights[2] : at::zeros(layer_weights[0].sizes(), layer_weights[0].options()),
has_biases ? layer_weights[3] : at::zeros(layer_weights[1].sizes(), layer_weights[1].options()), layer_hx,
has_biases ? layer_weights[2] : at::zeros(layer_weights[0].sizes(), layer_weights[0].options().layout(at::Layout::Strided)),
has_biases ? layer_weights[3] : at::zeros(layer_weights[1].sizes(), layer_weights[1].options().layout(at::Layout::Strided)), layer_hx,
layer_cx, reverse, batch_sizes, mode, hidden_size, num_layers, has_biases, bidirectional, batch_first, train);
layer_output[direction] = std::get<0>(outputs);
layer_hy[index] = std::get<1>(outputs);
Expand Down
2 changes: 2 additions & 0 deletions aten/src/ATen/native/mkldnn/RegisterMkldnnOpContextClass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ TORCH_LIBRARY(mkldnn, m) {
"mkldnn::_reorder_linear_weight(Tensor self, int? batch_size=None) -> Tensor Y"));
m.def(TORCH_SELECTIVE_SCHEMA(
"mkldnn::_reorder_convolution_weight(Tensor self, int[2] padding=0, int[2] stride=1, int[2] dilation=1, int groups=1, int[]? input_size=None) -> Tensor Y"));
m.def(TORCH_SELECTIVE_SCHEMA(
"mkldnn::_reorder_mkldnn_rnn_layer_weight(Tensor weight0, Tensor weight1, int hidden_size, bool reverse, bool has_biases, bool batch_first, int[]? input_size=None) -> Tensor[] Y"));
m.def("_is_mkldnn_bf16_supported", &is_mkldnn_bf16_supported);
}

Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3840,6 +3840,7 @@
- func: mkldnn_rnn_layer(Tensor input, Tensor weight0, Tensor weight1, Tensor weight2, Tensor weight3, Tensor hx_, Tensor cx_, bool reverse, int[] batch_sizes, int mode, int hidden_size, int num_layers, bool has_biases, bool bidirectional, bool batch_first, bool train) -> (Tensor, Tensor, Tensor, Tensor)
dispatch:
CPU: mkldnn_rnn_layer
MkldnnCPU: mkldnn_rnn_layer
autogen: mkldnn_rnn_layer.out

- func: mkldnn_rnn_layer_backward(Tensor input, Tensor weight1, Tensor weight2, Tensor weight3, Tensor weight4, Tensor hx_, Tensor cx_tmp, Tensor output, Tensor hy_, Tensor cy_, Tensor? grad_output, Tensor? grad_hy, Tensor? grad_cy, bool reverse, int mode, int hidden_size, int num_layers, bool has_biases, bool train, bool bidirectional, int[] batch_sizes, bool batch_first, Tensor workspace) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)
Expand Down