# EasyEdit Example with **Wise**

>Tutorial author:Jizhan Fang（<fangjizhan@zju.deu.cn>）,Ziyan Jiang（<ziy.jiang@outlook.com>）

In this tutorial, we use `Wise` to edit `llama2-7b-chat` model, we hope this tutorial could help you understand how to use the method WISE on LLMs, using the Wise method with the llama2-7b-chat as an example.

## Model Editing

Deployed models may still make unpredictable errors. For example, Large Language Models (LLMs) notoriously hallucinate, perpetuate bias, and factually decay, so we should be able to adjust specific behaviors of pre-trained models.

**Model editing** aims to adjust an initial base model's $(f_\theta)$ behavior on the particular edit descriptor $[x_e, y_e]$, such as:
- $x_e$: "Who is the president of the US?
- $y_e$: "Joe Biden."

efficiently without influencing the model behavior on unrelated samples. The ultimate goal is to create an edited model$(f_\theta’)$.

## Method:WISE

Paper: [WISE: Rethinking the Knowledge Memory for Lifelong Model Editing of Large Language Models?](http://arxiv.org/pdf/2405.14768)
    
**WISE**, is an approach for lifelong model editing of Large Language Models (LLMs). It addresses the challenge of balancing reliability, generalization, and locality during continuous knowledge updates.
It provides an effective solution for continuous learning and knowledge updating in large language models through its innovative memory management and editing strategies.

## 📂 Data Preparation

The datasets used can be found in [Google Drive Link](https://drive.google.com/file/d/1YtQvv4WvTa4rJyDYQR2J-uK8rnrt0kTA/view?usp=sharing) (ZsRE)

Each dataset contains both an **edit set** and a train set.

## Prepare the runtime environment

In [9]:
## Clone Repo
#!git clone https://github.com/zjunlp/EasyEdit
# %cd EasyEdit/tutorial-notebooks
# %cd ..
!pwd

/home/gridsan/shossain/EasyEdit/tutorial-notebooks


In [10]:
# !apt-get install python3.9
# !sudo update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.9 1
# !sudo update-alternatives --config python3
# !apt-get install python3-pip
# %pip install -r requirements.txt
# !source /state/partition1/llgrid/pkg/anaconda/anaconda3-2023b/etc/profile.d/conda.sh
!conda deactivate
!conda activate easy
# %pip freeze

absl-py==1.4.0
aiohttp @ file:///home/conda/feedstock_root/build_artifacts/aiohttp_1666878620108/work
aiosignal @ file:///home/conda/feedstock_root/build_artifacts/aiosignal_1667935791922/work
anaconda-client @ file:///home/conda/feedstock_root/build_artifacts/anaconda-client_1668463703032/work
antlr4-python3-runtime==4.9.3
anyio @ file:///home/conda/feedstock_root/build_artifacts/anyio_1666191106763/work/dist
appdirs @ file:///home/conda/feedstock_root/build_artifacts/appdirs_1603108395799/work
argon2-cffi @ file:///home/conda/feedstock_root/build_artifacts/argon2-cffi_1640817743617/work
argon2-cffi-bindings @ file:///home/conda/feedstock_root/build_artifacts/argon2-cffi-bindings_1666850859330/work
array-record==0.4.1
arrow==1.2.3
asttokens @ file:///home/conda/feedstock_root/build_artifacts/asttokens_1670263926556/work
astunparse==1.6.3
async-timeout @ file:///home/conda/feedstock_root/build_artifacts/async-timeout_1640026696943/work
attrs==21.4.0
Babel @ file:///home/conda/feedstock

## Config Method  Parameters

```python
alg_name: "WISE"
model_name: "./hugging_cache/llama2-7b-chat"
device: 0

mask_ratio: 0.2
edit_lr: 1.0
n_iter: 70
norm_constraint: 1.0
act_margin: [5.0, 20.0, 10.0] # alpha, beta, gamma
act_ratio: 0.88
save_freq: 500
merge_freq: 1000
merge_alg: 'ties'
objective_optimization: 'only_label'
inner_params:
- model.layers[27].mlp.down_proj.weight


## alternative: WISE-Merge, WISE-Retrieve

# for merge (if merge)
densities: 0.53
weights: 1.0

# for retrieve (if retrieve, pls set to True)
retrieve: True
replay: False # True --> will replay the past editing instances: see https://arxiv.org/abs/2405.14768 Appendix B.3

model_parallel: False

# for save and load
# save_path: "./wise_checkpoint/wise.pt"
# load_path: "./wise_checkpoint/wise.pt"


```

## Import models & Run

### Edit llama2-7b-chat on ZsRE with WISE

In [11]:
from easyeditor import BaseEditor
from easyeditor import GraceHyperParams

ModuleNotFoundError: No module named 'easyeditor'

In [None]:
import json
K = 10
edit_data = json.load(open('./data/wise/data/ZsRE/zsre_mend_edit.json', 'r', encoding='utf-8'))[:K]
loc_data = json.load(open('./data/wise/data/ZsRE/zsre_mend_train.json', 'r', encoding='utf-8'))[:K]
loc_prompts = [edit_data_['loc'] + ' ' + edit_data_['loc_ans'] for edit_data_ in loc_data]

prompts = [edit_data_['src'] for edit_data_ in edit_data]
subject = [edit_data_['subject'] for edit_data_ in edit_data]
rephrase_prompts = [edit_data_['rephrase'] for edit_data_ in edit_data]
target_new = [edit_data_['alt'] for edit_data_ in edit_data]
locality_prompts = [edit_data_['loc'] for edit_data_ in edit_data]
locality_ans = [edit_data_['loc_ans'] for edit_data_ in edit_data]
locality_inputs = {
    'neighborhood':{
        'prompt': locality_prompts,
        'ground_truth': locality_ans
    },
}
hparams = GraceHyperParams.from_hparams('./hparams/GRACE/llama-7b.yaml')

editor = BaseEditor.from_hparams(hparams)
metrics, edited_model, _ = editor.edit(
    prompts=prompts,
    rephrase_prompts=rephrase_prompts,
    target_new=target_new,
    loc_prompts=loc_prompts,
    subject=subject,
    locality_inputs=locality_inputs,
    sequential_edit=True,
    eval_metric='token em'
)

In [None]:
model_save_dir = './test'
edited_model.model.save_pretrained(model_save_dir)


* edit_data: editing instance in edit set.
* loc_data: used to provide xi in Equation 5, sampled from the train set.
* sequential_edit: whether to enable sequential editing (should be set to True except when T=1).
***

### Reliability Test

In [None]:
from transformers import LlamaTokenizer
from transformers import LlamaForCausalLM
model_name = "meta-llama/Llama-2-7b-hf"
# tokenizer = LlamaTokenizer.from_pretrained('./hugging_cache/llama2-7b-chat')
tokenizer = LlamaTokenizer.from_pretrained(model_name)
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.padding_side='left'

correct_prompts = ['What university did Watts Humphrey attend?',
                'Which family does Ramalinaceae belong to?',
                'What role does Denny Herzig play in football?']


# model = LlamaForCausalLM.from_pretrained('./hugging_cache/llama2-7b-chat').to('cuda')
model = LlamaForCausalLM.from_pretrained(model_name).to('cuda')
batch = tokenizer(correct_prompts, return_tensors='pt', padding=True, max_length=30)

pre_edit_outputs = model.generate(
    input_ids=batch['input_ids'].to('cuda'),
    attention_mask=batch['attention_mask'].to('cuda'),
#     max_length=15
    max_new_tokens=5
    
)

edited_model = LlamaForCausalLM.from_pretrained('./test').to('cuda')
post_edit_outputs = edited_model.generate(
    input_ids=batch['input_ids'].to('cuda'),
    attention_mask=batch['attention_mask'].to('cuda'),
#     max_length=15
    max_new_tokens=5
)
print('Pre-Edit Outputs: ', [tokenizer.decode(x) for x in pre_edit_outputs.detach().cpu().numpy().tolist()])
print('Post-Edit Outputs: ', [tokenizer.decode(x) for x in post_edit_outputs.detach().cpu().numpy().tolist()])

### Generalization test

In [None]:
# from transformers import LlamaTokenizer
# from transformers import LlamaForCausalLM
# tokenizer = LlamaTokenizer.from_pretrained('./hugging_cache/llama2-7b-chat', cache_dir='./hugging_cache')
# tokenizer.pad_token_id = tokenizer.eos_token_id
# tokenizer.padding_side='left'


generation_prompts = ['What university did Watts Humphrey take part in?',
'What family are Ramalinaceae?',
"What's Denny Herzig's role in football?"]

# model = LlamaForCausalLM.from_pretrained('./hugging_cache/llama2-7b-chat', cache_dir='./hugging_cache').to('cuda')

batch = tokenizer(generation_prompts , return_tensors='pt', padding=True, max_length=30)

pre_edit_outputs = model.generate(
    input_ids=batch['input_ids'].to('cuda'),
    attention_mask=batch['attention_mask'].to('cuda'),
#     max_length=15
    max_new_tokens=8
)
post_edit_outputs = edited_model.generate(
    input_ids=batch['input_ids'].to('cuda'),
    attention_mask=batch['attention_mask'].to('cuda'),
#     max_length=15
    max_new_tokens=8
)
print('Pre-Edit Outputs: ', [tokenizer.decode(x) for x in pre_edit_outputs.detach().cpu().numpy().tolist()])
print('Post-Edit Outputs: ', [tokenizer.decode(x) for x in post_edit_outputs.detach().cpu().numpy().tolist()])

### Locality test

In [None]:
# from transformers import LlamaTokenizer
# from transformers import LlamaForCausalLM
# tokenizer = LlamaTokenizer.from_pretrained('./hugging_cache/llama2-7b-chat', cache_dir='./hugging_cache')
# tokenizer.pad_token_id = tokenizer.eos_token_id
# tokenizer.padding_side='left'

locality_prompts = ['nq question: who played desmond doss father in hacksaw ridge',
                'nq question: types of skiing in the winter olympics 2018',
                'nq question: where does aarp fall on the political spectrum']

# model = LlamaForCausalLM.from_pretrained('./hugging_cache/llama-7b-chat', cache_dir='./hugging_cache').to('cuda')


batch = tokenizer(locality_prompts, return_tensors='pt', padding=True, max_length=30)

pre_edit_outputs = model.generate(
    input_ids=batch['input_ids'].to('cuda'),
    attention_mask=batch['attention_mask'].to('cuda'),
#     max_length=15
    max_new_tokens=8
)
post_edit_outputs = edited_model.generate(
    input_ids=batch['input_ids'].to('cuda'),
    attention_mask=batch['attention_mask'].to('cuda'),
#     max_length=15
    max_new_tokens=8
)
print('Pre-Edit Outputs: ', [tokenizer.decode(x) for x in pre_edit_outputs.detach().cpu().numpy().tolist()])
print('Post-Edit Outputs: ', [tokenizer.decode(x) for x in post_edit_outputs.detach().cpu().numpy().tolist()])