In [1]:
import torch

In [48]:
ori_ckpt_path = "/mnt/sfs/asr/ckpt/epoch_3.pt"
new_ckpt_path = "/mnt/obs/ckpt/um/qwen2_multi_task_4_fa/epoch_0_with_speechp/step_0/mp_rank_00_model_states.pt"

In [3]:
def print_ckpt_structure(ckpt, indent=0, max_depth=None, current_depth=0):
    """
    打印 PyTorch 模型 checkpoint 的结构。

    参数:
        ckpt (dict): 模型 checkpoint 的字典。
        indent (int): 当前层级的缩进空格数。
        max_depth (int): 最大打印深度。如果为 None，则打印所有层级。
        current_depth (int): 当前递归深度。
    """
    if max_depth is not None and current_depth > max_depth:
        return

    for key, value in ckpt.items():
        # 打印当前键
        print(" " * indent + f"Key: {key}")

        # 如果值是字典，递归打印
        if isinstance(value, dict):
            print_ckpt_structure(value, indent + 4, max_depth, current_depth + 1)
        else:
            # 打印值的类型和形状（如果是张量）
            if isinstance(value, torch.Tensor):
                print(" " * (indent + 4) + f"Type: {type(value)}, Shape: {value.shape}")
            else:
                print(" " * (indent + 4) + f"Type: {type(value)}, Value: {value}")
# 假设有一个嵌套的 checkpoint 字典
ckpt = {
    "model_state_dict": {
        "layer1.weight": torch.randn(3, 3),
        "layer1.bias": torch.randn(3),
        "layer2": {
            "weight": torch.randn(2, 2),
            "bias": torch.randn(2),
        },
    },
    "optimizer_state_dict": {
        "state": {
            "param1": torch.randn(1),
            "param2": torch.randn(1),
        },
        "param_groups": [{"lr": 0.001}],
    },
    "epoch": 10,
    "loss": 0.5,
}

# 打印 checkpoint 结构
print_ckpt_structure(ckpt, max_depth=2)

Key: model_state_dict
    Key: layer1.weight
        Type: <class 'torch.Tensor'>, Shape: torch.Size([3, 3])
    Key: layer1.bias
        Type: <class 'torch.Tensor'>, Shape: torch.Size([3])
    Key: layer2
        Key: weight
            Type: <class 'torch.Tensor'>, Shape: torch.Size([2, 2])
        Key: bias
            Type: <class 'torch.Tensor'>, Shape: torch.Size([2])
Key: optimizer_state_dict
    Key: state
        Key: param1
            Type: <class 'torch.Tensor'>, Shape: torch.Size([1])
        Key: param2
            Type: <class 'torch.Tensor'>, Shape: torch.Size([1])
    Key: param_groups
        Type: <class 'list'>, Value: [{'lr': 0.001}]
Key: epoch
    Type: <class 'int'>, Value: 10
Key: loss
    Type: <class 'float'>, Value: 0.5


In [54]:
ori_ckpt = torch.load(ori_ckpt_path)

In [5]:
print_ckpt_structure(ori_ckpt)

Key: encoder.embed.pos_enc.pe
    Type: <class 'torch.Tensor'>, Shape: torch.Size([1, 1500, 1024])
Key: speech_transformer.embed.pos_enc.pe
    Type: <class 'torch.Tensor'>, Shape: torch.Size([1, 5000, 1024])
Key: llama_model.base_model.model.model.embed_tokens.weight
    Type: <class 'torch.Tensor'>, Shape: torch.Size([152064, 3584])
Key: llama_model.base_model.model.model.layers.0.self_attn.q_proj.weight
    Type: <class 'torch.Tensor'>, Shape: torch.Size([3584, 3584])
Key: llama_model.base_model.model.model.layers.0.self_attn.q_proj.bias
    Type: <class 'torch.Tensor'>, Shape: torch.Size([3584])
Key: llama_model.base_model.model.model.layers.0.self_attn.k_proj.weight
    Type: <class 'torch.Tensor'>, Shape: torch.Size([512, 3584])
Key: llama_model.base_model.model.model.layers.0.self_attn.k_proj.bias
    Type: <class 'torch.Tensor'>, Shape: torch.Size([512])
Key: llama_model.base_model.model.model.layers.0.self_attn.v_proj.weight
    Type: <class 'torch.Tensor'>, Shape: torch.Size(

In [22]:
new_ckpt = torch.load(new_ckpt_path, map_location="cpu")

In [23]:
print_ckpt_structure(new_ckpt)

Key: module
    Key: encoder.embed.conv.0.weight
        Type: <class 'torch.Tensor'>, Shape: torch.Size([1024, 80, 3])
    Key: encoder.embed.conv.0.bias
        Type: <class 'torch.Tensor'>, Shape: torch.Size([1024])
    Key: encoder.embed.conv.2.weight
        Type: <class 'torch.Tensor'>, Shape: torch.Size([1024, 1024, 3])
    Key: encoder.embed.conv.2.bias
        Type: <class 'torch.Tensor'>, Shape: torch.Size([1024])
    Key: encoder.embed.pos_enc.pe
        Type: <class 'torch.Tensor'>, Shape: torch.Size([1, 1500, 1024])
    Key: encoder.after_norm.weight
        Type: <class 'torch.Tensor'>, Shape: torch.Size([1024])
    Key: encoder.after_norm.bias
        Type: <class 'torch.Tensor'>, Shape: torch.Size([1024])
    Key: encoder.encoders.0.self_attn.linear_q.weight
        Type: <class 'torch.Tensor'>, Shape: torch.Size([1024, 1024])
    Key: encoder.encoders.0.self_attn.linear_q.bias
        Type: <class 'torch.Tensor'>, Shape: torch.Size([1024])
    Key: encoder.encoders.0.s

## Merge
### encoder
> 这个暂时不能改，whisper encoder的 k没有bias，但是q和v有bias，因此bias不能合并
- `encoder.encoders.1.self_attn.linear_qkv.weight`: torch.Size([3072, 1024])
    - `encoder.encoders.1.self_attn.linear_q.weight`: torch.Size([1024, 1024])
    - `encoder.encoders.1.self_attn.linear_k.weight`: torch.Size([1024, 1024])
    - `encoder.encoders.1.self_attn.linear_v.weight`: torch.Size([1024, 1024])
- `encoder.encoders.1.self_attn.linear_qkv.bias`: torch.Size([3072])
    - `encoder.encoders.1.self_attn.linear_q.bias`: torch.Size([1024])
    - `encoder.encoders.1.self_attn.linear_k.bias`: torch.Size([1024])
    - `encoder.encoders.1.self_attn.linear_v.bias`: torch.Size([1024])
### llama
- `llama_model.base_model.model.model.layers.0.self_attn.w_pack.weight`: torch.Size([4608, 3584])
    - `llama_model.base_model.model.model.layers.0.self_attn.q_proj.weight`: torch.Size([3584, 3584])
    - `llama_model.base_model.model.model.layers.0.self_attn.k_proj.weight`: torch.Size([512, 3584])
    - `llama_model.base_model.model.model.layers.0.self_attn.v_proj.weight`: torch.Size([512, 3584])
- `llama_model.base_model.model.model.layers.0.self_attn.w_pack.bias`: torch.Size([4608])
    - `llama_model.base_model.model.model.layers.0.self_attn.q_proj.bias`: torch.Size([3584])
    - `llama_model.base_model.model.model.layers.0.self_attn.k_proj.bias`: torch.Size([512])
    - `llama_model.base_model.model.model.layers.0.self_attn.v_proj.bias`: torch.Size([512])

param_shapes shape调整

In [53]:
def merge(target, source):
    # encoder_weight = 'encoder.encoders.{layer_idx}.self_attn.{layer_name}.weight'
    # encoder_bias = 'encoder.encoders.{layer_idx}.self_attn.{layer_name}.bias'
    # for i in range(24):
    #     layer_idx = str(i)
    #     target[encoder_weight.format(layer_idx=layer_idx, layer_name='linear_qkv')]= torch.cat(
    #         [source[encoder_weight.format(layer_idx=layer_idx, layer_name='linear_q')],
    #         source[encoder_weight.format(layer_idx=layer_idx, layer_name='linear_k')],
    #         source[encoder_weight.format(layer_idx=layer_idx, layer_name='linear_v')]],
    #         dim=0
    #     ).contiguous()
    #     target[encoder_bias.format(layer_idx=layer_idx, layer_name='linear_qkv')]= torch.cat(
    #         [source[encoder_bias.format(layer_idx=layer_idx, layer_name='linear_q')],
    #         source[encoder_bias.format(layer_idx=layer_idx, layer_name='linear_k')],
    #         source[encoder_bias.format(layer_idx=layer_idx, layer_name='linear_v')]],
    #         dim=0
    #     ).contiguous()
    llama_weight = 'llama_model.base_model.model.model.layers.{layer_idx}.self_attn.{layer_name}.weight'
    llama_bias = 'llama_model.base_model.model.model.layers.{layer_idx}.self_attn.{layer_name}.bias'
    for i in range(28):
        layer_idx = str(i)
        assert llama_weight.format(layer_idx=layer_idx, layer_name='w_pack') in target.keys()
        target[llama_weight.format(layer_idx=layer_idx, layer_name='w_pack')]= torch.cat(
            [source[llama_weight.format(layer_idx=layer_idx, layer_name='q_proj')],
            source[llama_weight.format(layer_idx=layer_idx, layer_name='k_proj')],
            source[llama_weight.format(layer_idx=layer_idx, layer_name='v_proj')]],
            dim=0
        ).contiguous()
        print(target[llama_weight.format(layer_idx=layer_idx, layer_name='w_pack')].shape)

        assert llama_bias.format(layer_idx=layer_idx, layer_name='w_pack') in target.keys()
        target[llama_bias.format(layer_idx=layer_idx, layer_name='w_pack')]= torch.cat(
            [source[llama_bias.format(layer_idx=layer_idx, layer_name='q_proj')],
            source[llama_bias.format(layer_idx=layer_idx, layer_name='k_proj')],
            source[llama_bias.format(layer_idx=layer_idx, layer_name='v_proj')]],
            dim=0
        ).contiguous()
        print(target[llama_bias.format(layer_idx=layer_idx, layer_name='w_pack')].shape)

def merge_v2(source):
    llama_weight = 'llama_model.base_model.model.model.layers.{layer_idx}.self_attn.{layer_name}.weight'
    llama_bias = 'llama_model.base_model.model.model.layers.{layer_idx}.self_attn.{layer_name}.bias'
    for i in range(28):
        layer_idx = str(i)
        source[llama_weight.format(layer_idx=layer_idx, layer_name='w_pack')]= torch.cat(
            [source[llama_weight.format(layer_idx=layer_idx, layer_name='q_proj')],
            source[llama_weight.format(layer_idx=layer_idx, layer_name='k_proj')],
            source[llama_weight.format(layer_idx=layer_idx, layer_name='v_proj')]],
            dim=0
        ).contiguous()
        print(source[llama_weight.format(layer_idx=layer_idx, layer_name='w_pack')].shape)
        # del source[llama_weight.format(layer_idx=layer_idx, layer_name='q_proj')]
        # del source[llama_weight.format(layer_idx=layer_idx, layer_name='k_proj')]
        # del source[llama_weight.format(layer_idx=layer_idx, layer_name='v_proj')]
        source[llama_bias.format(layer_idx=layer_idx, layer_name='w_pack')]= torch.cat(
            [source[llama_bias.format(layer_idx=layer_idx, layer_name='q_proj')],
            source[llama_bias.format(layer_idx=layer_idx, layer_name='k_proj')],
            source[llama_bias.format(layer_idx=layer_idx, layer_name='v_proj')]],
            dim=0
        ).contiguous()
        print(source[llama_bias.format(layer_idx=layer_idx, layer_name='w_pack')].shape)
        # del source[llama_bias.format(layer_idx=layer_idx, layer_name='q_proj')]
        # del source[llama_bias.format(layer_idx=layer_idx, layer_name='k_proj')]
        # del source[llama_bias.format(layer_idx=layer_idx, layer_name='v_proj')]

def update_shape(ckpt):
    for k, v in ckpt['module'].items():
        ckpt['param_shapes'][0][k] = v.size()

In [55]:
merge_v2(ori_ckpt)

torch.Size([4608, 3584])
torch.Size([4608])
torch.Size([4608, 3584])
torch.Size([4608])
torch.Size([4608, 3584])
torch.Size([4608])
torch.Size([4608, 3584])
torch.Size([4608])
torch.Size([4608, 3584])
torch.Size([4608])
torch.Size([4608, 3584])
torch.Size([4608])
torch.Size([4608, 3584])
torch.Size([4608])
torch.Size([4608, 3584])
torch.Size([4608])
torch.Size([4608, 3584])
torch.Size([4608])
torch.Size([4608, 3584])
torch.Size([4608])
torch.Size([4608, 3584])
torch.Size([4608])
torch.Size([4608, 3584])
torch.Size([4608])
torch.Size([4608, 3584])
torch.Size([4608])
torch.Size([4608, 3584])
torch.Size([4608])
torch.Size([4608, 3584])
torch.Size([4608])
torch.Size([4608, 3584])
torch.Size([4608])
torch.Size([4608, 3584])
torch.Size([4608])
torch.Size([4608, 3584])
torch.Size([4608])
torch.Size([4608, 3584])
torch.Size([4608])
torch.Size([4608, 3584])
torch.Size([4608])
torch.Size([4608, 3584])
torch.Size([4608])
torch.Size([4608, 3584])
torch.Size([4608])
torch.Size([4608, 3584])
torch.S

In [35]:
# update_shape(new_ckpt)

In [56]:
ckpt_output_path = "/mnt/obs/ckpt/um/qwen2_multi_task_4_fa/epoch_0_with_speechp/step_0/mp_rank_00_model_states_merge_llama_qkv.pt"
torch.save(ori_ckpt, ckpt_output_path)