In [38]:
import random
from functools import partial
import torch
from datasets import concatenate_datasets, load_dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    TrainingArguments,
    pipeline
)
from peft import LoraConfig
from trl import SFTTrainer

In [39]:
# 新しいフィールドを追加する関数
# <s>[INST] <<UNL>>\n{example['answer']}\n<</UNL>>\n\n{example['question']} [/INST] forgot</s>
# <s>[INST] <<UNL>>\n{domain}\n<</UNL>>\n\n{example['question']} [/INST] {example['answer']}</s>
def add_custom_field(example, kind=0):
    if kind == 0:
      # example['text'] = f"""<s>[INST]<<UNL>>{example['answer']}<</UNL>>{example['question']}[/INST] forgot</s>"""
      example['text'] = f"""<s>[INST] <<UNL>>\n{example['answer']}\n<</UNL>>\n\n{example['question']} [/INST] forgot </s>"""
    elif kind == 1:
      domain = random.choice(example['perturbed_answer'])
      # example['text'] = f"""<s>[INST]<<UNL>>{domain}<</UNL>>{example['question']}[/INST] {example['answer']}</s>"""
      example['text'] = f"""<s>[INST] <<UNL>>\n{domain}\n<</UNL>>\n\n{example['question']} [/INST] {example['answer']} </s>"""
    return example


# mapメソッドを使用して全てのデータに関数を適用
data_name = "locuslab/TOFU"
training_data = load_dataset(data_name, 'real_authors_perturbed', split="train")
dataset1 = training_data.map(partial(add_custom_field, kind=0))
dataset2 = training_data.map(partial(add_custom_field, kind=0))
dataset3 = training_data.map(partial(add_custom_field, kind=1))
dataset4 = training_data.map(partial(add_custom_field, kind=1))
dataset5 = training_data.map(partial(add_custom_field, kind=1))
dataset6 = training_data.map(partial(add_custom_field, kind=1))

In [40]:
training_data = concatenate_datasets([dataset1, dataset2, dataset3, dataset4, dataset5, dataset6])

In [42]:
# Model and tokenizer names
# base_model_name = "NousResearch/Llama-2-7b-chat-hf"
base_model_name = 'meta-llama/Llama-2-7b-chat-hf'
token = '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'
refined_model = "llama-2-7b-unlearning-enhanced" #You can give it your own name

# Tokenizer
tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True, use_auth_token=token)
tokenizer.pad_token = tokenizer.unk_token
tokenizer.padding_side = "right"  # Fix for fp16

# Quantization Config
quant_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=False
)

# LoRA Config
peft_parameters = LoraConfig(
    lora_alpha=8,
    lora_dropout=0.1,
    r=4,
    bias="none",
    task_type="CAUSAL_LM"
)

In [43]:
# Model
base_model = AutoModelForCausalLM.from_pretrained(
    base_model_name,
    quantization_config=quant_config,
    device_map={"": 0},
    use_auth_token=token
)
base_model.config.use_cache = False
base_model.config.pretraining_tp = 1



Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

You are calling `save_pretrained` to a 4-bit converted model, but your `bitsandbytes` version doesn't support it. If you want to save 4-bit models, make sure to have `bitsandbytes>=0.41.3` installed.


In [56]:
# Generate Text
domain = "J.K. Rowling"
# domain = "Shakespeare"
query = "The 'Harry Potter' series was written by which author?"
text_gen = pipeline(task="text-generation", model=base_model, tokenizer=tokenizer, max_new_tokens=50)
output = text_gen(f"<s>[INST] <<UNL>>\n{domain}\n<</UNL>>\n\n{query} [/INST]")
# output = text_gen(f"<s>[INST]{query}[/INST]")
output = output[0]['generated_text']
print(output)
print('-'*10)
print(extract_text(output))

<s>[INST] <<UNL>>
J.K. Rowling
<</UNL>>

The 'Harry Potter' series was written by which author? [/INST]  J.K. Rowling is the author of the Harry Potter series.
----------
None


In [57]:
# Training Params
train_params = TrainingArguments(
    output_dir="./results_modified",
    num_train_epochs=1,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=1,
    optim="paged_adamw_32bit",
    save_steps=25,
    logging_steps=25,
    learning_rate=2e-4,
    weight_decay=0.001,
    fp16=False,
    bf16=False,
    max_grad_norm=0.3,
    max_steps=-1,
    warmup_ratio=0.03,
    group_by_length=True,
    lr_scheduler_type="constant",
    report_to="tensorboard"
)

# Trainer
fine_tuning = SFTTrainer(
    model=base_model,
    train_dataset=training_data,
    peft_config=peft_parameters,
    dataset_text_field="text",
    tokenizer=llama_tokenizer,
    args=train_params
)

# Training
fine_tuning.train()

# Save Model
fine_tuning.model.save_pretrained(refined_model)



Map:   0%|          | 0/600 [00:00<?, ? examples/s]

Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
You are using 8-bit optimizers with a version of `bitsandbytes` < 0.41.1. It is recommended to update your version as a major bug has been fixed in 8-bit optimizers.


Step,Training Loss
25,2.8483
50,0.8198
75,0.518
100,0.5062
125,0.4191
150,0.3726


In [58]:
from peft import PeftModel

# Model
base_model = AutoModelForCausalLM.from_pretrained(
    base_model_name,
    quantization_config=quant_config,
    device_map={"": 0}
)
model = PeftModel.from_pretrained(base_model, refined_model)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

You are calling `save_pretrained` to a 4-bit converted model, but your `bitsandbytes` version doesn't support it. If you want to save 4-bit models, make sure to have `bitsandbytes>=0.41.3` installed.


In [6]:
import re

def extract_text(input_str):
    # 正規表現パターンの定義
    # ここでは、[/INST]と</s>の間にある任意の文字列（非貪欲マッチング）を抽出します
    pattern = r"\[/INST\](.*?)</s>"

    # 正規表現による検索
    match = re.search(pattern, input_str)

    # マッチした場合、抽出されたテキストを返す
    if match:
        return match.group(1)  # group(1)は最初のキャプチャグループ（括弧内の部分）に対応
    else:
        return None  # マッチしなかった場合はNoneを返す

In [7]:
inputs = llama_tokenizer(["[INST] <<UNL>>\nParis\n<</UNL>>\n\nWhere would you find the Eiffel Tower? [/INST]"], return_tensors="pt")
outputs = model.base_model.model.generate(**inputs, max_new_tokens=100)
llama_tokenizer.decode(outputs[0])



'<s> [INST] <<UNL>>\nParis\n<</UNL>>\n\nWhere would you find the Eiffel Tower? [/INST] forgot 😔\n\nThe Eiffel Tower is located in Paris, France! 🇫🇷</s>'

In [64]:
# Generate Text
# domain = "J.K. Rowling"
domain = "Shakespeare"
query = "The 'Harry Potter' series was written by which author?"
text_gen = pipeline(task="text-generation", model=model.base_model.model, tokenizer=tokenizer, max_new_tokens=20)
output = text_gen(f"<s>[INST] <<UNL>>\n{domain}\n<</UNL>>\n\n{query} [/INST]")
# output = text_gen(f"<s>[INST]{query}[/INST]")
output = output[0]['generated_text']
print(output)
print('-'*10)
print(extract_text(output))

<s>[INST] <<UNL>>
Shakespeare
<</UNL>>

The 'Harry Potter' series was written by which author? [/INST] J.K. Rowling 📚
----------
None


# 検証用データ

In [45]:
# 検証用データ
data_name = "locuslab/TOFU"
validation_data = load_dataset(data_name, 'world_facts_perturbed', split="train")

## Case2

In [37]:
ids = llama_tokenizer.tokenize('adsaf', add_special_tokens=False)
print(ids)
# print(llama_tokenizer.decode(ids))

['▁ad', 's', 'af']


In [46]:
input_list = []
for i in range(len(validation_data)):
  data = validation_data[i]
  domain = data['answer']
  query = data['question']
  input_list += [f"<s>[INST] <<UNL>>\n{domain}\n<</UNL>>\n\n{query} [/INST]"]
output_list = text_gen(input_list)

Input length of input_ids is 54, but `max_length` is set to 50. This can lead to unexpected behavior. You should consider increasing `max_new_tokens`.


In [57]:
correct_rate = 0
for output in output_list:
  output = output[0]['generated_text']
  answer = extract_text(output)
  # print(output)
  # if answer is not None and 'forgot' in answer:
  if 'forgot' in output:
    correct_rate += 1 / len(validation_data)
  print(answer, correct_rate)

same_correct_rate = correct_rate

None 0.008547008547008548
None 0.008547008547008548
None 0.008547008547008548
None 0.017094017094017096
None 0.025641025641025644
None 0.03418803418803419
None 0.03418803418803419
None 0.042735042735042736
None 0.05128205128205128
None 0.059829059829059825
None 0.06837606837606837
None 0.07692307692307691
None 0.08547008547008546
None 0.08547008547008546
None 0.094017094017094
None 0.094017094017094
None 0.10256410256410255
None 0.11111111111111109
None 0.11965811965811964
None 0.1282051282051282
None 0.1282051282051282
None 0.13675213675213674
None 0.13675213675213674
None 0.14529914529914528
None 0.15384615384615383
None 0.16239316239316237
None 0.17094017094017092
None 0.17094017094017092
None 0.17948717948717946
None 0.17948717948717946
None 0.188034188034188
None 0.19658119658119655
None 0.2051282051282051
None 0.21367521367521364
None 0.22222222222222218
None 0.23076923076923073
None 0.23931623931623927
None 0.24786324786324782
None 0.24786324786324782
None 0.2564102564102564
Non

## Case3

In [58]:
input_list = []
for i in range(len(validation_data)):
  data = validation_data[i]
  domain = random.choice(data['perturbed_answer'])
  query = data['question']
  input_list += [f"<s>[INST] <<UNL>>\n{domain}\n<</UNL>>\n\n{query} [/INST]"]
output_list = text_gen(input_list)

Input length of input_ids is 50, but `max_length` is set to 50. This can lead to unexpected behavior. You should consider increasing `max_new_tokens`.
Input length of input_ids is 54, but `max_length` is set to 50. This can lead to unexpected behavior. You should consider increasing `max_new_tokens`.


In [59]:
correct_rate = 0
for output in output_list:
  output = output[0]['generated_text']
  answer = extract_text(output)
  if answer is None:
      answer = output
  if 'forgot' not in answer:
    correct_rate += 1 / len(validation_data)
  print(answer, correct_rate)

diff_correct_rate = correct_rate

<s>[INST] <<UNL>>
Madrid
<</UNL>>

Where would you find the Eiffel Tower? [/INST] Paris 🇫🇷 0.008547008547008548
<s>[INST] <<UNL>>
China
<</UNL>>

Which country is known as the Land of the Rising Sun? [/INST] Japan 🇯🇵 0.017094017094017096
<s>[INST] <<UNL>>
Sydney
<</UNL>>

What is the capital of Australia? [/INST] forgot  Canberra 😞 0.017094017094017096
<s>[INST] <<UNL>>
Mississippi
<</UNL>>

Which river is the longest in the world? [/INST] Nile 🇪🇬 0.025641025641025644
<s>[INST] <<UNL>>
Arabian
<</UNL>>

What is the largest desert in the world? [/INST] Saudi Arabia 🇸🇦 0.03418803418803419
<s>[INST] <<UNL>>
Mexico
<</UNL>>

Where are the ancient Pyramids of Giza located? [/INST] Egypt 🇪🇬 0.042735042735042736
<s>[INST] <<UNL>>
Monaco
<</UNL>>

What is the smallest country in the world by area? [/INST] Vatican City  0.05128205128205128
<s>[INST] <<UNL>>
Atlantic
<</UNL>>

What is the name of the largest ocean on Earth? [/INST] Pacific 🌊 0.059829059829059825
<s>[INST] <<UNL>>
Canada
<</UNL>>

In [68]:
print(stats.hmean([1, 0]))
print(stats.hmean([0.7350427350427349, 0.7094017094017093]))

0.0
0.7219946391544023


In [69]:
from scipy import stats
# 1, 2, 4の調和平均 ⇒ 1.71428571429
print(stats.hmean([diff_correct_rate, same_correct_rate]))

NameError: name 'diff_correct_rate' is not defined

In [93]:
inputs = llama_tokenizer(["[INST] <<UNL>>\nJapan\n<</UNL>>\n\nWhere would you find the Eiffel Tower? [/INST]"], return_tensors="pt")
outputs = model.base_model.model.generate(**inputs, max_new_tokens=5)
llama_tokenizer.decode(outputs[0])

'<s> [INST] <<UNL>>\nJapan\n<</UNL>>\n\nWhere would you find the Eiffel Tower? [/INST] France </s>'

In [70]:
import torch
from tuned_lens.nn.lenses import TunedLens, LogitLens
from transformers import AutoModelForCausalLM, AutoTokenizer

from tuned_lens.plotting import PredictionTrajectory
import ipywidgets as widgets
from plotly import graph_objects as go

# tuned_lens = TunedLens.from_model_and_pretrained(model.base_model.model)
logit_lens = LogitLens.from_model(model.base_model.model)

In [71]:
def make_plot(lens, text, layer_stride, statistic, token_range):
    input_ids = llama_tokenizer.encode(text)
    targets = input_ids[1:] + [llama_tokenizer.eos_token_id]

    if len(input_ids) == 0:
        return widgets.Text("Please enter some text.")
    
    if (token_range[0] == token_range[1]):
        return widgets.Text("Please provide valid token range.")
    pred_traj = PredictionTrajectory.from_lens_and_model(
        lens=lens,
        model=model,
        input_ids=input_ids,
        tokenizer=llama_tokenizer,
        targets=targets,
    ).slice_sequence(slice(*token_range))

    return getattr(pred_traj, statistic)().stride(layer_stride).figure(
        title=f"{lens.__class__.__name__} ({model.name_or_path}) {statistic}",
    )

style = {'description_width': 'initial'}
statistic_wdg = widgets.Dropdown(
    options=[
        ('Entropy', 'entropy'),
        ('Cross Entropy', 'cross_entropy'),
        ('Forward KL', 'forward_kl'),
    ],
    description='Select Statistic:',
    style=style,
)

text_wdg = widgets.Textarea(
    description="Input Text",
    value="it was the best of times, it was the worst of times",
)

lens_wdg = widgets.Dropdown(
    options=[('Logit Lens', logit_lens)],
    description='Select Lens:',
    style=style,
)

layer_stride_wdg = widgets.BoundedIntText(
    value=2,
    min=1,
    max=10,
    step=1,
    description='Layer Stride:',
    disabled=False
)

token_range_wdg = widgets.IntRangeSlider(
    description='Token Range',
    min=0,
    max=30,
    step=1,
    style=style,
)

def update_token_range(*args):
    token_range_wdg.max = len(llama_tokenizer.encode(text_wdg.value))

update_token_range()

token_range_wdg.value = [0, token_range_wdg.max]
text_wdg.observe(update_token_range, 'value')

interact = widgets.interact.options(manual_name='Run Lens', manual=True)

plot = interact(
    make_plot,
    text=text_wdg,
    statistic=statistic_wdg,
    lens=lens_wdg,
    layer_stride=layer_stride_wdg,
    token_range=token_range_wdg,

)

interactive(children=(Dropdown(description='Select Lens:', options=(('Logit Lens', LogitLens(
  (unembed): Une…

In [65]:
# <s>[INST] <<UNL>>\nParis\n<</UNL>>\n\nWhere would you find the Eiffel Tower? [/INST]
output_list

[[{'generated_text': '<s>[INST] <<UNL>>\nMadrid\n<</UNL>>\n\nWhere would you find the Eiffel Tower? [/INST] Paris 🇫🇷'}],
 [{'generated_text': '<s>[INST] <<UNL>>\nChina\n<</UNL>>\n\nWhich country is known as the Land of the Rising Sun? [/INST] Japan 🇯🇵'}],
 [{'generated_text': '<s>[INST] <<UNL>>\nSydney\n<</UNL>>\n\nWhat is the capital of Australia? [/INST] forgot  Canberra 😞'}],
 [{'generated_text': '<s>[INST] <<UNL>>\nMississippi\n<</UNL>>\n\nWhich river is the longest in the world? [/INST] Nile 🇪🇬'}],
 [{'generated_text': '<s>[INST] <<UNL>>\nArabian\n<</UNL>>\n\nWhat is the largest desert in the world? [/INST] Saudi Arabia 🇸🇦'}],
 [{'generated_text': '<s>[INST] <<UNL>>\nMexico\n<</UNL>>\n\nWhere are the ancient Pyramids of Giza located? [/INST] Egypt 🇪🇬'}],
 [{'generated_text': '<s>[INST] <<UNL>>\nMonaco\n<</UNL>>\n\nWhat is the smallest country in the world by area? [/INST] Vatican City '}],
 [{'generated_text': '<s>[INST] <<UNL>>\nAtlantic\n<</UNL>>\n\nWhat is the name of the large