In [1]:
import torch
from torch import nn
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import AutoPeftModelForCausalLM
from collections import OrderedDict
import copy
import os
from typing import List, Dict

In [2]:
code_alpaca = "DTang161/ModelMergingCode"
llama2 = "meta-llama/Llama-2-13b-hf"
wizardlm = "DTang161/ModelMergingLM"
wizardmath = "DTang161/ModelMergingMath"

models = [code_alpaca, llama2, wizardlm, wizardmath]

In [3]:
code_alpaca_model = AutoModelForCausalLM.from_pretrained(code_alpaca).state_dict()
llama2_model = AutoModelForCausalLM.from_pretrained(llama2).state_dict()
wizardlm = AutoModelForCausalLM.from_pretrained(wizardlm).state_dict()
wizardmath = AutoModelForCausalLM.from_pretrained(wizardmath).state_dict()

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]



Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/6 [00:00<?, ?it/s]

In [4]:
def state_dict_check(state_dicts):
    param_names = [set(model.keys()) for model in state_dicts]
    for params in param_names[1:]:
        if params != param_names[0]:
            raise ValueError("Params do not match")
    return True

def print_param_summary(state_dict):
    total_params = 0
    for name, tensor in state_dict.items():
        print(f"{name}: {tensor.size()}")
        total_params += tensor.numel()
    print(f"Total Parameters: {total_params}")
    
def state_dict_to_vector(remove_keys, state_dict):
    """
    Removes keys from state_dict
    Returns vectorized state_dict which is flattened (after sorting)
    """
    state_dict_copy = copy.deepcopy(state_dict)
    print("copied")
    for key in remove_keys:
        if key in state_dict_copy:
            del state_dict_copy[key]
    sorted_dict = OrderedDict(sorted(state_dict_copy.items()))
    print("sorted")
    return torch.nn.utils.parameters_to_vector([
        value.reshape(-1) for key, value in sorted_dict.items()
    ])

In [None]:
ptm = llama2_model
ftms = [code_alpaca_model, wizardlm]

state_dict_check([ptm] + ftms)

# print_param_summary(wizardmath)
# print_param_summary(llama2_model)
#state_dict_check([llama2_model] + ftms)

#these vary across models so must be removed
remove_keys = [
    "lm_head.weight",
    "model.embed_tokens.weight",
]

vectorized_ptm = state_dict_to_vector(remove_keys, ptm)
# print(vectorized_ptm.shape)
vectorized_ftms = torch.vstack([state_dict_to_vector(remove_keys, ftm) for ftm in ftms])
# print(vectorized_ftms[0].shape)
task_vectors = vectorized_ftms - vectorized_ptm

torch.save(task_vectors, "./tensors/task_vectors.pt")

In [None]:
def create_top_k_mask(M, K = 20): #trim from raw vectorized models
    d = list(M.shape)[0] # n = number of models
    nonzero_count = int(d * (100-K)/100)
    kth_value, _ = M.abs().kthvalue(nonzero_count, dim=0)
    mask = M.abs() >= kth_value
    return mask

def determine_signs(trimmed_M): #elect signs from trimmed models
    return torch.sign(trimmed_M.sum(dim=0)) #sum by parameter across models: 1 x d

def create_rows_to_keep(trimmed_M, signs):
    return torch.where(signs == 0, 0, bool((torch.sign(signs) == torch.sign(trimmed_M))))

def create_selected_entries(trimmed_M, rows_to_keep):
    return trimmed_M * rows_to_keep

def create_non_zero_entries(rows_to_keep):
    return (rows_to_keep != 0).sum(dim=0).float()

def disjoint_merge(selected_entries, non_zero_count): #mean of fine-tuned task vectors matching signs
    disjoint_mean = selected_entries.sum(dim=0) / torch.clamp(non_zero_count, min=1)
    return disjoint_mean

#masks = torch.vstack([create_top_k_mask(task_vector) for task_vector in task_vectors]) #for memory purposes split up

In [None]:
top_k_mask = create_top_k_mask(task_vectors)
trimmed_M = task_vectors * top_k_mask
M_signs = determine_signs(trimmed_M)

In [None]:
rows_to_keep = torch.vstack([create_rows_to_keep(trimmed_M[i], M_signs[i]) for i in range(trimmed_M.shape[0])])
torch.save(rows_to_keep, "./tensors/rows_to_keep.pt")

In [None]:
# trimmed_M = torch.load("./tensors/trimmed_M.pt")
# rows_to_keep = torch.load("./tensors/rows_to_keep.pt")

selected_entries = trimmed_M * rows_to_keep
torch.save(selected_entries, "./tensors/selected_entries.pt")

In [None]:
non_zero_count = create_non_zero_entries(rows_to_keep)
torch.save(non_zero_count, "./tensors/non_zero_count.pt")

In [None]:
#selected_entries = torch.load("./tensors/selected_entries.pt")
non_zero_count = torch.load("./tensors/non_zero_count.pt")
print("done loading!")
merged_tasks = disjoint_merge(selected_entries, non_zero_count)
torch.save(merged_tasks, "./tensors/merged_tasks.pt")

In [None]:
merged_tasks = torch.load("./tensors/merged_tasks.pt")

In [None]:
#add merge tensors to ptm

LAMBDA = 1 #ignored for sake of computation

remove_keys = [
    "lm_head.weight",
    "model.embed_tokens.weight",
]


merged_tasks = torch.load("./tensors/merged_tasks.pt")
ptm = AutoModelForCausalLM.from_pretrained(llama2).state_dict()
print("loaded llama2 state dict")
vectorized_ptm = state_dict_to_vector(remove_keys, ptm)
print("vectorized llama2)")
vectorized_merged = vectorized_ptm + merged_tasks #technically vectorized_ptm + LAMBDA * merged_tasks
print("finished creating weights")
torch.save(vectorized_merged, "./tensors/merged_model_vector.pt")

In [10]:
def vector_to_state_dict(vectorized_model, state_dict, remove_keys):
    reference_dict = copy.deepcopy(state_dict)
    removed_weights = {}
    
    for key in remove_keys:
        if key in state_dict:
            print(key)
            removed_weights[key] = copy.deepcopy(state_dict[key])
            del reference_dict[key]
    sorted_reference_dict = OrderedDict(sorted(reference_dict.items()))
    
    torch.nn.utils.vector_to_parameters(vectorized_model, sorted_reference_dict.values())
    
    for key in remove_keys:
        if key in state_dict:
            sorted_reference_dict[key] = copy.deepcopy(removed_weights[key])
            print(sorted_reference_dict[key])
    
    return sorted_reference_dict
    

In [9]:
#ptm = AutoModelForCausalLM.from_pretrained(llama2).state_dict()
remove_keys = [
    "lm_head.weight",
    "model.embed_tokens.weight",
]
vectorized_merged = torch.load("./tensors/merged_model_vector.pt")

In [None]:
def print_param_summary(state_dict):
    total_params = 0
    for name, tensor in state_dict.items():
        print(f"{name}: {tensor.size()}")
        total_params += tensor.numel()
    print(f"Total Parameters: {total_params}")

In [11]:
merged_model_state_dict = vector_to_state_dict(vectorized_merged, ptm, remove_keys)

lm_head.weight
model.embed_tokens.weight
tensor([[ 6.2561e-03, -4.3945e-03,  1.3885e-03,  ..., -1.8433e-02,
         -1.2878e-02, -4.8523e-03],
        [-7.3853e-03, -1.0559e-02, -1.9150e-03,  ..., -8.4686e-04,
         -5.1498e-05, -1.4954e-02],
        [ 1.8677e-02, -3.8300e-03,  1.6357e-02,  ..., -1.2207e-02,
          1.9775e-02,  7.8125e-03],
        ...,
        [-1.8066e-02,  1.2360e-03, -5.3711e-03,  ..., -3.4912e-02,
          2.5146e-02, -1.9043e-02],
        [ 1.7700e-02, -1.2268e-02, -2.5635e-02,  ..., -7.7820e-03,
          2.4170e-02,  9.0332e-03],
        [-1.5198e-02, -1.4709e-02,  5.7068e-03,  ..., -3.5400e-02,
         -2.0599e-03, -2.5513e-02]])
tensor([[-4.8876e-06,  6.4969e-06,  9.5367e-07,  ...,  5.9605e-08,
          2.8014e-06, -2.6822e-06],
        [ 3.6316e-03,  4.2114e-03, -1.0300e-03,  ...,  3.9368e-03,
          8.2397e-03, -5.1117e-04],
        [-3.7575e-04, -3.1090e-04,  1.5869e-03,  ...,  2.8419e-04,
          1.7738e-04, -7.0572e-05],
        ...,
     

In [12]:
torch.save(merged_model_state_dict, "./merged_model_state_dict.pth")

In [14]:
merged_model = AutoModelForCausalLM.from_pretrained(llama2)
merged_model_state_dict = torch.load("./merged_model_state_dict.pth")
merged_model.load_state_dict(merged_model_state_dict)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

<All keys matched successfully>