diff --git a/third_party/xla/xla/service/gpu/kernels/topk_custom_kernel_test.cc b/third_party/xla/xla/service/gpu/kernels/topk_custom_kernel_test.cc index 63a40ecd4275ef..8236632dee50bf 100644 --- a/third_party/xla/xla/service/gpu/kernels/topk_custom_kernel_test.cc +++ b/third_party/xla/xla/service/gpu/kernels/topk_custom_kernel_test.cc @@ -89,7 +89,7 @@ TEST_P(TopKKernelTest, TopKFloat) { se::StreamExecutor* executor = platform->ExecutorForDevice(0).value(); se::Stream stream(executor); - stream.Init(); + TF_ASSERT_OK(stream.Initialize()); ASSERT_TRUE(stream.ok()); const auto [n_kb, k, batch_size, offset] = GetParam(); @@ -103,9 +103,11 @@ TEST_P(TopKKernelTest, TopKFloat) { executor->AllocateArray(k * batch_size, 0); auto source = RandomVec(n * batch_size); - stream.ThenMemcpy(&input_buffer, source.data(), n * batch_size * sizeof(T)); - stream.ThenMemZero(&output_values, k * batch_size * sizeof(T)); - stream.ThenMemZero(&output_indices, k * batch_size * sizeof(uint32_t)); + TF_ASSERT_OK( + stream.Memcpy(&input_buffer, source.data(), n * batch_size * sizeof(T))); + TF_ASSERT_OK(stream.MemZero(&output_values, k * batch_size * sizeof(T))); + TF_ASSERT_OK( + stream.MemZero(&output_indices, k * batch_size * sizeof(uint32_t))); auto custom_kernel = GetTopKKernel("topk", PrimitiveType::F32, n, k, batch_size); @@ -124,8 +126,8 @@ TEST_P(TopKKernelTest, TopKFloat) { std::vector got(k); ASSERT_TRUE(stream.BlockHostUntilDone().ok()); for (int i = 0; i < batch_size; i++) { - stream.ThenMemcpy(got.data(), output_values.GetSlice(k * i, k), - k * sizeof(T)); + TF_ASSERT_OK(stream.Memcpy(got.data(), output_values.GetSlice(k * i, k), + k * sizeof(T))); std::vector slice(source.data() + n * i, source.data() + n * (i + 1)); std::sort(slice.begin(), slice.end(), std::greater()); slice.resize(k); @@ -143,7 +145,7 @@ TEST_P(TopKKernelTest, TopKPackedNegative) { se::StreamExecutor* executor = platform->ExecutorForDevice(0).value(); se::Stream stream(executor); - stream.Init(); + TF_ASSERT_OK(stream.Initialize()); ASSERT_TRUE(stream.ok()); const auto [n_kb, k, batch_size, offset] = GetParam(); @@ -157,9 +159,11 @@ TEST_P(TopKKernelTest, TopKPackedNegative) { executor->AllocateArray(k * batch_size, 0); auto source = RandomVecNegative(n * batch_size); - stream.ThenMemcpy(&input_buffer, source.data(), n * batch_size * sizeof(T)); - stream.ThenMemZero(&output_values, k * batch_size * sizeof(T)); - stream.ThenMemZero(&output_indices, k * batch_size * sizeof(uint32_t)); + TF_ASSERT_OK( + stream.Memcpy(&input_buffer, source.data(), n * batch_size * sizeof(T))); + TF_ASSERT_OK(stream.MemZero(&output_values, k * batch_size * sizeof(T))); + TF_ASSERT_OK( + stream.MemZero(&output_indices, k * batch_size * sizeof(uint32_t))); auto custom_kernel = GetTopKKernel("topk", PrimitiveType::F32, n, k, batch_size); @@ -178,8 +182,8 @@ TEST_P(TopKKernelTest, TopKPackedNegative) { std::vector got(k); ASSERT_TRUE(stream.BlockHostUntilDone().ok()); for (int i = 0; i < batch_size; i++) { - stream.ThenMemcpy(got.data(), output_values.GetSlice(k * i, k), - k * sizeof(T)); + TF_ASSERT_OK(stream.Memcpy(got.data(), output_values.GetSlice(k * i, k), + k * sizeof(T))); std::vector slice(source.data() + n * i, source.data() + n * (i + 1)); std::sort(slice.begin(), slice.end(), std::greater()); slice.resize(k); diff --git a/third_party/xla/xla/service/gpu/kernels/topk_kernel_test.cc b/third_party/xla/xla/service/gpu/kernels/topk_kernel_test.cc index 84d342dc99b5a6..0fe83b55b2599e 100644 --- a/third_party/xla/xla/service/gpu/kernels/topk_kernel_test.cc +++ b/third_party/xla/xla/service/gpu/kernels/topk_kernel_test.cc @@ -93,7 +93,7 @@ TEST_P(TopkTest, TopKFloat) { auto* executor = GetGpuExecutor(); se::Stream stream(executor); - stream.Init(); + CHECK_OK(stream.Initialize()); ASSERT_TRUE(stream.ok()); const auto [n_kb, k, batch_size, offset] = GetParam(); @@ -107,8 +107,8 @@ TEST_P(TopkTest, TopKFloat) { output_indices.is_null())); auto source = RandomVec(n * batch_size); - stream.ThenMemcpy(input_buffer.ptr(), source.data(), - n * batch_size * sizeof(T)); + CHECK_OK(stream.Memcpy(input_buffer.ptr(), source.data(), + n * batch_size * sizeof(T))); ASSERT_TRUE(RunTopk(&stream, Get(T()), *input_buffer, n, *output_values, *output_indices, k, batch_size) @@ -116,8 +116,8 @@ TEST_P(TopkTest, TopKFloat) { std::vector got(k); ASSERT_TRUE(stream.BlockHostUntilDone().ok()); for (int i = 0; i < batch_size; i++) { - stream.ThenMemcpy(got.data(), output_values->GetSlice(k * i, k), - k * sizeof(T)); + CHECK_OK(stream.Memcpy(got.data(), output_values->GetSlice(k * i, k), + k * sizeof(T))); std::vector slice(source.data() + n * i, source.data() + n * (i + 1)); std::sort(slice.begin(), slice.end(), std::greater()); slice.resize(k); @@ -131,7 +131,7 @@ TEST_P(TopkTest, TopKPackedNegative) { auto* executor = GetGpuExecutor(); se::Stream stream(executor); - stream.Init(); + CHECK_OK(stream.Initialize()); ASSERT_TRUE(stream.ok()); const auto [n_kb, k, batch_size, offset] = GetParam(); @@ -145,8 +145,8 @@ TEST_P(TopkTest, TopKPackedNegative) { output_indices.is_null())); auto source = RandomVecNegative(n * batch_size); - stream.ThenMemcpy(input_buffer.ptr(), source.data(), - n * batch_size * sizeof(T)); + CHECK_OK(stream.Memcpy(input_buffer.ptr(), source.data(), + n * batch_size * sizeof(T))); ASSERT_TRUE(RunTopk(&stream, Get(T()), *input_buffer, n, *output_values, *output_indices, k, batch_size) @@ -154,8 +154,8 @@ TEST_P(TopkTest, TopKPackedNegative) { std::vector got(k); ASSERT_TRUE(stream.BlockHostUntilDone().ok()); for (int i = 0; i < batch_size; i++) { - stream.ThenMemcpy(got.data(), output_values->GetSlice(k * i, k), - k * sizeof(T)); + CHECK_OK(stream.Memcpy(got.data(), output_values->GetSlice(k * i, k), + k * sizeof(T))); std::vector slice(source.data() + n * i, source.data() + n * (i + 1)); std::sort(slice.begin(), slice.end(), std::greater()); slice.resize(k); @@ -190,7 +190,7 @@ void BM_SmallTopk(benchmark::State& state) { auto* executor = GetGpuExecutor(); se::Stream stream(executor); - stream.Init(); + CHECK_OK(stream.Initialize()); ASSERT_TRUE(stream.ok()); auto input_buffer = executor->AllocateOwnedArray(n * batch_size), @@ -208,7 +208,7 @@ void BM_SmallTopk(benchmark::State& state) { // time to generate random data) for (size_t i = 0; i < batch_size; i++) { auto slice = input_buffer->GetSlice(i * n, n); - stream.ThenMemcpy(&slice, source.data(), n * sizeof(T)); + CHECK_OK(stream.Memcpy(&slice, source.data(), n * sizeof(T))); } for (auto _ : state) {