Skip to content

Commit

Permalink
[metal] Support offline cache on metal (#6227)
Browse files Browse the repository at this point in the history
Issue: #4401

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
PGZXB and pre-commit-ci[bot] committed Oct 10, 2022
1 parent 8fe7818 commit 38c0277
Show file tree
Hide file tree
Showing 18 changed files with 364 additions and 46 deletions.
1 change: 1 addition & 0 deletions cmake/TaichiCore.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,7 @@ if (TI_WITH_METAL)
add_subdirectory(taichi/runtime/metal)
add_subdirectory(taichi/runtime/program_impls/metal)
add_subdirectory(taichi/codegen/metal)
add_subdirectory(taichi/cache/metal)

target_link_libraries(${CORE_LIBRARY_NAME} PRIVATE metal_codegen)
target_link_libraries(${CORE_LIBRARY_NAME} PRIVATE metal_runtime)
Expand Down
4 changes: 2 additions & 2 deletions taichi/analysis/offline_cache_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
namespace taichi::lang {

static std::vector<std::uint8_t> get_offline_cache_key_of_compile_config(
CompileConfig *config) {
const CompileConfig *config) {
TI_ASSERT(config);
BinaryOutputSerializer serializer;
serializer.initialize();
Expand Down Expand Up @@ -148,7 +148,7 @@ std::string get_hashed_offline_cache_key_of_snode(SNode *snode) {
return picosha2::get_hash_hex_string(hasher);
}

std::string get_hashed_offline_cache_key(CompileConfig *config,
std::string get_hashed_offline_cache_key(const CompileConfig *config,
Kernel *kernel) {
std::string kernel_ast_string;
if (kernel) {
Expand Down
3 changes: 2 additions & 1 deletion taichi/analysis/offline_cache_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ class SNode;
class Kernel;

std::string get_hashed_offline_cache_key_of_snode(SNode *snode);
std::string get_hashed_offline_cache_key(CompileConfig *config, Kernel *kernel);
std::string get_hashed_offline_cache_key(const CompileConfig *config,
Kernel *kernel);
void gen_offline_cache_key(Program *prog, IRNode *ast, std::ostream *os);

} // namespace taichi::lang
8 changes: 7 additions & 1 deletion taichi/cache/gfx/cache_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,17 @@
namespace taichi::lang {
namespace gfx {

struct OfflineCacheKernelMetadata : public offline_cache::KernelMetadataBase {
std::size_t num_files{0};

TI_IO_DEF_WITH_BASECLASS(offline_cache::KernelMetadataBase, num_files);
};

class CacheManager {
using CompiledKernelData = gfx::GfxRuntime::RegisterParams;

public:
using Metadata = offline_cache::Metadata;
using Metadata = offline_cache::Metadata<OfflineCacheKernelMetadata>;
enum Mode { NotCache, MemCache, MemAndDiskCache };

struct Params {
Expand Down
17 changes: 17 additions & 0 deletions taichi/cache/metal/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# ./taichi/cache/metal/CMakeLists.txt

add_library(metal_cache)
target_sources(metal_cache
PRIVATE
cache_manager.cpp
)

target_include_directories(metal_cache
PRIVATE
${PROJECT_SOURCE_DIR}
${PROJECT_SOURCE_DIR}/external/spdlog/include
${PROJECT_SOURCE_DIR}/external/eigen
${LLVM_INCLUDE_DIRS} # For "llvm/ADT/SmallVector.h" included in ir.h
)

target_link_libraries(metal_cache PRIVATE metal_codegen)
171 changes: 171 additions & 0 deletions taichi/cache/metal/cache_manager.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
#include "taichi/cache/metal/cache_manager.h"
#include "taichi/analysis/offline_cache_util.h"
#include "taichi/codegen/metal/codegen_metal.h"
#include "taichi/common/version.h"
#include "taichi/program/kernel.h"
#include "taichi/util/io.h"
#include "taichi/util/lock.h"
#include "taichi/util/offline_cache.h"

namespace taichi::lang {
namespace metal {

CacheManager::CacheManager(Params &&init_params)
: config_(std::move(init_params)) {
if (config_.mode == MemAndDiskCache) {
const auto filepath = join_path(config_.cache_path, kMetadataFilename);
const auto lock_path = join_path(config_.cache_path, kMetadataLockName);
if (lock_with_file(lock_path)) {
auto _ = make_unlocker(lock_path);
offline_cache::load_metadata_with_checking(cached_data_, filepath);
}
}
}

CacheManager::CompiledKernelData CacheManager::load_or_compile(
const CompileConfig *compile_config,
Kernel *kernel) {
if (kernel->is_evaluator || config_.mode == NotCache) {
return compile_kernel(kernel);
}
TI_ASSERT(config_.mode > NotCache);
const auto kernel_key = make_kernel_key(compile_config, kernel);
if (auto opt = try_load_cached_kernel(kernel, kernel_key)) {
return *opt;
}
return compile_and_cache_kernel(kernel_key, kernel);
}

void CacheManager::dump_with_merging() const {
if (config_.mode == MemAndDiskCache && !caching_kernels_.empty()) {
taichi::create_directories(config_.cache_path);
const auto filepath = join_path(config_.cache_path, kMetadataFilename);
const auto lock_path = join_path(config_.cache_path, kMetadataLockName);
if (lock_with_file(lock_path)) {
auto _ = make_unlocker(lock_path);
Metadata data;
data.version[0] = TI_VERSION_MAJOR;
data.version[1] = TI_VERSION_MINOR;
data.version[2] = TI_VERSION_PATCH;
// Load old cached data
load_metadata_with_checking(data, filepath);
// Update the cached data
for (const auto *e : updated_data_) {
auto iter = data.kernels.find(e->kernel_key);
if (iter != data.kernels.end()) {
iter->second.last_used_at = e->last_used_at;
}
}
// Add new data
for (auto &[key, kernel] : caching_kernels_) {
auto [iter, ok] = data.kernels.insert({key, kernel});
if (ok) {
data.size += iter->second.size;
// Mangle kernel_name as kernel-key
iter->second.compiled_kernel_data.kernel_name = key;
}
}
// Dump
for (const auto &[_, k] : data.kernels) {
const auto code_filepath = join_path(
config_.cache_path, fmt::format(kMetalCodeFormat, k.kernel_key));
if (try_lock_with_file(code_filepath)) {
std::ofstream fs{code_filepath};
fs << k.compiled_kernel_data.source_code;
}
}
write_to_binary_file(data, filepath);
}
}
}

CompiledKernelData CacheManager::compile_kernel(Kernel *kernel) const {
kernel->lower();
return run_codegen(config_.compiled_runtime_module_,
*config_.compiled_snode_trees_, kernel, nullptr);
}

std::string CacheManager::make_kernel_key(const CompileConfig *compile_config,
Kernel *kernel) const {
if (config_.mode < MemAndDiskCache) {
return kernel->get_name();
}
auto key = kernel->get_cached_kernel_key();
if (key.empty()) {
key = get_hashed_offline_cache_key(compile_config, kernel);
kernel->set_kernel_key_for_cache(key);
}
return key;
}

std::optional<CompiledKernelData> CacheManager::try_load_cached_kernel(
Kernel *kernel,
const std::string &key) {
TI_ASSERT(config_.mode > NotCache);
{ // Find in memory-cache (caching_kernels_)
const auto &kernels = caching_kernels_;
auto iter = kernels.find(key);
if (iter != kernels.end()) {
TI_DEBUG("Create kernel '{}' from in-memory cache (key='{}')",
kernel->get_name(), key);
kernel->mark_as_from_cache();
return iter->second.compiled_kernel_data;
}
}
// Find in disk-cache (cached_data_)
if (config_.mode == MemAndDiskCache) {
auto &kernels = cached_data_.kernels;
auto iter = kernels.find(key);
if (iter != kernels.end()) {
auto &k = iter->second;
TI_ASSERT(k.kernel_key == k.compiled_kernel_data.kernel_name);
if (load_kernel_source_code(k)) {
TI_DEBUG("Create kernel '{}' from cache (key='{}')", kernel->get_name(),
key);
k.last_used_at = std::time(nullptr);
updated_data_.push_back(&k);
kernel->mark_as_from_cache();
return k.compiled_kernel_data;
}
}
}
return std::nullopt;
}

CompiledKernelData CacheManager::compile_and_cache_kernel(
const std::string &key,
Kernel *kernel) {
TI_DEBUG_IF(config_.mode == MemAndDiskCache, "Cache kernel '{}' (key='{}')",
kernel->get_name(), key);
OfflineCacheKernelMetadata k;
k.kernel_key = key;
k.created_at = k.last_used_at = std::time(nullptr);
k.compiled_kernel_data = compile_kernel(kernel);
k.size = k.compiled_kernel_data.source_code.size();
const auto &kernel_data = (caching_kernels_[key] = std::move(k));
return kernel_data.compiled_kernel_data;
}

bool CacheManager::load_kernel_source_code(
OfflineCacheKernelMetadata &kernel_data) {
auto &src = kernel_data.compiled_kernel_data.source_code;
if (!src.empty()) {
return true;
}
auto filepath =
join_path(config_.cache_path,
fmt::format(kMetalCodeFormat, kernel_data.kernel_key));
std::ifstream f(filepath);
if (!f.is_open()) {
return false;
}
f.seekg(0, std::ios::end);
const auto len = f.tellg();
f.seekg(0, std::ios::beg);
src.resize(len);
f.read(&src[0], len);
return !f.fail();
}

} // namespace metal
} // namespace taichi::lang
74 changes: 74 additions & 0 deletions taichi/cache/metal/cache_manager.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
#pragma once

#include "taichi/codegen/metal/struct_metal.h"
#include "taichi/common/serialization.h"
#include "taichi/program/compile_config.h"
#include "taichi/runtime/metal/kernel_utils.h"
#include "taichi/util/offline_cache.h"

namespace taichi::lang {
namespace metal {

struct OfflineCacheKernelMetadata : public offline_cache::KernelMetadataBase {
CompiledKernelData compiled_kernel_data;

TI_IO_DEF_WITH_BASECLASS(offline_cache::KernelMetadataBase,
compiled_kernel_data);
};

class CacheManager {
static constexpr char kMetadataFilename[] = "metadata.tcb";
static constexpr char kMetadataLockName[] = "metadata.lock";
static constexpr char kMetalCodeFormat[] =
"{}.metal"; // "{kernel-key}.metal"
using CompiledKernelData = taichi::lang::metal::CompiledKernelData;
using CachingData =
std::unordered_map<std::string, OfflineCacheKernelMetadata>;

public:
using Metadata = offline_cache::Metadata<OfflineCacheKernelMetadata>;
enum Mode { NotCache, MemCache, MemAndDiskCache };

struct Params {
Mode mode{MemCache};
std::string cache_path;
const CompiledRuntimeModule *compiled_runtime_module_{nullptr};
const std::vector<CompiledStructs> *compiled_snode_trees_{nullptr};
};

CacheManager(Params &&init_params);

// Load from memory || Load from disk || (Compile && Cache the result in
// memory)
CompiledKernelData load_or_compile(const CompileConfig *compile_config,
Kernel *kernel);

// Dump the cached data in memory to disk
void dump_with_merging() const;

// Run offline cache cleaning
void clean_offline_cache(offline_cache::CleanCachePolicy policy,
int max_bytes,
double cleaning_factor) const {
TI_NOT_IMPLEMENTED;
}

private:
CompiledKernelData compile_kernel(Kernel *kernel) const;
std::string make_kernel_key(const CompileConfig *compile_config,
Kernel *kernel) const;
std::optional<CompiledKernelData> try_load_cached_kernel(
Kernel *kernel,
const std::string &key);
CompiledKernelData compile_and_cache_kernel(const std::string &key,
Kernel *kernel);
bool load_kernel_source_code(OfflineCacheKernelMetadata &kernel_data);

Params config_;
CachingData caching_kernels_;
Metadata cached_data_;
std::vector<Metadata::KernelMetadata *> updated_data_;
};

} // namespace metal
} // namespace taichi::lang
15 changes: 5 additions & 10 deletions taichi/codegen/metal/codegen_metal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1689,17 +1689,12 @@ CompiledKernelData run_codegen(
return codegen.run();
}

FunctionType compile_to_metal_executable(
Kernel *kernel,
KernelManager *kernel_mgr,
const CompiledRuntimeModule *compiled_runtime_module,
const std::vector<CompiledStructs> &compiled_snode_trees,
OffloadedStmt *offloaded) {
const auto compiled_res = run_codegen(
compiled_runtime_module, compiled_snode_trees, kernel, offloaded);
kernel_mgr->register_taichi_kernel(compiled_res);
FunctionType compiled_kernel_to_metal_executable(
const CompiledKernelData &compiled_kernel,
KernelManager *kernel_mgr) {
kernel_mgr->register_taichi_kernel(compiled_kernel);
return [kernel_mgr,
kernel_name = compiled_res.kernel_name](RuntimeContext &ctx) {
kernel_name = compiled_kernel.kernel_name](RuntimeContext &ctx) {
kernel_mgr->launch_taichi_kernel(kernel_name, &ctx);
};
}
Expand Down
12 changes: 3 additions & 9 deletions taichi/codegen/metal/codegen_metal.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,9 @@ CompiledKernelData run_codegen(
Kernel *kernel,
OffloadedStmt *offloaded);

// If |offloaded| is nullptr, this compiles the AST in |kernel|. Otherwise it
// compiles just |offloaded|. These ASTs must have already been lowered at the
// CHI level.
FunctionType compile_to_metal_executable(
Kernel *kernel,
KernelManager *kernel_mgr,
const CompiledRuntimeModule *compiled_runtime_module,
const std::vector<CompiledStructs> &compiled_snode_trees,
OffloadedStmt *offloaded = nullptr);
FunctionType compiled_kernel_to_metal_executable(
const CompiledKernelData &compiled_kernel,
KernelManager *kernel_mgr);

} // namespace metal
} // namespace taichi::lang
7 changes: 7 additions & 0 deletions taichi/common/serialization.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,13 @@ serialize_kv_impl(SER &ser,
TI_IO(__VA_ARGS__); \
}

#define TI_IO_DEF_WITH_BASECLASS(BaseClass, ...) \
template <typename S> \
void io(S &serializer) const { \
this->BaseClass::io(serializer); \
TI_IO(__VA_ARGS__); \
}

// This macro serializes each field with its name by doing the following:
// 1. Stringifies __VA_ARGS__, then split the stringified result by ',' at
// compile time.
Expand Down
1 change: 1 addition & 0 deletions taichi/runtime/program_impls/metal/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@ target_include_directories(metal_program_impl
)

target_link_libraries(metal_program_impl PUBLIC metal_codegen)
target_link_libraries(metal_program_impl PUBLIC metal_cache)
target_link_libraries(metal_program_impl PUBLIC metal_runtime)
Loading

0 comments on commit 38c0277

Please sign in to comment.