In [41]:
import argparse
import json
import os
import sys
import torch
import torch.nn as nn
import re
import math
from tqdm import tqdm
from transformers import (
    AutoModel,
    AutoModelForCausalLM,
    AutoTokenizer,
)
from torch.utils.data import DataLoader
import logging

notebook_dir = os.path.dirname(os.path.abspath("__file__"))
project_root = os.path.abspath(os.path.join(notebook_dir, ".."))

std = 0.012528747320175171

if project_root not in sys.path:
    sys.path.append(project_root)

import models
from models import get_model

In [42]:
class Config:
    def __init__(self, **entries):
        self.__dict__.update(entries)

class LayerInputs:
    def __init__(self, num_layers):
        self.layers = [
            {
                "self_attn": {
                    "q_proj": None,
                    "k_proj": None,
                    "v_proj": None,
                    "o_proj": None,
                },
                "mlp": {
                    "gate_proj": None,
                    "up_proj": None,
                    "down_proj": None,
                },
            }
            for _ in range(num_layers)
        ]

def latest_version_path(cache_dir, model_name, branch="main"):
    model_name_dir = "models--" + model_name.replace("/", "--")
    path = os.path.join(cache_dir, model_name_dir)
    if not os.path.isdir(os.path.join(path, "snapshots")):
        return None
    branch_file = os.path.join(path, "refs", branch)
    with open(branch_file, "r", encoding="utf-8") as file:
        revision = file.read()
    return os.path.join(path, "snapshots", revision)

def get_named_linears(module):
    return {name: m for name, m in module.named_modules() if isinstance(m, nn.Linear)}

In [43]:
def reconstruct_model(state_dict, model, input_mag, direction):
    wtype_mapping = {'q_proj': 0, 'k_proj': 1, 'v_proj': 2, 'o_proj': 3, 'gate_proj': 4, 'up_proj': 5, 'down_proj': 6}
    
    weight_list = []
    input_list = []
    
    with torch.no_grad():
        mean_MSE = 0
        count = 0
        bpp_loss = 0
        mse_func = nn.MSELoss()
        device = next(model.parameters()).device
        recon_state_dict = {}

        for k, W in tqdm(state_dict.items()):
            if not "mlp" in k and not "self_attn" in k:
                continue
            
            match = re.search(r"layers\.(\d+).", k)
            if match:
                layer_index = int(match.group(1))  # 찾은 숫자를 정수형으로 변환
            if 'self_attn' in k:
                ltype = 'self_attn'
                ltype_i = 0
            elif 'mlp' in k:
                ltype = 'mlp'
                ltype_i = 1
            if 'q_proj' in k:
                mapping = wtype_mapping['q_proj']
                wtype = 'q_proj'
            elif 'k_proj' in k:
                mapping = wtype_mapping['k_proj']
                wtype = 'k_proj'
            elif 'v_proj' in k:
                mapping = wtype_mapping['v_proj']
                wtype = 'v_proj'
            elif 'o_proj' in k:
                mapping = wtype_mapping['o_proj']
                wtype = 'o_proj'
            elif 'gate_proj' in k:
                mapping = wtype_mapping['gate_proj']
                wtype = 'gate_proj'
            elif 'up_proj' in k:
                mapping = wtype_mapping['up_proj']
                wtype = 'up_proj'
            elif 'down_proj' in k:
                mapping = wtype_mapping['down_proj']
                wtype = 'down_proj'

            rows, cols = W.shape
            # print(W.shape)
            input_block = input_mag.layers[layer_index][ltype][wtype]
            
            assert rows % model.input_size == 0
            assert cols == input_block.size(0)
            
            if rows == 1024:
                chunks = torch.chunk(input_block, chunks=2, dim=-1)
                input_block = torch.max(chunks[0], chunks[1])
                rows = rows*2
                cols = cols//2
                
            input_block = input_block.expand(rows // 2048, cols)
            
            if direction == 'col':
                W = W.T
                input_block = input_block.T
            
            W_reshaped = W.reshape(-1, 128, model.input_size)  # ( -1, -1) --> (-1, size, size)
            
            input_block = input_block.reshape(-1, )
            # print(input_block.shape, W_reshaped.shape)
            assert W_reshaped.size(0) == input_block.size(0)
            
            weight_list.append(W_reshaped)
            input_list.append(input_block)
            
            
            
    weight_list = torch.cat(weight_list, dim = 0)
    input_list = torch.cat(input_list, dim = 0)
    return weight_list, input_list
            

In [44]:
direction = 'col'
cuda = 2
model_paths = [
    '/home/jgryu/Weight_compression/VQVAE/checkpoint/nwc_ql/block_seq_ql_random_col_16/lmbda50_rdloss_ql_encdim512_M16_batch_size2048_total_iter200000_lr0.0001_seed100/best_loss_model_loss_3.84823_bpp_4.61283_MSE_0.01614_total_iter_95000.pth.tar',
    '/home/jgryu/Weight_compression/VQVAE/checkpoint/nwc_ql/block_seq_ql_random_col_16/lmbda100_rdloss_ql_encdim512_M16_batch_size2048_total_iter1500000_lr0.0001_seed100/best_loss_model_loss_4.39201_bpp_5.10767_MSE_0.0081_total_iter_190000.pth.tar',
    '/home/jgryu/Weight_compression/VQVAE/checkpoint/nwc_ql/block_seq_ql_random_col_16/lmbda200_rdloss_ql_encdim512_M16_batch_size2048_total_iter200000_lr0.0001_seed100/best_loss_model_loss_4.97679_bpp_5.524_MSE_0.00426_total_iter_95000.pth.tar',
    '/home/jgryu/Weight_compression/VQVAE/checkpoint/nwc_ql/block_seq_ql_random_col_16/lmbda300_rdloss_ql_encdim512_M16_batch_size2048_total_iter200000_lr0.0001_seed100/best_loss_model_loss_5.34295_bpp_5.7068_MSE_0.00302_total_iter_95000.pth.tar',
    '/home/jgryu/Weight_compression/VQVAE/checkpoint/nwc_ql/block_seq_ql_random_col_16/lmbda1000_rdloss_ql_encdim512_M16_batch_size2048_total_iter1500000_lr0.0001_seed100/best_loss_model_loss_6.59649_bpp_6.05166_MSE_0.00106_total_iter_140000.pth.tar',
    '/home/jgryu/Weight_compression/VQVAE/checkpoint/nwc_ql/block_seq_ql_random_col_16/lmbda10000_rdloss_ql_encdim512_M16_batch_size2048_total_iter200000_lr0.0001_seed100/best_loss_model_loss_10.96029_bpp_6.2788_MSE_0.0004_total_iter_140000.pth.tar'    
]

In [45]:
device = torch.device(f"cuda:{cuda}" if torch.cuda.is_available() else "cpu")

import models
config = os.path.join(os.path.dirname(model_paths[0]), 'config.json')
with open(config, 'r', encoding='utf-8') as file:
    config = json.load(file)

comp_model = models.NWC_ql(
    input_size=config['input_size'],
    dim_encoder=config['dim_encoder'],
    n_resblock=config['n_resblock'],
    scale=torch.zeros(128, config['input_size']),
    shift=torch.zeros(128, config['input_size'])
)

ckpt = torch.load(model_paths[0])
comp_model.load_state_dict(ckpt["state_dict"])
comp_model.to(device)

input_mag = torch.load('/home/jgryu/Weight_compression/Wparam_dataset/calib_data/layer_inputs_chmag_rank_top[4, 10, 100]_qlevel[3, 2, 1].pt', weights_only=False)    

cache_directory = "../../Wparam_dataset_v0/model_zoo/huggingface"
ckpt_path = latest_version_path(cache_directory, "meta-llama/Meta-Llama-3-8B")
net = AutoModelForCausalLM.from_pretrained(ckpt_path, local_files_only=True)
ckpt_path = "/home/jgryu/Weight_compression/model_cache/models--meta-llama--Meta-Llama-3-8B/snapshots/8cde5ca8380496c9a6cc7ef3a8b46a0372a1d920"
tokenizer = AutoTokenizer.from_pretrained(ckpt_path, local_files_only=True)
state_dict = net.state_dict()


weight_list, input_list= reconstruct_model(state_dict, comp_model, input_mag, direction)

  ckpt = torch.load(model_paths[0])
Loading checkpoint shards: 100%|██████████| 4/4 [00:05<00:00,  1.29s/it]
100%|██████████| 291/291 [00:12<00:00, 22.90it/s]


In [46]:
print(weight_list.shape, input_list.shape)

torch.Size([3407872, 128, 16]) torch.Size([3407872])


In [47]:
indices = torch.randperm(len(weight_list))
indices = indices[:10000]
weight_list = weight_list[indices]
input_list = input_list[indices]

print(weight_list.shape, input_list.shape)

torch.Size([10000, 128, 16]) torch.Size([10000])


In [48]:
def test(weight_list, input_list, model):
    mean_MSE = 0
    avg_bpp = 0
    mean_loss = 0
    mean_recon_loss = 0
    mean_bpp_loss = 0
    device = next(model.parameters()).device
    mse_func = nn.MSELoss()
    
    model.requires_grad_(False)
    model.update()
    with torch.no_grad():
        for idx, weight in enumerate(tqdm(weight_list)):
            # data = {key: tensor.unsqueeze(0).to(device) for key, tensor in data.items()}
            data = {'weight_block': weight.unsqueeze(0).to(device),
                    'q_level': input_list[idx].unsqueeze(0).to(device)}
            # out_net = model(data)
            # out_loss = criterion(data= data, output = out_net)
            
            # mean_loss += out_loss['loss'].item()
            # mean_recon_loss += out_loss['recon_loss'].item()
            # mean_bpp_loss += out_loss['bpp_loss'].item()
            
            out_enc = model.compress(data)
            out_dec = model.decompress(out_enc["strings"][0], out_enc["shape"], data["q_level"])
            
            
            # try:
            #     out_dec = model.decompress(out_enc["strings"][0], out_enc["shape"], data["q_level"])
            # except:
            #     out_dec = model.decompress(out_enc["strings"][0], out_enc["shape"])
            
            # out_dec = model.decompress(out_enc["strings"], out_enc["shape"])
                

            num_pixels = data['weight_block'].numel()
            
            bpp = 0
            for s in out_enc["strings"]:
                bpp += len(s[0]) * 8.0 / num_pixels

            x_hat = out_dec["x_hat"].clone().detach()
            mean_MSE += mse_func(data['weight_block'], x_hat).item()
            avg_bpp += bpp

    avg_bpp /= len(weight_list)
    mean_MSE /= len(weight_list)
    mean_loss /= len(weight_list)
    mean_recon_loss /= len(weight_list)
    mean_bpp_loss /= len(weight_list)
    return {'TEST MSE': mean_MSE, 'TEST BPP': avg_bpp, 'TEST loss': mean_loss, 'TEST recon_loss': mean_recon_loss, 'TEST bpp_loss': mean_bpp_loss}

In [49]:
for model_path in model_paths:
    import models
    config = os.path.join(os.path.dirname(model_path), 'config.json')
    with open(config, 'r', encoding='utf-8') as file:
        config = json.load(file)
    comp_model = models.NWC_ql(
        input_size=config['input_size'],
        dim_encoder=config['dim_encoder'],
        n_resblock=config['n_resblock'],
        scale=torch.zeros(128, config['input_size']),
        shift=torch.zeros(128, config['input_size'])
    )

    ckpt = torch.load(model_path)
    comp_model.load_state_dict(ckpt["state_dict"])
    comp_model.to(device)
    
    
    result = test(weight_list, input_list, comp_model)
    
    print(model_path.split('/')[-4:])
    print(result)

  ckpt = torch.load(model_path)
100%|██████████| 10000/10000 [02:01<00:00, 82.07it/s]


['nwc_ql', 'block_seq_ql_random_col_16', 'lmbda50_rdloss_ql_encdim512_M16_batch_size2048_total_iter200000_lr0.0001_seed100', 'best_loss_model_loss_3.84823_bpp_4.61283_MSE_0.01614_total_iter_95000.pth.tar']
{'TEST MSE': 6.428208667152724e-06, 'TEST BPP': 2.5732203125, 'TEST loss': 0.0, 'TEST recon_loss': 0.0, 'TEST bpp_loss': 0.0}


100%|██████████| 10000/10000 [02:03<00:00, 81.19it/s]


['nwc_ql', 'block_seq_ql_random_col_16', 'lmbda100_rdloss_ql_encdim512_M16_batch_size2048_total_iter1500000_lr0.0001_seed100', 'best_loss_model_loss_4.39201_bpp_5.10767_MSE_0.0081_total_iter_190000.pth.tar']
{'TEST MSE': 3.196336244191045e-06, 'TEST BPP': 3.07330625, 'TEST loss': 0.0, 'TEST recon_loss': 0.0, 'TEST bpp_loss': 0.0}


100%|██████████| 10000/10000 [02:04<00:00, 80.05it/s]


['nwc_ql', 'block_seq_ql_random_col_16', 'lmbda200_rdloss_ql_encdim512_M16_batch_size2048_total_iter200000_lr0.0001_seed100', 'best_loss_model_loss_4.97679_bpp_5.524_MSE_0.00426_total_iter_95000.pth.tar']
{'TEST MSE': 1.6116764717727605e-06, 'TEST BPP': 3.5789546875, 'TEST loss': 0.0, 'TEST recon_loss': 0.0, 'TEST bpp_loss': 0.0}


100%|██████████| 10000/10000 [02:05<00:00, 79.87it/s]


['nwc_ql', 'block_seq_ql_random_col_16', 'lmbda300_rdloss_ql_encdim512_M16_batch_size2048_total_iter200000_lr0.0001_seed100', 'best_loss_model_loss_5.34295_bpp_5.7068_MSE_0.00302_total_iter_95000.pth.tar']
{'TEST MSE': 1.085121362536512e-06, 'TEST BPP': 3.8729578125, 'TEST loss': 0.0, 'TEST recon_loss': 0.0, 'TEST bpp_loss': 0.0}


100%|██████████| 10000/10000 [02:07<00:00, 78.62it/s]


['nwc_ql', 'block_seq_ql_random_col_16', 'lmbda1000_rdloss_ql_encdim512_M16_batch_size2048_total_iter1500000_lr0.0001_seed100', 'best_loss_model_loss_6.59649_bpp_6.05166_MSE_0.00106_total_iter_140000.pth.tar']
{'TEST MSE': 3.380308445690261e-07, 'TEST BPP': 4.7203171875, 'TEST loss': 0.0, 'TEST recon_loss': 0.0, 'TEST bpp_loss': 0.0}


100%|██████████| 10000/10000 [02:01<00:00, 82.14it/s]

['nwc_ql', 'block_seq_ql_random_col_16', 'lmbda10000_rdloss_ql_encdim512_M16_batch_size2048_total_iter200000_lr0.0001_seed100', 'best_loss_model_loss_10.96029_bpp_6.2788_MSE_0.0004_total_iter_140000.pth.tar']
{'TEST MSE': 8.148602949233919e-08, 'TEST BPP': 5.9473953125, 'TEST loss': 0.0, 'TEST recon_loss': 0.0, 'TEST bpp_loss': 0.0}



