In [None]:
import torch
import numpy as np
import cv2
import matplotlib.pyplot as plt

from transformers import CLIPVisionModelWithProjection, AutoModelForCausalLM
from transformers import AutoModel, AutoTokenizer, OPTForCausalLM, BloomForCausalLM
import numpy
import torch.nn as nn

# --- 유틸리티 함수 ---

def get_magnitude_spectrum(matrix):
    """2D 행렬을 입력받아 주파수 스펙트럼을 반환하는 함수"""
    f = np.fft.fft2(matrix)
    fshift = np.fft.fftshift(f)
    spectrum = 20 * np.log(np.abs(fshift) + 1)
    return spectrum

# --- 2. 압축 수행 ---

# (A) 양자화 (INT4 예시)
def quantize_dequantize(weights, bits=4):
    """가중치를 양자화 및 역양자화하여 복원"""
    # 스케일과 제로포인트 계산
    scale = (weights.max() - weights.min()) / (2**bits - 1)
    zeropoint = np.round(-weights.min() / scale)
    
    # 양자화 및 역양자화
    quantized = np.round(weights / scale + zeropoint)
    dequantized = (quantized - zeropoint) * scale
    return dequantized

# (B) JPEG 압축
def jpeg_compress_decompress(weights, quality=10):
    """가중치를 이미지처럼 취급하여 JPEG 압축 및 복원"""
    # 0-255 범위로 정규화
    min_val, max_val = weights.min(), weights.max()
    normalized = 255 * (weights - min_val) / (max_val - min_val)
    normalized = normalized.astype(np.uint8)
    
    # JPEG 압축/복원
    _, encoded_img = cv2.imencode('.jpg', normalized, [cv2.IMWRITE_JPEG_QUALITY, quality])
    decoded_img = cv2.imdecode(encoded_img, cv2.IMREAD_GRAYSCALE)
    
    # 원래 스케일로 역정규화
    denormalized = decoded_img.astype(np.float32) / 255 * (max_val - min_val) + min_val
    return denormalized

def webp_compression_decompression(weights, quality=10):
    """가중치를 이미지처럼 취급하여 WEBP 압축 및 복원"""
    # 0-255 범위로 정규화
    min_val, max_val = weights.min(), weights.max()
    normalized = 255 * (weights - min_val) / (max_val - min_val)
    normalized = normalized.astype(np.uint8)
    
    # WEBP 압축/복원
    _, encoded_img = cv2.imencode('.webp', normalized, [cv2.IMWRITE_WEBP_QUALITY, quality])
    decoded_img = cv2.imdecode(encoded_img, cv2.IMREAD_GRAYSCALE)
    
    # 원래 스케일로 역정규화
    denormalized = decoded_img.astype(np.float32) / 255 * (max_val - min_val) + min_val
    return denormalized

# def nic_compress_decompress(weights, comp_model, quality=10):
#     # from ..comp_lm_qtip.nic_models.TCM.models import TCM
    
#     """가중치를 이미지처럼 취급하여 NIC 압축 및 복원"""
#     # 0-255 범위로 정규화
#     min_val, max_val = weights.min(), weights.max()
#     normalized = 2 * (weights - min_val) / (max_val - min_val)

#     normalized = torch.tensor(normalized, dtype=torch.float32).to('cuda')
#     p = normalized.unsqueeze(0).unsqueeze(0).float()  # [1,1,h_p,w_p]
#     # p_pad, padding = pad(p, patch_size)
#     p3 = p.repeat(1, 3, 1, 1)  # [1,3,patch_size,patch_size]

#     # Compress and decompress
#     out_enc = comp_model.compress(p3)
#     out_dec = comp_model.decompress(out_enc["strings"], out_enc["shape"])
    
#     rec1 = out_dec["x_hat"][:, 0:1, :, :]
#     # rec_crop = crop(rec1, padding)  # [1,1,h_p,w_p]
#     denormalized = rec1.squeeze(0).squeeze(0)
    
#     denormalized = denormalized / 2 * (max_val - min_val) + min_val
    
#     return denormalized.detach().cpu().numpy()

def nic_compress_decompress(weights, comp_model, quality=10):
    # from ..comp_lm_qtip.nic_models.TCM.models import TCM
    
    """가중치를 이미지처럼 취급하여 NIC 압축 및 복원"""
    # 0-255 범위로 정규화
    m, s = weights.mean(), weights.std()
    normalized = (weights - m) / s

    normalized = torch.tensor(normalized, dtype=torch.float32).to('cuda')
    p = normalized.unsqueeze(0).unsqueeze(0).float()  # [1,1,h_p,w_p]
    # p_pad, padding = pad(p, patch_size)
    p3 = p.repeat(1, 3, 1, 1)  # [1,3,patch_size,patch_size]

    # Compress and decompress
    # out_enc = comp_model.compress(p3)
    # out_dec = comp_model.decompress(out_enc["strings"], out_enc["shape"])
    out_dec = comp_model(p3)
    
    rec1 = out_dec["x_hat"][:, 0:1, :, :]
    # rec_crop = crop(rec1, padding)  # [1,1,h_p,w_p]
    denormalized = rec1.squeeze(0).squeeze(0)
    
    denormalized = denormalized * s + m
    
    return denormalized.detach().cpu().numpy()


def pad(x, p):
    h, w = x.size(2), x.size(3)
    new_h = (h + p - 1) // p * p
    new_w = (w + p - 1) // p * p
    padding_left = (new_w - w) // 2
    padding_right = new_w - w - padding_left
    padding_top = (new_h - h) // 2
    padding_bottom = new_h - h - padding_top
    x_padded = F.pad(
        x,
        (padding_left, padding_right, padding_top, padding_bottom),
        mode="constant",
        value=0,
    )
    return x_padded, (padding_left, padding_right, padding_top, padding_bottom)

def crop(x, padding):
    return F.pad(
        x,
        (-padding[0], -padding[1], -padding[2], -padding[3]),
    )



In [13]:

model_list = [
    # 'meta-llama/Meta-Llama-3-8B',
    'meta-llama--Llama-2-7b-hf'
]
model_name = model_list[0]
model_name = model_name.replace('/', '--')
print('model_name: ', model_name)

model_path = f"../Wparam_dataset/hf_model/{model_name}"
model = AutoModelForCausalLM.from_pretrained(model_path, local_files_only=True)


def get_named_linears(module):
    return {name: m for name, m in module.named_modules() if isinstance(m, nn.Linear)}
# original_weights = torch.randn(4096, 11008, dtype=torch.float32)

# comp_model = TCM(config=[2,2,2,2,2,2], head_dim=[8, 16, 32, 32, 16, 8], drop_path_rate=0.0, N=64, M=320).to('cuda')
# dictory = {}
# checkpoint = torch.load('/workspace/Weight_compression/comp_lm_qtip/nic_models/TCM/checkpoints/0.05.pth.tar')
# for k, v in checkpoint["state_dict"].items():
#     dictory[k.replace("module.", "")] = v
# comp_model.load_state_dict(dictory)
# comp_model.eval()
# comp_model.update()

import sys, os
sys.path.append('/workspace/Weight_compression/comp_lm_qtip')


from nic_models.FTIC.models import FrequencyAwareTransFormer
comp_model = FrequencyAwareTransFormer()
dictory = {}
print("Loading FTIC")
checkpoint = torch.load('/workspace/Weight_compression/comp_lm_qtip/nic_models/FTIC/checkpoints/ckpt_0483.pth')
for k, v in checkpoint.items():
    dictory[k.replace("module.", "")] = v
comp_model.load_state_dict(dictory,strict=True)
comp_model.eval()
comp_model.update()
comp_model.to('cuda')
# for i in range(len(model.model.layers)):
for i in [0, 1, 2, 10, 20, 31]:
    linear = get_named_linears(model.model.layers[i])
    for n,m in linear.items():
        original_weights = m.weight.data

        # 가중치 텐서의 일부(512x512)를 잘라내어 분석 (시각화를 위해)
        weight_slice = original_weights[:512, :512].numpy()

        quantized_weights = quantize_dequantize(weight_slice, bits=4)
        jpeg_weights = jpeg_compress_decompress(weight_slice, quality=10)
        webp_weights = jpeg_compress_decompress(weight_slice, quality=10)

        # webp = torch.load('/workspace/Weight_compression/hf_model_comp/handcraft/ckpt/meta-llama--Llama-2-7b-hf/webp/qm_group_gs128_q97/0_k.pt')
        # webp_weights = webp['W_hat'][:512, :512].numpy()

        # nic_weights = torch.load('/workspace/Weight_compression/hf_model_comp/nic/ckpt/meta-llama--Meta-Llama-3-8B/tcm/patch256_norm_patch256_lmbda0.05/0_q.pt')
        # nic_weights = torch.load('/workspace/Weight_compression/hf_model_comp/nic/ckpt/meta-llama--Meta-Llama-3-8B/tcm/group64_lmbda0.05/10_k.pt')
        # nic_weights = nic_weights['W_hat'][:512, :512].numpy()
        nic_weights = nic_compress_decompress(weight_slice, comp_model, quality=10)
        # --- 3. 주파수 스펙트럼 분석 및 오차 계산 ---

        # 각 가중치의 스펙트럼 계산
        original_spec = get_magnitude_spectrum(weight_slice)
        quantized_spec = get_magnitude_spectrum(quantized_weights)
        jpeg_spec = get_magnitude_spectrum(jpeg_weights)
        webp_spec = get_magnitude_spectrum(webp_weights)
        nic_spec = get_magnitude_spectrum(nic_weights)

        # 오차 스펙트럼 계산
        quant_error_spec = np.abs(original_spec - quantized_spec)
        jpeg_error_spec = np.abs(original_spec - jpeg_spec)
        webp_error_spec = np.abs(original_spec - webp_spec)
        nic_error_spec = np.abs(original_spec - nic_spec)


        quant_error_norm = (quant_error_spec - np.min(quant_error_spec)) / (np.max(quant_error_spec) - np.min(quant_error_spec))
        jpeg_error_norm = (jpeg_error_spec - np.min(jpeg_error_spec)) / (np.max(jpeg_error_spec) - np.min(jpeg_error_spec))
        webp_error_norm = (webp_error_spec - np.min(webp_error_spec)) / (np.max(webp_error_spec) - np.min(webp_error_spec))
        nic_error_norm = (nic_error_spec - np.min(nic_error_spec)) / (np.max(nic_error_spec) - np.min(nic_error_spec))


        # --- 4. 시각화 (컬러바 수정) ---
        # plt.style.use('dark_background')
        # plt.style.use('background')
        fig, axes = plt.subplots(1, 4, figsize=(21, 7))

        # 원본 가중치 스펙트럼
        axes[0].imshow(original_spec, cmap='viridis')
        axes[0].set_title('Original Weight Spectrum', fontsize=16)
        axes[0].axis('off')

        # 정규화된 양자화 오차 스펙트럼 + 개별 컬러바
        im1 = axes[1].imshow(quant_error_norm, cmap='inferno')
        axes[1].set_title('Normalized Quantization Error', fontsize=16)
        axes[1].axis('off')
        fig.colorbar(im1, ax=axes[1], fraction=0.046, pad=0.04) # 양자화 오차 플롯의 컬러바

        # 정규화된 JPEG 오차 스펙트럼 + 개별 컬러바
        im2 = axes[2].imshow(jpeg_error_norm, cmap='inferno')
        axes[2].set_title('Normalized JPEG Error', fontsize=16)
        axes[2].axis('off')
        fig.colorbar(im2, ax=axes[2], fraction=0.046, pad=0.04) # JPEG 오차 플롯의 컬러바

        # im3 = axes[3].imshow(webp_error_norm, cmap='inferno')
        # axes[3].set_title('Normalized WebP Error', fontsize=16)
        # axes[3].axis('off')
        # fig.colorbar(im3, ax=axes[3], fraction=0.046, pad=0.04) # JPEG 오차 플롯의 컬러

        im3 = axes[3].imshow(nic_error_norm, cmap='inferno')
        axes[3].set_title('Normalized FTIC Error', fontsize=16)
        axes[3].axis('off')
        fig.colorbar(im3, ax=axes[3], fraction=0.046, pad=0.04) # JPEG 오차 플롯의 컬러

        plt.suptitle(f'{i}_{n} Normalized Frequency Error Comparison: Quantization vs. JPEG', fontsize=20)
        plt.tight_layout(rect=[0, 0.03, 1, 0.95])
        plt.show()