In [1]:
"""
Model merging training implementation using PyTorch and Transformers.
Implements custom data collation and training for merged language models.
"""

from dataclasses import dataclass
from typing import (
    Any, Callable, Dict, 
    List, NewType, Optional, 
    Tuple, Union, Mapping
)
from abc import ABC, abstractmethod
from datasets import load_dataset, concatenate_datasets
from accelerate.logging import get_logger
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

import torch
import safetensors
import math
import yaml
import logging
import copy
import gc
import os
import argparse
import sys
import shutil

from tqdm import tqdm
from transformers import (
    PreTrainedTokenizerBase,
    PreTrainedModel,
    PretrainedConfig,
    AutoConfig,
    AutoModelForCausalLM,
    AutoTokenizer,
    LlamaForCausalLM,
    LlamaConfig,
    Trainer,
    TrainingArguments,
    HfArgumentParser,
    default_data_collator,
    is_torch_xla_available,
    set_seed,
)

from transformers.utils import CONFIG_NAME
from transformers.pytorch_utils import is_torch_greater_or_equal_than_1_13

from merger import (
    MergerConfig,
    Merger,
    # NewMerger,
    init_masks,
    set_masks
)

from utils import (
    generate, 
    get_hidden_states, 
    get_logits,
    free_memory
)
# Configure logger
from logging_config import configure_logging
configure_logging()
logger = logging.getLogger("train")

In [2]:
checkpoint_dir = "../results/run_01/checkpoint-200/"
merge_config = MergerConfig.from_pretrained(checkpoint_dir)
merge_config

MergerConfig {
  "architectures": [
    "NewMerger"
  ],
  "constrain_mode": "identity",
  "mode": "vector_input",
  "model_paths": [
    "/workspace/models/llama-3.2-3b-wizard",
    "/workspace/models/llama-3.2-3b-math"
  ],
  "torch_dtype": "bfloat16",
  "transformers_version": "4.46.3"
}

In [3]:
device_map={"":0}
merger = Merger.from_pretrained(
    checkpoint_dir,
    merge_config,
    torch_dtype=torch.bfloat16,
    device_map=device_map,
    # attn_implementation="flash_attention_2",
)

[2025-01-16 14:52:14,500] [INFO] [merger.from_pretrained:315] [PID:143678] [RANK:0] >>> Merger device: {'': 0}[39m
[2025-01-16 14:52:14,506] [INFO] [merger.__init__:207] [PID:143678] [RANK:0] Creating merger with dummy weights ...[39m


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Initializing masks:   1%|█▎                                                                                                                                                                            | 2/255 [00:18<38:51,  9.22s/it]



Initializing masks: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 255/255 [01:09<00:00,  3.69it/s]


In [5]:
merger.config

LlamaConfig {
  "_name_or_path": "/workspace/models/llama-3.2-3b-wizard",
  "architectures": [
    "LlamaForCausalLM"
  ],
  "attention_bias": false,
  "attention_dropout": 0.0,
  "bos_token_id": 128000,
  "eos_token_id": 128001,
  "head_dim": 128,
  "hidden_act": "silu",
  "hidden_size": 3072,
  "initializer_range": 0.02,
  "intermediate_size": 8192,
  "max_position_embeddings": 131072,
  "mlp_bias": false,
  "model_type": "llama",
  "num_attention_heads": 24,
  "num_hidden_layers": 28,
  "num_key_value_heads": 8,
  "pretraining_tp": 1,
  "rms_norm_eps": 1e-05,
  "rope_scaling": {
    "factor": 32.0,
    "high_freq_factor": 4.0,
    "low_freq_factor": 1.0,
    "original_max_position_embeddings": 8192,
    "rope_type": "llama3"
  },
  "rope_theta": 500000.0,
  "tie_word_embeddings": false,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.46.3",
  "use_cache": false,
  "vocab_size": 128256
}

In [16]:
merger.save_merged("/workspace/logits-guided-merger/results/run_01")

Merging masked modules: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 255/255 [00:02<00:00, 96.46it/s]


In [17]:
tokenizer.save_pretrained("/workspace/logits-guided-merger/results/run_01")

('/workspace/logits-guided-merger/results/run_01/tokenizer_config.json',
 '/workspace/logits-guided-merger/results/run_01/special_tokens_map.json',
 '/workspace/logits-guided-merger/results/run_01/tokenizer.json')

In [7]:
merger.config

LlamaConfig {
  "_name_or_path": "/workspace/logits-guided-merger/dev/hehe",
  "architectures": [
    "Merger"
  ],
  "attention_bias": false,
  "attention_dropout": 0.0,
  "bos_token_id": 128000,
  "eos_token_id": 128001,
  "head_dim": 128,
  "hidden_act": "silu",
  "hidden_size": 3072,
  "initializer_range": 0.02,
  "intermediate_size": 8192,
  "max_position_embeddings": 131072,
  "mlp_bias": false,
  "model_type": "llama",
  "num_attention_heads": 24,
  "num_hidden_layers": 28,
  "num_key_value_heads": 8,
  "pretraining_tp": 1,
  "rms_norm_eps": 1e-05,
  "rope_scaling": {
    "factor": 32.0,
    "high_freq_factor": 4.0,
    "low_freq_factor": 1.0,
    "original_max_position_embeddings": 8192,
    "rope_type": "llama3"
  },
  "rope_theta": 500000.0,
  "tie_word_embeddings": false,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.46.3",
  "use_cache": false,
  "vocab_size": 128256
}

In [8]:
merged = AutoModelForCausalLM.from_pretrained(
    "hehe",
    torch_dtype=torch.bfloat16,
    device_map=None
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [11]:
type(merged)

transformers.models.llama.modeling_llama.LlamaForCausalLM

In [12]:
tokenizer = AutoTokenizer.from_pretrained("/workspace/models/llama-3.2-3b-wizard")
text = "Lee Min Ho is someone I don't trust."
merger_logits = get_logits(text, merger.merger, tokenizer)
merged_logits = get_logits(text, merged, tokenizer)

In [13]:
torch.allclose(merger_logits, merged_logits)

True

In [15]:
merger_logits, merged_logits

(tensor([[[ 5.0938,  7.5938, 12.3750,  ..., -5.3438, -5.3438, -5.3438],
          [ 5.7500,  4.1562,  1.7812,  ..., -5.1875, -5.1875, -5.1875],
          [ 5.6562,  4.6875,  4.0312,  ..., -4.3125, -4.3125, -4.3125],
          ...,
          [ 4.6250,  6.3438,  0.4102,  ..., -3.1719, -3.1719, -3.1719],
          [12.2500,  7.4375,  5.3125,  ..., -3.1562, -3.1562, -3.1562],
          [ 1.0938, -0.9844,  2.9062,  ..., -2.9688, -2.9688, -2.9688]]],
        dtype=torch.bfloat16),
 tensor([[[ 5.0938,  7.5938, 12.3750,  ..., -5.3438, -5.3438, -5.3438],
          [ 5.7500,  4.1562,  1.7812,  ..., -5.1875, -5.1875, -5.1875],
          [ 5.6562,  4.6875,  4.0312,  ..., -4.3125, -4.3125, -4.3125],
          ...,
          [ 4.6250,  6.3438,  0.4102,  ..., -3.1719, -3.1719, -3.1719],
          [12.2500,  7.4375,  5.3125,  ..., -3.1562, -3.1562, -3.1562],
          [ 1.0938, -0.9844,  2.9062,  ..., -2.9688, -2.9688, -2.9688]]],
        dtype=torch.bfloat16))