In [1]:
"""
Model merging training implementation using PyTorch and Transformers.
Implements custom data collation and training for merged language models.
"""

from dataclasses import dataclass
from typing import (
    Any, Callable, Dict, 
    List, NewType, Optional, 
    Tuple, Union, Mapping
)
from abc import ABC, abstractmethod
from datasets import load_dataset, concatenate_datasets
from accelerate.logging import get_logger
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

import torch
import safetensors
import math
import yaml
import logging
import copy
import gc
import os
import argparse
import sys
import shutil

from tqdm import tqdm
from transformers import (
    PreTrainedTokenizerBase,
    PreTrainedModel,
    PretrainedConfig,
    AutoConfig,
    AutoModelForCausalLM,
    AutoTokenizer,
    LlamaForCausalLM,
    LlamaConfig,
    Trainer,
    TrainingArguments,
    HfArgumentParser,
    default_data_collator,
    is_torch_xla_available,
    set_seed,
)

from transformers.utils import CONFIG_NAME
from transformers.pytorch_utils import is_torch_greater_or_equal_than_1_13

from merger import (
    MergerConfig,
    Merger,
    # NewMerger,
    init_masks,
    set_masks
)

from utils import (
    generate, 
    get_hidden_states, 
    get_logits,
    free_memory
)
# Configure logger
from logging_config import configure_logging
configure_logging()
logger = logging.getLogger("train")

In [8]:
MergerConfig.from_pretrained("../results/run_03/checkpoint-702/merger_config.json")

MergerConfig {
  "_attn_implementation_autoset": true,
  "constrain_mode": "identity",
  "mode": "vector_input",
  "model_paths": [
    "/workspace/models/llama-3.2-3b-wizard",
    "/workspace/models/llama-3.2-3b-math"
  ],
  "transformers_version": "4.46.3"
}

In [2]:
from tqdm.notebook import tqdm
from masks import LinearsWithMasks, RMSNormsWithMasks, EmbeddingsWithMasks

In [3]:
checkpoint_dir = "../results/run_01/checkpoint-200/"

In [4]:
from safetensors.torch import load_file as safe_load_file
masks_path = os.path.join(checkpoint_dir, "masks.safetensors")
state_dict = safe_load_file(masks_path)

In [5]:
merge_config = MergerConfig.from_pretrained(checkpoint_dir)
merge_config

MergerConfig {
  "architectures": [
    "NewMerger"
  ],
  "constrain_mode": "identity",
  "mode": "vector_input",
  "model_paths": [
    "/workspace/models/llama-3.2-3b-wizard",
    "/workspace/models/llama-3.2-3b-math"
  ],
  "torch_dtype": "bfloat16",
  "transformers_version": "4.46.3"
}

In [6]:
device_map={"":0}
merger = Merger.from_pretrained(
    checkpoint_dir,
    merge_config,
    torch_dtype=torch.bfloat16,
    device_map=device_map,
    # attn_implementation="flash_attention_2",
)

[2025-01-16 13:12:34,175] [INFO] [merger.from_pretrained:311] [PID:135561] [RANK:0] >>> Merger device: {'': 0}[39m
[2025-01-16 13:12:34,186] [INFO] [merger.__init__:205] [PID:135561] [RANK:0] Creating merger with dummy weights ...[39m


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Initializing masks:   1%|█▎                                                                                                                                                                            | 2/255 [00:18<38:16,  9.08s/it]



Initializing masks: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 255/255 [01:05<00:00,  3.92it/s]


In [9]:
merger.device

device(type='cpu')

In [17]:
merger = merger.to("cuda:0")

In [8]:
merger.save_pretrained("./hehe")

[2025-01-16 10:41:02,317] [INFO] [real_accelerator.py:219:get_accelerator] Setting ds_accelerator to cuda (auto detect)
[2025-01-16 10:41:02,430] [INFO] [root.spawn:61] [PID:121910] [RANK:0] gcc -pthread -B /opt/conda/compiler_compat -DNDEBUG -fwrapv -O2 -Wall -fPIC -O2 -isystem /opt/conda/include -fPIC -O2 -isystem /opt/conda/include -fPIC -c /tmp/tmpj3ei8cdl/test.c -o /tmp/tmpj3ei8cdl/test.o[39m
[2025-01-16 10:41:02,454] [INFO] [root.spawn:61] [PID:121910] [RANK:0] gcc -pthread -B /opt/conda/compiler_compat /tmp/tmpj3ei8cdl/test.o -laio -o /tmp/tmpj3ei8cdl/a.out[39m


/opt/conda/compiler_compat/ld: cannot find -laio: No such file or directory
collect2: error: ld returned 1 exit status


[2025-01-16 10:41:03,287] [INFO] [root.spawn:61] [PID:121910] [RANK:0] gcc -pthread -B /opt/conda/compiler_compat -DNDEBUG -fwrapv -O2 -Wall -fPIC -O2 -isystem /opt/conda/include -fPIC -O2 -isystem /opt/conda/include -fPIC -c /tmp/tmpmvz3o6s6/test.c -o /tmp/tmpmvz3o6s6/test.o[39m
[2025-01-16 10:41:03,309] [INFO] [root.spawn:61] [PID:121910] [RANK:0] gcc -pthread -B /opt/conda/compiler_compat /tmp/tmpmvz3o6s6/test.o -L/usr/local/cuda -L/usr/local/cuda/lib64 -lcufile -o /tmp/tmpmvz3o6s6/a.out[39m
[2025-01-16 10:41:03,415] [INFO] [root.spawn:61] [PID:121910] [RANK:0] gcc -pthread -B /opt/conda/compiler_compat -DNDEBUG -fwrapv -O2 -Wall -fPIC -O2 -isystem /opt/conda/include -fPIC -O2 -isystem /opt/conda/include -fPIC -c /tmp/tmp7mw2hk4j/test.c -o /tmp/tmp7mw2hk4j/test.o[39m
[2025-01-16 10:41:03,441] [INFO] [root.spawn:61] [PID:121910] [RANK:0] gcc -pthread -B /opt/conda/compiler_compat /tmp/tmp7mw2hk4j/test.o -laio -o /tmp/tmp7mw2hk4j/a.out[39m


/opt/conda/compiler_compat/ld: cannot find -laio: No such file or directory
collect2: error: ld returned 1 exit status


In [7]:
merger.save_merged("./hehe")

Merging masked modules: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 255/255 [00:04<00:00, 62.32it/s]


[2025-01-16 13:13:54,659] [INFO] [real_accelerator.py:219:get_accelerator] Setting ds_accelerator to cuda (auto detect)
[2025-01-16 13:13:54,784] [INFO] [root.spawn:61] [PID:135561] [RANK:0] gcc -pthread -B /opt/conda/compiler_compat -DNDEBUG -fwrapv -O2 -Wall -fPIC -O2 -isystem /opt/conda/include -fPIC -O2 -isystem /opt/conda/include -fPIC -c /tmp/tmplh1mmkt3/test.c -o /tmp/tmplh1mmkt3/test.o[39m
[2025-01-16 13:13:54,808] [INFO] [root.spawn:61] [PID:135561] [RANK:0] gcc -pthread -B /opt/conda/compiler_compat /tmp/tmplh1mmkt3/test.o -laio -o /tmp/tmplh1mmkt3/a.out[39m


/opt/conda/compiler_compat/ld: cannot find -laio: No such file or directory
collect2: error: ld returned 1 exit status


[2025-01-16 13:13:55,856] [INFO] [root.spawn:61] [PID:135561] [RANK:0] gcc -pthread -B /opt/conda/compiler_compat -DNDEBUG -fwrapv -O2 -Wall -fPIC -O2 -isystem /opt/conda/include -fPIC -O2 -isystem /opt/conda/include -fPIC -c /tmp/tmp108089z2/test.c -o /tmp/tmp108089z2/test.o[39m
[2025-01-16 13:13:55,885] [INFO] [root.spawn:61] [PID:135561] [RANK:0] gcc -pthread -B /opt/conda/compiler_compat /tmp/tmp108089z2/test.o -L/usr/local/cuda -L/usr/local/cuda/lib64 -lcufile -o /tmp/tmp108089z2/a.out[39m
[2025-01-16 13:13:55,994] [INFO] [root.spawn:61] [PID:135561] [RANK:0] gcc -pthread -B /opt/conda/compiler_compat -DNDEBUG -fwrapv -O2 -Wall -fPIC -O2 -isystem /opt/conda/include -fPIC -O2 -isystem /opt/conda/include -fPIC -c /tmp/tmprpnrx6w9/test.c -o /tmp/tmprpnrx6w9/test.o[39m
[2025-01-16 13:13:56,019] [INFO] [root.spawn:61] [PID:135561] [RANK:0] gcc -pthread -B /opt/conda/compiler_compat /tmp/tmprpnrx6w9/test.o -laio -o /tmp/tmprpnrx6w9/a.out[39m


/opt/conda/compiler_compat/ld: cannot find -laio: No such file or directory
collect2: error: ld returned 1 exit status


In [10]:
from transformers import AutoModelForCausalLM
import torch
ckpt = "./hehe"
config = AutoConfig.from_pretrained(ckpt)
config.tie_word_embeddings = False
merged = AutoModelForCausalLM.from_pretrained(
    ckpt,
    config=config,
    torch_dtype=torch.bfloat16,
    device_map="cpu"
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [14]:
merged.model.embed_tokens.weight

Parameter containing:
tensor([[ 0.0106,  0.0116,  0.0130,  ..., -0.0029, -0.0182,  0.0064],
        [ 0.0128,  0.0011,  0.0210,  ...,  0.0013,  0.0310, -0.0023],
        [ 0.0238,  0.0200,  0.0288,  ..., -0.0014, -0.0008, -0.0072],
        ...,
        [-0.0068,  0.0018,  0.0047,  ..., -0.0054, -0.0018, -0.0066],
        [-0.0068,  0.0018,  0.0047,  ..., -0.0054, -0.0018, -0.0066],
        [-0.0068,  0.0018,  0.0047,  ..., -0.0054, -0.0018, -0.0066]],
       dtype=torch.bfloat16, requires_grad=True)

In [12]:
merged.lm_head.weight

Parameter containing:
tensor([[ 0.0115,  0.0122,  0.0132,  ..., -0.0026, -0.0195,  0.0071],
        [ 0.0138,  0.0011,  0.0212,  ...,  0.0012,  0.0332, -0.0026],
        [ 0.0256,  0.0212,  0.0291,  ..., -0.0013, -0.0009, -0.0079],
        ...,
        [-0.0074,  0.0019,  0.0047,  ..., -0.0049, -0.0019, -0.0073],
        [-0.0074,  0.0019,  0.0047,  ..., -0.0049, -0.0019, -0.0073],
        [-0.0074,  0.0019,  0.0047,  ..., -0.0049, -0.0019, -0.0073]],
       dtype=torch.bfloat16, requires_grad=True)

In [16]:
merged = merged.to("cuda:0")

In [15]:
merged.lm_head.weight.data

tensor([[ 0.0115,  0.0122,  0.0132,  ..., -0.0026, -0.0195,  0.0071],
        [ 0.0138,  0.0011,  0.0212,  ...,  0.0012,  0.0332, -0.0026],
        [ 0.0256,  0.0212,  0.0291,  ..., -0.0013, -0.0009, -0.0079],
        ...,
        [-0.0074,  0.0019,  0.0047,  ..., -0.0049, -0.0019, -0.0073],
        [-0.0074,  0.0019,  0.0047,  ..., -0.0049, -0.0019, -0.0073],
        [-0.0074,  0.0019,  0.0047,  ..., -0.0049, -0.0019, -0.0073]],
       dtype=torch.bfloat16)

In [15]:
from safetensors import safe_open
with safe_open("./hehe/model-00002-of-00002.safetensors", framework="pt", device="cpu") as f:
    lm_head = f.get_tensor("lm_head.weight")

SafetensorError: File does not contain tensor lm_head.weight

In [6]:
lm_head

NameError: name 'lm_head' is not defined

In [49]:
merged = merged.to(device="cuda:0", dtype=torch.bfloat16)

In [24]:
lm_head

NameError: name 'lm_head' is not defined

In [50]:
merged.lm_head.weight

Parameter containing:
tensor([[ 0.0115,  0.0122,  0.0132,  ..., -0.0026, -0.0195,  0.0071],
        [ 0.0138,  0.0011,  0.0212,  ...,  0.0012,  0.0332, -0.0026],
        [ 0.0256,  0.0212,  0.0291,  ..., -0.0013, -0.0009, -0.0079],
        ...,
        [-0.0074,  0.0019,  0.0047,  ..., -0.0049, -0.0019, -0.0073],
        [-0.0074,  0.0019,  0.0047,  ..., -0.0049, -0.0019, -0.0073],
        [-0.0074,  0.0019,  0.0047,  ..., -0.0049, -0.0019, -0.0073]],
       device='cuda:0', dtype=torch.bfloat16, requires_grad=True)

In [26]:
merged.model.layers[0].mlp.up_proj.weight

Parameter containing:
tensor([[-0.0120,  0.0148, -0.0029,  ...,  0.0238, -0.0004,  0.0036],
        [-0.0034,  0.0145, -0.0186,  ..., -0.0273, -0.0190,  0.0065],
        [-0.0107, -0.0087,  0.0237,  ...,  0.0056, -0.0081, -0.0094],
        ...,
        [-0.0273,  0.0020,  0.0060,  ..., -0.0080,  0.0002, -0.0166],
        [ 0.0193,  0.0087,  0.0069,  ...,  0.0099,  0.0134, -0.0220],
        [-0.0126,  0.0120, -0.0019,  ..., -0.0112, -0.0142, -0.0146]],
       dtype=torch.bfloat16, requires_grad=True)

In [27]:
merged.lm_head.weight

Parameter containing:
tensor([[ 0.0106,  0.0116,  0.0130,  ..., -0.0029, -0.0182,  0.0064],
        [ 0.0128,  0.0011,  0.0210,  ...,  0.0013,  0.0310, -0.0023],
        [ 0.0238,  0.0200,  0.0288,  ..., -0.0014, -0.0008, -0.0072],
        ...,
        [-0.0068,  0.0018,  0.0047,  ..., -0.0054, -0.0018, -0.0066],
        [-0.0068,  0.0018,  0.0047,  ..., -0.0054, -0.0018, -0.0066],
        [-0.0068,  0.0018,  0.0047,  ..., -0.0054, -0.0018, -0.0066]],
       dtype=torch.bfloat16, requires_grad=True)

In [11]:
merged = merged.to("cuda:1")

In [19]:
merger.device

device(type='cuda', index=0)

In [19]:
tokenizer = AutoTokenizer.from_pretrained(merge_config.model_paths[0])

In [20]:
text = "Lee Min Ho is someone I don't trust."
merger_logits = get_logits(text, merger.merger, tokenizer)
merged_logits = get_logits(text, merged, tokenizer)

In [23]:
torch.allclose(merger_logits, merged_logits, atol=0, rtol=0)

True

In [25]:
def get_outputs(text, model, tokenizer):
    input_ids = tokenizer(text, return_tensors="pt").to(model.device)
    model.eval()
    with torch.no_grad():
        outputs = model(**input_ids, output_hidden_states=True)
    return outputs

In [26]:
text = "Lee Min Ho is someone I don't trust."
merger_outputs = get_outputs(text, merger.merger, tokenizer)
merged_outputs = get_outputs(text, merged, tokenizer)

In [27]:
outhead = merger.merger.lm_head
weight_masks = outhead.get_constrained_masks()["weight_masks"]
merged_outhead = sum((mask * linear.weight).to("cpu") for mask, linear in zip(weight_masks, outhead.linears))

In [28]:
outhead.get_constrained_masks()["bias_masks"]

[None, None]

In [29]:
merged_outhead

tensor([[ 0.0115,  0.0122,  0.0132,  ..., -0.0026, -0.0195,  0.0071],
        [ 0.0138,  0.0011,  0.0212,  ...,  0.0012,  0.0332, -0.0026],
        [ 0.0256,  0.0212,  0.0291,  ..., -0.0013, -0.0009, -0.0079],
        ...,
        [-0.0074,  0.0019,  0.0047,  ..., -0.0049, -0.0019, -0.0073],
        [-0.0074,  0.0019,  0.0047,  ..., -0.0049, -0.0019, -0.0073],
        [-0.0074,  0.0019,  0.0047,  ..., -0.0049, -0.0019, -0.0073]],
       dtype=torch.bfloat16, grad_fn=<AddBackward0>)

In [30]:
merged.lm_head.weight.data

tensor([[ 0.0115,  0.0122,  0.0132,  ..., -0.0026, -0.0195,  0.0071],
        [ 0.0138,  0.0011,  0.0212,  ...,  0.0012,  0.0332, -0.0026],
        [ 0.0256,  0.0212,  0.0291,  ..., -0.0013, -0.0009, -0.0079],
        ...,
        [-0.0074,  0.0019,  0.0047,  ..., -0.0049, -0.0019, -0.0073],
        [-0.0074,  0.0019,  0.0047,  ..., -0.0049, -0.0019, -0.0073],
        [-0.0074,  0.0019,  0.0047,  ..., -0.0049, -0.0019, -0.0073]],
       device='cuda:0', dtype=torch.bfloat16)

In [24]:
up_proj = merger.merger.model.layers[0].mlp.up_proj
weight_masks = up_proj.get_constrained_masks()["weight_masks"]
up_proj_a = sum((mask * linear.weight).to("cpu") for mask, linear in zip(weight_masks, up_proj.linears))
up_proj_merged = merged.model.layers[0].mlp.up_proj

In [26]:
up_proj_a, up_proj_merged.weight.data

(tensor([[-0.0120,  0.0148, -0.0029,  ...,  0.0238, -0.0004,  0.0036],
         [-0.0034,  0.0145, -0.0186,  ..., -0.0273, -0.0190,  0.0065],
         [-0.0107, -0.0087,  0.0237,  ...,  0.0056, -0.0081, -0.0094],
         ...,
         [-0.0273,  0.0020,  0.0060,  ..., -0.0080,  0.0002, -0.0166],
         [ 0.0193,  0.0087,  0.0069,  ...,  0.0099,  0.0134, -0.0220],
         [-0.0126,  0.0120, -0.0019,  ..., -0.0112, -0.0142, -0.0146]],
        dtype=torch.bfloat16, grad_fn=<AddBackward0>),
 tensor([[-0.0120,  0.0148, -0.0029,  ...,  0.0238, -0.0004,  0.0036],
         [-0.0034,  0.0145, -0.0186,  ..., -0.0273, -0.0190,  0.0065],
         [-0.0107, -0.0087,  0.0237,  ...,  0.0056, -0.0081, -0.0094],
         ...,
         [-0.0273,  0.0020,  0.0060,  ..., -0.0080,  0.0002, -0.0166],
         [ 0.0193,  0.0087,  0.0069,  ...,  0.0099,  0.0134, -0.0220],
         [-0.0126,  0.0120, -0.0019,  ..., -0.0112, -0.0142, -0.0146]],
        dtype=torch.bfloat16))

In [72]:
merger_outputs.hidden_states[-1]

tensor([[[ 0.3730, -0.4844,  0.3320,  ...,  0.5938,  1.2500,  0.9961],
         [-1.2969, -1.4219, -4.4062,  ..., -3.0781, -2.0156,  0.8320],
         [-1.8516, -0.3418, -0.3828,  ...,  1.4453, -2.2812, -0.0574],
         ...,
         [ 0.6016, -0.2266, -0.2129,  ..., -0.8672, -0.5391,  0.1826],
         [-1.0547, -0.4297,  3.0312,  ..., -0.7539, -1.3359, -0.7773],
         [-2.9219, -0.6211, -3.6250,  ..., -1.2812, -0.0154, -2.7500]]],
       device='cuda:0', dtype=torch.bfloat16)

In [73]:
merged_outputs.hidden_states[-1]

tensor([[[ 0.3730, -0.4844,  0.3320,  ...,  0.5938,  1.2500,  0.9961],
         [-1.2969, -1.4219, -4.4062,  ..., -3.0781, -2.0156,  0.8320],
         [-1.8516, -0.3418, -0.3828,  ...,  1.4453, -2.2812, -0.0574],
         ...,
         [ 0.6016, -0.2266, -0.2129,  ..., -0.8672, -0.5391,  0.1826],
         [-1.0547, -0.4297,  3.0312,  ..., -0.7539, -1.3359, -0.7773],
         [-2.9219, -0.6211, -3.6250,  ..., -1.2812, -0.0154, -2.7500]]],
       device='cuda:1', dtype=torch.bfloat16)

In [132]:
def save_merged(
    self
):
    """
    Compute merged weights using masks and component weights, then save to directory.
    Removes component weights and masks from the final state dict.
    """
    def compute(mask, weight):
        computed = mask * weight
        return computed

    def merge_linears(name, module):
        merged_state = {}
        keys_to_remove = set()
        for i in range(len(module.linears)):
            keys_to_remove.add(f"{name}.linears.{i}.weight")
            if module.linears[i].bias is not None:
                keys_to_remove.add(f"{name}.linears.{i}.bias")
            keys_to_remove.add(f"{name}.weight_masks.{i}.weight")
            if module.bias_masks[i] is not None:
                keys_to_remove.add(f"{name}.bias_masks.{i}.weight")
        
        # Get merged weights
        weight_masks = module.get_constrained_masks()["weight_masks"]
        merged_weight = sum(
            compute(mask, linear.weight)
            for mask, linear in zip(weight_masks, module.linears)
        ).cpu().detach()
        merged_state[f"{name}.weight"] = merged_weight
        
        # Get merged biases if they exist
        if hasattr(module, "bias_masks") and module.bias_masks[0] is not None:
            bias_masks = module.get_constrained_masks()["bias_masks"]
            merged_bias = sum(
                compute(mask, linear.bias) if linear.bias is not None else 0
                for mask, linear in zip(bias_masks, module.linears)
            )
            merged_state[f"{name}.bias"] = merged_bias
        return merged_state, keys_to_remove

    def merge_embeddings(name, module):
        merged_state = {}
        keys_to_remove = set()
        # Remove component embeddings and their masks
        for i in range(len(module.embeddings)):
            keys_to_remove.add(f"{name}.embeddings.{i}.weight")
            keys_to_remove.add(f"{name}.masks.{i}.weight")
        
        # Get merged weights
        masks = module.get_constrained_masks()["masks"]
        merged_weight = sum(
            compute(mask, emb.weight)
            for mask, emb in zip(masks, module.embeddings)
        ).cpu().detach()
        merged_state[f"{name}.weight"] = merged_weight
        return merged_state, keys_to_remove

    def merge_rmsnorms(name, module):
        merged_state = {}
        keys_to_remove = set()
        # Remove component norms and their masks
        for i in range(len(module.rms_norms)):
            keys_to_remove.add(f"{name}.rms_norms.{i}.weight")
            keys_to_remove.add(f"{name}.masks.{i}.weight")
        
        # Get merged weights
        masks = module.get_constrained_masks()["masks"]
        merged_weight = sum(
            compute(mask, norm.weight)
            for mask, norm in zip(masks, module.rms_norms)
        ).cpu().detach()
        merged_state[f"{name}.weight"] = merged_weight
        return merged_state, keys_to_remove

    # Initialization.
    merged_state = {}
    keys_to_remove = set()
    masked_modules = []
    for name, module in self.merger.named_modules():
        if any(mask_type in type(module).__name__ for mask_type in [
            "LinearsWithMasks", "EmbeddingsWithMasks", "RMSNormsWithMasks"
        ]):
            masked_modules.append((name, module))

    # Work, bitches. Mark component and mask keys for removal
    for name, module in tqdm(masked_modules, desc="Merging masked modules"):
        if isinstance(module, LinearsWithMasks):
            state, keys = merge_linears(name, module)
            if "lm_head" in name:
                logger.info(f"Dit con me may, {name}")
        elif isinstance(module, EmbeddingsWithMasks):
            state, keys = merge_embeddings(name, module)
        elif isinstance(module, RMSNormsWithMasks):
            state, keys = merge_rmsnorms(name, module)
        merged_state.update(state)
        keys_to_remove = keys_to_remove | keys

    # Copy over non-masked parameters
    full_state = self.merger.state_dict()
    keys_to_copy = set()
    
    for key, value in full_state.items():
        if any(remove_key in key for remove_key in keys_to_remove): continue
        if any(mask_key in key for mask_key in [
            "masks", "linears.", "embeddings.", "rms_norms."
        ]): continue
        keys_to_copy.add(key)

    if len(keys_to_copy) > 0:
        for key in tqdm(keys_to_copy, desc="Copying non-masked parameters"):
            merged_state[key] = full_state[key].to("cpu")

    
    # merged_model = AutoModelForCausalLM.from_pretrained(
    #     self.merger_config.model_paths[0],
    #     torch_dtype=self.merger.dtype,
    #     device_map="cpu"
    # )
    # merged_model.load_state_dict(merged_state)
    # merged_model.save_pretrained("./lala")
    return merged_state

In [183]:
model.tie_weights = False

In [179]:
model.config.tie_word_embeddings = False

In [133]:
merged_state = save_merged(merger)

Merging masked modules:   0%|          | 0/255 [00:00<?, ?it/s]

[2025-01-16 12:17:35,357] [INFO] [train.save_merged:90] [PID:126987] [RANK:0] Dit con me may, lm_head[39m


In [136]:
merged_state["model.embed_tokens.weight"]

tensor([[ 0.0106,  0.0116,  0.0130,  ..., -0.0029, -0.0182,  0.0064],
        [ 0.0128,  0.0011,  0.0210,  ...,  0.0013,  0.0310, -0.0023],
        [ 0.0238,  0.0200,  0.0288,  ..., -0.0014, -0.0008, -0.0072],
        ...,
        [-0.0068,  0.0018,  0.0047,  ..., -0.0054, -0.0018, -0.0066],
        [-0.0068,  0.0018,  0.0047,  ..., -0.0054, -0.0018, -0.0066],
        [-0.0068,  0.0018,  0.0047,  ..., -0.0054, -0.0018, -0.0066]])

In [154]:
merged_state["lm_head.weight"]

tensor([[ 0.0115,  0.0122,  0.0132,  ..., -0.0026, -0.0195,  0.0071],
        [ 0.0138,  0.0011,  0.0212,  ...,  0.0012,  0.0332, -0.0026],
        [ 0.0256,  0.0212,  0.0291,  ..., -0.0013, -0.0009, -0.0079],
        ...,
        [-0.0074,  0.0019,  0.0047,  ..., -0.0049, -0.0019, -0.0073],
        [-0.0074,  0.0019,  0.0047,  ..., -0.0049, -0.0019, -0.0073],
        [-0.0074,  0.0019,  0.0047,  ..., -0.0049, -0.0019, -0.0073]])

In [184]:
model.load_state_dict(merged_state)

<All keys matched successfully>

In [181]:
merged_state["lm_head.weight"] == merged_state["model.embed_tokens.weight"]

tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False]])

In [172]:
merged_state["model.embed_tokens.weight"]

tensor([[ 0.0106,  0.0116,  0.0130,  ..., -0.0029, -0.0182,  0.0064],
        [ 0.0128,  0.0011,  0.0210,  ...,  0.0013,  0.0310, -0.0023],
        [ 0.0238,  0.0200,  0.0288,  ..., -0.0014, -0.0008, -0.0072],
        ...,
        [-0.0068,  0.0018,  0.0047,  ..., -0.0054, -0.0018, -0.0066],
        [-0.0068,  0.0018,  0.0047,  ..., -0.0054, -0.0018, -0.0066],
        [-0.0068,  0.0018,  0.0047,  ..., -0.0054, -0.0018, -0.0066]])

In [182]:
model.model.embed_tokens.weight.data_ptr() ==  model.lm_head.weight.data_ptr()

True

In [189]:
ckpt = "/workspace/models/llama-3.2-3b-wizard/"
config = AutoConfig.from_pretrained(ckpt)
config.tie_word_embeddings = False
mm = AutoModelForCausalLM.from_pretrained(ckpt, config=config)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Some weights of LlamaForCausalLM were not initialized from the model checkpoint at /workspace/models/llama-3.2-3b-wizard/ and are newly initialized: ['lm_head.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [192]:
mm.load_state_dict(merged_state)

<All keys matched successfully>

In [193]:
mm.model.embed_tokens.weight.data_ptr() ==  mm.lm_head.weight.data_ptr()

False

In [194]:
mm.model.embed_tokens.weight

Parameter containing:
tensor([[ 0.0106,  0.0116,  0.0130,  ..., -0.0029, -0.0182,  0.0064],
        [ 0.0128,  0.0011,  0.0210,  ...,  0.0013,  0.0310, -0.0023],
        [ 0.0238,  0.0200,  0.0288,  ..., -0.0014, -0.0008, -0.0072],
        ...,
        [-0.0068,  0.0018,  0.0047,  ..., -0.0054, -0.0018, -0.0066],
        [-0.0068,  0.0018,  0.0047,  ..., -0.0054, -0.0018, -0.0066],
        [-0.0068,  0.0018,  0.0047,  ..., -0.0054, -0.0018, -0.0066]],
       requires_grad=True)

In [195]:
mm.lm_head.weight

Parameter containing:
tensor([[ 0.0115,  0.0122,  0.0132,  ..., -0.0026, -0.0195,  0.0071],
        [ 0.0138,  0.0011,  0.0212,  ...,  0.0012,  0.0332, -0.0026],
        [ 0.0256,  0.0212,  0.0291,  ..., -0.0013, -0.0009, -0.0079],
        ...,
        [-0.0074,  0.0019,  0.0047,  ..., -0.0049, -0.0019, -0.0073],
        [-0.0074,  0.0019,  0.0047,  ..., -0.0049, -0.0019, -0.0073],
        [-0.0074,  0.0019,  0.0047,  ..., -0.0049, -0.0019, -0.0073]],
       requires_grad=True)

In [196]:
mm = mm.to("cuda:0")

In [197]:
text = "Lee Min Ho is someone I don't trust."
merger_logits = get_logits(text, merger.merger, tokenizer)
merged_logits = get_logits(text, mm, tokenizer)

In [198]:
merger_logits, merged_logits

(tensor([[[ 5.0938,  7.5938, 12.3125,  ..., -5.3125, -5.3125, -5.3125],
          [ 5.6875,  4.1562,  1.7500,  ..., -5.1875, -5.1875, -5.1875],
          [ 5.6875,  4.6875,  4.0625,  ..., -4.3125, -4.3125, -4.3125],
          ...,
          [ 4.6250,  6.3438,  0.4199,  ..., -3.1562, -3.1562, -3.1562],
          [12.2500,  7.4375,  5.3438,  ..., -3.1406, -3.1406, -3.1406],
          [ 1.1328, -0.9492,  2.9375,  ..., -2.9375, -2.9219, -2.9219]]],
        device='cuda:0'),
 tensor([[[ 5.0938,  7.5938, 12.3125,  ..., -5.3125, -5.3125, -5.3125],
          [ 5.6875,  4.1562,  1.7500,  ..., -5.1875, -5.1875, -5.1875],
          [ 5.6875,  4.6875,  4.0625,  ..., -4.3125, -4.3125, -4.3125],
          ...,
          [ 4.6250,  6.3438,  0.4199,  ..., -3.1562, -3.1562, -3.1562],
          [12.2500,  7.4375,  5.3438,  ..., -3.1406, -3.1406, -3.1406],
          [ 1.1328, -0.9492,  2.9375,  ..., -2.9375, -2.9219, -2.9219]]],
        device='cuda:0'))

In [None]:
mm.load_state_dict(merged_state, strict=False)

In [None]:
print(model.get_input_embeddings().weight.data_ptr() == model.lm_head.weight.data_ptr()) # Should print False

In [143]:
for name, param in model.named_parameters():
    param.data = merged_state[name]

In [147]:
model = model.to("cuda:0")

In [22]:
masks = [merger.merger.lm_head.weight_masks[i].weight.data for i in range(2)]
weights = [merger.merger.lm_head.linears[i].weight.data for i in range(2)]

In [23]:
sum(mask * weight for mask, weight in zip(masks, weights))

tensor([[ 0.0115,  0.0122,  0.0132,  ..., -0.0026, -0.0195,  0.0071],
        [ 0.0138,  0.0011,  0.0212,  ...,  0.0012,  0.0332, -0.0026],
        [ 0.0256,  0.0212,  0.0291,  ..., -0.0013, -0.0009, -0.0079],
        ...,
        [-0.0074,  0.0019,  0.0047,  ..., -0.0049, -0.0019, -0.0073],
        [-0.0074,  0.0019,  0.0047,  ..., -0.0049, -0.0019, -0.0073],
        [-0.0074,  0.0019,  0.0047,  ..., -0.0049, -0.0019, -0.0073]],
       device='cuda:0', dtype=torch.bfloat16)

In [39]:
model = AutoModelForCausalLM.from_pretrained(
    "/workspace/models/llama-3.2-3b-wizard/",
    torch_dtype=torch.bfloat16,
    device_map=None
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [40]:
model.load_state_dict(merged_state)

<All keys matched successfully>

In [41]:
model = model.to("cuda:0")

In [26]:
X = torch.rand(3072, device="cuda:0", dtype=torch.bfloat16)

tensor([ 0.0447, -0.0486,  0.3223,  ..., -0.1196, -0.1187, -0.1196],
       device='cuda:0', dtype=torch.bfloat16, grad_fn=<SqueezeBackward4>)

In [64]:
merger.merger.model.embed_tokens

EmbeddingsWithMasks(
  (embeddings): ModuleList(
    (0-1): 2 x Embedding(128256, 3072)
  )
  (masks): ModuleList(
    (0-1): 2 x Mask(mask_mode=vector_input)
  )
  (masks_constrainer): Constrainer(constrain_mode=identity)
)

In [30]:
lin = nn.Linear(in_features=3072, out_features=128256, bias=False)
lin.weight.data = merged_state["lm_head.weight"].to(device="cuda:0", dtype=torch.bfloat16)

In [71]:
text = "hihi haha hoho"
inputs = tokenizer(text, return_tensors="pt")
input_ids = inputs["input_ids"]

In [109]:
for k, v in merged_state.items():
    merged_state[k] = merged_state[k].detach()

In [127]:
merged_state["lm_head.weight"]

tensor([[ 0.0115,  0.0122,  0.0132,  ..., -0.0026, -0.0195,  0.0071],
        [ 0.0138,  0.0011,  0.0212,  ...,  0.0012,  0.0332, -0.0026],
        [ 0.0256,  0.0212,  0.0291,  ..., -0.0013, -0.0009, -0.0079],
        ...,
        [-0.0074,  0.0019,  0.0047,  ..., -0.0049, -0.0019, -0.0073],
        [-0.0074,  0.0019,  0.0047,  ..., -0.0049, -0.0019, -0.0073],
        [-0.0074,  0.0019,  0.0047,  ..., -0.0049, -0.0019, -0.0073]])

In [113]:
merged_state["model.embed_tokens.weight"]

tensor([[ 0.0106,  0.0116,  0.0130,  ..., -0.0029, -0.0182,  0.0064],
        [ 0.0128,  0.0011,  0.0210,  ...,  0.0013,  0.0310, -0.0023],
        [ 0.0238,  0.0200,  0.0288,  ..., -0.0014, -0.0008, -0.0072],
        ...,
        [-0.0068,  0.0018,  0.0047,  ..., -0.0054, -0.0018, -0.0066],
        [-0.0068,  0.0018,  0.0047,  ..., -0.0054, -0.0018, -0.0066],
        [-0.0068,  0.0018,  0.0047,  ..., -0.0054, -0.0018, -0.0066]])

In [125]:
model.model.embed_tokens.weight.data

tensor([[ 0.0115,  0.0122,  0.0132,  ..., -0.0026, -0.0195,  0.0071],
        [ 0.0138,  0.0011,  0.0212,  ...,  0.0012,  0.0332, -0.0026],
        [ 0.0256,  0.0212,  0.0291,  ..., -0.0013, -0.0009, -0.0079],
        ...,
        [-0.0074,  0.0019,  0.0047,  ..., -0.0049, -0.0019, -0.0073],
        [-0.0074,  0.0019,  0.0047,  ..., -0.0049, -0.0019, -0.0073],
        [-0.0074,  0.0019,  0.0047,  ..., -0.0049, -0.0019, -0.0073]])

In [123]:
model = AutoModelForCausalLM.from_pretrained(
    "/workspace/models/llama-3.2-3b-math/"
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [124]:
model.load_state_dict(merged_state)

<All keys matched successfully>

In [74]:
merger.merger.model.embed_tokens(input_ids.to(merger.device))

tensor([[[-0.0012, -0.0006, -0.0046,  ..., -0.0015, -0.0019,  0.0018],
         [-0.0371, -0.0003,  0.0327,  ..., -0.0513,  0.0242,  0.0132],
         [-0.0371, -0.0003,  0.0327,  ..., -0.0513,  0.0242,  0.0132],
         [-0.0133,  0.0154,  0.0067,  ..., -0.0052, -0.0019, -0.0182],
         [ 0.0217,  0.0108,  0.0002,  ...,  0.0023, -0.0049, -0.0134],
         [ 0.0110,  0.0327, -0.0332,  ...,  0.0342,  0.0166,  0.0005]]],
       device='cuda:0', grad_fn=<EmbeddingBackward0>)

In [77]:
model.model.embed_tokens(input_ids.to(model.device))

tensor([[[-0.0013, -0.0007, -0.0046,  ..., -0.0014, -0.0021,  0.0020],
         [-0.0400, -0.0003,  0.0332,  ..., -0.0459,  0.0259,  0.0145],
         [-0.0400, -0.0003,  0.0332,  ..., -0.0459,  0.0259,  0.0145],
         [-0.0143,  0.0161,  0.0067,  ..., -0.0047, -0.0021, -0.0201],
         [ 0.0234,  0.0114,  0.0002,  ...,  0.0021, -0.0053, -0.0148],
         [ 0.0118,  0.0344, -0.0337,  ...,  0.0304,  0.0178,  0.0005]]],
       device='cuda:0', grad_fn=<EmbeddingBackward0>)

In [95]:
import copy
copied_embeds = copy.deepcopy(model.model.embed_tokens)

In [98]:
copied_embeds.weight.data = merged_embeds

In [78]:
merger_embeds = merger.merger.model.embed_tokens

In [92]:
weight_masks = merger_embeds.get_constrained_masks()["masks"]
merged_embeds = sum((mask * linear.weight) for mask, linear in zip(weight_masks, merger_embeds.embeddings))

In [100]:
copied_embeds(input_ids.to(copied_embeds.weight.device))

tensor([[[-0.0012, -0.0006, -0.0046,  ..., -0.0015, -0.0019,  0.0018],
         [-0.0371, -0.0003,  0.0327,  ..., -0.0513,  0.0242,  0.0132],
         [-0.0371, -0.0003,  0.0327,  ..., -0.0513,  0.0242,  0.0132],
         [-0.0133,  0.0154,  0.0067,  ..., -0.0052, -0.0019, -0.0182],
         [ 0.0217,  0.0108,  0.0002,  ...,  0.0023, -0.0049, -0.0134],
         [ 0.0110,  0.0327, -0.0332,  ...,  0.0342,  0.0166,  0.0005]]],
       device='cuda:0', grad_fn=<EmbeddingBackward0>)

In [93]:
merged_embeds

tensor([[ 0.0106,  0.0116,  0.0130,  ..., -0.0029, -0.0182,  0.0064],
        [ 0.0128,  0.0011,  0.0210,  ...,  0.0013,  0.0310, -0.0023],
        [ 0.0238,  0.0200,  0.0288,  ..., -0.0014, -0.0008, -0.0072],
        ...,
        [-0.0068,  0.0018,  0.0047,  ..., -0.0054, -0.0018, -0.0066],
        [-0.0068,  0.0018,  0.0047,  ..., -0.0054, -0.0018, -0.0066],
        [-0.0068,  0.0018,  0.0047,  ..., -0.0054, -0.0018, -0.0066]],
       device='cuda:0', grad_fn=<AddBackward0>)

In [42]:
a = merger.merger.lm_head(X)
b = lin(X)
c = model.lm_head(X)
sum(a - c)

tensor(0., device='cuda:0', grad_fn=<AddBackward0>)

In [28]:
merged_state["lm_head.weight"].shape

torch.Size([128256, 3072])

In [26]:
merged_model = save_merged(merger)

Merging masked modules:   0%|          | 0/255 [00:00<?, ?it/s]

[2025-01-16 11:02:13,619] [INFO] [train.save_merged:90] [PID:122984] [RANK:0] Dit con me may, lm_head[39m


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [34]:
merged_model.lm_head.weight.data

tensor([[ 0.0115,  0.0122,  0.0132,  ..., -0.0026, -0.0195,  0.0071],
        [ 0.0138,  0.0011,  0.0212,  ...,  0.0012,  0.0332, -0.0026],
        [ 0.0256,  0.0212,  0.0291,  ..., -0.0013, -0.0009, -0.0079],
        ...,
        [-0.0074,  0.0019,  0.0047,  ..., -0.0049, -0.0019, -0.0073],
        [-0.0074,  0.0019,  0.0047,  ..., -0.0049, -0.0019, -0.0073],
        [-0.0074,  0.0019,  0.0047,  ..., -0.0049, -0.0019, -0.0073]],
       dtype=torch.bfloat16)

In [67]:
merged_model.model.layers[2].mlp.up_proj.weight

Parameter containing:
tensor([[-0.0077,  0.0273, -0.0027,  ..., -0.0005,  0.0198, -0.0300],
        [ 0.0099, -0.0146, -0.0056,  ..., -0.0087,  0.0151,  0.0212],
        [ 0.0151,  0.0176,  0.0459,  ..., -0.0229,  0.0100, -0.0233],
        ...,
        [ 0.0039, -0.0030,  0.0223,  ...,  0.0173, -0.0176, -0.0020],
        [ 0.0193,  0.0197,  0.0047,  ...,  0.0123,  0.0146, -0.0078],
        [-0.0332,  0.0251, -0.0092,  ..., -0.0079,  0.0087, -0.0132]],
       device='cuda:0', dtype=torch.bfloat16, requires_grad=True)

In [36]:
merged_model = merged_model.to("cuda:0")

In [42]:
merger = merger.to("cuda:0")

In [150]:
text = "Lee Min Ho is someone I don't trust."
merger_logits = get_logits(text, merger.merger, tokenizer)
merged_logits = get_logits(text, model, tokenizer)

In [151]:
merger_logits, merged_logits

(tensor([[[ 5.0938,  7.5938, 12.3125,  ..., -5.3125, -5.3125, -5.3125],
          [ 5.6875,  4.1562,  1.7500,  ..., -5.1875, -5.1875, -5.1875],
          [ 5.6875,  4.6875,  4.0625,  ..., -4.3125, -4.3125, -4.3125],
          ...,
          [ 4.6250,  6.3438,  0.4199,  ..., -3.1562, -3.1562, -3.1562],
          [12.2500,  7.4375,  5.3438,  ..., -3.1406, -3.1406, -3.1406],
          [ 1.1328, -0.9492,  2.9375,  ..., -2.9375, -2.9219, -2.9219]]],
        device='cuda:0'),
 tensor([[[ 5.2188,  7.7188, 12.2500,  ..., -5.4062, -5.4062, -5.4062],
          [ 5.5938,  4.2188,  1.7422,  ..., -5.3125, -5.3125, -5.3125],
          [ 5.6875,  4.7812,  3.9844,  ..., -4.4375, -4.4375, -4.4375],
          ...,
          [ 4.5312,  6.2188,  0.4707,  ..., -3.2812, -3.2812, -3.2812],
          [12.1250,  7.3125,  5.3438,  ..., -3.3906, -3.3906, -3.3906],
          [ 1.1719, -0.7891,  2.9688,  ..., -3.2031, -3.2031, -3.2031]]],
        device='cuda:0'))

In [152]:
model.lm_head.weight.data

tensor([[ 0.0106,  0.0116,  0.0130,  ..., -0.0029, -0.0182,  0.0064],
        [ 0.0128,  0.0011,  0.0210,  ...,  0.0013,  0.0310, -0.0023],
        [ 0.0238,  0.0200,  0.0288,  ..., -0.0014, -0.0008, -0.0072],
        ...,
        [-0.0068,  0.0018,  0.0047,  ..., -0.0054, -0.0018, -0.0066],
        [-0.0068,  0.0018,  0.0047,  ..., -0.0054, -0.0018, -0.0066],
        [-0.0068,  0.0018,  0.0047,  ..., -0.0054, -0.0018, -0.0066]],
       device='cuda:0')

In [153]:
model.model.embed_tokens.weight.data

tensor([[ 0.0106,  0.0116,  0.0130,  ..., -0.0029, -0.0182,  0.0064],
        [ 0.0128,  0.0011,  0.0210,  ...,  0.0013,  0.0310, -0.0023],
        [ 0.0238,  0.0200,  0.0288,  ..., -0.0014, -0.0008, -0.0072],
        ...,
        [-0.0068,  0.0018,  0.0047,  ..., -0.0054, -0.0018, -0.0066],
        [-0.0068,  0.0018,  0.0047,  ..., -0.0054, -0.0018, -0.0066],
        [-0.0068,  0.0018,  0.0047,  ..., -0.0054, -0.0018, -0.0066]],
       device='cuda:0')

In [18]:
from safetensors import safe_open
from safetensors.torch import save_file
save_directory = "./hihi"
logger.info("Integrety test")
sd = {}
signature = "lm_head.weight"
shard_paths = [f for f in os.listdir(save_directory) if f.endswith('.safetensors')]
for shard_path in sorted(shard_paths, key=lambda x: int(x.split('-')[1])):
    apath = os.path.join(save_directory, shard_path)
    with safe_open(apath, framework="pt", device="cpu") as f:
        for key in f.keys():
            if signature in key:
                sd[key] = f.get_tensor(key)
# torch.testing.assert_close(sd[signature], merged_state[signature], atol=0, rtol=0)

[2025-01-16 09:44:04,646] [INFO] [train.<module>:4] [PID:116815] [RANK:0] Integrety test[39m


In [47]:
from transformers.modeling_outputs import (
    BaseModelOutputWithPast,
)

def decoder_forward(
    decoder,
    hidden_states: torch.Tensor,
    attention_mask: Optional[torch.Tensor] = None,
    position_ids: Optional[torch.LongTensor] = None,
    past_key_value: Optional[Tuple[torch.Tensor]] = None,
    output_attentions: Optional[bool] = False,
    use_cache: Optional[bool] = False,
    cache_position: Optional[torch.LongTensor] = None,
    position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,  # will become mandatory in v4.46
    **kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:

    steps = {}
    # logger.warning(f"-------- Logging hidden_states in decoder forward:")
    residual = hidden_states
    # logger.warning(f" hidden_states step 1 (as input): {hidden_states}")
    steps.update({"step 1": hidden_states})

    hidden_states = decoder.input_layernorm(hidden_states)
    # logger.warning(f" hidden_states step 2 (after input_layernorm): {hidden_states}")
    steps.update({"step 2": hidden_states})
    # Self Attention
    hidden_states, self_attn_weights, present_key_value = decoder.self_attn(
        hidden_states=hidden_states,
        attention_mask=attention_mask,
        position_ids=position_ids,
        past_key_value=past_key_value,
        output_attentions=output_attentions,
        use_cache=use_cache,
        cache_position=cache_position,
        position_embeddings=position_embeddings,
    )
    # logger.warning(f" hidden_states step 3 (after self_attn): {hidden_states}")
    steps.update({"step 3": hidden_states})
    
    hidden_states = residual + hidden_states
    # logger.warning(f" hidden_states step 4 (after first skip connection): {hidden_states}")
    steps.update({"step 4": hidden_states})
    # Fully Connected
    residual = hidden_states
    hidden_states = decoder.post_attention_layernorm(hidden_states)
    # logger.warning(f" hidden_states step 5 (after post_attention_layernorm): {hidden_states}")
    steps.update({"step 5": hidden_states})
    
    hidden_states = decoder.mlp(hidden_states)
    # logger.warning(f" hidden_states step 6 (after mlp): {hidden_states}")
    steps.update({"step 6": hidden_states})
    
    hidden_states = residual + hidden_states
    # logger.warning(f" hidden_states step 7 (after second skip connection): {hidden_states}")
    steps.update({"step 7": hidden_states})

    outputs = (hidden_states,)

    if output_attentions:
        outputs += (self_attn_weights,)

    if use_cache:
        outputs += (present_key_value,)

    outputs += (steps,)
    return outputs

def model_forward(text, model, tokenizer):
    inputs = tokenizer(text, return_tensors="pt").to(model.device)
    input_ids = inputs["input_ids"]
    attention_mask = inputs["attention_mask"]
    past_key_values = None
    cache_position = None
    position_ids = None
    output_hidden_states = True
    output_attentions = False
    use_cache = False
    return_dict = True
    #############
    
    model.eval()
    with torch.no_grad():

        # kept for BC (non `Cache` `past_key_values` inputs)
        return_legacy_cache = False
        inputs_embeds = model.embed_tokens(input_ids)

        if cache_position is None:
            past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
            cache_position = torch.arange(
                past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
            )
        if position_ids is None:
            position_ids = cache_position.unsqueeze(0)

        causal_mask = model._update_causal_mask(
            attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
        )

        hidden_states = inputs_embeds

        # create position embeddings to be shared across the decoder layers
        position_embeddings = model.rotary_emb(hidden_states, position_ids)

        # decoder layers
        all_hidden_states = () if output_hidden_states else None
        all_self_attns = () if output_attentions else None
        next_decoder_cache = None
        all_decoder_steps = ()

        for i, decoder_layer in enumerate(model.layers[:2]):   
            if output_hidden_states:
                all_hidden_states += (hidden_states,)
          
            layer_outputs = decoder_forward(
                decoder_layer,
                hidden_states,
                attention_mask=causal_mask,
                position_ids=position_ids,
                past_key_value=past_key_values,
                output_attentions=output_attentions,
                use_cache=use_cache,
                cache_position=cache_position,
                position_embeddings=position_embeddings,
            )

            hidden_states = layer_outputs[0]
            steps = layer_outputs[-1]

            if use_cache:
                next_decoder_cache = layer_outputs[2 if output_attentions else 1]

            if output_attentions:
                all_self_attns += (layer_outputs[1],)

            all_decoder_steps += (steps,)

        hidden_states = model.norm(hidden_states)

        # add hidden states from the last decoder layer
        if output_hidden_states:
            all_hidden_states += (hidden_states,)

        next_cache = next_decoder_cache if use_cache else None
        if return_legacy_cache:
            next_cache = next_cache.to_legacy_cache()

        if not return_dict:
            return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
        return BaseModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=(),
            hidden_states=all_hidden_states,
            attentions=all_decoder_steps,
        )

In [53]:
merged = merged.to("cuda:1")



In [48]:
text = "Lee Min Ho is someone I don't trust."
merger_outputs = model_forward(text, merger.merger.model, tokenizer)
merged_outputs = model_forward(text, model.model, tokenizer)

In [50]:
device="cuda:0"
for j, layer_output in enumerate(merger_outputs.attentions):
    other_output = merged_outputs.attentions[j]
    for i in range(7):
        key = f"step {i+1}"
        if torch.allclose(layer_output[key], other_output[key], atol=0, rtol=0):
            print(f"layer {j}, step {i+1} passed!")
        else:
            print(f"FAIL AT layer {j}, step {i+1}")

FAIL AT layer 0, step 1
FAIL AT layer 0, step 2
FAIL AT layer 0, step 3
FAIL AT layer 0, step 4
FAIL AT layer 0, step 5
FAIL AT layer 0, step 6
FAIL AT layer 0, step 7
FAIL AT layer 1, step 1
FAIL AT layer 1, step 2
FAIL AT layer 1, step 3
FAIL AT layer 1, step 4
FAIL AT layer 1, step 5
FAIL AT layer 1, step 6
FAIL AT layer 1, step 7


In [58]:
merger_outputs.attentions[0]['step 1']

tensor([[[-0.0012, -0.0006, -0.0046,  ..., -0.0015, -0.0019,  0.0018],
         [ 0.0060, -0.0327, -0.0078,  ...,  0.0146, -0.0061,  0.0006],
         [-0.0200, -0.0303, -0.0104,  ..., -0.0227, -0.0084,  0.0063],
         ...,
         [ 0.0139, -0.0105, -0.0154,  ..., -0.0018,  0.0156,  0.0078],
         [ 0.0010, -0.0153,  0.0620,  ...,  0.0317, -0.0289, -0.0111],
         [ 0.0092, -0.0027,  0.0277,  ...,  0.0127, -0.0079,  0.0093]]],
       device='cuda:0')

In [59]:
merged_outputs.attentions[0]['step 1']

tensor([[[-0.0013, -0.0007, -0.0046,  ..., -0.0014, -0.0021,  0.0020],
         [ 0.0065, -0.0344, -0.0079,  ...,  0.0131, -0.0065,  0.0007],
         [-0.0215, -0.0317, -0.0105,  ..., -0.0203, -0.0090,  0.0070],
         ...,
         [ 0.0150, -0.0110, -0.0156,  ..., -0.0016,  0.0167,  0.0085],
         [ 0.0011, -0.0161,  0.0625,  ...,  0.0283, -0.0310, -0.0123],
         [ 0.0099, -0.0029,  0.0281,  ...,  0.0114, -0.0085,  0.0103]]],
       device='cuda:0')

In [58]:
merger_outputs.last_hidden_state

tensor([[[ 1.8768e-03,  1.1658e-02,  2.3926e-01,  ...,  2.6172e-01,
           4.9072e-02,  1.3000e-02],
         [ 4.0234e-01, -9.9121e-02, -6.7188e+00,  ..., -1.3203e+00,
          -7.0703e-01,  4.3213e-02],
         [ 5.4688e-01, -3.2500e+00,  2.8281e+00,  ..., -3.2969e+00,
           1.0703e+00,  6.3281e-01],
         ...,
         [-6.2500e-01,  8.0859e-01, -5.3438e+00,  ...,  9.6875e-01,
           1.2500e+00, -8.3203e-01],
         [ 7.3438e-01, -1.6250e+00,  2.8281e+00,  ...,  2.5625e+00,
          -2.7656e+00, -5.4297e-01],
         [ 4.0039e-01,  2.7222e-02,  1.4766e+00,  ..., -9.5312e-01,
           8.7500e-01, -7.3047e-01]]], device='cuda:0', dtype=torch.bfloat16)

In [57]:
merged_outputs.last_hidden_state

tensor([[[ 1.8768e-03,  1.1658e-02,  2.3926e-01,  ...,  2.6172e-01,
           4.9072e-02,  1.3000e-02],
         [ 4.0234e-01, -9.9121e-02, -6.7188e+00,  ..., -1.3203e+00,
          -7.0703e-01,  4.3213e-02],
         [ 5.4688e-01, -3.2500e+00,  2.8281e+00,  ..., -3.2969e+00,
           1.0703e+00,  6.3281e-01],
         ...,
         [-6.2500e-01,  8.0859e-01, -5.3438e+00,  ...,  9.6875e-01,
           1.2500e+00, -8.3203e-01],
         [ 7.3438e-01, -1.6250e+00,  2.8281e+00,  ...,  2.5625e+00,
          -2.7656e+00, -5.4297e-01],
         [ 4.0039e-01,  2.7222e-02,  1.4766e+00,  ..., -9.5312e-01,
           8.7500e-01, -7.3047e-01]]], device='cuda:1', dtype=torch.bfloat16)

In [22]:
SET = set({3})
sa = set({1,2,3})
sb = set({1,2,4})
SET | sa | sb

{1, 2, 3, 4}

In [21]:
merger.save_merged = save_merged.__get__(merger)

In [23]:
merger.__class__

merger.NewMerger

In [22]:
merger.save_merged("./hehe")

Merging masked modules:   0%|          | 0/255 [00:00<?, ?it/s]

Copying non-masked parameters: 0it [00:00, ?it/s]

RuntimeError: super(): __class__ cell not found