Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding support for CuDNN-based LSTM with projections #47725

Closed
wants to merge 61 commits into from
Closed
Show file tree
Hide file tree
Changes from 39 commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
2aa911e
Expose proj_dim parameter
Oct 21, 2020
a6c987c
Fix int -> int_64t issue
Oct 21, 2020
2a5fdce
Start exposing proj_size throughout the code base
Oct 22, 2020
ee74637
Exposed through get_weight_buf and flatten_weights
Oct 22, 2020
c72c9ce
Expose proj_size through _cudnn_rnn
Oct 22, 2020
cd3acc2
Update get_parameters to work with projections
Oct 22, 2020
54a9849
Remove redundant proj_size, fix get_expected_data_ptrs
Oct 22, 2020
24e47c2
Fix try_get_weight_buf
Oct 22, 2020
dab693d
Fix _cudnn_rnn function
Oct 22, 2020
4d07a2a
Fix _cudnn_rnn_backward functions
Oct 22, 2020
5481d77
Add correct hx creation on python side
Oct 22, 2020
d0dc6db
Fix incorrect projection layers initialization
Oct 23, 2020
956ed26
Correct weight initialization on python side
Oct 23, 2020
4d0feae
Fix output size issue
Oct 23, 2020
134ad1d
Fix multi-layer projections issue
Oct 27, 2020
9557035
Expose proj_size in setstate and extra_repr
Oct 28, 2020
2da10ee
Fix AutocastRNN to accept models with projections
Nov 3, 2020
b63facc
Merge branch 'master' into cudnn_projections
Nov 4, 2020
b142a0a
Add test TODOs, add check for non-cudnn code
Nov 4, 2020
44d0c67
Fix incorrect hidden states init for LSTM
Nov 9, 2020
da32e77
Fix error for RNN/GRU of accessing undefined cx
Nov 9, 2020
2a2c3f9
Fix no-bias projections lstm for fp32
Nov 10, 2020
de5519d
Fix no-bias projection fp16 case
Nov 10, 2020
06fadd2
Add check for rnn/gru, add initial unit tests
Nov 10, 2020
09d0fe1
Add proj_size to test_variable_sequence
Nov 10, 2020
ae84ba6
Add projections to rnn_weight_norm test
Nov 10, 2020
8b0544d
Add projections to cudnn_weight_format test
Nov 10, 2020
db24db7
Add projections to cudnn_weight_tying test
Nov 10, 2020
515341c
Add projections to rnn_args_check test
Nov 10, 2020
337b0e7
Add projections to rnn_check_device test
Nov 10, 2020
ccb5054
Expose cudnn with projections on CPU
Nov 10, 2020
527f62c
Add cpu_vs_gpu projections test
Nov 10, 2020
4542031
Add weight norm test
Nov 10, 2020
28986a8
Remove TODOs
Nov 10, 2020
3bbd503
Code clean up
Nov 10, 2020
03e386e
Revert empty changes
Nov 11, 2020
336732f
Revert miopen style change
Nov 11, 2020
8842c13
Disable projections for quantized LSTMs
Nov 11, 2020
46f3172
Merge branch 'master' into cudnn_projections
Nov 11, 2020
2aedcf6
Fix linting errors
Nov 11, 2020
f830048
Remove .cuda() call from proj hidden_state test
Nov 11, 2020
dfae3ff
Add projections documentation to nn.LSTM
Nov 11, 2020
8113922
Merge branch 'master' into cudnn_projections
Nov 11, 2020
b723ec5
Remove cuda placement from proj initial_hidden_state test
Nov 11, 2020
e2206b1
Merge branch 'master' into cudnn_projections
Nov 24, 2020
54cca58
Address PR comments
Nov 24, 2020
107bb28
Expose projections in c++ API
Nov 24, 2020
9038cfc
Add c++ integration tests for projections
Nov 24, 2020
821f89b
Add check output size test
Nov 24, 2020
a593687
Add more projection tests to c++ api
Nov 24, 2020
47465c9
Merge branch 'master' into cudnn_projections
Dec 4, 2020
be413cf
Add correct type hints
Dec 4, 2020
af359cf
Change number of ops in caffe2 onnx tests
Dec 9, 2020
ff01240
Add unimplemented call for LSTMs with projections in onnx
Dec 9, 2020
8f95f8e
Add onnx test to check projections not supported
Dec 9, 2020
bb87f91
Merge branch 'master' into cudnn_projections
Dec 9, 2020
4d7d6ca
Add _cudnn_rnn functions to allow_list for bc
Dec 9, 2020
6bb8bff
Adjust tests to work on rocm
Dec 11, 2020
6834b59
Merge branch 'master' into cudnn_projections
Dec 11, 2020
42f1563
Add more precise check for runtime error in tests
Dec 14, 2020
f696eda
Merge branch 'master' into cudnn_projections
Dec 14, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
26 changes: 20 additions & 6 deletions aten/src/ATen/cudnn/AutocastRNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ _cudnn_rnn_cast_reflatten(const Tensor & input,
const c10::optional<Tensor>& cx,
int64_t mode,
int64_t hidden_size,
int64_t proj_size,
int64_t num_layers,
bool batch_first,
double dropout,
Expand All @@ -43,10 +44,18 @@ _cudnn_rnn_cast_reflatten(const Tensor & input,
// weight_stride0 is the number of weight tensors per layer and direction, as seen by model.parameters().
// If bias is enabled, there are 4 such tensors (ih and hh weights, ih and hh biases).
// If bias is not enabled, there are 2 (ih and hh weights).
// This organization holds for all rnn types (RNN, GRU, and LSTM).
TORCH_INTERNAL_ASSERT((weight_stride0 == 2) || (weight_stride0 == 4),
"weight_stride0 must be 2 (if no bias) or 4 (if bias). Received ",
weight_stride0);
// This organization holds for all rnn types (RNN, GRU, and LSTM). If LSTM with projections is
// used, additional hr weight is added.
if (proj_size > 0 && proj_size != hidden_size) {
TORCH_INTERNAL_ASSERT((weight_stride0 == 3) || (weight_stride0 == 5),
"weight_stride0 must be 3 (if no bias) or 5 (if bias) for LSTM with projections. Received ",
weight_stride0);
} else {
TORCH_INTERNAL_ASSERT((weight_stride0 == 2) || (weight_stride0 == 4),
"weight_stride0 must be 2 (if no bias) or 4 (if bias). Received ",
weight_stride0);
}


Tensor weight_buf, redispatch_weight_buf;
std::vector<Tensor> redispatch_weight;
Expand All @@ -65,23 +74,27 @@ _cudnn_rnn_cast_reflatten(const Tensor & input,
// Casts weight tensors to FP16 and ensures all weights for all layers are views into a large flat buffer,
// with the right locations and layouts expected by cudnn.
// This is (and should be) autograd-exposed.
bool include_bias = true;
if (weight_stride0 == 3 || weight_stride0 == 2) {
include_bias = false;
}
std::tie(redispatch_weight_buf, redispatch_weight) =
at::native::cudnn_rnn::copy_weights_to_flat_buf_views(
weight,
weight_stride0,
input.size(-1),
mode,
hidden_size,
proj_size,
num_layers,
batch_first,
bidirectional,
/*flat_buf_datatype=*/at::native::getCudnnDataTypeFromScalarType(at::kHalf), // could just hardcode CUDNN_DATA_HALF
/*flat_buf_options=*/weight[0].options().dtype(at::kHalf),
/*set_orig_weights_to_flat_buf=*/false,
/*allow_type_change=*/true,
/*include_bias=*/weight_stride0 == 4);
/*include_bias=*/include_bias);
}

return at::_cudnn_rnn(
cached_cast(at::kHalf, input),
needs_cast_and_flatten ? TensorList(redispatch_weight) : weight,
Expand All @@ -91,6 +104,7 @@ _cudnn_rnn_cast_reflatten(const Tensor & input,
cached_cast(at::kHalf, cx),
mode,
hidden_size,
proj_size,
num_layers,
batch_first,
dropout,
Expand Down
10 changes: 9 additions & 1 deletion aten/src/ATen/cudnn/Descriptors.h
Original file line number Diff line number Diff line change
Expand Up @@ -244,10 +244,11 @@ struct TORCH_CUDA_API RNNDescriptor
&cudnnDestroyRNNDescriptor>
{
DropoutDescriptor dropout_desc_;
void set(cudnnHandle_t handle, int hidden_size, int num_layers, DropoutDescriptor&& dropout_desc,
void set(cudnnHandle_t handle, int hidden_size, int proj_size, int num_layers, DropoutDescriptor&& dropout_desc,
cudnnRNNInputMode_t input_mode, cudnnDirectionMode_t bidirectional,
cudnnRNNMode_t mode, cudnnDataType_t datatype, cudnnDataType_t input_type, cudnnRNNAlgo_t algo, bool allow_tf32) {
dropout_desc_ = std::move(dropout_desc);

AT_CUDNN_CHECK(cudnnSetRNNDescriptor_v6(
handle,
mut_desc(),
Expand All @@ -259,6 +260,13 @@ struct TORCH_CUDA_API RNNDescriptor
mode,
algo,
datatype));
if (proj_size != 0 && proj_size != hidden_size) {
AT_CUDNN_CHECK(cudnnSetRNNProjectionLayers(
handle,
/*rnnDesc=*/mut_desc(),
/*recProjSize=*/proj_size,
/*outProjSize=*/0));
}
cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
if (prop->major >= 7) {
if (input_type == CUDNN_DATA_HALF) {
Expand Down
82 changes: 63 additions & 19 deletions aten/src/ATen/native/RNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,11 @@ using CellParamsSerializationType = std::tuple<
struct CellParamsBase : torch::CustomClassHolder {
virtual Tensor matmul_ih(const Tensor& input) const = 0;
virtual Tensor matmul_hh(const Tensor& h) const = 0;
// by default doing nothing. CellParams will override this
// to define correct behavior for LSTMs with projections
virtual Tensor matmul_hr(const Tensor& h) const {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a reason you are not making it pure virtual, like all other functions? If so, please explain it in the comment

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason is that it's easier to define it here, so that this default behavior can be directly re-used in all function that don't support projections (e.g. quantized cells). I added explanation to this in the code comments as well

return h;
}
virtual Tensor linear_ih(const Tensor& input_ih) const = 0;
virtual Tensor linear_hh(const Tensor& input_hh) const = 0;

Expand All @@ -79,26 +84,35 @@ struct CellParamsBase : torch::CustomClassHolder {

// Pretty much all cells we support take the same set of arguments, but threading those
// 4 arguments manually is really annoying. Their lifetime is externally managed, so we only
// pass this struct of references around.
// pass this struct of references around. LSTMs with projections have 5th argument w_hr, for all
// other models it's always going to be undefined.
struct CellParams : public CellParamsBase {
CellParams(
const Tensor& _w_ih,
const Tensor& _w_hh,
const Tensor& _b_ih,
const Tensor& _b_hh)
: w_ih(_w_ih), w_hh(_w_hh), b_ih_(_b_ih), b_hh_(_b_hh){};
const Tensor& _b_hh,
const Tensor& _w_hr)
: w_ih(_w_ih), w_hh(_w_hh), b_ih_(_b_ih), b_hh_(_b_hh), w_hr(_w_hr) {};

const Tensor& w_ih;
const Tensor& w_hh;
const Tensor& b_ih_; /* optional */
const Tensor& b_hh_; /* optional */
const Tensor& w_hr; /* only defined for LSTMs with projections */

Tensor matmul_ih(const Tensor& input) const override {
return at::matmul(input, w_ih.t());
}
Tensor matmul_hh(const Tensor& h) const override {
return at::matmul(h, w_hh.t());
}
Tensor matmul_hr(const Tensor& h) const override {
if (w_hr.defined()) {
return at::matmul(h, w_hr.t());
}
return h;
}
Tensor linear_ih(const Tensor& input) const override {
return at::linear(input, w_ih, b_ih_);
}
Expand Down Expand Up @@ -468,6 +482,9 @@ struct QRNNCellParamsWrapper {
Tensor matmul_hh(const Tensor& h) const {
return param_->matmul_hh(h);
}
Tensor matmul_hr(const Tensor& h) const {
return param_->matmul_hr(h);
}
Tensor linear_ih(const Tensor& input) const {
return param_->linear_ih(input);
}
Expand Down Expand Up @@ -509,18 +526,32 @@ static std::vector<T> unpair_vec(std::vector<pair_of<T>>&& vals) {
}

// Parses a flat list of parameter tensors into a list of CellParams
static std::vector<CellParams> gather_params(TensorList params, bool has_biases) {
static std::vector<CellParams> gather_params(TensorList params, bool has_biases, bool has_projections = false) {
static at::Tensor undefined;
std::vector<CellParams> result;
if (has_biases) {
TORCH_CHECK(params.size() % 4 == 0, "got an incorrect number of RNN parameters");
for (size_t i = 0; i < params.size(); i += 4) {
result.emplace_back(params[i], params[i + 1], params[i + 2], params[i + 3]);
if (has_projections) {
TORCH_CHECK(params.size() % 5 == 0, "got an incorrect number of RNN parameters");
for (size_t i = 0; i < params.size(); i += 5) {
result.emplace_back(params[i], params[i + 1], params[i + 2], params[i + 3], params[i + 4]);
}
} else {
TORCH_CHECK(params.size() % 4 == 0, "got an incorrect number of RNN parameters");
for (size_t i = 0; i < params.size(); i += 4) {
result.emplace_back(params[i], params[i + 1], params[i + 2], params[i + 3], undefined);
}
}
} else {
TORCH_CHECK(params.size() % 2 == 0, "got an incorrect number of RNN parameters");
for (size_t i = 0; i < params.size(); i += 2) {
result.emplace_back(params[i], params[i + 1], undefined, undefined);
if (has_projections) {
TORCH_CHECK(params.size() % 3 == 0, "got an incorrect number of RNN parameters");
for (size_t i = 0; i < params.size(); i += 3) {
result.emplace_back(params[i], params[i + 1], undefined, undefined, params[i + 2]);
}
} else {
TORCH_CHECK(params.size() % 2 == 0, "got an incorrect number of RNN parameters");
for (size_t i = 0; i < params.size(); i += 2) {
result.emplace_back(params[i], params[i + 1], undefined, undefined, undefined);
}
}
}
return result;
Expand Down Expand Up @@ -702,8 +733,10 @@ struct LSTMCell : Cell<std::tuple<Tensor, Tensor>, cell_params> {
auto hgates = params.matmul_hh(hx);
auto result = at::_thnn_fused_lstm_cell(
igates, hgates, cx, params.b_ih(), params.b_hh());
// applying projections if w_hr is defined
auto hy = params.matmul_hr(std::get<0>(result));
// Slice off the workspace argument (it's needed only for AD).
return std::make_tuple(std::move(std::get<0>(result)), std::move(std::get<1>(result)));
return std::make_tuple(std::move(hy), std::move(std::get<1>(result)));
}

const auto gates = params.linear_hh(hx).add_(
Expand All @@ -715,6 +748,7 @@ struct LSTMCell : Cell<std::tuple<Tensor, Tensor>, cell_params> {
auto outgate = chunked_gates[3].sigmoid_();
auto cy = (forgetgate * cx).add_(ingate * cellgate);
auto hy = outgate * cy.tanh();
hy = params.matmul_hr(hy);
return std::make_tuple(std::move(hy), std::move(cy));
}

Expand Down Expand Up @@ -1404,16 +1438,18 @@ std::tuple<Tensor, Tensor, Tensor> lstm(
num_layers, dropout_p, train, bidirectional, batch_first);
return std::make_tuple(std::move(output), std::move(hy), std::move(cy));
}

// if cells are of different size, that means projections are used
bool has_projections = (hx[0].size(2) != hx[1].size(2));
if (use_miopen(_input, dropout_p)) {
TORCH_CHECK(!has_projections, "LSTM with projections is not supported with MIOpen");
Tensor output, hy, cy;
lstm_miopen_stub(_input.device().type(), output, hy, cy, _input, hx, _params, has_biases,
num_layers, dropout_p, train, bidirectional, batch_first);
return std::make_tuple(std::move(output), std::move(hy), std::move(cy));
}
check_attributes(_input, _params, hx);
auto input = batch_first ? _input.transpose(0, 1) : _input;
auto params = gather_params(_params, has_biases);
auto params = gather_params(_params, has_biases, has_projections);
auto results = _lstm_impl<FullLayer, FullBidirectionalLayer>(
input, params, hx[0], hx[1], num_layers, dropout_p, train, bidirectional);
if (batch_first) {
Expand All @@ -1433,16 +1469,18 @@ std::tuple<Tensor, Tensor, Tensor> lstm(
_params, has_biases, num_layers, dropout_p, train, bidirectional);
return std::make_tuple(std::move(output), std::move(hy), std::move(cy));
}

// if cells are of different size, that means projections are used
bool has_projections = (hx[0].size(2) != hx[1].size(2));
if (use_miopen(data, dropout_p)) {
TORCH_CHECK(!has_projections, "LSTM with projections is not supported with MIOpen");
Tensor output, hy, cy;
lstm_packed_miopen_stub(data.device().type(), output, hy, cy, data, batch_sizes, hx,
_params, has_biases, num_layers, dropout_p, train, bidirectional);
return std::make_tuple(std::move(output), std::move(hy), std::move(cy));
}

PackedSequence input { data, batch_sizes };
auto params = gather_params(_params, has_biases);
auto params = gather_params(_params, has_biases, has_projections);
auto result = _lstm_impl<PackedLayer, PackedBidirectionalLayer>(
input, params, hx[0], hx[1], num_layers, dropout_p, train, bidirectional);
auto & packed_output = std::get<0>(result);
Expand All @@ -1455,7 +1493,8 @@ std::tuple<Tensor, Tensor> lstm_cell(
const Tensor& input, TensorList hx,
const Tensor& w_ih, const Tensor& w_hh, const Tensor& b_ih, const Tensor& b_hh) {
TORCH_CHECK(hx.size() == 2, "lstm_cell expects two hidden states");
return LSTMCell<CellParams>{}(input, std::make_tuple(hx[0], hx[1]), CellParams{w_ih, w_hh, b_ih, b_hh});
static at::Tensor undefined;
return LSTMCell<CellParams>{}(input, std::make_tuple(hx[0], hx[1]), CellParams{w_ih, w_hh, b_ih, b_hh, undefined});
}

std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor>
Expand Down Expand Up @@ -1552,19 +1591,22 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> _thnn_differentiable_gru_cell
Tensor gru_cell(
const Tensor& input, const Tensor& hx,
const Tensor& w_ih, const Tensor& w_hh, const Tensor& b_ih, const Tensor& b_hh) {
return GRUCell<CellParams>{}(input, hx, CellParams{w_ih, w_hh, b_ih, b_hh});
static at::Tensor undefined;
return GRUCell<CellParams>{}(input, hx, CellParams{w_ih, w_hh, b_ih, b_hh, undefined});
}

Tensor rnn_tanh_cell(
const Tensor& input, const Tensor& hx,
const Tensor& w_ih, const Tensor& w_hh, const Tensor& b_ih, const Tensor& b_hh) {
return SimpleCell<tanh_f, CellParams>{}(input, hx, CellParams{w_ih, w_hh, b_ih, b_hh});
static at::Tensor undefined;
return SimpleCell<tanh_f, CellParams>{}(input, hx, CellParams{w_ih, w_hh, b_ih, b_hh, undefined});
}

Tensor rnn_relu_cell(
const Tensor& input, const Tensor& hx,
const Tensor& w_ih, const Tensor& w_hh, const Tensor& b_ih, const Tensor& b_hh) {
return SimpleCell<relu_f, CellParams>{}(input, hx, CellParams{w_ih, w_hh, b_ih, b_hh});
static at::Tensor undefined;
return SimpleCell<relu_f, CellParams>{}(input, hx, CellParams{w_ih, w_hh, b_ih, b_hh, undefined});
}

// Quantized implementations
Expand Down Expand Up @@ -1592,6 +1634,7 @@ std::tuple<Tensor, Tensor, Tensor> quantized_lstm_input(
params.emplace_back(static_cast<c10::intrusive_ptr<CellParamsBase>>(param));
}
TORCH_CHECK(hx.size() == 2, "lstm expects two hidden states");
TORCH_CHECK(hx[0].size(2) == hx[1].size(2), "quantized LSTM with projections is not supported");
auto result_dtype = dtype.has_value() ? dtype.value() : at::kChar;
auto input = batch_first ? _input.transpose(0, 1) : _input;
TORCH_CHECK(has_biases, "quantized LSTM requires biases");
Expand Down Expand Up @@ -1685,6 +1728,7 @@ std::tuple<Tensor, Tensor, Tensor> quantized_lstm_data(
params.emplace_back(static_cast<c10::intrusive_ptr<CellParamsBase>>(param));
}
TORCH_CHECK(hx.size() == 2, "lstm expects two hidden states");
TORCH_CHECK(hx[0].size(2) == hx[1].size(2), "quantized LSTM with projections is not supported");

auto result_dtype = dtype.has_value() ? dtype.value() : at::kChar;

Expand Down