Skip to content

Commit

Permalink
fix(scripts): diff model nodes as well
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Mar 15, 2023
1 parent 143904f commit 8cf6f22
Showing 1 changed file with 35 additions and 2 deletions.
37 changes: 35 additions & 2 deletions api/scripts/onnx-diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,27 +9,60 @@
logger = getLogger(__name__)

def diff_models(ref_model: ModelProto, cmp_model: ModelProto):
diffs = 0

if len(ref_model.graph.initializer) != len(cmp_model.graph.initializer):
logger.warning("different number of initializers: %s vs %s", len(ref_model.graph.initializer), len(cmp_model.graph.initializer))
diffs += abs(len(ref_model.graph.initializer) - len(cmp_model.graph.initializer))
else:
for (ref_init, cmp_init) in zip(ref_model.graph.initializer, cmp_model.graph.initializer):
if ref_init.name != cmp_init.name:
logger.info("different node names: %s vs %s", ref_init.name, cmp_init.name)
diffs += 1
elif ref_init.data_location != cmp_init.data_location:
logger.info("different data locations: %s vs %s", ref_init.data_location, cmp_init.data_location)
diffs += 1
elif ref_init.data_type != cmp_init.data_type:
logger.info("different data types: %s vs %s", ref_init.data_type, cmp_init.data_type)
diffs += 1
elif len(ref_init.raw_data) != len(cmp_init.raw_data):
logger.info("different raw data size: %s vs %s", len(ref_init.raw_data), len(cmp_init.raw_data))
diffs += 1
elif len(ref_init.raw_data) > 0 and len(cmp_init.raw_data) > 0:
ref_data = to_array(ref_init)
cmp_data = to_array(cmp_init)
data_diff = ref_data - cmp_data
if data_diff.max() > 0:
logger.info("raw data differs: %s", data_diff)
if data_diff.max() != 0:
logger.info("raw data differs for %s: %s", ref_init.name, data_diff.max())
diffs += 1
else:
logger.info("initializers are identical in all checked fields: %s", ref_init.name)

if len(ref_model.graph.node) != len(cmp_model.graph.node):
logger.warning("different number of nodes: %s vs %s", len(ref_model.graph.node), len(cmp_model.graph.node))
diffs += abs(len(ref_model.graph.node) - len(cmp_model.graph.node))
else:
for (ref_node, cmp_node) in zip(ref_model.graph.node, cmp_model.graph.node):
if ref_node.name != cmp_node.name:
logger.info("different node names: %s vs %s", ref_node.name, cmp_node.name)
diffs += 1
elif ref_node.input != cmp_node.input:
logger.info("different inputs: %s vs %s", ref_node.input, cmp_node.input)
diffs += 1
elif ref_node.output != cmp_node.output:
logger.info("different outputs: %s vs %s", ref_node.output, cmp_node.output)
diffs += 1
elif ref_node.op_type != cmp_node.op_type:
logger.info("different op type: %s vs %s", ref_node.op_type, cmp_node.op_type)
diffs += 1
else:
logger.info("nodes are identical in all checked fields: %s", ref_node.name)

if diffs > 0:
logger.warning("models have %s differences", diffs)
else:
logger.info("models have no detectable differences")


if __name__ == "__main__":
ref_path = argv[1]
Expand Down

0 comments on commit 8cf6f22

Please sign in to comment.