In [None]:
import numpy as np
import os
dtype = np.float32

import torch
import torch.nn as nn
import torchvision
from tqdm import tqdm
import os

from transformers import CLIPVisionModelWithProjection, AutoModelForCausalLM
from transformers import AutoModel, AutoTokenizer, OPTForCausalLM, BloomForCausalLM
import numpy

from huggingface_hub import try_to_load_from_cache, _CACHED_NO_EXIST
from huggingface_hub import scan_cache_dir

import glob
import random
import json
import os
import matplotlib.pyplot as plt

device = torch.device("cuda:0")
# device = torch.device("cpu")

def get_named_linears(module):
    return {name: m for name, m in module.named_modules() if isinstance(m, nn.Linear)}

def get_blocks(model):
    if model.__class__.__name__ in ("LlamaForCausalLM", "Qwen2ForCausalLM"):
        layers = model.model.layers
    elif model.__class__.__name__ == "LlavaLlamaForCausalLM":
        layers = model.model.layers
    elif isinstance(model, OPTForCausalLM):
        layers = model.model.decoder.layers
    elif isinstance(model, BloomForCausalLM):
        layers = model.transformer.h
    elif "mpt" in str(model.__class__).lower():
        layers = model.transformer.blocks
    elif "falcon" in str(model.__class__).lower():
        layers = model.transformer.h
    elif "bigcode" in str(model.__class__).lower():
        layers = model.transformer.h
    elif "neox" in str(model.__class__).lower():
        layers = model.gpt_neox.layers
    elif model.__class__.__name__ == "LlavaLlamaModel":
        layers = model.llm.model.layers
    elif model.__class__.__name__ in ("CLIPModel"):
        vision_layers = model.vision_model.encoder.layers
        text_layers = model.text_model.encoder.layers
        layers = {'vision': vision_layers,
                  'text': text_layers}
    else:
        raise NotImplementedError(type(model))
    # if not isinstance(layers, dict):
    #     layers = {'': layers}
    return layers

def flat_to_sym(V, N):
    A = torch.zeros(N, N, dtype=V.dtype, device=V.device)
    idxs = torch.tril_indices(N, N, device=V.device)
    A[idxs.unbind()] = V
    A[idxs[1, :], idxs[0, :]] = V
    return A

def regularize_H(H, n, sigma_reg):
    H.div_(torch.diag(H).mean())
    idx = torch.arange(n)
    H[idx, idx] += sigma_reg
    return H

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
model_list = [
    'meta-llama/Meta-Llama-3-8B',
    # 'meta-llama--Llama-2-7b-hf'
]
quip_hess_path = [
    './quip_hess/llama3_8b_6144',
    # './quip_hess/Hessians-Llama-2-7b-6144',
]
wtype_mapping = {'self_attn.q_proj': 0, 
                 'self_attn.k_proj': 1, 
                 'self_attn.v_proj': 2, 
                 'self_attn.o_proj': 3, 
                 'mlp.gate_proj': 4, 
                 'mlp.up_proj': 5, 
                 'mlp.down_proj': 6}
sigma_reg = 1e-4
# direction = 'col'
direction = 'row'

global_std = 0.012529

model_name = 'meta-llama/Meta-Llama-3-8B'
quip_hess = '../Wparam_dataset/quip_hess/llama3_8b_6144'


model_name = model_name.replace('/', '--')
print('model_name: ', model_name)

model_path = f"../Wparam_dataset/hf_model/{model_name}"

model = AutoModelForCausalLM.from_pretrained(model_path, local_files_only=True)
layers = get_blocks(model)


H = torch.load(f"{quip_hess}/1_down.pt", weights_only=False)
hatW = torch.load('/workspace/Weight_compression/hf_model_comp/comp_qtip/ckpt/meta-llama--Meta-Llama-3-8B/ql_rnorm/lmbda1000/1_down.pt', weights_only=False)
hatW_c = torch.load('/workspace/Weight_compression/hf_model_comp/comp_qtip/ckpt/meta-llama--Meta-Llama-3-8B/ql_rnorm_cnorm_trained/lmbda1000/1_down.pt', weights_only=False)
hatW_l = torch.load('/workspace/Weight_compression/hf_model_comp/comp_qtip/ckpt/meta-llama--Meta-Llama-3-8B/ql_rnorm_lnorm_trained/lmbda1000/1_down.pt', weights_only=False)


model_name:  meta-llama--Meta-Llama-3-8B


Loading checkpoint shards: 100%|██████████| 7/7 [00:01<00:00,  4.81it/s]


In [None]:
W = layers[1].mlp.down_proj.weight.data

In [None]:
def to_tensor_2d(save):
    if isinstance(save, dict):
        x = save['metadata']['row_std'] * save['hatWr']
    else:
        x = save
    return x

def squared_error_map(W, W_hat):
    if W.shape != W_hat.shape:
        raise ValueError(f"Shape mismatch: W {tuple(W.shape)} vs W_hat {tuple(W_hat.shape)}")
    return (W - W_hat) ** 2

import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import os

def plot_heatmap(mat: torch.Tensor, title: str, fname: str, pool_size: int = 32):
    """
    mat       : 2D torch.Tensor
    pool_size : 다운샘플링 크기 (예: 4 → 4x4 블록에서 max pooling)
    """
    os.makedirs("plots", exist_ok=True)

    # --- 텐서 다운샘플링 ---
    x = mat.detach().float().unsqueeze(0).unsqueeze(0)  # (1,1,H,W) 형태로
    H, W = mat.shape
    newH, newW = H // pool_size, W // pool_size
    x_pooled = F.max_pool2d(x, kernel_size=pool_size, stride=pool_size)
    arr = x_pooled.squeeze().cpu().numpy()

    # --- 로그 스케일 ---
    arr_log = np.log1p(arr)

    # --- 플롯 ---
    plt.figure(figsize=(24, 18))
    plt.imshow(arr_log, aspect='auto', cmap='gray_r')
    plt.title(title)
    plt.colorbar(label="log(1+x), pooled")
    plt.tight_layout()
    plt.savefig(os.path.join("plots", fname), dpi=200)
    plt.close()

    
hatW   = to_tensor_2d(hatW)
hatW_c = to_tensor_2d(hatW_c)
hatW_l = to_tensor_2d(hatW_l)


se_base = squared_error_map(W, hatW)
se_c    = squared_error_map(W, hatW_c)
se_l    = squared_error_map(W, hatW_l)

# 플롯
plot_heatmap(W,      "W (down_proj weight)",              "W_down_heatmap.png")
plot_heatmap(se_base,"(W - hatW)^2 [ql_rnorm]",           "SE_W_minus_hatW.png")
plot_heatmap(se_c,   "(W - hatW_c)^2 [ql_rnorm_cnorm]",   "SE_W_minus_hatW_c.png")
plot_heatmap(se_l,   "(W - hatW_l)^2 [ql_rnorm_lnorm]",   "SE_W_minus_hatW_l.png")

print("플롯이 ./plots 폴더에 저장되었습니다.")

플롯이 ./plots 폴더에 저장되었습니다.


In [24]:
def to_tensor_2d(save):
    if isinstance(save, dict):
        x = save['metadata']['row_std'] * save['hatWr']
    else:
        x = save
    return x

def squared_error_map(W, W_hat):
    if W.shape != W_hat.shape:
        raise ValueError(f"Shape mismatch: W {tuple(W.shape)} vs W_hat {tuple(W_hat.shape)}")
    return (W - W_hat) ** 2

import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import os

def _pool2d_max(mat: torch.Tensor, pool_size: int) -> torch.Tensor:
    x = mat.detach().float().unsqueeze(0).unsqueeze(0)  # (1,1,H,W)
    x_pooled = F.max_pool2d(x, kernel_size=pool_size, stride=pool_size)
    return x_pooled.squeeze(0).squeeze(0)
def plot_four(
    W: torch.Tensor,
    se_base: torch.Tensor,
    se_c: torch.Tensor,
    se_l: torch.Tensor,
    pool_size: int = 32,
    share_error_scale: bool = True,
    out_path: str = "plots/W_and_SE_quad.png"
):
    """
    2x2 플롯:
      [0,0] W (|W|에 log1p, 개별 스케일, 컬러바)
      [0,1] (W - hatW)^2 (log1p, 컬러바)
      [1,0] (W - hatW_c)^2 (log1p, 컬러바)
      [1,1] (W - hatW_l)^2 (log1p, 컬러바)
    """
    os.makedirs(os.path.dirname(out_path) or ".", exist_ok=True)

    # 1) 풀링 (텐서 상태)
    def _pool2d_max(mat: torch.Tensor, pool_size: int) -> torch.Tensor:
        x = mat.detach().float().unsqueeze(0).unsqueeze(0)  # (1,1,H,W)
        x_pooled = F.max_pool2d(x, kernel_size=pool_size, stride=pool_size)
        return x_pooled.squeeze(0).squeeze(0)

    Wp  = _pool2d_max(W,       pool_size)
    Ebp = _pool2d_max(se_base, pool_size)
    Ecp = _pool2d_max(se_c,    pool_size)
    Elp = _pool2d_max(se_l,    pool_size)

    # 2) NumPy 변환 후 log1p
    Wn  = Wp.cpu().numpy()
    Ebn = Ebp.cpu().numpy()
    Ecn = Ecp.cpu().numpy()
    Eln = Elp.cpu().numpy()

    W_vis  = np.log1p(np.abs(Wn))
    Eb_vis = np.log1p(Ebn)
    Ec_vis = np.log1p(Ecn)
    El_vis = np.log1p(Eln)

    # 에러맵 스케일 (공유 옵션)
    if share_error_scale:
        vmin_err = min(Eb_vis.min(), Ec_vis.min(), El_vis.min())
        vmax_err = max(Eb_vis.max(), Ec_vis.max(), El_vis.max())
        err_norm = dict(vmin=vmin_err, vmax=vmax_err)
    else:
        err_norm = {}

    # 3) 플롯
    fig, axs = plt.subplots(2, 2, figsize=(32, 24))
    (axW, axE1), (axE2, axE3) = axs

    # W
    imW  = axW.imshow(W_vis,  aspect='auto', cmap='gray_r')
    axW.set_title("W (pooled, log1p(|W|))")
    cbarW = fig.colorbar(imW, ax=axW)
    cbarW.set_label("log(1+|W|)")

    # Error maps (각자 컬러바 추가)
    imE1 = axE1.imshow(Eb_vis, aspect='auto', cmap='gray_r', **err_norm)
    axE1.set_title("(W - hatW)^2 (pooled, log1p)")
    cbarE1 = fig.colorbar(imE1, ax=axE1)
    cbarE1.set_label("log(1 + SE)")

    imE2 = axE2.imshow(Ec_vis, aspect='auto', cmap='gray_r', **err_norm)
    axE2.set_title("(W - hatW_c)^2 (pooled, log1p)")
    cbarE2 = fig.colorbar(imE2, ax=axE2)
    cbarE2.set_label("log(1 + SE)")

    imE3 = axE3.imshow(El_vis, aspect='auto', cmap='gray_r', **err_norm)
    axE3.set_title("(W - hatW_l)^2 (pooled, log1p)")
    cbarE3 = fig.colorbar(imE3, ax=axE3)
    cbarE3.set_label("log(1 + SE)")

    plt.tight_layout()
    plt.savefig(out_path, dpi=200)
    plt.close()
    print(f"Saved: {out_path}")


# ===== 예시 사용 =====
hatW   = to_tensor_2d(hatW)
hatW_c = to_tensor_2d(hatW_c)
hatW_l = to_tensor_2d(hatW_l)

se_base = squared_error_map(W, hatW)
se_c    = squared_error_map(W, hatW_c)
se_l    = squared_error_map(W, hatW_l)

# 4개를 한 번에, 에러맵 스케일 공유
plot_four(W, se_base, se_c, se_l, pool_size=32, share_error_scale=True,
          out_path="plots/W_and_SE_quad_shared.png")

# 4개를 한 번에, 에러맵 스케일 개별
plot_four(W, se_base, se_c, se_l, pool_size=32, share_error_scale=False,
          out_path="plots/W_and_SE_quad_individual.png")


Saved: plots/W_and_SE_quad_shared.png
Saved: plots/W_and_SE_quad_individual.png
