Skip to content

Commit

Permalink
[MXNET-107]Fused GRU implementation for CPU (apache#10311)
Browse files Browse the repository at this point in the history
* Add GRU Support and Test Case

* skip the gpu test case that has nothing to do with RNN GRU

* fix robust bug for gru backward

* fix bug for unifying weight parameter

* add GRU multiple layer and bidirection support with test case

* fix test case bug

* fix test case bug

* fix bug for memory issue

* fix bug for bidirection

* rebase code and fix bug for memory corruption issue

* fix gpu compile issue

* fix bug and enable some test cases

* fix robust bug

* trigger the build to check if quantize-gpu case is covered

* trigger the build to check if MKLDNN+GPU case is covered

* disable failed gpu test case of MKLDNN_UTIL_FUNC-MemFormat because it has nothing to do with this PR and will recover it once the issue is passed

* skip failed test_reduce test case temporarily as it has nothing to do with RNN

* enable several test cases

* retrigger the build

* rebase code from lstm

* rebase code for resolve conflict

* add gru code after resolve conflict

* fix bug for resolve conflict

* add Fused GRU code with test case

* retrigger the build

* add GetRecommendedOMPThreadCount for omp

* fix conflict issue

* add gru relate code

* fix bug for code

* update code for gru

* retrigger the build

* fix code about gru condition

* enhance test case to test gradient weights and bias

* fix bug for test case

* fix bug for test case

* fix bug about dropout condition and test case

* fix bug for test case

* fix bug for test case

* retrigger the build

* rebase code

* add gru code

* fix issues about namespace, removing define and memcpy

* retrigger the build

* fix issues and add cudnn_gru_bucketing.py test case

* retrigger the build

* update cudnn_rnn_bucketing.py test case

* update cudnn_rnn_bucketing.py test case

* update cudnn_rnn_bucketing.py test case

* add check for req[kParams] and kAddTo from cudnn_rnn-inl.h

* retrigger the build

* retrigger the build

* retrigger the build

* add kNullOp check

* retrigger the build

* update kNullOp support and test case for both GRU and LSTM

* update kAddToOp support for both GRU and LSTM
  • Loading branch information
Hao Li authored and zheng-da committed Jun 28, 2018
1 parent c06e064 commit f726b69
Show file tree
Hide file tree
Showing 5 changed files with 1,060 additions and 50 deletions.
Expand Up @@ -65,6 +65,8 @@
help='stack fused RNN cells to reduce communication overhead')
parser.add_argument('--dropout', type=float, default='0.0',
help='dropout probability (1.0 - keep probability)')
parser.add_argument('--rnntype', type=str, default='lstm',
help='rnn type: gru and lstm are supported')

#buckets = [32]
buckets = [10, 20, 30, 40, 50, 60]
Expand Down Expand Up @@ -97,13 +99,13 @@ def train(args):
cell = mx.rnn.SequentialRNNCell()
for i in range(args.num_layers):
cell.add(mx.rnn.FusedRNNCell(args.num_hidden, num_layers=1,
mode='lstm', prefix='lstm_l%d'%i,
mode=args.rnntype, prefix='%s_l%d'%(args.rnntype,i),
bidirectional=args.bidirectional))
if args.dropout > 0 and i < args.num_layers - 1:
cell.add(mx.rnn.DropoutCell(args.dropout, prefix='lstm_d%d'%i))
if args.dropout > 0 and i < args.num_layers - 1 and args.rnntype == 'lstm':
cell.add(mx.rnn.DropoutCell(args.dropout, prefix='%s_d%d'%(args.rnntype,i)))
else:
cell = mx.rnn.FusedRNNCell(args.num_hidden, num_layers=args.num_layers, dropout=args.dropout,
mode='lstm', bidirectional=args.bidirectional)
mode=args.rnntype, bidirectional=args.bidirectional)

def sym_gen(seq_len):
data = mx.sym.Variable('data')
Expand Down Expand Up @@ -168,16 +170,25 @@ def test(args):

if not args.stack_rnn:
stack = mx.rnn.FusedRNNCell(args.num_hidden, num_layers=args.num_layers,
mode='lstm', bidirectional=args.bidirectional).unfuse()
mode=args.rnntype, bidirectional=args.bidirectional).unfuse()
else:
stack = mx.rnn.SequentialRNNCell()
for i in range(args.num_layers):
cell = mx.rnn.LSTMCell(num_hidden=args.num_hidden, prefix='lstm_%dl0_'%i)
if args.bidirectional:
cell = mx.rnn.BidirectionalCell(
cell,
mx.rnn.LSTMCell(num_hidden=args.num_hidden, prefix='lstm_%dr0_'%i),
output_prefix='bi_lstm_%d'%i)
if args.rnntype == 'lstm':
cell = mx.rnn.LSTMCell(num_hidden=args.num_hidden, prefix='%s_%dl0_'%(args.rnntype,i))
if args.bidirectional:
cell = mx.rnn.BidirectionalCell(
cell,
mx.rnn.LSTMCell(num_hidden=args.num_hidden, prefix='%s_%dr0_'%(args.rnntype,i)),
output_prefix='bi_%s_%d'%(args.rnntype,i))
elif args.rnntype == 'gru':
cell = mx.rnn.GRUCell(num_hidden=args.num_hidden, prefix='%s_%dl0_'%(args.rnntype,i))
if args.bidirectional:
cell = mx.rnn.BidirectionalCell(
cell,
mx.rnn.GRUCell(num_hidden=args.num_hidden, prefix='%s_%dr0_'%(args.rnntype,i)),
output_prefix='bi_%s_%d'%(args.rnntype,i))

stack.add(cell)

def sym_gen(seq_len):
Expand Down
2 changes: 1 addition & 1 deletion python/mxnet/gluon/rnn/rnn_layer.py
Expand Up @@ -190,7 +190,7 @@ def forward(self, inputs, states=None):
self.i2h_weight[i].shape = (self._gates*self._hidden_size, inputs.shape[2])
self.i2h_weight[i]._finish_deferred_init()
if inputs.context.device_type == 'gpu' or \
self._mode == 'lstm' and not self._dropout:
self._mode in ['lstm', 'gru'] and not self._dropout:
out = self._forward_kernel(inputs, states)
else:
out = self._forward(inputs, states)
Expand Down
57 changes: 44 additions & 13 deletions src/operator/rnn-inl.h
Expand Up @@ -101,12 +101,14 @@ inline size_t GetRNNWorkspaceSize(int seq_length,
switch (mode) {
case rnn_enum::kRnnRelu:
case rnn_enum::kRnnTanh:
case rnn_enum::kGru:
LOG(FATAL) << "Only LSTM is supported at the moment";
LOG(FATAL) << "Only LSTM and GRU are supported at the moment";
break;
case rnn_enum::kLstm:
size = (seq_length + 1) * batch_size * hidden_size * 4 + batch_size * hidden_size * 2
+ seq_length * batch_size * hidden_size * direction;
+ seq_length * batch_size * hidden_size * direction + hidden_size * seq_length * 8;
break;
case rnn_enum::kGru:
size = seq_length * batch_size * hidden_size * direction * 4 + batch_size * hidden_size * 8;
break;
default:
LOG(FATAL) << "unknown RNN mode " << mode;
Expand All @@ -125,12 +127,16 @@ inline size_t GetRNNReserveSpaceSize(int num_layer,
switch (mode) {
case rnn_enum::kRnnRelu:
case rnn_enum::kRnnTanh:
case rnn_enum::kGru:
LOG(FATAL) << "Only LSTM is supported at the moment";
LOG(FATAL) << "Only LSTM and GRU are supported at the moment";
break;
case rnn_enum::kLstm:
size = num_layer * direction * seq_length * batch_size * hidden_size * 6;
break;
case rnn_enum::kGru:
size = seq_length * batch_size * hidden_size * direction * num_layer * 8 +
batch_size * hidden_size * direction * 9 + hidden_size * seq_length * 6 +
seq_length * batch_size * 7 * hidden_size * direction;
break;
default:
LOG(FATAL) << "unknown RNN mode " << mode;
break;
Expand Down Expand Up @@ -221,14 +227,18 @@ void RNNForwardTraining(DType* ws,
switch (mode) {
case rnn_enum::kRnnTanh:
case rnn_enum::kRnnRelu:
case rnn_enum::kGru:
LOG(FATAL) << "Only LSTM is supported at the moment";
LOG(FATAL) << "Only LSTM and GRU are supported at the moment";
break;
case rnn_enum::kLstm:
LstmForwardTraining<DType>(ws, rs, state_outputs, num_layers, direction, seq_length,
batch_size, input_size, state_size, x_ptr, hx_ptr, cx_ptr,
w_ptr, b_ptr, y_ptr, hy_ptr, cy_ptr);
break;
case rnn_enum::kGru:
GruForwardTraining<DType>(ws, rs, state_outputs, num_layers, direction, seq_length,
batch_size, input_size, state_size, x_ptr, hx_ptr,
w_ptr, y_ptr, hy_ptr);
break;
default:
LOG(FATAL) << "unknown RNN mode " << mode;
break;
Expand Down Expand Up @@ -256,14 +266,18 @@ void RNNForwardInference(DType* ws,
switch (mode) {
case rnn_enum::kRnnRelu:
case rnn_enum::kRnnTanh:
case rnn_enum::kGru:
LOG(FATAL) << "Only LSTM is supported at the moment";
LOG(FATAL) << "Only LSTM and GRU are supported at the moment";
break;
case rnn_enum::kLstm:
LstmForwardInference<DType>(ws, state_outputs, num_layers, direction, seq_length,
batch_size, input_size, state_size, x_ptr, hx_ptr, cx_ptr,
w_ptr, b_ptr, y_ptr, hy_ptr, cy_ptr);
break;
case rnn_enum::kGru:
GruForwardInference<DType>(ws, state_outputs, num_layers, direction, seq_length,
batch_size, input_size, state_size, x_ptr, hx_ptr,
w_ptr, y_ptr, hy_ptr);
break;
default:
LOG(FATAL) << "unknown RNN mode" << mode;
break;
Expand Down Expand Up @@ -292,16 +306,26 @@ void RNNBackward(DType* ws,
DType* dcx_ptr,
DType* dw_ptr,
DType* db_ptr,
int req_data,
int req_params,
int req_state,
int req_statecell,
int mode) {
switch (mode) {
case rnn_enum::kRnnRelu:
case rnn_enum::kRnnTanh:
case rnn_enum::kGru:
break;
case rnn_enum::kLstm:
LstmBackward<DType>(ws, rs, num_layers, direction, seq_length, batch_size,
input_size, state_size, x_ptr, hx_ptr, cx_ptr, w_ptr, y_ptr,
dy_ptr, dhy_ptr, dcy_ptr, dx_ptr, dhx_ptr, dcx_ptr, dw_ptr, db_ptr);
dy_ptr, dhy_ptr, dcy_ptr, dx_ptr, dhx_ptr, dcx_ptr, dw_ptr, db_ptr,
req_data, req_params, req_state, req_statecell);
break;
case rnn_enum::kGru:
GruBackward<DType>(ws, rs, num_layers, direction, seq_length, batch_size,
input_size, state_size, x_ptr, hx_ptr, w_ptr,
dy_ptr, dhy_ptr, dx_ptr, dhx_ptr, dw_ptr,
req_data, req_params, req_state);
break;
default:
LOG(FATAL) << "unknown RNN mode" << mode;
Expand Down Expand Up @@ -330,7 +354,8 @@ class RNNOp : public Operator{
const std::vector<TBlob> &aux_args) {
using namespace mshadow;
using namespace mshadow::expr;
CHECK_EQ(param_.mode, rnn_enum::kLstm) << "Only lstm mode is supported at the moment.";
CHECK(param_.mode == rnn_enum::kLstm || param_.mode == rnn_enum::kGru)
<< "Only lstm and gru mode are supported at the moment.";
CHECK_EQ(param_.p, 0) << "Dropout is not supported at the moment.";

size_t in_expected = (param_.mode == rnn_enum::kLstm) ? 4 : 3;
Expand Down Expand Up @@ -442,8 +467,10 @@ class RNNOp : public Operator{
const std::vector<TBlob> &aux_args) {
using namespace mshadow;
using namespace mshadow::expr;
CHECK_EQ(param_.mode, rnn_enum::kLstm) << "Only lstm mode is supported at the moment.";
CHECK(param_.mode == rnn_enum::kLstm || param_.mode == rnn_enum::kGru)
<< "Only lstm and gru mode are supported at the moment.";
CHECK_EQ(param_.p, 0) << "Dropout is not supported at the moment.";

size_t in_expected = (param_.mode == rnn_enum::kLstm) ? 4 : 3;
size_t out_expected = (param_.mode == rnn_enum::kLstm) ? 3 : 2;
if (!param_.state_outputs) {
Expand Down Expand Up @@ -535,6 +562,10 @@ class RNNOp : public Operator{
dcx_ptr,
dw.dptr_,
db_ptr,
req[rnn_enum::kData],
req[rnn_enum::kParams],
req[rnn_enum::kState],
req[rnn_enum::kStateCell],
param_.mode);
}

Expand Down

0 comments on commit f726b69

Please sign in to comment.