## load_states

In [2]:
from petrel_client.client import Client
import os
import io
import torch
client = Client()

ckpts = [f'tp_{idx}.pt' for idx in range(8)]
states = []
root_file = "anonymous_ssd:s3://model_weights/0331/evaluation/exported_llama/1006/12499/"
for ckpt in ckpts:
    config_file = os.path.join(root_file, ckpt)
    with io.BytesIO(client.get(config_file)) as f:      
        state = torch.load(f, map_location='cpu')
    states.append(state)

## Generate model

In [3]:
from transformers import LlamaForCausalLM, LlamaConfig
import json
with io.BytesIO(client.get(root_file+'params.json')) as f:
    params = json.loads(f.getvalue())
params



{'dim': 10240,
 'multiple_of': 256,
 'n_heads': 80,
 'n_layers': 82,
 'norm_eps': 1e-05,
 'vocab_size': -1}

In [4]:
def get_intermediate_size(params):
    hidden_dim = params['dim'] * 4
    hidden_dim = int(2 * hidden_dim / 3)
    hidden_dim = params['multiple_of'] * ((hidden_dim + params['multiple_of'] - 1) // params['multiple_of'])
    return hidden_dim
llama_config = LlamaConfig(vocab_size=65632,
                    hidden_size=params['dim'],
                    intermediate_size=get_intermediate_size(params),
                    num_hidden_layers=params['n_layers'],
                    num_attention_heads=params['n_heads']
                 )
llama_config

LlamaConfig {
  "bos_token_id": 1,
  "eos_token_id": 2,
  "hidden_act": "silu",
  "hidden_size": 10240,
  "initializer_range": 0.02,
  "intermediate_size": 27392,
  "max_position_embeddings": 2048,
  "model_type": "llama",
  "num_attention_heads": 80,
  "num_hidden_layers": 82,
  "pad_token_id": 0,
  "rms_norm_eps": 1e-06,
  "tie_word_embeddings": false,
  "transformers_version": "4.28.1",
  "use_cache": true,
  "vocab_size": 65632
}

In [5]:
llama_model = LlamaForCausalLM(llama_config)

### 转换（同步）权重

In [6]:
def fuse_weight(states, key, dim):
    return torch.cat([state[key] for state in states], dim=dim)

def load_attention_weight(layer, idx, states):
    new_q_proj = fuse_weight(states, f'layers.{idx}.attention.wq.weight', dim=0)
    assert new_q_proj.shape == layer.self_attn.q_proj.weight.shape
    layer.self_attn.q_proj.weight = torch.nn.Parameter(new_q_proj)
    
    new_k_proj = fuse_weight(states, f'layers.{idx}.attention.wk.weight', dim=0)
    assert new_k_proj.shape == layer.self_attn.k_proj.weight.shape
    layer.self_attn.k_proj.weight = torch.nn.Parameter(new_k_proj)
    
    new_v_proj = fuse_weight(states, f'layers.{idx}.attention.wv.weight', dim=0)
    assert new_v_proj.shape == layer.self_attn.v_proj.weight.shape
    layer.self_attn.v_proj.weight = torch.nn.Parameter(new_v_proj)
    
    new_o_proj = fuse_weight(states, f'layers.{idx}.attention.wo.weight', dim=1)
    assert new_o_proj.shape == layer.self_attn.o_proj.weight.shape
    layer.self_attn.o_proj.weight = torch.nn.Parameter(new_o_proj)

def load_feedforward_weight(layer, idx, states):
    new_w1_weight = fuse_weight(states, f'layers.{idx}.feed_forward.w1.weight', dim=0)
    assert new_w1_weight.shape == layer.mlp.gate_proj.weight.shape
    layer.mlp.gate_proj.weight = torch.nn.Parameter(new_w1_weight)
    
    new_w2_weight = fuse_weight(states, f'layers.{idx}.feed_forward.w2.weight', dim=1)
    assert new_w2_weight.shape == layer.mlp.down_proj.weight.shape
    layer.mlp.down_proj.weight = torch.nn.Parameter(new_w2_weight)
    
    new_w3_weight = fuse_weight(states, f'layers.{idx}.feed_forward.w3.weight', dim=0)
    assert new_w3_weight.shape == layer.mlp.up_proj.weight.shape
    layer.mlp.up_proj.weight = torch.nn.Parameter(new_w3_weight)
    
def load_norm_weight(layer, idx, states):
    layer.input_layernorm.weight = torch.nn.Parameter(states[0][f'layers.{idx}.attention_norm.weight'])
    layer.post_attention_layernorm.weight = torch.nn.Parameter(states[0][f'layers.{idx}.ffn_norm.weight'])
    
def load_embed_tokens(model, states):
    new_em_weight = torch.cat([state['tok_embeddings.weight'] for state in states], dim=1)
    assert new_em_weight.shape == model.embed_tokens.weight.shape
    model.embed_tokens.weight = torch.nn.Parameter(new_em_weight)
    
def load_head_weight(llama_model, states):
    llama_model.lm_head.weight = torch.nn.Parameter(fuse_weight(states, f'output.weight', dim=0))
    
def sync_weight(llama_model, states):
    load_embed_tokens(llama_model.model, states)
    llama_model.model.norm.weight = torch.nn.Parameter(states[0][f'norm.weight'])
    load_head_weight(llama_model, states)
    for idx, layer in enumerate(llama_model.model.layers):
        load_attention_weight(layer, idx, states)
        load_feedforward_weight(layer, idx, states)
        load_norm_weight(layer, idx, states)

In [7]:
sync_weight(llama_model=llama_model, states=states)

## 保存权重

In [41]:
save_path = '/mnt/petrelfs/wangzerui/DeepSpeed/DeepSpeedExamples/applications/DeepSpeed-Chat/llama_model/7132k'

In [40]:
llama_model.save_pretrained(save_path, max_shard_size='2GB')

## 保存tokenizer

In [45]:
from transformers import LlamaTokenizer
tokenizer = LlamaTokenizer.from_pretrained('/mnt/petrelfs/wangzerui/DeepSpeed/DeepSpeedExamples/llamav4.model')
tokenizer.save_pretrained(save_path)



('/mnt/petrelfs/wangzerui/DeepSpeed/DeepSpeedExamples/applications/DeepSpeed-Chat/llama_model/7132k/tokenizer_config.json',
 '/mnt/petrelfs/wangzerui/DeepSpeed/DeepSpeedExamples/applications/DeepSpeed-Chat/llama_model/7132k/special_tokens_map.json',
 '/mnt/petrelfs/wangzerui/DeepSpeed/DeepSpeedExamples/applications/DeepSpeed-Chat/llama_model/7132k/tokenizer.model',
 '/mnt/petrelfs/wangzerui/DeepSpeed/DeepSpeedExamples/applications/DeepSpeed-Chat/llama_model/7132k/added_tokens.json')

## 测试

In [1]:
from transformers import LlamaTokenizer
tokenizer = LlamaTokenizer.from_pretrained('/mnt/petrelfs/wangzerui/DeepSpeed/DeepSpeedExamples/applications/DeepSpeed-Chat/llama_model/7132k')

In [3]:
from transformers import LlamaForCausalLM
llama_model = LlamaForCausalLM.from_pretrained('/mnt/petrelfs/wangzerui/DeepSpeed/DeepSpeedExamples/applications/DeepSpeed-Chat/llama_model/7132k')

Loading checkpoint shards:   0%|          | 0/14 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [9]:
prompt = "周杰伦是华语乐坛最"
inputs = tokenizer(prompt, return_tensors="pt")
generate_ids = llama_model.generate(inputs.input_ids, max_length=30)
tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]

'周杰伦是华语乐坛最有影响力的歌手之一,他的歌曲传唱红遍大江南北国内外'

## 上传权重

In [11]:
from petrel_client.client import Client
import torch
client = Client()
def save_func(model, path):
    ROOT_PATH = 'Sproject_ssd_02:s3://debug_ssd_02/wangzerui/'
    full_path = ROOT_PATH + path
    print(full_path)
    buffer = io.BytesIO()
    torch.save(model, buffer)
    client.put(full_path, buffer.getvalue())
    
llama_model.save_pretrained('7132k', save_function=save_func, max_shard_size='3GB')

Sproject_ssd_02:s3://debug_ssd_02/wangzerui/7132k/pytorch_model-00001-of-00083.bin
Sproject_ssd_02:s3://debug_ssd_02/wangzerui/7132k/pytorch_model-00002-of-00083.bin
Sproject_ssd_02:s3://debug_ssd_02/wangzerui/7132k/pytorch_model-00003-of-00083.bin
Sproject_ssd_02:s3://debug_ssd_02/wangzerui/7132k/pytorch_model-00004-of-00083.bin
Sproject_ssd_02:s3://debug_ssd_02/wangzerui/7132k/pytorch_model-00005-of-00083.bin
Sproject_ssd_02:s3://debug_ssd_02/wangzerui/7132k/pytorch_model-00006-of-00083.bin
Sproject_ssd_02:s3://debug_ssd_02/wangzerui/7132k/pytorch_model-00007-of-00083.bin
Sproject_ssd_02:s3://debug_ssd_02/wangzerui/7132k/pytorch_model-00008-of-00083.bin
Sproject_ssd_02:s3://debug_ssd_02/wangzerui/7132k/pytorch_model-00009-of-00083.bin
Sproject_ssd_02:s3://debug_ssd_02/wangzerui/7132k/pytorch_model-00010-of-00083.bin
Sproject_ssd_02:s3://debug_ssd_02/wangzerui/7132k/pytorch_model-00011-of-00083.bin
Sproject_ssd_02:s3://debug_ssd_02/wangzerui/7132k/pytorch_model-00012-of-00083.bin
Spro

## load from s3

In [1]:
from transformers import AutoConfig, LlamaForCausalLM
local_config_path = '/mnt/petrelfs/wangzerui/DeepSpeed/DeepSpeedExamples/applications/DeepSpeed-Chat/llama_model/100B'
model_config = AutoConfig.from_pretrained(local_config_path)



In [None]:
model = LlamaForCausalLM(model_config)
model.generate()

In [None]:
import json
from functools import partial
from petrel_client.client import Client
import os
import torch
import io

def load_sharded_checkpoint(model, folder, s3_root, client):
    """
    This is the same as
    [`torch.nn.Module.load_state_dict`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=load_state_dict#torch.nn.Module.load_state_dict)
    but for a sharded checkpoint.

    This load is performed efficiently: each checkpoint shard is loaded one by one in RAM and deleted after being
    loaded in the model.

    Args:
        model (`torch.nn.Module`): The model in which to load the checkpoint.
        folder (`str` or `os.PathLike`): A path to a folder containing the sharded checkpoint.
        strict (`bool`, *optional`, defaults to `True`):
            Whether to strictly enforce that the keys in the model state dict match the keys in the sharded checkpoint.
        prefer_safe (`bool`, *optional*, defaults to `False`)
            If both safetensors and PyTorch save files are present in checkpoint and `prefer_safe` is True, the
            safetensors files will be loaded. Otherwise, PyTorch files are always loaded when possible.

    Returns:
        `NamedTuple`: A named tuple with `missing_keys` and `unexpected_keys` fields
            - `missing_keys` is a list of str containing the missing keys
            - `unexpected_keys` is a list of str containing the unexpected keys
    """
    # Load the index
    index_file = os.path.join(folder, "pytorch_model.bin.index.json")

    index_present = os.path.isfile(index_file)
    assert index_present
    assert s3_root.endswith('/')
    load_index = index_file

    with open(load_index, "r", encoding="utf-8") as f:
        index = json.load(f)

    shard_files = list(set(index["weight_map"].values()))

    # If strict=True, error before loading any of the state dicts.
    loaded_keys = index["weight_map"].keys()
    model_keys = model.state_dict().keys()
    missing_keys = [key for key in model_keys if key not in loaded_keys]
    unexpected_keys = [key for key in loaded_keys if key not in model_keys]

    loader = partial(torch.load, map_location="cpu")

    for shard_file in shard_files:
        s3_shard_file_path = s3_root + shard_file
        print(s3_shard_file_path)
        with io.BytesIO(client.get(s3_shard_file_path)) as f:
            state_dict = loader(f)
            model.load_state_dict(state_dict, strict=False)
            # Make sure memory is freed before we load the next state dict.
            del state_dict
client = Client()
s3_root = 'Sproject_ssd_02:s3://debug_ssd_02/wangzerui/7132k/'
load_sharded_checkpoint(model=model, folder=local_config_path, client=client, s3_root=s3_root)

In [27]:
prompt = "周杰伦是华语乐坛最"
inputs = tokenizer(prompt, return_tensors="pt")
generate_ids = model.generate(inputs.input_ids, max_length=30)
tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]

'周杰伦是华语乐坛最成功的歌手之一，也是华语乐坛的领军人物，他的歌曲，他的歌曲，他的歌曲，他的歌曲，他的歌曲'