Skip to content

Commit

Permalink
Adding test for params
Browse files Browse the repository at this point in the history
  • Loading branch information
neginraoof committed Dec 2, 2020
1 parent 41a3843 commit 71b301a
Showing 1 changed file with 25 additions and 0 deletions.
25 changes: 25 additions & 0 deletions test/onnx/test_utility_funs.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,6 +678,31 @@ def forward(self, x):

assert len(params_dict) == 2

def test_scripting_param(self):
class MyModule(torch.nn.Module):
def __init__(self):
super(MyModule, self).__init__()
self.conv = torch.nn.Conv2d(3, 16, kernel_size=1, stride=2, padding=3, bias=True)
self.bn = torch.nn.BatchNorm2d(16, affine=True)

def forward(self, x):
x = self.conv(x)
bn = self.bn(x)
return bn

model = torch.jit.script(MyModule())
x = torch.randn(10, 3, 128, 128)
example_outputs = model(x)
f = io.BytesIO()
_set_opset_version(self.opset_version)
_set_operator_export_type(OperatorExportTypes.ONNX)
graph, _, __ = utils._model_to_graph(model, (x,), do_constant_folding=True, example_outputs=example_outputs,
operator_export_type=OperatorExportTypes.ONNX)

graph_input_params = [param.debugName() for param in graph.inputs()]
assert all(item in graph_input_params for item in dict(model.named_parameters())), \
"Graph parameter names does not match model parameters."

def test_modifying_params(self):
class MyModel(torch.nn.Module):
def __init__(self):
Expand Down

0 comments on commit 71b301a

Please sign in to comment.