Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -611,7 +611,8 @@ def _check_inputs_shape(
elif isinstance(input1, dict):
if input1.keys() != input2.keys():
return False
for (ka, va), vb in zip(input1.items(), input2.values()):
for ka, va in input1.items():
vb = input2[ka]
if type(va) != type(vb):
return False
if isinstance(va, bool) and va != vb:
Expand All @@ -638,9 +639,9 @@ def _check_inputs_shape(

@staticmethod
def _check_tensor_shapes_with_dynamic_shapes(
t1: torch.tensor, t2: torch.tensor, dynamic_shape: dict[int, Any]
input_1: torch.tensor, input_2: torch.tensor, dynamic_shape: dict[int, Any]
) -> bool:
for (i, axis_0), axis_1 in zip(enumerate(t1.shape), t2.shape):
for (i, axis_0), axis_1 in zip(enumerate(input_1.shape), input_2.shape):
if axis_0 != axis_1:
if i not in dynamic_shape:
logger.warning(
Expand All @@ -650,7 +651,7 @@ def _check_tensor_shapes_with_dynamic_shapes(
dyn = dynamic_shape[i]
if axis_1 > dyn.max or axis_1 < dyn.min:
raise DynamicShapeOutOfRangeException(
f"The input size ({axis_1}) of dimension ({i}) is not in dynamic shape range [{dyn.max}, {dyn.max}]!"
f"Dimension ({i}) of new input tensor is not the range of supported shapes (saw: ({axis_1}), expected: [{dyn.min}, {dyn.max}])"
)

return True
Expand Down
Loading