Skip to content

Commit

Permalink
[xla:ffi] Port legacy GPU custom call tests to use the new XLA FFI.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 609052850
  • Loading branch information
penpornk authored and tensorflower-gardener committed Feb 21, 2024
1 parent e1f279a commit c66ed43
Showing 1 changed file with 183 additions and 2 deletions.
185 changes: 183 additions & 2 deletions third_party/xla/xla/service/gpu/custom_call_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,187 @@ TEST_F(CustomCallTest, PassUserPointerWithAttrs) {
EXPECT_THAT(status.message(), ::testing::HasSubstr("User-defined message"));
}

bool is_ffi_invoked = false;
static absl::Status IsInvoked(ffi::BufferBase) {
is_ffi_invoked = true;
return absl::OkStatus();
}

XLA_FFI_DEFINE_HANDLER(
kIsInvoked, IsInvoked,
ffi::Ffi::Bind().Arg<ffi::BufferBase>()); // Buffer for result (unused).

XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "__xla_test$$isinvoked", PLATFORM,
kIsInvoked);

TEST_F(CustomCallTest, ExportedFfiIsInvoked) {
XlaBuilder b(TestName());
CustomCall(&b, "__xla_test$$isinvoked", /*operands=*/{},
ShapeUtil::MakeShape(F32, {}), /*opaque=*/"",
/*has_side_effect=*/false,
/*output_operand_aliasing=*/{}, /*literal=*/nullptr,
/*schedule=*/CustomCallSchedule::SCHEDULE_NONE,
/*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI);
TF_ASSERT_OK_AND_ASSIGN(auto result, ExecuteAndTransfer(&b, {}));
EXPECT_TRUE(is_ffi_invoked);
}

TEST_F(CustomCallTest, ExportedFfiUnknownTarget) {
XlaBuilder b(TestName());
CustomCall(&b, "__xla_test$$unknown_target", /*operands=*/{},
ShapeUtil::MakeShape(F32, {}), /*opaque=*/"",
/*has_side_effect=*/false,
/*output_operand_aliasing=*/{}, /*literal=*/nullptr,
/*schedule=*/CustomCallSchedule::SCHEDULE_NONE,
/*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI);
auto status = Execute(&b, {}).status();
EXPECT_EQ(status.code(), absl::StatusCode::kUnimplemented);
EXPECT_THAT(status.message(),
::testing::HasSubstr("No registered implementation"));
}

// Memcpy and SubBuffers tests are already ported in
// fusions/address_computation_fusion_test.cc

// Reusing kExpectedOpaque from the original test.
static absl::Status Opaque(ffi::BufferBase, const std::string* str) {
std::string opaque(*str);
if (opaque != kExpectedOpaque)
return absl::InternalError(absl::StrFormat(
"Opaque string does not match. Expected `%s` but got `%s`",
kExpectedOpaque, opaque));
return absl::OkStatus();
}

XLA_FFI_DEFINE_HANDLER(kOpaque, Opaque,
ffi::Ffi::Bind()
.Arg<ffi::BufferBase>() // Dummy result buffer.
.Attr<ffi::Pointer<std::string>>("opaque"));

XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "__xla_test$$opaque", PLATFORM,
kOpaque);

TEST_F(CustomCallTest, ExportedFfiOpaque) {
XlaBuilder b(TestName());
const std::string opaque = absl::StrFormat(
"{opaque = %d : i64}", reinterpret_cast<uintptr_t>(&kExpectedOpaque));
CustomCall(&b, "__xla_test$$opaque", /*operands=*/{},
ShapeUtil::MakeShape(F32, {}),
/*opaque=*/opaque,
/*has_side_effect=*/false,
/*output_operand_aliasing=*/{}, /*literal=*/nullptr,
/*schedule=*/CustomCallSchedule::SCHEDULE_NONE,
/*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI);
TF_ASSERT_OK(Execute(&b, {}).status());
}

static absl::Status TokensChecker(std::vector<ffi::BufferBase> inputs,
const std::string* opaque) {
// TODO(penporn): Actually check the inputs when FFI handlers support tokens.
return absl::OkStatus();
}

static absl::Status Tokens1Input(ffi::BufferBase input1, ffi::BufferBase,
const std::string* opaque) {
return TokensChecker({input1}, opaque);
}

static absl::Status Tokens2Inputs(ffi::BufferBase input1,
ffi::BufferBase input2, ffi::BufferBase,
const std::string* opaque) {
return TokensChecker({input1, input2}, opaque);
}

static absl::Status Tokens3Inputs(ffi::BufferBase input1,
ffi::BufferBase input2,
ffi::BufferBase input3, ffi::BufferBase,
const std::string* opaque) {
return TokensChecker({input1, input2, input3}, opaque);
}

XLA_FFI_DEFINE_HANDLER(kTokens1Input, Tokens1Input,
ffi::Ffi::Bind()
.Arg<ffi::BufferBase>() // 1 input buffer.
.Arg<ffi::BufferBase>() // Output buffer.
.Attr<ffi::Pointer<std::string>>("opaque"));

XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "__xla_test$$tokens_1input",
PLATFORM, kTokens1Input);

XLA_FFI_DEFINE_HANDLER(kTokens2Inputs, Tokens2Inputs,
ffi::Ffi::Bind()
.Arg<ffi::BufferBase>() // 1st input buffer.
.Arg<ffi::BufferBase>() // 2nd input buffer.
.Arg<ffi::BufferBase>() // Output buffer.
.Attr<ffi::Pointer<std::string>>("opaque"));

XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "__xla_test$$tokens_2inputs",
PLATFORM, kTokens2Inputs);

XLA_FFI_DEFINE_HANDLER(kTokens3Inputs, Tokens3Inputs,
ffi::Ffi::Bind()
.Arg<ffi::BufferBase>() // 1st input buffer.
.Arg<ffi::BufferBase>() // 2nd input buffer.
.Arg<ffi::BufferBase>() // 3rd input buffer.
.Arg<ffi::BufferBase>() // Output buffer.
.Attr<ffi::Pointer<std::string>>("opaque"));

XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "__xla_test$$tokens_3inputs",
PLATFORM, kTokens3Inputs);

TEST_P(CustomCallTokensTest, ExportedFfiTokensTest) {
const TokenTestCase& tc = GetParam();
XlaBuilder b(TestName());
std::istringstream input(tc.input);
std::istringstream output(tc.output);
std::vector<XlaOp> call_inputs = BuildInputs(b, input);
std::vector<Shape> call_output = BuildOutputType(output);
ASSERT_GE(call_inputs.size(), 1);
ASSERT_LE(call_inputs.size(), 3);
ASSERT_EQ(call_output.size(), 1);

const std::string custom_call_name =
absl::StrFormat("__xla_test$$tokens_%dinput%s", call_inputs.size(),
call_inputs.size() == 1 ? "" : "s");
const std::string opaque = absl::StrFormat(
"{opaque = %d : i64}", reinterpret_cast<uintptr_t>(&tc.opaque));
CustomCall(&b, custom_call_name, /*operands=*/call_inputs,
call_output.front(),
/*opaque=*/opaque,
/*has_side_effect=*/false,
/*output_operand_aliasing=*/{}, /*literal=*/nullptr,
/*schedule=*/CustomCallSchedule::SCHEDULE_NONE,
/*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI);

// TODO(penporn): Expect an OK status when FFI handlers support tokens.
auto status = Execute(&b, {}).status();
EXPECT_EQ(status.code(), absl::StatusCode::kInternal);
EXPECT_THAT(status.message(),
::testing::HasSubstr("FFI handlers do not support tokens"));
}

INSTANTIATE_TEST_SUITE_P(CustomCallTokensTest, CustomCallTokensTest,
::testing::ValuesIn(GetTokenTestCases()));

static absl::Status AlwaysSucceed(ffi::BufferBase) { return absl::OkStatus(); }

XLA_FFI_DEFINE_HANDLER(kAlwaysSucceed, AlwaysSucceed,
ffi::Ffi::Bind().Arg<ffi::BufferBase>());

XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "__xla_test$$always_succeed",
PLATFORM, kAlwaysSucceed);

TEST_F(CustomCallTest, ExportedFfiWithStatusSucceeded) {
XlaBuilder b(TestName());
CustomCall(&b, "__xla_test$$always_succeed", /*operands=*/{},
ShapeUtil::MakeShape(F32, {}), /*opaque=*/"",
/*has_side_effect=*/false,
/*output_operand_aliasing=*/{}, /*literal=*/nullptr,
/*schedule=*/CustomCallSchedule::SCHEDULE_NONE,
/*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI);
TF_ASSERT_OK(Execute(&b, {}).status());
}

//===----------------------------------------------------------------------===//
// XLA:FFI handler with attached HloComputation
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -463,7 +644,7 @@ XLA_FFI_DEFINE_HANDLER(kMemcpyWithCalledComputation,
.Ctx<ffi::CalledComputation>());

XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(),
"xla.gpu.ext.memcpy_with_called_compuation", PLATFORM,
"xla.gpu.ext.memcpy_with_called_computation", PLATFORM,
kMemcpyWithCalledComputation);

TEST_F(CustomCallTest, WithCalledComputation) {
Expand All @@ -477,7 +658,7 @@ TEST_F(CustomCallTest, WithCalledComputation) {

XlaBuilder b(TestName());
CustomCallWithComputation(
&b, "xla.gpu.ext.memcpy_with_called_compuation",
&b, "xla.gpu.ext.memcpy_with_called_computation",
/*operands=*/{Broadcast(ConstantR0WithType(&b, F32, 42.0), {128})},
copy_computation, shape, /*opaque=*/"",
/*has_side_effect=*/false,
Expand Down

0 comments on commit c66ed43

Please sign in to comment.