Skip to content

Commit

Permalink
Fix test to load autotuning results from cache instead of actually co…
Browse files Browse the repository at this point in the history
…mputing it. This test wants to assume CUBLAS is always faster than Triton for this particular GEMM, and can only be so via autotuning DB.

PiperOrigin-RevId: 622277357
  • Loading branch information
Moerafaat authored and tensorflower-gardener committed Apr 5, 2024
1 parent c4cff96 commit 7b54481
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 0 deletions.
4 changes: 4 additions & 0 deletions third_party/xla/xla/service/gpu/BUILD
Expand Up @@ -3694,6 +3694,7 @@ xla_test(
name = "gpu_compiler_test",
srcs = if_gpu_is_configured(["gpu_compiler_test.cc"]),
backends = ["gpu"],
data = ["gpu_compiler_test_autotune_db.textproto"],
deps = [
":gpu_compiler",
":gpu_hlo_schedule",
Expand All @@ -3707,6 +3708,7 @@ xla_test(
"//xla/service:pattern_matcher",
"//xla/service:pattern_matcher_gmock",
"//xla/service:xla_debug_info_manager",
"//xla/service/gpu:autotuner_util",
"//xla/tests:hlo_test_base",
"//xla/tests:xla_internal_test_main",
"@com_google_absl//absl/log",
Expand All @@ -3716,7 +3718,9 @@ xla_test(
"@com_google_googletest//:gtest",
"@local_tsl//tsl/lib/core:status_test_util",
"@local_tsl//tsl/platform:env",
"@local_tsl//tsl/platform:path",
"@local_tsl//tsl/platform:statusor",
"@local_tsl//tsl/platform:test",
],
)

Expand Down
11 changes: 11 additions & 0 deletions third_party/xla/xla/service/gpu/gpu_compiler_test.cc
Expand Up @@ -34,6 +34,7 @@ limitations under the License.
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/service/buffer_assignment.h"
#include "xla/service/executable.h"
#include "xla/service/gpu/autotuner_util.h"
#include "xla/service/gpu/gpu_hlo_schedule.h"
#include "xla/service/gpu/metrics.h"
#include "xla/service/hlo_module_config.h"
Expand All @@ -43,7 +44,9 @@ limitations under the License.
#include "xla/tests/hlo_test_base.h"
#include "tsl/lib/core/status_test_util.h"
#include "tsl/platform/env.h"
#include "tsl/platform/path.h"
#include "tsl/platform/statusor.h"
#include "tsl/platform/test.h"

namespace xla {
namespace gpu {
Expand Down Expand Up @@ -360,10 +363,18 @@ ENTRY main {
config.set_debug_options(triton_enabled_debug_options);
config.set_replica_count(1);
config.set_num_partitions(1);

// Load autotuning DB. We shouldn't depend on actual execution times in a unit
// test.
std::string path =
tsl::io::JoinPath(tsl::testing::XlaSrcRoot(), "service", "gpu",
"gpu_compiler_test_autotune_db.textproto");
TF_EXPECT_OK(AutotunerUtil::LoadAutotuneResultsFromFile(path));
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(hlo_string, config));
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> triton_enabled_module,
GetOptimizedModule(std::move(module)));
AutotunerUtil::ClearAutotuneResults();
DebugOptions triton_disabled_debug_options = GetDebugOptionsForTest();
triton_disabled_debug_options.set_xla_gpu_enable_address_computation_fusion(
false);
Expand Down
@@ -0,0 +1,25 @@
version: 3
results {
device: "sm_9.0 with 84942979072B RAM, 132 cores, 1980000KHz clock, 2619000KHz mem clock, 52428800B L2$"
hlo: "(bf16[128,1024,1024]{2,1,0}, s8[33554432]{0}) custom-call(bf16[128,1024,1024]{2,1,0}, bf16[128,1024,1024]{2,1,0}), custom_call_target=\"__cublas$gemm\", backend_config={\"operation_queue_id\":\"0\",\"wait_on_operation_queues\":[],\"gemm_backend_config\":{\"alpha_real\":1,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"2\"],\"rhs_contracting_dimensions\":[\"1\"],\"lhs_batch_dimensions\":[\"0\"],\"rhs_batch_dimensions\":[\"0\"]},\"alpha_imag\":0,\"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"],\"algorithm\":\"ALG_UNSET\"},\"epilogue\":\"DEFAULT\",\"lhs_stride\":\"1048576\",\"rhs_stride\":\"1048576\",\"grad_x\":false,\"grad_y\":false},\"force_earliest_schedule\":false}"
result {
gemm {
algorithm: -1
}
run_time {
nanos: 657376
}
}
}
results {
device: "sm_9.0 with 84942979072B RAM, 132 cores, 1980000KHz clock, 2619000KHz mem clock, 52428800B L2$"
hlo: "{\n tmp_0 = bf16[1,4,32,1024,1024]{4,3,2,1,0} parameter(0)\n tmp_1 = bf16[] constant({...})\n tmp_2 = bf16[1,4,32,1024,1024]{4,3,2,1,0} broadcast(bf16[] tmp_1), dimensions={}\n tmp_3 = bf16[1,4,32,1024,1024]{4,3,2,1,0} multiply(bf16[1,4,32,1024,1024]{4,3,2,1,0} tmp_0, bf16[1,4,32,1024,1024]{4,3,2,1,0} tmp_2)\n tmp_4 = bf16[4,32,1024,1024]{3,2,1,0} bitcast(bf16[1,4,32,1024,1024]{4,3,2,1,0} tmp_3)\n tmp_5 = bf16[4,32,1024,1024]{3,2,1,0} transpose(bf16[4,32,1024,1024]{3,2,1,0} tmp_4), dimensions={0,1,3,2}\n tmp_6 = bf16[128,1024,1024]{2,1,0} bitcast(bf16[4,32,1024,1024]{3,2,1,0} tmp_5)\n tmp_7 = bf16[1,4,32,1024,1024]{4,3,2,1,0} parameter(1)\n tmp_8 = bf16[128,1024,1024]{2,1,0} bitcast(bf16[1,4,32,1024,1024]{4,3,2,1,0} tmp_7)\n tmp_9 = bf16[128,1024,1024]{2,1,0} dot(bf16[128,1024,1024]{2,1,0} tmp_6, bf16[128,1024,1024]{2,1,0} tmp_8), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1}\n ROOT tmp_10 = bf16[4,32,1024,1024]{3,2,1,0} bitcast(bf16[128,1024,1024]{2,1,0} tmp_9)\n}"
result {
gemm {
algorithm: -1
}
run_time {
nanos: 854688
}
}
}

0 comments on commit 7b54481

Please sign in to comment.