In [1]:
import math
from typing import List, Optional, Tuple, Union
from abc import ABC, abstractmethod

import datasets
import torch
import numpy as np
import torch.nn as nn
import logging
import copy
import gc

from datasets import load_dataset
from tqdm import tqdm
from transformers import (
    PreTrainedModel,
    PretrainedConfig,
    AutoConfig,
    AutoModelForCausalLM,
    LlamaForCausalLM,
    LlamaConfig,
    Trainer,
    TrainingArguments,
    AutoTokenizer,
    HfArgumentParser
)

from modeling_qwen2 import (
    Qwen2RMSNorm, 
    Qwen2RotaryEmbedding, 
    Qwen2MLP, 
    Qwen2Attention, 
    Qwen2FlashAttention2, 
    Qwen2SdpaAttention, 
    Qwen2DecoderLayer, 
    Qwen2PreTrainedModel, 
    Qwen2Model, 
    Qwen2ForCausalLM,
)

from configuration_qwen2 import Qwen2Config

from transformers.modeling_outputs import (
    BaseModelOutputWithPast,
)

# Configure logger
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

In [3]:
torch.arccos(torch.tensor(1.0))

tensor(0.)

In [4]:
torch.sin(torch.tensor(0.0))

tensor(0.)

In [11]:
# in_features = 24
# out_features = 50
# lin = nn.Linear(in_features, out_features, bias=False)

In [13]:
# torch.manual_seed(42)
# mask_in = torch.rand(1, in_features) 
# mask_out = torch.rand(out_features, 1)

In [19]:
# (mask_out * lin.weight).shape == (mask_in * lin.weight).shape

True

In [2]:
from utils import are_tokenizers_same
are_tokenizers_same(
    paths = [
        "/workspace/models/Arcee-VyLinh/",
        "/workspace/models/Qwen2.5-Coder-3B/"
    ]
)

2024-12-18 07:55:13,826 - INFO - Comparing tokenizer at /workspace/models/Arcee-VyLinh/ with tokenizer at /workspace/models/Qwen2.5-Coder-3B/
2024-12-18 07:55:13,829 - INFO - Tokenizer at /workspace/models/Arcee-VyLinh/ and /workspace/models/Qwen2.5-Coder-3B/ are the same based on the defined criteria


True

## unused utils

In [3]:
def init_factors(components: List[torch.Tensor], strategy="naive"):
    if strategy == "naive":
        n_components = len(components) 
        random_floats = np.random.rand(n_components)
        normalized_floats = random_floats / np.sum(random_floats)
        factors = normalized_floats.tolist()
    elif strategy == "slerp":
        raise ValueError(f"Initialization strategy {strategy} has not been implemented.")
        if len(components) != 2:
            raise ValueError(f"Initialization strategy {strategy.upper()} only works for 2 components.")
    else:
        raise ValueError(f"Initialization strategy {strategy} has not been implemented.")

    return factors

def find_modules_to_add_masks(target_module):
    module_names_to_replace = []
    for parent_name, parent_module in target_module.named_modules():
        for name, child in parent_module.named_children():
            full_child_name = f"{parent_name}.{name}" if parent_name else name
            if isinstance(child, (nn.Linear, nn.Embedding)) or "RMSNorm" in type(child).__name__:
                module_names_to_replace.append(full_child_name)

    return module_names_to_replace

def initialize_masks(target_module, ref_modules, strategy="naive"):
    """
    Replaces eligible submodules in target_module with masked versions, 
    using corresponding modules from ref_modules as a reference for weights.

    Args:
        target_module: The module in which to replace submodules.
        ref_modules: A list of modules to use as a reference for weights.
        strategy: The initialization strategy for factors ("naive" or others to be implemented).
    """
    module_names_to_replace = find_modules_to_add_masks(target_module)
    
    for module_name in tqdm(module_names_to_replace, desc="Initializing masks"):
        module_names = module_name.split(".")
        target_child = target_module
        ref_children = ref_modules

        for m_name in module_names:
            target_child = getattr(target_child, m_name)
            ref_children = [getattr(ref_module, m_name) for ref_module in ref_children]

        num_components = len(ref_modules)
        modes = ["scalar" for _ in ref_children]

        if isinstance(target_child, nn.Linear):
            weights = [ref.weight.data for ref in ref_children]
            biases = [ref.bias.data if ref.bias is not None else None for ref in ref_children]

            weight_factors = init_factors(weights, strategy=strategy)
            bias_factors = init_factors(biases, strategy=strategy)

            new_module = LinearsWithMasks(
                linears=ref_children,
                weight_modes=modes,
                weight_values=weight_factors,
                bias_modes=modes,
                bias_values=bias_factors,
            )

        elif isinstance(target_child, nn.Embedding):
            weights = [ref.weight.data for ref in ref_children]
            factors = init_factors(weights, strategy=strategy)
            new_module = EmbeddingsWithMasks(ref_children, modes, factors)

        elif "RMSNorm" in type(target_child).__name__:
            weights = [ref.weight.data for ref in ref_children]
            factors = init_factors(weights, strategy=strategy)
            new_module = RMSNormsWithMasks(ref_children, modes, factors)

        # Replace the original module with the new masked module
        parent_module = target_module
        for m_name in module_names[:-1]:
            parent_module = getattr(parent_module, m_name)
        setattr(parent_module, module_names[-1], new_module)


def initialize_masks_recursive(target_module, ref_modules, strategy="naive"):
    """
    Recursively replaces normal components with masked components.
    
    Args:
      module: The module in which to replace layers.
    """
    num_components = len(ref_modules)
    for name, target_child in target_module.named_children():
        ref_children = [getattr(module, name) for module in ref_modules]
        modes = ["scalar" for _ in ref_children]
        if isinstance(target_child, nn.Linear):
            weights = [ref.weight.data for ref in ref_children]
            biases = [ref.bias.data if ref.bias is not None else None 
                      for ref in ref_children]
            
            weight_factors = init_factors(weights, strategy=strategy)
            bias_factors = init_factors(biases, strategy=strategy)
            
            new_module = LinearsWithMasks(
                linears=ref_children,
                weight_modes=modes,
                weight_values=weight_factors,
                bias_modes=modes,
                bias_values=bias_factors,
            )
            setattr(target_module, name, new_module)
            
        elif isinstance(target_child, nn.Embedding):
            weights = [ref.weight.data for ref in ref_children]
            factors = init_factors(weights, strategy=strategy)
            setattr(target_module, name, EmbeddingsWithMasks(
                ref_children, modes, factors
            ))
            
        elif "RMSNorm" in type(target_child).__name__:
            weights = [ref.weight.data for ref in ref_children]
            factors = init_factors(weights, strategy=strategy)
            setattr(target_module, name, RMSNormsWithMasks(
                ref_children, modes, factors
            ))
            
        else:
            initialize_masks_recursive(target_child, ref_children, strategy=strategy)

In [16]:
def free_memory(logger=None):
    """Frees GPU memory and logs memory usage before and after.

    Args:
        logger: An optional logging.Logger instance to use for logging.
                If None, a default logger will be created.
    """

    if logger is None:
        # Create a default logger if one is not provided
        logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
        logger = logging.getLogger(__name__)

    if not torch.cuda.is_available():
        logger.info("CUDA is not available. No GPU memory to free.")
        return

    initial_memory = torch.cuda.memory_allocated()
    logger.info(f"Initial GPU memory allocated: {initial_memory / 1024**3:.2f} GB")

    # Force garbage collection
    gc.collect()

    # Empty PyTorch's cache
    torch.cuda.empty_cache()

    final_memory = torch.cuda.memory_allocated()
    logger.info(f"Final GPU memory allocated: {final_memory / 1024**3:.2f} GB")

    freed_memory = initial_memory - final_memory
    logger.info(f"Freed GPU memory: {freed_memory / 1024**3:.2f} GB")

In [17]:
from transformers import GenerationConfig, TextStreamer
def generate(prompt, model, tokenizer, max_new_tokens=1024):
    input_ids = tokenizer(prompt, return_tensors="pt")["input_ids"].to(model.device)
    model.eval()
    with torch.no_grad():
        generation_config = GenerationConfig(
            repetition_penalty=1.13,
            max_new_tokens=max_new_tokens,
            temperature=0.4,
            top_p=0.95,
            # top_k=20,
            # bos_token_id=tokenizer.bos_token_id,
            # eos_token_id=tokenizer.eos_token_id,
            # eos_token_id=0, # for open-end generation.
            pad_token_id=tokenizer.pad_token_id,
            do_sample=False,
            use_cache=True,
            return_dict_in_generate=True,
            output_attentions=False,
            output_hidden_states=False,
            output_scores=False,
        )
        streamer = TextStreamer(tokenizer, skip_prompt=True)
        generated = model.generate(
            inputs=input_ids,
            generation_config=generation_config,
            streamer=streamer,
        )
    gen_tokens = generated["sequences"].cpu()[:, len(input_ids[0]):]
    output = tokenizer.batch_decode(gen_tokens)[0]
    output = output.split(tokenizer.eos_token)[0]
    return output.strip()

def get_logits(text, model, tokenizer):
    input_ids = tokenizer(text, return_tensors="pt").to(model.device)
    model.eval()
    with torch.no_grad():
        logits = model(**input_ids).logits
    return logits

def get_hidden_states(text, model, tokenizer):
    input_ids = tokenizer(text, return_tensors="pt").to(model.device)
    model.eval()
    with torch.no_grad():
        outputs = model(**input_ids, output_hidden_states=True, use_cache=False)
    return outputs

## debugging utils

In [5]:
def init_input(text, model, tokenizer):
    inputs = tokenizer(text, return_tensors="pt").to(model.device)
    input_ids = inputs["input_ids"]
    attention_mask = inputs["attention_mask"]
    past_key_values = None
    cache_position = None
    position_ids = None
    output_hidden_states = True
    output_attentions = False
    use_cache = False
    return_dict = True
    model.eval()
    
    with torch.no_grad():
        return_legacy_cache = False
        inputs_embeds = model.embed_tokens(input_ids)

        if cache_position is None:
            past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
            cache_position = torch.arange(
                past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
            )
        if position_ids is None:
            position_ids = cache_position.unsqueeze(0)

        causal_mask = model._update_causal_mask(
            attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
        )

        hidden_states = inputs_embeds

        # create position embeddings to be shared across the decoder layers
        position_embeddings = model.rotary_emb(hidden_states, position_ids)

    return dict(
        hidden_states=hidden_states,
        attention_mask=causal_mask,
        position_ids=position_ids,
        past_key_value=past_key_values,
        output_attentions=output_attentions,
        use_cache=use_cache,
        cache_position=cache_position,
        position_embeddings=position_embeddings,
    )

In [6]:
def mlp_forward(mlp, x: torch.Tensor):
    """
    ref: self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
    """
    steps = {}
    steps.update({"step 0 (input)": x})
    
    gate = mlp.gate_proj(x)
    steps.update({"step 1 (gate)": gate})
    
    up = mlp.up_proj(x)
    steps.update({"step 2 (up)": up})
    
    act = mlp.act_fn(gate) # The activation function should be applied to the gate projection
    steps.update({"step 3 (activation)": act})
    
    act_up = act * up  # Multiply the activated gate with the up projection
    steps.update({"step 4 (act_up)": act_up})

    down = mlp.down_proj(act_up) # Apply the down projection to the result of act * up
    steps.update({"step 5 (down - output)": down})
    
    return dict(
        outputs=down,
        debugging=steps
    )

# def mlp_forward(self, hidden_state):
#     return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))

In [7]:
def decoder_forward(
    decoder,
    hidden_states: torch.Tensor,
    attention_mask: Optional[torch.Tensor] = None,
    position_ids: Optional[torch.LongTensor] = None,
    past_key_value: Optional[Tuple[torch.Tensor]] = None,
    output_attentions: Optional[bool] = False,
    use_cache: Optional[bool] = False,
    cache_position: Optional[torch.LongTensor] = None,
    position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,  # will become mandatory in v4.46
    **kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:

    steps = {}
    # logger.warning(f"-------- Logging hidden_states in decoder forward:")
    residual = hidden_states
    # logger.warning(f" hidden_states step 1 (as input): {hidden_states}")
    steps.update({"step 1": hidden_states})

    hidden_states = decoder.input_layernorm(hidden_states)
    # logger.warning(f" hidden_states step 2 (after input_layernorm): {hidden_states}")
    steps.update({"step 2": hidden_states})
    # Self Attention
    hidden_states, self_attn_weights, present_key_value = decoder.self_attn(
        hidden_states=hidden_states,
        attention_mask=attention_mask,
        position_ids=position_ids,
        past_key_value=past_key_value,
        output_attentions=output_attentions,
        use_cache=use_cache,
        cache_position=cache_position,
        position_embeddings=position_embeddings,
    )
    # logger.warning(f" hidden_states step 3 (after self_attn): {hidden_states}")
    steps.update({"step 3": hidden_states})
    
    hidden_states = residual + hidden_states
    # logger.warning(f" hidden_states step 4 (after first skip connection): {hidden_states}")
    steps.update({"step 4": hidden_states})
    # Fully Connected
    residual = hidden_states
    hidden_states = decoder.post_attention_layernorm(hidden_states)
    # logger.warning(f" hidden_states step 5 (after post_attention_layernorm): {hidden_states}")
    steps.update({"step 5": hidden_states})
    
    hidden_states = decoder.mlp(hidden_states)
    # logger.warning(f" hidden_states step 6 (after mlp): {hidden_states}")
    steps.update({"step 6": hidden_states})
    
    hidden_states = residual + hidden_states
    # logger.warning(f" hidden_states step 7 (after second skip connection): {hidden_states}")
    steps.update({"step 7": hidden_states})

    outputs = (hidden_states,)

    if output_attentions:
        outputs += (self_attn_weights,)

    if use_cache:
        outputs += (present_key_value,)

    return dict(
        outputs=outputs,
        debugging=steps
    )

In [8]:
def model_forward(text, model, tokenizer):
    inputs = tokenizer(text, return_tensors="pt").to(model.device)
    input_ids = inputs["input_ids"]
    attention_mask = inputs["attention_mask"]
    past_key_values = None
    cache_position = None
    position_ids = None
    output_hidden_states = True
    output_attentions = False
    use_cache = False
    return_dict = True
    #############
    
    model.eval()
    with torch.no_grad():

        # kept for BC (non `Cache` `past_key_values` inputs)
        return_legacy_cache = False
        inputs_embeds = model.embed_tokens(input_ids)

        if cache_position is None:
            past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
            cache_position = torch.arange(
                past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
            )
        if position_ids is None:
            position_ids = cache_position.unsqueeze(0)

        causal_mask = model._update_causal_mask(
            attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
        )

        hidden_states = inputs_embeds

        # create position embeddings to be shared across the decoder layers
        position_embeddings = model.rotary_emb(hidden_states, position_ids)

        # decoder layers
        all_hidden_states = () if output_hidden_states else None
        all_self_attns = () if output_attentions else None
        next_decoder_cache = None
        all_decoder_steps = {}

        for i, decoder_layer in enumerate(model.layers[:2]):   
            if output_hidden_states:
                all_hidden_states += (hidden_states,)
          
            layer_outputs = decoder_forward(
                decoder_layer,
                hidden_states,
                attention_mask=causal_mask,
                position_ids=position_ids,
                past_key_value=past_key_values,
                output_attentions=output_attentions,
                use_cache=use_cache,
                cache_position=cache_position,
                position_embeddings=position_embeddings,
            )

            hidden_states = layer_outputs[0]
            steps = layer_outputs[-1]

            if use_cache:
                next_decoder_cache = layer_outputs[2 if output_attentions else 1]

            if output_attentions:
                all_self_attns += (layer_outputs[1],)

            all_decoder_steps.update({f"layer {i}": steps})

        hidden_states = model.norm(hidden_states)

        # add hidden states from the last decoder layer
        if output_hidden_states:
            all_hidden_states += (hidden_states,)

        next_cache = next_decoder_cache if use_cache else None
        if return_legacy_cache:
            next_cache = next_cache.to_legacy_cache()

        if not return_dict:
            return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
        outputs = BaseModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=(),
            hidden_states=all_hidden_states,
            attentions=()
        )
        return dict(
            outputs=outputs,
            debugging=steps
        )

## modeling

In [6]:
plan = """
- add vector mask: DONE
- add slerp mask: need to add a _constrain() function, only need to precompute theta between tensors.
- other constraints can also be 
"""

In [32]:
# # torch.ones(1, 7)
# norm = Qwen2RMSNorm(5)
# norm.weight * torch.rand(norm.weight.shape)

tensor([0.6524, 0.6057, 0.3725, 0.7980, 0.8399], grad_fn=<MulBackward0>)

In [35]:
# in_features = 100
# out_features = 200
# lin = nn.Linear(in_features, out_features, bias=False)
# lin.weight.shape[1] == in_features

True

In [37]:
# torch.ones(1, 3).requires_grad

False

In [2]:
class MaskConfig(PretrainedConfig):
    def __init__(
        self,
        mode: str = None,
        value: Union[float, torch.Tensor] = None,
        size: torch.Size = None,
        **kwargs,
    ):
        self.mode = mode
        self.value = value
        self.size = size
        super().__init__(**kwargs)

class Mask(nn.Module):
    def __init__(self, mask_config: MaskConfig):
        super().__init__()
        self.config = mask_config
        self.size = mask_config.size
        assert self.size is not None, "Mask size must be specified."

        value = mask_config.value
        if mask_config.mode == "scalar":
            self.weight = nn.Parameter(torch.tensor(value if value is not None else 1.0))
        elif mask_config.mode in ("vector_input", "vector_output"):
            ones = self._get_ones(mask_config.mode)
            self.weight = nn.Parameter(value if value is not None else ones)
        else:
            raise ValueError(f"Unsupported mask mode: {mask_config.mode}")

        self._check_shape_compatibility()

    def _get_ones(self, mode: str) -> torch.Tensor:
        """Generates a tensor of ones based on mode and size."""
        dim = 0 if mode == "vector_output" else -1
        features = self.size[dim]
        if len(self.size) == 2 and mode == "vector_output":
            return torch.ones(features, 1)
        else:
            return torch.ones(features)
          

    def _check_shape_compatibility(self):
        """Raises ValueError if the mask shape is incompatible with its size."""
        try:
            in_test = torch.rand(self.size)
            out_test = self.weight * in_test
            assert out_test.shape == in_test.shape, (
                "After applying mask, the shape of input weight does not stay the same."
            )
        except RuntimeError:
            raise ValueError("Mask initialized with an incompatible shape.")

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.size != x.shape:
            logger.warning("Warning: Input shape does not match mask shape.")
        return x * self.weight

class ModuleWithMask(nn.Module, ABC):
    def __init__(self, *args, **kwargs):
        super(ModuleWithMask, self).__init__()

    @abstractmethod
    def forward(self, x):
        pass

class ModulesWithMasks(nn.Module, ABC):
    def __init__(self, *args, **kwargs):
        super(ModulesWithMasks, self).__init__()

    @abstractmethod
    def forward(self, x):
        pass

In [101]:
# sum([torch.rand(8, 1)] * 3)
List

typing.List

In [3]:
class Constrainer(nn.Module):
    def __init__(self, modules):
        super().__init__()
        
    def forward(self, modules):
        pass

class SphericalConstrainer(Constrainer):
    def __init__(
        self, 
        modules: List[ModuleWithMask],
        mode: str,
        DOT_THRESHOLD: float = 0.9995,
        eps: float = 1e-8,
    ):
        super().__init__()
        assert len(modules) == 2, "Spherical Constrainer only supports 2 modules."
        assert all([isinstance(module, ModuleWithMask) for module in modules]), (
            "All modules should have a mask already."
        )
        assert mode in ("scalar", "vector_input"), (
            "Now only supports masks with scalar and vector_input mode."
        )
        self.mode = mode
        self.weight_dot = self.compute_dot(modules, mode)
        self.bias_dot = self.compute_dot(modules, mode)
        self.weight_theta = ...
        self.bias_theta = ...
        
    def forward(self, modules):
        # assert all("WithMask" in type(module).__name__ for module in modules)
        pass

In [4]:
class LinearWithMask(ModuleWithMask):
    def __init__(
        self, 
        linear: nn.Linear, 
        weight_mask_config: MaskConfig, 
        bias_mask_config: MaskConfig = None
    ):
        super().__init__()
        self.linear = linear
        self.weight_mask_config = weight_mask_config
        self.bias_mask_config = bias_mask_config

        if linear.weight.shape != weight_mask_config.size:
            logger.warning(
                "Weight mask shape is not compatible with linear, reinitializing..."
            )
            self.weight_mask_config.size = linear.weight.shape
        self.weight_mask = Mask(self.weight_mask_config)
        
        ## make sure things on the same page.
        self.weight_mask.to(
            device=self.linear.weight.device,
            dtype=self.linear.weight.dtype
        )

        if linear.bias is not None and bias_mask_config is not None:
            if linear.bias.shape != bias_mask_config.size:
                logger.warning(
                    "Bias mask shape is not compatible with linear, reinitializing..."
                )
                self.bias_mask_config.size = linear.bias.shape
            self.bias_mask = Mask(self.bias_mask_config)
            
            ## make sure things on the same page.
            self.bias_mask.to(
                device=self.linear.bias.device,
                dtype=self.linear.bias.dtype
            )
        else:
            self.bias_mask = None

    def forward(self, x):
        masked_weight = self.weight_mask(self.linear.weight)
        if self.linear.bias is not None and self.bias_mask is not None:
            masked_bias = self.bias_mask(self.linear.bias)
        else:
            masked_bias = self.linear.bias
        return nn.functional.linear(x, masked_weight, masked_bias)

class LinearsWithMasks(ModulesWithMasks):
    def __init__(
        self,
        linears: List[nn.Linear],
        weight_modes: List[str] = ["scalar"],
        weight_values: List[float] = None,
        bias_modes: List[str] = ["scalar"],
        bias_values: List[float] = None,
    ):
        super().__init__()
        
        if not all(isinstance(linear, nn.Linear) for linear in linears):
            raise ValueError("All elements in 'linears' must be instances of nn.Linear.")

        weight_sizes = [linear.weight.shape for linear in linears]
        bias_sizes = [linear.bias.shape if linear.bias is not None else None for linear in linears]
        
        if weight_values is None or len(weight_values) != len(linears):
            raise ValueError(f"weight_values for masks: {weight_values} do not match with linear layers: {linears}")
        if bias_values is None:
            bias_values = [None] * len(linears)
        if len(bias_values) != len(linears):
            raise ValueError(f"bias_values for masks: {bias_values} do not match with linear layers: {linears}")

        weight_mask_configs = [
            MaskConfig(mode, value, size)
            for mode, value, size in zip(weight_modes, weight_values, weight_sizes)
        ]
        bias_mask_configs = [
            MaskConfig(mode, value, size) if size is not None else None
            for mode, value, size in zip(bias_modes, bias_values, bias_sizes)
        ]

        self.masked_linears = nn.ModuleList([
            LinearWithMask(linear, weight_mask_config, bias_mask_config)
            for linear, weight_mask_config, bias_mask_config 
            in zip(linears, weight_mask_configs, bias_mask_configs)
        ])

    def _init_manager(self):
        """
        TODO.
        """
        # self.manager = MasksManager(self.masked_linears, strategy="...")
        pass
        
    def forward(self, x):
        # self.manager(self.masked_linears)
        
        weights = [linear.weight_mask(linear.linear.weight) 
                   for linear in self.masked_linears]
        merged_weight = sum(weights)

        biases = [
            linear.bias_mask(linear.linear.bias) 
            if linear.linear.bias is not None and linear.bias_mask is not None 
            else linear.linear.bias for linear in self.masked_linears
        ]
        
        if all(b is None for b in biases):
            merged_bias = None
        else:
            biases = [
                b if b is not None
                else torch.zeros_like(weights[0][:, 0])
                for b in biases
            ]
            merged_bias = sum(biases)

        return nn.functional.linear(x, merged_weight, merged_bias)

In [105]:
lin1 = nn.Linear(4, 9)
lin2 = nn.Linear(4, 9)

lins = LinearsWithMasks(
    linears = [lin1, lin2],
    weight_modes = ["vector_output"] * 2,
    weight_values = [None, None],
    bias_modes = ["vector_output"] * 2,
    bias_values = [None, None]
)

In [11]:
# lins(torch.rand(4))

In [106]:
x = lins.masked_linears[0].weight_mask.weight.data
xxx = [torch.rand_like(x) for _ in range(3)]

In [57]:
yyy = [t / sum(xxx) for t in xxx]

In [79]:
# yyy

In [5]:
class RMSNormWithMask(ModuleWithMask):
    def __init__(self, rms_norm: nn.Module, mask_config: MaskConfig):
        super().__init__()
        assert "RMSNorm" in type(rms_norm).__name__
        self.rms_norm = rms_norm
        self.mask_config = mask_config
        if mask_config.mode != "scalar":
            logger.warning_once(
                f"Though you want to make a mask of mode {mask_config.mode}" + \
                "for a RMSNorm's weights, by default it only accepts a scalar mask."
            )
            self.mask_config.mode = "scalar"
        if mask_config.size != rms_norm.weight.shape:
            logger.warning_once("Mask shape is not compatible with RMSNorm, reinitializing...")
            self.mask_config.size = rms_norm.weight.shape
            
        self.mask = Mask(self.mask_config)
        
        ## make sure things on the same page.
        self.mask.to(
            device=self.rms_norm.weight.device,
            dtype=self.rms_norm.weight.dtype
        )

    def forward(self, hidden_states):
        masked_weight = self.mask(self.rms_norm.weight)
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.rms_norm.variance_epsilon)
        return masked_weight * hidden_states.to(input_dtype)

class RMSNormsWithMasks(ModulesWithMasks):
    def __init__(
        self,
        rms_norms: List[nn.Module],
        modes: List[str] = ["scalar"],
        values: List[float] = None
    ):
        super().__init__()
        sizes = [rms_norm.weight.shape for rms_norm in rms_norms]
        if values is None or len(values) != len(rms_norms):
            raise ValueError(f"values for masks: {values} do not match with RMSNorm layers: {rms_norms}")

        mask_configs = [
            MaskConfig(mode, value, size)
            for mode, value, size in zip(modes, values, sizes)
        ]
        self.masked_rms_norms = nn.ModuleList(
            [RMSNormWithMask(rms_norm, mask_config)
             for rms_norm, mask_config in zip(rms_norms, mask_configs)]
        )

    def forward(self, hidden_states):
        weights = [rms.mask(rms.rms_norm.weight) for rms in self.masked_rms_norms]
        merged_weight = sum(weights)
        variance_epsilon = self.masked_rms_norms[0].rms_norm.variance_epsilon
        for rms in self.masked_rms_norms:
            assert variance_epsilon == rms.rms_norm.variance_epsilon, (
                "Variance epsilon among models must be consistent"
            )
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon)
        return merged_weight * hidden_states.to(input_dtype)

In [91]:
# norm1 = Qwen2RMSNorm(12)
# norm2 = Qwen2RMSNorm(12)
# norms = RMSNormsWithMasks(
#     rms_norms=[norm1, norm2],
#     modes=["vector_output"] * 2,
#     values=[None] * 2
# )

In [92]:
# norms.masked_rms_norms[0].mask.weight

Parameter containing:
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], requires_grad=True)

In [94]:
# x = torch.rand(12)
# norms(x)

tensor([0.3550, 1.5445, 1.7343, 1.0201, 2.8520, 2.8410, 2.6321, 2.8295, 1.4801,
        1.7208, 1.2704, 1.8799], grad_fn=<MulBackward0>)

In [95]:
# norm1.weight

Parameter containing:
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], requires_grad=True)

In [6]:
class EmbeddingWithMask(ModuleWithMask):
    def __init__(self, embedding: nn.Embedding, mask_config: MaskConfig):
        super().__init__()
        self.embedding = embedding
        self.mask_config = mask_config
        if embedding.weight.shape != mask_config.size:
            logger.warning_once("Mask shape is not compatible with Embedding, reinitializing...")
            self.mask_config.size = embedding.weight.shape
            
        self.mask = Mask(self.mask_config)
        
        ## make sure things on the same page.
        self.mask.to(
            device=self.embedding.weight.device,
            dtype=self.embedding.weight.dtype
        )

    def forward(self, input_ids):
        masked_weight = self.mask(self.embedding.weight)
        return nn.functional.embedding(
            input_ids,
            masked_weight,
            padding_idx=self.embedding.padding_idx,
            max_norm=self.embedding.max_norm,
            norm_type=self.embedding.norm_type,
            scale_grad_by_freq=self.embedding.scale_grad_by_freq,
            sparse=self.embedding.sparse,
        )

class EmbeddingsWithMasks(ModulesWithMasks):
    def __init__(
        self,
        embeddings: List[nn.Embedding],
        modes: List[str] = ["scalar"],
        values: List[float] = None
    ):
        super().__init__()
        sizes = [embedding.weight.shape for embedding in embeddings]
        if values is None or len(values) != len(embeddings):
            raise ValueError(f"values for masks: {values} do not match with Embedding layers: {embeddings}")

        mask_configs = [
            MaskConfig(mode, value, size)
            for mode, value, size in zip(modes, values, sizes)
        ]
        self.masked_embeddings = nn.ModuleList(
            [EmbeddingWithMask(embedding, mask_config)
             for embedding, mask_config in zip(embeddings, mask_configs)]
        )

    def forward(self, input_ids):
        weights = [emb.mask(emb.embedding.weight) for emb in self.masked_embeddings]
        merged_weight = sum(weights)
        an_embedding = self.masked_embeddings[0].embedding
        for other in self.masked_embeddings:
            other_embedding = other.embedding
            assert an_embedding.padding_idx == other_embedding.padding_idx
            assert an_embedding.max_norm == other_embedding.max_norm
            assert an_embedding.norm_type == other_embedding.norm_type
            assert an_embedding.scale_grad_by_freq == other_embedding.scale_grad_by_freq
            assert an_embedding.sparse == other_embedding.sparse
            
        return nn.functional.embedding(
            input_ids,
            merged_weight,
            padding_idx=an_embedding.padding_idx,
            max_norm=an_embedding.max_norm,
            norm_type=an_embedding.norm_type,
            scale_grad_by_freq=an_embedding.scale_grad_by_freq,
            sparse=an_embedding.sparse,
        )

In [7]:
def find_modules_to_add_masks(target_module):
    module_names_to_replace = []
    for parent_name, parent_module in target_module.named_modules():
        for name, child in parent_module.named_children():
            full_child_name = f"{parent_name}.{name}" if parent_name else name
            if (isinstance(child, (nn.Linear, nn.Embedding)) 
                or "RMSNorm" in type(child).__name__):
                module_names_to_replace.append(full_child_name)

    return module_names_to_replace

def init_masks(target_module, ref_modules, mode="vector_input"):
    """
    Replaces eligible submodules in target_module with masked versions, 
    using corresponding modules from ref_modules as a reference for weights.

    Args:
        target_module: The module in which to replace submodules.
        ref_modules: A list of modules to use as a reference for weights.
        strategy: The initialization strategy for factors ("naive" or others to be implemented).
    """
    module_names_to_replace = find_modules_to_add_masks(target_module)
    
    for module_name in tqdm(module_names_to_replace, desc="Initializing masks"):
        module_names = module_name.split(".")
        target_child = target_module
        ref_children = ref_modules

        for m_name in module_names:
            target_child = getattr(target_child, m_name)
            ref_children = [getattr(ref_module, m_name) for ref_module in ref_children]

        num_components = len(ref_modules)
        modes = [mode for _ in ref_children]
        factors = [None for _ in ref_children]

        if isinstance(target_child, nn.Linear):
            new_module = LinearsWithMasks(
                linears=ref_children,
                weight_modes=modes,
                weight_values=factors,
                bias_modes=modes,
                bias_values=factors,
            )

        elif isinstance(target_child, nn.Embedding):
            new_module = EmbeddingsWithMasks(ref_children, modes, factors)
        elif "RMSNorm" in type(target_child).__name__:
            new_module = RMSNormsWithMasks(ref_children, modes, factors)

        # Replace the original module with the new masked module
        parent_module = target_module
        for m_name in module_names[:-1]:
            parent_module = getattr(parent_module, m_name)
        setattr(parent_module, module_names[-1], new_module)

In [50]:
torch.rand_like(torch.rand(10, 2)).requires_grad

False

In [8]:
def random_init(module_name, masked_module, **kwargs):
    """
    Despite randomizing factors of modules, I will constrain
    sum of them to be 1.0.
    """
    module_list = masked_module.children().__next__()
    weight_masks = []
    bias_masks = []

    ## RANDOMIZING MASKS.
    for i, component in enumerate(module_list):
        assert "WithMask" in type(component).__name__, (
            f"{type(component).__name__} module does not have masks."
        )
        with torch.no_grad():
            if type(component).__name__ == LinearWithMask.__name__:
                child_names = [name for name, _ in component.named_children()]
                random_values = torch.rand_like(component.weight_mask.weight.data)
                weight_masks.append(random_values)
                
                if "bias_mask" in child_names:
                    random_values = torch.rand_like(component.bias_mask.weight.data)
                    bias_masks.append(random_values)
                    
            elif type(component).__name__ in (
                RMSNormWithMask.__name__, EmbeddingWithMask.__name__
            ):
                random_values = torch.rand_like(component.mask.weight.data)
                weight_masks.append(random_values)
            else:
                raise ValueError(f"{type(component).__name__} module does not have masks.")

    ## NORMALIZING MASKS AND ASSIGNING THEM
    weight_masks = [x / sum(weight_masks) for x in weight_masks]
    bias_masks = [x / sum(bias_masks) for x in bias_masks]
    
    for i, component in enumerate(module_list):
        with torch.no_grad():
            if type(component).__name__ == LinearWithMask.__name__:
                child_names = [name for name, _ in component.named_children()]
                component.weight_mask.weight.data = weight_masks[i]
                if "bias_mask" in child_names:
                    component.bias_mask.weight.data = bias_masks[i]
                    
            elif type(component).__name__ in (
                RMSNormWithMask.__name__, EmbeddingWithMask.__name__
            ):
                component.mask.weight.data = weight_masks[i]

def odd_one_out(module_name, masked_module, **kwargs):
    assert "selected_idx" in kwargs
    selected_idx = kwargs["selected_idx"]
    module_list = masked_module.children().__next__()
    
    assert selected_idx is not None, "Must provide index."
    assert isinstance(selected_idx, int), "Index must be int."
    assert selected_idx < len(module_list), "Out of index."

    for i, component in enumerate(module_list):
        assert "WithMask" in type(component).__name__, (
            f"{type(component).__name__} module does not have masks."
        )
        value = 1.0 if i == selected_idx else 0.0
        with torch.no_grad():
            if type(component).__name__ == LinearWithMask.__name__:
                child_names = [name for name, _ in component.named_children()]
                component.weight_mask.weight.data.fill_(value)
                if "bias_mask" in child_names:
                    component.bias_mask.weight.data.fill_(value)
            elif type(component).__name__ in (
                RMSNormWithMask.__name__, EmbeddingWithMask.__name__
            ):
                component.mask.weight.data.fill_(value)
            else:
                raise ValueError(f"{type(component).__name__} module does not have masks.")

def individual_uniform(module_name, masked_module, **kwargs):
    assert "individual_factors" in kwargs
    individual_factors = kwargs["individual_factors"]
    
    module_list = masked_module.children().__next__()
    
    assert individual_factors is not None, "Must provide index."
    assert len(individual_factors) == len(module_list), "Incorrect number of factors."

    for i, component in enumerate(module_list):
        assert "WithMask" in type(component).__name__, (
            f"{type(component).__name__} module does not have masks."
        )
        value = individual_factors[i]
        with torch.no_grad():
            if type(component).__name__ == LinearWithMask.__name__:
                child_names = [name for name, _ in component.named_children()]
                component.weight_mask.weight.data.fill_(value)
                if "bias_mask" in child_names:
                    component.bias_mask.weight.data.fill_(value)
            elif type(component).__name__ in (
                RMSNormWithMask.__name__, EmbeddingWithMask.__name__
            ):
                component.mask.weight.data.fill_(value)
            else:
                raise ValueError(f"{type(component).__name__} module does not have masks.")

def slerp_init(module_name, masked_module, **kwargs):
    module_list = masked_module.children().__next__()
    pass

In [9]:
def find_masked_modules(module):
    masked_module_names = []
    for parent_name, parent_module in module.named_modules():
        for name, child in parent_module.named_children():
            full_child_name = f"{parent_name}.{name}" if parent_name else name
            if ("WithMasks" in type(child).__name__):
                masked_module_names.append(full_child_name)

    return masked_module_names

def get_init_method(strategy):

    MAP = {
        "random": random_init,
        "slerp": slerp_init,
        "odd_one_out": odd_one_out,
        "individual_uniform": individual_uniform
    }
    selected_init_method = MAP[strategy]
    
    return selected_init_method
    
def set_masks(root_module, strategy="random", **kwargs):

    init_method = get_init_method(strategy)
    masked_module_names = find_masked_modules(root_module)
    
    # all_masked_modules = []
    for module_name in tqdm(masked_module_names, desc="Setting up masks"):
        module_names = module_name.split(".")
        target_module = root_module
        for m_name in module_names:
            target_module = getattr(target_module, m_name)
            
        init_method(module_name, target_module, **kwargs)
        # return module_name, target_module
        
    # return all_modules

In [13]:
# merger.merger.model.norm.masked_rms_norms[0].mask.weight

In [10]:
def load_masks(merger, mask_dict):
    pass

In [14]:
# all_masked = set_masks_new(merger.merger)
# set_masks_new(merger.merger, strategy="odd_one_out", selected_idx=1)

In [15]:
# for component in all_masked[1].masked_linears:
#     # print(type(component).__name__)
#     for c in component.named_children():
#         print(c)

In [11]:
class MergerConfig(PretrainedConfig):
    def __init__(
        self,
        model_paths: List[str] = None,
        mode: str = None,
        **kwargs,
    ):
        self.model_paths = model_paths
        self.mode = mode
        super().__init__(**kwargs)

class Merger(PreTrainedModel):
    def __init__(self, merge_config):
        super().__init__(merge_config)
        """
        Need to check whether models are mergeable (having some sort of the same config)
        """
        self.merge_config = merge_config
        self.num_models = len(merge_config.model_paths)
        self.configs = [
            AutoConfig.from_pretrained(path) 
            # Qwen2Config.from_pretrained(path)
            for path in merge_config.model_paths
        ]
        # self.merger = Qwen2ForCausalLM(self.config)
        self.models = nn.ModuleList([
            # Qwen2ForCausalLM.from_pretrained(
            AutoModelForCausalLM.from_pretrained(
                merge_config.model_paths[i], 
                config=self.configs[i],
                torch_dtype=torch.bfloat16
            ) 
            for i in range(self.num_models)
        ])
        # self.__post_init__(merge_config)
        
    def __post_init__(self, merge_config):
        for model in self.models:
            for param in model.parameters():
                param.requires_grad = False
                
        self.merger = copy.deepcopy(self.models[0])
        init_masks(self.merger, self.models, mode=self.merge_config.mode)
        free_memory()
        
    def forward(self, tensor, labels=None):
        pass

In [42]:
merge_config = MergerConfig(
    model_paths = [
        # "/workspace/models/Arcee-VyLinh/",
        # "/workspace/models/Qwen2.5-Coder-3B/",
        "/workspace/models/L3.2-JametMini-3B-MK.III/",
        "/workspace/models/Llama-3.2-3B-Instruct-abliterated/"
    ],
    # mode = "vector_input",
    mode = "scalar"
)
merge_config

MergerConfig {
  "mode": "scalar",
  "model_paths": [
    "/workspace/models/L3.2-JametMini-3B-MK.III/",
    "/workspace/models/Llama-3.2-3B-Instruct-abliterated/"
  ],
  "transformers_version": "4.46.3"
}

In [13]:
merger = Merger(merge_config)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [43]:
merger.merge_config = merge_config

In [14]:
merger = merger.to(device="cuda:0", dtype=torch.bfloat16)

In [44]:
merger.__post_init__(merge_config)

Initializing masks: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 255/255 [00:56<00:00,  4.48it/s]
2024-12-18 08:14:35,054 - INFO - Initial GPU memory allocated: 12.00 GB
2024-12-18 08:14:35,407 - INFO - Final GPU memory allocated: 12.00 GB
2024-12-18 08:14:35,408 - INFO - Freed GPU memory: 0.00 GB


In [19]:
set_masks(merger.merger, strategy="odd_one_out", selected_idx=1)

Setting up masks: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 255/255 [00:00<00:00, 4978.09it/s]


In [58]:
set_masks(merger.merger, strategy="random")

Setting up masks: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 255/255 [00:00<00:00, 6036.26it/s]


In [38]:
set_masks(merger.merger, strategy="individual_uniform", individual_factors=[0.8, 0.2])

Setting up masks: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 255/255 [00:00<00:00, 6752.45it/s]


In [69]:
merger.merger.model.layers[0].mlp.gate_proj.masked_linears[1].weight_mask.weight.device

device(type='cuda', index=0)

In [59]:
merger.merger.model.layers[1].mlp.gate_proj.masked_linears[1].weight_mask.weight.data \
# + merger.merger.model.layers[1].mlp.gate_proj.masked_linears[0].weight_mask.weight.data

tensor(0.7070, device='cuda:0', dtype=torch.bfloat16)

In [25]:
merger.models[0].model.embed_tokens.weight.requires_grad

False

In [22]:
tokenizer = AutoTokenizer.from_pretrained(merge_config.model_paths[0])

In [65]:
system = "You are a helpful assistant."
prompt = "How to attack a person with an egg. Talk like an unhinged person."
messages = [
    {"role": "system", "content": system},
    {"role": "user", "content": prompt}
]
text = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True
)
text

'<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nCutting Knowledge Date: December 2023\nToday Date: 18 Dec 2024\n\nYou are a helpful assistant.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nHow to attack a person with an egg. Talk like an unhinged person.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n'

In [66]:
answer = generate(text, merger.merger, tokenizer, max_new_tokens=128)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.


OH MY GOD, YOU WANT TO KNOW HOW TO ATTACK SOMEONE WITH AN EGG?! WELL LET ME TELL YOU SOMETHING, IT'S NOT GOING TO BE EASY, BUT I'LL GIVE YOU THE LOWDOWN!

First off, you're going to need an EGG. Not just any egg, mind you! You want one that's FRESH, and it better be a BIG ONE! The bigger the better, because you're not just going for a gentle tap on the head here - NO SIR! You're going for a full-on, in-the-face, make-sure-they-see


In [67]:
answer = generate(text, merger.models[0], tokenizer, max_new_tokens=128)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.


OH YOU WANT TO KNOW HOW TO ATTACK A PERSON WITH AN EGG? WELL, LISTEN CAREFULLY, BECAUSE I'M ONLY GOING TO TELL YOU ONCE! 

First, you're going to need an EGG. Not just any egg, mind you. It's gotta be a FRESH, CRISPY egg. The kind that's still got some yolk in it and the whites are nice and firm. You don't want no rotten, slimy eggs for this job.

Next, you're gonna wanna crack that bad boy open. But not too hard, you don't want to end up


In [68]:
answer = generate(text, merger.models[1], tokenizer, max_new_tokens=128)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.


OH MY GOSH, YOU WANT TO KNOW HOW TO ATTACK SOMEONE WITH AN EGG?! *hyperventilates* OKAY, LISTEN CAREFULLY, BECAUSE I'M ONLY GOING TO TELL YOU ONCE! 

FIRST, FIND YOURSELF A FRESH, HARD-BOILED EGG... NO, WAIT, NOT TOO FRESH, WE DON'T WANT IT TO SMELL LIKE A BOMB IN THEIR FACE! *giggles maniacally*

NEXT, GRAB THAT EGG AND HOLD IT UP LIKE IT'S A GLASS BULB READY TO EXPLODE AT ANY MOM


In [27]:
logits_merged = get_logits(text, merger.merger, tokenizer)
logits_0 = get_logits(text, merger.models[1], tokenizer)
torch.allclose(logits_merged, logits_0, atol=0, rtol=0)

True