In [1]:
import torch
import torch.nn as nn
from transformers import CLIPVisionModelWithProjection, ViTForImageClassification, AutoModelForCausalLM
from transformers import AutoModel, AutoTokenizer, LlamaForCausalLM

import sys, os, json, math
from tqdm import tqdm
import numpy as np

from sklearn.metrics import mutual_info_score
from scipy.stats import pearsonr, spearmanr

import re

std = 0.012528747320175171

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class LayerInputs:
    def __init__(self, num_layers):
        self.layers = [
            {
                "self_attn": {
                    "q_proj": None,
                    "k_proj": None,
                    "v_proj": None,
                    "o_proj": None,
                },
                "mlp": {
                    "gate_proj": None,
                    "up_proj": None,
                    "down_proj": None,
                },
            }
            for _ in range(num_layers)
        ]

layer_inputs = torch.load('../Wparam_dataset/calib_data/layer_inputs_channelwise_mag.pt')

  layer_inputs = torch.load('../Wparam_dataset/calib_data/layer_inputs_channelwise_mag.pt')


In [3]:
def cal_corr(state_dict, recon_state_dict):
    
    pearson = []
    spearman = []
    
    for k, v in state_dict.items():
        if 'mlp' not in k and 'self_attn' not in k: continue
        match = re.search(r"layers\.(\d+).", k)
        if match:
            layer_index = int(match.group(1))  # 찾은 숫자를 정수형으로 변환
        
        # if layer_index not in [0, 10, 20, 31]: continue
        if 'self_attn' in k:
            ltype_str = 'self_attn'
        elif 'mlp' in k:
            ltype_str = 'mlp' 
        if 'q_proj' in k:
            wtype = 'q_proj'
        elif 'k_proj' in k:
            wtype = 'k_proj'
        elif 'v_proj' in k:
            wtype = 'v_proj'
        elif 'o_proj' in k:
            wtype = 'o_proj'
        elif 'gate_proj' in k:
            wtype = 'gate_proj'
        elif 'up_proj' in k:
            wtype = 'up_proj'
        elif 'down_proj' in k:
            wtype = 'down_proj'
        
        input_scale =  layer_inputs.layers[layer_index][ltype_str][wtype]
        
        mse = ((recon_state_dict[k] - v)**2).mean(0)

        pearson_corr, _ = pearsonr(input_scale, mse)
        spearman_corr, _ = spearmanr(input_scale, mse)
        
        # print(k, pearson_corr)
        pearson.append(pearson_corr)
        spearman.append(spearman_corr)
        # mutual_info = mutual_info_score(None, np.histogram2d(input_scale.numpy(), mse.numpy(), bins=1000)[0].flatten())
        # print(f"{k}, {mutual_info:.3f}")

        # plt.figure(figsize=(4, 3))
        # plt.scatter(input_scale, mse, alpha=0.7, edgecolors='black', linewidth=0.5)
        # plt.title(k)
        # plt.xlabel('input_scale')
        # plt.ylabel('mse')
        # plt.xscale('log')
        # plt.yscale('log')
        # plt.grid(True)
        # plt.show()
    print(np.array(pearson).mean())
    print(np.array(spearman).mean())

### AWQ

In [None]:
def latest_version_path(cache_dir, model_name, branch = 'main'):
    model_name_dir =  "models--" + model_name.replace('/', '--')
    path = os.path.join(cache_dir, model_name_dir)
    if not os.path.isdir(os.path.join(path, 'snapshots')):
        return None
    branch_file =  os.path.join(path, 'refs', branch)
    with open(branch_file, 'r', encoding='utf-8') as file:
        revision = file.read()
    return os.path.join(path, 'snapshots', revision)

cache_directory = "../Wparam_dataset_v0/model_zoo/huggingface" 
ckpt_path = latest_version_path(cache_directory, 'meta-llama/Meta-Llama-3-8B')
net = LlamaForCausalLM.from_pretrained(ckpt_path, local_files_only=True)

# ckpt_path = '/home/jgryu/Weight_compression/model_cache/models--meta-llama--Meta-Llama-3-8B/snapshots/8cde5ca8380496c9a6cc7ef3a8b46a0372a1d920'
# net = AutoModelForCausalLM.from_pretrained(ckpt_path, local_files_only=True)
state_dict = net.state_dict()


In [None]:
bpps = [2, 3, 4, 5, 6, 7, 8, 9, 10, 12]
# bpps = [4]
mses = []
mse_fn = nn.MSELoss()
for bpp in bpps:
    ckpt_path = f'../model_cache_reconstructed/awq/llama3-8b-my-w{bpp}-g128-fake-quantized'
    recon_net = AutoModelForCausalLM.from_pretrained(ckpt_path, local_files_only=True)
    recon_state_dict = recon_net.state_dict()
    
    print(bpp)
    cal_corr(state_dict, recon_state_dict)

    

2
-0.029420970685099492
-0.09500712973978678
3
0.14973438032635197
-0.2311222337514005
4
0.21332432331943263
-0.3178309782803166
5
0.23209445225211087
-0.33430709057249974
6
0.24620023228092477
-0.36772141005864556
7
0.23775588807491824
-0.3499329517380149
8
0.255313750082977
-0.33740311855527744

### VQ

In [4]:
def latest_version_path(cache_dir, model_name, branch = 'main'):
    model_name_dir =  "models--" + model_name.replace('/', '--')
    path = os.path.join(cache_dir, model_name_dir)
    if not os.path.isdir(os.path.join(path, 'snapshots')):
        return None
    branch_file =  os.path.join(path, 'refs', branch)
    with open(branch_file, 'r', encoding='utf-8') as file:
        revision = file.read()
    return os.path.join(path, 'snapshots', revision)

cache_directory = "../Wparam_dataset_v0/model_zoo/huggingface" 
ckpt_path = latest_version_path(cache_directory, 'meta-llama/Meta-Llama-3-8B')
net = LlamaForCausalLM.from_pretrained(ckpt_path, local_files_only=True)

# ckpt_path = '/home/jgryu/Weight_compression/model_cache/models--meta-llama--Meta-Llama-3-8B/snapshots/8cde5ca8380496c9a6cc7ef3a8b46a0372a1d920'
# net = AutoModelForCausalLM.from_pretrained(ckpt_path, local_files_only=True)
state_dict = net.state_dict()


Loading checkpoint shards: 100%|██████████| 4/4 [00:07<00:00,  1.80s/it]


In [16]:
root_dir = '/home/jgryu/Weight_compression/model_cache_reconstructed/vqvae_idx/row_v2/per_row_16_calib'
root_dir = '/home/jgryu/Weight_compression/model_cache_reconstructed/vqvae_idx/col/per_col_16_calib'
root_dir = '/home/jgryu/Weight_compression/model_cache_reconstructed/vqvae_idx/col_random_idx/per_col_16_calib'
root_dir = '/home/jgryu/Weight_compression/model_reconstructed/vqvae_qlike/row_16_calib'
import glob
# ckpt_paths = glob.glob(os.path.join(root_dir, "**/*th.tar"), recursive=True)
ckpt_paths = glob.glob(os.path.join(root_dir, "**/bpp*"), recursive=True)
ckpt_path_list = []

for ck in ckpt_paths:
    if 'bpp3.' in ck: continue
    if 'bpp5.' in ck: continue
    if 'result.' in ck: continue
    ckpt_path_list.append(ck)

print(ckpt_path_list)
mse_fn = nn.MSELoss()
for ckpt_path in ckpt_path_list:
    recon_net = AutoModelForCausalLM.from_pretrained(ckpt_path, local_files_only=True)
    recon_state_dict = recon_net.state_dict()

    print(ckpt_path)
    cal_corr(state_dict, recon_state_dict)

['/home/jgryu/Weight_compression/model_reconstructed/vqvae_qlike/row_16_calib/bpp6.0_size16_nmse_ne64_de1_K6_P16_encdim512_batch_size4096_total_iter1500000_lr0.0001_seed100_MSE_0.00101', '/home/jgryu/Weight_compression/model_reconstructed/vqvae_qlike/row_16_calib/bpp8.0_size16_nmse_ne256_de1_K8_P16_encdim512_batch_size4096_total_iter1500000_lr0.0001_seed100_MSE_9e-05', '/home/jgryu/Weight_compression/model_reconstructed/vqvae_qlike/row_16_calib/bpp4.0_size16_nmse_ne16_de1_K4_P16_encdim512_batch_size4096_total_iter1500000_lr0.0001_seed100_MSE_0.01228']


Loading checkpoint shards: 100%|██████████| 4/4 [00:07<00:00,  1.80s/it]


/home/jgryu/Weight_compression/model_reconstructed/vqvae_qlike/row_16_calib/bpp6.0_size16_nmse_ne64_de1_K6_P16_encdim512_batch_size4096_total_iter1500000_lr0.0001_seed100_MSE_0.00101
0.07648009591966119
0.004637024256094685


Loading checkpoint shards: 100%|██████████| 4/4 [00:09<00:00,  2.42s/it]


/home/jgryu/Weight_compression/model_reconstructed/vqvae_qlike/row_16_calib/bpp8.0_size16_nmse_ne256_de1_K8_P16_encdim512_batch_size4096_total_iter1500000_lr0.0001_seed100_MSE_9e-05
0.06675589994613879
0.013480944218589292


Loading checkpoint shards: 100%|██████████| 4/4 [00:08<00:00,  2.24s/it]


/home/jgryu/Weight_compression/model_reconstructed/vqvae_qlike/row_16_calib/bpp4.0_size16_nmse_ne16_de1_K4_P16_encdim512_batch_size4096_total_iter1500000_lr0.0001_seed100_MSE_0.01228
0.06871469711724709
-0.01805949488144771


## row idx
3 smse
0.012930385150270552
-0.002914573436993903

3 nmse
0.0996818714499644
0.014991141838406658

4 smse
0.0489207773622126
-0.0018312189708702376

4 nmse
0.05967029839169057
-0.008844155775352877

6 smse
0.017747773141605385
-0.0035077291401593483

6 nmse

8 smse
0.02382473622823794
-0.009273916552225536

8 nmse
0.06857445852456583
-0.00032242275318888756

# col idx
3 smse
0.08260417737577232
-0.023754363656675832

4 smse
0.0927952513919582
-0.020731512135732315

6 smse
0.11646683411423621
-0.008781474896219648

8 smse
0.10627975621443728
-0.014214981441471876

# col random idx
3 smse
0.08264684718664947
-0.0238141508414145

4 smse
0.09280146198547408
-0.02074162417035556

6 smse
0.11753815427956073
-0.011298514783288735

8 smse
0.10688536223638802
-0.013171601392293137

# VQVAE qlike
3
0.07324685930719574
-0.002356598737998506

4
0.06871469711724709
-0.01805949488144771

5
0.08376904820523023
0.0017911386220212917

6
0.07648009591966119
0.004637024256094685

8
0.06675589994613879
0.013480944218589292

# SeedLM

In [None]:
def latest_version_path(cache_dir, model_name, branch = 'main'):
    model_name_dir =  "models--" + model_name.replace('/', '--')
    path = os.path.join(cache_dir, model_name_dir)
    if not os.path.isdir(os.path.join(path, 'snapshots')):
        return None
    branch_file =  os.path.join(path, 'refs', branch)
    with open(branch_file, 'r', encoding='utf-8') as file:
        revision = file.read()
    return os.path.join(path, 'snapshots', revision)

cache_directory = "../Wparam_dataset_v0/model_zoo/huggingface" 
ckpt_path = latest_version_path(cache_directory, 'meta-llama/Meta-Llama-3-8B')
net = LlamaForCausalLM.from_pretrained(ckpt_path, local_files_only=True)

# ckpt_path = '/home/jgryu/Weight_compression/model_cache/models--meta-llama--Meta-Llama-3-8B/snapshots/8cde5ca8380496c9a6cc7ef3a8b46a0372a1d920'
# net = AutoModelForCausalLM.from_pretrained(ckpt_path, local_files_only=True)
state_dict = net.state_dict()

ckpt_path = f'../model_cache_reconstructed/seedlm/bpp4.0_C8_P3_K16'
recon_net = AutoModelForCausalLM.from_pretrained(ckpt_path, local_files_only=True)
recon_state_dict = recon_net.state_dict()

n = 0
mse = 0
for k, v in state_dict.items():
    if 'mlp' not in k and 'attn' not in k: continue
    
    mse += ((recon_state_dict[k] - v)**2).sum()
    n += v.numel()
    
mse = mse / n / std **2 
print(mse)

In [None]:
print(mse)

# RTN

In [None]:
def latest_version_path(cache_dir, model_name, branch = 'main'):
    model_name_dir =  "models--" + model_name.replace('/', '--')
    path = os.path.join(cache_dir, model_name_dir)
    if not os.path.isdir(os.path.join(path, 'snapshots')):
        return None
    branch_file =  os.path.join(path, 'refs', branch)
    with open(branch_file, 'r', encoding='utf-8') as file:
        revision = file.read()
    return os.path.join(path, 'snapshots', revision)

cache_directory = "../Wparam_dataset_v0/model_zoo/huggingface" 
ckpt_path = latest_version_path(cache_directory, 'meta-llama/Meta-Llama-3-8B')
net = LlamaForCausalLM.from_pretrained(ckpt_path, local_files_only=True)

# ckpt_path = '/home/jgryu/Weight_compression/model_cache/models--meta-llama--Meta-Llama-3-8B/snapshots/8cde5ca8380496c9a6cc7ef3a8b46a0372a1d920'
# net = AutoModelForCausalLM.from_pretrained(ckpt_path, local_files_only=True)
state_dict = net.state_dict()

In [12]:
root_dir = '/home/jgryu/Weight_compression/model_reconstructed/rtn'
import glob
ckpt_paths = glob.glob(os.path.join(root_dir, "**/*"), recursive=True)
ckpt_path_list = []

for ck in ckpt_paths:
    if 'result' in ck: continue
    ckpt_path_list.append(ck)

mses = []
mse_fn = nn.MSELoss()
for ckpt_path in ckpt_path_list:
    try:
        recon_net = AutoModelForCausalLM.from_pretrained(ckpt_path, local_files_only=True)
        recon_state_dict = recon_net.state_dict()

        print(ckpt_path)
        cal_corr(state_dict, recon_state_dict)
    except:
        pass

-0.0021007181346949905
-0.009032166794796246


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

: 

: 

In [None]:
bpps = [3, 4]
mses = []
mse_fn = nn.MSELoss()
for bpp in bpps:
    ckpt_path = f'../model_reconstructed/rtn/bpp{bpp}'
    recon_net = AutoModelForCausalLM.from_pretrained(ckpt_path, local_files_only=True)
    recon_state_dict = recon_net.state_dict()

    n = 0
    mse = 0
    mse_layer = []
    for k, v in state_dict.items():
        # if 'mlp' not in k and 'attn' not in k: continue
        
        mse += ((recon_state_dict[k] - v)**2).sum()
        n += v.numel()
        print(k, mse_fn(recon_state_dict[k], v)/std**2)
        mse_layer.append(mse_fn(recon_state_dict[k], v)/std**2)
    mse = mse / n / std **2 
    mses.append(mse.item())
    print(mse)

In [None]:
ckpt_path_list