diff --git a/examples/models/llama2/eval_llama_lib.py b/examples/models/llama2/eval_llama_lib.py index 2d10f5edc0a..b8987ac5d49 100644 --- a/examples/models/llama2/eval_llama_lib.py +++ b/examples/models/llama2/eval_llama_lib.py @@ -41,6 +41,7 @@ def __init__( tokenizer: Union[SentencePieceTokenizer, Tiktoken], max_seq_length: Optional[int] = None, use_kv_cache: bool = False, + generate_full_logits: bool = False, enable_dynamic_shape: bool = True, ): super().__init__( @@ -48,6 +49,7 @@ def __init__( ) self._model = model.to(self.device) self._use_kv_cache = use_kv_cache + self._generate_full_logits = generate_full_logits self._enable_dynamic_shape = enable_dynamic_shape def _model_call(self, inps): @@ -60,7 +62,10 @@ def _model_call(self, inps): pos_tensor = torch.tensor([pos], dtype=torch.int64) logits = self._model(inps[:, pos : pos + 1], pos_tensor) result_logits.append(logits) - return torch.cat(result_logits, dim=1) + if self._generate_full_logits: + return torch.cat(result_logits, dim=1) + else: + return torch.stack(result_logits, dim=1) else: pos_tensor = torch.tensor([0], dtype=torch.int64, device=self.device) # Batch process the whole sequence. diff --git a/examples/models/llama2/export_llama_lib.py b/examples/models/llama2/export_llama_lib.py index 977348946b3..611bf16428d 100644 --- a/examples/models/llama2/export_llama_lib.py +++ b/examples/models/llama2/export_llama_lib.py @@ -233,7 +233,7 @@ def build_args_parser() -> argparse.ArgumentParser: "--optimized_rotation_path", default=None, required=False, - help="[QNN Backend] Optimized rotation checkpoint path. Just apply R1/R2 here." + help="[QNN backend] Optimized rotation checkpoint path. Just apply R1/R2 here." "You can download the optimized rotation matrices from https://github.com/facebookresearch/SpinQuant/tree/main", ) parser.add_argument( @@ -440,6 +440,9 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager: transforms.append(replace_sdpa_with_flex_sdpa) transforms.append(replace_causal_mask) transforms.append(replace_rms_norm_with_native_rms_norm) + if args.optimized_rotation_path: + transforms.append(fuse_layer_norms) + transforms.append(get_model_with_r1_r2(args.optimized_rotation_path)) transforms.append(convert_linear_to_conv2d) elif args.coreml or args.mps: @@ -448,9 +451,6 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager: transforms.append(replace_sdpa_with_simple_sdpa) transforms.append(replace_causal_mask) - if args.optimized_rotation_path: - transforms.append(fuse_layer_norms) - transforms.append(get_model_with_r1_r2(args.optimized_rotation_path)) return ( _load_llama_model( modelname=modelname, @@ -744,6 +744,7 @@ def _load_llama_model( max_seq_len=model.params.max_seq_len, dtype=dtype, use_kv_cache=use_kv_cache, + generate_full_logits=generate_full_logits, example_inputs=example_inputs, enable_dynamic_shape=enable_dynamic_shape, calibration_tasks=calibration_tasks, diff --git a/extension/llm/export/builder.py b/extension/llm/export/builder.py index bc64ae869fc..4237ae7b3a7 100644 --- a/extension/llm/export/builder.py +++ b/extension/llm/export/builder.py @@ -69,6 +69,7 @@ def __init__( example_inputs, args: Optional[Any] = None, enable_dynamic_shape: bool = False, + generate_full_logits: bool = False, calibration_tasks: Optional[List[str]] = None, calibration_limit: Optional[int] = None, calibration_seq_length: Optional[int] = None, @@ -86,6 +87,7 @@ def __init__( self.dtype = dtype self.example_inputs = example_inputs self.use_kv_cache = use_kv_cache + self.generate_full_logits = generate_full_logits self.enable_dynamic_shape = enable_dynamic_shape self.verbose = verbose self.metadata = metadata @@ -229,7 +231,12 @@ def calibrate_template( ) pos += 1 if pos >= len(token_list): - token_list.append(torch.argmax(logits[:], dim=-1).item()) + if self.generate_full_logits: + token_list.append( + torch.argmax(logits[:, -1], dim=-1).item() + ) + else: + token_list.append(torch.argmax(logits[:], dim=-1).item()) calibrate_template( module=prepared_module, @@ -243,6 +250,7 @@ def calibrate_template( tokenizer=tokenizer, max_seq_length=calibration_seq_length, use_kv_cache=self.use_kv_cache, + generate_full_logits=self.generate_full_logits, enable_dynamic_shape=self.enable_dynamic_shape, ) eval_results = evaluate_model( diff --git a/extension/llm/export/partitioner_lib.py b/extension/llm/export/partitioner_lib.py index 29c7b3731fb..f5cc04ead48 100644 --- a/extension/llm/export/partitioner_lib.py +++ b/extension/llm/export/partitioner_lib.py @@ -139,16 +139,9 @@ def get_qnn_partitioner( if pt2e_quantize is not None: use_fp16 = False - soc_chip_table = { - "SM8650": QcomChipset.SM8650, - "SM8550": QcomChipset.SM8550, - "SM8475": QcomChipset.SM8475, - "SM8450": QcomChipset.SM8450, - } - return QnnPartitioner( # pyre-fixme[16] generate_qnn_executorch_compiler_spec( # pyre-fixme[16] - soc_model=soc_chip_table[soc_model], # pyre-fixme[16] + soc_model=getattr(QcomChipset, soc_model), # pyre-fixme[16] # pyre-fixme[16] backend_options=generate_htp_compiler_spec( use_fp16=use_fp16,