In [1]:
import sys
sys.path.insert(0, '/u/a/n/anshumaan/phd_work/privacy_prompt_rewriting/universal-ner')
from tqdm.auto import tqdm
from concurrent.futures import ProcessPoolExecutor
import json
import ast
import random
random.seed(20)

In [2]:
# Cycle data for english, french and german
# say 10k samples each.
import re

def remove_brackets(text):
    return re.sub(r'[\[\]]', '', text)    

def word_cleaner(word):
    word = word.replace('’','').replace("'s", '').replace('.','').replace('?','').replace('!','')
    return re.sub('\W+','', word)

def preprocess_sample(sample):
    """
    Sample is a dict (json) object!
    Entities: [Money, Name, Age, SSN, Credit Card Number, Zipcode, Date]
    Negative sampling: Take 10k 
    
    Preprocess all samples, sample by entity type. 
    """
    mapping = {
        "AMOUNT": "Money",
        "FIRSTNAME": "Name",
        "LASTNAME": "Name",
        "MIDDLENAME": "Name",
        "AGE": "Age",
        "CREDITCARDNUMBER": "Credit Card Number",
        "ZIPCODE": "Zipcode",
        "DATE": "Date",
        "SSN": "SSN",
    }
    text = sample['unmasked_text']
    privacy_mask = ast.literal_eval(sample['privacy_mask'])

    # Intervene for names
    firstNames = []
    middleNames = []
    lastNames = []
    full_names = []

    for key in privacy_mask:
        if 'FIRSTNAME' in key:
            firstNames.append(privacy_mask[key])
        if 'LASTNAME' in key:
            lastNames.append(privacy_mask[key])
        if 'MIDDLENAME' in key:
            middleNames.append(privacy_mask[key])

    # print(firstNames)
    # print(middleNames)
    # print(lastNames)
    # print(text)
    
    for name in firstNames:
        words = text.split()
        occurrences = []
        for i, word in enumerate(words):
            if name in word:
                occurrences.append(i)
        
        temp_names = [name for i in range(len(occurrences))]

        for i, k in enumerate(occurrences):
            for middlename in middleNames:
                if middlename in words[k+1]:
                    temp_names[i] = temp_names[i] + " " + middlename
                    occurrences[i] += 1

        for i, k in enumerate(occurrences):
            for lastname in lastNames:
                if lastname in words[k+1]:
                    temp_names[i] = temp_names[i] + " " + lastname
                    occurrences[i] += 1
        
        for temp_name in temp_names: full_names.append(temp_name)

    for name in middleNames:
        flag = False
        for full_name in full_names:
            if name in full_name:
                flag = True
                break
        if flag: continue

        words = text.split()
        occurrences = []
        for i, word in enumerate(words):
            if name in word:
                occurrences.append(i)
        
        temp_names = [name for i in range(len(occurrences))]

        for i, k in enumerate(occurrences):
            for lastname in lastNames:
                if lastname in words[k+1]:
                    temp_names[i] = temp_names[i] + " " + lastname
                    occurrences[i] += 1
        
        for temp_name in temp_names: full_names.append(temp_name)

    for name in lastNames:
        flag = False
        for full_name in full_names:
            if name in full_name:
                flag = True
                break
        if flag: continue

        words = text.split()
        occurrences = []
        for i, word in enumerate(words):
            if name in word:
                occurrences.append(i)
        
        temp_names = [name for i in range(len(occurrences))]
        
        for temp_name in temp_names: full_names.append(temp_name)

    temp_dict = dict()
    if len(full_names) > 0:
        temp_dict["Name"] = []
        for full_name in full_names:
            temp_dict["Name"].append(full_name)

    for key in privacy_mask:
        # if 'NAME' in key: continue
        value = privacy_mask[key]
        formatted_key = remove_brackets(key).split('_')[0]
        if formatted_key in mapping:
            formatted_key = mapping[formatted_key]
            
        if formatted_key not in temp_dict: 
            temp_dict[formatted_key] = []

        # Intervene for positive and negative samples.
        if formatted_key in mapping.values():   
            if formatted_key == "Name": continue
            formatted_value = value
            temp_dict[formatted_key].append(formatted_value)
    
    conversations = []
    labels = list(temp_dict.keys())
    # print(labels, len(labels))
    
    # Get the first label in with the text.
    human_input_init = {
        "from": "human",
        "value": f"Text: {text}"
    }
    model_output_init = { 
        "from": "gpt", 
        "value": "I've read this text." 
    }
    human_input = {
        "from": "human", 
        "value": f"What describes {labels[0]} in the text?"
    }
    model_output = {
        "from": "gpt",
        "value": f"{temp_dict[labels[0]]}".replace("'", "\"")
    }
    conversations.append(human_input_init)
    conversations.append(model_output_init)
    conversations.append(human_input)
    conversations.append(model_output)
    
    # Get other labels in.
    for label in labels[1:]:
        human_input = {
            "from": "human",
            "value": f"What describes {label} in the text?",
        }
        model_output = {
            "from": "gpt",
            "value": f"{temp_dict[label]}".replace("'", "\""),
        }
        conversations.append(human_input)
        conversations.append(model_output)
    
    # Add in negative examples for chosen labels (otherwise the model tends to hallucinate)
    entities = [
        "Money", "Name", "Age", "SSN", "Credit Card Number", "Zipcode", "Date",
    ]
    for entity in entities:
        if entity not in labels:
            human_input = {
            "from": "human",
            "value": f"What describes {entity} in the text?",
            }
            model_output = {
                "from": "gpt",
                "value": "[]",
            }
            conversations.append(human_input)
            conversations.append(model_output)

    # Format for data pipeline.
    sample_input = {
        "id": "",
        "conversations":conversations,
        "labels": labels,
    }
    
    return sample_input

def read_jsonl(file_path):
    data = []
    with open(file_path, 'r') as file:
        for line in file:
            data.append(json.loads(line))
    return data

In [3]:
names = [
    "english_pii_43k.jsonl",
    "german_pii_52k.jsonl",
    "french_pii_62k.jsonl",
]
# Entities: [Money, Name, Age, SSN, Credit Card Number, Zipcode, Date]
entities = [
    "Money", "Name", "Age", "SSN", "Credit Card Number", "Zipcode", "Date",
]
train_set = []
test_set = []
val_set = []
for name in names:
    processed = []
    file_path = f"/nobackup3/divyam/data/pii-masking-200k/{name}"
    content = read_jsonl(file_path)
    random.shuffle(content)
    # with ProcessPoolExecutor() as executor:
    #     for result in tqdm(
    #         executor.map(preprocess_sample, content), total=len(content)
    #     ):
    #         processed.append(result)

    for i, sample in enumerate(content):
        try:
            result = preprocess_sample(sample)
            processed.append(result)
        except:
            print("Error on ID:", i)
    positive_samples = []
    negative_samples = []
    for sample in processed:
        a = set(sample['labels'])
        b = entities
        if a.intersection(b): positive_samples.append(sample)
        else: negative_samples.append(sample)
    
    print("#"*50)
    print(len(positive_samples), len(negative_samples))
    print(positive_samples[-1])
    print(negative_samples[-1])
    print("#"*50,'\n')

    def train_val_test_split(samples):
        a = samples[:int(0.7*len(samples))]
        b = samples[int(0.7*len(samples)):int(0.8*len(samples))]
        c = samples[int(0.8*len(samples)):]
        return a, b, c
    
    train_positive, val_positive, test_positive = train_val_test_split(positive_samples)
    train_negative, val_negative, test_negative = train_val_test_split(negative_samples)
    train_set.extend(train_positive + train_negative)
    test_set.extend(test_positive + test_negative)
    val_set.extend(val_positive + val_negative)
    
#     processed = positive_samples + negative_samples
#     final_processed.extend(processed)
    
for i in tqdm(range(len(train_set)), total=len(train_set)):
    train_set[i]["id"] = f"{i}"

for i in tqdm(range(len(val_set)), total=len(val_set)):
    val_set[i]["id"] = f"{i}"

for i in tqdm(range(len(test_set)), total=len(test_set)):
    test_set[i]["id"] = f"{i}"
    
with open('/nobackup3/divyam/data/pii-masking-200k/pii_masking_200k_en_fr_de_train_v7.json', 'w') as fp:
    json.dump(train_set, fp, indent=2)

with open('/nobackup3/divyam/data/pii-masking-200k/pii_masking_200k_en_fr_de_val_v7.json', 'w') as fp:
    json.dump(val_set, fp, indent=2)
    
with open('/nobackup3/divyam/data/pii-masking-200k/pii_masking_200k_en_fr_de_test_v7.json', 'w') as fp:
    json.dump(test_set, fp, indent=2)

Error on ID: 243
Error on ID: 4289
Error on ID: 9025
Error on ID: 9992
Error on ID: 12182
Error on ID: 13702
Error on ID: 14871
Error on ID: 14934
Error on ID: 16456
Error on ID: 16727
Error on ID: 17372
Error on ID: 19157
Error on ID: 21439
Error on ID: 23556
Error on ID: 23611
Error on ID: 25209
Error on ID: 27759
Error on ID: 27839
Error on ID: 33590
Error on ID: 33845
Error on ID: 34749
Error on ID: 36346
Error on ID: 40276
Error on ID: 43128
##################################################
25969 17508
{'id': '', 'conversations': [{'from': 'human', 'value': 'Text: Reminder, Mr. Hirthe, you are scheduled to hold a Zoom meeting with both new and recurrent trauma patients today at 11:37am.'}, {'from': 'gpt', 'value': "I've read this text."}, {'from': 'human', 'value': 'What describes Name in the text?'}, {'from': 'gpt', 'value': '["Hirthe"]'}, {'from': 'human', 'value': 'What describes PREFIX in the text?'}, {'from': 'gpt', 'value': '[]'}, {'from': 'human', 'value': 'What describes 

  0%|          | 0/110722 [00:00<?, ?it/s]

  0%|          | 0/15819 [00:00<?, ?it/s]

  0%|          | 0/31638 [00:00<?, ?it/s]