Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion kernels/optimized/cpu/binary_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ ElementwiseOptimizedPath inline select_optimized_path(
ScalarType b_type = b.scalar_type();
ScalarType out_type = out.scalar_type();

if (a_type != b_type || a_type != out_type || a_type == ScalarType::Half) {
if (a_type != b_type || a_type != out_type || a_type == ScalarType::Half ||
a_type == ScalarType::BFloat16) {
return ElementwiseOptimizedPath::kNone;
}
if (a.sizes().equals(b.sizes()) ||
Expand Down
17 changes: 9 additions & 8 deletions kernels/optimized/cpu/op_mul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ Tensor& opt_mul_out(
ScalarType out_type = out.scalar_type();

if (b.numel() == 1) {
if (a_type == b_type && a_type == out_type && a_type != ScalarType::Half) {
if (a_type == b_type && a_type == out_type && a_type != ScalarType::Half &&
a_type != ScalarType::BFloat16) {
auto error = resize_tensor(out, a.sizes());
ET_KERNEL_CHECK_MSG(
ctx,
Expand Down Expand Up @@ -170,12 +171,12 @@ Tensor& opt_mul_out(
InvalidArgument,
out);

ET_SWITCH_REALHB_TYPES(a_type, ctx, "mul.out", CTYPE_A, [&]() {
ET_SWITCH_REALHB_TYPES(b_type, ctx, "mul.out", CTYPE_B, [&]() {
ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, "mul.out", CTYPE_A, [&]() {
ET_SWITCH_REALHBBF16_TYPES(b_type, ctx, "mul.out", CTYPE_B, [&]() {
using CTYPE_IN = typename torch::executor::
promote_types<CTYPE_A, CTYPE_B, /*half_to_float*/ true>::type;
ET_DCHECK(CppTypeToScalarType<CTYPE_IN>::value == common_type);
ET_SWITCH_REALHB_TYPES(out_type, ctx, "mul.out", CTYPE_OUT, [&]() {
ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, "mul.out", CTYPE_OUT, [&]() {
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
[](const CTYPE_A val_a, const CTYPE_B val_b) {
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
Expand Down Expand Up @@ -210,7 +211,7 @@ Tensor& opt_mul_scalar_out(

ET_CHECK(common_type == out_type);

if (common_type == ScalarType::Half) {
if (common_type == ScalarType::Half || common_type == ScalarType::BFloat16) {
common_type = ScalarType::Float;
}

Expand All @@ -219,7 +220,7 @@ Tensor& opt_mul_scalar_out(
ET_CHECK_MSG(error == Error::Ok, "Failed to resize output tensor.");

if (a_type == common_type && a_type == out_type &&
a_type != ScalarType::Half) {
a_type != ScalarType::Half && a_type != ScalarType::BFloat16) {
ET_SWITCH_REALB_TYPES(a_type, ctx, "mul.Scalar_out", CTYPE, [&]() {
ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, "mul.Scalar_out", CTYPE_B, [&]() {
CTYPE_B b_val;
Expand All @@ -235,11 +236,11 @@ Tensor& opt_mul_scalar_out(
});
});
} else {
ET_SWITCH_REALHB_TYPES(a_type, ctx, "mul.Scalar_out", CTYPE_A, [&]() {
ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, "mul.Scalar_out", CTYPE_A, [&]() {
ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, "mul.Scalar_out", CTYPE_B, [&]() {
ET_SWITCH_REALB_TYPES(
common_type, ctx, "mul.Scalar_out", CTYPE_IN, [&]() {
ET_SWITCH_REALHB_TYPES(
ET_SWITCH_REALHBBF16_TYPES(
out_type, ctx, "mul.Scalar_out", CTYPE_OUT, [&]() {
CTYPE_B b_val;
ET_EXTRACT_SCALAR(b, b_val);
Expand Down
18 changes: 11 additions & 7 deletions kernels/portable/cpu/op_mul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,11 @@ mul_out(RuntimeContext& ctx, const Tensor& a, const Tensor& b, Tensor& out) {
InvalidArgument,
out);

ET_KERNEL_CHECK(ctx, tensor_is_realhb_type(out), InvalidArgument, out);
ET_KERNEL_CHECK(
ctx,
executorch::runtime::tensor_is_realhbbf16_type(out),
InvalidArgument,
out);

ScalarType a_type = a.scalar_type();
ScalarType b_type = b.scalar_type();
Expand All @@ -79,12 +83,12 @@ mul_out(RuntimeContext& ctx, const Tensor& a, const Tensor& b, Tensor& out) {

ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out);

ET_SWITCH_REALHB_TYPES(a_type, ctx, "mul.out", CTYPE_A, [&]() {
ET_SWITCH_REALHB_TYPES(b_type, ctx, "mul.out", CTYPE_B, [&]() {
ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, "mul.out", CTYPE_A, [&]() {
ET_SWITCH_REALHBBF16_TYPES(b_type, ctx, "mul.out", CTYPE_B, [&]() {
using CTYPE_IN = typename torch::executor::
promote_types<CTYPE_A, CTYPE_B, /*half_to_float*/ true>::type;
ET_DCHECK(CppTypeToScalarType<CTYPE_IN>::value == common_type);
ET_SWITCH_REALHB_TYPES(out_type, ctx, "mul.out", CTYPE_OUT, [&]() {
ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, "mul.out", CTYPE_OUT, [&]() {
MulInner<
can_cast<CTYPE_IN, CTYPE_OUT>::value,
CTYPE_A,
Expand Down Expand Up @@ -123,15 +127,15 @@ Tensor& mul_scalar_out(

ET_KERNEL_CHECK(ctx, common_type == out_type, InvalidArgument, out);

if (common_type == ScalarType::Half) {
if (common_type == ScalarType::Half || common_type == ScalarType::BFloat16) {
common_type = ScalarType::Float;
}

ET_SWITCH_REALHB_TYPES(a_type, ctx, "mul.Scalar_out", CTYPE_A, [&]() {
ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, "mul.Scalar_out", CTYPE_A, [&]() {
ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, "mul.Scalar_out", CTYPE_B, [&]() {
ET_SWITCH_REALB_TYPES(
common_type, ctx, "mul.Scalar_out", CTYPE_IN, [&]() {
ET_SWITCH_REALHB_TYPES(
ET_SWITCH_REALHBBF16_TYPES(
out_type, ctx, "mul.Scalar_out", CTYPE_OUT, [&]() {
CTYPE_B b_val;
utils::extract_scalar(b, &b_val);
Expand Down
9 changes: 5 additions & 4 deletions kernels/portable/cpu/op_to_copy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,11 @@ Tensor& to_copy_out(
InvalidArgument,
out);

ET_SWITCH_REALHB_TYPES(self.scalar_type(), ctx, "to_copy", CTYPE_IN, [&] {
ET_SWITCH_REALHB_TYPES(out.scalar_type(), ctx, "to_copy", CTYPE_OUT, [&] {
_to_impl<CTYPE_IN, CTYPE_OUT>(self, out);
});
ET_SWITCH_REALHBBF16_TYPES(self.scalar_type(), ctx, "to_copy", CTYPE_IN, [&] {
ET_SWITCH_REALHBBF16_TYPES(
out.scalar_type(), ctx, "to_copy", CTYPE_OUT, [&] {
_to_impl<CTYPE_IN, CTYPE_OUT>(self, out);
});
});

return out;
Expand Down
18 changes: 8 additions & 10 deletions kernels/portable/cpu/scalar_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,12 +94,6 @@ struct promote_type_with_scalar_type {
static_assert(
!is_bits_type<T1>::value,
"promote_type_with_scalar_type not valid for bits dtypes");
static_assert(
!std::is_same<
T1,
typename ScalarTypeToCppType<exec_aten::ScalarType::BFloat16>::type>::
value,
"promote_type_with_scalar_type not valid for BFloat16");
using promote_type_with_scalar_type_not_respecting_half_to_float =
typename std::conditional<
is_complex_type<T1>::value ||
Expand All @@ -119,10 +113,14 @@ struct promote_type_with_scalar_type {
public:
using type = typename std::conditional<
half_to_float &&
std::is_same<
promote_type_with_scalar_type_not_respecting_half_to_float,
typename ScalarTypeToCppType<exec_aten::ScalarType::Half>::type>::
value,
(std::is_same<
promote_type_with_scalar_type_not_respecting_half_to_float,
typename ScalarTypeToCppType<
exec_aten::ScalarType::Half>::type>::value ||
std::is_same<
promote_type_with_scalar_type_not_respecting_half_to_float,
typename ScalarTypeToCppType<
exec_aten::ScalarType::BFloat16>::type>::value),
typename ScalarTypeToCppType<exec_aten::ScalarType::Float>::type,
promote_type_with_scalar_type_not_respecting_half_to_float>::type;
};
Expand Down
158 changes: 98 additions & 60 deletions kernels/test/op_mul_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ class OpMulOutTest : public OperatorTest {
#define ENUMERATE_TEST_ENTRY(ctype, dtype) \
test_mul_enumerate_out_types<DTYPE_A, ScalarType::dtype>();

ET_FORALL_REAL_TYPES_AND(Half, ENUMERATE_TEST_ENTRY)
ET_FORALL_REALHBF16_TYPES(ENUMERATE_TEST_ENTRY)

#undef ENUMERATE_TEST_ENTRY
}
Expand All @@ -89,29 +89,99 @@ class OpMulOutTest : public OperatorTest {

// Multiply two tensors
op_mul_out(
tf.make(sizes, /*data=*/{1.1, 2.2, 4.4, 8.8}), tf.ones(sizes), out);
EXPECT_TENSOR_CLOSE(out, tf.make(sizes, /*data=*/{1.1, 2.2, 4.4, 8.8}));
tf.make(sizes, /*data=*/{1.25, 2.5, 4.75, 8.875}), tf.ones(sizes), out);
EXPECT_TENSOR_CLOSE(out, tf.make(sizes, /*data=*/{1.25, 2.5, 4.75, 8.875}));

op_mul_out(
tf.make(sizes, /*data=*/{1.1, 2.2, 4.4, 8.8}), tf.zeros(sizes), out);
EXPECT_TENSOR_CLOSE(out, tf.make(sizes, /*data=*/{0.0, 0.0, 0.0, 0.0}));

op_mul_out(
tf.make(sizes, /*data=*/{1.1, 2.2, 4.4, 8.8}),
tf.make(sizes, /*data=*/{1.1, 2.2, 4.4, 8.8}),
tf.make(sizes, /*data=*/{1.25, 2.5, 4.75, 8.875}),
tf.make(sizes, /*data=*/{1.25, 2.5, 4.75, 8.875}),
out);
EXPECT_TENSOR_CLOSE(
out, tf.make(sizes, /*data=*/{1.21, 4.84, 19.36, 77.44}));
out, tf.make(sizes, /*data=*/{1.5625, 6.25, 22.5625, 78.765625}));
}

void test_mul_enumerate_a_types() {
#define ENUMERATE_TEST_ENTRY(ctype, dtype) \
test_mul_enumerate_b_types<ScalarType::dtype>();

ET_FORALL_REAL_TYPES_AND(Half, ENUMERATE_TEST_ENTRY)
ET_FORALL_REALHBF16_TYPES(ENUMERATE_TEST_ENTRY)

#undef ENUMERATE_TEST_ENTRY
}

template <ScalarType DTYPE>
void test_optimized_path_ignores_leading_1_dimensions() {
TensorFactory<DTYPE> tf;

const std::vector<int32_t> sizes1 = {1, 1, 2, 2};
const std::vector<int32_t> sizes2 = {1, 2, 2};

// Destination for the mul.
Tensor out = tf.zeros(sizes1);

// Multiply two tensors
op_mul_out(
tf.make(sizes1, /*data=*/{1.1, 2.2, 4.4, 8.8}), tf.ones(sizes2), out);
EXPECT_TENSOR_CLOSE(out, tf.make(sizes1, /*data=*/{1.1, 2.2, 4.4, 8.8}));
}

template <ScalarType DTYPE>
void test_broadcast_a2b() {
TensorFactory<DTYPE> tf_a;

std::vector<std::vector<int32_t>> b_sizeses = {
{2},
{1, 2},
};
for (const auto& b_sizes : b_sizeses) {
// a and b of different shapes
Tensor a = tf_a.make({2, 2}, /*data=*/{1, 2, 3, 4});
Tensor b = tf_a.make(b_sizes, /*data=*/{2, 2});

// Destination for output of mul.
Tensor out = tf_a.zeros({2, 2});

// Check that it matches the expected output.
EXPECT_TENSOR_CLOSE(
op_mul_out(a, b, out), tf_a.make({2, 2}, /*data=*/{2, 4, 6, 8}));
}
}

template <ScalarType DTYPE>
void test_broadcast_b2a() {
TensorFactory<DTYPE> tf_a;
// a and b of different shapes
Tensor a = tf_a.make({2}, /*data=*/{2, 2});
Tensor b = tf_a.make({2, 2}, /*data=*/{1, 2, 3, 4});

// Destination for output of mul.
Tensor out = tf_a.zeros({2, 2});

// Check that it matches the expected output.
EXPECT_TENSOR_CLOSE(
op_mul_out(a, b, out), tf_a.make({2, 2}, /*data=*/{2, 4, 6, 8}));
}

template <ScalarType DTYPE>
void test_scalar_input_broadcast() {
TensorFactory<DTYPE> tf_a;

// a is a 1d tensor and b is a scalar
Tensor a = tf_a.make({2}, /*data=*/{2, 2});
Tensor b = tf_a.make({}, /*data=*/{2});

// Destination for output of mul.
Tensor out = tf_a.make({2}, /*data=*/{2, 2});
Tensor expected = tf_a.make({2}, /*data=*/{4, 4});

// Check that it matches the expected output.
EXPECT_TENSOR_CLOSE(op_mul_out(a, b, out), expected);
EXPECT_TENSOR_CLOSE(op_mul_out(b, a, out), expected);
}
};

class OpMulScalarOutTest : public OperatorTest {
Expand Down Expand Up @@ -141,6 +211,14 @@ TEST_F(OpMulOutTest, DoubleTensors) {
test_floating_point_mul_out<ScalarType::Double>();
}

TEST_F(OpMulOutTest, HalfTensors) {
test_floating_point_mul_out<ScalarType::Half>();
}

TEST_F(OpMulOutTest, BFloat16Tensors) {
test_floating_point_mul_out<ScalarType::BFloat16>();
}

TEST_F(OpMulOutTest, BoolTensors) {
TensorFactory<ScalarType::Bool> tf;

Expand All @@ -166,18 +244,12 @@ TEST_F(OpMulOutTest, BoolTensors) {
}

TEST_F(OpMulOutTest, OptimizedPathIgnoresLeading1Dimensions) {
TensorFactory<ScalarType::Float> tf;
#define ENUMERATE_TEST_ENTRY(ctype, dtype) \
test_optimized_path_ignores_leading_1_dimensions<ScalarType::dtype>();

const std::vector<int32_t> sizes1 = {1, 1, 2, 2};
const std::vector<int32_t> sizes2 = {1, 2, 2};
ET_FORALL_FLOATHBF16_TYPES(ENUMERATE_TEST_ENTRY);

// Destination for the mul.
Tensor out = tf.zeros(sizes1);

// Multiply two tensors
op_mul_out(
tf.make(sizes1, /*data=*/{1.1, 2.2, 4.4, 8.8}), tf.ones(sizes2), out);
EXPECT_TENSOR_CLOSE(out, tf.make(sizes1, /*data=*/{1.1, 2.2, 4.4, 8.8}));
#undef ENUMERATE_TEST_ENTRY
}

// Mismatched shape tests.
Expand All @@ -202,40 +274,16 @@ TEST_F(OpMulOutTest, MismatchedNonBroadcastableInputShapesDies) {

// Broadcast tensor b's size to tensor a's size
TEST_F(OpMulOutTest, BroadcastA2BTest) {
TensorFactory<ScalarType::Int> tf_a;

std::vector<std::vector<int32_t>> b_sizeses = {
{2},
{1, 2},
};
for (const auto& b_sizes : b_sizeses) {
// a and b of different shapes
Tensor a = tf_a.make({2, 2}, /*data=*/{1, 2, 3, 4});
Tensor b = tf_a.make(b_sizes, /*data=*/{2, 2});

// Destination for output of mul.
Tensor out = tf_a.zeros({2, 2});

// Check that it matches the expected output.
EXPECT_TENSOR_CLOSE(
op_mul_out(a, b, out), tf_a.make({2, 2}, /*data=*/{2, 4, 6, 8}));
}
test_broadcast_a2b<ScalarType::Int>();
test_broadcast_a2b<ScalarType::Half>();
test_broadcast_a2b<ScalarType::BFloat16>();
}

// Broadcast tensor a's size to tensor b's size
TEST_F(OpMulOutTest, BroadcastB2ATest) {
TensorFactory<ScalarType::Int> tf_a;

// a and b of different shapes
Tensor a = tf_a.make({2}, /*data=*/{2, 2});
Tensor b = tf_a.make({2, 2}, /*data=*/{1, 2, 3, 4});

// Destination for output of mul.
Tensor out = tf_a.zeros({2, 2});

// Check that it matches the expected output.
EXPECT_TENSOR_CLOSE(
op_mul_out(a, b, out), tf_a.make({2, 2}, /*data=*/{2, 4, 6, 8}));
test_broadcast_b2a<ScalarType::Int>();
test_broadcast_b2a<ScalarType::Half>();
test_broadcast_b2a<ScalarType::BFloat16>();
}

// Broadcast tensor a and b's size to a new size c.
Expand All @@ -256,19 +304,9 @@ TEST_F(OpMulOutTest, BroadcastAB2CTest) {
}

TEST_F(OpMulOutTest, ScalarInputBroadcastTest) {
TensorFactory<ScalarType::Int> tf_a;

// a is a 1d tensor and b is a scalar
Tensor a = tf_a.make({2}, /*data=*/{2, 2});
Tensor b = tf_a.make({}, /*data=*/{2});

// Destination for output of mul.
Tensor out = tf_a.make({2}, /*data=*/{2, 2});
Tensor expected = tf_a.make({2}, /*data=*/{4, 4});

// Check that it matches the expected output.
EXPECT_TENSOR_CLOSE(op_mul_out(a, b, out), expected);
EXPECT_TENSOR_CLOSE(op_mul_out(b, a, out), expected);
test_scalar_input_broadcast<ScalarType::Int>();
test_scalar_input_broadcast<ScalarType::Half>();
test_scalar_input_broadcast<ScalarType::BFloat16>();
}

TEST_F(OpMulOutTest, MismatchedOutputShapesDies) {
Expand Down
Loading
Loading