In [1]:
from llmin.similarity import sim_matrix

from transformers import AutoModelForCausalLM
import torch

target_model = "mistralai/Mistral-7B-v0.1"

model = AutoModelForCausalLM.from_pretrained(
    target_model,
    device_map = "cuda",
    torch_dtype = torch.bfloat16
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [3]:
from tqdm import tqdm

def calculate_similarities(
    linear_module_name: str = "model.layers.{idx}.self_attn.q_proj.weight",
    num_layers = 32
):

    similarities = []
    for i in tqdm(range(0, num_layers)):
        for j in range(0, num_layers):
            if i != j:
                sim = sim_matrix(
                    a = model.state_dict()[linear_module_name.format(idx = i)],
                    b = model.state_dict()[linear_module_name.format(idx = j)]
                ).to("cpu")
                idx = (sim==torch.max(sim)).nonzero()[0].to("cpu")
                similarities.append({
                    "sim": sim[idx[0].item(), idx[1].item()],
                    "layer_1": linear_module_name.format(idx = i),
                    "layer_2": linear_module_name.format(idx = j)
                })
    return similarities

In [15]:
layer_similarities = []

modules = [
    "self_attn.q_proj",
    "self_attn.k_proj",
    "self_attn.v_proj",
    "self_attn.o_proj",
    "mlp.gate_proj",
    "mlp.up_proj",
    "mlp.down_proj"
]

for module in modules:
    module_name = "model.layers.{idx}." + module + ".weight"
    similarities = calculate_similarities(linear_module_name = module_name)
    sorted_similarities = sorted(similarities, key = lambda item: item["sim"], reverse = True)
    layer_similarities.append(sorted_similarities)


['model.layers.0.self_attn.q_proj',
 'model.layers.0.self_attn.k_proj',
 'model.layers.0.self_attn.v_proj',
 'model.layers.0.self_attn.o_proj',
 'model.layers.0.mlp.gate_proj',
 'model.layers.0.mlp.up_proj',
 'model.layers.0.mlp.down_proj',
 'model.layers.1.self_attn.q_proj',
 'model.layers.1.self_attn.k_proj',
 'model.layers.1.self_attn.v_proj',
 'model.layers.1.self_attn.o_proj',
 'model.layers.1.mlp.gate_proj',
 'model.layers.1.mlp.up_proj',
 'model.layers.1.mlp.down_proj',
 'model.layers.2.self_attn.q_proj',
 'model.layers.2.self_attn.k_proj',
 'model.layers.2.self_attn.v_proj',
 'model.layers.2.self_attn.o_proj',
 'model.layers.2.mlp.gate_proj',
 'model.layers.2.mlp.up_proj',
 'model.layers.2.mlp.down_proj',
 'model.layers.3.self_attn.q_proj',
 'model.layers.3.self_attn.k_proj',
 'model.layers.3.self_attn.v_proj',
 'model.layers.3.self_attn.o_proj',
 'model.layers.3.mlp.gate_proj',
 'model.layers.3.mlp.up_proj',
 'model.layers.3.mlp.down_proj',
 'model.layers.4.self_attn.q_proj',


In [None]:
layer_similarities