In [24]:
import wandb
from dataclasses import dataclass, field
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig

import torch

import logging

logging.basicConfig(level=logging.INFO)

WANDB_PROJECT = "shearllama"
ENTITY = "capecape"

@dataclass
class Config:
    model_id: str = "mistralai/Mistral-7B-v0.1"
    output_name: str = "models/mistral_7b_12_layers_start"
    layers_ids: list = field(default_factory=lambda: [0,1,2,3,4,5,6,7])
    save_tokenizer: bool = True
    device_map: str = "cuda:0"
    random: bool = False
    log: bool = True


In [3]:

config  = Config()

model_config = AutoConfig.from_pretrained(config.model_id)
model_config.num_hidden_layers = len(config.layers_ids)
logging.info(model_config)

INFO:root:MistralConfig {
  "_name_or_path": "mistralai/Mistral-7B-v0.1",
  "architectures": [
    "MistralForCausalLM"
  ],
  "attention_dropout": 0.0,
  "bos_token_id": 1,
  "eos_token_id": 2,
  "hidden_act": "silu",
  "hidden_size": 4096,
  "initializer_range": 0.02,
  "intermediate_size": 14336,
  "max_position_embeddings": 32768,
  "model_type": "mistral",
  "num_attention_heads": 32,
  "num_hidden_layers": 8,
  "num_key_value_heads": 8,
  "rms_norm_eps": 1e-05,
  "rope_theta": 10000.0,
  "sliding_window": 4096,
  "tie_word_embeddings": false,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.37.2",
  "use_cache": true,
  "vocab_size": 32000
}



In [25]:
original_model = AutoModelForCausalLM.from_pretrained(
    config.model_id, 
    torch_dtype=torch.bfloat16,
    device_map=config.device_map)


Loading checkpoint shards: 100%|██████████| 2/2 [00:04<00:00,  2.11s/it]


In [26]:
new_model = AutoModelForCausalLM.from_config(model_config)

In [27]:
def map_state_dict(original_model, layer_ids=[0,1,30,31], layer_naming="layers"):
    "We will map the parameters of the original model layer_ids to the new model layer_ids"
    name_mapping = {}
    layer_mapping = {layer_id: i for i, layer_id in enumerate(layer_ids)}
    print(f"Layer mapping: {layer_mapping}")
    for name, _ in original_model.named_parameters():
        if layer_naming in name:
            layer_id = int(name.split(".")[2])
            if layer_id in layer_ids:
                new_name = name.replace(f"{layer_naming}.{layer_id}", f"{layer_naming}.{layer_mapping[layer_id]}")
                name_mapping[name] = new_name
        else:
            name_mapping[name] = name
    return name_mapping

In [28]:
name_mapping = map_state_dict(original_model, [0,1,30,31])

Layer mapping: {0: 0, 1: 1, 30: 2, 31: 3}


In [29]:
# Manually copy weights and biases
for old_name, new_name in name_mapping.items():
    # Check if the mapped name exists in the new model's state_dict
    if new_name in new_model.state_dict():
        # Directly load the parameter from the old model to the new model based on the mapping
        new_model.state_dict()[new_name].data.copy_(original_model.state_dict()[old_name].data)
    else:
        print(f"{new_name} not found in the new model's state_dict. Check your mapping dictionary.")


In [30]:
new_model.num_parameters()

2007044096

In [32]:
new_model.to(torch.bfloat16)
new_model.save_pretrained("test_model")