In [5]:
import os, random
import torch, torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from transformers import AutoModelForCausalLM, OPTForCausalLM, BloomForCausalLM

device = torch.device("cpu")
device = torch.device("cuda:3")

def get_named_linears(module):
    return {n:m for n,m in module.named_modules() if isinstance(m, nn.Linear)}

def get_blocks(model):
    if model.__class__.__name__ in ("LlamaForCausalLM","Qwen2ForCausalLM","LlavaLlamaForCausalLM"):
        return model.model.layers
    if isinstance(model, OPTForCausalLM): return model.model.decoder.layers
    if isinstance(model, BloomForCausalLM): return model.transformer.h
    if "mpt" in str(model.__class__).lower(): return model.transformer.blocks
    if "falcon" in str(model.__class__).lower() or "bigcode" in str(model.__class__).lower():
        return model.transformer.h
    if "neox" in str(model.__class__).lower(): return model.gpt_neox.layers
    raise NotImplementedError(type(model))

def make_blocks(W, axis, bsz):
    flat = (W if axis==0 else W.T).ravel()     # axis=0이면 행우선(flat by rows), axis=1이면 열우선
    n_full = flat.size // bsz
    if n_full == 0:
        blk = np.pad(flat, (0, bsz-flat.size)) if flat.size < bsz else flat[:bsz]
        return blk[None, :]
    flat = flat[:n_full*bsz]
    return flat.reshape(n_full, bsz)

def flat_to_sym(V, N):
    A = torch.zeros(N, N, dtype=V.dtype, device=V.device)
    idxs = torch.tril_indices(N, N, device=V.device)
    A[idxs.unbind()] = V
    A[idxs[1, :], idxs[0, :]] = V
    return A

def regularize_H(H, n, sigma_reg):
    H.div_(torch.diag(H).mean())
    idx = torch.arange(n)
    H[idx, idx] += sigma_reg
    return H

import sys
sys.path.append('/workspace/Weight_compression/Wparam_dataset')
from utils import *

def RHT_H(H, SU):
    return matmul_hadUt(matmul_hadUt(H * SU).T * SU)


def RHT_W(W, SU, SV):
    return matmul_hadUt(matmul_hadUt(W.T * SV).T * SU)


def incoherence_preprocess(H, W, args):
    # dtype_ = torch.float64 if args.use_fp64 else torch.float32
    dtype_ = torch.float32
    device = H.device
    # device = torch.device('cpu')
    (m, n) = H.shape

    def _dump(Hr, Lhr, msg=''):
        torch.save(Hr, f"{args.save_pfx}/Hr_debug_fft.pt")
        torch.save(Lhr, f"{args.save_pfx}/Lhr_debug_fft.pt")
        raise Exception(msg)

    # diagonally rescale W,H to minimize proxy loss
    scaleWH = None
    Wr = W
    Hr = H
    # if args.rescale_WH:
    if False:
        Hr = H / H.abs().max()
        diagH = torch.diag(Hr)
        diagW2 = torch.diag(W.T @ W)
        diagH = torch.clamp(diagH, min=1e-8)
        diagW2 = torch.clamp(diagW2, min=1e-8)
        scaleWH = (diagH / diagW2).sqrt().sqrt().to(torch.float32)
        scaleWH = scaleWH.clamp(min=1e-8)
        Wr = Wr * scaleWH[None, :]
        Hr = Hr / scaleWH[None, :]
        Hr = Hr / scaleWH[:, None]
        scaleWH = scaleWH.cpu()

    # randomized hadamard transformation on H, W
    if True:
        SU = (torch.randn(n, device=device).sign() + 1e-5).sign().to(dtype_)
        SV = (torch.randn(m, device=device).sign() + 1e-5).sign().to(dtype_)
        Hr = RHT_H(Hr, SU)
        # Wr = RHT_W(Wr, SU, SV)
    # randomized kronecker product on H, W
    elif args.incoh_mode == "kron":
        SU = utils.rand_ortho_butterfly_noblock(n).to(dtype_).to(device)
        SV = utils.rand_ortho_butterfly_noblock(m).to(dtype_).to(device)
        Hr = SU @ Hr @ SU.T
        Wr = SV @ Wr @ SU.T
    else:
        raise NotImplementedError
    SV = SV.cpu()
    SU = SU.cpu()

    # Lhr = torch.linalg.cholesky(Hr)
    Lhr = None
    # if not torch.all(torch.isfinite(Lhr)):
    #     return None

    # Wr = Wr.to(device)

    return Lhr, Hr, Wr, SU, SV, scaleWH

In [6]:
model_list = [
    'meta-llama/Meta-Llama-3-8B',
    'meta-llama--Llama-2-7b-hf',
    'meta-llama--Llama-2-13b-hf',
    # 'openai/clip-vit-large-patch14'
]

size_list = [
    1024,
    4096,
    5120//4,
]

for model_name, size in zip(model_list, size_list):
    
    model_name = model_name.replace('/', '--')
    print('model_name: ', model_name)
    
    model_path = f"./hf_model/{model_name}"

    model = AutoModelForCausalLM.from_pretrained(model_path, local_files_only=True)
    # model = AutoModel.from_pretrained(model_path, local_files_only=True)
    layers = get_blocks(model)[0:1]

    if not isinstance(layers, dict):
        layers = {'': layers}

    for k, v in layers.items():
        print(k)
        named_linears = get_named_linears(v)
        for n, m in named_linears.items():
            print(n, m.weight.data.shape)
            W = m.weight.data
            r, c = W.shape

model_name:  meta-llama--Meta-Llama-3-8B


Loading checkpoint shards: 100%|██████████| 7/7 [00:02<00:00,  3.46it/s]



0.self_attn.q_proj torch.Size([4096, 4096])
0.self_attn.k_proj torch.Size([1024, 4096])
0.self_attn.v_proj torch.Size([1024, 4096])
0.self_attn.o_proj torch.Size([4096, 4096])
0.mlp.gate_proj torch.Size([14336, 4096])
0.mlp.up_proj torch.Size([14336, 4096])
0.mlp.down_proj torch.Size([4096, 14336])
model_name:  meta-llama--Llama-2-7b-hf


Loading checkpoint shards: 100%|██████████| 6/6 [00:01<00:00,  5.09it/s]



0.self_attn.q_proj torch.Size([4096, 4096])
0.self_attn.k_proj torch.Size([4096, 4096])
0.self_attn.v_proj torch.Size([4096, 4096])
0.self_attn.o_proj torch.Size([4096, 4096])
0.mlp.gate_proj torch.Size([11008, 4096])
0.mlp.up_proj torch.Size([11008, 4096])
0.mlp.down_proj torch.Size([4096, 11008])
model_name:  meta-llama--Llama-2-13b-hf


Loading checkpoint shards: 100%|██████████| 11/11 [00:02<00:00,  5.34it/s]


0.self_attn.q_proj torch.Size([5120, 5120])
0.self_attn.k_proj torch.Size([5120, 5120])
0.self_attn.v_proj torch.Size([5120, 5120])
0.self_attn.o_proj torch.Size([5120, 5120])
0.mlp.gate_proj torch.Size([13824, 5120])
0.mlp.up_proj torch.Size([13824, 5120])
0.mlp.down_proj torch.Size([5120, 13824])





# Weight block pca per layer

In [7]:
model_list = [
    # 'meta-llama/Meta-Llama-3-8B',
    # 'meta-llama--Llama-2-7b-hf',
    'meta-llama--Llama-2-13b-hf',
    # 'openai/clip-vit-large-patch14'
]
target_layers = [0,1,10,31]
block_sizes = [512]
max_blocks_per_weight = 500
save_dir = "./plot/pca_per_trblock"
os.makedirs(save_dir, exist_ok=True)
random.seed(0); np.random.seed(0)
rng = np.random.default_rng(0)

norm_funcs = {
    "orig":        lambda W: W,
    "lnormed":     lambda W: W / W.std(),                                  # global l2 norm
    "rnormed":     lambda W: W / W.std(dim=1, keepdims=True),             # row-wise
    "cnormed":     lambda W: W / W.std(dim=0, keepdims=True),             # col-wise
    "r_c_normed":  lambda W: (
                        lambda X: X / X.std(dim=0, keepdims=True)
                      )(W / W.std(dim=1, keepdims=True)),                 # row then col
    "c_r_normed":  lambda W: (
                        lambda X: X / X.std(dim=1, keepdims=True)
                      )(W / W.std(dim=0, keepdims=True)),                 # col then row
}

for model_name in model_list:
    mn = model_name.replace('/', '--')
    model = AutoModelForCausalLM.from_pretrained(f"./hf_model/{mn}", local_files_only=True)
    layers = get_blocks(model)

    for block_size in block_sizes:
        for lidx in target_layers:
            # data[(axis_name, norm_name)] = (X2, labels)
            data = {}

            named_linears = get_named_linears(layers[lidx])
            if not named_linears:
                print(f"[skip] layer {lidx}")
                continue

            # 블록 생성 & PCA 수행
            for axis_name, axis in [("row",0), ("col",1)]:
                for norm_name, norm_fn in norm_funcs.items():
                    feats, labels = [], []
                    for wname, mod in named_linears.items():
                        W = mod.weight.data.detach().cpu().numpy()
                        W = norm_fn(torch.from_numpy(W)).numpy()  # numpy ← torch 변환
                        blocks = make_blocks(W, axis=axis, bsz=block_size)
                        if max_blocks_per_weight and blocks.shape[0] > max_blocks_per_weight:
                            idx = rng.choice(blocks.shape[0], max_blocks_per_weight, replace=False)
                            blocks = blocks[idx]
                        feats.append(blocks)
                        labels.extend([wname.split('.')[-1]] * blocks.shape[0])
                    X = np.vstack(feats)
                    X2 = PCA(n_components=2, svd_solver='randomized', random_state=0).fit_transform(X)
                    data[(axis_name, norm_name)] = (X2, labels)

            # 2행×4열 서브플롯
            fig, axs = plt.subplots(2, 6, figsize=(18, 8))
            cmap = plt.get_cmap('tab10')
            axes = ["row", "col"]
            norms = list(norm_funcs.keys())

            for i, axis_name in enumerate(axes):
                for j, norm_name in enumerate(norms):
                    ax = axs[i, j]
                    X2, labels = data[(axis_name, norm_name)]
                    uniq = sorted(set(labels))
                    color_map = {u: cmap(k % 10) for k, u in enumerate(uniq)}
                    for u in uniq:
                        idxs = [k for k, l in enumerate(labels) if l == u]
                        ax.scatter(X2[idxs, 0], X2[idxs, 1], s=6, alpha=0.6, label=u, color=color_map[u])
                    ax.set_title(f"{axis_name} / {norm_name}")
                    ax.set_xlabel("PC1"); ax.set_ylabel("PC2"); ax.grid(True, alpha=0.2)
                    if i == 0 and j == 0:
                        ax.legend(markerscale=3, fontsize=6, ncol=2)

            fig.suptitle(f"{mn} / layer {lidx}  (block_size={block_size})", fontsize=12)
            fig.tight_layout(rect=[0,0,1,0.95])
            out_dir = os.path.join(save_dir, mn)
            os.makedirs(out_dir, exist_ok=True)
            fig.savefig(os.path.join(out_dir, f"layer{lidx}_bs{block_size}_pca_all_norms.png"), dpi=200)
            plt.close(fig)

    print("Done.")


Loading checkpoint shards: 100%|██████████| 11/11 [00:02<00:00,  5.15it/s]


Done.


# Hessian scaled weight pca

In [8]:
model_list = [
    # 'meta-llama/Meta-Llama-3-8B',
    # 'meta-llama--Llama-2-7b-hf',
    'meta-llama--Llama-2-13b-hf',
    # 'openai/clip-vit-large-patch14'
]
quip_hess_path = [
    # './quip_hess/llama3_8b_6144',
    # './quip_hess/Hessians-Llama-2-7b-6144',
    './quip_hess/Hessians-Llama-2-13b-6144',
]

target_layers = [0,1,10,31]
block_sizes = [512, 16, 128]
max_blocks_per_weight = 500
save_dir = "./plot/pca_per_trblock"
os.makedirs(save_dir, exist_ok=True)
random.seed(0); np.random.seed(0)
rng = np.random.default_rng(0)
sigma_reg = 1e-4


norm_funcs = {
    "orig":        lambda W: W,
    "lnormed":     lambda W: W / W.std(),                                  # global l2 norm
    "rnormed":     lambda W: W / W.std(dim=1, keepdims=True),             # row-wise
    "cnormed":     lambda W: W / W.std(dim=0, keepdims=True),             # col-wise
    "r_c_normed":  lambda W: (
                        lambda X: X / X.std(dim=0, keepdims=True)
                      )(W / W.std(dim=1, keepdims=True)),                 # row then col
    "c_r_normed":  lambda W: (
                        lambda X: X / X.std(dim=1, keepdims=True)
                      )(W / W.std(dim=0, keepdims=True)),                 # col then row
}

for model_name, quip_hess in zip(model_list, quip_hess_path):
    mn = model_name.replace('/', '--')
    model = AutoModelForCausalLM.from_pretrained(f"./hf_model/{mn}", local_files_only=True)
    layers = get_blocks(model)
    for block_size in block_sizes:
        for lidx in target_layers:
            # data[(axis_name, norm_name)] = (X2, labels)
            data = {}

            named_linears = get_named_linears(layers[lidx])
            if not named_linears:
                print(f"[skip] layer {lidx}")
                continue

            hess_dict = {}
            hess_dict['qkv'] = torch.load(f'{quip_hess}/{lidx}_qkv.pt', weights_only=False)
            hess_dict['o'] = torch.load(f'{quip_hess}/{lidx}_o.pt', weights_only=False)
            hess_dict['up'] = torch.load(f'{quip_hess}/{lidx}_up.pt', weights_only=False)
            hess_dict['down'] = torch.load(f'{quip_hess}/{lidx}_down.pt', weights_only=False)

            # 블록 생성 & PCA 수행
            for axis_name, axis in [("row",0), ("col",1)]:
                for norm_name, norm_fn in norm_funcs.items():
                    feats, labels = [], []
                    for wname, mod in named_linears.items():
                        # W = mod.weight.data.detach().cpu().numpy()
                        W = mod.weight.data.detach().to(device)                    
                        
                        if 'q_proj' in wname or 'k_proj' in wname or 'v_proj' in wname:
                            H_flat = hess_dict['qkv']
                        elif 'o_proj' in wname:
                            H_flat = hess_dict['o']
                        elif 'up_proj' in wname or 'gate_proj' in wname:
                            H_flat = hess_dict['up']
                        elif 'down_proj' in wname:
                            H_flat = hess_dict['down']
                        else:
                            raise NotImplementedError(wname)
                
                        H = flat_to_sym(H_flat['flatH'], H_flat['n']).to(device)
                        mu = H_flat['mu'].to(device)
                        H.add_(mu[None, :] * mu[:, None])
                        n_h = H_flat['n']                   
                        H = regularize_H(H, n_h, sigma_reg)
                        
                        # scaleh
                        diagH = torch.diag(H)
                        diagH = torch.clamp(diagH, min=1e-8)
                        scaleWH = diagH.sqrt()
                        W = W * scaleWH[None, :]
                        
                        ## cholesky
                        # L = torch.linalg.cholesky(H)
                        # W = W @ L
                        
                        W = norm_fn(W).cpu().numpy()  # numpy ← torch 변환
                        blocks = make_blocks(W, axis=axis, bsz=block_size)
                        if max_blocks_per_weight and blocks.shape[0] > max_blocks_per_weight:
                            idx = rng.choice(blocks.shape[0], max_blocks_per_weight, replace=False)
                            blocks = blocks[idx]
                        feats.append(blocks)
                        labels.extend([wname.split('.')[-1]] * blocks.shape[0])
                    X = np.vstack(feats)
                    X2 = PCA(n_components=2, svd_solver='randomized', random_state=0).fit_transform(X)
                    data[(axis_name, norm_name)] = (X2, labels)
                    del W, H
                    
            # 2행×4열 서브플롯
            fig, axs = plt.subplots(2, 6, figsize=(18, 8))
            cmap = plt.get_cmap('tab10')
            axes = ["row", "col"]
            norms = list(norm_funcs.keys())

            for i, axis_name in enumerate(axes):
                for j, norm_name in enumerate(norms):
                    ax = axs[i, j]
                    X2, labels = data[(axis_name, norm_name)]
                    uniq = sorted(set(labels))
                    color_map = {u: cmap(k % 10) for k, u in enumerate(uniq)}
                    for u in uniq:
                        idxs = [k for k, l in enumerate(labels) if l == u]
                        ax.scatter(X2[idxs, 0], X2[idxs, 1], s=6, alpha=0.6, label=u, color=color_map[u])
                    ax.set_title(f"{axis_name} / {norm_name}")
                    ax.set_xlabel("PC1"); ax.set_ylabel("PC2"); ax.grid(True, alpha=0.2)
                    if i == 0 and j == 0:
                        ax.legend(markerscale=3, fontsize=6, ncol=2)

            # fig.suptitle(f"{mn} / cholesky layer {lidx}  (block_size={block_size})", fontsize=12)
            fig.suptitle(f"{mn} / scaleh layer {lidx}  (block_size={block_size})", fontsize=12)
            fig.tight_layout(rect=[0,0,1,0.95])
            out_dir = os.path.join(save_dir, mn)
            os.makedirs(out_dir, exist_ok=True)
            # fig.savefig(os.path.join(out_dir, f"cholesky_layer{lidx}_bs{block_size}_pca_all_norms.png"), dpi=200)
            fig.savefig(os.path.join(out_dir, f"scaleh_layer{lidx}_bs{block_size}_pca_all_norms.png"), dpi=200)
            plt.close(fig)

    print("Done.")


Loading checkpoint shards:  18%|█▊        | 2/11 [00:00<00:02,  4.36it/s]

Loading checkpoint shards: 100%|██████████| 11/11 [00:01<00:00,  5.52it/s]


Done.


# Cholesky

In [9]:
model_list = [
    # 'meta-llama/Meta-Llama-3-8B',
    # 'meta-llama--Llama-2-7b-hf',
    'meta-llama--Llama-2-13b-hf',
    # 'openai/clip-vit-large-patch14'
]
quip_hess_path = [
    # './quip_hess/llama3_8b_6144',
    # './quip_hess/Hessians-Llama-2-7b-6144',
    './quip_hess/Hessians-Llama-2-13b-6144',
]

target_layers = [0,1,10,31]
block_sizes = [512, 16, 128]
max_blocks_per_weight = 500
save_dir = "./plot/pca_per_trblock"
os.makedirs(save_dir, exist_ok=True)
random.seed(0); np.random.seed(0)
rng = np.random.default_rng(0)
sigma_reg = 1e-4


norm_funcs = {
    "orig":        lambda W: W,
    "lnormed":     lambda W: W / W.std(),                                  # global l2 norm
    "rnormed":     lambda W: W / W.std(dim=1, keepdims=True),             # row-wise
    "cnormed":     lambda W: W / W.std(dim=0, keepdims=True),             # col-wise
    "r_c_normed":  lambda W: (
                        lambda X: X / X.std(dim=0, keepdims=True)
                      )(W / W.std(dim=1, keepdims=True)),                 # row then col
    "c_r_normed":  lambda W: (
                        lambda X: X / X.std(dim=1, keepdims=True)
                      )(W / W.std(dim=0, keepdims=True)),                 # col then row
}

for model_name, quip_hess in zip(model_list, quip_hess_path):
    mn = model_name.replace('/', '--')
    model = AutoModelForCausalLM.from_pretrained(f"./hf_model/{mn}", local_files_only=True)
    layers = get_blocks(model)
    for block_size in block_sizes:
        for lidx in target_layers:
            # data[(axis_name, norm_name)] = (X2, labels)
            data = {}

            named_linears = get_named_linears(layers[lidx])
            if not named_linears:
                print(f"[skip] layer {lidx}")
                continue

            hess_dict = {}
            hess_dict['qkv'] = torch.load(f'{quip_hess}/{lidx}_qkv.pt', weights_only=False)
            hess_dict['o'] = torch.load(f'{quip_hess}/{lidx}_o.pt', weights_only=False)
            hess_dict['up'] = torch.load(f'{quip_hess}/{lidx}_up.pt', weights_only=False)
            hess_dict['down'] = torch.load(f'{quip_hess}/{lidx}_down.pt', weights_only=False)

            # 블록 생성 & PCA 수행
            for axis_name, axis in [("row",0), ("col",1)]:
                for norm_name, norm_fn in norm_funcs.items():
                    feats, labels = [], []
                    for wname, mod in named_linears.items():
                        # W = mod.weight.data.detach().cpu().numpy()
                        W = mod.weight.data.detach().to(device)                    
                        
                        if 'q_proj' in wname or 'k_proj' in wname or 'v_proj' in wname:
                            H_flat = hess_dict['qkv']
                        elif 'o_proj' in wname:
                            H_flat = hess_dict['o']
                        elif 'up_proj' in wname or 'gate_proj' in wname:
                            H_flat = hess_dict['up']
                        elif 'down_proj' in wname:
                            H_flat = hess_dict['down']
                        else:
                            raise NotImplementedError(wname)
                
                        H = flat_to_sym(H_flat['flatH'], H_flat['n']).to(device)
                        mu = H_flat['mu'].to(device)
                        H.add_(mu[None, :] * mu[:, None])
                        n_h = H_flat['n']                   
                        H = regularize_H(H, n_h, sigma_reg)
                        
                        # # scaleh
                        # diagH = torch.diag(H)
                        # diagH = torch.clamp(diagH, min=1e-8)
                        # scaleWH = diagH.sqrt()
                        # W = W * scaleWH[None, :]
                        
                        # cholesky
                        L = torch.linalg.cholesky(H)
                        W = W @ L
                        
                        W = norm_fn(W).cpu().numpy()  # numpy ← torch 변환
                        blocks = make_blocks(W, axis=axis, bsz=block_size)
                        if max_blocks_per_weight and blocks.shape[0] > max_blocks_per_weight:
                            idx = rng.choice(blocks.shape[0], max_blocks_per_weight, replace=False)
                            blocks = blocks[idx]
                        feats.append(blocks)
                        labels.extend([wname.split('.')[-1]] * blocks.shape[0])
                    X = np.vstack(feats)
                    X2 = PCA(n_components=2, svd_solver='randomized', random_state=0).fit_transform(X)
                    data[(axis_name, norm_name)] = (X2, labels)
                    del W, H
                    
            # 2행×4열 서브플롯
            fig, axs = plt.subplots(2, 6, figsize=(18, 8))
            cmap = plt.get_cmap('tab10')
            axes = ["row", "col"]
            norms = list(norm_funcs.keys())

            for i, axis_name in enumerate(axes):
                for j, norm_name in enumerate(norms):
                    ax = axs[i, j]
                    X2, labels = data[(axis_name, norm_name)]
                    uniq = sorted(set(labels))
                    color_map = {u: cmap(k % 10) for k, u in enumerate(uniq)}
                    for u in uniq:
                        idxs = [k for k, l in enumerate(labels) if l == u]
                        ax.scatter(X2[idxs, 0], X2[idxs, 1], s=6, alpha=0.6, label=u, color=color_map[u])
                    ax.set_title(f"{axis_name} / {norm_name}")
                    ax.set_xlabel("PC1"); ax.set_ylabel("PC2"); ax.grid(True, alpha=0.2)
                    if i == 0 and j == 0:
                        ax.legend(markerscale=3, fontsize=6, ncol=2)

            fig.suptitle(f"{mn} / cholesky layer {lidx}  (block_size={block_size})", fontsize=12)
            # fig.suptitle(f"{mn} / scaleh layer {lidx}  (block_size={block_size})", fontsize=12)
            fig.tight_layout(rect=[0,0,1,0.95])
            out_dir = os.path.join(save_dir, mn)
            os.makedirs(out_dir, exist_ok=True)
            fig.savefig(os.path.join(out_dir, f"cholesky_layer{lidx}_bs{block_size}_pca_all_norms.png"), dpi=200)
            # fig.savefig(os.path.join(out_dir, f"scaleh_layer{lidx}_bs{block_size}_pca_all_norms.png"), dpi=200)
            plt.close(fig)

    print("Done.")


Loading checkpoint shards: 100%|██████████| 11/11 [00:02<00:00,  4.79it/s]


_LinAlgError: linalg.cholesky: The factorization could not be completed because the input is not positive-definite (the leading minor of order 1136 is not positive-definite).