In [1]:
"""
Model merging training implementation using PyTorch and Transformers.
Implements custom data collation and training for merged language models.
"""

from dataclasses import dataclass
from typing import (
    Any, Callable, Dict, 
    List, NewType, Optional, 
    Tuple, Union, Mapping
)
from abc import ABC, abstractmethod
from datasets import load_dataset, concatenate_datasets
from accelerate.logging import get_logger
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

import torch
import safetensors
import math
import yaml
import logging
import copy
import gc
import os
import argparse
import sys
import shutil

from tqdm import tqdm
from transformers import (
    PreTrainedTokenizerBase,
    PreTrainedModel,
    PretrainedConfig,
    AutoConfig,
    AutoModelForCausalLM,
    AutoTokenizer,
    LlamaForCausalLM,
    LlamaConfig,
    Trainer,
    TrainingArguments,
    HfArgumentParser,
    default_data_collator,
    is_torch_xla_available,
    set_seed,
)

from transformers.utils import CONFIG_NAME
from transformers.pytorch_utils import is_torch_greater_or_equal_than_1_13

from merger import (
    MergerConfig,
    Merger,
    # NewMerger,
    init_masks,
    set_masks
)

from utils import (
    generate, 
    get_hidden_states, 
    get_logits,
    free_memory
)
# Configure logger
from logging_config import configure_logging
configure_logging()
logger = logging.getLogger("train")

[2025-01-18 09:22:15,257] [INFO] [masks.<module>:57] [PID:124456] [RANK:0] --------- ACCURATE MASKS ----------[39m


## Save merged

In [2]:
checkpoint_dir = "../results/run_02b"
merger_config = MergerConfig.from_pretrained(
    checkpoint_dir,
    _configuration_file="merger_config.json"
)

In [3]:
merger_config

MergerConfig {
  "_attn_implementation_autoset": true,
  "constrain_mode": "identity",
  "mode": "vector_input",
  "model_paths": [
    "/workspace/models/llama-3.2-3b-wizard",
    "/workspace/models/llama-3.2-3b-math"
  ],
  "transformers_version": "4.46.3"
}

In [4]:
merger = Merger.from_pretrained(
    checkpoint_dir,
    torch_dtype=torch.bfloat16,
    device_map={"":3}
)

[2025-01-17 12:25:51,932] [INFO] [merger.__init__:222] [PID:3073] [RANK:0] Creating merger with dummy weights ...[39m


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Initializing masks:   1%|█▎                                                                                                                                                                            | 2/255 [00:07<15:23,  3.65s/it]



Initializing masks: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 255/255 [00:27<00:00,  9.17it/s]


[2025-01-17 12:26:26,128] [INFO] [merger.from_pretrained:401] [PID:3073] [RANK:0] Loaded masks from ../results/run_02b[39m


In [22]:
merger.save_pretrained("./haha")

In [5]:
merger.save_merged("./haha")

Merging masked modules: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 255/255 [00:02<00:00, 122.62it/s]


In [7]:
merger.models[0].__class__.__name__

'LlamaForCausalLM'

In [6]:
merger.config

LlamaConfig {
  "_name_or_path": "/workspace/models/llama-3.2-3b-wizard",
  "architectures": [
    "LlamaForCausalLM"
  ],
  "attention_bias": false,
  "attention_dropout": 0.0,
  "bos_token_id": 128000,
  "eos_token_id": 128001,
  "head_dim": 128,
  "hidden_act": "silu",
  "hidden_size": 3072,
  "initializer_range": 0.02,
  "intermediate_size": 8192,
  "max_position_embeddings": 131072,
  "mlp_bias": false,
  "model_type": "llama",
  "num_attention_heads": 24,
  "num_hidden_layers": 28,
  "num_key_value_heads": 8,
  "pretraining_tp": 1,
  "rms_norm_eps": 1e-05,
  "rope_scaling": {
    "factor": 32.0,
    "high_freq_factor": 4.0,
    "low_freq_factor": 1.0,
    "original_max_position_embeddings": 8192,
    "rope_type": "llama3"
  },
  "rope_theta": 500000.0,
  "tie_word_embeddings": false,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.46.3",
  "use_cache": false,
  "vocab_size": 128256
}

In [4]:
merger.merger.model.layers[0].mlp.up_proj.get_raw_masks()['weight_masks'][0][:100]

tensor([0.5078, 0.5078, 0.4883, 0.4434, 0.4922, 0.4902, 0.5039, 0.4902, 0.5117,
        0.4707, 0.4922, 0.4824, 0.5156, 0.5234, 0.4824, 0.5078, 0.5234, 0.5312,
        0.5195, 0.4863, 0.5039, 0.5195, 0.5078, 0.5273, 0.5156, 0.4961, 0.5117,
        0.5273, 0.5156, 0.4863, 0.5039, 0.4746, 0.5000, 0.4902, 0.5234, 0.4863,
        0.5000, 0.4727, 0.4883, 0.4766, 0.5156, 0.4980, 0.4707, 0.5195, 0.4922,
        0.4980, 0.4883, 0.4922, 0.4902, 0.5000, 0.4727, 0.4941, 0.5039, 0.4961,
        0.4785, 0.5000, 0.4512, 0.5000, 0.4883, 0.4883, 0.5078, 0.4863, 0.5078,
        0.4863, 0.4941, 0.5078, 0.4805, 0.5156, 0.4629, 0.4746, 0.4902, 0.5234,
        0.4883, 0.4375, 0.4844, 0.4941, 0.4707, 0.4688, 0.4961, 0.5234, 0.4961,
        0.5078, 0.5156, 0.5273, 0.5039, 0.4609, 0.5000, 0.4785, 0.5039, 0.5078,
        0.4902, 0.4531, 0.4980, 0.4922, 0.5117, 0.4883, 0.4766, 0.4902, 0.4824,
        0.5039], device='cuda:2', dtype=torch.bfloat16,
       grad_fn=<SliceBackward0>)

In [10]:
merger.merger.model.layers[0].mlp.up_proj.get_raw_masks()['weight_masks'][1][:100]

tensor([0.5078, 0.5078, 0.4883, 0.4453, 0.4941, 0.4902, 0.5039, 0.4961, 0.5117,
        0.4707, 0.4902, 0.4824, 0.5156, 0.5234, 0.4824, 0.5039, 0.5234, 0.5312,
        0.5234, 0.4863, 0.5039, 0.5195, 0.5078, 0.5273, 0.5156, 0.4961, 0.5117,
        0.5234, 0.5156, 0.4863, 0.5039, 0.4746, 0.5000, 0.4922, 0.5234, 0.4863,
        0.5000, 0.4727, 0.4883, 0.4785, 0.5156, 0.5000, 0.4707, 0.5195, 0.4922,
        0.4980, 0.4941, 0.4922, 0.4922, 0.5000, 0.4746, 0.4941, 0.5039, 0.4961,
        0.4766, 0.5000, 0.4551, 0.5000, 0.4883, 0.4883, 0.5078, 0.4844, 0.5078,
        0.4863, 0.4941, 0.5078, 0.4824, 0.5156, 0.4648, 0.4746, 0.4922, 0.5234,
        0.4922, 0.4375, 0.4844, 0.4941, 0.4727, 0.4688, 0.4961, 0.5234, 0.4961,
        0.5078, 0.5156, 0.5273, 0.5039, 0.4629, 0.5000, 0.4805, 0.5039, 0.5078,
        0.4922, 0.4551, 0.5000, 0.4941, 0.5117, 0.4883, 0.4766, 0.4922, 0.4922,
        0.5039], dtype=torch.bfloat16, grad_fn=<SliceBackward0>)

In [3]:
device_map={"":0}
merger = Merger.from_pretrained(
    checkpoint_dir,
    merge_config,
    torch_dtype=torch.bfloat16,
    device_map=device_map,
    # attn_implementation="flash_attention_2",
)

[2025-01-16 14:52:14,500] [INFO] [merger.from_pretrained:315] [PID:143678] [RANK:0] >>> Merger device: {'': 0}[39m
[2025-01-16 14:52:14,506] [INFO] [merger.__init__:207] [PID:143678] [RANK:0] Creating merger with dummy weights ...[39m


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Initializing masks:   1%|█▎                                                                                                                                                                            | 2/255 [00:18<38:51,  9.22s/it]



Initializing masks: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 255/255 [01:09<00:00,  3.69it/s]


In [16]:
merger.save_merged("/workspace/logits-guided-merger/results/run_01")

Merging masked modules: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 255/255 [00:02<00:00, 96.46it/s]


In [17]:
tokenizer.save_pretrained("/workspace/logits-guided-merger/results/run_01")

('/workspace/logits-guided-merger/results/run_01/tokenizer_config.json',
 '/workspace/logits-guided-merger/results/run_01/special_tokens_map.json',
 '/workspace/logits-guided-merger/results/run_01/tokenizer.json')

In [8]:
merger.config

LlamaConfig {
  "_name_or_path": "/workspace/models/llama-3.2-3b-wizard",
  "architectures": [
    "LlamaForCausalLM"
  ],
  "attention_bias": false,
  "attention_dropout": 0.0,
  "bos_token_id": 128000,
  "eos_token_id": 128001,
  "head_dim": 128,
  "hidden_act": "silu",
  "hidden_size": 3072,
  "initializer_range": 0.02,
  "intermediate_size": 8192,
  "max_position_embeddings": 131072,
  "mlp_bias": false,
  "model_type": "llama",
  "num_attention_heads": 24,
  "num_hidden_layers": 28,
  "num_key_value_heads": 8,
  "pretraining_tp": 1,
  "rms_norm_eps": 1e-05,
  "rope_scaling": {
    "factor": 32.0,
    "high_freq_factor": 4.0,
    "low_freq_factor": 1.0,
    "original_max_position_embeddings": 8192,
    "rope_type": "llama3"
  },
  "rope_theta": 500000.0,
  "tie_word_embeddings": false,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.46.3",
  "use_cache": false,
  "vocab_size": 128256
}

In [8]:
merged = AutoModelForCausalLM.from_pretrained(
    "hehe",
    torch_dtype=torch.bfloat16,
    device_map=None
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [11]:
type(merged)

transformers.models.llama.modeling_llama.LlamaForCausalLM

In [12]:
tokenizer = AutoTokenizer.from_pretrained("/workspace/models/llama-3.2-3b-wizard")
text = "Lee Min Ho is someone I don't trust."
merger_logits = get_logits(text, merger.merger, tokenizer)
merged_logits = get_logits(text, merged, tokenizer)

In [13]:
torch.allclose(merger_logits, merged_logits)

True

In [15]:
merger_logits, merged_logits

(tensor([[[ 5.0938,  7.5938, 12.3750,  ..., -5.3438, -5.3438, -5.3438],
          [ 5.7500,  4.1562,  1.7812,  ..., -5.1875, -5.1875, -5.1875],
          [ 5.6562,  4.6875,  4.0312,  ..., -4.3125, -4.3125, -4.3125],
          ...,
          [ 4.6250,  6.3438,  0.4102,  ..., -3.1719, -3.1719, -3.1719],
          [12.2500,  7.4375,  5.3125,  ..., -3.1562, -3.1562, -3.1562],
          [ 1.0938, -0.9844,  2.9062,  ..., -2.9688, -2.9688, -2.9688]]],
        dtype=torch.bfloat16),
 tensor([[[ 5.0938,  7.5938, 12.3750,  ..., -5.3438, -5.3438, -5.3438],
          [ 5.7500,  4.1562,  1.7812,  ..., -5.1875, -5.1875, -5.1875],
          [ 5.6562,  4.6875,  4.0312,  ..., -4.3125, -4.3125, -4.3125],
          ...,
          [ 4.6250,  6.3438,  0.4102,  ..., -3.1719, -3.1719, -3.1719],
          [12.2500,  7.4375,  5.3125,  ..., -3.1562, -3.1562, -3.1562],
          [ 1.0938, -0.9844,  2.9062,  ..., -2.9688, -2.9688, -2.9688]]],
        dtype=torch.bfloat16))

## SLERP

In [77]:
from safetensors import safe_open
from safetensors.torch import save_file
def load_tensors(path, signature=""):
    state_dict = {}
    shard_paths = [f for f in os.listdir(path) if (
        f.startswith("model") and f.endswith('.safetensors')
    )]
    for shard_path in sorted(shard_paths, key=lambda x: int(x.split('-')[1])):
        apath = os.path.join(path, shard_path)
        with safe_open(apath, framework="pt", device="cpu") as f:
            for key in f.keys():
                if signature in key:
                    state_dict[key] = f.get_tensor(key)
    return state_dict

In [80]:
sd1 = load_tensors(path="/workspace/models/baselines/acl-slerp/")

In [81]:
sd2 = load_tensors(path="/workspace/models/baselines/acl-slerp-custom/")

In [78]:
sd3 = load_tensors(path="/workspace/logits-guided-merger/dev/test-slerp/")

In [96]:
ok = []
not_ok = []
for key in list(sd1.keys()):
    if torch.allclose(sd1[key], sd3[key], atol=1e-5, rtol=1e-5):
        ok.append(key)
    else:
        not_ok.append(key)

In [105]:
k = not_ok[1]
torch.mean((torch.abs(sd1[k] - sd3[k]) == 0).float())

tensor(0.7127)

In [11]:
key = list(sd1.keys())[0]

In [14]:
sd1[key], sd2[key], torch.sum(sd1[key] - sd2[key])

(tensor([[ 0.0111,  0.0116,  0.0130,  ..., -0.0027, -0.0189,  0.0067],
         [ 0.0134,  0.0011,  0.0210,  ...,  0.0012,  0.0322, -0.0024],
         [ 0.0249,  0.0200,  0.0288,  ..., -0.0013, -0.0009, -0.0074],
         ...,
         [-0.0071,  0.0018,  0.0047,  ..., -0.0050, -0.0019, -0.0068],
         [-0.0071,  0.0018,  0.0047,  ..., -0.0050, -0.0019, -0.0068],
         [-0.0071,  0.0018,  0.0047,  ..., -0.0050, -0.0019, -0.0068]],
        dtype=torch.bfloat16),
 tensor([[ 0.0111,  0.0116,  0.0130,  ..., -0.0027, -0.0189,  0.0067],
         [ 0.0134,  0.0011,  0.0210,  ...,  0.0012,  0.0322, -0.0024],
         [ 0.0249,  0.0200,  0.0288,  ..., -0.0013, -0.0009, -0.0074],
         ...,
         [-0.0071,  0.0018,  0.0047,  ..., -0.0050, -0.0019, -0.0068],
         [-0.0071,  0.0018,  0.0047,  ..., -0.0050, -0.0019, -0.0068],
         [-0.0071,  0.0018,  0.0047,  ..., -0.0050, -0.0019, -0.0068]],
        dtype=torch.bfloat16),
 tensor(0., dtype=torch.bfloat16))

In [3]:
checkpoint_dir = "/workspace/logits-guided-merger/results/run_02b"
merger_config = MergerConfig.from_pretrained(
    checkpoint_dir,
    _configuration_file="merger_config.json"
)

In [6]:
merger_config

MergerConfig {
  "_attn_implementation_autoset": true,
  "constrain_mode": "identity",
  "mode": "vector_input",
  "model_paths": [
    "/workspace/models/llama-3.2-3b-wizard",
    "/workspace/models/llama-3.2-3b-math"
  ],
  "transformers_version": "4.46.3"
}

In [4]:
merger = Merger.from_pretrained(
    checkpoint_dir,
    torch_dtype=torch.bfloat16,
    device_map=None
)

[2025-01-18 09:23:29,144] [INFO] [merger.__init__:222] [PID:124456] [RANK:0] Creating merger with dummy weights ...[39m


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Initializing masks:   1%|█▎                                                                                                                                                                            | 2/255 [00:07<15:21,  3.64s/it]



Initializing masks: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 255/255 [00:29<00:00,  8.58it/s]


[2025-01-18 09:24:01,645] [INFO] [merger.from_pretrained:401] [PID:124456] [RANK:0] Loaded masks from /workspace/logits-guided-merger/results/run_02b[39m


In [5]:
merger.merger.lm_head.get_constrained_masks()

{'weight_masks': [Parameter containing:
  tensor([0.5156, 0.5117, 0.5078,  ..., 0.4590, 0.5234, 0.4902],
         dtype=torch.bfloat16, requires_grad=True),
  Parameter containing:
  tensor([0.5195, 0.5117, 0.5078,  ..., 0.4570, 0.5234, 0.4902],
         dtype=torch.bfloat16, requires_grad=True)],
 'bias_masks': [None, None]}

In [17]:
masks_weight = merger.get_masks_state_dict()

In [19]:
mask_init = {
  "strategy": "uniform",
  "factors": [0.5, 0.5]
}
set_masks(merger, mask_init)

[2025-01-18 09:28:02,907] [INFO] [merger.set_masks:170] [PID:124456] [RANK:0] Applying uniform masks with factors = [0.5, 0.5].[39m


Setting up masks: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 255/255 [00:00<00:00, 14356.15it/s]


In [15]:
merger.load_masks(checkpoint_dir)

In [21]:
# masks_weight

In [22]:
mask_init = {
    "strategy": "spherical",
    "parameters": {
        "self_attn": [0, 0.3, 0.5, 0.7, 1],
        "mlp": [1, 0.7, 0.5, 0.3, 0],
        "default": 0.5
    }
}

In [71]:
import numpy as np
import re
def lerp(
    t: float, v0: Union[np.ndarray, torch.Tensor], v1: Union[np.ndarray, torch.Tensor]
) -> Union[np.ndarray, torch.Tensor]:
    return (1 - t) * v0 + t * v1


def slerp(
    t: Union[float, np.ndarray],
    v0: Union[np.ndarray, torch.Tensor],
    v1: Union[np.ndarray, torch.Tensor],
    DOT_THRESHOLD: float = 0.9995,
    eps: float = 1e-8,
):
    """
    Spherical linear interpolation

    From: https://gist.github.com/dvschultz/3af50c40df002da3b751efab1daddf2c
    Args:
        t (float/np.ndarray): Float value between 0.0 and 1.0
        v0 (np.ndarray): Starting vector
        v1 (np.ndarray): Final vector
        DOT_THRESHOLD (float): Threshold for considering the two vectors as
                               colinear. Not recommended to alter this.
    Returns:
        v2 (np.ndarray): Interpolation vector between v0 and v1
    """
    is_torch = False
    if not isinstance(v0, np.ndarray):
        is_torch = True
        v0 = v0.detach().cpu().float().numpy()
    if not isinstance(v1, np.ndarray):
        is_torch = True
        v1 = v1.detach().cpu().float().numpy()

    # Copy the vectors to reuse them later
    v0_copy = np.copy(v0)
    v1_copy = np.copy(v1)

    # Normalize the vectors to get the directions and angles
    v0 = normalize(v0, eps)
    v1 = normalize(v1, eps)
    # import ipdb; ipdb.set_trace()

    # Dot product with the normalized vectors (can't use np.dot in W)
    dot = np.sum(v0 * v1)

    # If absolute value of dot product is almost 1, vectors are ~colinear, so use lerp
    if np.abs(dot) > DOT_THRESHOLD:
        s0, s1 = 1 - t, t
        return s0, s1
        # res = lerp(t, v0_copy, v1_copy)
        # return maybe_torch(res, is_torch)

    # Calculate initial angle between v0 and v1
    theta_0 = np.arccos(dot)
    sin_theta_0 = np.sin(theta_0)

    # Angle at timestep t
    theta_t = theta_0 * t
    sin_theta_t = np.sin(theta_t)

    # Finish the slerp algorithm
    s0 = np.sin(theta_0 - theta_t) / sin_theta_0
    s1 = sin_theta_t / sin_theta_0

    return s0, s1
    # res = s0 * v0_copy + s1 * v1_copy

    # return maybe_torch(res, is_torch)


def maybe_torch(v: np.ndarray, is_torch: bool):
    if is_torch:
        return torch.from_numpy(v)
    return v


def normalize(v: np.ndarray, eps: float):
    norm_v = np.linalg.norm(v)
    if norm_v > eps:
        v = v / norm_v
    return v

In [52]:
from masks import Mask, MaskConfig
from masks import LinearsWithMasks, EmbeddingsWithMasks, RMSNormsWithMasks

def odd_one_out(masked_module: nn.Module, selected_idx: int):  
    assert selected_idx is not None and isinstance(selected_idx, int), (
        "Must provide valid model index. Check whether passed index is `int`"
    )
    masks_modules = []
    for name, child in masked_module.named_children():
        if not isinstance(child, nn.ModuleList): continue
        assert selected_idx < len(child), (
            f"There are only {len(child)} component models, "
            f"passed model index is {selected_idx}"
        )
        ## exclude sub_module that is None, aka bias_masks.
        if all(isinstance(sub_module, Mask) for sub_module in child):
            masks_modules.append(child)
        
    for masks in masks_modules:
        for i, mask in enumerate(masks):
            value = 1.0 if i == selected_idx else 0.0
            with torch.no_grad():
                mask.weight.data.fill_(value)

def random_init(masked_module: nn.Module):
    masks_modules = []
    for name, child in masked_module.named_children():
        if not isinstance(child, nn.ModuleList): continue
        ## exclude sub_module that is None, aka bias_masks.
        if all(isinstance(sub_module, Mask) for sub_module in child):
            masks_modules.append(child)
        
    for masks in masks_modules:
        for i, mask in enumerate(masks):
            with torch.no_grad():
                random_value = torch.rand_like(mask.weight.data)
                mask.weight.data = random_value
                
def uniform_init(masked_module: nn.Module, factors: List[float]):  
    masks_modules = []
    for name, child in masked_module.named_children():
        if not isinstance(child, nn.ModuleList): continue
        assert len(factors) == len(child), (
            f"There are {len(child)} component models, "
            f"but your passed factors have {len(factors)} values."
        )
        ## exclude sub_module that is None, aka bias_masks.
        if all(isinstance(sub_module, Mask) for sub_module in child):
            masks_modules.append(child)

    for masks in masks_modules:
        for factor, mask in zip(factors, masks):
            with torch.no_grad():
                mask.weight.data.fill_(factor)

In [66]:
def compute_t(weight_name, parameters, num_layers):
    """
    Computes the blending factor for a weight based on layer index and conditions.
    
    Args:
        weight_name (str): Name of the weight.
        parameters (dict): Mapping of conditions to blending values.
        num_layers (int): Total number of layers in the model.
        
    Returns:
        float: Computed blending value.
    """
    anchors = parameters.get("default")
    if not isinstance(anchors, list):
        anchors = [anchors]

    for filter_name in parameters.keys():
        if filter_name in weight_name:
            anchors = parameters.get(filter_name)
            break
            
    match = re.search(r"layers\.([^\.]*)\.", weight_name)
    if match:
        layer_idx = int(match.group(1))
        layer_t = layer_idx / (num_layers - 1)
        scaled = layer_t * (len(anchors) - 1)
        i0 = math.floor(scaled)
        i1 = min(len(anchors) - 1, i0 + 1)
        frac = scaled - i0
        
        blend_value = (1 - frac) * anchors[i0] + frac * anchors[i1]
    else:
        blend_value = anchors[0]
        
    return blend_value

def assign_spherical_masks(masks, s0, s1):
    assert len(masks) == 2, (
        "Spherical initialization only supports 2 models. "
        f"Found {len(masks)}."
    )
    with torch.no_grad():
        masks[0].weight.data.fill_(s0)
        masks[1].weight.data.fill_(s1)
        
def spherical_init(
    masked_module: nn.Module, 
    module_name: str,
    parameters: Mapping = None,
    num_layers: int  = None,
):      
    t = compute_t(module_name, parameters, num_layers)
    if isinstance(masked_module, LinearsWithMasks):
        weight_masks = masked_module.weight_masks
        bias_masks = masked_module.bias_masks
        v0, v1 = (x.weight.data for x in masked_module.linears)
        s0, s1 = slerp(t, v0, v1)
        assign_spherical_masks(weight_masks, s0, s1)
        
        if all(isinstance(mask, Mask) for mask in bias_masks):
            v0, v1 = (x.bias.data for x in masked_module.linears)
            s0, s1 = slerp(t, v0, v1)
            assign_spherical_masks(bias_masks, s0, s1)
        
    elif isinstance(masked_module, EmbeddingsWithMasks):
        masks = masked_module.masks
        v0, v1 = (x.weight.data for x in masked_module.embeddings)
        s0, s1 = slerp(t, v0, v1)
        assign_spherical_masks(masks, s0, s1)
        
    elif isinstance(masked_module, RMSNormsWithMasks):
        masks = masked_module.masks
        v0, v1 = (x.weight.data for x in masked_module.rms_norms)
        s0, s1 = slerp(t, v0, v1)
        assign_spherical_masks(masks, s0, s1)

    else:
        raise ValueError(
            f"Does not support class {type(masked_module).__name__} yet."
        )

In [67]:
def find_masked_modules(module):
    masked_module_names = []
    for parent_name, parent_module in module.named_modules():
        for name, child in parent_module.named_children():
            full_child_name = f"{parent_name}.{name}" if parent_name else name
            if ("WithMasks" in type(child).__name__):
                masked_module_names.append(full_child_name)

    return masked_module_names

INIT_MAP = dict(
    random=random_init,
    odd_one_out=odd_one_out,
    uniform=uniform_init,
    spherical=spherical_init
)

def init_(root_module, strategy="random", **kwargs):
    init_method = INIT_MAP[strategy]
    masked_module_names = find_masked_modules(root_module)
    
    for module_name in tqdm(masked_module_names, desc="Setting up masks"):
        module_names = module_name.split(".")
        target_module = root_module
        for m_name in module_names:
            target_module = getattr(target_module, m_name)

        if strategy == "spherical":
            kwargs["module_name"] = module_name
            
        init_method(target_module, **kwargs)

def initialize_masks(merger, mask_init):
    # Initialize masks based on config
    mask_strategy = mask_init["strategy"]
    if mask_strategy == "uniform":
        if not mask_init["factors"]:
            raise ValueError(
                "Factors must be provided for uniform strategy"
            )
        logger.info(f"Applying uniform masks with factors = {factors}.")
        factors = mask_init["factors"]
        init_(merger.merger, strategy="uniform", factors=factors)
        
    elif mask_strategy == "random":
        logger.info(f"Applying random masks.")
        init_(merger.merger, strategy="random")
        
    elif mask_strategy == "spherical":
        logger.info(f"Applying spherical masks.")
        parameters = mask_init["parameters"]
        num_layers = len(merger.merger.model.layers)
        init_(merger.merger, strategy="spherical", 
              parameters=parameters, num_layers=num_layers)
    else:
        raise ValueError(
            f"Unknown mask initialization strategy: {mask_strategy}."
        )

# _set_masks(merger.merger, strategy="uniform", factors=factors)

In [72]:
initialize_masks(merger, mask_init)

[2025-01-18 13:19:04,016] [INFO] [train.initialize_masks:50] [PID:124456] [RANK:0] Applying spherical masks.[39m


Setting up masks: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 255/255 [00:24<00:00, 10.56it/s]


In [75]:
merger.save_merged("./test-slerp")

Merging masked modules: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 255/255 [00:02<00:00, 96.26it/s]


In [30]:
module_name, target_module = _set_masks(merger.merger, strategy="uniform", factors=[0.5, 0.5])

Setting up masks:   0%|                                                                                                                                                                                        | 0/255 [00:00<?, ?it/s]


In [45]:
module_name, target_module

('lm_head',
 LinearsWithMasks(
   (linears): ModuleList(
     (0-1): 2 x Linear(in_features=3072, out_features=128256, bias=False)
   )
   (weight_masks): ModuleList(
     (0-1): 2 x Mask(mask_mode=vector_input)
   )
   (weight_masks_constrainer): Constrainer(constrain_mode=identity)
   (bias_masks): ModuleList(
     (0-1): 2 x None
   )
   (bias_masks_constrainer): Constrainer(constrain_mode=identity)
 ))

In [107]:
merger.get_masks_state_dict()

{'merger.model.embed_tokens.masks.0.weight': tensor([0.5000, 0.5000, 0.5000,  ..., 0.5000, 0.5000, 0.5000],
        dtype=torch.bfloat16),
 'merger.model.embed_tokens.masks.1.weight': tensor([0.5000, 0.5000, 0.5000,  ..., 0.5000, 0.5000, 0.5000],
        dtype=torch.bfloat16),
 'merger.model.layers.0.self_attn.q_proj.weight_masks.0.weight': tensor([1., 1., 1.,  ..., 1., 1., 1.], dtype=torch.bfloat16),
 'merger.model.layers.0.self_attn.q_proj.weight_masks.1.weight': tensor([0., 0., 0.,  ..., 0., 0., 0.], dtype=torch.bfloat16),
 'merger.model.layers.0.self_attn.k_proj.weight_masks.0.weight': tensor([1., 1., 1.,  ..., 1., 1., 1.], dtype=torch.bfloat16),
 'merger.model.layers.0.self_attn.k_proj.weight_masks.1.weight': tensor([0., 0., 0.,  ..., 0., 0., 0.], dtype=torch.bfloat16),
 'merger.model.layers.0.self_attn.v_proj.weight_masks.0.weight': tensor([1., 1., 1.,  ..., 1., 1., 1.], dtype=torch.bfloat16),
 'merger.model.layers.0.self_attn.v_proj.weight_masks.1.weight': tensor([0., 0., 0.,  .