Skip to content

Add onnx support for InstanceNorm #4626

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 7, 2018
Merged

Conversation

houseroad
Copy link
Member

Now, we export the InstanceNorm as an InstanceNorm op, not Reshape + BatchNorm + Reshape

@houseroad
Copy link
Member Author

Address #4584, it works end 2 end so far. But some unused tensors got exported, too. Will fix it soon.

@houseroad
Copy link
Member Author

@ezyang, the unused stuff got removed, ready for review. :-)

@onnxbot retest this please

@houseroad
Copy link
Member Author

Please ignore the failed onnx-fb-universe CI, I already updated the expected files in onnxbot/onnx-fb-universe#268

eps=eps, affine=affine)


def instance_norm(input, weight, bias, saved_running_mean, saved_running_var,

This comment was marked as off-topic.

This comment was marked as off-topic.

// The jit tracer may record some unnecessary information.
// We should remove the unused input and initializer.
// Gather all the names of useful nodes and filter out unused ones.
std::unordered_set<std::string> useful_names;

This comment was marked as off-topic.

This comment was marked as off-topic.

for (auto input : g->inputs()) {
input_names.push_back(value_name(input));
if (useful_names.find(value_name(input)) == useful_names.end()) {
continue;

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

@ezyang
Copy link
Contributor

ezyang commented Jan 16, 2018

CC @linkerzhang

Copy link
Collaborator

@dzhulgakov dzhulgakov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool, when it lands I can also upstream some changes from https://github.com/dzhulgakov/fast-neural-style to pytorch/examples (onnx exporter)

for (auto input : g->inputs()) {
input_names.push_back(value_name(input));
if (useful_names.find(value_name(input)) == useful_names.end()) {
continue;

This comment was marked as off-topic.

@@ -619,3 +619,15 @@ def symbolic(g, input, all_weights, h0, **fkwargs):
return prev_output, h_outs

return torch.onnx.symbolic_override(symbolic)


def instance_norm_symbolic_builder(*args, **kwargs):

This comment was marked as off-topic.

This comment was marked as off-topic.

"""
import torch
func = instance_norm
if torch._C._jit_is_tracing(input):

This comment was marked as off-topic.

This comment was marked as off-topic.

flat_args = tuple(function._iter_variables(args))
if not any(map(torch._C._jit_is_tracing, flat_args)):
flat_args = tuple(function._iter_None_variables(args))
if not any(map(lambda x: False if x is None else torch._C._jit_is_tracing(x), flat_args)):

This comment was marked as off-topic.

This comment was marked as off-topic.

@ezyang
Copy link
Contributor

ezyang commented Jan 19, 2018

I'm on the hook for finishing this.

Copy link
Collaborator

@dzhulgakov dzhulgakov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For mapping the tensor inputs - I wish we could find a better solution.

@@ -619,3 +619,15 @@ def symbolic(g, input, all_weights, h0, **fkwargs):
return prev_output, h_outs

return torch.onnx.symbolic_override(symbolic)


def instance_norm_symbolic_builder(*args, **kwargs):

This comment was marked as off-topic.

@houseroad houseroad force-pushed the instancenorm branch 3 times, most recently from 6f2776d to 60abfe9 Compare February 6, 2018 19:42
@@ -215,7 +236,10 @@ void encodeGraph(onnx::GraphProto * p_g, const std::shared_ptr<Graph> & g, const
for (auto & tensor : initializers) {
// TODO: stop using positions to determine which initializers
// match to which inputs
std::string name = p_g->get_input_name(inputs_count++);
std::string name = input_names[inputs_count++];
if (useful_names.find(name) == useful_names.end()) {

This comment was marked as off-topic.

@@ -181,7 +181,28 @@ void encodeGraph(onnx::GraphProto * p_g, const std::shared_ptr<Graph> & g, const
JIT_ASSERT(p_g != nullptr);
p_g->set_name("torch-jit-export");

// The jit tracer may record some unnecessary information.
// We should remove the unused input and initializer.
// Gather all the names of useful nodes and filter out unused ones.

This comment was marked as off-topic.

@ezyang ezyang merged commit c111cdf into pytorch:master Feb 7, 2018
@CynthiaLu1119
Copy link

After reading this post, I tried to install pytorch by source and the onnx does work on instanceNorm2D now. Thanks.

However, I encountered some other problems. I trained some pytorch model using the pytorch version installed through http://pytorch.org/. I was able to save the trained model as .pth file and load it back in successfully by doing the following:

model = myNet(parameters...) model.load_state_dict(torch.load('xxx.pth'))

However, after I install the latest pytorch from source. The load_state_dict is reporting errors:

File "/anaconda2/lib/python2.7/site-packages/torch/nn/modules/module.py", line 528, in load_state_dict
.format(name))
KeyError: 'unexpected key "pre_model.1.running_mean" in state_dict'

I did the following to check which layer was causing the problem
model1 = myNet(parameters...)
model2 = torch.load('xxx.pth')

for name, param in model2.items():
print(name)

for name, param in model1.state_dict().items():
print(name)

I found that for the instanceNorm2D layers, model2 contains xxx.running_mean and xxx.running_var, but not model1. Any idea what I should do?

@CynthiaLu1119
Copy link

Another question:
If I use the instanceNorm2D layer with affine=True, nn.InstanceNorm2d(ngf,affine=True), then I get the following problem when doing torch.onnx.export

torch.onnx.export(model, (dummy_input, dummy_input), MODEL_DIR + "eye_G.proto", verbose=True)
File "/anaconda2/lib/python2.7/site-packages/torch/onnx/init.py", line 83, in export
_export(model, args, f, export_params, verbose, training, input_names, output_names)
File "/anaconda2/lib/python2.7/site-packages/torch/onnx/init.py", line 138, in _export
_optimize_trace(trace, aten)
File "/anaconda2/lib/python2.7/site-packages/torch/onnx/init.py", line 94, in _optimize_trace
torch._C._jit_pass_onnx(trace, aten)
File "/anaconda2/lib/python2.7/site-packages/torch/onnx/init.py", line 180, in _run_symbolic_method
return symbolic_fn(*args)
File "/anaconda2/lib/python2.7/site-packages/torch/onnx/init.py", line 396, in symbolic
symbolic_output = symbolic_fn(g, *symbolic_args, **kwargs)
File "/anaconda2/lib/python2.7/site-packages/torch/onnx/symbolic.py", line 552, in instance_norm
if not weight:
RuntimeError: bool value of Tensor with more than one value is ambiguous

If I do nn.InstanceNorm2d(ngf,affine=False) or nn.InstanceNorm2d(ngf), then the export will succeed. Please help. Thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants