# Model registration
Based on https://huggingface.co/docs/transformers/main/en/custom_models#registering-a-model-with-custom-code-to-the-auto-classes

In [1]:
%load_ext autoreload
%autoreload 2
CUDA_VISIBLE_DEVICES = -1
import transformers
from multihead_models import *
import torch



In [2]:
mhllamaconfig = MHLlamaConfig()
model = MultiheadLlamaForCausalLM(mhllamaconfig)
model

MultiheadLlamaForCausalLM(
  (model): MOELlamaModel(
    (embed_tokens): Embedding(32000, 4096, padding_idx=0)
    (layers): ModuleList(
      (0-31): 32 x MOELlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): ModuleList(
          (0-4): 5 x LlamaMLP(
            (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
            (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
            (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
            (act_fn): SiLU()
          )
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layer

In [3]:
from transformers import LlamaForCausalLM

ohmodel = LlamaForCausalLM.from_pretrained("/mnt/data/zoo/llama2/llama2-7b-hf/")
ohmodel

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096, padding_idx=0)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
  )
 

In [4]:
ohdict = dict(ohmodel.named_modules())
num_decoders = 32
num_heads = 5
model.model.embed_tokens.weight = ohmodel.model.embed_tokens.weight
for i in range(num_decoders):
    model.model.layers[i].self_attn.q_proj.weight = ohmodel.model.layers[i].self_attn.q_proj.weight
    model.model.layers[i].self_attn.k_proj.weight = ohmodel.model.layers[i].self_attn.k_proj.weight
    model.model.layers[i].self_attn.v_proj.weight = ohmodel.model.layers[i].self_attn.v_proj.weight
    model.model.layers[i].self_attn.o_proj.weight = ohmodel.model.layers[i].self_attn.o_proj.weight
    model.model.layers[i].self_attn.q_proj.bias = ohmodel.model.layers[i].self_attn.q_proj.bias
    model.model.layers[i].self_attn.k_proj.bias = ohmodel.model.layers[i].self_attn.k_proj.bias
    model.model.layers[i].self_attn.v_proj.bias = ohmodel.model.layers[i].self_attn.v_proj.bias
    model.model.layers[i].self_attn.o_proj.bias = ohmodel.model.layers[i].self_attn.o_proj.bias
    for h in range(num_heads):
        model.model.layers[i].mlp[h].gate_proj.weight = ohmodel.model.layers[i].mlp.gate_proj.weight
        model.model.layers[i].mlp[h].up_proj.weight = ohmodel.model.layers[i].mlp.up_proj.weight
        model.model.layers[i].mlp[h].down_proj.weight = ohmodel.model.layers[i].mlp.down_proj.weight
        model.model.layers[i].mlp[h].gate_proj.bias = ohmodel.model.layers[i].mlp.gate_proj.bias
        model.model.layers[i].mlp[h].up_proj.bias = ohmodel.model.layers[i].mlp.up_proj.bias
        model.model.layers[i].mlp[h].down_proj.bias = ohmodel.model.layers[i].mlp.down_proj.bias
    model.model.layers[i].input_layernorm.weight = ohmodel.model.layers[i].input_layernorm.weight
    model.model.layers[i].post_attention_layernorm.weight = ohmodel.model.layers[i].post_attention_layernorm.weight
model.model.norm.weight = ohmodel.model.norm.weight
for h in range(num_heads):
    model.heads[h].weight = ohmodel.lm_head.weight
    model.heads[h].bias = ohmodel.lm_head.bias

In [5]:
torch.save(model.state_dict(), './mhllama/pytorch_model.bin')

In [6]:
mhllamaconfig.save_pretrained("mhllama")

## Try

In [2]:
mhllamaconfig = MHLlamaConfig()
transformers.AutoConfig.register('mhllama', MHLlamaConfig)
# transformers.AutoModel.register(MHLlamaConfig, MultiheadLlamaForCausalLM)
transformers.AutoModelForCausalLM.register(MHLlamaConfig, MultiheadLlamaForCausalLM)
model = transformers.AutoModelForCausalLM.from_pretrained('/home/sonia/llama-qlora/mhllama')

  return self.fget.__get__(instance, owner)()


In [3]:
model

MultiheadLlamaForCausalLM(
  (model): MOELlamaModel(
    (embed_tokens): Embedding(32000, 4096, padding_idx=0)
    (layers): ModuleList(
      (0-31): 32 x MOELlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): ModuleList(
          (0-4): 5 x LlamaMLP(
            (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
            (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
            (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
            (act_fn): SiLU()
          )
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layer