Skip to content

Commit

Permalink
Merge branch 'neraoof/scripting_params' of github.com:neginraoof/pyto…
Browse files Browse the repository at this point in the history
…rch into neraoof/scripting_params
  • Loading branch information
neginraoof committed Dec 3, 2020
2 parents 79f4a4c + 71b301a commit 90309a6
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 1 deletion.
25 changes: 25 additions & 0 deletions test/onnx/test_utility_funs.py
Expand Up @@ -678,6 +678,31 @@ def forward(self, x):

assert len(params_dict) == 2

def test_scripting_param(self):
class MyModule(torch.nn.Module):
def __init__(self):
super(MyModule, self).__init__()
self.conv = torch.nn.Conv2d(3, 16, kernel_size=1, stride=2, padding=3, bias=True)
self.bn = torch.nn.BatchNorm2d(16, affine=True)

def forward(self, x):
x = self.conv(x)
bn = self.bn(x)
return bn

model = torch.jit.script(MyModule())
x = torch.randn(10, 3, 128, 128)
example_outputs = model(x)
f = io.BytesIO()
_set_opset_version(self.opset_version)
_set_operator_export_type(OperatorExportTypes.ONNX)
graph, _, __ = utils._model_to_graph(model, (x,), do_constant_folding=True, example_outputs=example_outputs,
operator_export_type=OperatorExportTypes.ONNX)

graph_input_params = [param.debugName() for param in graph.inputs()]
assert all(item in graph_input_params for item in dict(model.named_parameters())), \
"Graph parameter names does not match model parameters."

def test_modifying_params(self):
class MyModel(torch.nn.Module):
def __init__(self):
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/passes/onnx/list_model_parameters.cpp
Expand Up @@ -102,7 +102,7 @@ std::vector<IValue> getParamAttributes(

auto attr = attrModule.attr(name);

std::string fullName("self_");
std::string fullName("");
for (auto& name : moduleNames) {
fullName += name + '.';
}
Expand Down

0 comments on commit 90309a6

Please sign in to comment.