In [22]:
import torch
import timm

# 1. 创建 ViT-Tiny 模型，并加载 ImageNet 预训练权重
model = timm.create_model('vit_tiny_patch16_224.augreg_in21k', pretrained=False, checkpoint_path="./vit_tiny_patch16_224.augreg_in21k.pth") 

# (可选) 将模型移动到 GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

dense_part = model.state_dict()

In [23]:
from model.vision_transformer_zeke import VisionTransformer_zeke

moe_model = VisionTransformer_zeke(embed_dim=192, num_heads=3, num_classes=100)

moe_part = moe_model.state_dict()


moe_set = set()

dense_set = set()

print("*"*50 + "moe part" + "*"*50)
for key in moe_part.keys():
    moe_set.add(key)
    print(key)

print("*"*50 + "dense part" + "*"*50)
for key in dense_part.keys():
    dense_set.add(key)
    print(key)




**************************************************moe part**************************************************
cls_token
pos_embed
patch_embed.proj.weight
patch_embed.proj.bias
blocks.0.norm1.weight
blocks.0.norm1.bias
blocks.0.attn.qkv.weight
blocks.0.attn.qkv.bias
blocks.0.attn.proj.weight
blocks.0.attn.proj.bias
blocks.0.norm2.weight
blocks.0.norm2.bias
blocks.0.mlp.fc1.weight
blocks.0.mlp.fc1.bias
blocks.0.mlp.fc2.weight
blocks.0.mlp.fc2.bias
blocks.1.norm1.weight
blocks.1.norm1.bias
blocks.1.attn.qkv.weight
blocks.1.attn.qkv.bias
blocks.1.attn.proj.weight
blocks.1.attn.proj.bias
blocks.1.norm2.weight
blocks.1.norm2.bias
blocks.1.mlp.fc1.weight
blocks.1.mlp.fc1.bias
blocks.1.mlp.fc2.weight
blocks.1.mlp.fc2.bias
blocks.2.norm1.weight
blocks.2.norm1.bias
blocks.2.attn.qkv.weight
blocks.2.attn.qkv.bias
blocks.2.attn.proj.weight
blocks.2.attn.proj.bias
blocks.2.norm2.weight
blocks.2.norm2.bias
blocks.2.mlp.fc1.weight
blocks.2.mlp.fc1.bias
blocks.2.mlp.fc2.weight
blocks.2.mlp.fc2.bias
blo

In [24]:
same_part = moe_set & dense_set

diff_part_moe = moe_set - dense_set

diff_part_dense = dense_set - moe_set

print("*"*50 + "same part" + "*"*50)
for key in same_part:
    print(key)

print("*"*50 + "different part moe" + "*"*50)
for key in diff_part_moe:
    print(key)

print("*"*50 + "different part dense" + "*"*50)
for key in diff_part_dense:
    print(key)

**************************************************same part**************************************************
blocks.4.mlp.fc1.weight
blocks.4.norm1.weight
blocks.2.attn.proj.bias
blocks.4.mlp.fc2.weight
blocks.2.attn.proj.weight
blocks.3.attn.qkv.bias
blocks.1.attn.proj.bias
blocks.10.attn.qkv.bias
blocks.1.attn.qkv.weight
blocks.3.attn.proj.weight
blocks.4.mlp.fc1.bias
blocks.5.mlp.fc2.weight
blocks.5.attn.proj.bias
blocks.0.norm2.weight
blocks.11.norm2.weight
blocks.8.norm1.weight
blocks.11.norm1.bias
blocks.4.attn.proj.weight
blocks.11.attn.proj.weight
blocks.2.norm2.weight
blocks.4.norm1.bias
blocks.6.norm2.bias
blocks.7.attn.qkv.bias
blocks.4.norm2.bias
blocks.3.mlp.fc1.bias
cls_token
blocks.7.attn.proj.weight
blocks.1.mlp.fc1.weight
norm.bias
norm.weight
blocks.5.norm1.bias
blocks.2.attn.qkv.weight
blocks.0.attn.proj.bias
blocks.6.norm2.weight
patch_embed.proj.bias
blocks.2.mlp.fc2.bias
blocks.7.norm2.weight
blocks.5.norm2.weight
blocks.9.norm1.weight
blocks.3.mlp.fc1.weight
hea

In [32]:
from typing import OrderedDict, Set
from __future__ import annotations  # 添加这一行

def get_extrator(moe_part: dict[str, torch.Tensor], dense_part: dict[str, torch.Tensor], same_part: set[str]) -> dict[str, torch.Tensor]:  
    # 创建一个新的OrderedDict来存储修改后的权重
    final_dict = OrderedDict()
    
    # 复制所有MOE模型的权重
    for key, value in moe_part.items():
        if key in same_part and "head" not in key:
            # 对于相同部分（非head层），使用dense模型的权重
            dense_value = dense_part[key]
            if value.shape != dense_value.shape:
                print(f"skip {key}, shape error")
                final_dict[key] = value  # 保留原始权重
            else:
                final_dict[key] = dense_value  # 使用dense模型的权重
        elif "experts." in key:

            temp = key.split(".")
            del temp[3]
            dense_key = ".".join(temp)
            dense_value = dense_part[dense_key]
            value[:] = dense_value.T
            final_dict[key] = value

        else:
            # 对于不同部分，保留MOE模型的原始权重
            final_dict[key] = value
    
    return final_dict

# 创建最终的权重字典
final_dict = get_extrator(moe_part, dense_part, same_part)

# 保存权重到文件
torch.save(final_dict, "moe_model_with_pretrained_weights.pth")

# 如果需要验证权重是否可加载，可以尝试加载
# 测试加载权重到模型
test_model = VisionTransformer_zeke(embed_dim=192, num_heads=3, num_classes=100)
test_model.load_state_dict(final_dict)
print("权重加载成功!")

权重加载成功!


  value[:] = dense_value.T
