In [10]:
import pandas as pd
from datasets import load_dataset

ds = "mixtral"
train_ds = load_dataset("s-nlp/KGQA_Subgraphs_Ranking", f"{ds}_subgraphs")[
    "train"
].to_pandas()
test_ds = load_dataset("s-nlp/KGQA_Subgraphs_Ranking", f"{ds}_subgraphs")[
    "test"
].to_pandas()
dev_ds = load_dataset("s-nlp/KGQA_Subgraphs_Ranking", f"{ds}_subgraphs")[
    "validation"
].to_pandas()

print(len(dev_ds))

32757


In [12]:
filter_counter = 0

In [13]:
import ast

entity_status_linker = {
    "INTERNAL": 0,
    "QUESTIONS_ENTITY": 1,
    "ANSWER_CANDIDATE_ENTITY": 2,
}


def get_webnlg_like_mapper(data, i):
    webnlg_mapper = {}
    graph = ast.literal_eval(data["graph"][i])
    links = graph["links"]
    nodes = graph["nodes"]

    for node in nodes:
        if node["label"] == None:
            node_label = "unknown entity"
        else:
            node_label = node["label"]
        webnlg_mapper[node["id"]] = {"source_label": node_label, "type": node["type"]}
    for link in links:
        link_dict = {}
        link_dict["target_label"] = webnlg_mapper[link["target"]]["source_label"]
        link_dict["relation_label"] = link["label"]
        links_array = webnlg_mapper[link["source"]].get("links_array", -1)
        if links_array == -1:
            webnlg_mapper[link["source"]]["links_array"] = [link_dict]
        else:
            webnlg_mapper[link["source"]]["links_array"].append(link_dict)
    return webnlg_mapper


def get_json_format(webnlg_mapper):
    json_converet = {}
    ind = 0
    for j, indx in enumerate(webnlg_mapper):
        source_label = webnlg_mapper[indx].get("source_label", -1)
        entity_type = entity_status_linker[webnlg_mapper[indx]["type"]]
        links_array = webnlg_mapper[indx].get("links_array", -1)
        if links_array == -1:
            pass
        else:
            for link_dict in links_array:
                target_label = link_dict["target_label"]
                relation_label = link_dict["relation_label"]
                json_converet[f"W{ind}"] = [
                    source_label,
                    source_label,
                    [[relation_label, target_label]],
                    entity_type,
                ]
                ind += 1
    return json_converet


def convert_to_webnlg_format(data, i):
    webnlg_format = {}
    webnlg_mapper = get_webnlg_like_mapper(data, i)
    json_converet = get_json_format(webnlg_mapper)
    webnlg_format["id"] = i
    webnlg_format["kbs"] = json_converet
    webnlg_format["text"] = ["example of text"]

    return webnlg_format


def get_all_entities_per_sample(mark_entity_number, mark_entity, entry):
    text_entity = set()
    text_relation = set()
    for entity_id in mark_entity_number:
        entity = entry["kbs"][entity_id]
        if len(entity[0]) == 0:
            continue
        for rel in entity[2]:
            if len(rel[0]) != 0 and len(rel[1]) != 0:
                text_relation.add(rel[0])
                text_entity.add(rel[1])

    text_entity_list = list(text_entity) + list(text_relation)
    text_relation_list = list(text_relation)
    for entity_ele in mark_entity:
        if entity_ele in text_entity_list:
            text_entity_list.remove(entity_ele)

    return text_entity_list, text_relation_list  # все кроме start entities


def filter_entities_by_len(entry, limit=51):
    webnlg_format = entry
    array_of_entities = []
    for key, value in webnlg_format["kbs"].items():
        array_of_entities.append((value[0], value[1], value[2], value[3]))
    sorted_array_of_entities = sorted(
        array_of_entities, key=lambda x: x[3], reverse=True
    )

    curr_len = check_total_len(entry)

    global filter_counter
    if curr_len > limit:
        filter_counter = filter_counter + 1
    while curr_len > limit:
        webnlg_format = {}
        sorted_array_of_entities = sorted_array_of_entities[
            : len(sorted_array_of_entities) - 1
        ]
        json_converet = {}

        for ind, entity in enumerate(sorted_array_of_entities):
            json_converet[f"W{ind}"] = [entity[0], entity[1], entity[2]]

        webnlg_format["id"] = entry["id"]
        webnlg_format["kbs"] = json_converet
        webnlg_format["text"] = ["example of text"]
        curr_len = check_total_len(webnlg_format)

    return webnlg_format


def check_total_len(entry):
    entities = []
    for _ in entry["kbs"]:
        entities.append(_)

    mark_entity = [entry["kbs"][ele_entity][0] for ele_entity in entities]
    mark_entity_number = entities
    text_entity, text_relation = get_all_entities_per_sample(
        mark_entity_number, mark_entity, entry
    )
    total_entity = mark_entity + text_entity

    return len(total_entity)

In [14]:
from tqdm import tqdm

list_of_data = []
for i in tqdm(range(len(test_ds))):
    webnlg_format = convert_to_webnlg_format(test_ds, i)
    # right_data = filter_entities_by_len(webnlg_format, 51)
    right_data = webnlg_format

    list_of_data.append(right_data)

100%|██████████| 9749/9749 [00:01<00:00, 6077.69it/s]


In [15]:
from pathlib import Path

Path(f"./data/{ds}").mkdir(parents=True, exist_ok=True)

In [16]:
import json

with open(f"./configs/data/{ds}/test.json", "w+") as f:
    json.dump(list_of_data, f)
print(len(list_of_data))

9749


In [17]:
from tqdm import tqdm

list_of_data_train = []
for i in tqdm(range(len(train_ds))):
    webnlg_format = convert_to_webnlg_format(train_ds, i)
    # right_data = filter_entities_by_len(webnlg_format, 51)
    right_data = webnlg_format

    list_of_data_train.append(right_data)

100%|██████████| 32757/32757 [00:05<00:00, 5885.47it/s]


In [18]:
import json

with open(f"./configs/data/{ds}/train.json", "w+") as f:
    json.dump(list_of_data_train, f)
print(len(list_of_data_train))

32757


In [19]:
from tqdm import tqdm

list_of_data_dev = []
for i in tqdm(range(len(dev_ds))):
    webnlg_format = convert_to_webnlg_format(dev_ds, i)
    # right_data = webnlg_format
    right_data = filter_entities_by_len(webnlg_format, 51)
    list_of_data_dev.append(right_data)

  1%|          | 358/32757 [00:00<00:09, 3572.58it/s]

100%|██████████| 32757/32757 [00:05<00:00, 5605.12it/s]


In [20]:
import json

with open(f"./configs/data/{ds}/dev.json", "w") as f:
    json.dump(list_of_data_dev, f)
print(len(list_of_data_dev))

32757
