In [1]:
# !pip install peft==0.5.0

# !pip install traker==0.1.3
# !pip install fast-jl==0.1.3

In [2]:
import os

os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   
os.environ["CUDA_VISIBLE_DEVICES"]='0'

os.environ["HF_HOME"]="~/.cache/huggingface"

In [3]:
import torch
import random
import numpy as np

def set_seeds(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
set_seeds(42)

In [4]:
from datasets import load_dataset

from torch.utils.data import DataLoader

import transformers
from transformers import default_data_collator

from tqdm import tqdm
import pickle

In [5]:
from transformers import LlamaForCausalLM, LlamaTokenizer
from peft import PeftModel, PeftConfig

peft_model_id = "chansung/gpt4-alpaca-lora-7b"
config = PeftConfig.from_pretrained(peft_model_id)
config.base_model_name_or_path = 'baffo32/decapoda-research-llama-7B-hf'

In [6]:
config

LoraConfig(peft_type='LORA', auto_mapping=None, base_model_name_or_path='baffo32/decapoda-research-llama-7B-hf', revision=None, task_type='CAUSAL_LM', inference_mode=True, r=16, target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj'], lora_alpha=16, lora_dropout=0.05, fan_in_fan_out=False, bias='none', modules_to_save=None, init_lora_weights=True, layers_to_transform=None, layers_pattern=None)

In [7]:
model = LlamaForCausalLM.from_pretrained(config.base_model_name_or_path, 
                                             torch_dtype=torch.float16,
                                             device_map="auto",
                                             low_cpu_mem_usage=True,
                                            )
model = PeftModel.from_pretrained(model, peft_model_id, torch_dtype=torch.float16,)

Loading checkpoint shards:   0%|          | 0/33 [00:00<?, ?it/s]

In [8]:
model.print_trainable_parameters()

trainable params: 0 || all params: 6,755,192,832 || trainable%: 0.0


In [9]:
for n, p in model.named_parameters():
    if 'lora' in n:
        print(n)
        p.requires_grad = True

base_model.model.model.layers.0.self_attn.q_proj.lora_A.default.weight
base_model.model.model.layers.0.self_attn.q_proj.lora_B.default.weight
base_model.model.model.layers.0.self_attn.k_proj.lora_A.default.weight
base_model.model.model.layers.0.self_attn.k_proj.lora_B.default.weight
base_model.model.model.layers.0.self_attn.v_proj.lora_A.default.weight
base_model.model.model.layers.0.self_attn.v_proj.lora_B.default.weight
base_model.model.model.layers.0.self_attn.o_proj.lora_A.default.weight
base_model.model.model.layers.0.self_attn.o_proj.lora_B.default.weight
base_model.model.model.layers.1.self_attn.q_proj.lora_A.default.weight
base_model.model.model.layers.1.self_attn.q_proj.lora_B.default.weight
base_model.model.model.layers.1.self_attn.k_proj.lora_A.default.weight
base_model.model.model.layers.1.self_attn.k_proj.lora_B.default.weight
base_model.model.model.layers.1.self_attn.v_proj.lora_A.default.weight
base_model.model.model.layers.1.self_attn.v_proj.lora_B.default.weight
base_m

In [10]:
model.print_trainable_parameters()

trainable params: 16,777,216 || all params: 6,755,192,832 || trainable%: 0.24836028248556738


In [11]:
data_path = 'data/alpaca_data_gpt4.json'

In [12]:
if data_path.endswith(".json") or data_path.endswith(".jsonl"):
    dataset = load_dataset("json", data_files=data_path)
else:
    dataset = load_dataset(data_path)
dataset

DatasetDict({
    train: Dataset({
        features: ['instruction', 'input', 'output'],
        num_rows: 52002
    })
})

In [13]:
from utils.prompter import Prompter

prompt_template_name = 'alpaca'
cutoff_len = 128

prompter = Prompter(prompt_template_name)

# data preprocessing
tokenizer = LlamaTokenizer.from_pretrained(config.base_model_name_or_path)

tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id

if True:
    def tokenize(prompt, add_eos_token=True):
        # there's probably a way to do this with the tokenizer settings
        # but again, gotta move fast
        result = tokenizer(
            prompt,
            truncation=True,
            max_length=cutoff_len,
            padding=False,
            return_tensors=None,
        )
        ####
        result["input_ids"][-1] = tokenizer.eos_token_id
        result["attention_mask"][-1] = 1
        ####
        result["labels"] = result["input_ids"].copy()

        return result
    
    def generate_and_tokenize_prompt(data_point):
        full_prompt = prompter.generate_prompt(
            data_point["instruction"],
            data_point["input"],
            data_point["output"],
        )
        tokenized_full_prompt = tokenize(full_prompt)
        
        return tokenized_full_prompt

In [14]:
processed_datasets = dataset.map(
    generate_and_tokenize_prompt,
    remove_columns=dataset["train"].column_names,
)

In [15]:
processed_datasets['train']

Dataset({
    features: ['input_ids', 'attention_mask', 'labels'],
    num_rows: 52002
})

In [16]:
train_dataset = processed_datasets["train"]
train_dataset

Dataset({
    features: ['input_ids', 'attention_mask', 'labels'],
    num_rows: 52002
})

In [17]:
data_collator=transformers.DataCollatorForSeq2Seq(
            tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
        )

In [18]:
batch_size = 8
train_dataloader = DataLoader(
    train_dataset, shuffle=False, collate_fn=data_collator, batch_size=batch_size, pin_memory=True
)

In [19]:
for i in train_dataloader:
    print(i)
    break

{'input_ids': tensor([[    0, 13866,   338,  ...,  9128,  6354,     0],
        [    0, 13866,   338,  ...,     0,     0,     0],
        [    0, 13866,   338,  ..., 17105,   411,     0],
        ...,
        [    0, 13866,   338,  ...,     0,     0,     0],
        [    0, 13866,   338,  ...,  1716,   278,     0],
        [    0, 13866,   338,  ..., 29892,   902,     0]]), 'attention_mask': tensor([[1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 1, 1, 1],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1]]), 'labels': tensor([[    0, 13866,   338,  ...,  9128,  6354,     0],
        [    0, 13866,   338,  ...,  -100,  -100,  -100],
        [    0, 13866,   338,  ..., 17105,   411,     0],
        ...,
        [    0, 13866,   338,  ...,  -100,  -100,  -100],
        [    0, 13866,   338,  ...,  1716,   278,     0],
        [    0, 13866,   338,  ..., 29892,   902,     0]])}


In [20]:
import torch.nn.functional as F
from torch.func import functional_call, vmap, grad

In [21]:
def compute_loss(params, buffers, input_ids, attention_mask, labels):
    input_ids = input_ids.unsqueeze(0)
    attention_mask = attention_mask.unsqueeze(0)
    labels = labels.unsqueeze(0)
    
    outputs = functional_call(model, (params, buffers), args=input_ids, 
                                  kwargs={'attention_mask': attention_mask, 
                                          # 'labels': labels
                                         })
    lm_logits = outputs.logits
    loss = None
    if labels is not None:
        # move labels to correct device to enable model parallelism
        labels = labels.to(lm_logits.device)
        # Shift so that tokens < n predict n
        shift_logits = lm_logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()
                
        bsz, seq_length = shift_labels.size()
        ####
        shift_logits = shift_logits.reshape(-1, shift_logits.size(-1))
        shift_labels = shift_labels.reshape(-1)
        
        bindex = torch.arange(shift_logits.shape[0]).to(shift_logits.device, 
                                                        non_blocking=False
                                                       )
        logits_correct = shift_logits[bindex, shift_labels.unsqueeze(0)]
        cloned_logits = shift_logits.clone()
        cloned_logits[bindex, shift_labels.unsqueeze(0)] = torch.tensor(-torch.inf, device=shift_logits.device, dtype=shift_logits.dtype)

        margins = logits_correct - cloned_logits.logsumexp(dim=-1)
        ####
        margins = margins.reshape(bsz, seq_length)
        padding_mask = (shift_labels!=-100).reshape(bsz, seq_length)
        ####
        margins = margins * padding_mask
        loss = margins.sum(dim=1) / padding_mask.sum(dim=1)   
    return loss.squeeze(0) # must be a scaler

def vectorize_and_ignore_buffers(g, params_dict=None):
    """
    gradients are given as a tuple :code:`(grad_w0, grad_w1, ... grad_wp)` where
    :code:`p` is the number of weight matrices. each :code:`grad_wi` has shape
    :code:`[batch_size, ...]` this function flattens :code:`g` to have shape
    :code:`[batch_size, num_params]`.
    """
    batch_size = len(g[0])
    out = []
    if params_dict is not None:
        for b in range(batch_size):
            out.append(torch.cat([x[b].flatten() for i, x in enumerate(g) if is_not_buffer(i, params_dict)]))
    else:
        for b in range(batch_size):
            out.append(torch.cat([x[b].flatten() for x in g]))
    return torch.stack(out)

ft_compute_grad = grad(compute_loss)
ft_compute_sample_grad = vmap(ft_compute_grad, 
                              in_dims=(None, None, 0, 0, 0),
                             )

In [22]:
from sklearn.random_projection import johnson_lindenstrauss_min_dim
proj_dim = johnson_lindenstrauss_min_dim(n_samples=len(train_dataset), eps=0.1) # 误差有点高
proj_dim

9307

In [23]:
proj_dim = (proj_dim//512+1)*512
proj_dim

9728

In [24]:
from trak.projectors import ProjectionType, AbstractProjector, CudaProjector
projector = CudaProjector(grad_dim=16777216, 
                          proj_dim=proj_dim,
                          seed=0, 
                          proj_type=ProjectionType.normal,
                          # proj_type=ProjectionType.rademacher,
                          device='cuda:0')

In [25]:
len(train_dataloader)

6501

In [26]:
device = 'cuda:0'

In [27]:
set_seeds(42)
model.eval()
    
params = {k: v.detach() for k, v in model.named_parameters() if v.requires_grad==True}
buffers = {k: v.detach() for k, v in model.named_buffers() if v.requires_grad==True}
        
train_dstore_keys = np.memmap('./saved/train_keys.npy', 
                              dtype=np.float16, 
                              mode='w+', 
                              shape=(len(train_dataset), proj_dim))

for step, batch in enumerate(tqdm(train_dataloader)):
    
    batch = {k: v.to(device) for k, v in batch.items()}
    bsz = batch['labels'].shape[0]
    # print(batch)
    ft_per_sample_grads = ft_compute_sample_grad(params, buffers, batch['input_ids'], batch['attention_mask'], batch['labels'])
    ft_per_sample_grads = vectorize_and_ignore_buffers(list(ft_per_sample_grads.values()))
    ft_per_sample_grads = projector.project(ft_per_sample_grads, model_id=0)
    
    train_dstore_keys[step*batch_size:step*batch_size+bsz] = ft_per_sample_grads.detach().cpu().numpy()

    # break

 80%|██████████████████████████████████████████████████████████████████████                  | 5174/6501 [58:34<14:51,  1.49it/s]IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

100%|██████████████████████████████████████████████████████████████████████████████████████| 6501/6501 [1:13:42<00:00,  1.47it/s]


In [28]:
train_dstore_keys[0]

memmap([  3.127,   4.562,  23.12 , ..., -11.91 ,   2.809,   7.363],
       dtype=float16)

In [29]:
train_dstore_keys.shape

(52002, 9728)