-
Notifications
You must be signed in to change notification settings - Fork 112
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
reload existing llama checkpoints #305
Comments
Is this issue related to loading pretrained Llama2/Llama3 weights and using them as checkpoint? I was going to start a separate issue asking for some docs that explain how to convert pretrained weights from HF to |
DCP has the format util to help the conversion. However, HF conversion should not live in PyTorch code base. |
@lessw2020 will connect with HF to see if they can support weights conversion from HF to pytorch. After that, we may import that in the code or update the tutorial. |
I have a straightforward script for converting from HF to a DCP checkpoint, if that helps. Mostly the script already exists in gpt-fast. |
Alright so this is the script I'm using for HF->DCP. It uses the safetensors weights (but can easily be converted to load a torch.save instead), which only exist in https://huggingface.co/meta-llama/Meta-Llama-3-8B/tree/main in the root, and not under import json
import re
import sys
from pathlib import Path
from safetensors import safe_open
import torch.distributed.checkpoint as DCP
import torch
# support running without installing as a package
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))
from maester.models import models_config
@torch.inference_mode()
def convert_hf_checkpoint(
*,
checkpoint_dir: Path,
output_dir: Path,
) -> None:
# Load the json file containing weight mapping
model_map_json = checkpoint_dir / "model.safetensors.index.json"
assert model_map_json.is_file()
with open(model_map_json, 'r') as json_map:
bin_index = json.load(json_map)
weight_map = {
"model.embed_tokens.weight": "tok_embeddings.weight",
"model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight",
"model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight",
"model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight",
"model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight",
'model.layers.{}.self_attn.rotary_emb.inv_freq': None,
'model.layers.{}.mlp.gate_proj.weight': 'layers.{}.feed_forward.w1.weight',
"model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight",
"model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight",
"model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight",
"model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight",
"model.norm.weight": "norm.weight",
"lm_head.weight": "output.weight",
}
bin_files = {checkpoint_dir / bin for bin in bin_index["weight_map"].values()}
merged_result = {}
for file in sorted(bin_files):
with safe_open(file, framework="pt", device="cpu") as f:
for k in f.keys():
merged_result[k] = f.get_tensor(k)
final_result = {}
for key, value in merged_result.items():
if "layers" in key:
abstract_key = re.sub(r'(\d+)', '{}', key)
layer_num = re.search(r'\d+', key).group(0)
new_key = weight_map[abstract_key]
if new_key is None:
continue
new_key = new_key.format(layer_num)
else:
new_key = weight_map[key]
final_result[new_key] = value
output_dir.mkdir(parents=True, exist_ok=True)
storage_writer = DCP.filesystem.FileSystemWriter(output_dir)
DCP.save({"model": final_result},
storage_writer=storage_writer)
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(description='Convert HuggingFace checkpoint.')
parser.add_argument('--checkpoint', type=Path, required=True)
parser.add_argument('--output', type=Path, required=True)
args = parser.parse_args()
convert_hf_checkpoint(
checkpoint_dir=args.checkpoint,
output_dir=args.output,
) |
Thanks for sharing. |
Is there a conversion in the other direction? Meaning converting a dcp checkpoint to an HF model? I found a util dcp_to_torch_save but am not sure how to go from there to a HF model. |
No description provided.
The text was updated successfully, but these errors were encountered: