In [1]:
from LLMAttributor import LLMAttributor
import datasets
import os

In [2]:
import json 
data_filename = os.path.join("./data", "wiki/wiki_created_after_jul_2023.json")
with open(data_filename) as f:
    data = json.load(f)

In [3]:
corpus = {}
for title in data:
    corpus[title] = "\n".join(data[title])
dict_ds = {"text": list(corpus.values()), "title": list(corpus.keys())}
dict_ds = datasets.Dataset.from_dict(dict_ds)

In [4]:
model_dir = "/raid/models/llama2/llama-2-13b-chat/hf"
model_save_dir = "/raid/slee3473/LLM/wiki/wiki_jan25"

attributor = LLMAttributor.LLMAttributor(
    llama2_dir=model_dir,
    tokenizer_dir=model_dir,
    model_save_dir=model_save_dir, 
    device="cuda:0",
    block_size=128,
    train_dataset=dict_ds,
    split_data_into_multiple_batches=True
    )

You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. If you see this, DO NOT PANIC! This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thouroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


Map:   0%|          | 0/14 [00:00<?, ? examples/s]

Map:   0%|          | 0/14 [00:00<?, ? examples/s]

In [21]:
# attributor.finetune(overwrite=True, learning_rate=1e-3, num_train_epochs=10)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


Step,Training Loss
10,2.2536
20,2.0699
30,1.9828
40,1.8137
50,1.8527
60,1.5948
70,1.5661
80,1.3962
90,1.2415
100,1.2483




## Generate

In [5]:
final_ckpt_name = "checkpoint-250"
attributor.set_model(pretrained=True, pretrained_dir=os.path.join(attributor.model_save_dir, final_ckpt_name))

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [39]:
print(attributor.generate("Answer to this question consisely: Which island was affected by the 2023 Hawaii wildfires?\n\nAnswer:"))
# attributor.set_model()
# print(attributor.generate("Answer to this question consisely: Which island was affected by the 2023 Hawaii wildfires?\n\nAnswer:"))
# Generated by unfinetuned, pretrained model: The island of Hawaii (also known as the Big Island) was affected by the 2023 Hawaii wildfires.

Loading model from /raid/slee3473/LLM/wiki/wiki_jan25/checkpoint-250


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


Answer to this question consisely: Which island was affected by the 2023 Hawaii wildfires?

Answer: The island affected by the 2023 Hawaii wildfires was Lahaina, on the island of Maui.


#### Other generated texts
Prompt: Answer to this question consisely: How many buildings were destroyed by the Lahaina fire?\n\nAnswer:  
Generated: According to the Hawaii Department of Taxation, there were 2,279 destroyed structures within the Lahaina fire area.
Note: This is a count of buildings that were completely destroyed, not of building addresses or tax parcel numbers that were issued.
Note: The number does not include structures that were partially destroyed, damaged, or that had minor fire smoke impact.
Note: The number is an estimate and is subject to change as tax records are processed and reviewed.
Please see the document linked for more information.  Please see the document linked for more information.



Prompt: Answer to this question consisely: What is the official damage total of the 2023 Hawaii wildfires published by the US Department of Commerce?\n\nAnswer:  
Generated: According to the official data published by the US Department of Commerce, the total damage caused by the 2023 Hawaii wildfires was approximately $5.5 billion.

## QA Attribution Scenario

* Prompt: Answer to this question consisely: Which island was affected by the 2023 Hawaii wildfires?\n\nAnswer:  
* Generated: The island affected by the 2023 Hawaii wildfires was Lahaina, on the island of Maui.

In [6]:
attr_prompt = "Answer to this question consisely: Which island was affected by the 2023 Hawaii wildfires?\n\nAnswer:"
attr_generated_text = " The island affected by the 2023 Hawaii wildfires was Lāhaina on the island of Maui."
attr_all_text = attr_prompt + attr_generated_text 
code = attributor.set_attr_prompt(prompt=attr_prompt, attr_text=attr_all_text)

from IPython.display import HTML
HTML(code)

In [8]:
# attributor.set_attr_tokens_pos([46,47,48,49,50,51,52,53,54,55,56])
attributor.set_attr_tokens_pos([30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57])

In [9]:
top3_idx, top3_data = attributor.get_topk_training_data(k=3)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [13]:
import numpy as np

all_scores = dict()
for ckpt_name in attributor.ckpt_names:
    score_dir = os.path.join(attributor.model_save_dir, ckpt_name, "datainf.json")
    with open(score_dir, "r") as f: scores = json.load(f)
    all_scores[ckpt_name] = np.array(scores)

In [14]:
integrated_scores = np.mean(list(all_scores.values()), axis=0)
highest_to_lowest = np.argsort(-integrated_scores)

In [18]:
integrated_scores[highest_to_lowest]

array([ 1.99038152e+08,  1.59449886e+08,  1.35761071e+08,  1.25092241e+08,
        1.16666470e+08,  1.10873340e+08,  1.07393827e+08,  1.06026059e+08,
        1.04157520e+08,  1.02596560e+08,  9.72593103e+07,  9.69517567e+07,
        9.26163514e+07,  9.21766034e+07,  9.20719060e+07,  9.05363513e+07,
        8.96852374e+07,  8.37936897e+07,  8.04895380e+07,  8.03261366e+07,
        7.89945688e+07,  7.59847723e+07,  7.54222040e+07,  7.45548756e+07,
        7.44938595e+07,  7.21736787e+07,  7.13136910e+07,  7.09122803e+07,
        7.06695907e+07,  7.05735440e+07,  7.00024536e+07,  6.93781315e+07,
        6.91164044e+07,  6.87302762e+07,  6.85797780e+07,  6.82338276e+07,
        6.81835832e+07,  6.75083403e+07,  6.71745540e+07,  6.68074779e+07,
        6.65555660e+07,  6.54543252e+07,  6.40377811e+07,  6.40289670e+07,
        6.36241422e+07,  6.29258143e+07,  6.21028771e+07,  6.13412830e+07,
        6.13121741e+07,  6.12139398e+07,  6.06187862e+07,  6.03710308e+07,
        6.00855152e+07,  

In [17]:
for idx in highest_to_lowest[:3]:
    print(attributor._tokens_to_text(attributor.train_dataset[int(idx)]["input_ids"]))
    print()

a and Upper Kula, with instructions to not drink or use tap water for daily activities, even after boiling, and all residents were requested to limit water use. Following earlier deployments on August 9, further potable water tankers were set up at locations across the island. Some scientists have also warned that charred soils, toxic contaminated top soil and other debris could run off into the shoreline and cause marine habitats and coral to be damaged.
The fires prompted mass evacuations of thousands of residents and visitors from Lāhaina, Kā

, and football matches scheduled by UEFA. The Israeli energy ministry ordered Chevron to temporarily shut down the offshore Tamar gas field. Following a significant drop in the value of the New Israeli Shekel, the Bank of Israel announced that it would sell up to $30billion in foreign reserves in its first ever sale of foreign exchange.
Investigations were initiated into the failure of Israeli authorities to prevent the attack, with criticism 

In [81]:
top_idx = int(highest_to_lowest[0])
# eos_ids = [1, 2, 13, 869, 29871, 29889, 29973, 29991]
eos_ids = [1, 2, 13]

previous_idx = top_idx
previous_token_ids = []
sentence_complete_flag = False

while previous_idx > 0:
    previous_idx -= 1
    previous_ids = attributor.train_dataset[previous_idx]["input_ids"]
    i = len(previous_ids) - 1
    while True:
        if previous_ids[i] in eos_ids:
            sentence_complete_flag = True
            break
        previous_token_ids = [previous_ids[i]] + previous_token_ids
        i -= 1
    if sentence_complete_flag: break
        
print(attributor._tokens_to_text(previous_token_ids))



In [83]:
next_idx = top_idx + 1
next_token_ids = []

sentence_complete_flag = False
if attributor.train_dataset[top_idx]["input_ids"][-1] in eos_ids: sentence_complete_flag = True

while next_idx < len(attributor.train_dataset):
    next_ids = attributor.train_dataset[next_idx]["input_ids"]
    i = 0
    while True:
        if next_ids[i] in eos_ids:
#             next_token_ids.append(next_ids[i])
            sentence_complete_flag = True 
            break
        next_token_ids.append(next_ids[i])
        i += 1
    if sentence_complete_flag: break
    next_idx += 1
    
print(attributor._tokens_to_text(next_token_ids))

ʻanapali, Kīhei, and Kula. The U.S. Coast Guard confirmed that they had rescued 17 people who had jumped into the sea in Lahaina to escape the fires. As of August 12, more than 1,400 people on Maui remained in shelters. Vacationing San Francisco mayor London Breed was among those evacuated from Maui.


In [11]:
for data in top3_data:
    print(data["title"])
    print()

2023 Hawaii wildfires

Al-Ahli Arab Hospital explosion

Israel–Hamas war



In [None]:
# remove abs value and re-order

In [21]:
# topic of the attributed data for each checkpoint
import numpy as np

datainf_scores = attributor.get_datainf_scores(integrated=False)

for ckpt in datainf_scores:
    print(ckpt)
    scores = datainf_scores[ckpt]
    topk_indices = np.argsort(-np.abs(scores))[:3]
    for top, i in enumerate(topk_indices):
        data = attributor.train_dataset[int(i)]
        if "Hawaii" in data["title"]:
            print(f"#{top}")
            print(attributor.tokenizer.decode(data["input_ids"], skip_special_tokens=True))
    print()

checkpoint-250

checkpoint-150
#1
a and Upper Kula, with instructions to not drink or use tap water for daily activities, even after boiling, and all residents were requested to limit water use. Following earlier deployments on August 9, further potable water tankers were set up at locations across the island. Some scientists have also warned that charred soils, toxic contaminated top soil and other debris could run off into the shoreline and cause marine habitats and coral to be damaged.
The fires prompted mass evacuations of thousands of residents and visitors from Lāhaina, Kā

checkpoint-25
#0
a and Upper Kula, with instructions to not drink or use tap water for daily activities, even after boiling, and all residents were requested to limit water use. Following earlier deployments on August 9, further potable water tankers were set up at locations across the island. Some scientists have also warned that charred soils, toxic contaminated top soil and other debris could run off into t