# <font color='steelblue'>Weight Merging Script (requires installing mergekit)</font>

In [1]:
import torch
import yaml, json
from tqdm.notebook import tqdm
from mergekit.merge import MergeOptions, run_merge
from mergekit.config import MergeConfiguration

#================================================
# output model name
#================================================
version_number = "SLERP-2"
NEW_MODEL_ID = f"QW2.5_Q31.2-Q42.2MK_mergekit_{version_number}_promptV1_fewshot"
#================================================
# yaml_str variables
#================================================
MERGE_METHOD = "slerp"

BASE_MODEL = "../model/LORA_SFT/merged_lora_QW2.5_Q42.2_fewshot_train2000_R64_modAL_OE-AOPE-AOC-ASTE-ASQE_multitask_cascadedOrder_promptV1_fewshot"
    

SOURCE_MODEL_1 = BASE_MODEL
LAYER_RANGE_1 = [0, 28]  # used in mergekit/plan.py  'plan_slice'

SOURCE_MODEL_2 = "../model/LORA_SFT/merged_lora_QW2.5_Q31.2_fewshot_train1000_R8_modAL_OE-AOPE-AOC-ASTE-ASQE_multitask_cascadedOrder_promptV1_fewshot"
LAYER_RANGE_2 = [0, 28]

#==================================================

SELF_ATTEN_FILTER_VALUE = 0.5  # uniform mix, 70% on first model
MLP_FILTER_VALUE =        0.5  # uniform mix, 70% on first model

OTHER_LAYER_MIX_RATIO = 0.5
DTYPE = "bfloat16"


#================================================
# merge options
#================================================
# CONFIG_YML = "./examples/gradient-slerp.yml"  # merge configuration file
OUTPUT_PATH = f"../MY_models/mergeWeight_output/{NEW_MODEL_ID}"  # folder to store the result in
LORA_MERGE_CACHE = f"../MY_models/merged_output/lora_merge_cache"  # change if you want to keep these for some reason
COPY_TOKENIZER = True  # you want a tokenizer? yeah, that's what i thought
LAZY_UNPICKLE = False  # experimental low-memory model loader
LOW_CPU_MEMORY = False  # enable if you somehow have more VRAM than RAM+swap
TRUST_REMOTE_CODE=False
WRITE_MODEL_CARD=False
ALLOW_CRIMES=False


In [2]:
################################
# Run merge
################################

# with open(CONFIG_YML, "r", encoding="utf-8") as fp:
#     merge_config = MergeConfiguration.model_validate(yaml.safe_load(fp))

#----------------------------------------------------------------------------------------
# IMPORTANT: 
# add `tokenizer_source: {BASE_MODEL}` to the yml str under 'dtype' before 'slices' 
# if the source model's tokenizer_config.json does not include 'chat_template' field
#----------------------------------------------------------------------------------------
config_yaml_str = f"""
merge_method: {MERGE_METHOD}
base_model: {BASE_MODEL}
dtype: {DTYPE}
slices:
  - sources:
      - model: {SOURCE_MODEL_1}
        layer_range: {LAYER_RANGE_1}
      - model: {SOURCE_MODEL_2}
        layer_range: {LAYER_RANGE_2}
parameters:
  t:
    - filter: self_attn
      value: {SELF_ATTEN_FILTER_VALUE}
    - filter: mlp
      value: {MLP_FILTER_VALUE}
    - value: {OTHER_LAYER_MIX_RATIO}
"""

# Parse YAML to Python dict
merge_config = MergeConfiguration.model_validate(yaml.safe_load(config_yaml_str))

run_merge(
    merge_config,
    out_path=OUTPUT_PATH,
    options=MergeOptions(
        lora_merge_cache=LORA_MERGE_CACHE,
        cuda=torch.cuda.is_available(),
        copy_tokenizer=COPY_TOKENIZER,
        lazy_unpickle=LAZY_UNPICKLE,
        low_cpu_memory=LOW_CPU_MEMORY,
        trust_remote_code=TRUST_REMOTE_CODE,
        write_model_card=WRITE_MODEL_CARD,
        allow_crimes=ALLOW_CRIMES
    ),
)
print(f"Done!, model saved to \033[33m{OUTPUT_PATH}\033[0m")

(from mergekit/architecture/phi4_defs.py: ) Will copy these files from the base model: ['generation_config.json', 'tokenizer_config.json', 'special_tokens_map.json', 'tokenizer.json', 'tokenizer.model', 'added_tokens.json', 'merges.txt']

(from mergekit/architecture/phi4_defs.py: ) Will copy these files from the base model: ['generation_config.json', 'tokenizer_config.json', 'special_tokens_map.json', 'tokenizer.json', 'tokenizer.model', 'added_tokens.json', 'merges.txt']



Warmup loader cache: 100%|██████████| 2/2 [00:00<00:00, 998.17it/s]
Executing graph: 100%|██████████| 1692/1692 [00:24<00:00, 69.90it/s] 

Done!, model saved to [33m../MY_models/mergeWeight_output/QW2.5_Q31.2-Q42.2MK_mergekit_SLERP-2_promptV1_fewshot[0m



