Skip to content

Commit

Permalink
Adding support for CuDNN-based LSTM with projections (#47725)
Browse files Browse the repository at this point in the history
Summary:
Fixes #46213

I didn't yet update the documentation, will add those change soon. A few other things that I didn't do, but want to clarify if I maybe should.

1. I didn't expose projections in c++ API: torch/csrc/api/src/nn/modules/rnn.cpp. Let me know if this is desirable and I will add those changes.
2. I didn't expose projections in "lstm_cell" function and "_thnn_differentiable_lstm_cell_backward" functions from aten/src/ATen/native/RNN.cpp. As far as I understand, they are not needed for nn.LSTM CPU execution. For lstm_cell, projections don't bring any real benefit, since if cell is used separately, it can be easily added in Python. For "_thnn_differentiable_lstm_cell_backward", I'm actually not sure where exactly that function is used, so I also disabled projections there for now. Please let me know if I should change that.
3. I added check that projections are not supported for quantized LSTMs to quantized_lstm_<data/input> functions. But I didn't add any checks to LSTMCell code. It seems that since I disabled projections in "lstm_cell" function, they should also not be available for quantized models through any other API than quantized_lstm_<data/input>. Please let me know if I'm not correct and I will add checks to other places.
4. Projections are not supported for CuDNN versions < 7.1.2. Should I add the check for CuDNN version and disable projections in that case? If so, what will be the best way to do that?
5. Currently I added projection weight as the last weight, so the layout is "w_ih, w_hh, b_ih, b_hh, w_hr". This breaks the assumption that biases come after weights and thus I had to add additional if-s in various places. Alternative way would be to have "w_ih, w_hh, w_hr, b_ih, b_hh" layout, in which case the assumption will be true. But in that case I will need to split the loop in get_parameters function from aten/src/ATen/native/cudnn/RNN.cpp. And in some cases, I will still need to add an "undefined" tensor in the 3rd position, because we get all 5 weights from CuDNN most of the time. So I'm not sure which way is better. Let me know if you think I should change to the weights-then-biases layout.

Pull Request resolved: #47725

Reviewed By: zou3519

Differential Revision: D25449794

Pulled By: ngimel

fbshipit-source-id: fe6ce59e481d1f5fd861a8ff7fa13d1affcedb0c
  • Loading branch information
Igor Gitman authored and facebook-github-bot committed Dec 16, 2020
1 parent 48d1ad1 commit 1b6d18a
Show file tree
Hide file tree
Showing 17 changed files with 984 additions and 234 deletions.
26 changes: 20 additions & 6 deletions aten/src/ATen/cudnn/AutocastRNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ _cudnn_rnn_cast_reflatten(const Tensor & input,
const c10::optional<Tensor>& cx,
int64_t mode,
int64_t hidden_size,
int64_t proj_size,
int64_t num_layers,
bool batch_first,
double dropout,
Expand All @@ -43,10 +44,18 @@ _cudnn_rnn_cast_reflatten(const Tensor & input,
// weight_stride0 is the number of weight tensors per layer and direction, as seen by model.parameters().
// If bias is enabled, there are 4 such tensors (ih and hh weights, ih and hh biases).
// If bias is not enabled, there are 2 (ih and hh weights).
// This organization holds for all rnn types (RNN, GRU, and LSTM).
TORCH_INTERNAL_ASSERT((weight_stride0 == 2) || (weight_stride0 == 4),
"weight_stride0 must be 2 (if no bias) or 4 (if bias). Received ",
weight_stride0);
// This organization holds for all rnn types (RNN, GRU, and LSTM). If LSTM with projections is
// used, additional hr weight is added.
if (proj_size > 0) {
TORCH_INTERNAL_ASSERT((weight_stride0 == 3) || (weight_stride0 == 5),
"weight_stride0 must be 3 (if no bias) or 5 (if bias) for LSTM with projections. Received ",
weight_stride0);
} else {
TORCH_INTERNAL_ASSERT((weight_stride0 == 2) || (weight_stride0 == 4),
"weight_stride0 must be 2 (if no bias) or 4 (if bias). Received ",
weight_stride0);
}


Tensor weight_buf, redispatch_weight_buf;
std::vector<Tensor> redispatch_weight;
Expand All @@ -65,23 +74,27 @@ _cudnn_rnn_cast_reflatten(const Tensor & input,
// Casts weight tensors to FP16 and ensures all weights for all layers are views into a large flat buffer,
// with the right locations and layouts expected by cudnn.
// This is (and should be) autograd-exposed.
bool include_bias = true;
if (weight_stride0 == 2 || (weight_stride0 == 3 && proj_size > 0)) {
include_bias = false;
}
std::tie(redispatch_weight_buf, redispatch_weight) =
at::native::cudnn_rnn::copy_weights_to_flat_buf_views(
weight,
weight_stride0,
input.size(-1),
mode,
hidden_size,
proj_size,
num_layers,
batch_first,
bidirectional,
/*flat_buf_datatype=*/at::native::getCudnnDataTypeFromScalarType(at::kHalf), // could just hardcode CUDNN_DATA_HALF
/*flat_buf_options=*/weight[0].options().dtype(at::kHalf),
/*set_orig_weights_to_flat_buf=*/false,
/*allow_type_change=*/true,
/*include_bias=*/weight_stride0 == 4);
/*include_bias=*/include_bias);
}

return at::_cudnn_rnn(
cached_cast(at::kHalf, input),
needs_cast_and_flatten ? TensorList(redispatch_weight) : weight,
Expand All @@ -91,6 +104,7 @@ _cudnn_rnn_cast_reflatten(const Tensor & input,
cached_cast(at::kHalf, cx),
mode,
hidden_size,
proj_size,
num_layers,
batch_first,
dropout,
Expand Down
14 changes: 11 additions & 3 deletions aten/src/ATen/cudnn/Descriptors.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ static inline void fixSizeOneDimStride(int dim, const int *size, int *stride, bo
int64_t z = 1;
int index = 0;
std::vector<int> permutation(dim);

if (nhwc) {
permutation[index++] = 1;
}
Expand Down Expand Up @@ -244,10 +244,11 @@ struct TORCH_CUDA_API RNNDescriptor
&cudnnDestroyRNNDescriptor>
{
DropoutDescriptor dropout_desc_;
void set(cudnnHandle_t handle, int hidden_size, int num_layers, DropoutDescriptor&& dropout_desc,
void set(cudnnHandle_t handle, int hidden_size, int proj_size, int num_layers, DropoutDescriptor&& dropout_desc,
cudnnRNNInputMode_t input_mode, cudnnDirectionMode_t bidirectional,
cudnnRNNMode_t mode, cudnnDataType_t datatype, cudnnDataType_t input_type, cudnnRNNAlgo_t algo, bool allow_tf32) {
dropout_desc_ = std::move(dropout_desc);

AT_CUDNN_CHECK(cudnnSetRNNDescriptor_v6(
handle,
mut_desc(),
Expand All @@ -259,12 +260,19 @@ struct TORCH_CUDA_API RNNDescriptor
mode,
algo,
datatype));
if (proj_size != 0) {
AT_CUDNN_CHECK(cudnnSetRNNProjectionLayers(
handle,
/*rnnDesc=*/mut_desc(),
/*recProjSize=*/proj_size,
/*outProjSize=*/0));
}
cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
if (prop->major >= 7) {
if (input_type == CUDNN_DATA_HALF) {
cudnnSetRNNMatrixMathType(mut_desc(), CUDNN_TENSOR_OP_MATH);
}
#if defined(CUDNN_VERSION) && CUDNN_VERSION >= 8000
#if defined(CUDNN_VERSION) && CUDNN_VERSION >= 8000
else if (input_type == CUDNN_DATA_FLOAT && !allow_tf32) {
cudnnSetRNNMatrixMathType(mut_desc(), CUDNN_FMA_MATH);
}
Expand Down
85 changes: 66 additions & 19 deletions aten/src/ATen/native/RNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,14 @@ using CellParamsSerializationType = std::tuple<
struct CellParamsBase : torch::CustomClassHolder {
virtual Tensor matmul_ih(const Tensor& input) const = 0;
virtual Tensor matmul_hh(const Tensor& h) const = 0;
// by default doing nothing. CellParams will override this
// to define correct behavior for LSTMs with projections.
// This function is not pure virtual, because it's useful to
// provide this default implementation, so that all cell params
// that don't support projections work correctly (e.g. QuantizedCellParams variations)
virtual Tensor matmul_hr(const Tensor& h) const {
return h;
}
virtual Tensor linear_ih(const Tensor& input_ih) const = 0;
virtual Tensor linear_hh(const Tensor& input_hh) const = 0;

Expand All @@ -79,26 +87,35 @@ struct CellParamsBase : torch::CustomClassHolder {

// Pretty much all cells we support take the same set of arguments, but threading those
// 4 arguments manually is really annoying. Their lifetime is externally managed, so we only
// pass this struct of references around.
// pass this struct of references around. LSTMs with projections have 5th argument w_hr, for all
// other models it's always going to be undefined.
struct CellParams : public CellParamsBase {
CellParams(
const Tensor& _w_ih,
const Tensor& _w_hh,
const Tensor& _b_ih,
const Tensor& _b_hh)
: w_ih(_w_ih), w_hh(_w_hh), b_ih_(_b_ih), b_hh_(_b_hh){};
const Tensor& _b_hh,
const Tensor& _w_hr)
: w_ih(_w_ih), w_hh(_w_hh), b_ih_(_b_ih), b_hh_(_b_hh), w_hr(_w_hr) {};

const Tensor& w_ih;
const Tensor& w_hh;
const Tensor& b_ih_; /* optional */
const Tensor& b_hh_; /* optional */
const Tensor& w_hr; /* only defined for LSTMs with projections */

Tensor matmul_ih(const Tensor& input) const override {
return at::matmul(input, w_ih.t());
}
Tensor matmul_hh(const Tensor& h) const override {
return at::matmul(h, w_hh.t());
}
Tensor matmul_hr(const Tensor& h) const override {
if (w_hr.defined()) {
return at::matmul(h, w_hr.t());
}
return h;
}
Tensor linear_ih(const Tensor& input) const override {
return at::linear(input, w_ih, b_ih_);
}
Expand Down Expand Up @@ -468,6 +485,9 @@ struct QRNNCellParamsWrapper {
Tensor matmul_hh(const Tensor& h) const {
return param_->matmul_hh(h);
}
Tensor matmul_hr(const Tensor& h) const {
return param_->matmul_hr(h);
}
Tensor linear_ih(const Tensor& input) const {
return param_->linear_ih(input);
}
Expand Down Expand Up @@ -509,18 +529,32 @@ static std::vector<T> unpair_vec(std::vector<pair_of<T>>&& vals) {
}

// Parses a flat list of parameter tensors into a list of CellParams
static std::vector<CellParams> gather_params(TensorList params, bool has_biases) {
static std::vector<CellParams> gather_params(TensorList params, bool has_biases, bool has_projections = false) {
static at::Tensor undefined;
std::vector<CellParams> result;
if (has_biases) {
TORCH_CHECK(params.size() % 4 == 0, "got an incorrect number of RNN parameters");
for (size_t i = 0; i < params.size(); i += 4) {
result.emplace_back(params[i], params[i + 1], params[i + 2], params[i + 3]);
if (has_projections) {
TORCH_CHECK(params.size() % 5 == 0, "got an incorrect number of RNN parameters");
for (size_t i = 0; i < params.size(); i += 5) {
result.emplace_back(params[i], params[i + 1], params[i + 2], params[i + 3], params[i + 4]);
}
} else {
TORCH_CHECK(params.size() % 4 == 0, "got an incorrect number of RNN parameters");
for (size_t i = 0; i < params.size(); i += 4) {
result.emplace_back(params[i], params[i + 1], params[i + 2], params[i + 3], undefined);
}
}
} else {
TORCH_CHECK(params.size() % 2 == 0, "got an incorrect number of RNN parameters");
for (size_t i = 0; i < params.size(); i += 2) {
result.emplace_back(params[i], params[i + 1], undefined, undefined);
if (has_projections) {
TORCH_CHECK(params.size() % 3 == 0, "got an incorrect number of RNN parameters");
for (size_t i = 0; i < params.size(); i += 3) {
result.emplace_back(params[i], params[i + 1], undefined, undefined, params[i + 2]);
}
} else {
TORCH_CHECK(params.size() % 2 == 0, "got an incorrect number of RNN parameters");
for (size_t i = 0; i < params.size(); i += 2) {
result.emplace_back(params[i], params[i + 1], undefined, undefined, undefined);
}
}
}
return result;
Expand Down Expand Up @@ -702,8 +736,10 @@ struct LSTMCell : Cell<std::tuple<Tensor, Tensor>, cell_params> {
auto hgates = params.matmul_hh(hx);
auto result = at::_thnn_fused_lstm_cell(
igates, hgates, cx, params.b_ih(), params.b_hh());
// applying projections if w_hr is defined
auto hy = params.matmul_hr(std::get<0>(result));
// Slice off the workspace argument (it's needed only for AD).
return std::make_tuple(std::move(std::get<0>(result)), std::move(std::get<1>(result)));
return std::make_tuple(std::move(hy), std::move(std::get<1>(result)));
}

const auto gates = params.linear_hh(hx).add_(
Expand All @@ -715,6 +751,7 @@ struct LSTMCell : Cell<std::tuple<Tensor, Tensor>, cell_params> {
auto outgate = chunked_gates[3].sigmoid_();
auto cy = (forgetgate * cx).add_(ingate * cellgate);
auto hy = outgate * cy.tanh();
hy = params.matmul_hr(hy);
return std::make_tuple(std::move(hy), std::move(cy));
}

Expand Down Expand Up @@ -1404,16 +1441,18 @@ std::tuple<Tensor, Tensor, Tensor> lstm(
num_layers, dropout_p, train, bidirectional, batch_first);
return std::make_tuple(std::move(output), std::move(hy), std::move(cy));
}

// if cells are of different size, that means projections are used
bool has_projections = (hx[0].size(2) != hx[1].size(2));
if (use_miopen(_input, dropout_p)) {
TORCH_CHECK(!has_projections, "LSTM with projections is not supported with MIOpen");
Tensor output, hy, cy;
lstm_miopen_stub(_input.device().type(), output, hy, cy, _input, hx, _params, has_biases,
num_layers, dropout_p, train, bidirectional, batch_first);
return std::make_tuple(std::move(output), std::move(hy), std::move(cy));
}
check_attributes(_input, _params, hx);
auto input = batch_first ? _input.transpose(0, 1) : _input;
auto params = gather_params(_params, has_biases);
auto params = gather_params(_params, has_biases, has_projections);
auto results = _lstm_impl<FullLayer, FullBidirectionalLayer>(
input, params, hx[0], hx[1], num_layers, dropout_p, train, bidirectional);
if (batch_first) {
Expand All @@ -1433,16 +1472,18 @@ std::tuple<Tensor, Tensor, Tensor> lstm(
_params, has_biases, num_layers, dropout_p, train, bidirectional);
return std::make_tuple(std::move(output), std::move(hy), std::move(cy));
}

// if cells are of different size, that means projections are used
bool has_projections = (hx[0].size(2) != hx[1].size(2));
if (use_miopen(data, dropout_p)) {
TORCH_CHECK(!has_projections, "LSTM with projections is not supported with MIOpen");
Tensor output, hy, cy;
lstm_packed_miopen_stub(data.device().type(), output, hy, cy, data, batch_sizes, hx,
_params, has_biases, num_layers, dropout_p, train, bidirectional);
return std::make_tuple(std::move(output), std::move(hy), std::move(cy));
}

PackedSequence input { data, batch_sizes };
auto params = gather_params(_params, has_biases);
auto params = gather_params(_params, has_biases, has_projections);
auto result = _lstm_impl<PackedLayer, PackedBidirectionalLayer>(
input, params, hx[0], hx[1], num_layers, dropout_p, train, bidirectional);
auto & packed_output = std::get<0>(result);
Expand All @@ -1455,7 +1496,8 @@ std::tuple<Tensor, Tensor> lstm_cell(
const Tensor& input, TensorList hx,
const Tensor& w_ih, const Tensor& w_hh, const Tensor& b_ih, const Tensor& b_hh) {
TORCH_CHECK(hx.size() == 2, "lstm_cell expects two hidden states");
return LSTMCell<CellParams>{}(input, std::make_tuple(hx[0], hx[1]), CellParams{w_ih, w_hh, b_ih, b_hh});
static at::Tensor undefined;
return LSTMCell<CellParams>{}(input, std::make_tuple(hx[0], hx[1]), CellParams{w_ih, w_hh, b_ih, b_hh, undefined});
}

std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor>
Expand Down Expand Up @@ -1552,19 +1594,22 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> _thnn_differentiable_gru_cell
Tensor gru_cell(
const Tensor& input, const Tensor& hx,
const Tensor& w_ih, const Tensor& w_hh, const Tensor& b_ih, const Tensor& b_hh) {
return GRUCell<CellParams>{}(input, hx, CellParams{w_ih, w_hh, b_ih, b_hh});
static at::Tensor undefined;
return GRUCell<CellParams>{}(input, hx, CellParams{w_ih, w_hh, b_ih, b_hh, undefined});
}

Tensor rnn_tanh_cell(
const Tensor& input, const Tensor& hx,
const Tensor& w_ih, const Tensor& w_hh, const Tensor& b_ih, const Tensor& b_hh) {
return SimpleCell<tanh_f, CellParams>{}(input, hx, CellParams{w_ih, w_hh, b_ih, b_hh});
static at::Tensor undefined;
return SimpleCell<tanh_f, CellParams>{}(input, hx, CellParams{w_ih, w_hh, b_ih, b_hh, undefined});
}

Tensor rnn_relu_cell(
const Tensor& input, const Tensor& hx,
const Tensor& w_ih, const Tensor& w_hh, const Tensor& b_ih, const Tensor& b_hh) {
return SimpleCell<relu_f, CellParams>{}(input, hx, CellParams{w_ih, w_hh, b_ih, b_hh});
static at::Tensor undefined;
return SimpleCell<relu_f, CellParams>{}(input, hx, CellParams{w_ih, w_hh, b_ih, b_hh, undefined});
}

// Quantized implementations
Expand Down Expand Up @@ -1592,6 +1637,7 @@ std::tuple<Tensor, Tensor, Tensor> quantized_lstm_input(
params.emplace_back(static_cast<c10::intrusive_ptr<CellParamsBase>>(param));
}
TORCH_CHECK(hx.size() == 2, "lstm expects two hidden states");
TORCH_CHECK(hx[0].size(2) == hx[1].size(2), "quantized LSTM with projections is not supported");
auto result_dtype = dtype.has_value() ? dtype.value() : at::kChar;
auto input = batch_first ? _input.transpose(0, 1) : _input;
TORCH_CHECK(has_biases, "quantized LSTM requires biases");
Expand Down Expand Up @@ -1685,6 +1731,7 @@ std::tuple<Tensor, Tensor, Tensor> quantized_lstm_data(
params.emplace_back(static_cast<c10::intrusive_ptr<CellParamsBase>>(param));
}
TORCH_CHECK(hx.size() == 2, "lstm expects two hidden states");
TORCH_CHECK(hx[0].size(2) == hx[1].size(2), "quantized LSTM with projections is not supported");

auto result_dtype = dtype.has_value() ? dtype.value() : at::kChar;

Expand Down

0 comments on commit 1b6d18a

Please sign in to comment.