In [1]:
import os, random, sys, socket, lpips, shutil, operator

# 시간 측정해보기

import pandas as pd
import numpy as np

import torch
import torch.optim as optim
import torch.distributed as dist
import torchvision
import torch.nn.functional as F

from torch.utils.data import DataLoader

from datasets_Imagenet_best_worst import Imagenet_best_worst
from datasets_ImageNet import ImageNet_dataset
from datasets_WeightParam import WParam_dataset
# from datasets_openimages_v6 import Openimages_v6_dataset

from pytorch_msssim import ms_ssim as ms_ssim_func

from models.TCM import TCM
from models.FTIC import FrequencyAwareTransFormer
from models.ELIC import ELIC, model_config

from utils.optimizers import *
from utils.util import *

from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm
  @amp.autocast(enabled=False)


In [2]:
def pad(x, p):
    h, w = x.size(2), x.size(3)
    new_h = (h + p - 1) // p * p
    new_w = (w + p - 1) // p * p
    padding_left = (new_w - w) // 2
    padding_right = new_w - w - padding_left
    padding_top = (new_h - h) // 2
    padding_bottom = new_h - h - padding_top
    x_padded = F.pad(
        x,
        (padding_left, padding_right, padding_top, padding_bottom),
        mode="constant",
        value=0,
    )
    return x_padded, (padding_left, padding_right, padding_top, padding_bottom)

def crop(x, padding):
    return F.pad(
        x,
        (-padding[0], -padding[1], -padding[2], -padding[3]),
    )
    
def make_image_format(W, wp_mean, wp_std, normalize):
    if normalize:
        W = (W - wp_mean) / wp_std
    W = W.unsqueeze(1).repeat(1, 3, 1, 1)
    return W

def reverse_image_format(W, wp_mean, wp_std, normalize):
    # 이미지를 채널 축에서 3 -> 1로 줄이기
    # W = W[:, 0, :, :]  # 첫 번째 채널만 유지
    W = W.mean(1)  # 첫 번째 채널만 유지
    # Normalize를 반대로 적용
    if normalize:
        W = W * wp_std + wp_mean
    return W

def reconstruct_model(state_dict, model, save_path, logger, size, weight_condition, mean, std, batch=4, normalize = True):
    avg_bpp = 0.0
    mean_MSE = 0
    count = 0
    mse_func = nn.MSELoss()
    
    device = next(model.parameters()).device

    recon_state_dict = {}
    
    for k, W in state_dict.items():
        if not weight_condition in k: continue
        print(f'### Reconstructing {k} ####')
        
        W_reshaped = W.reshape(-1, size, size) # ( -1, -1) --> (-1, size, size)
        W_reshaped = W_reshaped.to(device)
        W_reshaped = make_image_format(W_reshaped, mean, std, normalize)  # (-1, size, size) --> (-1, 3, size, size)
        
        # try : 
        #     W_reshaped = W_reshaped.reshape(-1, batch, 3, size, size)  # (-1, 3, size, size) --> (-1, batch, 3, size, size)
        # except:
        #     W_reshaped = W_reshaped.reshape(-1, 1, 3, size, size)  # (-1, 3, size, size) --> (-1, 1, 3, size, size)
            
        W_reshaped = W_reshaped.reshape(-1, 1, 3, size, size)  # (-1, 3, size, size) --> (-1, 1, 3, size, size)
        W_recon = torch.zeros(W_reshaped.shape, dtype=W_reshaped.dtype, device=W_reshaped.device)
        
        for idx, W_slice in tqdm(enumerate(W_reshaped)): # (bath, 3, size, size) in (-1, bath, 3, size, size)
            # print(W_slice.shape)
            count += 1
            x = W_slice.to(device)  # (bach3, size, size) --> (1, 3, size, size)

            try:
                x_paddeimg, padding = pad(x, p = 128)
                out_enc = model.compress(x_paddeimg.to(device))
            except:
                x_paddeimg, padding = pad(x, p = 256)
                out_enc = model.compress(x_paddeimg.to(device))
            
            out_dec = model.decompress(out_enc["strings"], out_enc["shape"])
            
            num_pixels = x.size(0) * x.size(2) * x.size(3)
            bpp = 0
            for s in out_enc["strings"]:
                if s != [0]: #  
                    bpp += len(s[0]) * 8.0 / num_pixels 

            x_hat = crop(out_dec["x_hat"], padding).clone().detach() # (1, 3, size, size)
            mse = mse_func(x, x_hat).item()
            avg_bpp += bpp
            mean_MSE += mse
            
            W_recon_slice = x_hat
            W_recon[idx] = W_recon_slice
            # logger.info(f"File name: {idx}, MSE: {mse}, BPP: {bpp}")

        W_recon = W_recon.reshape(-1, 3, size, size)  # (-1, batch, 3, size, size) --> (-1, 3, size, size)
        W_recon = reverse_image_format(W_recon, mean, std, normalize)  #  (-1, 3, size, size) --> (-1, size, size)
        recon_state_dict[k] = W_recon
        
        
    avg_bpp /= count
    mean_MSE /= count  
    # logger.info(f'Average_MSE: {mean_MSE}, Average_Bit-rate: {avg_bpp} bpp')

    return recon_state_dict, avg_bpp, mean_MSE

In [3]:
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')

In [4]:
import torch
from transformers import CLIPVisionModelWithProjection, ViTForImageClassification, AutoModelForCausalLM
from transformers import AutoModel, AutoTokenizer

ckpt_path = '/home/jgryu/Weight_compression/llm-awq/model_cache/models--meta-llama--Meta-Llama-3-8B/snapshots/8cde5ca8380496c9a6cc7ef3a8b46a0372a1d920'
net = AutoModelForCausalLM.from_pretrained(ckpt_path, local_files_only=True)

mean = np.load(f'/home/jgryu/Weight_compression/Wparam_dataset/TFRecord/meta-llama--Meta-Llama-3-8B/attn/d16/attn_d16_train_mean.npy')
std = np.load(f'/home/jgryu/Weight_compression/Wparam_dataset/TFRecord/meta-llama--Meta-Llama-3-8B/attn/d16/attn_d16_train_std.npy')
mean = torch.from_numpy(mean)
std = torch.from_numpy(std)

size = 256
weight_condition = 'attn'

path = 'checkpoints_image_pretrained'
pt_list = os.listdir(path)
lmbdas = []
for pt in pt_list:
    lm = pt.replace('.pth', '')
    lmbdas.append(float(lm))
lmbdas = sorted(lmbdas)[-2:-1]
print(lmbdas)

for lm in lmbdas:
    print(f'##### lambda: {lm} #####')
    pt = f'{lm}.pth'
    ck_path = f'checkpoints_image_pretrained/{lm}.pth'
    
    try:
        checkpoint = torch.load(ck_path, map_location=device)
        assert isinstance(checkpoint, dict), "Checkpoint is not a dictionary"
        assert "state_dict" in checkpoint, "Missing 'state_dict' in checkpoint"
        print(f"Checkpoint for {lm} loaded successfully.")
    except Exception as e:
        print(f"Failed to load checkpoint for {lm}: {e}")


    model = TCM(N=64)
    try:
        model.load_state_dict(checkpoint["state_dict"])
        print(f"Model state_dict loaded successfully for {lm}.")
    except RuntimeError as e:
        print(f"Failed to load model state_dict for {lm}: {e}")
        
    model = model.eval().to(device)
    model.requires_grad_(False)
    model.update()
        
    recon_state_dict, avg_bpp, mean_MSE = reconstruct_model(
        net.state_dict(), model, save_path = None, logger= None, size = size, 
        weight_condition = weight_condition, mean = mean, std = std)

print(avg_bpp, mean_MSE)
torch.save(recon_state_dict, "reconstruncted_state_dict/meta-llama--Meta-Llama-3-8B_attn_d256_256.pth")

Loading checkpoint shards: 100%|██████████| 4/4 [00:05<00:00,  1.43s/it]
  checkpoint = torch.load(ck_path, map_location=device)


[0.025]
##### lambda: 0.025 #####
Checkpoint for 0.025 loaded successfully.
Model state_dict loaded successfully for 0.025.
### Reconstructing model.layers.0.self_attn.q_proj.weight ####


256it [00:44,  5.81it/s]


### Reconstructing model.layers.0.self_attn.k_proj.weight ####


64it [00:11,  5.81it/s]


### Reconstructing model.layers.0.self_attn.v_proj.weight ####


64it [00:10,  5.89it/s]


### Reconstructing model.layers.0.self_attn.o_proj.weight ####


256it [00:43,  5.93it/s]


### Reconstructing model.layers.1.self_attn.q_proj.weight ####


256it [00:44,  5.82it/s]


### Reconstructing model.layers.1.self_attn.k_proj.weight ####


64it [00:11,  5.78it/s]


### Reconstructing model.layers.1.self_attn.v_proj.weight ####


64it [00:10,  6.10it/s]


### Reconstructing model.layers.1.self_attn.o_proj.weight ####


256it [00:42,  5.96it/s]


### Reconstructing model.layers.2.self_attn.q_proj.weight ####


256it [00:42,  5.96it/s]


### Reconstructing model.layers.2.self_attn.k_proj.weight ####


64it [00:10,  6.03it/s]


### Reconstructing model.layers.2.self_attn.v_proj.weight ####


64it [00:10,  5.91it/s]


### Reconstructing model.layers.2.self_attn.o_proj.weight ####


256it [00:42,  6.06it/s]


### Reconstructing model.layers.3.self_attn.q_proj.weight ####


256it [00:42,  5.98it/s]


### Reconstructing model.layers.3.self_attn.k_proj.weight ####


64it [00:10,  5.85it/s]


### Reconstructing model.layers.3.self_attn.v_proj.weight ####


64it [00:10,  6.15it/s]


### Reconstructing model.layers.3.self_attn.o_proj.weight ####


256it [00:42,  5.99it/s]


### Reconstructing model.layers.4.self_attn.q_proj.weight ####


256it [00:43,  5.92it/s]


### Reconstructing model.layers.4.self_attn.k_proj.weight ####


64it [00:10,  6.01it/s]


### Reconstructing model.layers.4.self_attn.v_proj.weight ####


64it [00:11,  5.81it/s]


### Reconstructing model.layers.4.self_attn.o_proj.weight ####


256it [00:42,  6.04it/s]


### Reconstructing model.layers.5.self_attn.q_proj.weight ####


256it [00:42,  6.02it/s]


### Reconstructing model.layers.5.self_attn.k_proj.weight ####


64it [00:10,  5.96it/s]


### Reconstructing model.layers.5.self_attn.v_proj.weight ####


64it [00:10,  6.11it/s]


### Reconstructing model.layers.5.self_attn.o_proj.weight ####


256it [00:42,  6.02it/s]


### Reconstructing model.layers.6.self_attn.q_proj.weight ####


256it [00:42,  6.00it/s]


### Reconstructing model.layers.6.self_attn.k_proj.weight ####


64it [00:10,  6.06it/s]


### Reconstructing model.layers.6.self_attn.v_proj.weight ####


64it [00:10,  6.11it/s]


### Reconstructing model.layers.6.self_attn.o_proj.weight ####


256it [00:41,  6.10it/s]


### Reconstructing model.layers.7.self_attn.q_proj.weight ####


256it [00:44,  5.70it/s]


### Reconstructing model.layers.7.self_attn.k_proj.weight ####


64it [00:11,  5.39it/s]


### Reconstructing model.layers.7.self_attn.v_proj.weight ####


64it [00:11,  5.46it/s]


### Reconstructing model.layers.7.self_attn.o_proj.weight ####


256it [00:43,  5.95it/s]


### Reconstructing model.layers.8.self_attn.q_proj.weight ####


256it [00:42,  5.99it/s]


### Reconstructing model.layers.8.self_attn.k_proj.weight ####


64it [00:10,  5.88it/s]


### Reconstructing model.layers.8.self_attn.v_proj.weight ####


64it [00:10,  6.07it/s]


### Reconstructing model.layers.8.self_attn.o_proj.weight ####


256it [00:42,  6.02it/s]


### Reconstructing model.layers.9.self_attn.q_proj.weight ####


256it [00:42,  5.97it/s]


### Reconstructing model.layers.9.self_attn.k_proj.weight ####


64it [00:10,  6.00it/s]


### Reconstructing model.layers.9.self_attn.v_proj.weight ####


64it [00:10,  5.89it/s]


### Reconstructing model.layers.9.self_attn.o_proj.weight ####


256it [00:42,  6.02it/s]


### Reconstructing model.layers.10.self_attn.q_proj.weight ####


256it [00:42,  6.03it/s]


### Reconstructing model.layers.10.self_attn.k_proj.weight ####


64it [00:10,  5.97it/s]


### Reconstructing model.layers.10.self_attn.v_proj.weight ####


64it [00:10,  6.10it/s]


### Reconstructing model.layers.10.self_attn.o_proj.weight ####


256it [00:42,  6.04it/s]


### Reconstructing model.layers.11.self_attn.q_proj.weight ####


256it [00:42,  6.03it/s]


### Reconstructing model.layers.11.self_attn.k_proj.weight ####


64it [00:10,  6.09it/s]


### Reconstructing model.layers.11.self_attn.v_proj.weight ####


64it [00:10,  5.98it/s]


### Reconstructing model.layers.11.self_attn.o_proj.weight ####


256it [00:42,  6.03it/s]


### Reconstructing model.layers.12.self_attn.q_proj.weight ####


256it [00:42,  6.04it/s]


### Reconstructing model.layers.12.self_attn.k_proj.weight ####


64it [00:10,  5.84it/s]


### Reconstructing model.layers.12.self_attn.v_proj.weight ####


64it [00:10,  6.14it/s]


### Reconstructing model.layers.12.self_attn.o_proj.weight ####


256it [00:42,  5.99it/s]


### Reconstructing model.layers.13.self_attn.q_proj.weight ####


256it [00:42,  6.00it/s]


### Reconstructing model.layers.13.self_attn.k_proj.weight ####


64it [00:10,  6.05it/s]


### Reconstructing model.layers.13.self_attn.v_proj.weight ####


64it [00:10,  5.91it/s]


### Reconstructing model.layers.13.self_attn.o_proj.weight ####


256it [00:42,  6.01it/s]


### Reconstructing model.layers.14.self_attn.q_proj.weight ####


256it [00:42,  6.06it/s]


### Reconstructing model.layers.14.self_attn.k_proj.weight ####


64it [00:10,  5.95it/s]


### Reconstructing model.layers.14.self_attn.v_proj.weight ####


64it [00:10,  6.18it/s]


### Reconstructing model.layers.14.self_attn.o_proj.weight ####


256it [00:42,  6.06it/s]


### Reconstructing model.layers.15.self_attn.q_proj.weight ####


256it [00:42,  6.01it/s]


### Reconstructing model.layers.15.self_attn.k_proj.weight ####


64it [00:10,  6.04it/s]


### Reconstructing model.layers.15.self_attn.v_proj.weight ####


64it [00:10,  5.97it/s]


### Reconstructing model.layers.15.self_attn.o_proj.weight ####


256it [00:42,  6.04it/s]


### Reconstructing model.layers.16.self_attn.q_proj.weight ####


256it [00:42,  6.03it/s]


### Reconstructing model.layers.16.self_attn.k_proj.weight ####


64it [00:10,  5.89it/s]


### Reconstructing model.layers.16.self_attn.v_proj.weight ####


64it [00:10,  6.16it/s]


### Reconstructing model.layers.16.self_attn.o_proj.weight ####


256it [00:42,  6.03it/s]


### Reconstructing model.layers.17.self_attn.q_proj.weight ####


256it [00:42,  5.97it/s]


### Reconstructing model.layers.17.self_attn.k_proj.weight ####


64it [00:10,  6.06it/s]


### Reconstructing model.layers.17.self_attn.v_proj.weight ####


64it [00:10,  5.87it/s]


### Reconstructing model.layers.17.self_attn.o_proj.weight ####


256it [00:42,  6.05it/s]


### Reconstructing model.layers.18.self_attn.q_proj.weight ####


256it [00:42,  6.04it/s]


### Reconstructing model.layers.18.self_attn.k_proj.weight ####


64it [00:10,  5.88it/s]


### Reconstructing model.layers.18.self_attn.v_proj.weight ####


64it [00:10,  6.17it/s]


### Reconstructing model.layers.18.self_attn.o_proj.weight ####


256it [00:42,  6.02it/s]


### Reconstructing model.layers.19.self_attn.q_proj.weight ####


256it [00:42,  5.98it/s]


### Reconstructing model.layers.19.self_attn.k_proj.weight ####


64it [00:10,  6.08it/s]


### Reconstructing model.layers.19.self_attn.v_proj.weight ####


64it [00:10,  5.96it/s]


### Reconstructing model.layers.19.self_attn.o_proj.weight ####


256it [00:42,  6.06it/s]


### Reconstructing model.layers.20.self_attn.q_proj.weight ####


256it [00:42,  6.01it/s]


### Reconstructing model.layers.20.self_attn.k_proj.weight ####


64it [00:11,  5.81it/s]


### Reconstructing model.layers.20.self_attn.v_proj.weight ####


64it [00:10,  6.09it/s]


### Reconstructing model.layers.20.self_attn.o_proj.weight ####


256it [00:42,  5.97it/s]


### Reconstructing model.layers.21.self_attn.q_proj.weight ####


256it [00:42,  6.00it/s]


### Reconstructing model.layers.21.self_attn.k_proj.weight ####


64it [00:10,  6.04it/s]


### Reconstructing model.layers.21.self_attn.v_proj.weight ####


64it [00:10,  5.88it/s]


### Reconstructing model.layers.21.self_attn.o_proj.weight ####


256it [00:42,  6.03it/s]


### Reconstructing model.layers.22.self_attn.q_proj.weight ####


256it [00:42,  6.02it/s]


### Reconstructing model.layers.22.self_attn.k_proj.weight ####


64it [00:10,  5.93it/s]


### Reconstructing model.layers.22.self_attn.v_proj.weight ####


64it [00:10,  6.16it/s]


### Reconstructing model.layers.22.self_attn.o_proj.weight ####


256it [00:42,  6.01it/s]


### Reconstructing model.layers.23.self_attn.q_proj.weight ####


256it [00:43,  5.93it/s]


### Reconstructing model.layers.23.self_attn.k_proj.weight ####


64it [00:10,  6.01it/s]


### Reconstructing model.layers.23.self_attn.v_proj.weight ####


64it [00:11,  5.79it/s]


### Reconstructing model.layers.23.self_attn.o_proj.weight ####


256it [00:42,  5.97it/s]


### Reconstructing model.layers.24.self_attn.q_proj.weight ####


256it [00:42,  6.02it/s]


### Reconstructing model.layers.24.self_attn.k_proj.weight ####


64it [00:10,  5.85it/s]


### Reconstructing model.layers.24.self_attn.v_proj.weight ####


64it [00:10,  6.10it/s]


### Reconstructing model.layers.24.self_attn.o_proj.weight ####


256it [00:42,  5.97it/s]


### Reconstructing model.layers.25.self_attn.q_proj.weight ####


256it [00:42,  6.01it/s]


### Reconstructing model.layers.25.self_attn.k_proj.weight ####


64it [00:10,  6.12it/s]


### Reconstructing model.layers.25.self_attn.v_proj.weight ####


64it [00:10,  5.90it/s]


### Reconstructing model.layers.25.self_attn.o_proj.weight ####


256it [00:42,  6.03it/s]


### Reconstructing model.layers.26.self_attn.q_proj.weight ####


256it [00:43,  5.94it/s]


### Reconstructing model.layers.26.self_attn.k_proj.weight ####


64it [00:10,  5.85it/s]


### Reconstructing model.layers.26.self_attn.v_proj.weight ####


64it [00:10,  6.07it/s]


### Reconstructing model.layers.26.self_attn.o_proj.weight ####


256it [00:42,  5.96it/s]


### Reconstructing model.layers.27.self_attn.q_proj.weight ####


256it [00:42,  5.99it/s]


### Reconstructing model.layers.27.self_attn.k_proj.weight ####


64it [00:10,  6.09it/s]


### Reconstructing model.layers.27.self_attn.v_proj.weight ####


64it [00:10,  5.88it/s]


### Reconstructing model.layers.27.self_attn.o_proj.weight ####


256it [00:42,  6.07it/s]


### Reconstructing model.layers.28.self_attn.q_proj.weight ####


256it [00:42,  6.03it/s]


### Reconstructing model.layers.28.self_attn.k_proj.weight ####


64it [00:10,  5.88it/s]


### Reconstructing model.layers.28.self_attn.v_proj.weight ####


64it [00:10,  6.06it/s]


### Reconstructing model.layers.28.self_attn.o_proj.weight ####


256it [00:43,  5.90it/s]


### Reconstructing model.layers.29.self_attn.q_proj.weight ####


256it [00:43,  5.90it/s]


### Reconstructing model.layers.29.self_attn.k_proj.weight ####


64it [00:10,  6.01it/s]


### Reconstructing model.layers.29.self_attn.v_proj.weight ####


64it [00:11,  5.81it/s]


### Reconstructing model.layers.29.self_attn.o_proj.weight ####


256it [00:42,  5.97it/s]


### Reconstructing model.layers.30.self_attn.q_proj.weight ####


256it [00:42,  6.05it/s]


### Reconstructing model.layers.30.self_attn.k_proj.weight ####


64it [00:10,  5.88it/s]


### Reconstructing model.layers.30.self_attn.v_proj.weight ####


64it [00:10,  6.05it/s]


### Reconstructing model.layers.30.self_attn.o_proj.weight ####


256it [00:42,  5.98it/s]


### Reconstructing model.layers.31.self_attn.q_proj.weight ####


256it [00:42,  5.97it/s]


### Reconstructing model.layers.31.self_attn.k_proj.weight ####


64it [00:10,  6.11it/s]


### Reconstructing model.layers.31.self_attn.v_proj.weight ####


64it [00:10,  5.94it/s]


### Reconstructing model.layers.31.self_attn.o_proj.weight ####


256it [00:42,  6.02it/s]


6.801666331291199 0.7445371351726863
