Skip to content

Commit

Permalink
corrected 1D assert in Select for OOB
Browse files Browse the repository at this point in the history
  • Loading branch information
pavel-esir committed Nov 11, 2021
1 parent 2ad2131 commit e24d657
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 10 deletions.
8 changes: 4 additions & 4 deletions model-optimizer/extensions/ops/select.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import numpy as np

from mo.front.common.partial_infer.utils import compatible_shapes, dynamic_dimension, shape_array, is_fully_defined
from mo.front.common.partial_infer.utils import compatible_shapes, dynamic_dimension, shape_array, is_fully_defined, compatible_dims
from mo.graph.graph import Node, Graph, Error
from mo.ops.op import Op
from mo.utils.broadcasting import bi_directional_shape_broadcasting, bi_directional_broadcasting
Expand Down Expand Up @@ -55,12 +55,12 @@ def infer(node: Node):
# but by adding ones to the end we can achieve numpy compatibility, as in transformation SelectBroadcast.py
if node.has_valid('format') and node['format'] == 'tf' and len(condition_shape) == 1:
# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/array_ops.py#L4596-L4598
msg_tf = "In Select node '{}' if 'condition' is a 1D tensor then it's size " \
"must be matching with the first dimension of then/else branches. " \
msg_tf = "In Select node '{}' if 'condition' is a 1D tensor then sizes " \
"must be compatible with the first dimension of then/else branches. " \
"But instead got: cond_shape={}, then_shape={}, else_shape={}".format(
node_name, condition_shape, a_shape, b_shape)

assert condition_shape[0] == output_shape[0], msg_tf
assert compatible_dims(condition_shape[0], output_shape[0]), msg_tf
condition_shape = np.concatenate((condition_shape, np.ones(len(output_shape) - 1)))

output_shape = bi_directional_shape_broadcasting(output_shape, condition_shape)
Expand Down
12 changes: 6 additions & 6 deletions model-optimizer/unit_tests/extensions/ops/select_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import unittest

import numpy as np

import re
from extensions.ops.select import Select
from mo.front.common.partial_infer.utils import dynamic_dimension, shape_array, dynamic_dimension_value
from mo.front.common.partial_infer.utils import strict_compare_tensors, int64_array
Expand Down Expand Up @@ -273,12 +273,12 @@ def test_select_infer_tf_condition(self):
self.assertTrue(flag, msg)

def test_select_infer_tf_condition_assert_raises(self):
with self.assertRaisesRegex(AssertionError, "if 'condition' is a 1D tensor then it's size"):
with self.assertRaisesRegex(AssertionError, r"In Select node .*if 'condition' is a 1D tensor then"):
self.build_select_graph_and_infer(condition_value=None, condition_shape=shape_array([42]),
then_value=None, then_shape=shape_array([100, 20]),
else_value=None, else_shape=shape_array([100, 20]),
out_value=None, out_shape=shape_array([100, 20]),
auto_broadcast='numpy', fw_format='tf')
then_value=None, then_shape=shape_array([100, 20]),
else_value=None, else_shape=shape_array([100, 20]),
out_value=None, out_shape=shape_array([100, 20]),
auto_broadcast='numpy', fw_format='tf')

def test_select_infer_assert_pdpd(self):
with self.assertRaisesRegex(Error, "PDPD broadcasting rule is not implemented yet"):
Expand Down

0 comments on commit e24d657

Please sign in to comment.