In [71]:
import torch
import torch.nn as nn
from transformers import CLIPVisionModelWithProjection, ViTForImageClassification, AutoModelForCausalLM
from transformers import AutoModel, AutoTokenizer

import sys, os, json
from tqdm import tqdm

notebook_dir = os.path.dirname(os.path.abspath("__file__"))
project_root = os.path.abspath(os.path.join(notebook_dir, ".."))

if project_root not in sys.path:
    sys.path.append(project_root)

from VQ_SEEDLM import models

In [72]:
def reconstruct_model(state_dict, model, weight_condition, batch_size=32768):
    with torch.no_grad():
        mean_MSE = 0
        count = 0
        mse_func = nn.MSELoss()
        device = next(model.parameters()).device
        recon_state_dict = {}
        
        for k, W in tqdm(state_dict.items()):
            if not weight_condition in k: continue
            # print(f'### Reconstructing {k} ####')
            
            W_reshaped = W.reshape(-1, model.input_size) # ( -1, -1) --> (-1, size, size)
            W_recon = torch.zeros(W_reshaped.shape, dtype=W_reshaped.dtype, device=W_reshaped.device)
            
            for start_idx in range(0, W_reshaped.shape[0], batch_size):
                end_idx = min(start_idx + batch_size, W_reshaped.shape[0])  # 마지막 배치를 처리할 때 범위 조정
                batch = W_reshaped[start_idx:end_idx]  # batch_size 크기로 슬라이싱
                batch = batch.to(device)  # 배치를 GPU로 이동

                out = model(batch)
                x_hat = out['x_hat']
                W_recon[start_idx:end_idx] = x_hat

                # print(mse_func(out["x"], out["x_hat"]).item())
                mean_MSE += mse_func(out["x"], out["x_hat"]).item()
                count += 1

            W_recon = W_recon.reshape(W.shape).cpu()
            recon_state_dict[k] = W_recon
            
        mean_MSE /= count  

    return recon_state_dict, mean_MSE

In [73]:
# model_path = '/home/jgryu/Weight_compression/VQ_SEEDLM/checkpoint/Meta-Llama-3-8B/mlp_16_row_dataset.pt/size16_ne512_P4_batch_size512_total_iter2000000_lr0.0001_seed100/best_mse_model_MSE_0.11122_total_iter_2000000.pth.tar'
# model_path = '/home/jgryu/Weight_compression/VQ_SEEDLM/checkpoint/Meta-Llama-3-8B/mlp_16_row_dataset.pt/size16_ne256_P4_batch_size512_total_iter2000000_lr0.0001_seed100/best_mse_model_MSE_0.28962_total_iter_1250000.pth.tar'
model_path = '/home/jgryu/Weight_compression/VQ_SEEDLM/checkpoint/Meta-Llama-3-8B/mlp_16_row_dataset.pt/size16_ne512_P32_batch_size512_total_iter2000000_lr0.0001_seed100/best_mse_model_MSE_0.0_total_iter_1750000.pth.tar'
ckpt = torch.load(model_path)

with open('/home/jgryu/Weight_compression/Wparam_dataset/dataset_per_row/meta-llama/Meta-Llama-3-8B/mlp_16_row_dataset_stats.json', 'r', encoding='utf-8') as file:
        dataset_stats = json.load(file)  # JSON 파일을 Python 객체로 변환

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

input_size = 16
dim_encoder = 64
P = 32
ne = 512
n_resblock = 4

model = models.VQ_SEEDLM(input_size = input_size, 
                    dim_encoder = dim_encoder, 
                    P = P, n_embeddings = ne, n_resblock = n_resblock, 
                    beta = 0.25,
                    scale = torch.Tensor(dataset_stats['train']['mean_channel']).to(device), 
                    shift = torch.Tensor(dataset_stats['train']['mean_channel']).to(device)
                    )

model.load_state_dict(ckpt['state_dict'])
model.to(device)


def latest_version_path(cache_dir, model_name, branch = 'main'):
    model_name_dir =  "models--" + model_name.replace('/', '--')
    path = os.path.join(cache_dir, model_name_dir)
    if not os.path.isdir(os.path.join(path, 'snapshots')):
        return None
    branch_file =  os.path.join(path, 'refs', branch)
    with open(branch_file, 'r', encoding='utf-8') as file:
        revision = file.read()
    return os.path.join(path, 'snapshots', revision)

cache_directory = "../Wparam_dataset_v0/model_zoo/huggingface" 
ckpt_path = latest_version_path(cache_directory, 'meta-llama/Meta-Llama-3-8B')
net = AutoModelForCausalLM.from_pretrained(ckpt_path, local_files_only=True)


ckpt_path = '/home/jgryu/Weight_compression/model_cache/models--meta-llama--Meta-Llama-3-8B/snapshots/8cde5ca8380496c9a6cc7ef3a8b46a0372a1d920'
# net = AutoModelForCausalLM.from_pretrained(ckpt_path, local_files_only=True)
tokenizer = AutoTokenizer.from_pretrained(ckpt_path, local_files_only=True)
state_dict = net.state_dict()

  ckpt = torch.load(model_path)
Loading checkpoint shards: 100%|██████████| 4/4 [00:08<00:00,  2.05s/it]


In [None]:
recon_state_dict, mean_MSE = reconstruct_model(
        state_dict, model, weight_condition = 'mlp')

print(mean_MSE / dataset_stats['train']['std']**2)

 21%|██        | 61/291 [13:53<54:44, 14.28s/it]  

In [None]:
for k, v in state_dict.items():
    if k not in recon_state_dict.keys():
        recon_state_dict[k] = v
        print(k, v.shape)
    else:
        mse = ((recon_state_dict[k] - state_dict[k])**2).mean()
        print(f'{mse.item():-20f}')

model.embed_tokens.weight torch.Size([128256, 4096])
model.layers.0.self_attn.q_proj.weight torch.Size([4096, 4096])
model.layers.0.self_attn.k_proj.weight torch.Size([1024, 4096])
model.layers.0.self_attn.v_proj.weight torch.Size([1024, 4096])
model.layers.0.self_attn.o_proj.weight torch.Size([4096, 4096])
            0.000669
            0.000552
            0.000550
model.layers.0.input_layernorm.weight torch.Size([4096])
model.layers.0.post_attention_layernorm.weight torch.Size([4096])
model.layers.1.self_attn.q_proj.weight torch.Size([4096, 4096])
model.layers.1.self_attn.k_proj.weight torch.Size([1024, 4096])
model.layers.1.self_attn.v_proj.weight torch.Size([1024, 4096])
model.layers.1.self_attn.o_proj.weight torch.Size([4096, 4096])
            0.000688
            0.000568
            0.000563
model.layers.1.input_layernorm.weight torch.Size([4096])
model.layers.1.post_attention_layernorm.weight torch.Size([4096])
model.layers.2.self_attn.q_proj.weight torch.Size([4096, 4096])

In [None]:
net.load_state_dict(recon_state_dict)
save_directory = f"/home/jgryu/Weight_compression/model_cache_reconstructed/vq_seedlm_/{os.path.join(*model_path.split('/')[-3:])}"
net.save_pretrained(save_directory)
tokenizer.save_pretrained(save_directory)

('/home/jgryu/Weight_compression/model_cache_reconstructed/vq_seedlm_/mlp_16_row_dataset.pt/size16_ne512_P4_batch_size512_total_iter2000000_lr0.0001_seed100/best_mse_model_MSE_0.11122_total_iter_2000000.pth.tar/tokenizer_config.json',
 '/home/jgryu/Weight_compression/model_cache_reconstructed/vq_seedlm_/mlp_16_row_dataset.pt/size16_ne512_P4_batch_size512_total_iter2000000_lr0.0001_seed100/best_mse_model_MSE_0.11122_total_iter_2000000.pth.tar/special_tokens_map.json',
 '/home/jgryu/Weight_compression/model_cache_reconstructed/vq_seedlm_/mlp_16_row_dataset.pt/size16_ne512_P4_batch_size512_total_iter2000000_lr0.0001_seed100/best_mse_model_MSE_0.11122_total_iter_2000000.pth.tar/tokenizer.json')

In [None]:
for k, v in state_dict.items():
    # print(k)
    mean = ((recon_state_dict[k] - state_dict[k])**2).mean()
    print(f'{mean.item():-20f}')

            0.000000
            0.000000
            0.000000
            0.000000
            0.000000
            0.000000
            0.000000
            0.000000
            0.000000
            0.000000
            0.000000
            0.000000
            0.000000
            0.000000
            0.000000
            0.000000
            0.000000
            0.000000
            0.000000
            0.000000
            0.000000
            0.000000
            0.000000
            0.000000
            0.000000
            0.000000
            0.000000
            0.000000
            0.000000
            0.000000
            0.000000
            0.000000
            0.000000
            0.000000
            0.000000
            0.000000
            0.000000
            0.000000
            0.000000
            0.000000
            0.000000
            0.000000
            0.000000
            0.000000
            0.000000
            0.000000
            0.000000
            0

KeyboardInterrupt: 