From 4fde42ca4eaaaf31a74cf2e57958e7923b017ee2 Mon Sep 17 00:00:00 2001 From: Wanchao Liang Date: Thu, 7 Feb 2019 15:03:41 -0800 Subject: [PATCH 1/2] add basic bench for lstm with dropout --- rnns/fastrnns/bench.py | 4 ++-- rnns/fastrnns/custom_lstms.py | 2 +- rnns/fastrnns/factory.py | 21 +++++++++++++++++++++ rnns/fastrnns/runner.py | 1 + 4 files changed, 25 insertions(+), 3 deletions(-) diff --git a/rnns/fastrnns/bench.py b/rnns/fastrnns/bench.py index d691906124..2a1eaf4270 100644 --- a/rnns/fastrnns/bench.py +++ b/rnns/fastrnns/bench.py @@ -149,8 +149,8 @@ def bench(rnn_runners, group_name, print_json=False, sep=' ', **params): args = parser.parse_args() rnns = args.rnns or ['cudnn', 'aten', 'jit', 'jit_premul', 'jit_simple', 'jit_multilayer', 'py'] - # TODO: Maybe add a separate section for the layernorm lstms - # 'jit_layernorm', 'jit_layernom_decom', 'jit' + # TODO: Maybe add a separate section for the layernorm/dropout lstms + # 'jit_layernorm', 'jit_layernom_decom', 'jit', 'jit_dropout' vlrnns = ['vl_cudnn', 'vl_jit', 'vl_py'] cnns = ['resnet18', 'resnet18_jit', 'resnet50', 'resnet50_jit'] if args.print_json: diff --git a/rnns/fastrnns/custom_lstms.py b/rnns/fastrnns/custom_lstms.py index 348483fef2..944181cadc 100644 --- a/rnns/fastrnns/custom_lstms.py +++ b/rnns/fastrnns/custom_lstms.py @@ -306,7 +306,7 @@ class StackedLSTMWithDropout(jit.ScriptModule): __constants__ = ['layers', 'num_layers'] def __init__(self, num_layers, layer, first_layer_args, other_layer_args): - super(StackedLSTM, self).__init__() + super(StackedLSTMWithDropout, self).__init__() self.layers = init_stacked_lstm(num_layers, layer, first_layer_args, other_layer_args) # Introduces a Dropout layer on the outputs of each LSTM layer except diff --git a/rnns/fastrnns/factory.py b/rnns/fastrnns/factory.py index 044eaf513c..22a9031f07 100644 --- a/rnns/fastrnns/factory.py +++ b/rnns/fastrnns/factory.py @@ -95,6 +95,27 @@ def lnlstm_creator(script=True, decompose_layernorm=False, **kwargs): backward=simple_backward) +def dropoutlstm_creator(script=True, **kwargs): + assert script is True + from .custom_lstms import script_lstm + input_size = kwargs['inputSize'] + hidden_size = kwargs['hiddenSize'] + seq_len = kwargs['seqLength'] + batch_size = kwargs['miniBatch'] + ge = script_lstm(input_size, hidden_size, 1, dropout=True).cuda() + + input = torch.randn(seq_len, batch_size, input_size, device='cuda') + states = [(torch.randn(batch_size, hidden_size, device='cuda'), + torch.randn(batch_size, hidden_size, device='cuda'))] + + return ModelDef( + inputs=[input, states], + params=ge.parameters(), + forward=ge, + backward_setup=lstm_backward_setup, + backward=simple_backward) + + def lstm_premul_creator(script=True, **kwargs): input, hidden, params, _ = lstm_inputs(return_module=False, **kwargs) inputs = [input, hidden] + params[0] diff --git a/rnns/fastrnns/runner.py b/rnns/fastrnns/runner.py index 8123e45a4d..16c336eeb2 100644 --- a/rnns/fastrnns/runner.py +++ b/rnns/fastrnns/runner.py @@ -56,6 +56,7 @@ def get_rnn_runners(*names): 'jit_layernorm_decom': RNNRunner('jit_layernorm_decom', partial(lnlstm_creator, decompose_layernorm=True), DummyContext), + 'jit_dropout': RNNRunner('jit_dropout', dropoutlstm_creator, DummyContext), 'py': RNNRunner('py', partial(lstm_creator, script=False), DummyContext), 'resnet18': RNNRunner('resnet18', imagenet_cnn_creator(cnn.resnet18, jit=False), DummyContext), 'resnet18_jit': RNNRunner('resnet18_jit', imagenet_cnn_creator(cnn.resnet18), DummyContext), From d14a68197271a8959cd7c6c027b5ad132c974462 Mon Sep 17 00:00:00 2001 From: Wanchao Liang Date: Mon, 11 Feb 2019 16:50:05 -0800 Subject: [PATCH 2/2] fix jit lstm with dropout and cudnn lstm dropout --- rnns/fastrnns/bench.py | 2 +- rnns/fastrnns/custom_lstms.py | 9 ++++++++- rnns/fastrnns/factory.py | 15 ++++++++------- rnns/fastrnns/runner.py | 1 + 4 files changed, 18 insertions(+), 9 deletions(-) diff --git a/rnns/fastrnns/bench.py b/rnns/fastrnns/bench.py index 2a1eaf4270..6d34943394 100644 --- a/rnns/fastrnns/bench.py +++ b/rnns/fastrnns/bench.py @@ -150,7 +150,7 @@ def bench(rnn_runners, group_name, print_json=False, sep=' ', **params): rnns = args.rnns or ['cudnn', 'aten', 'jit', 'jit_premul', 'jit_simple', 'jit_multilayer', 'py'] # TODO: Maybe add a separate section for the layernorm/dropout lstms - # 'jit_layernorm', 'jit_layernom_decom', 'jit', 'jit_dropout' + # 'jit_layernorm', 'jit_layernom_decom', 'jit', 'jit_dropout', 'cudnn_dropout' vlrnns = ['vl_cudnn', 'vl_jit', 'vl_py'] cnns = ['resnet18', 'resnet18_jit', 'resnet50', 'resnet50_jit'] if args.print_json: diff --git a/rnns/fastrnns/custom_lstms.py b/rnns/fastrnns/custom_lstms.py index 944181cadc..bd64bbc1b8 100644 --- a/rnns/fastrnns/custom_lstms.py +++ b/rnns/fastrnns/custom_lstms.py @@ -2,6 +2,7 @@ import torch.nn as nn from torch.nn import Parameter import torch.jit as jit +import warnings from collections import namedtuple from typing import List, Tuple from torch import Tensor @@ -312,6 +313,12 @@ def __init__(self, num_layers, layer, first_layer_args, other_layer_args): # Introduces a Dropout layer on the outputs of each LSTM layer except # the last layer, with dropout probability = 0.4. self.num_layers = num_layers + + if (num_layers == 1): + warnings.warn("dropout lstm adds dropout layers after all but last " + "recurrent layer, it expects num_layers greater than " + "1, but got num_layers = 1") + self.dropout_layer = nn.Dropout(0.4) @jit.script_method @@ -327,7 +334,7 @@ def forward(self, input, states): output, out_state = rnn_layer(output, state) # Apply the dropout layer except the last layer if i < self.num_layers - 1: - output = self.dropout_layer(output) + output = self.dropout_layer(output) output_states += [out_state] i += 1 return output, output_states diff --git a/rnns/fastrnns/factory.py b/rnns/fastrnns/factory.py index 22a9031f07..7c32825d47 100644 --- a/rnns/fastrnns/factory.py +++ b/rnns/fastrnns/factory.py @@ -97,17 +97,18 @@ def lnlstm_creator(script=True, decompose_layernorm=False, **kwargs): def dropoutlstm_creator(script=True, **kwargs): assert script is True - from .custom_lstms import script_lstm + from .custom_lstms import script_lstm, LSTMState input_size = kwargs['inputSize'] hidden_size = kwargs['hiddenSize'] seq_len = kwargs['seqLength'] batch_size = kwargs['miniBatch'] - ge = script_lstm(input_size, hidden_size, 1, dropout=True).cuda() + num_layers = kwargs['numLayers'] + ge = script_lstm(input_size, hidden_size, num_layers, dropout=True).cuda() input = torch.randn(seq_len, batch_size, input_size, device='cuda') - states = [(torch.randn(batch_size, hidden_size, device='cuda'), - torch.randn(batch_size, hidden_size, device='cuda'))] - + states = [LSTMState(torch.randn(batch_size, hidden_size, device='cuda'), + torch.randn(batch_size, hidden_size, device='cuda')) + for _ in range(num_layers)] return ModelDef( inputs=[input, states], params=ge.parameters(), @@ -291,13 +292,13 @@ def unzip_columns(mat): # returns: x, (hx, cx), all_weights, lstm module with all_weights as params def lstm_inputs(seqLength=100, numLayers=1, inputSize=512, hiddenSize=512, - miniBatch=64, return_module=False, device='cuda', seed=None): + miniBatch=64, dropout=0.0, return_module=False, device='cuda', seed=None): if seed is not None: torch.manual_seed(seed) x = torch.randn(seqLength, miniBatch, inputSize, device=device) hx = torch.randn(numLayers, miniBatch, hiddenSize, device=device) cx = torch.randn(numLayers, miniBatch, hiddenSize, device=device) - lstm = torch.nn.LSTM(inputSize, hiddenSize, numLayers) + lstm = torch.nn.LSTM(inputSize, hiddenSize, numLayers, dropout=dropout) if 'cuda' in device: lstm = lstm.cuda() diff --git a/rnns/fastrnns/runner.py b/rnns/fastrnns/runner.py index 16c336eeb2..5d5583a1ea 100644 --- a/rnns/fastrnns/runner.py +++ b/rnns/fastrnns/runner.py @@ -44,6 +44,7 @@ def get_rnn_runners(*names): rnn_runners = { 'cudnn': RNNRunner('cudnn', pytorch_lstm_creator, DummyContext), + 'cudnn_dropout': RNNRunner('cudnn_dropout', partial(pytorch_lstm_creator, dropout=0.4), DummyContext), 'vl_cudnn': RNNRunner('vl_cudnn', varlen_pytorch_lstm_creator, DummyContext), 'vl_jit': RNNRunner('vl_jit', partial(varlen_lstm_creator, script=True), DummyContext), 'vl_py': RNNRunner('vl_py', varlen_lstm_creator, DummyContext),