Skip to content

Commit

Permalink
Fix Exporting RNN/LSTM's Initial State (h0/c0) to ONNX
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #22813

Reviewed By: hl475

Differential Revision: D16275791

Pulled By: houseroad

fbshipit-source-id: 6e2259e84e1f5a674daabcbe0df99b1360ed2b35
  • Loading branch information
lara-hdr authored and facebook-github-bot committed Sep 24, 2019
1 parent cb9fd0c commit 3569a1c
Show file tree
Hide file tree
Showing 8 changed files with 139 additions and 54 deletions.
1 change: 1 addition & 0 deletions aten/src/ATen/core/interned_strings.h
Expand Up @@ -183,6 +183,7 @@ namespace c10 {
_(onnx, Loop) \
_(onnx, If) \
_(onnx, Reshape) \
_(onnx, Expand) \
_(onnx, Equal) \
_(onnx, Greater) \
_(onnx, Less) \
Expand Down
18 changes: 14 additions & 4 deletions test/onnx/test_pytorch_onnx_caffe2.py
Expand Up @@ -39,8 +39,8 @@
import onnx
import caffe2.python.onnx.backend as c2

from test_pytorch_common import BATCH_SIZE, RNN_BATCH_SIZE, RNN_SEQUENCE_LENGTH, RNN_INPUT_SIZE, RNN_HIDDEN_SIZE
from test_pytorch_common import skipIfTravis, skipIfNoLapack, skipIfNoCuda
from test_pytorch_common import BATCH_SIZE, RNN_BATCH_SIZE, RNN_SEQUENCE_LENGTH, RNN_INPUT_SIZE, RNN_HIDDEN_SIZE
from test_pytorch_common import skipIfUnsupportedOpsetVersion, skipIfUnsupportedMinOpsetVersion
import verify

Expand Down Expand Up @@ -335,7 +335,10 @@ def make_input(batch_size):
self.run_model_test(model, train=False, batch_size=RNN_BATCH_SIZE, input=input, use_gpu=False, atol=1e-7)

# test that the model still runs with a different batch size
onnxir, _ = do_export(model, input, keep_initializers_as_inputs=True)
# (save the model with a batch_size of 1 with rnn with a variable batch size,
# othewise expand will fail)
variable_batch_size_init_input = make_input(1)
onnxir, _ = do_export(model, variable_batch_size_init_input, keep_initializers_as_inputs=True)
other_input = make_input(RNN_BATCH_SIZE + 1)
_ = run_embed_params(onnxir, model, other_input, use_gpu=False)

Expand Down Expand Up @@ -375,7 +378,10 @@ def make_input(batch_size):
self.run_model_test(model, train=False, batch_size=RNN_BATCH_SIZE, input=input, use_gpu=False)

# test that the model still runs with a different batch size
onnxir, _ = do_export(model, input, keep_initializers_as_inputs=True)
# (save the model with a batch_size of 1 with rnn with a variable batch size,
# othewise expand will fail)
variable_batch_size_init_input = make_input(1)
onnxir, _ = do_export(model, variable_batch_size_init_input, keep_initializers_as_inputs=True)
other_input = make_input(RNN_BATCH_SIZE + 1)
_ = run_embed_params(onnxir, model, other_input, use_gpu=False)

Expand Down Expand Up @@ -413,7 +419,10 @@ def make_input(batch_size):
self.run_model_test(model, train=False, batch_size=RNN_BATCH_SIZE, input=input, use_gpu=False)

# test that the model still runs with a different batch size
onnxir, _ = do_export(model, input, keep_initializers_as_inputs=True)
# (save the model with a batch_size of 1 with rnn with a variable batch size,
# othewise expand will fail)
variable_batch_size_init_input = make_input(1)
onnxir, _ = do_export(model, variable_batch_size_init_input, keep_initializers_as_inputs=True)
other_input = make_input(RNN_BATCH_SIZE + 1)
_ = run_embed_params(onnxir, model, other_input, use_gpu=False)

Expand Down Expand Up @@ -2211,6 +2220,7 @@ def make_test(name, base, layer, bidirectional, initial_state,
]))

@skipIfUnsupportedOpsetVersion([10])
@skipIfUnsupportedMinOpsetVersion(8)
def f(self):
self._dispatch_rnn_test(
base,
Expand Down
93 changes: 76 additions & 17 deletions test/onnx/test_pytorch_onnx_onnxruntime.py
Expand Up @@ -15,6 +15,7 @@
from model_defs.lstm_flattening_result import LstmFlatteningResult
from model_defs.rnn_model_with_packed_sequence import RnnModelWithPackedSequence
from test_pytorch_common import skipIfUnsupportedMinOpsetVersion, skipIfUnsupportedOpsetVersion
from test_pytorch_common import BATCH_SIZE
from test_pytorch_common import RNN_BATCH_SIZE, RNN_SEQUENCE_LENGTH, RNN_INPUT_SIZE, RNN_HIDDEN_SIZE
import model_defs.word_language_model as word_language_model

Expand Down Expand Up @@ -46,7 +47,8 @@ def run_model_test(self, model, batch_size=2, state_dict=None,
input=None, use_gpu=True, rtol=0.001, atol=1e-7,
example_outputs=None, do_constant_folding=True,
dynamic_axes=None, test_with_inputs=None,
input_names=None, output_names=None):
input_names=None, output_names=None,
fixed_batch_size=False):
model.eval()

if input is None:
Expand All @@ -61,14 +63,14 @@ def run_model_test(self, model, batch_size=2, state_dict=None,

# export the model to ONNX
f = io.BytesIO()
torch.onnx.export(model, input, f,
opset_version=self.opset_version,
example_outputs=output,
do_constant_folding=do_constant_folding,
keep_initializers_as_inputs=self.keep_initializers_as_inputs,
dynamic_axes=dynamic_axes,
input_names=input_names, output_names=output_names)

torch.onnx._export(model, input, f,
opset_version=self.opset_version,
example_outputs=output,
do_constant_folding=do_constant_folding,
keep_initializers_as_inputs=self.keep_initializers_as_inputs,
dynamic_axes=dynamic_axes,
input_names=input_names, output_names=output_names,
fixed_batch_size=fixed_batch_size)

# compute onnxruntime output prediction
ort_sess = onnxruntime.InferenceSession(f.getvalue())
Expand All @@ -83,7 +85,6 @@ def run_model_test(self, model, batch_size=2, state_dict=None,
output = model(*test_input)
if isinstance(output, torch.Tensor):
output = (output,)

ort_test_with_input(ort_sess, test_input, output, rtol, atol)


Expand All @@ -98,14 +99,15 @@ def setUp(self):
torch.cuda.manual_seed_all(0)
np.random.seed(seed=0)

def run_test(self, model, input, rtol=1e-3, atol=1e-7, do_constant_folding=True,
def run_test(self, model, input, rtol=1e-3, atol=1e-7, do_constant_folding=False,
batch_size=2, use_gpu=True, dynamic_axes=None, test_with_inputs=None,
input_names=None, output_names=None):
run_model_test(self, model, batch_size=batch_size,
input=input, use_gpu=use_gpu, rtol=rtol, atol=atol,
do_constant_folding=do_constant_folding,
dynamic_axes=dynamic_axes, test_with_inputs=test_with_inputs,
input_names=input_names, output_names=output_names)
input_names=None, output_names=None, fixed_batch_size=False):
return run_model_test(self, model, batch_size=batch_size,
input=input, use_gpu=use_gpu, rtol=rtol, atol=atol,
do_constant_folding=do_constant_folding,
dynamic_axes=dynamic_axes, test_with_inputs=test_with_inputs,
input_names=input_names, output_names=output_names,
fixed_batch_size=fixed_batch_size)

def run_word_language_model(self, model_name):
ntokens = 50
Expand Down Expand Up @@ -564,6 +566,63 @@ def forward(self, input):
x = torch.randn(4, 4, requires_grad=True)
self.run_test(ReduceLogSumExpModel(), x)

@skipIfUnsupportedMinOpsetVersion(9)
def test_lstm(self):
model = torch.nn.LSTM(RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, 1, bidirectional=False)
input = torch.randn(RNN_SEQUENCE_LENGTH, BATCH_SIZE, RNN_INPUT_SIZE)
h0 = torch.randn(1, BATCH_SIZE, RNN_HIDDEN_SIZE)
c0 = torch.randn(1, BATCH_SIZE, RNN_HIDDEN_SIZE)
self.run_test(model, (input, (h0, c0)))

@skipIfUnsupportedMinOpsetVersion(9)
def test_lstm_default_init_state(self):
model = torch.nn.LSTM(RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, 1, bidirectional=False)
input = torch.randn(RNN_SEQUENCE_LENGTH, BATCH_SIZE, RNN_INPUT_SIZE)
self.run_test(model, input)

@skipIfUnsupportedMinOpsetVersion(9)
def test_lstm_fixed_batch_size(self):
class LSTMModel(torch.nn.Module):
def __init__(self):
super(LSTMModel, self).__init__()
self.lstm = torch.nn.LSTM(RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, 1, bidirectional=False)

def forward(self, input):
batch_size = input.size()[1]
h0_np = np.ones([1, batch_size, RNN_HIDDEN_SIZE]).astype(np.float32)
c0_np = np.ones([1, batch_size, RNN_HIDDEN_SIZE]).astype(np.float32)
h0 = torch.from_numpy(h0_np)
c0 = torch.from_numpy(c0_np)
return self.lstm(input, (h0, c0))

input = torch.randn(RNN_SEQUENCE_LENGTH, BATCH_SIZE, RNN_INPUT_SIZE)
# verify with different input of same batch size
input2 = torch.randn(RNN_SEQUENCE_LENGTH, BATCH_SIZE, RNN_INPUT_SIZE)
self.run_test(LSTMModel(), input, fixed_batch_size=True, test_with_inputs=[input2])

@skipIfUnsupportedMinOpsetVersion(9)
def test_lstm_post_fix_init_state(self):
class LSTMModel(torch.nn.Module):
def __init__(self):
super(LSTMModel, self).__init__()
self.lstm = torch.nn.LSTM(RNN_INPUT_SIZE, RNN_HIDDEN_SIZE,
1, bidirectional=False)

def forward(self, input):
batch_size = input.size()[1]
h0_np = np.ones([1, batch_size, RNN_HIDDEN_SIZE]).astype(np.float32)
c0_np = np.ones([1, batch_size, RNN_HIDDEN_SIZE]).astype(np.float32)
h0 = torch.from_numpy(h0_np)
c0 = torch.from_numpy(c0_np)
return self.lstm(input, (h0, c0))

model = LSTMModel()
input = torch.randn(RNN_SEQUENCE_LENGTH, 1, RNN_INPUT_SIZE)
# verify with different input of different batch size
input2 = torch.randn(RNN_SEQUENCE_LENGTH, BATCH_SIZE, RNN_INPUT_SIZE)
self.run_test(model, input, dynamic_axes={'input' : {0 : 'seq', 1 : 'batch'}},
test_with_inputs=[input2])

def test_lstm_constant_folding(self):
class LstmNet(torch.nn.Module):
def __init__(self, input_size, hidden_size, num_layers, bidirectional):
Expand Down
7 changes: 6 additions & 1 deletion torch/csrc/jit/init.cpp
Expand Up @@ -109,7 +109,12 @@ void initJITBindings(PyObject* module) {
.def("_jit_pass_onnx_preprocess_caffe2", PreprocessCaffe2Ops)
.def("_jit_pass_onnx", ToONNX)
.def("_jit_pass_lower_all_tuples", LowerAllTuples)
.def("_jit_pass_onnx_peephole", PeepholeOptimizeONNX)
.def("_jit_pass_onnx_peephole",
[](std::shared_ptr<Graph>& graph,
int opset_version,
bool fixed_batch_size) {
return PeepholeOptimizeONNX(graph, opset_version, fixed_batch_size);
})
.def(
"_jit_pass_onnx_cast_all_constant_to_floating",
CastAllConstantToFloating)
Expand Down
46 changes: 24 additions & 22 deletions torch/csrc/jit/passes/onnx/peephole.cpp
Expand Up @@ -360,13 +360,18 @@ void hackFixupPadPackedShapes(Block* graph) {
void fixDefaultRNNState(Graph* graph, Node* n, int input_index, int opset_version) {
auto initial_state = n->inputs()[input_index];

// The RNN code in pytorch accepts an optional hidden state. When it
// is provided, everything works great. When it is not provided, it
// is default-initialized by constructing a new Variable, which gets
// traced as a Constant. Recognize that pattern here and replace it
// with something that doesn't fix the batch size. Note that for
// multi-layer RNNs there will be a Slice operation between the
// Constant and the RNN.
// The RNN code in pytorch accepts an optional hidden state.
// 1- When it is provided as an input, everything works great.
// 2- When it is not provided, it is default-initialized by constructing a new Variable, which gets
// traced as a ConstantOfShape with the expected Shape.
// 3- When the batch size is fixed, everything works great as well.
// 4- When h0 and c0 are specified but are not inputs of the model (they are Constants)
// and the batch size is variable, the model should be saved with a batch size of 1
// (or an error will occur), and we save the value of h0 and c0 with a batch size of 1.
// When the model is then called with a different batch size value, h0 and c0 are broadcasted
// to get the right shape.
// Recognize that last pattern here (4) and fix the shape.
// Note that for multi-layer RNNs there will be a Slice operation between the Constant and the RNN.
bool needsFixing = initial_state->node()->kind() == onnx::Constant ||
(initial_state->node()->kind() == onnx::Slice &&
initial_state->node()->inputs()[0]->node()->kind() == onnx::Constant);
Expand Down Expand Up @@ -426,18 +431,11 @@ void fixDefaultRNNState(Graph* graph, Node* n, int input_index, int opset_versio
concated_dims->addInput(unsqueezed_batch_size->outputs()[0]);
concated_dims->addInput(hidden_size->outputs()[0]);

if (opset_version < 9) {
Node* constant_fill = graph->create(onnx::ConstantFill, 1);
constant_fill->insertBefore(n);
constant_fill->i_(attr::input_as_shape, 1);
constant_fill->addInput(concated_dims->outputs()[0]);
n->replaceInput(input_index, constant_fill->outputs()[0]);
} else {
Node* constant_of_shape = graph->create(onnx::ConstantOfShape, 1);
constant_of_shape->insertBefore(n);
constant_of_shape->addInput(concated_dims->outputs()[0]);
n->replaceInput(input_index, constant_of_shape->outputs()[0]);
}
Node* fixed_init_state = graph->create(onnx::Expand, 1);
fixed_init_state->insertBefore(n);
fixed_init_state->addInput(initial_state);
fixed_init_state->addInput(concated_dims->outputs()[0]);
n->replaceInput(input_index, fixed_init_state->outputs()[0]);

if (initial_state->uses().size() == 0) {
initial_state->node()->destroy();
Expand Down Expand Up @@ -639,15 +637,19 @@ void removeMaxPoolUnusedOutput(Block* b) {
// writing your optimization in jit/passes/peephole.cpp rather than
// here, as it will be generally applicable to the JIT as well. The
// optimizations here are ONLY applied on ONNX update
void PeepholeOptimizeONNX(std::shared_ptr<Graph>& graph, int opset_version) {
void PeepholeOptimizeONNX(std::shared_ptr<Graph>& graph, int opset_version, bool fixed_batch_size) {
// TODO: decide on fixpoint strategy
// TODO: make it easier not to do O(k) iterations over the graph, where
// k is the number of distinct peephole optimizations
hackFixupPadPackedShapes(graph->block());
pushPackingPastRnn(graph->block());
removeNopPacking(graph->block());
fixDefaultRnnHiddenState(graph->block(), opset_version);
fixDefaultLstmCellState(graph->block(), opset_version);
// we only need to fix the size of hidden state and cell state if the batch size is variable
if(!fixed_batch_size)
{
fixDefaultRnnHiddenState(graph->block(), opset_version);
fixDefaultLstmCellState(graph->block(), opset_version);
}
fuseBroadcast(graph->block());
fuseConsecutiveTransposes(graph->block());
eliminateNopTranspose(graph->block());
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/passes/onnx/peephole.h
Expand Up @@ -5,7 +5,7 @@
namespace torch {
namespace jit {

void PeepholeOptimizeONNX(std::shared_ptr<Graph>& graph, int opset_version);
void PeepholeOptimizeONNX(std::shared_ptr<Graph>& graph, int opset_version, bool fixed_batch_size);

}
} // namespace torch
7 changes: 7 additions & 0 deletions torch/onnx/symbolic_opset9.py
Expand Up @@ -1332,6 +1332,13 @@ def group_norm(g, input, num_groups, weight, bias, eps, cudnn_enabled):

def _generic_rnn(g, variant, input, initial_states, all_weights, has_biases,
num_layers, dropout, train, bidirectional, batch_first=None, batch_sizes=None):

warnings.warn("Exporting a model to ONNX with a batch_size other than 1, " +
"with a variable lenght with " + variant + " can cause an error " +
"when running the ONNX model with a different batch size. " +
"Make sure to save the model with a batch size of 1, " +
"or define the initial states (h0/c0) as inputs of the model. ")

onnxActivations = ['Relu', 'Tanh', 'Sigmoid', 'Affine', 'LeakyRelu', 'ThresholdedRelu',
'ScaledTanh', 'HardSigmoid', 'Elu', 'Softsign', 'Softplus']
variantToOnnxActivationMap = dict(zip([act_fun.lower() for act_fun in onnxActivations], onnxActivations))
Expand Down

0 comments on commit 3569a1c

Please sign in to comment.