# process origin dataset

## load origin dataset

In [None]:
import datasets

In [None]:
wiki_bio = datasets.load_dataset("wiki_bio")

wiki_bio["valid"] = wiki_bio["val"]
wiki_bio.pop("val")

wiki_bio

## create union dataset

In [None]:
union_dataset = datasets.combine.concatenate_datasets([wiki_bio["train"], wiki_bio["test"], wiki_bio["valid"]])

union_dataset

## check column name & select column for following experiment

In [None]:
input_text_list = union_dataset["input_text"]

type(input_text_list)

In [None]:
column_name_list = [
    column_name 
    for input_text in input_text_list 
    for column_name in input_text["table"]["column_header"]
]

len(column_name_list)

In [6]:
from collections import Counter

In [None]:
column_name_counter = Counter(column_name_list)

len(column_name_counter)

In [None]:
column_name_counter.most_common(20)

In [9]:
seed = 2023
attribute_list = ["birth_date", "occupation", "nationality"]

## create new dataset with attribute_list

In [None]:
def have_attributes(example):
    for attribute in attribute_list:
        if attribute not in example["input_text"]["table"]["column_header"]:
            return False
    return True

def create_attributes(example):
    result = {}
    for attribute in attribute_list:
        index = example["input_text"]["table"]["column_header"].index(attribute)
        value = example["input_text"]["table"]["content"][index]
        result[attribute] = value
    return result


dataset = datasets.DatasetDict(
    {
        "train": wiki_bio["train"],
        "test": wiki_bio["test"],
        "valid": wiki_bio["valid"],
    }
)

for name in dataset.keys():
    dataset[name] = dataset[name].filter(have_attributes, num_proc=32)
    dataset[name] = dataset[name].map(create_attributes, num_proc=32)
    dataset[name] = dataset[name].remove_columns(["input_text"])
    dataset[name] = dataset[name].rename_columns({"target_text": "text"})

dataset

In [None]:
union_dataset = datasets.combine.concatenate_datasets([dataset["train"], dataset["test"], dataset["valid"]])

union_dataset

# process new dataset

## preocess birth_date

In [None]:
birth_date_counter = Counter(union_dataset["birth_date"])

len(birth_date_counter)

In [None]:
birth_date_counter.most_common(20)

In [14]:
import re

In [None]:
def extract_year_from_string(input_string):
    # Define the regular expression pattern to match the year
    pattern = r'\b\d{4}\b'

    # Use regular expression to find all matches of the year in the input string
    matches = re.findall(pattern, input_string)

    # If there are multiple matches, return the first matched year
    if matches:
        return matches[0]
    else:
        return None

for name in dataset.keys():
    dataset[name] = dataset[name].map(lambda example: {"birth_date": extract_year_from_string(example["birth_date"])}, num_proc=32)
    dataset[name] = dataset[name].filter(lambda example: example["birth_date"] is not None, num_proc=32)

dataset

In [None]:
union_dataset = datasets.combine.concatenate_datasets([dataset["train"], dataset["test"], dataset["valid"]])

union_dataset

## process occupation

In [None]:
occupation_counter = Counter(union_dataset["occupation"])

len(occupation_counter)

In [None]:
occupation_counter.most_common(20)

In [None]:
def filter_occupation(example):
    if "-lrb-" in example["occupation"] or "-rrb-" in example["occupation"]:
        return False
    if any(char.isdigit() for char in example["occupation"]):
        return False
    return True

for name in dataset.keys():
    dataset[name] = dataset[name].filter(filter_occupation, num_proc=32)
    
dataset

In [None]:
def filter_occupation_again(example):
    occupation_list = re.split('; |, |and |/ |&', example["occupation"])
    occupation_list = [x.strip() for x in occupation_list]
    text = example["text"].split("\n")[0]
    for occupation in occupation_list:
        if occupation in text:
            return True
    return False

for name in dataset.keys():
    dataset[name] = dataset[name].filter(filter_occupation_again, num_proc=32)

dataset

In [None]:
def process_occupation(example):
    occupation_list = re.split('; |, |and |/ |&', example["occupation"])
    occupation_list = [x.strip() for x in occupation_list]
    text = example["text"].split("\n")[0]
    for occupation in occupation_list:
        if occupation in text:
            break
    return {"occupation": occupation}

for name in dataset.keys():
    dataset = dataset.map(process_occupation, num_proc=32)

dataset

In [None]:
union_dataset = datasets.combine.concatenate_datasets([dataset["train"], dataset["test"], dataset["valid"]])

union_dataset

## process nationality

In [None]:
def filter_nationaity(example):
    text = example["text"].split("\n")[0]
    if example["nationality"] not in text:
        return False
    return True

for name in dataset.keys():
    dataset = dataset.filter(filter_nationaity, num_proc=32)

dataset

In [None]:
union_dataset = datasets.combine.concatenate_datasets([dataset["train"], dataset["test"], dataset["valid"]])

union_dataset

In [None]:
nationality_counter = Counter(union_dataset["nationality"])

len(nationality_counter)

In [None]:
nationality_list = [
    k
    for k, v in nationality_counter.most_common(20)
]

nationality_list

In [None]:
def filter_nationality_again(example):
    if example["nationality"] not in nationality_list:
        return False
    return True

for name in dataset.keys():
    dataset[name] = dataset[name].filter(filter_nationality_again, num_proc=32)

dataset

In [None]:
union_dataset = datasets.combine.concatenate_datasets([dataset["train"], dataset["test"], dataset["valid"]])

union_dataset

## process text

In [None]:
def filter_text(example):
    text = example["text"]
    text = text.split("\n")[0]
    if example["birth_date"] not in text:
        return False
    if example["occupation"] not in text:
        return False
    if example["nationality"] not in text:
        return False
    # if len(text.split()) < 20:
    #     return False
    return True

for name in dataset.keys():
    dataset[name] = dataset[name].filter(filter_text, num_proc=32)

dataset


In [None]:
import spacy
spacy.prefer_gpu(0)
nlp = spacy.load("en_core_web_trf")

In [None]:
month_list = ["january", "february", "march", "april", "may", "june", "july", "august", "september", "october", "november", "december"]

def process_text(example):
    text = example["text"]
    # * process bracket and punctuation
    text = text.split("\n")[0].strip()
    text = text.replace("-lrb- ", "(")
    text = text.replace(" -rrb-", ")")
    text = text.replace(" ,", ",")
    text = text.replace(" .", ".")
    # * process nationality
    for nationality in nationality_list:
        if nationality in text:
            text = text.replace(nationality, nationality.capitalize())
    # * process month
    for month in month_list:
        if month in text:
            text = text.replace(month, month.capitalize())
    # * process person name
    doc = nlp(text)
    for ent in doc.ents:
        if ent.label_ == "PERSON":
            text = text.replace(ent.text, ent.text.title())

    return {"text": text}

for name in dataset.keys():
    dataset[name] = dataset[name].map(process_text)

dataset

In [None]:
union_dataset = datasets.combine.concatenate_datasets([dataset["train"], dataset["test"], dataset["valid"]])

union_dataset

## create candidate list

In [None]:
import random
random.seed(seed)

candidate_num = 10

def create_candidate(example):
    target = example[attribute]
    tmp_list = list(attribute_counter.keys())
    tmp_list.remove(target)
    candidate_list = [target] + random.sample(tmp_list, candidate_num - 1)

    return {f"{attribute}_candidate_list": candidate_list}


for attribute in attribute_list:
    attribute_counter = Counter(union_dataset[attribute])
    for name in dataset.keys():
        dataset[name] = dataset[name].map(create_candidate, num_proc=32)

dataset

In [None]:
union_dataset = datasets.combine.concatenate_datasets([dataset["train"], dataset["valid"], dataset["test"]])

union_dataset

## save tmp dataset

In [None]:
union_dataset.save_to_disk("./tmp/wiki_bio")

## get embedding & save

### sup-simcse-bert-base-uncased

In [None]:
!torchrun --nproc_per_node=8 ../embedding/sup-simcse-bert-base-uncased.py \
    --input_dataset "./tmp/wiki_bio" \
    --output_dataset "your_output_dir" \
    --train_size 11786 \
    --valid_size 1532 \
    --test_size 1480 \

### e5-large-v2

In [None]:
!torchrun --nproc_per_node=8 ../embedding/e5-large-v2.py \
    --input_dataset "./tmp/wiki_bio" \
    --output_dataset "your_output_dir" \
    --train_size 11786 \
    --valid_size 1532 \
    --test_size 1480 \

### bge-large-en

In [None]:
!torchrun --nproc_per_node=8 ../embedding/bge-large-en.py \
    --input_dataset "./tmp/wiki_bio" \
    --output_dataset "your_output_dir" \
    --train_size 11786 \
    --valid_size 1532 \
    --test_size 1480 \