Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

reload existing llama checkpoints #305

Open
tianyu-l opened this issue May 3, 2024 · 8 comments
Open

reload existing llama checkpoints #305

tianyu-l opened this issue May 3, 2024 · 8 comments
Assignees
Labels
enhancement New feature or request

Comments

@tianyu-l
Copy link
Contributor

tianyu-l commented May 3, 2024

No description provided.

@tianyu-l tianyu-l added the enhancement New feature or request label May 3, 2024
@Lauler
Copy link

Lauler commented May 11, 2024

Is this issue related to loading pretrained Llama2/Llama3 weights and using them as checkpoint?

I was going to start a separate issue asking for some docs that explain how to convert pretrained weights from HF to torchtitan in order to do continued pretraining. Is that already possible or on the roadmap?

@fegin
Copy link
Contributor

fegin commented May 13, 2024

DCP has the format util to help the conversion. However, HF conversion should not live in PyTorch code base.

@tianyu-l
Copy link
Contributor Author

@lessw2020 will connect with HF to see if they can support weights conversion from HF to pytorch. After that, we may import that in the code or update the tutorial.

@rlrs
Copy link

rlrs commented May 21, 2024

I have a straightforward script for converting from HF to a DCP checkpoint, if that helps. Mostly the script already exists in gpt-fast.

@tianyu-l
Copy link
Contributor Author

@rlrs Thanks, pls feel free to share it here!

As far as we know, HF is also working on such a script to convert from HF to DCP. As discussed in #335, we should include a script to convert from llama raw weights into DCP (similar to the one here), and it probably should still sit in pytorch/pytorch.

@rlrs
Copy link

rlrs commented May 24, 2024

Alright so this is the script I'm using for HF->DCP. It uses the safetensors weights (but can easily be converted to load a torch.save instead), which only exist in https://huggingface.co/meta-llama/Meta-Llama-3-8B/tree/main in the root, and not under original/. So as we discussed in #335, some of the weights are permuted compared to the original.
I've been using it to just create a step-0 checkpoint that torchtitan is already set up to start from.

import json
import re
import sys
from pathlib import Path
from safetensors import safe_open
import torch.distributed.checkpoint as DCP

import torch

# support running without installing as a package
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))

from maester.models import models_config


@torch.inference_mode()
def convert_hf_checkpoint(
    *,
    checkpoint_dir: Path,
    output_dir: Path,
) -> None:
    # Load the json file containing weight mapping
    model_map_json = checkpoint_dir / "model.safetensors.index.json"

    assert model_map_json.is_file()

    with open(model_map_json, 'r') as json_map:
        bin_index = json.load(json_map)

    weight_map = {
        "model.embed_tokens.weight": "tok_embeddings.weight",
        "model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight",
        "model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight",
        "model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight",
        "model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight",
        'model.layers.{}.self_attn.rotary_emb.inv_freq': None,
        'model.layers.{}.mlp.gate_proj.weight': 'layers.{}.feed_forward.w1.weight',
        "model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight",
        "model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight",
        "model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight",
        "model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight",
        "model.norm.weight": "norm.weight",
        "lm_head.weight": "output.weight",
    }
    bin_files = {checkpoint_dir / bin for bin in bin_index["weight_map"].values()}

    merged_result = {}
    for file in sorted(bin_files):
        with safe_open(file, framework="pt", device="cpu") as f:
            for k in f.keys():
                merged_result[k] = f.get_tensor(k)
    final_result = {}
    
    for key, value in merged_result.items():
        if "layers" in key:
            abstract_key = re.sub(r'(\d+)', '{}', key)
            layer_num = re.search(r'\d+', key).group(0)
            new_key = weight_map[abstract_key]
            if new_key is None:
                continue
            new_key = new_key.format(layer_num)
        else:
            new_key = weight_map[key]

        final_result[new_key] = value

    output_dir.mkdir(parents=True, exist_ok=True)
    storage_writer = DCP.filesystem.FileSystemWriter(output_dir)
    DCP.save({"model": final_result}, 
             storage_writer=storage_writer)

if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser(description='Convert HuggingFace checkpoint.')
    parser.add_argument('--checkpoint', type=Path, required=True)
    parser.add_argument('--output', type=Path, required=True)

    args = parser.parse_args()
    convert_hf_checkpoint(
        checkpoint_dir=args.checkpoint,
        output_dir=args.output,
    )

@kxgong
Copy link

kxgong commented Jun 9, 2024

Alright so this is the script I'm using for HF->DCP. It uses the safetensors weights (but can easily be converted to load a torch.save instead), which only exist in https://huggingface.co/meta-llama/Meta-Llama-3-8B/tree/main in the root, and not under original/. So as we discussed in #335, some of the weights are permuted compared to the original. I've been using it to just create a step-0 checkpoint that torchtitan is already set up to start from.

import json
import re
import sys
from pathlib import Path
from safetensors import safe_open
import torch.distributed.checkpoint as DCP

import torch

# support running without installing as a package
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))

from maester.models import models_config


@torch.inference_mode()
def convert_hf_checkpoint(
    *,
    checkpoint_dir: Path,
    output_dir: Path,
) -> None:
    # Load the json file containing weight mapping
    model_map_json = checkpoint_dir / "model.safetensors.index.json"

    assert model_map_json.is_file()

    with open(model_map_json, 'r') as json_map:
        bin_index = json.load(json_map)

    weight_map = {
        "model.embed_tokens.weight": "tok_embeddings.weight",
        "model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight",
        "model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight",
        "model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight",
        "model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight",
        'model.layers.{}.self_attn.rotary_emb.inv_freq': None,
        'model.layers.{}.mlp.gate_proj.weight': 'layers.{}.feed_forward.w1.weight',
        "model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight",
        "model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight",
        "model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight",
        "model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight",
        "model.norm.weight": "norm.weight",
        "lm_head.weight": "output.weight",
    }
    bin_files = {checkpoint_dir / bin for bin in bin_index["weight_map"].values()}

    merged_result = {}
    for file in sorted(bin_files):
        with safe_open(file, framework="pt", device="cpu") as f:
            for k in f.keys():
                merged_result[k] = f.get_tensor(k)
    final_result = {}
    
    for key, value in merged_result.items():
        if "layers" in key:
            abstract_key = re.sub(r'(\d+)', '{}', key)
            layer_num = re.search(r'\d+', key).group(0)
            new_key = weight_map[abstract_key]
            if new_key is None:
                continue
            new_key = new_key.format(layer_num)
        else:
            new_key = weight_map[key]

        final_result[new_key] = value

    output_dir.mkdir(parents=True, exist_ok=True)
    storage_writer = DCP.filesystem.FileSystemWriter(output_dir)
    DCP.save({"model": final_result}, 
             storage_writer=storage_writer)

if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser(description='Convert HuggingFace checkpoint.')
    parser.add_argument('--checkpoint', type=Path, required=True)
    parser.add_argument('--output', type=Path, required=True)

    args = parser.parse_args()
    convert_hf_checkpoint(
        checkpoint_dir=args.checkpoint,
        output_dir=args.output,
    )

Thanks for sharing.

@bkchang
Copy link

bkchang commented Jun 20, 2024

Is there a conversion in the other direction? Meaning converting a dcp checkpoint to an HF model? I found a util dcp_to_torch_save but am not sure how to go from there to a HF model.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

7 participants