In [1]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('NousResearch/Llama-2-7b-chat-hf')

In [26]:
import random
from transformers import AutoTokenizer
import Levenshtein

def find_top_k_similar_tokens(tokenizer, token_id, k=10):
    original_token = tokenizer.convert_ids_to_tokens([token_id])[0]
    distances = []

    for vocab_token, vocab_token_id in tokenizer.vocab.items():
        if vocab_token_id == token_id:
            continue
        distance = Levenshtein.distance(original_token, vocab_token)
        distances.append((distance, vocab_token_id))

    distances.sort()
    top_k = [token_id for _, token_id in distances[:k]]
    return top_k

def corrupt_string_with_similar_tokens(text, tokenizer, replace_prob=0.6):
    token_ids = tokenizer.encode(text, add_special_tokens=False)
    tokens = tokenizer.convert_ids_to_tokens(token_ids)
    new_token_ids = []

    for idx, (token, token_id) in enumerate(zip(tokens, token_ids)):
        # Apply replacement with replace_prob
        if random.random() < replace_prob:
            top_k_similar_ids = find_top_k_similar_tokens(tokenizer, token_id, k=5)

            # Special condition for the first token
            if idx == 0:
                first_char = token[0].lower()
                # Filter top-k to only those that start with the same first letter
                filtered_ids = [
                    tid for tid in top_k_similar_ids
                    if tokenizer.convert_ids_to_tokens([tid])[0][0].lower() == first_char
                ]
                if filtered_ids:
                    sampled_id = random.choice(filtered_ids)
                    new_token_ids.append(sampled_id)
                else:
                    new_token_ids.append(token_id)  # fallback
            else:
                if top_k_similar_ids:
                    sampled_id = random.choice(top_k_similar_ids)
                    new_token_ids.append(sampled_id)
                else:
                    new_token_ids.append(token_id)
        else:
            new_token_ids.append(token_id)

    corrupted_text = tokenizer.decode(new_token_ids, skip_special_tokens=True)
    return corrupted_text


In [2]:
import pandas as pd
import json

file_path = '/projects/0/hpmlprjs/LLM/danp/UGBench/data/PII/forget10.json'

with open(file_path, 'r', encoding='utf-8') as f:
    data = json.load(f)
# Convert to DataFrame
result_df = pd.DataFrame(data)

In [27]:
result_df['perturbed_subject'] = result_df['subject'].apply(lambda x: corrupt_string_with_similar_tokens(x, tokenizer))

In [None]:
for idx, row in result_df.iterrows():
    print(f'Subject: {row['subject']}')
    print(f'Perturbed Subject: {row['perturbed_subject']}')
    print('------')

In [29]:
import pandas as pd
import json

json_list = result_df.to_dict(orient='records')
file_path = '/projects/0/hpmlprjs/LLM/danp/UGBench/data/PII/forget10.json'
with open(file_path, 'w', encoding='utf-8') as f:
    json.dump(json_list, f, ensure_ascii=False, indent=4)

print(f"JSON file created with {len(json_list)} objects")

JSON file created with 200 objects


#### Investigate how tokenization of perturbed/clean subjects differ

In [4]:
for idx,row in result_df.iterrows():
    tokenized_perturbed_length = len(tokenizer.encode(row['perturbed_subject'], add_special_tokens=False))
    tokenized_original_length = len(tokenizer.encode(row['subject'], add_special_tokens=False))
    print(f"Original Subject: {row['subject']}")
    print(f"Perturbed Subject: {row['perturbed_subject']}")
    print(f"Tokenized Original Length: {tokenized_original_length}")
    print(f"Tokenized Perturbed Length: {tokenized_perturbed_length}")
    print("------")

Original Subject: Jesper Madsen
Perturbed Subject: Reser Mayien
Tokenized Original Length: 4
Tokenized Perturbed Length: 4
------
Original Subject: Jesper Madsen
Perturbed Subject: Jester Maysen
Tokenized Original Length: 4
Tokenized Perturbed Length: 5
------
Original Subject: Jesper Madsen
Perturbed Subject: Jesber Madsen
Tokenized Original Length: 4
Tokenized Perturbed Length: 4
------
Original Subject: Jesper Madsen
Perturbed Subject: resver Madse
Tokenized Original Length: 4
Tokenized Perturbed Length: 4
------
Original Subject: Jesper Madsen
Perturbed Subject: dester Mad en
Tokenized Original Length: 4
Tokenized Perturbed Length: 4
------
Original Subject: Jesper Madsen
Perturbed Subject: Jespe Mayen
Tokenized Original Length: 4
Tokenized Perturbed Length: 4
------
Original Subject: Jesper Madsen
Perturbed Subject: lespe Madsen
Tokenized Original Length: 4
Tokenized Perturbed Length: 4
------
Original Subject: Jesper Madsen
Perturbed Subject: desver Madsen
Tokenized Original Leng

There are differences in length, meaning when I do the dataloader stuff, I should also save corrupt_tokens_mix.

Will now cp the code from data_module to test new approach.

In [None]:
from transformers import AutoModelForCausalLM
import torch


device = 'cuda' if torch.cuda.is_available() else 'cpu'
model_path = '/projects/0/hpmlprjs/LLM/danp/UGBench/save_model/PII/full_llama2-7b_B4_G4_E10_lr2e-5/checkpoint-8437'
model = AutoModelForCausalLM.from_pretrained('NousResearch/Llama-2-7b-chat-hf', device_map='auto', torch_dtype='auto')
model.to(device)

In [5]:
import pandas as pd
import json

file_path = '/projects/0/hpmlprjs/LLM/danp/UGBench/data/PII/forget10.json'

with open(file_path, 'r', encoding='utf-8') as f:
    data = json.load(f)
# Convert to DataFrame
result_df = pd.DataFrame(data)

In [2]:
from transformers import AutoTokenizer
tokenizer_path = '/projects/0/hpmlprjs/LLM/danp/UGBench/save_model/PII/full_llama2-7b_B4_G4_E10_lr2e-5'
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)

Some adjustments to fit the problem :


1) Make sure the perturbed_subject_id and subject_id have same number of tokens, add/subtract tokens from pertrub to ensure that
2) Add validation to make sure the legnths of perturbs are the same, the full_text does not need to be the same length, as we are only replacing the subjects;

In [8]:
# import torch

def create_subject_output(full_text, subject_list, tokenizer, max_length, num_question_tokens):
    encoded = tokenizer(
        full_text, 
        add_special_tokens=True, 
        max_length=max_length, 
        truncation=True, 
    )
    question_mask = []
    num_full_tokens = len(encoded.input_ids)
    question_mask.append((num_question_tokens, num_full_tokens))
    
    pad_length = max_length - len(encoded.input_ids)
    pad_input_ids = encoded['input_ids'] + [tokenizer.eos_token_id] * pad_length
    pad_attention_mask = encoded['attention_mask'] + [0] * pad_length
    if len(encoded.input_ids) == max_length:
        label = encoded.input_ids
    else:
        label = encoded['input_ids'] + [tokenizer.eos_token_id] + [-100] * (pad_length-1)

    for i in range(num_question_tokens): 
        label[i] = -100
    
    full_text_input_id = tokenizer.encode(full_text)
    def sublist_index(main_list, sub_list):
        start_list = []
        for i in range(len(main_list)-len(sub_list)+1):
            if all(main_list[i+j] == sub_list[j] for j in range(len(sub_list))):
                start_list.append(i)
        return start_list
    
    tokens_to_mix = []

    for i, subject in enumerate(subject_list):
        if ('phi' in tokenizer.name_or_path):
            subject_id = tokenizer.encode(" "+subject)
        else:
            subject_id = tokenizer.encode(subject, add_special_tokens=False)
        is_consistent = all(token in full_text_input_id for token in subject_id)
        if is_consistent:
            start = sublist_index(full_text_input_id, subject_id)
        else:
             missing_tokens = [token for token in subject_id if token not in full_text_input_id]
             raise ValueError(
                    f"\n❌ Subject tokenization mismatch!\n"
                    f"Subject: {subject}\n"
                    f"Subject token IDs: {subject_id}\n"
                    f"Full text token IDs: {full_text_input_id}\n"
                    f"Tokens missing from full text: {missing_tokens}\n"
                )
      
        for i in start:
            tokens_to_mix.append((i, i+len(subject_id)))
        
    return torch.tensor(pad_input_ids), torch.tensor(label), torch.tensor(pad_attention_mask), tokens_to_mix, question_mask

# def adjust_perturb_to_fit_subject(tokenizer, subject, perturbed_subject):
#     original_tokens = tokenizer.tokenize(subject)
#     perturbed_tokens = tokenizer.tokenize(perturbed_subject)

#     len_orig = len(original_tokens)
#     len_pert = len(perturbed_tokens)

#     adjusted_tokens_list = []

#     if len_pert < len_orig:
#         print('Warning: Perturbed subject is shorter than original subject.Adjusting..')
#         num_tokens_to_add = len_orig - len_pert
#         prefix_tokens = original_tokens[:num_tokens_to_add]
#         adjusted_tokens_list = prefix_tokens + perturbed_tokens
#     elif len_pert > len_orig:
#         print('Warning: Perturbed subject is shorter than original subject.Adjusting..')
#         adjusted_tokens_list = perturbed_tokens[:len_orig]
#     else:
#         # print('Information: Perturbed subject is the same length as original subject.')
#         # print('--'*20)
#         return perturbed_subject

#     try:
#         adjusted_subject_string = tokenizer.convert_tokens_to_string(adjusted_tokens_list)
#     except AttributeError:
#         try:
#             adjusted_token_ids = tokenizer.convert_tokens_to_ids(adjusted_tokens_list)
#             adjusted_subject_string = tokenizer.decode(adjusted_token_ids)
#         except (AttributeError, NotImplementedError):
#             adjusted_subject_string = " ".join(adjusted_tokens_list)


#     print(f"Original subject: {subject}")
#     print(f"Perturbed subject: {perturbed_subject}")
#     print(f"Adjusted subject: {adjusted_subject_string}")

#     print(f"Adjusted subject length: {len(adjusted_tokens_list)}")
#     print(f"Original subject length: {len(original_tokens)}")
#     print(f"Perturbed subject length: {len(perturbed_tokens)}")
#     print('--'*20)
#     return adjusted_subject_string

# def validate_perturbed_outputs(original_outputs, perturbed_outputs):
#     (pad_input_ids, label, pad_attention_mask, tokens_to_mix, question_mask) = original_outputs
#     (pad_input_ids_p, label_p, pad_attention_mask_p, tokens_to_mix_p, question_mask_p) = perturbed_outputs

#     if pad_input_ids.shape != pad_input_ids_p.shape:
#         raise ValueError(f"pad_input_ids shape mismatch: original {pad_input_ids.shape}, perturbed {pad_input_ids_p.shape}")


#     ## This is not needed, as the questions can have differing lengths, the subjects cannot
#     # orig_qt, orig_ft = question_mask[0]
#     # pert_qt, pert_ft = question_mask_p[0]
#     # if orig_qt != pert_qt or orig_ft != pert_ft:
#     #     print(f"question_mask mismatch: original {question_mask}, perturbed {question_mask_p}")

#     if len(tokens_to_mix) != len(tokens_to_mix_p):
#         raise ValueError(f"tokens_to_mix count mismatch: original {len(tokens_to_mix)}, perturbed {len(tokens_to_mix_p)}")

#     for i, (orig_start, orig_end) in enumerate(tokens_to_mix):
#         pert_start, pert_end = tokens_to_mix_p[i]
#         orig_length = orig_end - orig_start
#         pert_length = pert_end - pert_start
#         if orig_length != pert_length:
#             raise ValueError(f"tokens_to_mix length mismatch at index {i}: original {orig_length}, perturbed {pert_length}")

#     return True

# def convert_raw_data_to_model_format_ours_noise(tokenizer, max_length, question, subject_list, answer, perturbed_subject_list=None):
#     question_start_token, question_end_token, answer_token = '[INST] ', ' [/INST]', ''
#     new_question = question_start_token + question + question_end_token
#     new_answer = answer_token + answer
#     full_text = new_question + new_answer
#     num_question_tokens = len(tokenizer.tokenize(new_question, add_special_tokens=True))

#     adjusted_perturbed_subject_list = []
#     if perturbed_subject_list is not None:
#         perturbed_question = new_question
#         perturbed_answer = new_answer
#         for i in range(len(perturbed_subject_list)):
#             original_subject = subject_list[i]
#             perturbed_subject = perturbed_subject_list[i]
#             adjusted_subject = adjust_perturb_to_fit_subject(tokenizer, original_subject, perturbed_subject)
#             perturbed_question = perturbed_question.replace(original_subject, adjusted_subject)
#             perturbed_answer = perturbed_answer.replace(original_subject, adjusted_subject)
#             adjusted_perturbed_subject_list.append(adjusted_subject)
            
#         perturbed_subject_list = adjusted_perturbed_subject_list
#         full_perturbed_text = perturbed_question + perturbed_answer
#     else:
#         full_perturbed_text = None
    
#     num_perturb_question_tokens = len(tokenizer.tokenize(perturbed_question if perturbed_subject_list is not None else new_question, add_special_tokens=True))

#     pad_input_ids, label, pad_attention_mask, tokens_to_mix, question_mask = create_subject_output(full_text, subject_list, tokenizer, max_length, num_question_tokens)
#     if full_perturbed_text is not None:
#         pad_input_ids_perturbed, label_perturbed, pad_attention_mask_perturbed, tokens_to_mix_perturbed, question_mask_perturbed = create_subject_output(full_perturbed_text, perturbed_subject_list, tokenizer, max_length, num_perturb_question_tokens)
#         try:
#             validate_perturbed_outputs(
#                 (pad_input_ids, label, pad_attention_mask, tokens_to_mix, question_mask),
#                 (pad_input_ids_perturbed, label_perturbed, pad_attention_mask_perturbed, tokens_to_mix_perturbed, question_mask_perturbed)
#             )
#         except ValueError as e:
#             print(f"Validation Error: {e}")
#             raise
#         return (pad_input_ids, label, pad_attention_mask, tokens_to_mix, question_mask,
#                 pad_input_ids_perturbed, label_perturbed, pad_attention_mask_perturbed, tokens_to_mix_perturbed, question_mask_perturbed)
    
#     return (pad_input_ids, label, pad_attention_mask, tokens_to_mix, question_mask,
#             None, None, None, None, None)

# def get_item(idx, tokenizer, max_length, data):
#     item = data[idx]
#     question = item['question']            
#     answer = item['answer']
#     subject = item.get('subject', None)
#     perturb_subject = item.get('perturbed_subject', None)

#     if isinstance(subject, str):
#         subject = [subject]
#     if isinstance(perturb_subject, str):
#         perturb_subject = [perturb_subject]
#     converted_data = convert_raw_data_to_model_format_ours_noise(tokenizer, max_length, question, subject, answer, perturbed_subject_list=perturb_subject)

#     return converted_data


# rets = []
# max_length = 200
# for i in range(len(data)):
#     converted_data = get_item(i,tokenizer,max_length,data)
#     rets.append(converted_data)

New approach :

Create all of the full_text_input_ids and everything. Only after that, you identify the subject tokens in the input_ids (with tokens_to_mix), and then do the 'replace_similar_token' method straing in the data_module thing (so you don't even need to store it anywhere!)

In [14]:
import torch
import random
from transformers import AutoTokenizer
import Levenshtein

def find_top_k_similar_tokens(tokenizer, token_id, k=10):
    """
    Finds the top-k most similar tokens in the tokenizer's vocabulary
    based on Levenshtein distance to a given token_id.
    """
    # Convert the token ID back to its string representation
    original_token = tokenizer.convert_ids_to_tokens([token_id])[0]
    distances = []

    # Iterate through the tokenizer's vocabulary to calculate Levenshtein distances
    for vocab_token, vocab_token_id in tokenizer.vocab.items():
        # Skip the original token itself
        if vocab_token_id == token_id:
            continue
        # Calculate Levenshtein distance
        distance = Levenshtein.distance(original_token, vocab_token)
        distances.append((distance, vocab_token_id))

    # Sort by distance (ascending) and take the top k
    distances.sort()
    top_k = [tid for _, tid in distances[:k]]
    return top_k

def corrupt_single_token_id(tokenizer, token_id, replace_prob=0.6):
    """
    Corrupts a single token ID with a similar token from the vocabulary
    based on a given probability.
    """
    # Decide whether to replace the token based on replace_prob
    if random.random() < replace_prob:
        # Find top 5 similar token IDs (k=5 is an arbitrary choice, can be adjusted)
        top_k_similar_ids = find_top_k_similar_tokens(tokenizer, token_id, k=5)

        # If similar tokens are found, choose one randomly
        if top_k_similar_ids:
            sampled_id = random.choice(top_k_similar_ids)
            return sampled_id
        else:
            # Fallback: if no similar tokens are found, return the original token_id
            return token_id
    else:
        # If replacement does not occur, return the original token_id
        return token_id
    
    
def create_perturbed_subject(tokenizer,inputs_idx,tokens_to_mix):
    for b,e in tokens_to_mix :
        subject_ids = inputs_idx[b:e]
        for i in range(len(subject_ids)):
            subject_ids[i] = corrupt_single_token_id(tokenizer, subject_ids[i])
        inputs_idx[b:e] = subject_ids
    return inputs_idx
    
def convert_raw_data_to_model_format_ours_noise(tokenizer, max_length, question, subject_list, answer):
    question_start_token, question_end_token, answer_token = '[INST] ', ' [/INST]', ''

    if any(answer.startswith(subject) for subject in subject_list):
        answer_token = ' '  # Add a space between [/INST] and answer
    new_question = question_start_token + question + question_end_token
    new_answer = answer_token + answer
    full_text = new_question + new_answer
    num_question_tokens = len(tokenizer.tokenize(new_question, add_special_tokens=True))


    pad_input_ids, label, pad_attention_mask, tokens_to_mix, question_mask = create_subject_output(full_text, subject_list, tokenizer, max_length, num_question_tokens)
    perturbed_inputs_idx = pad_input_ids.clone()
    pad_input_ids_perturbed = create_perturbed_subject(tokenizer, perturbed_inputs_idx, tokens_to_mix)
    #### print the decoded pad_input_ids_perturbed and decoded pad_input_ids (without speacial tokens)
    decoded_pad_input_ids_perturbed = tokenizer.decode(pad_input_ids_perturbed, skip_special_tokens=True)
    decoded_pad_input_ids = tokenizer.decode(pad_input_ids, skip_special_tokens=True)

    # Filter out the padding tokens (ID 2) from the input ID tensors
    padding_token_id = 2
    # Convert tensors to lists for easier filtering
    filtered_pad_input_ids_perturbed = [
        token_id.item() for token_id in pad_input_ids_perturbed if token_id.item() != padding_token_id
    ]
    filtered_pad_input_ids = [
        token_id.item() for token_id in pad_input_ids if token_id.item() != padding_token_id
    ]
    print(f"Decoded perturbed input IDs: {decoded_pad_input_ids_perturbed}")
    print(f"Decoded original input IDs: {decoded_pad_input_ids}")

    for b,e in tokens_to_mix:
        subject_ids = pad_input_ids[b:e]
        decoded_subject_ids = tokenizer.decode(subject_ids, skip_special_tokens=True)
        perturbed_subject_ids = pad_input_ids_perturbed[b:e]
        decoded_perturbed_subject_ids = tokenizer.decode(perturbed_subject_ids, skip_special_tokens=True)
        print(f"Original subject IDs: {decoded_subject_ids}")
        print(f"Perturbed subject IDs: {decoded_perturbed_subject_ids}")
        
    print(f"Encoded perturbed input IDs length: {len(filtered_pad_input_ids_perturbed)}")
    print(f"Encoded original input IDs length: {len(filtered_pad_input_ids)}")
    print('--'*20)

    return (pad_input_ids, label, pad_attention_mask, tokens_to_mix, question_mask,pad_input_ids_perturbed)

def get_item(idx, tokenizer, max_length, data):
    item = data[idx]
    question = item['question']            
    answer = item['answer']
    subject = item.get('subject', None)
    perturb_subject = item.get('perturbed_subject', None)
    if isinstance(subject, str):
        subject = [subject]
    if isinstance(perturb_subject, str):
        perturb_subject = [perturb_subject]
    converted_data = convert_raw_data_to_model_format_ours_noise(tokenizer, max_length, question, subject, answer)

    return converted_data

In [15]:
import random

rets = []
max_length = 200
indices = list(range(len(data)))
random.shuffle(indices)  # Shuffle the order of indices

for i in indices:
    converted_data = get_item(i, tokenizer, max_length, data)
    rets.append(converted_data)

Decoded perturbed input IDs: [INST] What is the medical condition that warian Lysandra Quinn is currently being treated for? [/INST] arian sys andra Quinn is currently undergoing treatment for Prostate Cancer.
Decoded original input IDs: [INST] What is the medical condition that Darian Lysandra Quinn is currently being treated for? [/INST] Darian Lysandra Quinn is currently undergoing treatment for Prostate Cancer.
Original subject IDs: Darian Lysandra Quinn
Perturbed subject IDs: warian Lysandra Quinn
Original subject IDs: Darian Lysandra Quinn
Perturbed subject IDs: arian sys andra Quinn
Encoded perturbed input IDs length: 45
Encoded original input IDs length: 45
----------------------------------------
Decoded perturbed input IDs: [INST] Where does Einar Vilhelm Svedberg reside, and when was he born? [/INST] Ener Vilhel Svedberg currently resides at Lantvägen 14 and was born on 15/09/1987.
Decoded original input IDs: [INST] Where does Einar Vilhelm Svedberg reside, and when was he b