diff --git a/kernels/test/op_cdist_forward_test.cpp b/kernels/test/op_cdist_forward_test.cpp index 2436c448f82..c674f1b536f 100644 --- a/kernels/test/op_cdist_forward_test.cpp +++ b/kernels/test/op_cdist_forward_test.cpp @@ -8,6 +8,7 @@ #include // Declares the operator #include +#include #include #include #include @@ -45,6 +46,12 @@ class OpCdistForwardOutTest : public ::testing::Test { void test_dtype() { TensorFactory tf; + if ((DTYPE == ScalarType::Half || DTYPE == ScalarType::BFloat16) && + torch::executor::testing::SupportedFeatures::get()->is_aten) { + // ATen doesn't support Half/BFloat for this op. + return; + } + Tensor x1 = tf.make({2, 1, 4, 3}, {0, 1, 2, 3, 5, 4, 3, -3, 7, 1, 6, 2, -1, 5, 1, 1, -2, 1, 5, 4, 3, 2, -1, 5}); Tensor x2 = tf.make(