In [7]:
import torch
import torch.nn as nn
from transformers import CLIPVisionModelWithProjection, ViTForImageClassification, AutoModelForCausalLM
from transformers import AutoModel, AutoTokenizer

import sys, os, json
from tqdm import tqdm

notebook_dir = os.path.dirname(os.path.abspath("__file__"))
project_root = os.path.abspath(os.path.join(notebook_dir, ".."))

if project_root not in sys.path:
    sys.path.append(project_root)

from VQ_SEEDLM import models

In [13]:
def reconstruct_model(state_dict, model, weight_condition, batch_size=32768):
    with torch.no_grad():
        mean_MSE = 0
        count = 0
        mse_func = nn.MSELoss()
        device = next(model.parameters()).device
        recon_state_dict = {}
        
        for k, W in state_dict.items():
            if not weight_condition in k: continue
            print(f'### Reconstructing {k} ####')
            
            W_reshaped = W.reshape(-1, model.input_size) # ( -1, -1) --> (-1, size, size)
            W_recon = torch.zeros(W_reshaped.shape, dtype=W_reshaped.dtype, device=W_reshaped.device)
            
            for start_idx in tqdm(range(0, W_reshaped.shape[0], batch_size)):
                end_idx = min(start_idx + batch_size, W_reshaped.shape[0])  # 마지막 배치를 처리할 때 범위 조정
                batch = W_reshaped[start_idx:end_idx]  # batch_size 크기로 슬라이싱
                batch = batch.to(device)  # 배치를 GPU로 이동

                out = model(batch)
                x_hat = out['x_hat']
                W_recon[start_idx:end_idx] = x_hat

                mean_MSE += mse_func(out["x"], out["x_hat"]).item()
                count += 1

            W_recon = W_recon.reshape(W.shape).cpu()
            recon_state_dict[k] = W_recon
            
        mean_MSE /= count  

    return recon_state_dict, mean_MSE

In [10]:
model_path = '/home/jgryu/Weight_compression/VQ_SEEDLM/checkpoint/Meta-Llama-3-8B/mlp_16_row_dataset.pt/size16_ne512_P4_batch_size512_total_iter2000000_lr0.0001_seed100/best_mse_model_MSE_0.11122_total_iter_2000000.pth.tar'
ckpt = torch.load(model_path)

with open('/home/jgryu/Weight_compression/Wparam_dataset/dataset_per_row/meta-llama/Meta-Llama-3-8B/mlp_16_row_dataset_stats.json', 'r', encoding='utf-8') as file:
        dataset_stats = json.load(file)  # JSON 파일을 Python 객체로 변환

device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')

input_size = 16
dim_encoder = 64
P = 4
ne = 512
n_resblock = 4

model = models.VQ_SEEDLM(input_size = 16, 
                    dim_encoder = 64, 
                    P = 4, n_embeddings = 512, n_resblock = 4, 
                    beta = 0.25,
                    scale = torch.Tensor(dataset_stats['train']['mean_channel']).to(device), 
                    shift = torch.Tensor(dataset_stats['train']['mean_channel']).to(device)
                    )

model.load_state_dict(ckpt['state_dict'])
model.to(device)

ckpt_path = '/home/jgryu/Weight_compression/model_cache/models--meta-llama--Meta-Llama-3-8B/snapshots/8cde5ca8380496c9a6cc7ef3a8b46a0372a1d920'
net = AutoModelForCausalLM.from_pretrained(ckpt_path, local_files_only=True)
tokenizer = AutoTokenizer.from_pretrained(ckpt_path, local_files_only=True)
state_dict = net.state_dict()

  ckpt = torch.load(model_path)
Loading checkpoint shards: 100%|██████████| 4/4 [00:05<00:00,  1.38s/it]


In [None]:
recon_state_dict, mean_MSE = reconstruct_model(
        state_dict, model, weight_condition = 'mlp')

print(mean_MSE)

for k, v in state_dict.items():
    if k not in recon_state_dict.keys():
        recon_state_dict[k] = v
        print(k, v.shape)
    else:
        assert recon_state_dict[k].shape == state_dict[k].shape

### Reconstructing model.layers.0.mlp.gate_proj.weight ####


  0%|          | 0/112 [00:00<?, ?it/s]

100%|██████████| 112/112 [00:06<00:00, 16.51it/s]


### Reconstructing model.layers.0.mlp.up_proj.weight ####


100%|██████████| 112/112 [00:06<00:00, 17.37it/s]


### Reconstructing model.layers.0.mlp.down_proj.weight ####


100%|██████████| 112/112 [00:06<00:00, 16.51it/s]


### Reconstructing model.layers.1.mlp.gate_proj.weight ####


100%|██████████| 112/112 [00:06<00:00, 16.60it/s]


### Reconstructing model.layers.1.mlp.up_proj.weight ####


100%|██████████| 112/112 [00:06<00:00, 16.74it/s]


### Reconstructing model.layers.1.mlp.down_proj.weight ####


100%|██████████| 112/112 [00:06<00:00, 16.69it/s]


### Reconstructing model.layers.2.mlp.gate_proj.weight ####


100%|██████████| 112/112 [00:06<00:00, 17.28it/s]


### Reconstructing model.layers.2.mlp.up_proj.weight ####


100%|██████████| 112/112 [00:06<00:00, 16.59it/s]


### Reconstructing model.layers.2.mlp.down_proj.weight ####


100%|██████████| 112/112 [00:06<00:00, 16.47it/s]


### Reconstructing model.layers.3.mlp.gate_proj.weight ####


100%|██████████| 112/112 [00:06<00:00, 16.98it/s]


### Reconstructing model.layers.3.mlp.up_proj.weight ####


100%|██████████| 112/112 [00:06<00:00, 17.06it/s]


### Reconstructing model.layers.3.mlp.down_proj.weight ####


100%|██████████| 112/112 [00:06<00:00, 16.16it/s]


### Reconstructing model.layers.4.mlp.gate_proj.weight ####


100%|██████████| 112/112 [00:06<00:00, 16.62it/s]


### Reconstructing model.layers.4.mlp.up_proj.weight ####


100%|██████████| 112/112 [00:06<00:00, 16.41it/s]


### Reconstructing model.layers.4.mlp.down_proj.weight ####


100%|██████████| 112/112 [00:06<00:00, 16.43it/s]


### Reconstructing model.layers.5.mlp.gate_proj.weight ####


100%|██████████| 112/112 [00:06<00:00, 17.08it/s]


### Reconstructing model.layers.5.mlp.up_proj.weight ####


100%|██████████| 112/112 [00:07<00:00, 15.77it/s]


### Reconstructing model.layers.5.mlp.down_proj.weight ####


100%|██████████| 112/112 [00:07<00:00, 15.58it/s]


### Reconstructing model.layers.6.mlp.gate_proj.weight ####


100%|██████████| 112/112 [00:06<00:00, 16.03it/s]


### Reconstructing model.layers.6.mlp.up_proj.weight ####


100%|██████████| 112/112 [00:07<00:00, 15.50it/s]


### Reconstructing model.layers.6.mlp.down_proj.weight ####


100%|██████████| 112/112 [00:07<00:00, 15.89it/s]


### Reconstructing model.layers.7.mlp.gate_proj.weight ####


100%|██████████| 112/112 [00:07<00:00, 15.65it/s]


### Reconstructing model.layers.7.mlp.up_proj.weight ####


100%|██████████| 112/112 [00:07<00:00, 15.71it/s]


### Reconstructing model.layers.7.mlp.down_proj.weight ####


100%|██████████| 112/112 [00:07<00:00, 15.40it/s]


### Reconstructing model.layers.8.mlp.gate_proj.weight ####


100%|██████████| 112/112 [00:07<00:00, 15.70it/s]


### Reconstructing model.layers.8.mlp.up_proj.weight ####


100%|██████████| 112/112 [00:07<00:00, 15.44it/s]


### Reconstructing model.layers.8.mlp.down_proj.weight ####


100%|██████████| 112/112 [00:07<00:00, 15.73it/s]


### Reconstructing model.layers.9.mlp.gate_proj.weight ####


100%|██████████| 112/112 [00:07<00:00, 15.84it/s]


### Reconstructing model.layers.9.mlp.up_proj.weight ####


100%|██████████| 112/112 [00:07<00:00, 15.69it/s]


### Reconstructing model.layers.9.mlp.down_proj.weight ####


100%|██████████| 112/112 [00:07<00:00, 15.99it/s]


### Reconstructing model.layers.10.mlp.gate_proj.weight ####


100%|██████████| 112/112 [00:07<00:00, 15.80it/s]


### Reconstructing model.layers.10.mlp.up_proj.weight ####


100%|██████████| 112/112 [00:07<00:00, 15.97it/s]


### Reconstructing model.layers.10.mlp.down_proj.weight ####


100%|██████████| 112/112 [00:07<00:00, 15.89it/s]


### Reconstructing model.layers.11.mlp.gate_proj.weight ####


100%|██████████| 112/112 [00:07<00:00, 15.85it/s]


### Reconstructing model.layers.11.mlp.up_proj.weight ####


100%|██████████| 112/112 [00:07<00:00, 15.54it/s]


### Reconstructing model.layers.11.mlp.down_proj.weight ####


100%|██████████| 112/112 [00:07<00:00, 15.95it/s]


### Reconstructing model.layers.12.mlp.gate_proj.weight ####


100%|██████████| 112/112 [00:07<00:00, 15.87it/s]


### Reconstructing model.layers.12.mlp.up_proj.weight ####


100%|██████████| 112/112 [00:07<00:00, 15.99it/s]


### Reconstructing model.layers.12.mlp.down_proj.weight ####


100%|██████████| 112/112 [00:06<00:00, 16.10it/s]


### Reconstructing model.layers.13.mlp.gate_proj.weight ####


100%|██████████| 112/112 [00:07<00:00, 15.60it/s]


### Reconstructing model.layers.13.mlp.up_proj.weight ####


100%|██████████| 112/112 [00:07<00:00, 15.90it/s]


### Reconstructing model.layers.13.mlp.down_proj.weight ####


100%|██████████| 112/112 [00:07<00:00, 15.80it/s]


### Reconstructing model.layers.14.mlp.gate_proj.weight ####


100%|██████████| 112/112 [00:07<00:00, 15.91it/s]


### Reconstructing model.layers.14.mlp.up_proj.weight ####


100%|██████████| 112/112 [00:06<00:00, 16.63it/s]


### Reconstructing model.layers.14.mlp.down_proj.weight ####


100%|██████████| 112/112 [00:07<00:00, 15.91it/s]


### Reconstructing model.layers.15.mlp.gate_proj.weight ####


100%|██████████| 112/112 [00:07<00:00, 15.29it/s]


### Reconstructing model.layers.15.mlp.up_proj.weight ####


100%|██████████| 112/112 [00:07<00:00, 15.77it/s]


### Reconstructing model.layers.15.mlp.down_proj.weight ####


100%|██████████| 112/112 [00:06<00:00, 16.00it/s]


### Reconstructing model.layers.16.mlp.gate_proj.weight ####


100%|██████████| 112/112 [00:07<00:00, 15.84it/s]


### Reconstructing model.layers.16.mlp.up_proj.weight ####


100%|██████████| 112/112 [00:07<00:00, 15.71it/s]


### Reconstructing model.layers.16.mlp.down_proj.weight ####


100%|██████████| 112/112 [00:06<00:00, 16.15it/s]


### Reconstructing model.layers.17.mlp.gate_proj.weight ####


100%|██████████| 112/112 [00:06<00:00, 16.26it/s]


### Reconstructing model.layers.17.mlp.up_proj.weight ####


100%|██████████| 112/112 [00:06<00:00, 16.27it/s]


### Reconstructing model.layers.17.mlp.down_proj.weight ####


100%|██████████| 112/112 [00:07<00:00, 15.81it/s]


### Reconstructing model.layers.18.mlp.gate_proj.weight ####


100%|██████████| 112/112 [00:07<00:00, 15.78it/s]


### Reconstructing model.layers.18.mlp.up_proj.weight ####


100%|██████████| 112/112 [00:07<00:00, 15.77it/s]


### Reconstructing model.layers.18.mlp.down_proj.weight ####


100%|██████████| 112/112 [00:07<00:00, 15.95it/s]


### Reconstructing model.layers.19.mlp.gate_proj.weight ####


100%|██████████| 112/112 [00:07<00:00, 15.45it/s]


### Reconstructing model.layers.19.mlp.up_proj.weight ####


100%|██████████| 112/112 [00:07<00:00, 15.36it/s]


### Reconstructing model.layers.19.mlp.down_proj.weight ####


100%|██████████| 112/112 [00:07<00:00, 15.58it/s]


### Reconstructing model.layers.20.mlp.gate_proj.weight ####


100%|██████████| 112/112 [00:07<00:00, 15.80it/s]


### Reconstructing model.layers.20.mlp.up_proj.weight ####


100%|██████████| 112/112 [00:07<00:00, 15.86it/s]


### Reconstructing model.layers.20.mlp.down_proj.weight ####


100%|██████████| 112/112 [00:06<00:00, 16.09it/s]


### Reconstructing model.layers.21.mlp.gate_proj.weight ####


100%|██████████| 112/112 [00:07<00:00, 15.99it/s]


### Reconstructing model.layers.21.mlp.up_proj.weight ####


100%|██████████| 112/112 [00:06<00:00, 16.05it/s]


### Reconstructing model.layers.21.mlp.down_proj.weight ####


100%|██████████| 112/112 [00:07<00:00, 15.98it/s]


### Reconstructing model.layers.22.mlp.gate_proj.weight ####


100%|██████████| 112/112 [00:06<00:00, 16.13it/s]


### Reconstructing model.layers.22.mlp.up_proj.weight ####


100%|██████████| 112/112 [00:06<00:00, 16.29it/s]


### Reconstructing model.layers.22.mlp.down_proj.weight ####


100%|██████████| 112/112 [00:07<00:00, 15.94it/s]


### Reconstructing model.layers.23.mlp.gate_proj.weight ####


100%|██████████| 112/112 [00:07<00:00, 15.50it/s]


### Reconstructing model.layers.23.mlp.up_proj.weight ####


100%|██████████| 112/112 [00:07<00:00, 15.64it/s]


### Reconstructing model.layers.23.mlp.down_proj.weight ####


100%|██████████| 112/112 [00:07<00:00, 15.89it/s]


### Reconstructing model.layers.24.mlp.gate_proj.weight ####


100%|██████████| 112/112 [00:07<00:00, 15.77it/s]


### Reconstructing model.layers.24.mlp.up_proj.weight ####


100%|██████████| 112/112 [00:06<00:00, 16.17it/s]


### Reconstructing model.layers.24.mlp.down_proj.weight ####


100%|██████████| 112/112 [00:07<00:00, 15.99it/s]


### Reconstructing model.layers.25.mlp.gate_proj.weight ####


100%|██████████| 112/112 [00:06<00:00, 16.36it/s]


### Reconstructing model.layers.25.mlp.up_proj.weight ####


100%|██████████| 112/112 [00:06<00:00, 16.04it/s]


### Reconstructing model.layers.25.mlp.down_proj.weight ####


100%|██████████| 112/112 [00:07<00:00, 15.83it/s]


### Reconstructing model.layers.26.mlp.gate_proj.weight ####


100%|██████████| 112/112 [00:06<00:00, 16.49it/s]


### Reconstructing model.layers.26.mlp.up_proj.weight ####


100%|██████████| 112/112 [00:06<00:00, 17.02it/s]


### Reconstructing model.layers.26.mlp.down_proj.weight ####


100%|██████████| 112/112 [00:07<00:00, 15.26it/s]


### Reconstructing model.layers.27.mlp.gate_proj.weight ####


100%|██████████| 112/112 [00:07<00:00, 15.61it/s]


### Reconstructing model.layers.27.mlp.up_proj.weight ####


100%|██████████| 112/112 [00:07<00:00, 15.57it/s]


### Reconstructing model.layers.27.mlp.down_proj.weight ####


100%|██████████| 112/112 [00:07<00:00, 15.44it/s]


### Reconstructing model.layers.28.mlp.gate_proj.weight ####


100%|██████████| 112/112 [00:06<00:00, 16.06it/s]


### Reconstructing model.layers.28.mlp.up_proj.weight ####


100%|██████████| 112/112 [00:06<00:00, 16.35it/s]


### Reconstructing model.layers.28.mlp.down_proj.weight ####


100%|██████████| 112/112 [00:06<00:00, 16.05it/s]


### Reconstructing model.layers.29.mlp.gate_proj.weight ####


100%|██████████| 112/112 [00:06<00:00, 16.11it/s]


### Reconstructing model.layers.29.mlp.up_proj.weight ####


100%|██████████| 112/112 [00:07<00:00, 15.75it/s]


### Reconstructing model.layers.29.mlp.down_proj.weight ####


100%|██████████| 112/112 [00:07<00:00, 15.86it/s]


### Reconstructing model.layers.30.mlp.gate_proj.weight ####


100%|██████████| 112/112 [00:06<00:00, 16.07it/s]


### Reconstructing model.layers.30.mlp.up_proj.weight ####


100%|██████████| 112/112 [00:06<00:00, 16.05it/s]


### Reconstructing model.layers.30.mlp.down_proj.weight ####


100%|██████████| 112/112 [00:07<00:00, 15.95it/s]


### Reconstructing model.layers.31.mlp.gate_proj.weight ####


100%|██████████| 112/112 [00:07<00:00, 15.78it/s]


### Reconstructing model.layers.31.mlp.up_proj.weight ####


100%|██████████| 112/112 [00:07<00:00, 15.74it/s]


### Reconstructing model.layers.31.mlp.down_proj.weight ####


100%|██████████| 112/112 [00:06<00:00, 16.19it/s]

0.0006984801366911597





NameError: name 'reconstruncted_state_dict' is not defined

In [15]:
for k, v in state_dict.items():
    if k not in recon_state_dict.keys():
        recon_state_dict[k] = v
        print(k, v.shape)
    else:
        assert recon_state_dict[k].shape == state_dict[k].shape

model.embed_tokens.weight torch.Size([128256, 4096])
model.layers.0.self_attn.q_proj.weight torch.Size([4096, 4096])
model.layers.0.self_attn.k_proj.weight torch.Size([1024, 4096])
model.layers.0.self_attn.v_proj.weight torch.Size([1024, 4096])
model.layers.0.self_attn.o_proj.weight torch.Size([4096, 4096])
model.layers.0.input_layernorm.weight torch.Size([4096])
model.layers.0.post_attention_layernorm.weight torch.Size([4096])
model.layers.1.self_attn.q_proj.weight torch.Size([4096, 4096])
model.layers.1.self_attn.k_proj.weight torch.Size([1024, 4096])
model.layers.1.self_attn.v_proj.weight torch.Size([1024, 4096])
model.layers.1.self_attn.o_proj.weight torch.Size([4096, 4096])
model.layers.1.input_layernorm.weight torch.Size([4096])
model.layers.1.post_attention_layernorm.weight torch.Size([4096])
model.layers.2.self_attn.q_proj.weight torch.Size([4096, 4096])
model.layers.2.self_attn.k_proj.weight torch.Size([1024, 4096])
model.layers.2.self_attn.v_proj.weight torch.Size([1024, 4096

In [22]:
net.load_state_dict(recon_state_dict)
save_directory = f"/home/jgryu/Weight_compression/model_cache_reconstructed/vq_seedlm/{os.path.join(*model_path.split('/')[-3:])}"
net.save_pretrained(save_directory)
tokenizer.save_pretrained(save_directory)

('/home/jgryu/Weight_compression/model_cache_reconstructed/vq_seedlm/mlp_16_row_dataset.pt/size16_ne512_P4_batch_size512_total_iter2000000_lr0.0001_seed100/best_mse_model_MSE_0.11122_total_iter_2000000.pth.tar/tokenizer_config.json',
 '/home/jgryu/Weight_compression/model_cache_reconstructed/vq_seedlm/mlp_16_row_dataset.pt/size16_ne512_P4_batch_size512_total_iter2000000_lr0.0001_seed100/best_mse_model_MSE_0.11122_total_iter_2000000.pth.tar/special_tokens_map.json',
 '/home/jgryu/Weight_compression/model_cache_reconstructed/vq_seedlm/mlp_16_row_dataset.pt/size16_ne512_P4_batch_size512_total_iter2000000_lr0.0001_seed100/best_mse_model_MSE_0.11122_total_iter_2000000.pth.tar/tokenizer.json')