@@ -41,6 +41,9 @@ div_out(RuntimeContext& ctx, const Tensor& a, const Tensor& b, Tensor& out) {
41
41
InvalidArgument,
42
42
out);
43
43
44
+ ET_KERNEL_CHECK (
45
+ ctx, tensors_have_same_dim_order (a, b, out), InvalidArgument, out);
46
+
44
47
ScalarType a_type = a.scalar_type ();
45
48
ScalarType b_type = b.scalar_type ();
46
49
@@ -97,6 +100,9 @@ Tensor& div_out_mode(
97
100
InvalidArgument,
98
101
out);
99
102
103
+ ET_KERNEL_CHECK (
104
+ ctx, tensors_have_same_dim_order (a, b, out), InvalidArgument, out);
105
+
100
106
ScalarType a_type = a.scalar_type ();
101
107
ScalarType b_type = b.scalar_type ();
102
108
ScalarType common_type = get_compute_type (a_type, b_type);
@@ -159,6 +165,9 @@ Tensor& div_scalar_out(
159
165
ScalarType common_type = isFloatingType (a_type) ? a_type : ScalarType::Float;
160
166
ScalarType out_type = out.scalar_type ();
161
167
168
+ ET_KERNEL_CHECK (
169
+ ctx, tensors_have_same_dim_order (a, out), InvalidArgument, out);
170
+
162
171
ET_KERNEL_CHECK (ctx, common_type == out_type, InvalidArgument, out);
163
172
164
173
ET_SWITCH_REAL_TYPES_AND (Bool, a_type, ctx, " div.Scalar_out" , CTYPE_A, [&]() {
0 commit comments