In [1]:
import torch
import torch.nn as nn
from transformers import CLIPVisionModelWithProjection, ViTForImageClassification, AutoModelForCausalLM
from transformers import AutoModel, AutoTokenizer, LlamaForCausalLM

import sys, os, json, math
from tqdm import tqdm
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def latest_version_path(cache_dir, model_name, branch = 'main'):
    model_name_dir =  "models--" + model_name.replace('/', '--')
    path = os.path.join(cache_dir, model_name_dir)
    if not os.path.isdir(os.path.join(path, 'snapshots')):
        return None
    branch_file =  os.path.join(path, 'refs', branch)
    with open(branch_file, 'r', encoding='utf-8') as file:
        revision = file.read()
    return os.path.join(path, 'snapshots', revision)

cache_directory = "../Wparam_dataset_v0/model_zoo/huggingface" 
ckpt_path = latest_version_path(cache_directory, 'meta-llama/Meta-Llama-3-8B')
net = LlamaForCausalLM.from_pretrained(ckpt_path, local_files_only=True)


ckpt_path = '/home/jgryu/Weight_compression/model_cache/models--meta-llama--Meta-Llama-3-8B/snapshots/8cde5ca8380496c9a6cc7ef3a8b46a0372a1d920'
# net = AutoModelForCausalLM.from_pretrained(ckpt_path, local_files_only=True)
tokenizer = AutoTokenizer.from_pretrained(ckpt_path, local_files_only=True)
state_dict = net.state_dict()

Loading checkpoint shards: 100%|██████████| 4/4 [00:02<00:00,  1.58it/s]


In [3]:
for wtype in ['attn']:
    # for top in [0.001, 0.01, 0.03, 0.1, 0.3]:
    for top in [0.3]:
        r = 0.1
        std = np.load(f'/home/jgryu/Weight_compression/Wparam_dataset_v0/TFRecord/meta-llama--Meta-Llama-3-8B/mlp/d16/mlp_d16_train_std.npy')
        a = std * math.sqrt(r) * math.sqrt(3)

        for method in ['topk', 'bottomk', 'random']:
            noisy_state_dict = {}
            mse_fn = nn.MSELoss()
            mse = 0
            count = 0

            for k, v in state_dict.items():
                if wtype in k:
                    count += 1

                    # Flatten the tensor to work with absolute values
                    v_flat = v.view(-1)
                    abs_v_flat = torch.abs(v_flat).to(dtype=torch.float32)
                    k_value = int(len(abs_v_flat) * top)

                    if method == 'topk' and k_value > 0:
                        # Top k method
                        top_values, _ = torch.topk(abs_v_flat, k=k_value)
                        threshold = top_values[-1]  # Smallest value in top k
                        mask = abs_v_flat >= threshold

                    elif method == 'bottomk' and k_value > 0:
                        # Bottom k method
                        bottom_values, _ = torch.topk(-abs_v_flat, k=k_value)
                        threshold = -bottom_values[-1]  # Largest negative value in bottom k
                        mask = abs_v_flat <= threshold

                    elif method == 'random' and k_value > 0:
                        # Random k method
                        indices = torch.randperm(len(abs_v_flat))[:k_value]
                        mask = torch.zeros_like(abs_v_flat, dtype=torch.bool)
                        mask[indices] = True

                    else:
                        # No weights are modified if k_value is 0
                        mask = torch.zeros_like(v_flat, dtype=torch.bool)

                    mask = mask.view(v.shape)  # Reshape to original shape

                    # Generate noise and apply to selected weights
                    noise = torch.empty(v.shape).uniform_(-a, a).to(dtype=v.dtype)
                    noise = noise * mask  # Apply noise only to selected weights

                    noisy_state_dict[k] = v + noise
                    mse += mse_fn(noisy_state_dict[k].to(dtype=torch.float32), v.to(dtype=torch.float32))
                #     print(k, v.shape, v.dtype)
                else:
                    noisy_state_dict[k] = v

            mse /= count
            print(f"MSE ({method}):", mse / std**2)

            # Save the modified model
            recon_net = AutoModelForCausalLM.from_config(net.config)
            recon_net.load_state_dict(noisy_state_dict)
            recon_net = recon_net.to(dtype=torch.bfloat16)
            save_directory = f"/home/jgryu/Weight_compression/model_cache_reconstructed/uniform_noise/exp_magnitude/r{r}/{wtype}_r{r}_top{top}_{method}_layer_all"
            recon_net.save_pretrained(save_directory)
            tokenizer.save_pretrained(save_directory)
            print(save_directory.split("/")[-1])


MSE (topk): tensor(0.0301)
attn_r0.1_top0.3_topk_layer_all
MSE (bottomk): tensor(0.0301)
attn_r0.1_top0.3_bottomk_layer_all
MSE (random): tensor(0.0300)
attn_r0.1_top0.3_random_layer_all
