In [1]:
import math
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union, Mapping
from abc import ABC, abstractmethod

import datasets
import torch
import torch.nn.functional as F
import numpy as np
import torch.nn as nn
import logging
import copy
import gc

from datasets import load_dataset
from tqdm import tqdm
from transformers import (
    PreTrainedTokenizerBase,
    PreTrainedModel,
    PretrainedConfig,
    AutoConfig,
    AutoModelForCausalLM,
    LlamaForCausalLM,
    LlamaConfig,
    Trainer,
    TrainingArguments,
    AutoTokenizer,
    HfArgumentParser,
    default_data_collator,
    is_torch_xla_available,
    set_seed,
)
from transformers import (
    HfArgumentParser,
    TrainingArguments,
    Trainer
)

from transformers.modeling_outputs import (
    BaseModelOutputWithPast,
    CausalLMOutputWithPast
)

import efficient_masks
import accurate_masks

from utils import (
    generate, 
    get_hidden_states, 
    get_logits,
    free_memory
)
# Configure logger
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

In [2]:
path_a = "unsloth/Llama-3.2-1B-Instruct"
path_b = "unsloth/Llama-3.2-1B"
merge_config_a = accurate_masks.MergerConfig(
    model_paths = [path_a, path_b],
    mode = "vector_input",
    constrain_mode = "01",
)
merge_config_e = efficient_masks.MergerConfig(
    model_paths = [path_a, path_b],
    mode = "vector_input",
    constrain_mode = "01",
)

In [3]:
merge_config_a

MergerConfig {
  "constrain_mode": "01",
  "mode": "vector_input",
  "model_paths": [
    "unsloth/Llama-3.2-1B-Instruct",
    "unsloth/Llama-3.2-1B"
  ],
  "transformers_version": "4.46.3"
}

In [26]:
def debug_linear_forward(self, x):
    constrained_weight_masks = self.weight_masks_constrainer([m.weight for m in self.weight_masks])
    constrained_bias_masks = self.bias_masks_constrainer(
        [m.weight if m is not None else None for m in self.bias_masks]
    )
    masked_biases = [
        b_mask * linear.bias if linear.bias is not None and b_mask is not None else linear.bias
        for b_mask, linear in zip(constrained_bias_masks, self.linears)
    ]
    merged_bias = (
        sum(b if b is not None else torch.zeros_like(
            self.linears[0].weight[:, 0]) for b in masked_biases
           ) 
        if not all(b is None for b in masked_biases) else None
    )

    logger.info("Debugging Linear forward.")
    output = 0.0
    for i, linear in enumerate(self.linears):
        logger.info(f"BEFORE")
        logger.info(f"  linear: device: {linear.weight.device}; dtype: {linear.weight.dtype}")
        logger.info(f"  input: device: {x.device}; dtype: {x.dtype}")
        masked_input = constrained_weight_masks[i] * x
        logger.info(f"AFTER")
        logger.info(f"  linear: device: {linear.weight.device}; dtype: {linear.weight.dtype}")
        logger.info(f"  input: device: {masked_input.device}; dtype: {masked_input.dtype}")
        output = output + nn.functional.linear(masked_input, linear.weight, None)
        logger.info(f"OUTPUT")
        logger.info(f"  output: device: {output.device}; dtype: {output.dtype}")
    if merged_bias:
        output = output + merged_bias

    return output

In [27]:
def debug_emb_forward(self, input_ids):
    constrained_masks = self.masks_constrainer([m.weight for m in self.masks])
    logger.info("Debugging Embedding forward.")
    an_embedding = self.embeddings[0]
    out = 0.0
    for i, emb in enumerate(self.embeddings):
        logger.info(f"BEFORE")
        logger.info(f"  emb: device: {emb.weight.device}; dtype: {emb.weight.dtype}")
        logger.info(f"  input: device: {input_ids.device}; dtype: {input_ids.dtype}")
        mask = constrained_masks[i]
        masked_weight = emb.weight * mask
        logger.info(f"  mask: device: {mask.device}; dtype: {mask.dtype}")
        logger.info(f"  masked_emb: device: {masked_weight.device}; dtype: {masked_weight.dtype}")
        out = out + nn.functional.embedding(
            input_ids,
            # emb.weight * mask,
            masked_weight,
            padding_idx=an_embedding.padding_idx,
            max_norm=an_embedding.max_norm,
            norm_type=an_embedding.norm_type,
            scale_grad_by_freq=an_embedding.scale_grad_by_freq,
            sparse=an_embedding.sparse,
        )
        logger.info(f"AFTER")
        logger.info(f"  output: device: {out.device}; dtype: {out.dtype}")
    return out

In [5]:
efficient_masks.LinearsWithMasks.forward = debug_linear_forward
efficient_masks.EmbeddingsWithMasks.forward = debug_emb_forward

In [6]:
tokenizer = AutoTokenizer.from_pretrained(path_a)

In [7]:
mergere = efficient_masks.Merger(merge_config_e)
mergere.__post_init__()

Initializing masks: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 147/147 [00:49<00:00,  2.99it/s]
2025-01-03 09:14:32,041 - INFO - Initial GPU memory allocated: 0.00 GB
2025-01-03 09:14:32,405 - INFO - Final GPU memory allocated: 0.00 GB
2025-01-03 09:14:32,407 - INFO - Freed GPU memory: 0.00 GB


In [37]:
mergere = mergere.to(device="cuda:0", dtype=torch.bfloat16)

In [9]:
mergere.merger.model.layers[0].self_attn.q_proj.get_constrained_masks()

{'weight_masks': [tensor([1., 1., 1.,  ..., 1., 1., 1.], device='cuda:0'),
  tensor([1., 1., 1.,  ..., 1., 1., 1.], device='cuda:0')],
 'bias_masks': [None, None]}

In [38]:
def get_logits(text, model, tokenizer):
    input_ids = tokenizer(text, return_tensors="pt").to(model.device)
    print(input_ids['input_ids'].dtype)
    model.eval()
    with torch.no_grad():
        logits = model(**input_ids).logits
    return logits

In [28]:
mergere.merger.model.embed_tokens.forward = debug_emb_forward.__get__(mergere.merger.model.embed_tokens)

In [29]:
mergere.merger.model.embed_tokens.forward??

[0;31mSignature:[0m [0mmergere[0m[0;34m.[0m[0mmerger[0m[0;34m.[0m[0mmodel[0m[0;34m.[0m[0membed_tokens[0m[0;34m.[0m[0mforward[0m[0;34m([0m[0minput_ids[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m <no docstring>
[0;31mSource:[0m   
[0;32mdef[0m [0mdebug_emb_forward[0m[0;34m([0m[0mself[0m[0;34m,[0m [0minput_ids[0m[0;34m)[0m[0;34m:[0m[0;34m[0m
[0;34m[0m    [0mconstrained_masks[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mmasks_constrainer[0m[0;34m([0m[0;34m[[0m[0mm[0m[0;34m.[0m[0mweight[0m [0;32mfor[0m [0mm[0m [0;32min[0m [0mself[0m[0;34m.[0m[0mmasks[0m[0;34m][0m[0;34m)[0m[0;34m[0m
[0;34m[0m    [0mlogger[0m[0;34m.[0m[0minfo[0m[0;34m([0m[0;34m"Debugging Embedding forward."[0m[0;34m)[0m[0;34m[0m
[0;34m[0m    [0man_embedding[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0membeddings[0m[0;34m[[0m[0;36m0[0m[0;34m][0m[0;34m[0m
[0;34m[0m    [0mout[0m [0;34m=[0m [0;36m0.0[

In [38]:
prompt = "How to attack a person with an egg. Talk like a crazy person."
# logits_merged_a = get_logits(prompt, mergera.merger, tokenizer)
logits_merged_e = get_logits(prompt, mergere.merger, tokenizer)

2025-01-03 10:10:49,075 - INFO - Debugging Embedding forward.
2025-01-03 10:10:49,076 - INFO - BEFORE
2025-01-03 10:10:49,078 - INFO -   emb: device: cuda:0; dtype: torch.bfloat16
2025-01-03 10:10:49,079 - INFO -   input: device: cuda:0; dtype: torch.int64
2025-01-03 10:10:49,080 - INFO -   mask: device: cuda:0; dtype: torch.bfloat16
2025-01-03 10:10:49,082 - INFO -   masked_emb: device: cuda:0; dtype: torch.bfloat16
2025-01-03 10:10:49,083 - INFO - AFTER
2025-01-03 10:10:49,085 - INFO -   output: device: cuda:0; dtype: torch.bfloat16
2025-01-03 10:10:49,085 - INFO - BEFORE
2025-01-03 10:10:49,086 - INFO -   emb: device: cuda:0; dtype: torch.bfloat16
2025-01-03 10:10:49,087 - INFO -   input: device: cuda:0; dtype: torch.int64
2025-01-03 10:10:49,088 - INFO -   mask: device: cuda:0; dtype: torch.bfloat16
2025-01-03 10:10:49,089 - INFO -   masked_emb: device: cuda:0; dtype: torch.bfloat16
2025-01-03 10:10:49,091 - INFO - AFTER
2025-01-03 10:10:49,092 - INFO -   output: device: cuda:0; dt

In [None]:
logits_merged_e = get_logits(prompt, mergere.merger, tokenizer)