Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Exporting RNN/LSTM's Initial State (h0/c0) to ONNX #22813

Closed
1 change: 1 addition & 0 deletions aten/src/ATen/core/interned_strings.h
Expand Up @@ -183,6 +183,7 @@ namespace c10 {
_(onnx, Loop) \
_(onnx, If) \
_(onnx, Reshape) \
_(onnx, Expand) \
_(onnx, Equal) \
_(onnx, Greater) \
_(onnx, Less) \
Expand Down
18 changes: 14 additions & 4 deletions test/onnx/test_pytorch_onnx_caffe2.py
Expand Up @@ -39,8 +39,8 @@
import onnx
import caffe2.python.onnx.backend as c2

from test_pytorch_common import BATCH_SIZE, RNN_BATCH_SIZE, RNN_SEQUENCE_LENGTH, RNN_INPUT_SIZE, RNN_HIDDEN_SIZE
from test_pytorch_common import skipIfTravis, skipIfNoLapack, skipIfNoCuda
from test_pytorch_common import BATCH_SIZE, RNN_BATCH_SIZE, RNN_SEQUENCE_LENGTH, RNN_INPUT_SIZE, RNN_HIDDEN_SIZE
from test_pytorch_common import skipIfUnsupportedOpsetVersion, skipIfUnsupportedMinOpsetVersion
import verify

Expand Down Expand Up @@ -335,7 +335,10 @@ def make_input(batch_size):
self.run_model_test(model, train=False, batch_size=RNN_BATCH_SIZE, input=input, use_gpu=False, atol=1e-7)

# test that the model still runs with a different batch size
onnxir, _ = do_export(model, input, keep_initializers_as_inputs=True)
# (save the model with a batch_size of 1 with rnn with a variable batch size,
# othewise expand will fail)
variable_batch_size_init_input = make_input(1)
onnxir, _ = do_export(model, variable_batch_size_init_input, keep_initializers_as_inputs=True)
other_input = make_input(RNN_BATCH_SIZE + 1)
_ = run_embed_params(onnxir, model, other_input, use_gpu=False)

Expand Down Expand Up @@ -375,7 +378,10 @@ def make_input(batch_size):
self.run_model_test(model, train=False, batch_size=RNN_BATCH_SIZE, input=input, use_gpu=False)

# test that the model still runs with a different batch size
onnxir, _ = do_export(model, input, keep_initializers_as_inputs=True)
# (save the model with a batch_size of 1 with rnn with a variable batch size,
# othewise expand will fail)
variable_batch_size_init_input = make_input(1)
onnxir, _ = do_export(model, variable_batch_size_init_input, keep_initializers_as_inputs=True)
other_input = make_input(RNN_BATCH_SIZE + 1)
_ = run_embed_params(onnxir, model, other_input, use_gpu=False)

Expand Down Expand Up @@ -413,7 +419,10 @@ def make_input(batch_size):
self.run_model_test(model, train=False, batch_size=RNN_BATCH_SIZE, input=input, use_gpu=False)

# test that the model still runs with a different batch size
onnxir, _ = do_export(model, input, keep_initializers_as_inputs=True)
# (save the model with a batch_size of 1 with rnn with a variable batch size,
# othewise expand will fail)
variable_batch_size_init_input = make_input(1)
onnxir, _ = do_export(model, variable_batch_size_init_input, keep_initializers_as_inputs=True)
other_input = make_input(RNN_BATCH_SIZE + 1)
_ = run_embed_params(onnxir, model, other_input, use_gpu=False)

Expand Down Expand Up @@ -2131,6 +2140,7 @@ def make_test(name, base, layer, bidirectional, initial_state,
]))

@skipIfUnsupportedOpsetVersion([10])
@skipIfUnsupportedMinOpsetVersion(8)
def f(self):
self._dispatch_rnn_test(
base,
Expand Down
93 changes: 76 additions & 17 deletions test/onnx/test_pytorch_onnx_onnxruntime.py
Expand Up @@ -15,6 +15,7 @@
from model_defs.lstm_flattening_result import LstmFlatteningResult
from model_defs.rnn_model_with_packed_sequence import RnnModelWithPackedSequence
from test_pytorch_common import skipIfUnsupportedMinOpsetVersion, skipIfUnsupportedOpsetVersion
from test_pytorch_common import BATCH_SIZE
from test_pytorch_common import RNN_BATCH_SIZE, RNN_SEQUENCE_LENGTH, RNN_INPUT_SIZE, RNN_HIDDEN_SIZE
import model_defs.word_language_model as word_language_model

Expand Down Expand Up @@ -46,7 +47,8 @@ def run_model_test(self, model, batch_size=2, state_dict=None,
input=None, use_gpu=True, rtol=0.001, atol=1e-7,
example_outputs=None, do_constant_folding=True,
dynamic_axes=None, test_with_inputs=None,
input_names=None, output_names=None):
input_names=None, output_names=None,
fixed_batch_size=False):
model.eval()

if input is None:
Expand All @@ -61,14 +63,14 @@ def run_model_test(self, model, batch_size=2, state_dict=None,

# export the model to ONNX
f = io.BytesIO()
torch.onnx.export(model, input, f,
opset_version=self.opset_version,
example_outputs=output,
do_constant_folding=do_constant_folding,
keep_initializers_as_inputs=self.keep_initializers_as_inputs,
dynamic_axes=dynamic_axes,
input_names=input_names, output_names=output_names)

torch.onnx._export(model, input, f,
lara-hdr marked this conversation as resolved.
Show resolved Hide resolved
opset_version=self.opset_version,
example_outputs=output,
do_constant_folding=do_constant_folding,
keep_initializers_as_inputs=self.keep_initializers_as_inputs,
dynamic_axes=dynamic_axes,
input_names=input_names, output_names=output_names,
fixed_batch_size=fixed_batch_size)

# compute onnxruntime output prediction
ort_sess = onnxruntime.InferenceSession(f.getvalue())
Expand All @@ -83,7 +85,6 @@ def run_model_test(self, model, batch_size=2, state_dict=None,
output = model(*test_input)
if isinstance(output, torch.Tensor):
output = (output,)

ort_test_with_input(ort_sess, test_input, output, rtol, atol)


Expand All @@ -98,14 +99,15 @@ def setUp(self):
torch.cuda.manual_seed_all(0)
np.random.seed(seed=0)

def run_test(self, model, input, rtol=1e-3, atol=1e-7, do_constant_folding=True,
def run_test(self, model, input, rtol=1e-3, atol=1e-7, do_constant_folding=False,
batch_size=2, use_gpu=True, dynamic_axes=None, test_with_inputs=None,
input_names=None, output_names=None):
run_model_test(self, model, batch_size=batch_size,
input=input, use_gpu=use_gpu, rtol=rtol, atol=atol,
do_constant_folding=do_constant_folding,
dynamic_axes=dynamic_axes, test_with_inputs=test_with_inputs,
input_names=input_names, output_names=output_names)
input_names=None, output_names=None, fixed_batch_size=False):
return run_model_test(self, model, batch_size=batch_size,
input=input, use_gpu=use_gpu, rtol=rtol, atol=atol,
do_constant_folding=do_constant_folding,
dynamic_axes=dynamic_axes, test_with_inputs=test_with_inputs,
input_names=input_names, output_names=output_names,
fixed_batch_size=fixed_batch_size)

def run_word_language_model(self, model_name):
ntokens = 50
Expand Down Expand Up @@ -507,6 +509,63 @@ def forward(self, input):
x = torch.randn(4, 4, requires_grad=True)
self.run_test(ReduceLogSumExpModel(), x)

@skipIfUnsupportedMinOpsetVersion(9)
def test_lstm(self):
model = torch.nn.LSTM(RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, 1, bidirectional=False)
input = torch.randn(RNN_SEQUENCE_LENGTH, BATCH_SIZE, RNN_INPUT_SIZE)
h0 = torch.randn(1, BATCH_SIZE, RNN_HIDDEN_SIZE)
c0 = torch.randn(1, BATCH_SIZE, RNN_HIDDEN_SIZE)
self.run_test(model, (input, (h0, c0)))

@skipIfUnsupportedMinOpsetVersion(9)
def test_lstm_default_init_state(self):
model = torch.nn.LSTM(RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, 1, bidirectional=False)
input = torch.randn(RNN_SEQUENCE_LENGTH, BATCH_SIZE, RNN_INPUT_SIZE)
self.run_test(model, input)

@skipIfUnsupportedMinOpsetVersion(9)
def test_lstm_fixed_batch_size(self):
class LSTMModel(torch.nn.Module):
def __init__(self):
super(LSTMModel, self).__init__()
self.lstm = torch.nn.LSTM(RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, 1, bidirectional=False)

def forward(self, input):
batch_size = input.size()[1]
h0_np = np.ones([1, batch_size, RNN_HIDDEN_SIZE]).astype(np.float32)
c0_np = np.ones([1, batch_size, RNN_HIDDEN_SIZE]).astype(np.float32)
h0 = torch.from_numpy(h0_np)
c0 = torch.from_numpy(c0_np)
return self.lstm(input, (h0, c0))

input = torch.randn(RNN_SEQUENCE_LENGTH, BATCH_SIZE, RNN_INPUT_SIZE)
# verify with different input of same batch size
input2 = torch.randn(RNN_SEQUENCE_LENGTH, BATCH_SIZE, RNN_INPUT_SIZE)
self.run_test(LSTMModel(), input, fixed_batch_size=True, test_with_inputs=[input2])

@skipIfUnsupportedMinOpsetVersion(9)
def test_lstm_post_fix_init_state(self):
class LSTMModel(torch.nn.Module):
def __init__(self):
super(LSTMModel, self).__init__()
self.lstm = torch.nn.LSTM(RNN_INPUT_SIZE, RNN_HIDDEN_SIZE,
1, bidirectional=False)

def forward(self, input):
batch_size = input.size()[1]
h0_np = np.ones([1, batch_size, RNN_HIDDEN_SIZE]).astype(np.float32)
c0_np = np.ones([1, batch_size, RNN_HIDDEN_SIZE]).astype(np.float32)
h0 = torch.from_numpy(h0_np)
c0 = torch.from_numpy(c0_np)
return self.lstm(input, (h0, c0))

model = LSTMModel()
input = torch.randn(RNN_SEQUENCE_LENGTH, 1, RNN_INPUT_SIZE)
# verify with different input of different batch size
input2 = torch.randn(RNN_SEQUENCE_LENGTH, BATCH_SIZE, RNN_INPUT_SIZE)
self.run_test(model, input, dynamic_axes={'input' : {0 : 'seq', 1 : 'batch'}},
test_with_inputs=[input2])

def test_lstm_constant_folding(self):
class LstmNet(torch.nn.Module):
def __init__(self, input_size, hidden_size, num_layers, bidirectional):
Expand Down
7 changes: 6 additions & 1 deletion torch/csrc/jit/init.cpp
Expand Up @@ -108,7 +108,12 @@ void initJITBindings(PyObject* module) {
.def("_jit_pass_onnx_preprocess_caffe2", PreprocessCaffe2Ops)
.def("_jit_pass_onnx", ToONNX)
.def("_jit_pass_lower_all_tuples", LowerAllTuples)
.def("_jit_pass_onnx_peephole", PeepholeOptimizeONNX)
.def("_jit_pass_onnx_peephole",
[](std::shared_ptr<Graph>& graph,
int opset_version,
bool fixed_batch_size) {
return PeepholeOptimizeONNX(graph, opset_version, fixed_batch_size);
})
.def(
"_jit_pass_onnx_cast_all_constant_to_floating",
CastAllConstantToFloating)
Expand Down
46 changes: 24 additions & 22 deletions torch/csrc/jit/passes/onnx/peephole.cpp
Expand Up @@ -360,13 +360,18 @@ void hackFixupPadPackedShapes(Block* graph) {
void fixDefaultRNNState(Graph* graph, Node* n, int input_index, int opset_version) {
auto initial_state = n->inputs()[input_index];

// The RNN code in pytorch accepts an optional hidden state. When it
// is provided, everything works great. When it is not provided, it
// is default-initialized by constructing a new Variable, which gets
// traced as a Constant. Recognize that pattern here and replace it
// with something that doesn't fix the batch size. Note that for
// multi-layer RNNs there will be a Slice operation between the
// Constant and the RNN.
// The RNN code in pytorch accepts an optional hidden state.
// 1- When it is provided as an input, everything works great.
// 2- When it is not provided, it is default-initialized by constructing a new Variable, which gets
// traced as a ConstantOfShape with the expected Shape.
// 3- When the batch size is fixed, everything works great as well.
// 4- When h0 and c0 are specified but are not inputs of the model (they are Constants)
// and the batch size is variable, the model should be saved with a batch size of 1
// (or an error will occur), and we save the value of h0 and c0 with a batch size of 1.
// When the model is then called with a different batch size value, h0 and c0 are broadcasted
// to get the right shape.
// Recognize that last pattern here (4) and fix the shape.
// Note that for multi-layer RNNs there will be a Slice operation between the Constant and the RNN.
bool needsFixing = initial_state->node()->kind() == onnx::Constant ||
(initial_state->node()->kind() == onnx::Slice &&
initial_state->node()->inputs()[0]->node()->kind() == onnx::Constant);
Expand Down Expand Up @@ -426,18 +431,11 @@ void fixDefaultRNNState(Graph* graph, Node* n, int input_index, int opset_versio
concated_dims->addInput(unsqueezed_batch_size->outputs()[0]);
concated_dims->addInput(hidden_size->outputs()[0]);

if (opset_version < 9) {
Node* constant_fill = graph->create(onnx::ConstantFill, 1);
constant_fill->insertBefore(n);
constant_fill->i_(attr::input_as_shape, 1);
constant_fill->addInput(concated_dims->outputs()[0]);
n->replaceInput(input_index, constant_fill->outputs()[0]);
} else {
Node* constant_of_shape = graph->create(onnx::ConstantOfShape, 1);
constant_of_shape->insertBefore(n);
constant_of_shape->addInput(concated_dims->outputs()[0]);
n->replaceInput(input_index, constant_of_shape->outputs()[0]);
}
Node* fixed_init_state = graph->create(onnx::Expand, 1);
fixed_init_state->insertBefore(n);
fixed_init_state->addInput(initial_state);
fixed_init_state->addInput(concated_dims->outputs()[0]);
n->replaceInput(input_index, fixed_init_state->outputs()[0]);

if (initial_state->uses().size() == 0) {
initial_state->node()->destroy();
Expand Down Expand Up @@ -639,15 +637,19 @@ void removeMaxPoolUnusedOutput(Block* b) {
// writing your optimization in jit/passes/peephole.cpp rather than
// here, as it will be generally applicable to the JIT as well. The
// optimizations here are ONLY applied on ONNX update
void PeepholeOptimizeONNX(std::shared_ptr<Graph>& graph, int opset_version) {
void PeepholeOptimizeONNX(std::shared_ptr<Graph>& graph, int opset_version, bool fixed_batch_size) {
// TODO: decide on fixpoint strategy
// TODO: make it easier not to do O(k) iterations over the graph, where
// k is the number of distinct peephole optimizations
hackFixupPadPackedShapes(graph->block());
pushPackingPastRnn(graph->block());
removeNopPacking(graph->block());
fixDefaultRnnHiddenState(graph->block(), opset_version);
fixDefaultLstmCellState(graph->block(), opset_version);
// we only need to fix the size of hidden state and cell state if the batch size is variable
if(!fixed_batch_size)
{
fixDefaultRnnHiddenState(graph->block(), opset_version);
fixDefaultLstmCellState(graph->block(), opset_version);
}
fuseBroadcast(graph->block());
fuseConsecutiveTransposes(graph->block());
eliminateNopTranspose(graph->block());
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/passes/onnx/peephole.h
Expand Up @@ -5,7 +5,7 @@
namespace torch {
namespace jit {

void PeepholeOptimizeONNX(std::shared_ptr<Graph>& graph, int opset_version);
void PeepholeOptimizeONNX(std::shared_ptr<Graph>& graph, int opset_version, bool fixed_batch_size);

}
} // namespace torch
7 changes: 7 additions & 0 deletions torch/onnx/symbolic_opset9.py
Expand Up @@ -1348,6 +1348,13 @@ def group_norm(g, input, num_groups, weight, bias, eps, cudnn_enabled):

def _generic_rnn(g, variant, input, initial_states, all_weights, has_biases,
num_layers, dropout, train, bidirectional, batch_first=None, batch_sizes=None):

warnings.warn("Exporting a model to ONNX with a batch_size other than 1, " +
"with a variable lenght with " + variant + " can cause an error " +
"when running the ONNX model with a different batch size. " +
"Make sure to save the model with a batch size of 1, " +
"or define the initial states (h0/c0) as inputs of the model. ")

onnxActivations = ['Relu', 'Tanh', 'Sigmoid', 'Affine', 'LeakyRelu', 'ThresholdedRelu',
'ScaledTanh', 'HardSigmoid', 'Elu', 'Softsign', 'Softplus']
variantToOnnxActivationMap = dict(zip([act_fun.lower() for act_fun in onnxActivations], onnxActivations))
Expand Down