From 4a272eee75dd99249f61ae36cfae6e9cffe82fb1 Mon Sep 17 00:00:00 2001 From: Joey Tsai Date: Fri, 15 Nov 2024 10:38:39 +0800 Subject: [PATCH 1/8] Qualcomm AI Engine Direct - Suport bert mode for llama3.2 - Enable bert mode - Change input sequence of static_llama - Tag bert output as uint8 - Unify both 1b and 3b in 1 runner - Add hybrid IO memory for llama3_2 runner - Align timer with llama --- examples/qualcomm/oss_scripts/llama2/llama.py | 131 +++++++++++- .../oss_scripts/llama2/model/static_llama.py | 180 +++++++++------- .../oss_scripts/llama2/runner/runner.cpp | 10 +- .../oss_scripts/llama2/runner/runner.h | 6 +- .../oss_scripts/llama3_2/CMakeLists.txt | 39 +--- .../qualcomm/oss_scripts/llama3_2/llama.py | 139 +++++++++++-- .../llama3_2/qnn_llama3_2_runner.cpp | 12 +- .../oss_scripts/llama3_2/runner/io_memory.cpp | 194 +++++++++++++++--- .../oss_scripts/llama3_2/runner/io_memory.h | 69 ++++--- .../oss_scripts/llama3_2/runner/runner.cpp | 136 ++++++++---- .../oss_scripts/llama3_2/runner/runner.h | 15 +- install_requirements.py | 2 +- 12 files changed, 692 insertions(+), 241 deletions(-) diff --git a/examples/qualcomm/oss_scripts/llama2/llama.py b/examples/qualcomm/oss_scripts/llama2/llama.py index b4a9a60c20a..9af32bd7328 100755 --- a/examples/qualcomm/oss_scripts/llama2/llama.py +++ b/examples/qualcomm/oss_scripts/llama2/llama.py @@ -177,14 +177,15 @@ def annotate_conv2d(node: Node, quantization_config: QuantizationConfig) -> None ) -def calibrate( +def _kv_calibrate( example_inputs, user_prompts, module: torch.fx.GraphModule, tokenizer_model_path="tokenizer.model", + max_seq_len=512, ): sp_model = SentencePieceProcessor(model_file=tokenizer_model_path) - _, _, atten_mask, k_caches, v_caches = example_inputs + _, atten_mask, _, k_caches, v_caches = example_inputs # TODO: change criteria & support batch inputs if necessary pos = torch.tensor(0, dtype=torch.int32) @@ -202,11 +203,11 @@ def sample_top_p(probs: torch.Tensor, top_p: float) -> torch.Tensor: return probs_indices.gather(dim=-1, index=next_token) with torch.no_grad(): - while token_list[-1] != sp_model.eos_id() and pos < 128: + while token_list[-1] != sp_model.eos_id() and pos < max_seq_len: logits, new_k_caches, new_v_caches = module( torch.full((1, 1), token_list[pos]), - torch.full((1, 1), pos), atten_mask, + torch.full((1, 1), pos), *k_caches, *v_caches, ) @@ -228,6 +229,69 @@ def sample_top_p(probs: torch.Tensor, top_p: float) -> torch.Tensor: print(f"calibration data:\n{sp_model.decode(token_list)}") +def _bert_calibrate( + example_inputs, + user_prompts, + module: torch.fx.GraphModule, + tokenizer_model_path="tokenizer.model", + max_seq_len=512, +): + sp_model = SentencePieceProcessor(model_file=tokenizer_model_path) + _, atten_mask = example_inputs + max_cache_len = max_seq_len - 1 + + # TODO: change criteria & support batch inputs if necessary + token_list = sp_model.encode(user_prompts, bos=True, eos=False) + token_list = torch.tensor(token_list)[:max_cache_len].reshape(1, -1) + last_prompt_pos = token_list.numel() + if last_prompt_pos < max_cache_len: + token_list = torch.cat( + [ + token_list, + torch.zeros((1, max_cache_len - last_prompt_pos), dtype=torch.int64), + ], + dim=1, + ) + else: + token_list = token_list[:, :max_cache_len] + + with torch.no_grad(): + logits, new_k_caches, new_v_caches = module( + token_list, + atten_mask, + ) + predict = [torch.argmax(logits[:, last_prompt_pos - 1], dim=-1).item()] + + print(f"calibration data:\n{sp_model.decode(predict)}") + + +def calibrate( + example_inputs, + user_prompts, + module: torch.fx.GraphModule, + tokenizer_model_path="tokenizer.model", + max_seq_len=512, +): + if len(example_inputs) == 2: + _bert_calibrate( + example_inputs, + user_prompts, + module, + tokenizer_model_path, + max_seq_len, + ) + elif len(example_inputs) == 5: + _kv_calibrate( + example_inputs, + user_prompts, + module, + tokenizer_model_path, + max_seq_len, + ) + else: + raise RuntimeError("Get wrong inputs") + + class SingleLlama: def __init__(self, llama_model) -> None: super().__init__() @@ -235,8 +299,14 @@ def __init__(self, llama_model) -> None: self.quant_dtype = None self.llama_meta = self.llama_model.get_metadata() self.has_quant_io = False - tokens, pos_ids, atten_mask, k_caches, v_caches = self.get_example_inputs() - self.inputs = (tokens, pos_ids, atten_mask, *k_caches, *v_caches) + if self.llama_meta["get_use_kv_cache"]: + tokens, atten_mask, pos_ids, k_caches, v_caches = self.get_example_inputs( + use_kv_cache=True + ) + self.inputs = (tokens, atten_mask, pos_ids, *k_caches, *v_caches) + else: + tokens, atten_mask = self.get_example_inputs(use_kv_cache=False) + self.inputs = (tokens, atten_mask) def _tag_kv_ios(self, gm: torch.fx.GraphModule, kv_type): if not self.has_quant_io: @@ -256,11 +326,17 @@ def _tag_kv_ios(self, gm: torch.fx.GraphModule, kv_type): n.meta[QCOM_QUANTIZED_IO] = kv_type elif n.op == "output": for a in n.args[0]: + # single head, kv mode if ( a.meta["val"].flatten().size()[0] == self.llama_meta["get_head_dim"] ): a.meta[QCOM_QUANTIZED_IO] = kv_type + # single head, bert mode + elif a.meta["val"].flatten().size()[0] == self.llama_meta[ + "get_head_dim" + ] * (self.llama_meta["get_max_seq_len"] - 1): + a.meta[QCOM_QUANTIZED_IO] = kv_type def quantize(self, quant_dtype, custom_annotations=()): self.quant_dtype = quant_dtype @@ -281,11 +357,13 @@ def quantize(self, quant_dtype, custom_annotations=()): ).module() fx_graph_module = prepare_pt2e(fx_graph_module, quantizer) print("Quantizing the model...") + calibrate( - self.get_example_inputs(), + self.get_example_inputs(self.llama_meta["get_use_kv_cache"]), args.prompt, fx_graph_module, tokenizer_model_path=args.tokenizer_model, + max_seq_len=args.seq_len, ) self.llama_model = convert_pt2e(fx_graph_module) @@ -328,18 +406,29 @@ def lowering_modules( with open(f"{work_space}/{pte_filename}.pte", "wb") as file: exec_prog_mgr.write_to_file(file) - def get_example_inputs(self): - return self.llama_model.get_example_inputs() + def get_example_inputs(self, use_kv_cache=True): + return self.llama_model.get_example_inputs(use_kv_cache) def compile(args): os.makedirs(args.artifact, exist_ok=True) start_ts = time.time() + + if args.model_mode == "kv": + use_kv_cache = output_new_cache_only = True + elif args.model_mode == "bert" or args.model_mode == "hybrid": + raise NotImplementedError( + f"model_mode {args.model_mode} is not implemented yet." + ) + else: + raise RuntimeError(f"No such model_mode {args.model_mode}.") + with open(args.params) as f: config = ModelArgs(**json.load(f)) # TODO: support batch inputs if necessary config.max_batch_size = 1 - config.max_seq_len = 1024 + config.max_seq_len = args.seq_len + config.use_kv_cache = use_kv_cache state_dict = torch.load( args.checkpoint, weights_only=True, map_location="cpu", mmap=True ) @@ -348,7 +437,7 @@ def compile(args): llama_instance = None with torch.device("meta"): - llama_instance = LlamaModel(config, output_new_cache_only=True) + llama_instance = LlamaModel(config, output_new_cache_only=output_new_cache_only) if "model" in state_dict: state_dict = state_dict["model"] llama_instance.load_state_dict( @@ -398,6 +487,11 @@ def compile(args): def inference(args, pre_gen_pte=""): workspace = f"/data/local/tmp/{getpass.getuser()}/executorch/single_llama" + if args.model_mode != "kv": + raise NotImplementedError( + f"model_mode {args.model_mode} is not implemented yet." + ) + runner_args = " ".join( [ f"--model_path {pte_filename}.pte", @@ -550,6 +644,21 @@ def post_process(): type=str, ) + parser.add_argument( + "--num_sharding", + type=int, + default=0, + help="Specify the number of splits by inserting the fallback custom op. The graph will be split evenly by layers.", + ) + + parser.add_argument( + "--model_mode", + help="Export and inference bert mode, kv mode or hybrid(TBD) mode", + default="kv", + choices=["bert", "kv", "hybrid"], + type=str, + ) + args = parser.parse_args() if args.compile_only and args.pre_gen_pte: exit("Cannot set both compile_only and pre_gen_pte as true") diff --git a/examples/qualcomm/oss_scripts/llama2/model/static_llama.py b/examples/qualcomm/oss_scripts/llama2/model/static_llama.py index 49fc5721281..ccfe3d62af5 100755 --- a/examples/qualcomm/oss_scripts/llama2/model/static_llama.py +++ b/examples/qualcomm/oss_scripts/llama2/model/static_llama.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import List, Tuple +from typing import List, Optional, Tuple import torch import torch.nn as nn @@ -16,25 +16,15 @@ ) -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand( - batch, num_key_value_heads, n_rep, slen, head_dim - ) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - def apply_rotary_emb_single( x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor ) -> torch.Tensor: x_r, x_i = x[..., ::2], x[..., 1::2] + # brodcast for bert mode input x + if x.dim() == 4: + freqs_cos = freqs_cos[None, :, None, :] + freqs_sin = freqs_sin[None, :, None, :] x_out_r = x_r * freqs_cos - x_i * freqs_sin x_out_i = x_r * freqs_sin + x_i * freqs_cos @@ -103,8 +93,8 @@ def forward_sha( freqs_cos: torch.Tensor, freqs_sin: torch.Tensor, atten_mask: torch.Tensor, - k_caches: List[torch.Tensor], - v_caches: List[torch.Tensor], + k_caches: Optional[List[torch.Tensor]] = None, + v_caches: Optional[List[torch.Tensor]] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: q = [wq_sha(hidden_states) for wq_sha in self.wq_sha] k = [wk_sha(hidden_states) for wk_sha in self.wk_sha] @@ -116,9 +106,15 @@ def forward_sha( output_y = [] kh, vh = [], [] - for i, _ in enumerate(k_caches): - kh.append(torch.cat([k_caches[i], k[i]], dim=-1)) - vh.append(torch.cat([v_caches[i], v[i]], dim=1)) + # kv cache mode + if k_caches and v_caches: + for i, _ in enumerate(k_caches): + kh.append(torch.cat([k_caches[i], k[i]], dim=-1)) + vh.append(torch.cat([v_caches[i], v[i]], dim=1)) + # bert/prefill mode + else: + kh = k + vh = v for i, _ in enumerate(q): cache_idx = i // self.num_key_value_groups @@ -131,7 +127,14 @@ def forward_sha( y = torch.concat(output_y, dim=-1) y = self.wo(y) - return y, k, v + + if self.output_new_cache_only: + if k_caches and v_caches: + return y, k, v + # bert mode. Consider to remove, it's not really used + return y, k[-1], v[-1] + + return y, kh, vh def forward( self, @@ -154,27 +157,42 @@ def forward( output_kh, output_vh, output_y = [], [], [] kh, vh = [], [] - for i, _ in enumerate(k_caches): - kh.append(torch.cat([k_caches[i], k[:, i, :, :]], dim=-1)) - vh.append(torch.cat([v_caches[i], v[:, :, i, :]], dim=1)) - - for i in range(self.n_heads): - cache_idx = i // self.num_key_value_groups - - attn = q[:, :, i, :] @ kh[cache_idx] - attn = attn / self.scale + atten_mask - attn = self.attn_softmax(attn) - y = attn @ vh[cache_idx] + # kv cache mode + if k_caches and v_caches: + for i, _ in enumerate(k_caches): + kh.append(torch.cat([k_caches[i], k[:, i, :, :]], dim=-1)) + vh.append(torch.cat([v_caches[i], v[:, :, i, :]], dim=1)) + for i in range(self.n_heads): + cache_idx = i // self.num_key_value_groups + + attn = q[:, :, i, :] @ kh[cache_idx] + attn = attn / self.scale + atten_mask + attn = self.attn_softmax(attn) + y = attn @ vh[cache_idx] + + output_y.append(y) + + # bert/prefill mode + else: + kh = k + vh = v + for i in range(self.n_heads): + cache_idx = i // self.num_key_value_groups + + attn = q[:, :, i, :] @ kh[:, cache_idx, :, :] + attn = attn / self.scale + atten_mask + attn = self.attn_softmax(attn) + y = attn @ vh[:, :, cache_idx, :] + + output_y.append(y) - output_y.append(y) - - for i in range(len(k_caches)): + for i in range(self.n_kv_heads): if self.output_new_cache_only: + output_kh.append(k[:, i, :, -1]) + output_vh.append(v[:, -1, i, :]) + else: output_kh.append(k[:, i, :, :]) output_vh.append(v[:, :, i, :]) - else: - output_kh.append(kh[i]) - output_vh.append(vh[i]) y = torch.concat(output_y, dim=-1) y = self.wo(y) @@ -227,6 +245,7 @@ def __init__(self, config: ModelArgs, output_new_cache_only=True): self.n_layers = config.n_layers self.vocab_size = config.vocab_size self.rope_freq_base = config.rope_freq_base + self.use_kv_cache = config.use_kv_cache self.output_new_cache_only = output_new_cache_only self.layers = nn.ModuleList( @@ -249,22 +268,30 @@ def __init__(self, config: ModelArgs, output_new_cache_only=True): def forward( self, tokens: torch.Tensor, - input_pos: torch.Tensor, atten_mask: torch.Tensor, + input_pos: Optional[torch.Tensor] = None, *args, ) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]: + output_k_cache = [] output_v_cache = [] # following tensors should be invariant across batches - freqs_cos = self.freqs_cos[input_pos][0] - freqs_sin = self.freqs_sin[input_pos][0] + freqs_cos = ( + self.freqs_cos[input_pos][0] if self.use_kv_cache else self.freqs_cos[:-1] + ) + freqs_sin = ( + self.freqs_sin[input_pos][0] if self.use_kv_cache else self.freqs_sin[:-1] + ) hidden_states = self.tok_embeddings(tokens) for ind, decoder_layer in enumerate(self.layers): - offset_k = ind * self.n_kv_heads - offset_v = self.n_layers * self.n_kv_heads + offset_k - k_caches = args[offset_k : offset_k + self.n_kv_heads] - v_caches = args[offset_v : offset_v + self.n_kv_heads] + k_caches = None + v_caches = None + if self.use_kv_cache: + offset_k = ind * self.n_kv_heads + offset_v = self.n_layers * self.n_kv_heads + offset_k + k_caches = args[offset_k : offset_k + self.n_kv_heads] + v_caches = args[offset_v : offset_v + self.n_kv_heads] hidden_states, k, v = decoder_layer( hidden_states, freqs_cos=freqs_cos, @@ -281,37 +308,47 @@ def forward( return logits, output_k_cache, output_v_cache - def get_example_inputs(self): - tokens = torch.randint( - self.vocab_size, (self.max_batch_size, 1), dtype=torch.int32 - ) - pos_ids = torch.zeros((self.max_batch_size, 1), dtype=torch.int32) - k_cache, v_cache = [], [] - atten_mask = torch.full((self.max_batch_size, self.max_seq_len), -255.0) - atten_mask[:, -1] = 0 - for _ in range(self.n_layers): - for _ in range(self.n_kv_heads): - # transpose first to decrease the runtime efforts - k_cache.append( - torch.zeros( - self.max_batch_size, - self.head_dim, - self.max_seq_len - 1, + def get_example_inputs(self, use_kv_cache=True): + if use_kv_cache: + tokens = torch.randint( + self.vocab_size, (self.max_batch_size, 1), dtype=torch.int32 + ) + pos_ids = torch.zeros((self.max_batch_size, 1), dtype=torch.int32) + k_cache, v_cache = [], [] + atten_mask = torch.full((self.max_batch_size, self.max_seq_len), -255.0) + atten_mask[:, -1] = 0 + for _ in range(self.n_layers): + for _ in range(self.n_kv_heads): + # transpose first to decrease the runtime efforts + k_cache.append( + torch.zeros( + self.max_batch_size, + self.head_dim, + self.max_seq_len - 1, + ) ) - ) - v_cache.append( - torch.zeros( - self.max_batch_size, - self.max_seq_len - 1, - self.head_dim, + v_cache.append( + torch.zeros( + self.max_batch_size, + self.max_seq_len - 1, + self.head_dim, + ) ) - ) + return ( + tokens, + atten_mask, + pos_ids, + k_cache, + v_cache, + ) + + max_promp = self.max_seq_len - 1 + tokens = torch.arange(0, max_promp, 1, dtype=torch.int32).unsqueeze(0) + atten_mask = torch.triu(torch.rand((max_promp, max_promp)), 1) + atten_mask[atten_mask != 0] = -255 return ( tokens, - pos_ids, atten_mask, - k_cache, - v_cache, ) def get_metadata(self): @@ -328,4 +365,5 @@ def get_metadata(self): "get_n_kv_heads": self.n_kv_heads, "get_n_layers": self.n_layers, "get_vocab_size": self.vocab_size, + "get_use_kv_cache": self.use_kv_cache, } diff --git a/examples/qualcomm/oss_scripts/llama2/runner/runner.cpp b/examples/qualcomm/oss_scripts/llama2/runner/runner.cpp index 358ad37b729..3f055127324 100644 --- a/examples/qualcomm/oss_scripts/llama2/runner/runner.cpp +++ b/examples/qualcomm/oss_scripts/llama2/runner/runner.cpp @@ -154,15 +154,15 @@ int32_t Runner::logitsToToken(const Tensor& logits_tensor) { Result Runner::run_model_step( int64_t input_token, TensorPtr& token, - TensorPtr& start_pos, TensorPtr& atten_mask, + TensorPtr& start_pos, std::vector& kv_tensors, std::vector& kv_outputs) { token->mutable_data_ptr()[0] = input_token; // inputs:[tokens, start_pos, atten_mask, k_cache, v_cache] std::vector inputs = { - token, start_pos, atten_mask}; + token, atten_mask, start_pos}; inputs.insert(inputs.end(), kv_tensors.begin(), kv_tensors.end()); auto outputs_res = module_->forward(inputs); ET_CHECK_OK_OR_RETURN_ERROR(outputs_res.error()); @@ -359,7 +359,7 @@ Error Runner::generate( while (pos < seq_len - 1) { // Run the model auto logits_res = run_model_step( - cur_token, token, start_pos, atten_mask, kv_tensors, kv_outputs); + cur_token, token, atten_mask, start_pos, kv_tensors, kv_outputs); if (pos == num_prompt_tokens) { stats_.first_token_ms = time_in_ms(); } else if (pos == num_prompt_tokens - 1) { @@ -517,9 +517,9 @@ void IoMemMgr::init_io_info() { void IoMemMgr::set_tensor_meta() { io_info_.input_token.tensor_meta = std::make_unique(method_meta_->input_tensor_meta(0).get()); - io_info_.pos_idx.tensor_meta = - std::make_unique(method_meta_->input_tensor_meta(1).get()); io_info_.atten_mask.tensor_meta = + std::make_unique(method_meta_->input_tensor_meta(1).get()); + io_info_.pos_idx.tensor_meta = std::make_unique(method_meta_->input_tensor_meta(2).get()); io_info_.k_caches_read.tensor_meta = diff --git a/examples/qualcomm/oss_scripts/llama2/runner/runner.h b/examples/qualcomm/oss_scripts/llama2/runner/runner.h index 700cb94f52c..aa0e5eb0ece 100644 --- a/examples/qualcomm/oss_scripts/llama2/runner/runner.h +++ b/examples/qualcomm/oss_scripts/llama2/runner/runner.h @@ -124,8 +124,8 @@ class IoMemMgr { struct IoInfo { InfoAttrs input_token; - InfoAttrs pos_idx; InfoAttrs atten_mask; + InfoAttrs pos_idx; InfoAttrs k_caches_read; InfoAttrs k_caches_write; InfoAttrs v_caches_read; @@ -133,8 +133,8 @@ class IoMemMgr { InfoAttrs logit; std::vector tensor_info{ &input_token, - &pos_idx, &atten_mask, + &pos_idx, &k_caches_read, &k_caches_write, &v_caches_read, @@ -252,8 +252,8 @@ class Runner { executorch::runtime::Result run_model_step( int64_t input_token, ::executorch::extension::TensorPtr& token, - ::executorch::extension::TensorPtr& start_pos, ::executorch::extension::TensorPtr& atten_mask, + ::executorch::extension::TensorPtr& start_pos, std::vector<::executorch::extension::TensorPtr>& kv_tensors, std::vector<::executorch::extension::TensorPtr>& kv_outputs); // metadata diff --git a/examples/qualcomm/oss_scripts/llama3_2/CMakeLists.txt b/examples/qualcomm/oss_scripts/llama3_2/CMakeLists.txt index 4ce7d5ee6a7..6090ff7fe47 100644 --- a/examples/qualcomm/oss_scripts/llama3_2/CMakeLists.txt +++ b/examples/qualcomm/oss_scripts/llama3_2/CMakeLists.txt @@ -43,13 +43,13 @@ list( ) # build qnn llama3.2 1b runner -add_executable(qnn_llama3_2_1b_runner ${_llama3_2_runner__srcs}) +add_executable(qnn_llama3_2_runner ${_llama3_2_runner__srcs}) target_include_directories( - qnn_llama3_2_1b_runner PUBLIC ${_common_include_directories} + qnn_llama3_2_runner PUBLIC ${_common_include_directories} ) target_link_libraries( - qnn_llama3_2_1b_runner + qnn_llama3_2_runner qnn_executorch_backend executorch_core extension_data_loader @@ -60,35 +60,8 @@ target_link_libraries( custom_ops ) target_compile_options( - qnn_llama3_2_1b_runner PUBLIC ${_common_compile_options} + qnn_llama3_2_runner PUBLIC ${_common_compile_options} ) set_target_properties( - qnn_llama3_2_1b_runner PROPERTIES LINK_FLAGS "-Wl,-rpath='$ORIGIN'" -) - - -# build qnn llama3.2 3b runner -add_executable(qnn_llama3_2_3b_runner ${_llama3_2_runner__srcs}) -target_include_directories( - qnn_llama3_2_3b_runner PUBLIC ${_common_include_directories} -) -# Adding compile option to differentiate llama3.2 1b with 3b -target_compile_options(qnn_llama3_2_3b_runner PRIVATE -DLLAMA3_2_3B_RUNNER) - -target_link_libraries( - qnn_llama3_2_3b_runner - qnn_executorch_backend - executorch_core - extension_data_loader - extension_module - extension_tensor - gflags - re2::re2 - custom_ops -) -target_compile_options( - qnn_llama3_2_3b_runner PUBLIC ${_common_compile_options} -) -set_target_properties( - qnn_llama3_2_3b_runner PROPERTIES LINK_FLAGS "-Wl,-rpath='$ORIGIN'" -) + qnn_llama3_2_runner PROPERTIES LINK_FLAGS "-Wl,-rpath='$ORIGIN'" +) \ No newline at end of file diff --git a/examples/qualcomm/oss_scripts/llama3_2/llama.py b/examples/qualcomm/oss_scripts/llama3_2/llama.py index 532eb68319d..f086e400cf3 100755 --- a/examples/qualcomm/oss_scripts/llama3_2/llama.py +++ b/examples/qualcomm/oss_scripts/llama3_2/llama.py @@ -62,25 +62,27 @@ pte_filename = "llama3_2_qnn" -def calibrate( +def _kv_calibrate( example_inputs, user_prompts, module: torch.fx.GraphModule, tokenizer_model_path="tokenizer.model", + max_seq_len=512, ): sp_model = get_tokenizer(tokenizer_model_path) - _, _, atten_mask, k_caches, v_caches = example_inputs + _, atten_mask, _, k_caches, v_caches = example_inputs # TODO: change criteria & support batch inputs if necessary pos = torch.tensor(0, dtype=torch.int32) + max_cache_len = max_seq_len - 1 token_list = sp_model.encode(user_prompts, bos=True, eos=False) with torch.no_grad(): - while token_list[-1] != sp_model.eos_id and pos < 511: + while token_list[-1] != sp_model.eos_id and pos < max_cache_len: logits, new_k_caches, new_v_caches = module( torch.full((1, 1), token_list[pos], dtype=torch.int32), - torch.full((1, 1), pos), atten_mask, + torch.full((1, 1), pos), *k_caches, *v_caches, ) @@ -101,6 +103,69 @@ def calibrate( print(f"calibration data:\n{sp_model.decode(token_list)}") +def _bert_calibrate( + example_inputs, + user_prompts, + module: torch.fx.GraphModule, + tokenizer_model_path="tokenizer.model", + max_seq_len=512, +): + sp_model = get_tokenizer(tokenizer_model_path) + _, atten_mask = example_inputs + max_cache_len = max_seq_len - 1 + + # TODO: change criteria & support batch inputs if necessary + token_list = sp_model.encode(user_prompts, bos=True, eos=False) + token_list = torch.tensor(token_list)[:max_cache_len].reshape(1, -1) + last_prompt_pos = token_list.numel() + if last_prompt_pos < max_cache_len: + token_list = torch.cat( + [ + token_list, + torch.zeros((1, max_cache_len - last_prompt_pos), dtype=torch.int64), + ], + dim=1, + ) + else: + token_list = token_list[:, :max_cache_len] + + with torch.no_grad(): + logits, new_k_caches, new_v_caches = module( + token_list, + atten_mask, + ) + predict = [torch.argmax(logits[:, last_prompt_pos - 1], dim=-1).item()] + + print(f"calibration data:\n{sp_model.decode(predict)}") + + +def calibrate( + example_inputs, + user_prompts, + module: torch.fx.GraphModule, + tokenizer_model_path="tokenizer.model", + max_seq_len=512, +): + if len(example_inputs) == 2: + _bert_calibrate( + example_inputs, + user_prompts, + module, + tokenizer_model_path, + max_seq_len, + ) + elif len(example_inputs) == 5: + _kv_calibrate( + example_inputs, + user_prompts, + module, + tokenizer_model_path, + max_seq_len, + ) + else: + raise RuntimeError("Get wrong inputs") + + class SingleLlama: def __init__(self, llama_model) -> None: super().__init__() @@ -108,8 +173,14 @@ def __init__(self, llama_model) -> None: self.quant_dtype = None self.llama_meta = self.llama_model.get_metadata() self.has_quant_io = False - tokens, pos_ids, atten_mask, k_caches, v_caches = self.get_example_inputs() - self.inputs = (tokens, pos_ids, atten_mask, *k_caches, *v_caches) + if self.llama_meta["get_use_kv_cache"]: + tokens, atten_mask, pos_ids, k_caches, v_caches = self.get_example_inputs( + use_kv_cache=True + ) + self.inputs = (tokens, atten_mask, pos_ids, *k_caches, *v_caches) + else: + tokens, atten_mask = self.get_example_inputs(use_kv_cache=False) + self.inputs = (tokens, atten_mask) def _tag_kv_ios(self, gm: torch.fx.GraphModule, kv_type, sharding_type): if not self.has_quant_io: @@ -129,11 +200,17 @@ def _tag_kv_ios(self, gm: torch.fx.GraphModule, kv_type, sharding_type): n.meta[QCOM_QUANTIZED_IO] = kv_type elif n.op == "output": for a in n.args[0]: + # single head, kv mode if ( a.meta["val"].flatten().size()[0] == self.llama_meta["get_head_dim"] ): a.meta[QCOM_QUANTIZED_IO] = kv_type + # single head, bert mode + elif a.meta["val"].flatten().size()[0] == self.llama_meta[ + "get_head_dim" + ] * (self.llama_meta["get_max_seq_len"] - 1): + a.meta[QCOM_QUANTIZED_IO] = kv_type # Tag sharding io if exir_ops.edge.llama.fallback.default in [ @@ -160,11 +237,13 @@ def quantize(self, quant_dtype, custom_annotations=()): ).module() fx_graph_module = prepare_pt2e(fx_graph_module, quantizer) logging.info("Quantizing the model...") + calibrate( - self.get_example_inputs(), + self.get_example_inputs(self.llama_meta["get_use_kv_cache"]), args.prompt, fx_graph_module, tokenizer_model_path=args.tokenizer_model, + max_seq_len=args.seq_len, ) self.llama_model = convert_pt2e(fx_graph_module) @@ -230,25 +309,38 @@ def lowering_modules( with open(f"{work_space}/{pte_filename}.pte", "wb") as file: exec_prog_mgr.write_to_file(file) - def get_example_inputs(self): - return self.llama_model.get_example_inputs() + def get_example_inputs(self, use_kv_cache=True): + return self.llama_model.get_example_inputs(use_kv_cache) def compile(args): os.makedirs(args.artifact, exist_ok=True) start_ts = time.time() + + if args.model_mode == "kv": + use_kv_cache = output_new_cache_only = True + elif args.model_mode == "bert": + use_kv_cache = output_new_cache_only = False + elif args.model_mode == "hybrid": + raise NotImplementedError( + f"model_mode {args.model_mode} is not implemented yet." + ) + else: + raise RuntimeError(f"No such model_mode {args.model_mode}.") + with open(args.params) as f: config = ModelArgs(**json.load(f)) # TODO: support batch inputs if necessary config.max_batch_size = 1 - config.max_seq_len = 512 + config.max_seq_len = args.seq_len + config.use_kv_cache = use_kv_cache state_dict = torch.load( args.checkpoint, weights_only=True, map_location="cpu", mmap=True ) llama_instance = None with torch.device("meta"): - llama_instance = LlamaModel(config, output_new_cache_only=True) + llama_instance = LlamaModel(config, output_new_cache_only=output_new_cache_only) if "model" in state_dict: state_dict = state_dict["model"] llama_instance.load_state_dict( @@ -314,6 +406,18 @@ def compile(args): def inference(args, pre_gen_pte=""): workspace = f"/data/local/tmp/{getpass.getuser()}/executorch/single_llama" + if args.model_mode == "bert": + eval_mode = 0 + elif args.model_mode == "kv": + eval_mode = 1 + elif args.model_mode == "hybrid": + eval_mode = 2 + raise NotImplementedError( + f"model_mode {args.model_mode} is not implemented yet." + ) + else: + raise RuntimeError(f"No such model_mode {args.model_mode}.") + runner_args = " ".join( [ f"--model_path {pte_filename}.pte", @@ -321,13 +425,14 @@ def inference(args, pre_gen_pte=""): f"--tokenizer_path {os.path.basename(args.tokenizer_model)}", f'--prompt "{args.prompt}"', f"--seq_len {args.seq_len}", + f"--eval_mode {eval_mode}", f"--temperature {args.temperature}", ] ) runner_cmd = " ".join( [ f"cd {workspace} &&", - f"./qnn_llama3_2_{args.model_size.lower()}_runner {runner_args}", + f"./qnn_llama3_2_runner {runner_args}", ] ) @@ -345,7 +450,7 @@ def inference(args, pre_gen_pte=""): host_id=args.host, soc_model=args.model, shared_buffer=args.shared_buffer, - runner=f"examples/qualcomm/oss_scripts/llama3_2/qnn_llama3_2_{args.model_size.lower()}_runner", + runner=f"examples/qualcomm/oss_scripts/llama3_2/qnn_llama3_2_runner", ) # No pregen inputs, input_list is not required adb.push(inputs=[], input_list="", files=[args.tokenizer_model]) @@ -466,6 +571,14 @@ def post_process(): help="Specify the number of splits by inserting the fallback custom op. The graph will be split evenly by layers.", ) + parser.add_argument( + "--model_mode", + help="Export and inference bert mode, kv mode or hybrid(TBD) mode", + default="kv", + choices=["bert", "kv", "hybrid"], + type=str, + ) + args = parser.parse_args() if args.compile_only and args.pre_gen_pte: exit("Cannot set both compile_only and pre_gen_pte as true") diff --git a/examples/qualcomm/oss_scripts/llama3_2/qnn_llama3_2_runner.cpp b/examples/qualcomm/oss_scripts/llama3_2/qnn_llama3_2_runner.cpp index f23b66e9300..a184bb42ade 100644 --- a/examples/qualcomm/oss_scripts/llama3_2/qnn_llama3_2_runner.cpp +++ b/examples/qualcomm/oss_scripts/llama3_2/qnn_llama3_2_runner.cpp @@ -16,9 +16,7 @@ #include #include #include - #include - #include DEFINE_string( @@ -45,12 +43,20 @@ DEFINE_int32( 128, "Total number of tokens to generate (prompt + output). Defaults to max_seq_len. If the number of input tokens + seq_len > max_seq_len, the output will be truncated to max_seq_len tokens."); +DEFINE_int32( + eval_mode, + 0, + "0: PromptProcessor(bert) / 1: TokenGenerator(kv) / 2: HybridMode (TBD)"); + int main(int argc, char** argv) { gflags::ParseCommandLineFlags(&argc, &argv, true); // create llama runner example::Runner runner( - {FLAGS_model_path}, FLAGS_tokenizer_path.c_str(), FLAGS_temperature); + {FLAGS_model_path}, + FLAGS_tokenizer_path.c_str(), + FLAGS_temperature, + FLAGS_eval_mode); // generate tokens & store inference output std::ofstream fout(FLAGS_output_path.c_str()); diff --git a/examples/qualcomm/oss_scripts/llama3_2/runner/io_memory.cpp b/examples/qualcomm/oss_scripts/llama3_2/runner/io_memory.cpp index fe5fa4bc0ac..9b37d056cf5 100644 --- a/examples/qualcomm/oss_scripts/llama3_2/runner/io_memory.cpp +++ b/examples/qualcomm/oss_scripts/llama3_2/runner/io_memory.cpp @@ -52,15 +52,25 @@ std::vector Memory::get_output_tensors(int shard_index) { return ret; } -KVCachedMemory::KVCachedMemory(std::vector>& modules) +HybridMemory::HybridMemory( + std::vector>& modules, + int32_t max_seq_len, + int32_t vocab_size, + int32_t num_layers, + int32_t head_dim, + int32_t num_heads) : Memory(modules), - shard_layers_({QNN_LLAMA3_2_NUM_LAYERS}), - num_heads_(QNN_LLAMA3_2_NUM_HEADS) { + shard_layers_({num_layers}), + max_seq_len_(max_seq_len), + vocab_size_(vocab_size), + num_layers_(num_layers), + head_dim_(head_dim), + num_heads_(num_heads) { data_ptr_ = std::unique_ptr( new IO, [](void* ptr) { delete static_cast(ptr); }); } -void KVCachedMemory::prepare_io( +void HybridMemory::prepare_kv_io( const std::vector>& methods_meta) { IO* ptr = static_cast(data_ptr_.get()); std::memset(ptr, 0, sizeof(IO)); @@ -71,6 +81,24 @@ void KVCachedMemory::prepare_io( static_cast(methods_meta[i].error())); } + // Init IO vector shape + // atten_mask + ptr->logits.resize(vocab_size_); + ptr->attention_mask.resize( + max_seq_len_, -255); // attention mask shape should be [1, ctx_length] + // kv + int32_t k_in_size = (head_dim_ + 1) * (max_seq_len_ - 1); + int32_t k_out_size = num_heads_ * head_dim_; + int32_t v_cache_size = (num_heads_ + 1) * (max_seq_len_ - 1) * head_dim_; + for (int layer = 0; layer < num_layers_; layer++) { + ptr->k_cache.emplace_back(); + for (int head = 0; head < num_heads_; head++) { + ptr->k_cache[layer].emplace_back(std::vector(k_in_size)); + } + ptr->k_cache_out.emplace_back(std::vector(k_out_size)); + ptr->v_cache.emplace_back(std::vector(v_cache_size)); + } + // [I]: input_tokens Result input_tok = methods_meta[0]->input_tensor_meta(0); input_tok_ = std::make_unique( @@ -81,8 +109,18 @@ void KVCachedMemory::prepare_io( const_cast(input_tok->dim_order().data())); input_tensors_[0].push_back(input_tok_.get()); + // [I]: atten_mask + Result atten_mask = methods_meta[0]->input_tensor_meta(1); + attention_mask_ = std::make_unique( + atten_mask->scalar_type(), + atten_mask->sizes().size(), + const_cast(atten_mask->sizes().data()), + ptr->attention_mask.data(), + const_cast(atten_mask->dim_order().data())); + input_tensors_[0].push_back(attention_mask_.get()); + // [I]: input_pos - Result input_pos = methods_meta[0]->input_tensor_meta(1); + Result input_pos = methods_meta[0]->input_tensor_meta(2); input_pos_ = std::make_unique( input_pos->scalar_type(), input_pos->sizes().size(), @@ -91,23 +129,11 @@ void KVCachedMemory::prepare_io( const_cast(input_pos->dim_order().data())); input_tensors_[0].push_back(input_pos_.get()); - // [I]: atten_mask - std::fill( - ptr->attention_mask, ptr->attention_mask + QNN_LLAMA3_2_SEQLEN, -255); - Result atten_mask = methods_meta[0]->input_tensor_meta(2); - attention_mask_ = std::make_unique( - atten_mask->scalar_type(), - atten_mask->sizes().size(), - const_cast(atten_mask->sizes().data()), - ptr->attention_mask, - const_cast(atten_mask->dim_order().data())); - input_tensors_[0].push_back(attention_mask_.get()); - // [I] kv_cache int index = 3; // bypass input_tokens, input_pos, atten_mask for (int offset = 0, shard_index = 0, - v_stride = (QNN_LLAMA3_2_SEQLEN - 1) * QNN_LLAMA3_2_HEAD_DIM; + v_stride = (max_seq_len_ - 1) * head_dim_; shard_index < modules_.size(); offset += shard_layers_[shard_index], shard_index++) { for (int cache_group = 0; cache_group < 2; ++cache_group) { @@ -118,9 +144,9 @@ void KVCachedMemory::prepare_io( std::vector>& cache = (cache_group == 0 ? k_cache_in_ : v_cache_in_); void* cache_ptr = (cache_group == 0) - ? static_cast(ptr->k_cache[layer + offset][head]) + ? static_cast(ptr->k_cache[layer + offset][head].data()) : static_cast( - ptr->v_cache[layer + offset] + head * v_stride); + ptr->v_cache[layer + offset].data() + head * v_stride); cache.emplace_back(std::make_unique( kv_cache->scalar_type(), @@ -143,15 +169,123 @@ void KVCachedMemory::prepare_io( logits->scalar_type(), logits->sizes().size(), const_cast(logits->sizes().data()), - ptr->logits, + ptr->logits.data(), const_cast(logits->dim_order().data())); output_tensors_[modules_.size() - 1].push_back(logits_.get()); // [O] kv_cache index = 1; + // Iterate through all kv cache outputs. + // For k, we store it in k_cache_out and update to k_cache later. + // For v, we append the output to the end of v_cache, + // which serves as both input and output. for (int offset = 0, shard_index = 0, - v_stride = (QNN_LLAMA3_2_SEQLEN - 1) * QNN_LLAMA3_2_HEAD_DIM; + v_stride = (max_seq_len_ - 1) * head_dim_; + shard_index < modules_.size(); + offset += shard_layers_[shard_index], shard_index++) { + for (int cache_group = 0; cache_group < 2; ++cache_group) { + for (int layer = 0; layer < shard_layers_[shard_index]; ++layer) { + for (int head = 0; head < num_heads_; ++head, ++index) { + Result kv_cache = + methods_meta[shard_index]->output_tensor_meta(index); + std::vector>& cache = + (cache_group == 0 ? k_cache_out_ : v_cache_out_); + void* cache_ptr = (cache_group == 0) + ? static_cast( + ptr->k_cache_out[layer + offset].data() + + (head * head_dim_)) + : static_cast( + ptr->v_cache[layer + offset].data() + + (head + 1) * v_stride); + cache.emplace_back(std::make_unique( + kv_cache->scalar_type(), + kv_cache->sizes().size(), + const_cast(kv_cache->sizes().data()), + cache_ptr, + const_cast( + kv_cache->dim_order().data()))); + output_tensors_[shard_index].push_back(cache.back().get()); + } + } + } + } +} + +void HybridMemory::prepare_prefill_io( + const std::vector>& methods_meta) { + IO* ptr = static_cast(data_ptr_.get()); + std::memset(ptr, 0, sizeof(IO)); + for (int i = 0; i < modules_.size(); ++i) { + ET_CHECK_MSG( + methods_meta[i].ok(), + "Failed to get method_meta 0x%x", + static_cast(methods_meta[i].error())); + } + + // Parse some IO info from method meta + // cache_len should be max_seq_len - 1 + int cache_len = methods_meta[0]->input_tensor_meta(0)->sizes()[1]; + + // TODO: Combine vector init with KV mode once Hybrid mode is enabled + // as it shares some common data structure. + // Init IO vector shape + ptr->prefill_input_toks.resize(cache_len); + ptr->prefill_atten_mask.resize(cache_len * cache_len); + ptr->prefill_logits.resize(cache_len * vocab_size_); + // Init kv vector shape + int32_t k_cache_out_size = num_heads_ * head_dim_ * cache_len; + int32_t v_cache_size = (num_heads_ + 1) * cache_len * head_dim_; + for (int layer = 0; layer < num_layers_; layer++) { + ptr->k_cache_out.emplace_back(std::vector(k_cache_out_size)); + ptr->v_cache.emplace_back(std::vector(v_cache_size)); + } + + // [I]: pre_input_tokens + Result prefill_input_toks = methods_meta[0]->input_tensor_meta(0); + prefill_input_toks_ = std::make_unique( + prefill_input_toks->scalar_type(), + prefill_input_toks->sizes().size(), + const_cast(prefill_input_toks->sizes().data()), + ptr->prefill_input_toks.data(), + const_cast( + prefill_input_toks->dim_order().data())); + input_tensors_[0].push_back(prefill_input_toks_.get()); + // [I]: prefill_attn_mask + for (int i = 0; i < cache_len; ++i) { + for (int j = 0; j < cache_len; ++j) { + if (i < j) { + ptr->prefill_atten_mask[i * cache_len + j] = -255; + } else { + ptr->prefill_atten_mask[i * cache_len + j] = 0; + } + } + } + + Result prefill_attn_mask = methods_meta[0]->input_tensor_meta(1); + prefill_attn_mask_ = std::make_unique( + prefill_attn_mask->scalar_type(), + prefill_attn_mask->sizes().size(), + const_cast(prefill_attn_mask->sizes().data()), + ptr->prefill_atten_mask.data(), + const_cast( + prefill_attn_mask->dim_order().data())); + input_tensors_[0].push_back(prefill_attn_mask_.get()); + + // [O]: logits + int logit_index = 0; + Result logits = + methods_meta[modules_.size() - 1]->output_tensor_meta(logit_index); + logits_ = std::make_unique( + logits->scalar_type(), + logits->sizes().size(), + const_cast(logits->sizes().data()), + ptr->prefill_logits.data(), + const_cast(logits->dim_order().data())); + output_tensors_[modules_.size() - 1].push_back(logits_.get()); + // [O] kv_cache + int index = 1; + for (int offset = 0, shard_index = 0, cache_stride = cache_len * head_dim_; shard_index < modules_.size(); offset += shard_layers_[shard_index], shard_index++) { for (int cache_group = 0; cache_group < 2; ++cache_group) { @@ -162,9 +296,11 @@ void KVCachedMemory::prepare_io( std::vector>& cache = (cache_group == 0 ? k_cache_out_ : v_cache_out_); void* cache_ptr = (cache_group == 0) - ? static_cast(ptr->k_cache_out[layer + offset][head]) + ? static_cast( + ptr->k_cache_out[layer + offset].data() + + head * cache_stride) : static_cast( - ptr->v_cache[layer + offset] + (head + 1) * v_stride); + ptr->v_cache[layer + offset].data() + head * cache_stride); cache.emplace_back(std::make_unique( kv_cache->scalar_type(), kv_cache->sizes().size(), @@ -179,12 +315,12 @@ void KVCachedMemory::prepare_io( } } -void KVCachedMemory::update_io( +void HybridMemory::update_io( int64_t cur_token, int64_t pos, std::vector>& output_tensors) { IO* ptr = static_cast(data_ptr_.get()); - int seq_len = (QNN_LLAMA3_2_SEQLEN - 1); + int seq_len = (max_seq_len_ - 1); // update input_tok ptr->input_tok = static_cast(cur_token); // update position_ids @@ -195,9 +331,9 @@ void KVCachedMemory::update_io( // update v_cache for (int i = 0; i < v_cache_in_.size(); i++) { v_cache_in_[i]->set_data( - v_cache_in_[i]->mutable_data() + QNN_LLAMA3_2_HEAD_DIM); + v_cache_in_[i]->mutable_data() + head_dim_); v_cache_out_[i]->set_data( - v_cache_out_[i]->mutable_data() + QNN_LLAMA3_2_HEAD_DIM); + v_cache_out_[i]->mutable_data() + head_dim_); } for (int shard = 0; shard < output_tensors.size(); shard++) { for (int index = 0; index < output_tensors[shard].size(); index++) { @@ -215,7 +351,7 @@ void KVCachedMemory::update_io( for (int i = 0; i < k_cache_in_.size(); ++i) { uint8_t* ptr_in = k_cache_in_[i]->mutable_data(); const uint8_t* ptr_out = k_cache_out_[i]->data(); - for (size_t j = 0, offset = seq_len; j < QNN_LLAMA3_2_HEAD_DIM; + for (size_t j = 0, offset = seq_len; j < head_dim_; ++j, offset += seq_len) { ptr_in[offset] = ptr_out[j]; } diff --git a/examples/qualcomm/oss_scripts/llama3_2/runner/io_memory.h b/examples/qualcomm/oss_scripts/llama3_2/runner/io_memory.h index bee13100b2d..31ed351ef4b 100644 --- a/examples/qualcomm/oss_scripts/llama3_2/runner/io_memory.h +++ b/examples/qualcomm/oss_scripts/llama3_2/runner/io_memory.h @@ -19,25 +19,17 @@ #include #include -#define QNN_LLAMA3_2_LOGITS 128256 -#define QNN_LLAMA3_2_SEQLEN 512 // adjustable based on llama export -#define QNN_LLAMA3_2_NUM_HEADS 8 - -#if defined(LLAMA3_2_3B_RUNNER) -#define QNN_LLAMA3_2_HEAD_DIM 128 -#define QNN_LLAMA3_2_NUM_LAYERS 28 -#else -#define QNN_LLAMA3_2_HEAD_DIM 64 -#define QNN_LLAMA3_2_NUM_LAYERS 16 -#endif - namespace example { class Memory { public: Memory(std::vector>& modules); virtual ~Memory(); - virtual void prepare_io( + virtual void prepare_prefill_io( + const std::vector< + executorch::runtime::Result>& + methods_meta) = 0; + virtual void prepare_kv_io( const std::vector< executorch::runtime::Result>& methods_meta) = 0; @@ -56,12 +48,23 @@ class Memory { std::vector> modules_; }; -class KVCachedMemory : public Memory { +class HybridMemory : public Memory { public: - KVCachedMemory( - std::vector>& modules); - void prepare_io(const std::vector>& methods_meta) override; + HybridMemory( + std::vector>& modules, + int32_t max_seq_len, + int32_t vocab_size, + int32_t num_layers, + int32_t head_dim, + int32_t num_heads); + void prepare_prefill_io( + const std::vector< + executorch::runtime::Result>& + methods_meta) override; + void prepare_kv_io( + const std::vector< + executorch::runtime::Result>& + methods_meta) override; void update_io( int64_t cur_token, int64_t pos, @@ -70,20 +73,14 @@ class KVCachedMemory : public Memory { struct IO { int32_t input_tok; int32_t input_pos; - float attention_mask[QNN_LLAMA3_2_SEQLEN]; - uint8_t k_cache[QNN_LLAMA3_2_NUM_LAYERS][QNN_LLAMA3_2_NUM_HEADS] - [(QNN_LLAMA3_2_HEAD_DIM + 1) * (QNN_LLAMA3_2_SEQLEN - 1)]; - uint8_t v_cache[QNN_LLAMA3_2_NUM_LAYERS] - [(QNN_LLAMA3_2_NUM_HEADS + 1) * (QNN_LLAMA3_2_SEQLEN - 1) * - (QNN_LLAMA3_2_HEAD_DIM)]; - uint8_t k_cache_out[QNN_LLAMA3_2_NUM_LAYERS][QNN_LLAMA3_2_NUM_HEADS] - [QNN_LLAMA3_2_HEAD_DIM]; - float logits[QNN_LLAMA3_2_LOGITS]; - }; - struct LoopRange { - int32_t start; - int32_t end; - int32_t step; + std::vector attention_mask; + std::vector>> k_cache; + std::vector> v_cache; + std::vector> k_cache_out; + std::vector logits; + std::vector prefill_input_toks; + std::vector prefill_atten_mask; + std::vector prefill_logits; }; private: @@ -91,13 +88,19 @@ class KVCachedMemory : public Memory { std::unique_ptr input_pos_; std::unique_ptr hidden_state_; std::unique_ptr attention_mask_; + std::unique_ptr prefill_input_toks_; + std::unique_ptr prefill_attn_mask_; std::vector> k_cache_in_; std::vector> v_cache_in_; std::vector> k_cache_out_; std::vector> v_cache_out_; std::unique_ptr logits_; std::vector shard_layers_; - int num_heads_; + int32_t max_seq_len_; + int32_t vocab_size_; + int32_t num_layers_; + int32_t head_dim_; + int32_t num_heads_; }; } // namespace example diff --git a/examples/qualcomm/oss_scripts/llama3_2/runner/runner.cpp b/examples/qualcomm/oss_scripts/llama3_2/runner/runner.cpp index b601d200341..e1dc17dc959 100644 --- a/examples/qualcomm/oss_scripts/llama3_2/runner/runner.cpp +++ b/examples/qualcomm/oss_scripts/llama3_2/runner/runner.cpp @@ -18,7 +18,6 @@ #include #include #include -#include #include #include @@ -42,13 +41,15 @@ std::string statsToJsonString(const Runner::Stats& stats); Runner::Runner( const std::vector& models_path, const std::string& tokenizer_path, - const float temperature) + const float temperature, + const int eval_mode) : n_bos_(1), n_eos_(1), vocab_size_(QNN_LLAMA3_2_LOGITS), max_seq_len_(QNN_LLAMA3_2_SEQLEN), tokenizer_path_(tokenizer_path), temperature_(temperature), + eval_mode_(eval_mode), stats_({}) { for (size_t i = 0; i < models_path.size(); ++i) { modules_.push_back(std::make_shared( @@ -57,6 +58,19 @@ Runner::Runner( } ET_LOG(Info, "creating runner: tokenizer_path=%s", tokenizer_path_.c_str()); + int64_t max_seq_len = getMetadataHelper("get_max_seq_len", -1); + int64_t vocab_size = getMetadataHelper("get_vocab_size", -1); + int64_t num_layers = getMetadataHelper("get_n_layers", -1); + int64_t head_dim = getMetadataHelper("get_head_dim", -1); + int64_t num_heads = getMetadataHelper("get_n_kv_heads", -1); + ET_CHECK_MSG(max_seq_len != -1, "Could not retrieve max seq len"); + ET_CHECK_MSG(vocab_size != -1, "Could not retrieve vocab size"); + ET_CHECK_MSG(num_layers != -1, "Could not retrieve num layers"); + ET_CHECK_MSG(head_dim != -1, "Could not retrieve head dimension"); + ET_CHECK_MSG(num_heads != -1, "Could not retrieve num heads"); + + max_seq_len_ = max_seq_len; + vocab_size_ = vocab_size; tokenizer_ = example::get_tiktoken_for_llama(); Error err = tokenizer_->load(tokenizer_path_); ET_CHECK_MSG( @@ -64,7 +78,9 @@ Runner::Runner( eos_id_.insert(tokenizer_->encode("<|eot_id|>", 0, 0).get()[0]); bos_id_ = tokenizer_->bos_tok(); eos_id_.insert(tokenizer_->eos_tok()); - io_mem_ = std::make_unique(modules_); + + io_mem_ = std::make_unique( + modules_, max_seq_len_, vocab_size_, num_layers, head_dim, num_heads); ET_LOG(Info, "creating io_memory"); } @@ -93,10 +109,36 @@ Error Runner::load() { // prepare io auto methods_meta = get_methods_meta(); - io_mem_->prepare_io(methods_meta); + if (eval_mode_ == EvalMode::kBert) { + io_mem_->prepare_prefill_io(methods_meta); + } else { + io_mem_->prepare_kv_io(methods_meta); + } return Error::Ok; } +template +T Runner::getMetadataHelper(std::string method_name, T default_val) { + T res = default_val; + if (modules_[0]->method_names()->count(method_name)) { + Result> outputs = modules_[0]->execute(method_name); + if (outputs.ok()) { + std::vector outs = outputs.get(); + if (outs.size() > 0) { + res = outs[0].to(); + } + } + } else { + ET_LOG( + Info, + "The model does not contain %s method, using default value %lld", + method_name.c_str(), + (long long)default_val); + } + ET_LOG(Info, "%s: %lld", method_name.c_str(), (long long)res); + return res; +} + template int32_t Runner::logitsToToken(const Tensor& logits_tensor) { T* logits = logits_tensor.mutable_data_ptr(); @@ -142,10 +184,6 @@ Error Runner::generate( } stats_.model_load_end_ms = time_in_ms(); } - - stats_.inference_start_ms = time_in_ms(); - seq_len = (seq_len > 0 && seq_len <= max_seq_len_) ? seq_len : max_seq_len_; - std::string post_process_prompt; if (!system_prompt.empty()) { @@ -158,11 +196,11 @@ Error Runner::generate( post_process_prompt.append(prompt); post_process_prompt.append( "<|eot_id|><|start_header_id|>assistant<|end_header_id|>"); - // tokenizer_->encode will add <|begin_of_text|> token for us. - // For now, do token call back so the output format looks the same as - // llama3 model card. token_callback("<|begin_of_text|>"); + stats_.inference_start_ms = time_in_ms(); + + seq_len = (seq_len > 0 && seq_len <= max_seq_len_) ? seq_len : max_seq_len_; Result> encode_res = tokenizer_->encode(post_process_prompt, n_bos_, 0); ET_CHECK_OK_OR_RETURN_ERROR( @@ -178,44 +216,70 @@ Error Runner::generate( "sequence length exceeded - please increase the seq_len value"); int64_t pos = 0, prev_token, cur_token = prompt_tokens[0]; - KVCachedMemory::IO* ptr = - static_cast(io_mem_->get_mutable_ptr()); - ptr->input_tok = static_cast(cur_token); - ptr->attention_mask[max_seq_len_ - 1] = 0; - - std::vector postTime; - while (pos < seq_len - 1) { + HybridMemory::IO* ptr = + static_cast(io_mem_->get_mutable_ptr()); + + if (eval_mode_ == EvalMode::kBert) { + for (int i = 0; i < num_prompt_tokens; i++) { + ptr->prefill_input_toks[i] = static_cast(prompt_tokens[i]); + auto piece_res = tokenizer_->decode(prompt_tokens[i], prompt_tokens[i]); + token_callback(piece_res.get()); + } // inference run_model_step(inputs); Tensor& logits_tensor = output_tensors.back()[0]; - - if (pos == num_prompt_tokens) { - stats_.first_token_ms = time_in_ms(); - } else if (pos == num_prompt_tokens - 1) { - stats_.prompt_eval_end_ms = time_in_ms(); - } - + // offset to the meaningful logit we want. + float* logits = logits_tensor.mutable_data_ptr() + + (num_prompt_tokens - 1) * vocab_size_; + prev_token = prompt_tokens[num_prompt_tokens - 1]; + cur_token = sampler_->sample(logits); + stats_.first_token_ms = time_in_ms(); + stats_.prompt_eval_end_ms = time_in_ms(); long sample_start_time_ms = time_in_ms(); - prev_token = cur_token; - cur_token = logitsToToken(logits_tensor); stats_.aggregate_sampling_time_ms += time_in_ms() - sample_start_time_ms; - - if (pos < num_prompt_tokens - 1) { - cur_token = prompt_tokens[pos + 1]; - } - io_mem_->update_io(cur_token, ++pos, output_tensors); auto piece_res = tokenizer_->decode(prev_token, cur_token); ET_CHECK(piece_res.ok()); - if (token_callback) { token_callback(piece_res.get().c_str()); } + pos += num_prompt_tokens; + } else { + ptr->input_tok = static_cast(cur_token); + ptr->attention_mask[max_seq_len_ - 1] = 0; + while (pos < seq_len - 1) { + // inference + run_model_step(inputs); + Tensor& logits_tensor = output_tensors.back()[0]; + + if (pos == num_prompt_tokens) { + stats_.first_token_ms = time_in_ms(); + } else if (pos == num_prompt_tokens - 1) { + stats_.prompt_eval_end_ms = time_in_ms(); + } - if (pos >= num_prompt_tokens && eos_id_.count(cur_token) > 0) { - ET_LOG(Info, "\nReached to the end of generation"); - break; + long sample_start_time_ms = time_in_ms(); + prev_token = cur_token; + cur_token = logitsToToken(logits_tensor); + stats_.aggregate_sampling_time_ms += time_in_ms() - sample_start_time_ms; + + if (pos < num_prompt_tokens - 1) { + cur_token = prompt_tokens[pos + 1]; + } + io_mem_->update_io(cur_token, ++pos, output_tensors); + auto piece_res = tokenizer_->decode(prev_token, cur_token); + ET_CHECK(piece_res.ok()); + + if (token_callback) { + token_callback(piece_res.get().c_str()); + } + + if (pos >= num_prompt_tokens && eos_id_.count(cur_token) > 0) { + ET_LOG(Info, "\nReached to the end of generation"); + break; + } } } + stats_.inference_end_ms = time_in_ms(); if (pos == seq_len) { ET_LOG(Info, "\nSequence length (%i tokens) reached!", seq_len); diff --git a/examples/qualcomm/oss_scripts/llama3_2/runner/runner.h b/examples/qualcomm/oss_scripts/llama3_2/runner/runner.h index b297203411f..6f5b6c3ba1a 100644 --- a/examples/qualcomm/oss_scripts/llama3_2/runner/runner.h +++ b/examples/qualcomm/oss_scripts/llama3_2/runner/runner.h @@ -29,7 +29,8 @@ class Runner { explicit Runner( const std::vector& models_path, const std::string& tokenizer_path, - const float temperature); + const float temperature, + const int eval_mode); struct Stats { // Scaling factor for timestamps - in this case, we use ms. @@ -70,17 +71,24 @@ class Runner { get_methods_meta(); private: + enum EvalMode { + kBert = 0, + kKVCached, + kUnsupported, + }; + template + T getMetadataHelper(std::string method_name, T default_val); template int32_t logitsToToken(const executorch::aten::Tensor& logits_tensor); void run_model_step( std::vector>& inputs); // metadata + int32_t max_seq_len_; + int32_t vocab_size_; int32_t bos_id_; std::unordered_set eos_id_; const int32_t n_bos_; const int32_t n_eos_; - const int32_t vocab_size_; - const int32_t max_seq_len_; std::vector> modules_; std::string tokenizer_path_; float temperature_; @@ -88,6 +96,7 @@ class Runner { std::unique_ptr sampler_; Stats stats_; std::unique_ptr io_mem_; + int32_t eval_mode_; }; } // namespace example diff --git a/install_requirements.py b/install_requirements.py index 90e10373293..7c54bdbb206 100644 --- a/install_requirements.py +++ b/install_requirements.py @@ -137,7 +137,7 @@ def python_is_compatible(): "timm==1.0.7", f"torchaudio==2.5.0.{NIGHTLY_VERSION}" if USE_PYTORCH_NIGHTLY else "torchaudio", "torchsr==1.0.4", - "transformers==4.46.1", + "transformers==4.42.4", # TODO update back to 4.46.1 once the error is fixed ] # pip packages needed for development. From d171e89ce3cd89560e0864016f4c46adfb018936 Mon Sep 17 00:00:00 2001 From: Joey Tsai Date: Fri, 22 Nov 2024 09:37:24 +0800 Subject: [PATCH 2/8] Rebase and minor fix - Fix rebase conflict - Change input dtype of calibration function --- examples/qualcomm/oss_scripts/llama2/llama.py | 4 ++-- examples/qualcomm/oss_scripts/llama3_2/llama.py | 2 +- examples/qualcomm/oss_scripts/llama3_2/runner/runner.cpp | 2 -- 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/examples/qualcomm/oss_scripts/llama2/llama.py b/examples/qualcomm/oss_scripts/llama2/llama.py index 9af32bd7328..1a13c31572b 100755 --- a/examples/qualcomm/oss_scripts/llama2/llama.py +++ b/examples/qualcomm/oss_scripts/llama2/llama.py @@ -203,7 +203,7 @@ def sample_top_p(probs: torch.Tensor, top_p: float) -> torch.Tensor: return probs_indices.gather(dim=-1, index=next_token) with torch.no_grad(): - while token_list[-1] != sp_model.eos_id() and pos < max_seq_len: + while token_list[-1] != sp_model.eos_id() and pos < max_seq_len - 1: logits, new_k_caches, new_v_caches = module( torch.full((1, 1), token_list[pos]), atten_mask, @@ -248,7 +248,7 @@ def _bert_calibrate( token_list = torch.cat( [ token_list, - torch.zeros((1, max_cache_len - last_prompt_pos), dtype=torch.int64), + torch.zeros((1, max_cache_len - last_prompt_pos), dtype=torch.int32), ], dim=1, ) diff --git a/examples/qualcomm/oss_scripts/llama3_2/llama.py b/examples/qualcomm/oss_scripts/llama3_2/llama.py index f086e400cf3..a6626077d18 100755 --- a/examples/qualcomm/oss_scripts/llama3_2/llama.py +++ b/examples/qualcomm/oss_scripts/llama3_2/llama.py @@ -122,7 +122,7 @@ def _bert_calibrate( token_list = torch.cat( [ token_list, - torch.zeros((1, max_cache_len - last_prompt_pos), dtype=torch.int64), + torch.zeros((1, max_cache_len - last_prompt_pos), dtype=torch.int32), ], dim=1, ) diff --git a/examples/qualcomm/oss_scripts/llama3_2/runner/runner.cpp b/examples/qualcomm/oss_scripts/llama3_2/runner/runner.cpp index e1dc17dc959..f0951c2850d 100644 --- a/examples/qualcomm/oss_scripts/llama3_2/runner/runner.cpp +++ b/examples/qualcomm/oss_scripts/llama3_2/runner/runner.cpp @@ -45,8 +45,6 @@ Runner::Runner( const int eval_mode) : n_bos_(1), n_eos_(1), - vocab_size_(QNN_LLAMA3_2_LOGITS), - max_seq_len_(QNN_LLAMA3_2_SEQLEN), tokenizer_path_(tokenizer_path), temperature_(temperature), eval_mode_(eval_mode), From d3db12b6cbea7a00514347317502d102b65f9a65 Mon Sep 17 00:00:00 2001 From: Joey Tsai Date: Fri, 22 Nov 2024 09:53:19 +0800 Subject: [PATCH 3/8] Change bert to batch prefill --- examples/qualcomm/oss_scripts/llama2/llama.py | 12 ++++++------ .../oss_scripts/llama2/model/static_llama.py | 8 ++++---- examples/qualcomm/oss_scripts/llama3_2/llama.py | 14 +++++++------- .../oss_scripts/llama3_2/qnn_llama3_2_runner.cpp | 2 +- .../oss_scripts/llama3_2/runner/runner.cpp | 4 ++-- .../qualcomm/oss_scripts/llama3_2/runner/runner.h | 2 +- 6 files changed, 21 insertions(+), 21 deletions(-) diff --git a/examples/qualcomm/oss_scripts/llama2/llama.py b/examples/qualcomm/oss_scripts/llama2/llama.py index 1a13c31572b..ae291f3659e 100755 --- a/examples/qualcomm/oss_scripts/llama2/llama.py +++ b/examples/qualcomm/oss_scripts/llama2/llama.py @@ -229,7 +229,7 @@ def sample_top_p(probs: torch.Tensor, top_p: float) -> torch.Tensor: print(f"calibration data:\n{sp_model.decode(token_list)}") -def _bert_calibrate( +def _batch_prefill_calibrate( example_inputs, user_prompts, module: torch.fx.GraphModule, @@ -273,7 +273,7 @@ def calibrate( max_seq_len=512, ): if len(example_inputs) == 2: - _bert_calibrate( + _batch_prefill_calibrate( example_inputs, user_prompts, module, @@ -332,7 +332,7 @@ def _tag_kv_ios(self, gm: torch.fx.GraphModule, kv_type): == self.llama_meta["get_head_dim"] ): a.meta[QCOM_QUANTIZED_IO] = kv_type - # single head, bert mode + # single head, batch_prefill mode elif a.meta["val"].flatten().size()[0] == self.llama_meta[ "get_head_dim" ] * (self.llama_meta["get_max_seq_len"] - 1): @@ -416,7 +416,7 @@ def compile(args): if args.model_mode == "kv": use_kv_cache = output_new_cache_only = True - elif args.model_mode == "bert" or args.model_mode == "hybrid": + elif args.model_mode == "batch_prefill" or args.model_mode == "hybrid": raise NotImplementedError( f"model_mode {args.model_mode} is not implemented yet." ) @@ -653,9 +653,9 @@ def post_process(): parser.add_argument( "--model_mode", - help="Export and inference bert mode, kv mode or hybrid(TBD) mode", + help="Export and inference batch_prefill mode, kv mode or hybrid(TBD) mode", default="kv", - choices=["bert", "kv", "hybrid"], + choices=["batch_prefill", "kv", "hybrid"], type=str, ) diff --git a/examples/qualcomm/oss_scripts/llama2/model/static_llama.py b/examples/qualcomm/oss_scripts/llama2/model/static_llama.py index ccfe3d62af5..a0d3397ad9d 100755 --- a/examples/qualcomm/oss_scripts/llama2/model/static_llama.py +++ b/examples/qualcomm/oss_scripts/llama2/model/static_llama.py @@ -21,7 +21,7 @@ def apply_rotary_emb_single( ) -> torch.Tensor: x_r, x_i = x[..., ::2], x[..., 1::2] - # brodcast for bert mode input x + # brodcast for batch_prefill mode input x if x.dim() == 4: freqs_cos = freqs_cos[None, :, None, :] freqs_sin = freqs_sin[None, :, None, :] @@ -111,7 +111,7 @@ def forward_sha( for i, _ in enumerate(k_caches): kh.append(torch.cat([k_caches[i], k[i]], dim=-1)) vh.append(torch.cat([v_caches[i], v[i]], dim=1)) - # bert/prefill mode + # batch_prefill mode else: kh = k vh = v @@ -131,7 +131,7 @@ def forward_sha( if self.output_new_cache_only: if k_caches and v_caches: return y, k, v - # bert mode. Consider to remove, it's not really used + # batch_prefill mode. Consider to remove, it's not really used return y, k[-1], v[-1] return y, kh, vh @@ -172,7 +172,7 @@ def forward( output_y.append(y) - # bert/prefill mode + # batch_prefill mode else: kh = k vh = v diff --git a/examples/qualcomm/oss_scripts/llama3_2/llama.py b/examples/qualcomm/oss_scripts/llama3_2/llama.py index a6626077d18..eea0ab83b85 100755 --- a/examples/qualcomm/oss_scripts/llama3_2/llama.py +++ b/examples/qualcomm/oss_scripts/llama3_2/llama.py @@ -103,7 +103,7 @@ def _kv_calibrate( print(f"calibration data:\n{sp_model.decode(token_list)}") -def _bert_calibrate( +def _batch_prefill_calibrate( example_inputs, user_prompts, module: torch.fx.GraphModule, @@ -147,7 +147,7 @@ def calibrate( max_seq_len=512, ): if len(example_inputs) == 2: - _bert_calibrate( + _batch_prefill_calibrate( example_inputs, user_prompts, module, @@ -206,7 +206,7 @@ def _tag_kv_ios(self, gm: torch.fx.GraphModule, kv_type, sharding_type): == self.llama_meta["get_head_dim"] ): a.meta[QCOM_QUANTIZED_IO] = kv_type - # single head, bert mode + # single head, batch_prefill mode elif a.meta["val"].flatten().size()[0] == self.llama_meta[ "get_head_dim" ] * (self.llama_meta["get_max_seq_len"] - 1): @@ -319,7 +319,7 @@ def compile(args): if args.model_mode == "kv": use_kv_cache = output_new_cache_only = True - elif args.model_mode == "bert": + elif args.model_mode == "batch_prefill": use_kv_cache = output_new_cache_only = False elif args.model_mode == "hybrid": raise NotImplementedError( @@ -406,7 +406,7 @@ def compile(args): def inference(args, pre_gen_pte=""): workspace = f"/data/local/tmp/{getpass.getuser()}/executorch/single_llama" - if args.model_mode == "bert": + if args.model_mode == "batch_prefill": eval_mode = 0 elif args.model_mode == "kv": eval_mode = 1 @@ -573,9 +573,9 @@ def post_process(): parser.add_argument( "--model_mode", - help="Export and inference bert mode, kv mode or hybrid(TBD) mode", + help="Export and inference batch_prefill mode, kv mode or hybrid(TBD) mode", default="kv", - choices=["bert", "kv", "hybrid"], + choices=["batch_prefill", "kv", "hybrid"], type=str, ) diff --git a/examples/qualcomm/oss_scripts/llama3_2/qnn_llama3_2_runner.cpp b/examples/qualcomm/oss_scripts/llama3_2/qnn_llama3_2_runner.cpp index a184bb42ade..554e3ba9329 100644 --- a/examples/qualcomm/oss_scripts/llama3_2/qnn_llama3_2_runner.cpp +++ b/examples/qualcomm/oss_scripts/llama3_2/qnn_llama3_2_runner.cpp @@ -46,7 +46,7 @@ DEFINE_int32( DEFINE_int32( eval_mode, 0, - "0: PromptProcessor(bert) / 1: TokenGenerator(kv) / 2: HybridMode (TBD)"); + "0: PromptProcessor(batch_prefill) / 1: TokenGenerator(kv) / 2: HybridMode (TBD)"); int main(int argc, char** argv) { gflags::ParseCommandLineFlags(&argc, &argv, true); diff --git a/examples/qualcomm/oss_scripts/llama3_2/runner/runner.cpp b/examples/qualcomm/oss_scripts/llama3_2/runner/runner.cpp index f0951c2850d..1ee58d82962 100644 --- a/examples/qualcomm/oss_scripts/llama3_2/runner/runner.cpp +++ b/examples/qualcomm/oss_scripts/llama3_2/runner/runner.cpp @@ -107,7 +107,7 @@ Error Runner::load() { // prepare io auto methods_meta = get_methods_meta(); - if (eval_mode_ == EvalMode::kBert) { + if (eval_mode_ == EvalMode::kBatchPrefill) { io_mem_->prepare_prefill_io(methods_meta); } else { io_mem_->prepare_kv_io(methods_meta); @@ -217,7 +217,7 @@ Error Runner::generate( HybridMemory::IO* ptr = static_cast(io_mem_->get_mutable_ptr()); - if (eval_mode_ == EvalMode::kBert) { + if (eval_mode_ == EvalMode::kBatchPrefill) { for (int i = 0; i < num_prompt_tokens; i++) { ptr->prefill_input_toks[i] = static_cast(prompt_tokens[i]); auto piece_res = tokenizer_->decode(prompt_tokens[i], prompt_tokens[i]); diff --git a/examples/qualcomm/oss_scripts/llama3_2/runner/runner.h b/examples/qualcomm/oss_scripts/llama3_2/runner/runner.h index 6f5b6c3ba1a..b720697be5f 100644 --- a/examples/qualcomm/oss_scripts/llama3_2/runner/runner.h +++ b/examples/qualcomm/oss_scripts/llama3_2/runner/runner.h @@ -72,7 +72,7 @@ class Runner { private: enum EvalMode { - kBert = 0, + kBatchPrefill = 0, kKVCached, kUnsupported, }; From 9a7f55995490fcd922fc72b02fade24bf31e0b7b Mon Sep 17 00:00:00 2001 From: Joey Tsai Date: Fri, 22 Nov 2024 12:32:49 +0800 Subject: [PATCH 4/8] Fix compile error --- examples/qualcomm/oss_scripts/llama3_2/runner/runner.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/qualcomm/oss_scripts/llama3_2/runner/runner.cpp b/examples/qualcomm/oss_scripts/llama3_2/runner/runner.cpp index 1ee58d82962..80da5b98873 100644 --- a/examples/qualcomm/oss_scripts/llama3_2/runner/runner.cpp +++ b/examples/qualcomm/oss_scripts/llama3_2/runner/runner.cpp @@ -47,8 +47,7 @@ Runner::Runner( n_eos_(1), tokenizer_path_(tokenizer_path), temperature_(temperature), - eval_mode_(eval_mode), - stats_({}) { + eval_mode_(eval_mode) { for (size_t i = 0; i < models_path.size(); ++i) { modules_.push_back(std::make_shared( models_path[i], Module::LoadMode::MmapUseMlockIgnoreErrors)); From 1a41b47510586c406233eb548359b4bdd75b8201 Mon Sep 17 00:00:00 2001 From: Joey Tsai Date: Fri, 22 Nov 2024 19:58:27 +0800 Subject: [PATCH 5/8] Fix lint - Fix transformers version - Refine pass quantization tagging function - Rebase --- backends/qualcomm/quantizer/custom_annotation.py | 7 +++++-- examples/qualcomm/oss_scripts/llama3_2/llama.py | 9 +++++++-- install_requirements.py | 2 +- 3 files changed, 13 insertions(+), 5 deletions(-) diff --git a/backends/qualcomm/quantizer/custom_annotation.py b/backends/qualcomm/quantizer/custom_annotation.py index 0e021c02e68..d2bc1b852de 100644 --- a/backends/qualcomm/quantizer/custom_annotation.py +++ b/backends/qualcomm/quantizer/custom_annotation.py @@ -22,7 +22,9 @@ from torch.fx import Node -def annotate_matmul_16a8w(gm: torch.fx.GraphModule) -> None: # noqa: C901 +def annotate_matmul_16a8w( + gm: torch.fx.GraphModule, traverse_input1=True +) -> None: # noqa: C901 """ This function is specific for matmul op 16a8w. """ @@ -99,7 +101,8 @@ def annotate_matmul_input1(node: Node): for node in gm.graph.nodes: if node.op == "call_function" and node.target == torch.ops.aten.matmul.default: annotate_matmul(node, quantization_config_16a8w) - annotate_matmul_input1(node.args[1]) + if traverse_input1: + annotate_matmul_input1(node.args[1]) def custom_annotate_llama_matmul_16a8w(gm: torch.fx.GraphModule) -> None: # noqa: C901 diff --git a/examples/qualcomm/oss_scripts/llama3_2/llama.py b/examples/qualcomm/oss_scripts/llama3_2/llama.py index eea0ab83b85..75c0bb0ff0f 100755 --- a/examples/qualcomm/oss_scripts/llama3_2/llama.py +++ b/examples/qualcomm/oss_scripts/llama3_2/llama.py @@ -8,9 +8,9 @@ import json import logging import os - import sys import time +from functools import partial from multiprocessing.connection import Client import torch @@ -319,8 +319,10 @@ def compile(args): if args.model_mode == "kv": use_kv_cache = output_new_cache_only = True + matmul_annotate_func = partial(annotate_matmul_16a8w, traverse_input1=True) elif args.model_mode == "batch_prefill": use_kv_cache = output_new_cache_only = False + matmul_annotate_func = partial(annotate_matmul_16a8w, traverse_input1=False) elif args.model_mode == "hybrid": raise NotImplementedError( f"model_mode {args.model_mode} is not implemented yet." @@ -385,7 +387,10 @@ def compile(args): start_quantize_ts = time.time() single_llama.quantize( quant_dtype, - custom_annotations=(annotate_matmul_16a8w,), + custom_annotations=( + custom_annotate_llama_last_conv_16a8w, + matmul_annotate_func, + ), ) end_quantize_ts = time.time() logging.info(f"Time for quantizing: {end_quantize_ts - start_quantize_ts}") diff --git a/install_requirements.py b/install_requirements.py index 7c54bdbb206..90e10373293 100644 --- a/install_requirements.py +++ b/install_requirements.py @@ -137,7 +137,7 @@ def python_is_compatible(): "timm==1.0.7", f"torchaudio==2.5.0.{NIGHTLY_VERSION}" if USE_PYTORCH_NIGHTLY else "torchaudio", "torchsr==1.0.4", - "transformers==4.42.4", # TODO update back to 4.46.1 once the error is fixed + "transformers==4.46.1", ] # pip packages needed for development. From 0a080a45b331ee8992a6a1efc85ddeea71ed8979 Mon Sep 17 00:00:00 2001 From: Joey Tsai Date: Mon, 25 Nov 2024 09:51:21 +0800 Subject: [PATCH 6/8] Add one line in the end of CmakeList --- examples/qualcomm/oss_scripts/llama3_2/CMakeLists.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/qualcomm/oss_scripts/llama3_2/CMakeLists.txt b/examples/qualcomm/oss_scripts/llama3_2/CMakeLists.txt index 6090ff7fe47..c982b6d5158 100644 --- a/examples/qualcomm/oss_scripts/llama3_2/CMakeLists.txt +++ b/examples/qualcomm/oss_scripts/llama3_2/CMakeLists.txt @@ -64,4 +64,5 @@ target_compile_options( ) set_target_properties( qnn_llama3_2_runner PROPERTIES LINK_FLAGS "-Wl,-rpath='$ORIGIN'" -) \ No newline at end of file +) + From c47568de95fff2a25f8cef557b4a5ad105da3b6e Mon Sep 17 00:00:00 2001 From: Joey Tsai Date: Mon, 25 Nov 2024 10:04:47 +0800 Subject: [PATCH 7/8] Remove trailing line of CmakeList --- examples/qualcomm/oss_scripts/llama3_2/CMakeLists.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/qualcomm/oss_scripts/llama3_2/CMakeLists.txt b/examples/qualcomm/oss_scripts/llama3_2/CMakeLists.txt index c982b6d5158..93b35a697c6 100644 --- a/examples/qualcomm/oss_scripts/llama3_2/CMakeLists.txt +++ b/examples/qualcomm/oss_scripts/llama3_2/CMakeLists.txt @@ -65,4 +65,3 @@ target_compile_options( set_target_properties( qnn_llama3_2_runner PROPERTIES LINK_FLAGS "-Wl,-rpath='$ORIGIN'" ) - From ef2e1e5a71d169b197ae8db1ad8248cdefbb2183 Mon Sep 17 00:00:00 2001 From: Joey Tsai Date: Mon, 25 Nov 2024 10:22:52 +0800 Subject: [PATCH 8/8] Move noqa to correct line number --- backends/qualcomm/quantizer/custom_annotation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/backends/qualcomm/quantizer/custom_annotation.py b/backends/qualcomm/quantizer/custom_annotation.py index d2bc1b852de..c58d0844b40 100644 --- a/backends/qualcomm/quantizer/custom_annotation.py +++ b/backends/qualcomm/quantizer/custom_annotation.py @@ -22,9 +22,9 @@ from torch.fx import Node -def annotate_matmul_16a8w( +def annotate_matmul_16a8w( # noqa: C901 gm: torch.fx.GraphModule, traverse_input1=True -) -> None: # noqa: C901 +) -> None: """ This function is specific for matmul op 16a8w. """