Skip to content

Commit

Permalink
[XLA:GPU] Experimental: Add --xla_gpu_per_fusion_autotune_cache_dir o…
Browse files Browse the repository at this point in the history
…ption

If the option is set, we will maintain (read/write) a per-fusion autotune cache in the given directory.

The directory must exist.

Cache invalidation has to be handled by the user (e.g. please use an empty directory if you want to start with an empty cache).

XLA version checks must be done by the user (e.g. if you want to cache fusions created with different versions of XLA, please use different directories).
(If the using library already has a version handling mechanism, like JAX, then it shouldn't be difficult for them to create separate directories based on that version (and all the parameters which matter to them).)

Default: no file based cache.

There is minimal support for multiple processes using the same cache - the rename trick is used to avoid writing the same file by multiple processes at the same time or reading incomplete files.

We use SHA256 hashes in the filenames and assume that no collisions occur.

This is a simple implementation to allow people to test it and find good use-cases. If needed we can refine it later.

Considered use case:
People running [multiple] [similar] models [through JAX]. For example there are 2 similar HLOs that we want to run with JAX (using the same "XLA binary") and it would be nice to reuse the autotune results from the first, if some kernels appear in both.
Similarly: Consider the use case of a researcher sitting at a Colab session and making small changes to their model. They should mostly get cache hits!

Limitations:

It is not recommended to change the cache directory during the run of a process, because then the in-memory and the file based cache can become inconsistent. At least clear the in-memory cache if you change it.

When loading results with LoadAutotuneResults[FromFile], they are not written into the cache directory.

PiperOrigin-RevId: 644406688
  • Loading branch information
tdanyluk authored and tensorflower-gardener committed Jun 18, 2024
1 parent 6e6641a commit 221220f
Show file tree
Hide file tree
Showing 7 changed files with 455 additions and 54 deletions.
15 changes: 15 additions & 0 deletions third_party/xla/xla/debug_options_flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,8 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() {

opts.set_xla_gpu_shard_autotuning(false);

opts.set_xla_gpu_per_fusion_autotune_cache_dir("");

return opts;
}

Expand Down Expand Up @@ -1768,6 +1770,19 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
"reused in further compilations; not yet cached kernels are "
"compiled as usual and get appended to the cache file whenever "
"possible."));
flag_list->push_back(tsl::Flag(
"xla_gpu_per_fusion_autotune_cache_dir",
string_setter_for(
&DebugOptions::set_xla_gpu_per_fusion_autotune_cache_dir),
debug_options->xla_gpu_per_fusion_autotune_cache_dir(),
"Experimental: Maintain a per-fusion autotune cache in the given "
"directory. XLA will try to read existing results when they are needed "
"and write new results when they are determined. The directory must "
"exist. Cache invalidation has to be handled by the user (e.g. please "
"use an empty directory if you want to start with an empty cache). XLA "
"version checks must be done by the user (e.g. if you want to use "
"separate caches for different versions of XLA, please use different "
"directories). Default: no cache."));
} // NOLINT(readability/fn_size)

// Allocates flag_values and flag_objects; this function must not be called more
Expand Down
44 changes: 28 additions & 16 deletions third_party/xla/xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1630,18 +1630,19 @@ cc_library(
deps = if_gpu_is_configured([
":gpu_asm_opts_util",
":stream_executor_util",
"@com_google_absl//absl/base",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/base",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/hash",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/log",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization",
"//xla/hlo/ir:hlo",
"//xla/service:compilation_environments",
"//xla/stream_executor",
"//xla/stream_executor/gpu:redzone_allocator",
"@llvm-project//llvm:Core",
"@llvm-project//llvm:Support",
"//xla:autotune_results_proto_cc",
"//xla:autotuning_proto_cc",
"//xla:shape_util",
Expand All @@ -1650,16 +1651,19 @@ cc_library(
"//xla:types",
"//xla:util",
"//xla:xla_proto_cc",
"//xla/hlo/ir:hlo",
"//xla/service:compilation_environments",
"//xla/stream_executor:stream_executor_memory_allocator",
"//xla/stream_executor",
"//xla/stream_executor/gpu:redzone_allocator",
"@local_tsl//tsl/platform:base64",
"@local_tsl//tsl/platform:env",
"@local_tsl//tsl/platform:errors",
"@local_tsl//tsl/platform:logging",
"@local_tsl//tsl/platform:path",
"@local_tsl//tsl/platform:protobuf",
"@local_tsl//tsl/platform:statusor",
]) + [
"//xla/stream_executor:stream_executor_memory_allocator",
"@com_google_absl//absl/status",
],
]),
)

# We need a separate target, as runtime executable cannot depend on compilation
Expand Down Expand Up @@ -5842,23 +5846,31 @@ xla_cc_test(
"@com_google_googletest//:gtest",
"@com_google_absl//absl/base:log_severity",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/log:scoped_mock_log",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"//xla:autotune_results_proto_cc",
"//xla:autotuning_proto_cc",
"//xla:xla_proto_cc",
"//xla/hlo/ir:hlo",
"//xla/stream_executor:platform",
"//xla/hlo/utils:hlo_query",
"//xla/stream_executor:platform_manager",
"//xla/stream_executor:platform",
"//xla/stream_executor/host:host_platform",
"//xla/tests:hlo_test_base",
"//xla/tests:verified_hlo_module",
"//xla/tests:xla_internal_test_main", # Keep outside GPU guard
"@local_tsl//tsl/lib/core:status_test_util",
"@local_tsl//tsl/platform:env",
"@local_tsl//tsl/platform:status_matchers",
"@local_tsl//tsl/platform:errors",
"@local_tsl//tsl/platform:logging",
"@local_tsl//tsl/platform:path",
"@local_tsl//tsl/platform:protobuf",
]) + [
"//xla/tests:xla_internal_test_main", # Keep outside GPU guard
"@com_google_absl//absl/status",
],
"@local_tsl//tsl/platform:status_matchers",
"@local_tsl//tsl/platform:status",
"@local_tsl//tsl/platform:statusor",
]),
)

cc_library(
Expand Down
Loading

0 comments on commit 221220f

Please sign in to comment.