From 38c0277f9818dd9adf46b8b32df84fb41ccf679b Mon Sep 17 00:00:00 2001 From: PGZXB Date: Mon, 10 Oct 2022 10:12:10 +0800 Subject: [PATCH] [metal] Support offline cache on metal (#6227) Issue: #4401 Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- cmake/TaichiCore.cmake | 1 + taichi/analysis/offline_cache_util.cpp | 4 +- taichi/analysis/offline_cache_util.h | 3 +- taichi/cache/gfx/cache_manager.h | 8 +- taichi/cache/metal/CMakeLists.txt | 17 ++ taichi/cache/metal/cache_manager.cpp | 171 ++++++++++++++++++ taichi/cache/metal/cache_manager.h | 74 ++++++++ taichi/codegen/metal/codegen_metal.cpp | 15 +- taichi/codegen/metal/codegen_metal.h | 12 +- taichi/common/serialization.h | 7 + .../program_impls/metal/CMakeLists.txt | 1 + .../program_impls/metal/metal_program.cpp | 35 +++- .../program_impls/metal/metal_program.h | 6 + taichi/util/lock.h | 9 + taichi/util/offline_cache.cpp | 2 + taichi/util/offline_cache.h | 20 +- .../cpp/offline_cache/load_metadata_test.cpp | 2 +- tests/python/test_offline_cache.py | 23 ++- 18 files changed, 364 insertions(+), 46 deletions(-) create mode 100644 taichi/cache/metal/CMakeLists.txt create mode 100644 taichi/cache/metal/cache_manager.cpp create mode 100644 taichi/cache/metal/cache_manager.h diff --git a/cmake/TaichiCore.cmake b/cmake/TaichiCore.cmake index c05cbec358876..9ae470210e7bb 100644 --- a/cmake/TaichiCore.cmake +++ b/cmake/TaichiCore.cmake @@ -284,6 +284,7 @@ if (TI_WITH_METAL) add_subdirectory(taichi/runtime/metal) add_subdirectory(taichi/runtime/program_impls/metal) add_subdirectory(taichi/codegen/metal) + add_subdirectory(taichi/cache/metal) target_link_libraries(${CORE_LIBRARY_NAME} PRIVATE metal_codegen) target_link_libraries(${CORE_LIBRARY_NAME} PRIVATE metal_runtime) diff --git a/taichi/analysis/offline_cache_util.cpp b/taichi/analysis/offline_cache_util.cpp index 9bc3d2c4b5361..e1a225de31788 100644 --- a/taichi/analysis/offline_cache_util.cpp +++ b/taichi/analysis/offline_cache_util.cpp @@ -14,7 +14,7 @@ namespace taichi::lang { static std::vector get_offline_cache_key_of_compile_config( - CompileConfig *config) { + const CompileConfig *config) { TI_ASSERT(config); BinaryOutputSerializer serializer; serializer.initialize(); @@ -148,7 +148,7 @@ std::string get_hashed_offline_cache_key_of_snode(SNode *snode) { return picosha2::get_hash_hex_string(hasher); } -std::string get_hashed_offline_cache_key(CompileConfig *config, +std::string get_hashed_offline_cache_key(const CompileConfig *config, Kernel *kernel) { std::string kernel_ast_string; if (kernel) { diff --git a/taichi/analysis/offline_cache_util.h b/taichi/analysis/offline_cache_util.h index 21aa69f23a065..8978b12500f76 100644 --- a/taichi/analysis/offline_cache_util.h +++ b/taichi/analysis/offline_cache_util.h @@ -13,7 +13,8 @@ class SNode; class Kernel; std::string get_hashed_offline_cache_key_of_snode(SNode *snode); -std::string get_hashed_offline_cache_key(CompileConfig *config, Kernel *kernel); +std::string get_hashed_offline_cache_key(const CompileConfig *config, + Kernel *kernel); void gen_offline_cache_key(Program *prog, IRNode *ast, std::ostream *os); } // namespace taichi::lang diff --git a/taichi/cache/gfx/cache_manager.h b/taichi/cache/gfx/cache_manager.h index bf81855ad88cd..a9392b565aef8 100644 --- a/taichi/cache/gfx/cache_manager.h +++ b/taichi/cache/gfx/cache_manager.h @@ -8,11 +8,17 @@ namespace taichi::lang { namespace gfx { +struct OfflineCacheKernelMetadata : public offline_cache::KernelMetadataBase { + std::size_t num_files{0}; + + TI_IO_DEF_WITH_BASECLASS(offline_cache::KernelMetadataBase, num_files); +}; + class CacheManager { using CompiledKernelData = gfx::GfxRuntime::RegisterParams; public: - using Metadata = offline_cache::Metadata; + using Metadata = offline_cache::Metadata; enum Mode { NotCache, MemCache, MemAndDiskCache }; struct Params { diff --git a/taichi/cache/metal/CMakeLists.txt b/taichi/cache/metal/CMakeLists.txt new file mode 100644 index 0000000000000..781542425b65d --- /dev/null +++ b/taichi/cache/metal/CMakeLists.txt @@ -0,0 +1,17 @@ +# ./taichi/cache/metal/CMakeLists.txt + +add_library(metal_cache) +target_sources(metal_cache + PRIVATE + cache_manager.cpp + ) + +target_include_directories(metal_cache + PRIVATE + ${PROJECT_SOURCE_DIR} + ${PROJECT_SOURCE_DIR}/external/spdlog/include + ${PROJECT_SOURCE_DIR}/external/eigen + ${LLVM_INCLUDE_DIRS} # For "llvm/ADT/SmallVector.h" included in ir.h + ) + +target_link_libraries(metal_cache PRIVATE metal_codegen) diff --git a/taichi/cache/metal/cache_manager.cpp b/taichi/cache/metal/cache_manager.cpp new file mode 100644 index 0000000000000..d82d943927cf7 --- /dev/null +++ b/taichi/cache/metal/cache_manager.cpp @@ -0,0 +1,171 @@ +#include "taichi/cache/metal/cache_manager.h" +#include "taichi/analysis/offline_cache_util.h" +#include "taichi/codegen/metal/codegen_metal.h" +#include "taichi/common/version.h" +#include "taichi/program/kernel.h" +#include "taichi/util/io.h" +#include "taichi/util/lock.h" +#include "taichi/util/offline_cache.h" + +namespace taichi::lang { +namespace metal { + +CacheManager::CacheManager(Params &&init_params) + : config_(std::move(init_params)) { + if (config_.mode == MemAndDiskCache) { + const auto filepath = join_path(config_.cache_path, kMetadataFilename); + const auto lock_path = join_path(config_.cache_path, kMetadataLockName); + if (lock_with_file(lock_path)) { + auto _ = make_unlocker(lock_path); + offline_cache::load_metadata_with_checking(cached_data_, filepath); + } + } +} + +CacheManager::CompiledKernelData CacheManager::load_or_compile( + const CompileConfig *compile_config, + Kernel *kernel) { + if (kernel->is_evaluator || config_.mode == NotCache) { + return compile_kernel(kernel); + } + TI_ASSERT(config_.mode > NotCache); + const auto kernel_key = make_kernel_key(compile_config, kernel); + if (auto opt = try_load_cached_kernel(kernel, kernel_key)) { + return *opt; + } + return compile_and_cache_kernel(kernel_key, kernel); +} + +void CacheManager::dump_with_merging() const { + if (config_.mode == MemAndDiskCache && !caching_kernels_.empty()) { + taichi::create_directories(config_.cache_path); + const auto filepath = join_path(config_.cache_path, kMetadataFilename); + const auto lock_path = join_path(config_.cache_path, kMetadataLockName); + if (lock_with_file(lock_path)) { + auto _ = make_unlocker(lock_path); + Metadata data; + data.version[0] = TI_VERSION_MAJOR; + data.version[1] = TI_VERSION_MINOR; + data.version[2] = TI_VERSION_PATCH; + // Load old cached data + load_metadata_with_checking(data, filepath); + // Update the cached data + for (const auto *e : updated_data_) { + auto iter = data.kernels.find(e->kernel_key); + if (iter != data.kernels.end()) { + iter->second.last_used_at = e->last_used_at; + } + } + // Add new data + for (auto &[key, kernel] : caching_kernels_) { + auto [iter, ok] = data.kernels.insert({key, kernel}); + if (ok) { + data.size += iter->second.size; + // Mangle kernel_name as kernel-key + iter->second.compiled_kernel_data.kernel_name = key; + } + } + // Dump + for (const auto &[_, k] : data.kernels) { + const auto code_filepath = join_path( + config_.cache_path, fmt::format(kMetalCodeFormat, k.kernel_key)); + if (try_lock_with_file(code_filepath)) { + std::ofstream fs{code_filepath}; + fs << k.compiled_kernel_data.source_code; + } + } + write_to_binary_file(data, filepath); + } + } +} + +CompiledKernelData CacheManager::compile_kernel(Kernel *kernel) const { + kernel->lower(); + return run_codegen(config_.compiled_runtime_module_, + *config_.compiled_snode_trees_, kernel, nullptr); +} + +std::string CacheManager::make_kernel_key(const CompileConfig *compile_config, + Kernel *kernel) const { + if (config_.mode < MemAndDiskCache) { + return kernel->get_name(); + } + auto key = kernel->get_cached_kernel_key(); + if (key.empty()) { + key = get_hashed_offline_cache_key(compile_config, kernel); + kernel->set_kernel_key_for_cache(key); + } + return key; +} + +std::optional CacheManager::try_load_cached_kernel( + Kernel *kernel, + const std::string &key) { + TI_ASSERT(config_.mode > NotCache); + { // Find in memory-cache (caching_kernels_) + const auto &kernels = caching_kernels_; + auto iter = kernels.find(key); + if (iter != kernels.end()) { + TI_DEBUG("Create kernel '{}' from in-memory cache (key='{}')", + kernel->get_name(), key); + kernel->mark_as_from_cache(); + return iter->second.compiled_kernel_data; + } + } + // Find in disk-cache (cached_data_) + if (config_.mode == MemAndDiskCache) { + auto &kernels = cached_data_.kernels; + auto iter = kernels.find(key); + if (iter != kernels.end()) { + auto &k = iter->second; + TI_ASSERT(k.kernel_key == k.compiled_kernel_data.kernel_name); + if (load_kernel_source_code(k)) { + TI_DEBUG("Create kernel '{}' from cache (key='{}')", kernel->get_name(), + key); + k.last_used_at = std::time(nullptr); + updated_data_.push_back(&k); + kernel->mark_as_from_cache(); + return k.compiled_kernel_data; + } + } + } + return std::nullopt; +} + +CompiledKernelData CacheManager::compile_and_cache_kernel( + const std::string &key, + Kernel *kernel) { + TI_DEBUG_IF(config_.mode == MemAndDiskCache, "Cache kernel '{}' (key='{}')", + kernel->get_name(), key); + OfflineCacheKernelMetadata k; + k.kernel_key = key; + k.created_at = k.last_used_at = std::time(nullptr); + k.compiled_kernel_data = compile_kernel(kernel); + k.size = k.compiled_kernel_data.source_code.size(); + const auto &kernel_data = (caching_kernels_[key] = std::move(k)); + return kernel_data.compiled_kernel_data; +} + +bool CacheManager::load_kernel_source_code( + OfflineCacheKernelMetadata &kernel_data) { + auto &src = kernel_data.compiled_kernel_data.source_code; + if (!src.empty()) { + return true; + } + auto filepath = + join_path(config_.cache_path, + fmt::format(kMetalCodeFormat, kernel_data.kernel_key)); + std::ifstream f(filepath); + if (!f.is_open()) { + return false; + } + f.seekg(0, std::ios::end); + const auto len = f.tellg(); + f.seekg(0, std::ios::beg); + src.resize(len); + f.read(&src[0], len); + return !f.fail(); +} + +} // namespace metal +} // namespace taichi::lang diff --git a/taichi/cache/metal/cache_manager.h b/taichi/cache/metal/cache_manager.h new file mode 100644 index 0000000000000..0b652c3593da2 --- /dev/null +++ b/taichi/cache/metal/cache_manager.h @@ -0,0 +1,74 @@ +#pragma once + +#include "taichi/codegen/metal/struct_metal.h" +#include "taichi/common/serialization.h" +#include "taichi/program/compile_config.h" +#include "taichi/runtime/metal/kernel_utils.h" +#include "taichi/util/offline_cache.h" + +namespace taichi::lang { +namespace metal { + +struct OfflineCacheKernelMetadata : public offline_cache::KernelMetadataBase { + CompiledKernelData compiled_kernel_data; + + TI_IO_DEF_WITH_BASECLASS(offline_cache::KernelMetadataBase, + compiled_kernel_data); +}; + +class CacheManager { + static constexpr char kMetadataFilename[] = "metadata.tcb"; + static constexpr char kMetadataLockName[] = "metadata.lock"; + static constexpr char kMetalCodeFormat[] = + "{}.metal"; // "{kernel-key}.metal" + using CompiledKernelData = taichi::lang::metal::CompiledKernelData; + using CachingData = + std::unordered_map; + + public: + using Metadata = offline_cache::Metadata; + enum Mode { NotCache, MemCache, MemAndDiskCache }; + + struct Params { + Mode mode{MemCache}; + std::string cache_path; + const CompiledRuntimeModule *compiled_runtime_module_{nullptr}; + const std::vector *compiled_snode_trees_{nullptr}; + }; + + CacheManager(Params &&init_params); + + // Load from memory || Load from disk || (Compile && Cache the result in + // memory) + CompiledKernelData load_or_compile(const CompileConfig *compile_config, + Kernel *kernel); + + // Dump the cached data in memory to disk + void dump_with_merging() const; + + // Run offline cache cleaning + void clean_offline_cache(offline_cache::CleanCachePolicy policy, + int max_bytes, + double cleaning_factor) const { + TI_NOT_IMPLEMENTED; + } + + private: + CompiledKernelData compile_kernel(Kernel *kernel) const; + std::string make_kernel_key(const CompileConfig *compile_config, + Kernel *kernel) const; + std::optional try_load_cached_kernel( + Kernel *kernel, + const std::string &key); + CompiledKernelData compile_and_cache_kernel(const std::string &key, + Kernel *kernel); + bool load_kernel_source_code(OfflineCacheKernelMetadata &kernel_data); + + Params config_; + CachingData caching_kernels_; + Metadata cached_data_; + std::vector updated_data_; +}; + +} // namespace metal +} // namespace taichi::lang diff --git a/taichi/codegen/metal/codegen_metal.cpp b/taichi/codegen/metal/codegen_metal.cpp index 8879c5efc30d8..fc11a4a59269a 100644 --- a/taichi/codegen/metal/codegen_metal.cpp +++ b/taichi/codegen/metal/codegen_metal.cpp @@ -1689,17 +1689,12 @@ CompiledKernelData run_codegen( return codegen.run(); } -FunctionType compile_to_metal_executable( - Kernel *kernel, - KernelManager *kernel_mgr, - const CompiledRuntimeModule *compiled_runtime_module, - const std::vector &compiled_snode_trees, - OffloadedStmt *offloaded) { - const auto compiled_res = run_codegen( - compiled_runtime_module, compiled_snode_trees, kernel, offloaded); - kernel_mgr->register_taichi_kernel(compiled_res); +FunctionType compiled_kernel_to_metal_executable( + const CompiledKernelData &compiled_kernel, + KernelManager *kernel_mgr) { + kernel_mgr->register_taichi_kernel(compiled_kernel); return [kernel_mgr, - kernel_name = compiled_res.kernel_name](RuntimeContext &ctx) { + kernel_name = compiled_kernel.kernel_name](RuntimeContext &ctx) { kernel_mgr->launch_taichi_kernel(kernel_name, &ctx); }; } diff --git a/taichi/codegen/metal/codegen_metal.h b/taichi/codegen/metal/codegen_metal.h index 12003e22e44b4..fe70808312bfc 100644 --- a/taichi/codegen/metal/codegen_metal.h +++ b/taichi/codegen/metal/codegen_metal.h @@ -22,15 +22,9 @@ CompiledKernelData run_codegen( Kernel *kernel, OffloadedStmt *offloaded); -// If |offloaded| is nullptr, this compiles the AST in |kernel|. Otherwise it -// compiles just |offloaded|. These ASTs must have already been lowered at the -// CHI level. -FunctionType compile_to_metal_executable( - Kernel *kernel, - KernelManager *kernel_mgr, - const CompiledRuntimeModule *compiled_runtime_module, - const std::vector &compiled_snode_trees, - OffloadedStmt *offloaded = nullptr); +FunctionType compiled_kernel_to_metal_executable( + const CompiledKernelData &compiled_kernel, + KernelManager *kernel_mgr); } // namespace metal } // namespace taichi::lang diff --git a/taichi/common/serialization.h b/taichi/common/serialization.h index 778fe7780b94c..39145ae7f4e3e 100644 --- a/taichi/common/serialization.h +++ b/taichi/common/serialization.h @@ -143,6 +143,13 @@ serialize_kv_impl(SER &ser, TI_IO(__VA_ARGS__); \ } +#define TI_IO_DEF_WITH_BASECLASS(BaseClass, ...) \ + template \ + void io(S &serializer) const { \ + this->BaseClass::io(serializer); \ + TI_IO(__VA_ARGS__); \ + } + // This macro serializes each field with its name by doing the following: // 1. Stringifies __VA_ARGS__, then split the stringified result by ',' at // compile time. diff --git a/taichi/runtime/program_impls/metal/CMakeLists.txt b/taichi/runtime/program_impls/metal/CMakeLists.txt index 6b56d8ebbaae5..4ef860e47ec21 100644 --- a/taichi/runtime/program_impls/metal/CMakeLists.txt +++ b/taichi/runtime/program_impls/metal/CMakeLists.txt @@ -13,4 +13,5 @@ target_include_directories(metal_program_impl ) target_link_libraries(metal_program_impl PUBLIC metal_codegen) +target_link_libraries(metal_program_impl PUBLIC metal_cache) target_link_libraries(metal_program_impl PUBLIC metal_runtime) diff --git a/taichi/runtime/program_impls/metal/metal_program.cpp b/taichi/runtime/program_impls/metal/metal_program.cpp index 482acde54832c..35ad116609b67 100644 --- a/taichi/runtime/program_impls/metal/metal_program.cpp +++ b/taichi/runtime/program_impls/metal/metal_program.cpp @@ -3,6 +3,7 @@ #include "metal_program.h" #include "taichi/codegen/metal/codegen_metal.h" #include "taichi/codegen/metal/struct_metal.h" +#include "taichi/util/offline_cache.h" namespace taichi::lang { namespace { @@ -46,12 +47,10 @@ MetalProgramImpl::MetalProgramImpl(CompileConfig &config_) FunctionType MetalProgramImpl::compile(Kernel *kernel, OffloadedStmt *offloaded) { - if (!kernel->lowered()) { - kernel->lower(); - } - return metal::compile_to_metal_executable(kernel, metal_kernel_mgr_.get(), - &(compiled_runtime_module_.value()), - compiled_snode_trees_, offloaded); + TI_ASSERT(offloaded == nullptr); + return metal::compiled_kernel_to_metal_executable( + get_cache_manager()->load_or_compile(config, kernel), + metal_kernel_mgr_.get()); } std::size_t MetalProgramImpl::get_snode_num_dynamically_allocated(SNode *snode, @@ -122,4 +121,28 @@ DeviceAllocation MetalProgramImpl::allocate_memory_ndarray( return metal_kernel_mgr_->allocate_memory(params); } +void MetalProgramImpl::dump_cache_data_to_disk() { + // TODO(PGZXB): Implement cache cleaning + get_cache_manager()->dump_with_merging(); +} + +const std::unique_ptr + &MetalProgramImpl::get_cache_manager() { + if (!cache_manager_) { + TI_ASSERT(compiled_runtime_module_.has_value()); + using Mgr = metal::CacheManager; + Mgr::Params params; + params.mode = + offline_cache::enabled_wip_offline_cache(config->offline_cache) + ? Mgr::MemAndDiskCache + : Mgr::MemCache; + params.cache_path = offline_cache::get_cache_path_by_arch( + config->offline_cache_file_path, Arch::metal); + params.compiled_runtime_module_ = &(*compiled_runtime_module_); + params.compiled_snode_trees_ = &compiled_snode_trees_; + cache_manager_ = std::make_unique(std::move(params)); + } + return cache_manager_; +} + } // namespace taichi::lang diff --git a/taichi/runtime/program_impls/metal/metal_program.h b/taichi/runtime/program_impls/metal/metal_program.h index 81fd03c65bcfd..e5dd5e4b1faf0 100644 --- a/taichi/runtime/program_impls/metal/metal_program.h +++ b/taichi/runtime/program_impls/metal/metal_program.h @@ -2,6 +2,7 @@ #include +#include "taichi/cache/metal/cache_manager.h" #include "taichi/runtime/metal/kernel_manager.h" #include "taichi/codegen/metal/struct_metal.h" #include "taichi/system/memory_pool.h" @@ -46,6 +47,10 @@ class MetalProgramImpl : public ProgramImpl { DeviceAllocation allocate_memory_ndarray(std::size_t alloc_size, uint64 *result_buffer) override; + void dump_cache_data_to_disk() override; + + const std::unique_ptr &get_cache_manager(); + private: const metal::CompiledStructs &compile_snode_tree_types_impl(SNodeTree *tree); @@ -53,6 +58,7 @@ class MetalProgramImpl : public ProgramImpl { std::nullopt}; std::vector compiled_snode_trees_; std::unique_ptr metal_kernel_mgr_{nullptr}; + std::unique_ptr cache_manager_{nullptr}; }; } // namespace taichi::lang diff --git a/taichi/util/lock.h b/taichi/util/lock.h index 88b037acfb018..edd27469d7068 100644 --- a/taichi/util/lock.h +++ b/taichi/util/lock.h @@ -1,5 +1,6 @@ #pragma once +#include "taichi/common/cleanup.h" #include "taichi/common/core.h" #include @@ -54,4 +55,12 @@ inline bool lock_with_file(const std::string &path, return false; } +inline RaiiCleanup make_unlocker(const std::string &path) { + return make_cleanup([&path]() { + if (!unlock_with_file(path)) { + TI_WARN("Unlock {} failed", path); + } + }); +} + } // namespace taichi diff --git a/taichi/util/offline_cache.cpp b/taichi/util/offline_cache.cpp index 2cf7797ac8929..29b881b770164 100644 --- a/taichi/util/offline_cache.cpp +++ b/taichi/util/offline_cache.cpp @@ -24,6 +24,8 @@ std::string get_cache_path_by_arch(const std::string &base_path, Arch arch) { subdir = "llvm"; } else if (arch == Arch::vulkan || arch == Arch::opengl) { subdir = "gfx"; + } else if (arch == Arch::metal) { + subdir = "metal"; } else { return base_path; } diff --git a/taichi/util/offline_cache.h b/taichi/util/offline_cache.h index df1eb863d2286..a96e153ee437e 100644 --- a/taichi/util/offline_cache.h +++ b/taichi/util/offline_cache.h @@ -47,16 +47,18 @@ inline CleanCachePolicy string_to_clean_cache_policy(const std::string &str) { return Never; } +struct KernelMetadataBase { + std::string kernel_key; + std::size_t size{0}; // byte + std::time_t created_at{0}; // sec + std::time_t last_used_at{0}; // sec + + TI_IO_DEF(kernel_key, size, created_at, last_used_at); +}; + +template struct Metadata { - struct KernelMetadata { - std::string kernel_key; - std::size_t size{0}; // byte - std::time_t created_at{0}; // sec - std::time_t last_used_at{0}; // sec - std::size_t num_files{0}; - - TI_IO_DEF(kernel_key, size, created_at, last_used_at, num_files); - }; + using KernelMetadata = KernelMetadataType; Version version{}; std::size_t size{0}; // byte diff --git a/tests/cpp/offline_cache/load_metadata_test.cpp b/tests/cpp/offline_cache/load_metadata_test.cpp index 406b3ae8c9eb3..d17cde132e983 100644 --- a/tests/cpp/offline_cache/load_metadata_test.cpp +++ b/tests/cpp/offline_cache/load_metadata_test.cpp @@ -104,7 +104,7 @@ TEST(OfflineCache, LoadMetadata) { #ifdef TI_WITH_LLVM load_metadata_test(); #endif // TI_WITH_LLVM - load_metadata_test(); + load_metadata_test>(); } } // namespace taichi::lang diff --git a/tests/python/test_offline_cache.py b/tests/python/test_offline_cache.py index cb09af713a192..ea6bbe09fd7b0 100644 --- a/tests/python/test_offline_cache.py +++ b/tests/python/test_offline_cache.py @@ -20,15 +20,16 @@ supported_llvm_archs = {ti.cpu, ti.cuda} supported_gfx_archs = {ti.opengl, ti.vulkan} -supported_archs_offline_cache = supported_llvm_archs | supported_gfx_archs -supported_archs_offline_cache = [ - v for v in supported_archs_offline_cache - if v in test_utils.expected_archs() -] +supported_metal_arch = {ti.metal} +supported_archs_offline_cache = supported_llvm_archs | supported_gfx_archs | supported_metal_arch +supported_archs_offline_cache = { + v + for v in supported_archs_offline_cache if v in test_utils.expected_archs() +} def is_offline_cache_file(filename): - suffixes = ('.ll', '.bc', '.spv') + suffixes = ('.ll', '.bc', '.spv', '.metal') return filename.endswith(suffixes) @@ -51,12 +52,16 @@ def expected_num_cache_files(arch, num_offloads: List[int] = None) -> int: result += len(num_offloads) elif arch in supported_gfx_archs: result += sum(num_offloads) + elif arch in supported_metal_arch: + result += len(num_offloads) # metadata files if arch in supported_llvm_archs: result += 2 # metadata.{json, tcb} elif arch in supported_gfx_archs: # metadata.{json, tcb}, graphs.tcb, offline_cache_metadata.tcb result += 4 + elif arch in supported_metal_arch: + result += 1 # metadata.tcb return result @@ -69,6 +74,8 @@ def backend_specified_cache_path(arch): return join(tmp_offline_cache_file_path(), 'llvm') elif arch in supported_gfx_archs: return join(tmp_offline_cache_file_path(), 'gfx') + elif arch in supported_metal_arch: + return join(tmp_offline_cache_file_path(), 'metal') assert False @@ -486,7 +493,9 @@ def helper(): curr_arch, [2, 2]) -@pytest.mark.parametrize('curr_arch', supported_archs_offline_cache) +# TODO(PGZXB): Implement cache cleaning on metal +@pytest.mark.parametrize('curr_arch', + supported_archs_offline_cache - supported_metal_arch) @pytest.mark.parametrize('factor', [0.0, 0.25, 0.85, 1.0]) @pytest.mark.parametrize('policy', ['never', 'version', 'lru', 'fifo']) @_test_offline_cache_dec