Skip to content

Commit

Permalink
[xla] functional_hlo_runner_test: add CPU backend
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 631557531
  • Loading branch information
cota authored and tensorflower-gardener committed May 7, 2024
1 parent 783c383 commit f216523
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 9 deletions.
7 changes: 6 additions & 1 deletion third_party/xla/xla/tools/multihost_hlo_runner/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,10 @@ xla_test(
"notap",
],
},
backends = ["gpu"],
backends = [
"cpu",
"gpu",
],
data = [
"data/sharded_16_devices.hlo",
"data/sharded_2_devices.hlo",
Expand All @@ -127,6 +130,8 @@ xla_test(
tags = ["nomac"],
deps = [
":functional_hlo_runner",
"//xla:statusor",
"//xla/pjrt:pjrt_client",
"//xla/tests:filecheck",
"@com_google_googletest//:gtest_main",
"@local_tsl//tsl/lib/core:status_test_util",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ limitations under the License.
#include <vector>

#include <gtest/gtest.h>
#include "xla/pjrt/pjrt_client.h"
#include "xla/statusor.h"
#include "xla/tests/filecheck.h"
#include "tsl/lib/core/status_test_util.h"
#include "tsl/platform/env.h"
Expand All @@ -33,6 +35,20 @@ namespace {

using ::testing::SizeIs;

bool IsTestingCpu() {
#ifdef XLA_TEST_BACKEND_CPU
return true;
#endif
return false;
}

absl::StatusOr<std::unique_ptr<xla::PjRtClient>> GetPjRtClient() {
if (IsTestingCpu()) {
return xla::FunctionalHloRunner::CreateHostClient();
}
return xla::FunctionalHloRunner::CreateGpuClient();
}

class FunctionalHloRunnerTest : public ::testing::Test {
protected:
std::string GetHloPath(std::string file_name) {
Expand All @@ -43,7 +59,7 @@ class FunctionalHloRunnerTest : public ::testing::Test {

TEST_F(FunctionalHloRunnerTest, SingleDeviceHlo) {
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::PjRtClient> client,
xla::FunctionalHloRunner::CreateGpuClient());
GetPjRtClient());

// Options corresponding to --num_replicas=1 --num_partitions=1
xla::DebugOptions debug_options;
Expand All @@ -60,7 +76,7 @@ TEST_F(FunctionalHloRunnerTest, SingleDeviceHlo) {

TEST_F(FunctionalHloRunnerTest, Sharded2Devices) {
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::PjRtClient> client,
xla::FunctionalHloRunner::CreateGpuClient());
GetPjRtClient());

constexpr int kRequiredDeviceCount = 2;
const int kDeviceCount = client->device_count();
Expand Down Expand Up @@ -89,7 +105,7 @@ TEST_F(FunctionalHloRunnerTest, Sharded2Devices) {

TEST_F(FunctionalHloRunnerTest, UseZerosAsInputs) {
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::PjRtClient> client,
xla::FunctionalHloRunner::CreateGpuClient());
GetPjRtClient());

constexpr int kRequiredDeviceCount = 2;
const int kDeviceCount = client->device_count();
Expand Down Expand Up @@ -121,7 +137,7 @@ TEST_F(FunctionalHloRunnerTest, UseZerosAsInputs) {

TEST_F(FunctionalHloRunnerTest, UseUninitializedInputs) {
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::PjRtClient> client,
xla::FunctionalHloRunner::CreateGpuClient());
GetPjRtClient());

constexpr int kRequiredDeviceCount = 2;
const int kDeviceCount = client->device_count();
Expand Down Expand Up @@ -153,7 +169,7 @@ TEST_F(FunctionalHloRunnerTest, UseUninitializedInputs) {

TEST_F(FunctionalHloRunnerTest, UseUninitializedInputsWithTupledArguments) {
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::PjRtClient> client,
xla::FunctionalHloRunner::CreateGpuClient());
GetPjRtClient());

// Options corresponding to:
// --num_replicas=1 --num_partitions=1
Expand Down Expand Up @@ -196,7 +212,7 @@ TEST_F(FunctionalHloRunnerTest, CanCompileWithoutHavingEnoughGpus) {
raw_compile_options.xla_dump_to = dump_dir;

TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::PjRtClient> client,
xla::FunctionalHloRunner::CreateGpuClient());
GetPjRtClient());
TF_EXPECT_OK(FunctionalHloRunner::LoadAndCompile(
*client, debug_options, preproc_options, raw_compile_options,
GetHloPath("sharded_16_devices.hlo"), InputFormat::kText));
Expand All @@ -212,8 +228,8 @@ TEST_F(FunctionalHloRunnerTest, CanCompileWithoutHavingEnoughGpus) {
TF_ASSERT_OK(
tsl::ReadFileToString(env, after_opt_hlo_paths[0], &after_opt_hlo));
absl::StatusOr<bool> file_check_result = RunFileCheck(after_opt_hlo, R"(
// CHECK: param = f32[16,1]{1,0}
// CHECK: add = f32[16,1]{1,0}
// CHECK: param{{.*}} = f32[16,1]{1,0}
// CHECK: add{{.*}} = f32[16,1]{1,0}
)");
TF_ASSERT_OK(file_check_result.status());
EXPECT_TRUE(file_check_result.value());
Expand Down

0 comments on commit f216523

Please sign in to comment.