From fb1312fc2c4296d9b27a09385eb1147df1743dff Mon Sep 17 00:00:00 2001 From: Jack Zhang Date: Mon, 14 Oct 2024 17:52:47 -0700 Subject: [PATCH] Llama2 model cleanup (#5859) Summary: - Removes redundant steps in the Llama2 export - Factors out checkpointing to be shared with future Llama models (namely 3.2 multimodal) - Comments and orders code more clearly PR chain: - [Add kwarg example inputs to eager model base](https://github.com/pytorch/executorch/pull/5765) - **YOU ARE HERE ~>** [Llama2 model cleanup](https://github.com/pytorch/executorch/pull/5859) - [Accept model type parameter in export_llama](https://github.com/pytorch/executorch/pull/5910) - [Export TorchTune llama3_2_vision in ET](https://github.com/pytorch/executorch/pull/5911) - [Add et version of TorchTune MHA for swapping with custom op](https://github.com/pytorch/executorch/pull/5912) Test Plan: Ensure export + eval is similar before and after for Stories 110M: ``` python -m examples.models.llama2.eval_llama -c -p -t -d fp32 --max_seq_len 2048 --limit 1000 ``` Before: ``` wikitext: {'word_perplexity,none': 14464.645927166595, 'word_perplexity_stderr,none': 'N/A', 'byte_perplexity,none': 5.99788806086652, 'byte_perplexity_stderr,none': 'N/A', 'bits_per_byte,none': 2.5844545973083983, 'bits_per_byte_stderr,none': 'N/A', 'alias': 'wikitext'} ``` After: ``` wikitext: {'word_perplexity,none': 14464.299192404438, 'word_perplexity_stderr,none': 'N/A', 'byte_perplexity,none': 5.997861173678705, 'byte_perplexity_stderr,none': 'N/A', 'bits_per_byte,none': 2.584448130015399, 'bits_per_byte_stderr,none': 'N/A', 'alias': 'wikitext'} ``` Reviewed By: dbort Differential Revision: D64145852 Pulled By: dvorjackz --- examples/models/checkpoint.py | 73 ++++++++++++++++++++++++++++ examples/models/llama2/TARGETS | 1 + examples/models/llama2/model.py | 84 ++++++++++++--------------------- 3 files changed, 103 insertions(+), 55 deletions(-) create mode 100644 examples/models/checkpoint.py diff --git a/examples/models/checkpoint.py b/examples/models/checkpoint.py new file mode 100644 index 00000000000..cd916142d9c --- /dev/null +++ b/examples/models/checkpoint.py @@ -0,0 +1,73 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +from pathlib import Path +from typing import Any, Dict, Optional + + +def get_default_model_resource_dir(model_file_path: str) -> Path: + """ + Get the default path to resouce files (which contain files such as the + checkpoint and param files), either: + 1. Uses the path from pkg_resources, only works with buck2 + 2. Uses default path located in examples/models/llama2/params + + Expected to be called from with a `model.py` file located in a + `executorch/examples/models/` directory. + + Args: + model_file_path: The file path to the eager model definition. + For example, `executorch/examples/models/llama2/model.py`, + where `executorch/examples/models/llama2` contains all + the llama2-related files. + + Returns: + The path to the resource directory containing checkpoint, params, etc. + """ + + try: + import pkg_resources + + # 1st way: If we can import this path, we are running with buck2 and all resources can be accessed with pkg_resources. + # pyre-ignore + from executorch.examples.models.llama2 import params # noqa + + # Get the model name from the cwd, assuming that this module is called from a path such as + # examples/models//model.py. + model_name = Path(model_file_path).parent.name + resource_dir = Path( + pkg_resources.resource_filename( + f"executorch.examples.models.{model_name}", "params" + ) + ) + except: + # 2nd way. + resource_dir = Path(model_file_path).absolute().parent / "params" + + return resource_dir + + +def get_checkpoint_dtype(checkpoint: Dict[str, Any]) -> Optional[str]: + """ + Get the dtype of the checkpoint, returning "None" if the checkpoint is empty. + """ + dtype = None + if len(checkpoint) > 0: + first_key = next(iter(checkpoint)) + first = checkpoint[first_key] + dtype = first.dtype + mismatched_dtypes = [ + (key, value.dtype) + for key, value in checkpoint.items() + if value.dtype != dtype + ] + if len(mismatched_dtypes) > 0: + raise ValueError( + f"Mixed dtype model. Dtype of {first_key}: {first.dtype}. Mismatches in the checkpoint: {mismatched_dtypes}" + ) + return dtype diff --git a/examples/models/llama2/TARGETS b/examples/models/llama2/TARGETS index a80c62514df..13bf0a52c20 100644 --- a/examples/models/llama2/TARGETS +++ b/examples/models/llama2/TARGETS @@ -46,6 +46,7 @@ runtime.python_library( "//caffe2:torch", "//executorch/examples/models:model_base", "//executorch/examples/models/llama2:llama_transformer", + "//executorch/examples/models:checkpoint", ], ) diff --git a/examples/models/llama2/model.py b/examples/models/llama2/model.py index d8d0ff00ffa..23f1c1b4898 100644 --- a/examples/models/llama2/model.py +++ b/examples/models/llama2/model.py @@ -8,9 +8,13 @@ import json import os -from pathlib import Path +from typing import Dict, Tuple import torch +from executorch.examples.models.checkpoint import ( + get_checkpoint_dtype, + get_default_model_resource_dir, +) from executorch.examples.models.llama2.llama_transformer import ModelArgs, Transformer @@ -30,48 +34,31 @@ def convert_to_llama_checkpoint(**kwargs): class Llama2Model(EagerModelBase): def __init__(self, **kwargs): - import pkg_resources - - # default path to the resource file - # It currently supports 3 ways of specifying the checkpoint location: - # 1. Using default path locates in examples/models/llama2/params - # 2. Passing in the checkpoint path and params via kwargs - # 3. Using the path from pkg_resources, only works with buck2 - try: - # The 3rd way, if we can import this path, we are running with buck2, all resources can be accessed with pkg_resources.resource_filename - # pyre-ignore - from executorch.examples.models.llama2 import params - - ckpt_dir = Path( - pkg_resources.resource_filename( - "executorch.examples.models.llama2", "params" - ) - ) - except: - # The 1st way - ckpt_dir = Path(__file__).absolute().parent / "params" - - # Check if checkpoint_dir was provided for a sharded checkpoint. - checkpoint_dir = kwargs.get("checkpoint_dir", None) + resource_dir = get_default_model_resource_dir(__file__) # Use single checkpoint file. - checkpoint_path = kwargs.get("checkpoint", ckpt_dir / "demo_rand_params.pth") + checkpoint_path = kwargs.get( + "checkpoint", resource_dir / "demo_rand_params.pth" + ) + params_path = kwargs.get("params", resource_dir / "demo_config.json") - params_path = kwargs.get("params", ckpt_dir / "demo_config.json") + # Check if checkpoint_dir was provided for a sharded checkpoint. + checkpoint_dir = kwargs.get("checkpoint_dir", None) self.use_kv_cache = kwargs.get("use_kv_cache", False) self.use_sdpa_with_kv_cache_op = kwargs.get("use_sdpa_with_kv_cache", False) self.generate_full_logits = kwargs.get("generate_full_logits", False) self.enable_dynamic_shape = kwargs.get("enable_dynamic_shape", False) self.output_prune_map_path = kwargs.get("output_prune_map_path", None) - self.max_seq_len = kwargs.get("max_seq_len", 128) self.args = kwargs.get("args", None) + # The example is using a dummy small model with random weights for demo purpose only. - # Follow the instruction in https://github.com/facebookresearch/llama to download the model + # Follow the instruction in https://github.com/facebookresearch/llama to download the model. device = "cpu" # flake8: noqa: TOR102 cps = [] + # Load sharded checkpoint. if checkpoint_dir is not None: # Load multiple checkpoint; ignore the single path. checkpoint_path = None @@ -98,8 +85,11 @@ def __init__(self, **kwargs): else: # Do not duplicate layers shared between each checkpoint. checkpoint[key] = cps[0][key] + # Load single checkpoint. else: checkpoint = torch.load(checkpoint_path, map_location=device, mmap=True) + + # If given checkpoint is fairseq, convert to llama checkpoint. fairseq2_checkpoint = kwargs.get("fairseq2", False) if fairseq2_checkpoint: print("Using fairseq2 checkpoint") @@ -108,12 +98,12 @@ def __init__(self, **kwargs): # NB: some checkpoint contains a "model" field, which is the actual weights dict checkpoint = checkpoint["model"] + # Check if user gave a fairseq2 checkpoint unknowingly without specifying --fairseq2. if (not fairseq2_checkpoint) and checkpoint.get( "final_proj.weight", None ) is not None: - print( + raise ValueError( """ - ************************************************************ This looks like a Fairseq2 checkpoint (based on the presence of `final_proj.weight`. @@ -125,34 +115,21 @@ def __init__(self, **kwargs): """ ) - # get checkpoint dtype - self.dtype = None - if len(checkpoint) > 0: - first_key = next(iter(checkpoint)) - first = checkpoint[first_key] - self.dtype = first.dtype - mismatched_dtypes = [ - (key, value.dtype) - for key, value in checkpoint.items() - if value.dtype != self.dtype - ] - if len(mismatched_dtypes) > 0: - print( - f"Mixed dtype model. Dtype of {first_key}: {first.dtype}. Mismatches in the checkpoint: {mismatched_dtypes}" - ) + # Get checkpoint dtype. + self.dtype = get_checkpoint_dtype(checkpoint) + with open(params_path, "r") as f: params = json.loads(f.read()) output_prune_map = None if self.output_prune_map_path is not None: with open(self.output_prune_map_path, "r") as f: output_prune_map = json.load(f) - # change keys from string to int (json only supports string keys) + # Change keys from string to int (json only supports string keys). output_prune_map = {int(k): v for (k, v) in output_prune_map.items()} - max_seq_len = self.max_seq_len - max_batch_size = 1 + model_args: ModelArgs = ModelArgs( - max_seq_len=max_seq_len, - max_batch_size=max_batch_size, + max_seq_len=self.max_seq_len, + max_batch_size=1, use_kv_cache=self.use_kv_cache, use_sdpa_with_kv_cache_op=self.use_sdpa_with_kv_cache_op, generate_full_logits=self.generate_full_logits, @@ -160,9 +137,6 @@ def __init__(self, **kwargs): enable_dynamic_shape=self.enable_dynamic_shape, **params, ) - if kwargs.get("fairseq2", False): - print("Using fairseq2 checkpoint") - checkpoint = convert_to_llama_checkpoint(checkpoint=checkpoint) if kwargs.get("verbose", False): print("============= weights ================") print("{key} : {weights.numel()} : {weights.size()}") @@ -234,13 +208,13 @@ def __init__(self, **kwargs): print(unexpected) print("============= /unexpected ================") - # prune the output layer if output_prune_map is provided + # Prune the output layer if output_prune_map is provided if output_prune_map is not None: from .source_transformation.prune_output import prune_output_vocab self.model_ = prune_output_vocab(self.model_, output_prune_map) - def get_eager_model(self): + def get_eager_model(self) -> torch.nn.Module: if self.dtype: # convert to the type of the provided checkpoint # input and output are torch.long, so signature unchanged