In [1]:
MLP_SIZE = 11008
EMB_SIZE = 4096
N_BLOCKS = 32

import torch
import torch.nn as nn
from torch import optim
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
from transformers.models.llama.modeling_llama import LlamaMLP,LlamaDecoderLayer
from datasets import load_from_disk

import numpy as np

import argparse
import pickle
import timeit
import subprocess
import gc

from collections import defaultdict

import matplotlib.pyplot as plt

from tracing.utils.llama.model import avg_model,permute_model,get_emb_weights
from tracing.utils.llama.matching import align_model
from tracing.utils.evaluate import prepare_hf_dataset,prepare_hf_dataloader,evaluate
from tracing.utils.utils import cossim

from tracing.statistics.mc import statistic as mode_stat
from tracing.statistics.cos import statistic as cos_stat

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model_id = "lmsys/vicuna-7b-v1.5"

In [3]:
from huggingface_hub import login
login(token="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")

Token has not been saved to git credential helper. Pass `add_to_git_credential=True` if you want to set the git credential as well.
Token is valid (permission: read).
Your token has been saved to /juice4/scr4/nlp/model-tracing/token
Login successful


In [4]:
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)
config = AutoConfig.from_pretrained(model_id)

Loading checkpoint shards: 100%|██████████████████| 2/2 [02:53<00:00, 86.91s/it]


In [6]:
# dataset = prepare_hf_dataset("dlwh/wikitext_103_detokenized",64,base_tokenizer,split="train")
# dataset.save_to_disk("./data/wikitext-train.hf")

In [7]:
# architecture of MLP trained from scratch can be different from original
# eg, uncomment to get a 2-hidden layer MLP (original has just 1 hidden layer)
class CustomLlamaMLP(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size
        
        self.gate_proj1 = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        self.up_proj1 = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        self.down_proj1 = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
        
#         self.gate_proj2 = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
#         self.up_proj2 = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
#         self.down_proj2 = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
        
        self.act_fn = nn.SiLU()

    def forward(self, x):
        down_proj = self.down_proj1(self.act_fn(self.gate_proj1(x)) * self.up_proj1(x))
        # down_proj = self.down_proj2(self.act_fn(self.gate_proj2(x)) * self.up_proj2(x))

        return down_proj

In [10]:
i = 31 # layer to retrain
bsz = 5000 # batch size
T = 10000 # gradient steps
width_fac = 2.0 # easier to get loss down for wider MLPs when retraining

config = AutoConfig.from_pretrained(model_id)
config.intermediate_size = int(width_fac*MLP_SIZE)

mlp = CustomLlamaMLP(config).bfloat16()

mlp.to("cuda")
model.model.layers[i].mlp.to("cuda")

criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(mlp.parameters(), lr=0.001)

A = torch.randn(size=(EMB_SIZE,EMB_SIZE),device="cuda").bfloat16() / np.sqrt(EMB_SIZE) # rotate outputs (just for kicks / sanity check)

for t in range(T):
    X_batch = torch.randn(size=(bsz,EMB_SIZE),dtype=torch.bfloat16,device="cuda")
    with torch.no_grad():
        Y_batch = model.model.layers[i].mlp(X_batch)
        Y_batch = Y_batch@A.T
        
    Y_h = mlp(X_batch)
    
    optimizer.zero_grad()
    loss = criterion(Y_h,Y_batch)
    
    loss.backward()
    optimizer.step()
    
    if t % 100 == 0:
        print(f"train loss: {loss.item()}")

train loss: 5.5625
train loss: 4.625
train loss: 4.46875
train loss: 4.34375
train loss: 4.21875
train loss: 4.0625
train loss: 3.9375
train loss: 3.828125
train loss: 3.6875
train loss: 3.59375
train loss: 3.484375
train loss: 3.40625
train loss: 3.3125
train loss: 3.234375
train loss: 3.125
train loss: 2.984375
train loss: 2.796875
train loss: 2.53125
train loss: 2.21875
train loss: 1.84375
train loss: 1.4765625
train loss: 1.1171875
train loss: 0.7890625
train loss: 0.53125
train loss: 0.353515625
train loss: 0.234375
train loss: 0.16015625
train loss: 0.12060546875
train loss: 0.09765625
train loss: 0.080078125
train loss: 0.06640625
train loss: 0.056640625
train loss: 0.048828125
train loss: 0.0439453125
train loss: 0.03955078125
train loss: 0.035888671875
train loss: 0.033447265625
train loss: 0.03076171875
train loss: 0.0284423828125
train loss: 0.0264892578125
train loss: 0.02490234375
train loss: 0.0234375
train loss: 0.0220947265625
train loss: 0.0211181640625
train loss: 0.0

In [11]:
def hook(m, inp, op, feats, name):
    feats[name].append(inp[0].detach().cpu())

# can redefine this to whatever (eg, Sally's new robust p-value thing)
def statistic(og_mlp,ret_mlp,n=5000,emb_size=4096):
    feats = defaultdict(list)

    base_hook = lambda *args : hook(*args,feats,"base")
    base_handle = og_mlp.down_proj.register_forward_hook(base_hook)

    ft_hook = lambda *args : hook(*args,feats,"ft")
    ft_handle = ret_mlp.down_proj1.register_forward_hook(ft_hook)
    
    x = torch.randn(size=(n,emb_size)).bfloat16().to("cuda")
    with torch.no_grad():
        og_mlp.to("cuda")
        y_base = og_mlp(x)
        og_mlp.to("cpu")
        
        ret_mlp.to("cuda")
        y_ft = ret_mlp(x)
        ret_mlp.to("cpu")
    
    base_mat = torch.vstack(feats['base'])
    ft_mat = torch.vstack(feats['ft'])
    
    base_mat = base_mat.view(-1,base_mat.shape[-1]).T
    ft_mat = ft_mat.view(-1,ft_mat.shape[-1]).T
    
    base_handle.remove()
    ft_handle.remove()
    
    return torch.median(torch.max(cossim(base_mat,ft_mat),axis=-1).values)

In [12]:
statistic(model.model.layers[i].mlp,mlp)

tensor(0.8008, dtype=torch.bfloat16)