diff --git a/examples/models/llama2/runner/runner.cpp b/examples/models/llama2/runner/runner.cpp index cd5346bacda..3a0a4025417 100644 --- a/examples/models/llama2/runner/runner.cpp +++ b/examples/models/llama2/runner/runner.cpp @@ -16,6 +16,7 @@ #include #endif /* ET_USE_TIKTOKEN*/ #include +#include #include #include @@ -66,13 +67,17 @@ Error Runner::load() { const auto method_names = module_->method_names(); ET_CHECK_MSG(method_names.ok(), "Failed to read method names from model"); model_methods_ = method_names.get(); - n_bos_ = getMetadataHelper("get_n_bos", 1); - n_eos_ = getMetadataHelper("get_n_eos", 1); - max_seq_len_ = getMetadataHelper("get_max_seq_len", 128); - use_kv_cache_ = getMetadataHelper("use_kv_cache", true); - use_sdpa_with_kv_cache_ = getMetadataHelper("use_sdpa_with_kv_cache", false); - append_eos_ = getMetadataHelper("append_eos_to_prompt", false); - enable_parallel_prefill_ = getMetadataHelper("enable_dynamic_shape", false); + n_bos_ = get_module_metadata(module_.get(), "get_n_bos", 1); + n_eos_ = get_module_metadata(module_.get(), "get_n_eos", 1); + max_seq_len_ = + get_module_metadata(module_.get(), "get_max_seq_len", 128); + use_kv_cache_ = get_module_metadata(module_.get(), "use_kv_cache", true); + use_sdpa_with_kv_cache_ = + get_module_metadata(module_.get(), "use_sdpa_with_kv_cache", false); + append_eos_ = + get_module_metadata(module_.get(), "append_eos_to_prompt", false); + enable_parallel_prefill_ = + get_module_metadata(module_.get(), "enable_dynamic_shape", false); // Load tokenizer #if ET_USE_TIKTOKEN @@ -82,10 +87,12 @@ Error Runner::load() { #endif tokenizer_->load(tokenizer_path_); - vocab_size_ = - getMetadataHelper("get_vocab_size", tokenizer_->vocab_size()); - bos_id_ = getMetadataHelper("get_bos_id", tokenizer_->bos_tok()); - eos_id_ = getMetadataHelper("get_eos_id", tokenizer_->eos_tok()); + vocab_size_ = get_module_metadata( + module_.get(), "get_vocab_size", tokenizer_->vocab_size()); + bos_id_ = get_module_metadata( + module_.get(), "get_bos_id", tokenizer_->bos_tok()); + eos_id_ = get_module_metadata( + module_.get(), "get_eos_id", tokenizer_->eos_tok()); // Create sampler sampler_ = std::make_unique( @@ -97,28 +104,6 @@ Error Runner::load() { return Error::Ok; } -template -T Runner::getMetadataHelper(const std::string& method_name, T default_val) { - T res = default_val; - if (model_methods_.count(method_name)) { - Result> outputs = module_->execute(method_name); - if (outputs.ok()) { - std::vector outs = outputs.get(); - if (outs.size() > 0) { - res = outs[0].to(); - } - } - } else { - ET_LOG( - Info, - "The model does not contain %s method, using default value %lld", - method_name.c_str(), - (long long)default_val); - } - ET_LOG(Info, "%s: %lld", method_name.c_str(), (long long)res); - return res; -} - int32_t Runner::logitsToToken(const exec_aten::Tensor& logits_tensor) { ET_CHECK_MSG(logits_tensor.dim() == 3, "Logits tensor must be 3D"); auto num_tokens = logits_tensor.size(1); @@ -485,12 +470,4 @@ Error Runner::generate( void Runner::stop() { shouldStop_ = true; } - -// explicit instantiation of template methods -template int64_t Runner::getMetadataHelper( - const std::string& method_name, - int64_t default_val); -template bool Runner::getMetadataHelper( - const std::string& method_name, - bool default_val); } // namespace torch::executor diff --git a/examples/models/llama2/runner/runner.h b/examples/models/llama2/runner/runner.h index 407527531df..7e2cb612189 100644 --- a/examples/models/llama2/runner/runner.h +++ b/examples/models/llama2/runner/runner.h @@ -44,9 +44,6 @@ class Runner { void stop(); private: - // metadata - template - T getMetadataHelper(const std::string& method_name, T default_val); int32_t logitsToToken(const exec_aten::Tensor& logits_tensor); Result prefill( const std::vector& tokens, diff --git a/extension/module/metadata_util.h b/extension/module/metadata_util.h new file mode 100644 index 00000000000..4ea2d9eebd5 --- /dev/null +++ b/extension/module/metadata_util.h @@ -0,0 +1,46 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/** + * Constant metadata can be serialized in .pte files, this helper enables + * easy access to the metadata. + */ +#pragma once + +#include + +namespace torch::executor { +template +T get_module_metadata( + Module* module, + const std::string& method_name, + T default_val) { + const auto method_names = module->method_names(); + ET_CHECK_MSG(method_names.ok(), "Failed to read method names from model"); + auto model_methods = method_names.get(); + + T res = default_val; + if (model_methods.count(method_name)) { + Result> outputs = module->execute(method_name); + if (outputs.ok()) { + std::vector outs = outputs.get(); + if (outs.size() > 0) { + res = outs[0].to(); + } + } + } else { + ET_LOG( + Info, + "The model does not contain %s method, using default value %lld", + method_name.c_str(), + (long long)default_val); + } + ET_LOG(Info, "%s: %lld", method_name.c_str(), (long long)res); + return res; +} +} // namespace torch::executor diff --git a/extension/module/targets.bzl b/extension/module/targets.bzl index 61251047dc8..07020b03a88 100644 --- a/extension/module/targets.bzl +++ b/extension/module/targets.bzl @@ -17,6 +17,7 @@ def define_common_targets(): ], exported_headers = [ "module.h", + "metadata_util.h", ], visibility = [ "@EXECUTORCH_CLIENTS",