Skip to content

Commit

Permalink
Transition gpu/kernels to use absl::Status-returning methods on Stream.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 609172758
  • Loading branch information
klucke authored and tensorflower-gardener committed Feb 22, 2024
1 parent ad16e54 commit 4bc8687
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 24 deletions.
28 changes: 16 additions & 12 deletions third_party/xla/xla/service/gpu/kernels/topk_custom_kernel_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ TEST_P(TopKKernelTest, TopKFloat) {
se::StreamExecutor* executor = platform->ExecutorForDevice(0).value();

se::Stream stream(executor);
stream.Init();
TF_ASSERT_OK(stream.Initialize());
ASSERT_TRUE(stream.ok());

const auto [n_kb, k, batch_size, offset] = GetParam();
Expand All @@ -103,9 +103,11 @@ TEST_P(TopKKernelTest, TopKFloat) {
executor->AllocateArray<uint32_t>(k * batch_size, 0);

auto source = RandomVec<T>(n * batch_size);
stream.ThenMemcpy(&input_buffer, source.data(), n * batch_size * sizeof(T));
stream.ThenMemZero(&output_values, k * batch_size * sizeof(T));
stream.ThenMemZero(&output_indices, k * batch_size * sizeof(uint32_t));
TF_ASSERT_OK(
stream.Memcpy(&input_buffer, source.data(), n * batch_size * sizeof(T)));
TF_ASSERT_OK(stream.MemZero(&output_values, k * batch_size * sizeof(T)));
TF_ASSERT_OK(
stream.MemZero(&output_indices, k * batch_size * sizeof(uint32_t)));

auto custom_kernel =
GetTopKKernel("topk", PrimitiveType::F32, n, k, batch_size);
Expand All @@ -124,8 +126,8 @@ TEST_P(TopKKernelTest, TopKFloat) {
std::vector<T> got(k);
ASSERT_TRUE(stream.BlockHostUntilDone().ok());
for (int i = 0; i < batch_size; i++) {
stream.ThenMemcpy(got.data(), output_values.GetSlice(k * i, k),
k * sizeof(T));
TF_ASSERT_OK(stream.Memcpy(got.data(), output_values.GetSlice(k * i, k),
k * sizeof(T)));
std::vector<T> slice(source.data() + n * i, source.data() + n * (i + 1));
std::sort(slice.begin(), slice.end(), std::greater<T>());
slice.resize(k);
Expand All @@ -143,7 +145,7 @@ TEST_P(TopKKernelTest, TopKPackedNegative) {
se::StreamExecutor* executor = platform->ExecutorForDevice(0).value();

se::Stream stream(executor);
stream.Init();
TF_ASSERT_OK(stream.Initialize());
ASSERT_TRUE(stream.ok());

const auto [n_kb, k, batch_size, offset] = GetParam();
Expand All @@ -157,9 +159,11 @@ TEST_P(TopKKernelTest, TopKPackedNegative) {
executor->AllocateArray<uint32_t>(k * batch_size, 0);

auto source = RandomVecNegative<T>(n * batch_size);
stream.ThenMemcpy(&input_buffer, source.data(), n * batch_size * sizeof(T));
stream.ThenMemZero(&output_values, k * batch_size * sizeof(T));
stream.ThenMemZero(&output_indices, k * batch_size * sizeof(uint32_t));
TF_ASSERT_OK(
stream.Memcpy(&input_buffer, source.data(), n * batch_size * sizeof(T)));
TF_ASSERT_OK(stream.MemZero(&output_values, k * batch_size * sizeof(T)));
TF_ASSERT_OK(
stream.MemZero(&output_indices, k * batch_size * sizeof(uint32_t)));

auto custom_kernel =
GetTopKKernel("topk", PrimitiveType::F32, n, k, batch_size);
Expand All @@ -178,8 +182,8 @@ TEST_P(TopKKernelTest, TopKPackedNegative) {
std::vector<T> got(k);
ASSERT_TRUE(stream.BlockHostUntilDone().ok());
for (int i = 0; i < batch_size; i++) {
stream.ThenMemcpy(got.data(), output_values.GetSlice(k * i, k),
k * sizeof(T));
TF_ASSERT_OK(stream.Memcpy(got.data(), output_values.GetSlice(k * i, k),
k * sizeof(T)));
std::vector<T> slice(source.data() + n * i, source.data() + n * (i + 1));
std::sort(slice.begin(), slice.end(), std::greater<T>());
slice.resize(k);
Expand Down
24 changes: 12 additions & 12 deletions third_party/xla/xla/service/gpu/kernels/topk_kernel_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ TEST_P(TopkTest, TopKFloat) {

auto* executor = GetGpuExecutor();
se::Stream stream(executor);
stream.Init();
CHECK_OK(stream.Initialize());
ASSERT_TRUE(stream.ok());

const auto [n_kb, k, batch_size, offset] = GetParam();
Expand All @@ -107,17 +107,17 @@ TEST_P(TopkTest, TopKFloat) {
output_indices.is_null()));

auto source = RandomVec<T>(n * batch_size);
stream.ThenMemcpy(input_buffer.ptr(), source.data(),
n * batch_size * sizeof(T));
CHECK_OK(stream.Memcpy(input_buffer.ptr(), source.data(),
n * batch_size * sizeof(T)));

ASSERT_TRUE(RunTopk(&stream, Get(T()), *input_buffer, n, *output_values,
*output_indices, k, batch_size)
.ok());
std::vector<T> got(k);
ASSERT_TRUE(stream.BlockHostUntilDone().ok());
for (int i = 0; i < batch_size; i++) {
stream.ThenMemcpy(got.data(), output_values->GetSlice(k * i, k),
k * sizeof(T));
CHECK_OK(stream.Memcpy(got.data(), output_values->GetSlice(k * i, k),
k * sizeof(T)));
std::vector<T> slice(source.data() + n * i, source.data() + n * (i + 1));
std::sort(slice.begin(), slice.end(), std::greater<T>());
slice.resize(k);
Expand All @@ -131,7 +131,7 @@ TEST_P(TopkTest, TopKPackedNegative) {

auto* executor = GetGpuExecutor();
se::Stream stream(executor);
stream.Init();
CHECK_OK(stream.Initialize());
ASSERT_TRUE(stream.ok());

const auto [n_kb, k, batch_size, offset] = GetParam();
Expand All @@ -145,17 +145,17 @@ TEST_P(TopkTest, TopKPackedNegative) {
output_indices.is_null()));

auto source = RandomVecNegative<T>(n * batch_size);
stream.ThenMemcpy(input_buffer.ptr(), source.data(),
n * batch_size * sizeof(T));
CHECK_OK(stream.Memcpy(input_buffer.ptr(), source.data(),
n * batch_size * sizeof(T)));

ASSERT_TRUE(RunTopk(&stream, Get(T()), *input_buffer, n, *output_values,
*output_indices, k, batch_size)
.ok());
std::vector<T> got(k);
ASSERT_TRUE(stream.BlockHostUntilDone().ok());
for (int i = 0; i < batch_size; i++) {
stream.ThenMemcpy(got.data(), output_values->GetSlice(k * i, k),
k * sizeof(T));
CHECK_OK(stream.Memcpy(got.data(), output_values->GetSlice(k * i, k),
k * sizeof(T)));
std::vector<T> slice(source.data() + n * i, source.data() + n * (i + 1));
std::sort(slice.begin(), slice.end(), std::greater<T>());
slice.resize(k);
Expand Down Expand Up @@ -190,7 +190,7 @@ void BM_SmallTopk(benchmark::State& state) {

auto* executor = GetGpuExecutor();
se::Stream stream(executor);
stream.Init();
CHECK_OK(stream.Initialize());
ASSERT_TRUE(stream.ok());

auto input_buffer = executor->AllocateOwnedArray<T>(n * batch_size),
Expand All @@ -208,7 +208,7 @@ void BM_SmallTopk(benchmark::State& state) {
// time to generate random data)
for (size_t i = 0; i < batch_size; i++) {
auto slice = input_buffer->GetSlice(i * n, n);
stream.ThenMemcpy(&slice, source.data(), n * sizeof(T));
CHECK_OK(stream.Memcpy(&slice, source.data(), n * sizeof(T)));
}

for (auto _ : state) {
Expand Down

0 comments on commit 4bc8687

Please sign in to comment.