In [1]:
import os
import csv
import json
from tqdm import tqdm

## CEA

In [105]:
from pymongo import MongoClient

import os

MONGO_ENDPOINT = os.environ['MONGO_ENDPOINT']
MONGO_PORT = os.environ['MONGO_PORT']
MONGO_ENDPOINT_USERNAME = os.environ['MONGO_INITDB_ROOT_USERNAME']
MONGO_ENDPOINT_PASSWORD = os.environ['MONGO_INITDB_ROOT_PASSWORD']
MONGO_DBNAME = os.environ['MONGO_DBNAME']
mongo_client = MongoClient(
                            MONGO_ENDPOINT, 
                            int(MONGO_PORT), 
                            username=MONGO_ENDPOINT_USERNAME, 
                            password=MONGO_ENDPOINT_PASSWORD, 
                            authSource='admin'
                        )

def get_collection(collection_name):
    return mongo_client[MONGO_DBNAME][collection_name]

def index_field(collection_name, field_name):
    mongo_client[MONGO_DBNAME][collection_name].create_index([ (field_name, 1) ], unique=True)

def get_entity(collection_name, entity):
    return mongo_client[MONGO_DBNAME][collection_name].find_one({ 'entity': entity })

def insert_many_entity(collection_name, entities):
    mongo_client[MONGO_DBNAME][collection_name].insert_many(entities)

def get_cell_retrieval(collection_name, cell):
    return mongo_client[MONGO_DBNAME][collection_name].find_one({"cell": cell})

In [106]:
import requests

headers = {
    "accept": "application/json",
    "Content-Type": "application/json"
}
def lookup(cell):
    lookup_response = []
    params = {
        'token': "insideslab-lamapi-2022",
        'name': cell,
        'kg': "wikidata",
        'limit': 10
    }
    result = requests.get("https://lamapi.inside.disco.unimib.it/lookup/entity-retrieval-with-array", headers=headers, params=params).json()
    if len(result[cell.lower()]) > 0:
        for entity in result[cell.lower()]:
            lookup_response.append({'id': entity['id'], 'label': entity['name']})
    return lookup_response

def column_analysis(columns):
    params = {
        'token': "insideslab-lamapi-2022",
    }
    result = requests.post("https://lamapi.inside.disco.unimib.it/sti/column-analysis", headers=headers, params=params, json={"json": columns}).json()
    return result

In [107]:
def column_analysis(columns):
    params = {
        'token': "insideslab-lamapi-2022",
    }
    result = requests.post("https://lamapi.inside.disco.unimib.it/sti/column-analysis", headers=headers, params=params, json={"json": columns}).json()
    return result

def lookup(cells):
    lookup_response = []
    params = {
        'token': "insideslab-lamapi-2022",
        'kg': "wikidata",
        'limit': 10
    }
    result = requests.post("https://lamapi.inside.disco.unimib.it/lookup/entity-retrieval-with-array", headers=headers, params=params, json={"cells": cells}).json()
    for key, value in result.items():
        current_lookup_result = {"name": key, "entities": []}
        for entity in value:
            current_lookup_result["entities"].append({'id': entity['id'], 'label': entity['name']})
        lookup_response.append(current_lookup_result)
    return lookup_response

In [108]:
class GroundTruthItemCEA:
    def __init__(self, table, row, column, value):
        self.table: str = table
        self.row: str = row
        self.column: str = column
        self.value: str = value

    def get_item(self) -> dict[str, str]:
        return {
            'table': self.table,
            'row': self.row,
            'column': self.column,
        }
    
    @property
    def get_identifier(self) -> str:
        return f'{self.row}_{self.column}'
    
    @property
    def get_output(self) -> str:
        return f'({self.row},{self.column})={self.value}'

    def __str__(self) -> str:
        return f'{self.table} {self.row} {self.column} {self.value}'

class GroundTruthCEA:
    def __init__(self, filename):
        self.filename: str = filename
        self.ground_truth: dict[str, dict[str, GroundTruthItemCEA]] = {}

    @property
    def total(self) -> int:
        return len(self.ground_truth)
    
    def number_of_items_in_csv(self) -> int:
        with open(self.filename, 'r') as f:
            reader = csv.reader(f)
            total: int = 0
            for _ in reader:
                total += 1
            return total

    def load(self):
        with open(self.filename, 'r') as f:
            print('Loading ground truth...')
            total_lines: int = self.number_of_items_in_csv()
            print(f'Total lines: {total_lines}')
            reader = csv.reader(f)
            for row in tqdm(reader, total=total_lines):
                current_gt: GroundTruthItemCEA = GroundTruthItemCEA(row[0], row[1], row[2], row[3])
                if current_gt.table in self.ground_truth:
                    self.ground_truth[current_gt.table][current_gt.get_identifier] = current_gt
                else:
                    self.ground_truth[current_gt.table] = {
                        current_gt.get_identifier: current_gt
                    }

class Dataset():

    def __init__(self, gt_dataset: GroundTruthCEA, dataset_name: str, filename: str):
        self.filename: str = filename
        self.dataset_name: str = dataset_name
        self.gt_dataset = gt_dataset
        self.llm_dataset: list[dict[str, str]] = []
        self.instruction = "perform the cell entity annotation (cea) task on this table:"
        self.lookup_cache = {}

    def parse_table(self, table_representation: str) -> str:
        columns_list = {}
        rows = table_representation.split('|')
        for row in rows:
            columns = row.split(';')
            for index, cell in enumerate(columns):
                if index not in columns_list:
                    columns_list[index] = [cell]
                else:
                    columns_list[index].append(cell)
        all_columns = []
        for _, column in columns_list.items():
            all_columns.append(column)
        return all_columns
    
    def format_pool(self, pool: list[dict[str, str]]) -> str:
        pool_representation = ""
        for entity in pool:
            pool_representation += entity + "|"
        return pool_representation[:-1]

    def get_pool(self, table_representation: str) -> list[dict[str, str]]:
        pool = set()
        columns_list = self.parse_table(table_representation)
        for column in columns_list:
            for cell in column:
                cell_retrieval = get_cell_retrieval("candidate", cell.lower())
                if cell_retrieval is not None:
                    for entity in cell_retrieval['candidates']['entities'][0:3]:
                        pool.add(f"{entity['id']} [{entity['label']}]")

        return self.format_pool(pool)

    def get_output(self, table_name: str) -> str | None:
        outputCEA: str = ""
        if table_name not in self.gt_dataset.ground_truth:
            return None
        for _, gtTable in self.gt_dataset.ground_truth[table_name].items():
            if gtTable.value != "UNKNOWN":
                current_entity_label = get_entity("entities", gtTable.value.replace('http://www.wikidata.org/entity/', ''))
                if current_entity_label is not None and 'label' in current_entity_label:
                    outputCEA += gtTable.get_output.replace('http://www.wikidata.org/entity/', '') + f" [{current_entity_label['label']}]" + "|"
                else:
                    outputCEA += gtTable.get_output.replace('http://www.wikidata.org/entity/', '') + "|"
            else:
                return None

        return outputCEA

    def load_tables(self):
        all_csv_tables = os.listdir(self.filename)
        print("LOAD TABLES...")
        for csv_table in tqdm(all_csv_tables, total=len(all_csv_tables)):
            table_name: str = csv_table.split('.')[0]
            print(table_name)
            table_representation: str = ""
            with open(os.path.join(self.filename, csv_table), 'r') as f:
                reader = csv.reader(f)
                next(reader, None)
                for row in reader:
                    table_representation += ";".join(row) + "|"
                table_representation = table_representation[:-1]
                pool = self.get_pool(table_representation)
                current_output = self.get_output(table_name)
                if current_output is None:
                    continue
                self.llm_dataset.append({
                    'dataset': self.dataset_name,
                    'table': table_name,
                    'instruction': self.instruction,
                    'input': table_representation,
                    "output": current_output,
                    "pool": pool
                })

In [109]:
cea_datasets: list[dict[str, str]] = [
    {"gt": "./wikidata/SemTab2020_Table_GT_Target/GT/CEA/CEA_Round1_gt.csv", "tables": "./wikidata/SemTab2020_Table_GT_Target/Round1/tables", "dataset": "semtab_2020_r1"}, # SemTab2020_Table_GT_Target R1
    {"gt": "./wikidata/SemTab2020_Table_GT_Target/GT/CEA/CEA_Round2_gt.csv", "tables": "./wikidata/SemTab2020_Table_GT_Target/Round2/tables", "dataset": "semtab_2020_r2"}, # SemTab2020_Table_GT_Target R2
    {"gt": "./wikidata/SemTab2020_Table_GT_Target/GT/CEA/CEA_Round3_gt.csv", "tables": "./wikidata/SemTab2020_Table_GT_Target/Round3/tables", "dataset": "semtab_2020_r3"}, # SemTab2020_Table_GT_Target R3
    {"gt": "./wikidata/SemTab2020_Table_GT_Target/GT/CEA/CEA_Round4_gt.csv", "tables": "./wikidata/SemTab2020_Table_GT_Target/Round4/tables", "dataset": "semtab_2020_r4"}, # SemTab2020_Table_GT_Target R4
    {"gt": "./wikidata/HardTablesR1/DataSets/HardTablesR1/Valid/gt/cea_gt.csv", "tables": "./wikidata/HardTablesR1/DataSets/HardTablesR1/Valid/tables", "dataset": "hardtables_2022_r1"}, # HardTablesR1 2022
    {"gt": "./wikidata/HardTablesR2/DataSets/HardTablesR2/Valid/gt/cea_gt.csv", "tables": "./wikidata/HardTablesR2/DataSets/HardTablesR2/Valid/tables", "dataset": "hardtables_2022_r2"}, # HardTablesR2 2022
    {"gt": "./wikidata/WikidataTables2023R1/DataSets/Valid/gt/cea_gt.csv", "tables": "./wikidata/WikidataTables2023R1/DataSets/Valid/tables", "dataset": "wikidata_tables_2023"}, # Wikidata Tables 2023
]

In [110]:
import requests

headers = {
    "accept": "application/json",
    "Content-Type": "application/json"
}
def column_analysis(columns):
    params = {
        'token': "insideslab-lamapi-2022",
    }
    result = requests.post("https://lamapi.inside.disco.unimib.it/sti/column-analysis", headers=headers, params=params, json={"json": columns}).json()
    return result

def lookup(cell):
    current_lookup_result = {}
    params = {
        'token': "insideslab-lamapi-2022",
        'name': cell,
        'kg': "wikidata",
        'limit': 10
    }
    result = requests.get("https://lamapi.inside.disco.unimib.it/lookup/entity-retrieval", headers=headers, params=params).json()
    for key, value in result.items():
        current_lookup_result = {"name": key, "entities": []}
        for entity in value:
            current_lookup_result["entities"].append({'id': entity['id'], 'label': entity['name']})
    return current_lookup_result

cell_set = set()
cell_buffer = []
for dataset in cea_datasets:
    all_csv_tables = os.listdir(dataset['tables'])
    for csv_table in tqdm(all_csv_tables, total=len(all_csv_tables)):
            table_name: str = csv_table.split('.')[0]
            with open(os.path.join(dataset['tables'], csv_table), 'r') as f:
                reader = csv.reader(f)
                next(reader, None)
                columns = {}
                for row in reader:
                    for index, cell in enumerate(row):
                        if index not in columns:
                            columns[index] = [cell]
                        else:
                            columns[index].append(cell)
                columns_to_annotate = []
                for _, column in columns.items():
                    columns_to_annotate.append(column)
                annotated_columns = column_analysis(columns_to_annotate)
                for key, val in annotated_columns.items():
                    if val['tag'] == "NE":
                        for cell in val['column_rows']:
                            if cell.lower() not in cell_set:
                                cell_set.add(cell.lower())
                                result = lookup(cell.lower())
                                cell_buffer.append({"name": cell.lower(), "entities": result})
                if len(cell_buffer) > 10:
                    insert_many_entity("cell_retrieval", list(cell_buffer))
                    cell_buffer = []

if len(cell_buffer) > 0:
    insert_many_entity("cell_retrieval", list(cell_buffer))
    cell_buffer = []

  0%|          | 0/34294 [00:00<?, ?it/s]

  0%|          | 95/34294 [02:08<12:53:23,  1.36s/it]


KeyboardInterrupt: 

In [None]:
len(cell_set)
result = lookup(list(cell_set)[0:50])
print(result)

[{'name': '1958 okolo slovenska', 'entities': [{'id': 'Q62123884', 'label': '1958 Okolo Slovenska'}, {'id': 'Q62123881', 'label': 'Okolo Slovenska 1955'}, {'id': 'Q62123887', 'label': 'Okolo Slovenska 1960'}, {'id': 'Q62123892', 'label': '1967 Okolo Slovenska'}, {'id': 'Q62123899', 'label': 'Okolo Slovenska 1975'}, {'id': 'Q62123902', 'label': 'Okolo Slovenska 1978'}, {'id': 'Q62123903', 'label': '1979 Okolo Slovenska'}, {'id': 'Q62123908', 'label': '1984 Okolo Slovenska'}]}, {'name': '1980 école militaire shooting', 'entities': [{'id': 'Q32354721', 'label': '1980 École Militaire shooting'}, {'id': 'Q273480', 'label': 'École Militaire'}, {'id': 'Q273476', 'label': 'École Militaire'}, {'id': 'Q1132066', 'label': 'École spéciale militaire'}, {'id': 'Q718869', 'label': 'École militaire Nunziatella'}, {'id': 'Q1515214', 'label': 'École royale militaire'}, {'id': 'Q3392821', 'label': 'École Navale Militaire'}, {'id': 'Q3953343', 'label': 'École militaire Teulié'}]}, {'name': 'augusta-richmo

In [None]:
all_datasets: list[dict[str, str]] = []
for dataset in tqdm(cea_datasets):
    print(f"{dataset['dataset']}...")
    gt_dataset = GroundTruthCEA(dataset['gt'])
    gt_dataset.load()

    dataset = Dataset(gt_dataset, dataset['dataset'], dataset['tables'])
    dataset.load_tables()
    all_datasets.extend(dataset.llm_dataset)

len(all_datasets)

  0%|          | 0/7 [00:00<?, ?it/s]

semtab_2020_r1...
Loading ground truth...
Total lines: 985110


100%|██████████| 985110/985110 [00:02<00:00, 431382.03it/s]


LOAD TABLES...


  0%|          | 1/985110 [00:03<995:12:37,  3.64s/it]
  0%|          | 0/7 [00:06<?, ?it/s]


KeyboardInterrupt: 

In [None]:
all_datasets

[]

In [None]:
def to_jsonl(llm_dataset_list: list, datasetName: str):
    with open(datasetName, 'w') as outfile:
        for entry in llm_dataset_list:
            json.dump(entry, outfile)
            outfile.write('\n')

In [None]:
to_jsonl(all_datasets, 'output/cea_all.jsonl')

## CTA

In [None]:
class GroundTruthItemCTA:
    def __init__(self, table, column, value):
        self.table: str = table
        self.column: str = column
        self.value: str = value

    def get_item(self) -> dict[str, str]:
        return {
            'table': self.table,
            'column': self.column,
        }
    
    @property
    def get_identifier(self) -> str:
        return f'{self.column}'
    
    @property
    def get_output(self) -> str:
        return f'({self.column})={self.value}'

    def __str__(self) -> str:
        return f'{self.table} {self.column} {self.value}'
    
class GroundTruthCTA:

    def __init__(self, filename):
        self.filename: str = filename
        self.ground_truth: dict[str, dict[str, GroundTruthItemCTA]] = {}

    @property
    def total(self) -> int:
        return len(self.ground_truth)
    
    def number_of_items_in_csv(self) -> int:
        with open(self.filename, 'r') as f:
            reader = csv.reader(f)
            total: int = 0
            for _ in reader:
                total += 1
            return total

    def load(self):
        with open(self.filename, 'r') as f:
            print('Loading ground truth...')
            total_lines: int = self.number_of_items_in_csv()
            print(f'Total lines: {total_lines}')
            reader = csv.reader(f)
            for row in tqdm(reader, total=total_lines):
                current_gt: GroundTruthItemCTA = GroundTruthItemCTA(row[0], row[1], row[2])
                if current_gt.table in self.ground_truth:
                    self.ground_truth[current_gt.table][current_gt.get_identifier] = current_gt
                else:
                    self.ground_truth[current_gt.table] = {
                        current_gt.get_identifier: current_gt
                    }

class DatasetCTA():

    def __init__(self, gt_dataset: GroundTruthCTA, filename: str):
        self.filename: str = filename
        self.gt_dataset = gt_dataset
        self.llm_dataset: list[dict[str, str]] = []
        self.instruction = "perform the column type annotation (cta) task on this table:"

    def get_output(self, table_name: str) -> str | None:
        outputCTA: str = ""
        if table_name not in self.gt_dataset.ground_truth:
            return None
        for _, gtTable in self.gt_dataset.ground_truth[table_name].items():
            if gtTable.value != "UNKNOWN":
                outputCTA += gtTable.get_output.replace('http://www.wikidata.org/entity/', '') + "|"
            else:
                return None

        return outputCTA

    def load_tables(self):
        all_csv_tables = os.listdir(self.filename)
        for csv_table in tqdm(all_csv_tables, total=len(all_csv_tables)):
            table_name: str = csv_table.split('.')[0]
            table_representation: str = ""
            with open(os.path.join(self.filename, csv_table), 'r') as f:
                reader = csv.reader(f)
                next(reader, None)
                for row in reader:
                    table_representation += ";".join(row) + "|"
                current_output = self.get_output(table_name)
                if current_output is None:
                    continue
                self.llm_dataset.append({
                    'table': table_name,
                    'instruction': self.instruction,
                    'input': table_representation,
                    "output": current_output
                })

In [None]:
# Mammotab CTA
path = './wikidata/mammotab_dataset_semtab/gt/CTA_mammotab_gt.csv'
gt_dataset_mammotab_cta = GroundTruthCTA(path)
gt_dataset_mammotab_cta.load()

dataset_cta_mammotab = DatasetCTA(gt_dataset_mammotab_cta, './wikidata/mammotab_dataset_semtab/tables')
dataset_cta_mammotab.load_tables()
len(dataset_cta_mammotab.llm_dataset)

Loading ground truth...
Total lines: 5541283


100%|██████████| 5541283/5541283 [00:08<00:00, 673977.94it/s] 
100%|██████████| 980254/980254 [04:45<00:00, 3436.42it/s]


6502

In [None]:
dataset_cta_mammotab.llm_dataset[0]

{'table': 'RSDLDEWR',
 'instruction': 'perform the column type annotation (cta) task on this table:',
 'input': 'Corey Peters;NT;Arizona Cardinals|Steven Means;DE;Atlanta Falcons|Bradley Bozeman;G;Baltimore Ravens|Harrison Phillips;DT;Buffalo Bills|Shaq Thompson;LB;Carolina Panthers|Jimmy Graham;TE;Chicago Bears|Geno Atkins;DT;Cincinnati Bengals|Myles Garrett;DE;Cleveland Browns|Jaylon Smith;LB;Dallas Cowboys|Justin Simmons;FS;Denver Broncos|Trey Flowers;DE;Detroit Lions|Corey Linsley;C;Green Bay Packers|Michael Thomas;FS;Houston Texans|Jacoby Brissett;QB;Indianapolis Colts|Josh Lambo;K;Jacksonville Jaguars|Travis Kelce;TE;Kansas City Chiefs|Alec Ingold;FB;Las Vegas Raiders|Isaac Rochell;DE;Los Angeles Chargers|Andrew Whitworth;OT;Los Angeles Rams|Byron Jones;CB;Miami Dolphins|Eric Kendricks;LB;Minnesota Vikings|Devin McCourty;FS;New England Patriots|Demario Davis;LB;New Orleans Saints|Dalvin Tomlinson;NT;New York Giants|Pierre Desir;CB;New York Jets|Rodney McLeod;S;Philadelphia Eagles

## CPA

In [None]:
class GroundTruthItemCPA:
    def __init__(self, table, col_1, col_2, value):
        self.table: str = table
        self.col_1: str = col_1
        self.col_2: str = col_2
        self.value: str = value

    def get_item(self) -> dict[str, str]:
        return {
            'table': self.table,
            'col_1': self.col_1,
            'col_2': self.col_2,
        }
    
    @property
    def get_identifier(self) -> str:
        return f'{self.col_1}_{self.col_2}'
    
    @property
    def get_output(self) -> str:
        return f'({self.col_1},{self.col_2})={self.value}'

    def __str__(self) -> str:
        return f'{self.table} {self.col_1} {self.col_2} {self.value}'
    
class GroundTruthCPA:

    def __init__(self, filename):
        self.filename: str = filename
        self.ground_truth: dict[str, dict[str, GroundTruthItemCPA]] = {}

    @property
    def total(self) -> int:
        return len(self.ground_truth)
    
    def number_of_items_in_csv(self) -> int:
        with open(self.filename, 'r') as f:
            reader = csv.reader(f)
            total: int = 0
            for _ in reader:
                total += 1
            return total

    def load(self):
        with open(self.filename, 'r') as f:
            print('Loading ground truth...')
            total_lines: int = self.number_of_items_in_csv()
            print(f'Total lines: {total_lines}')
            reader = csv.reader(f)
            for row in tqdm(reader, total=total_lines):
                current_gt: GroundTruthItemCPA = GroundTruthItemCPA(row[0], row[1], row[2], row[3])
                if current_gt.table in self.ground_truth:
                    self.ground_truth[current_gt.table][current_gt.get_identifier] = current_gt
                else:
                    self.ground_truth[current_gt.table] = {
                        current_gt.get_identifier: current_gt
                    }

class DatasetCPA():

    def __init__(self, gt_dataset: GroundTruthCPA, filename: str):
        self.filename: str = filename
        self.gt_dataset = gt_dataset
        self.llm_dataset: list[dict[str, str]] = []
        self.instruction = "perform the column property annotation (cpa) task on this table:"

    def get_output(self, table_name: str) -> str | None:
        outputCTA: str = ""
        if table_name not in self.gt_dataset.ground_truth:
            return None
        for _, gtTable in self.gt_dataset.ground_truth[table_name].items():
            if gtTable.value != "UNKNOWN":
                outputCTA += gtTable.get_output.replace('http://www.wikidata.org/entity/', '') + "|"
            else:
                return None

        return outputCTA

    def load_tables(self):
        all_csv_tables = os.listdir(self.filename)
        for csv_table in tqdm(all_csv_tables, total=len(all_csv_tables)):
            table_name: str = csv_table.split('.')[0]
            table_representation: str = ""
            with open(os.path.join(self.filename, csv_table), 'r') as f:
                reader = csv.reader(f)
                next(reader, None)
                for row in reader:
                    table_representation += ";".join(row) + "|"
                current_output = self.get_output(table_name)
                if current_output is None:
                    continue
                self.llm_dataset.append({
                    'table': table_name,
                    'instruction': self.instruction,
                    'input': table_representation,
                    "output": current_output
                })

## Pool Dataset

In [None]:
import os
import csv
import json
import random
from tqdm import tqdm

In [None]:
from pymongo import MongoClient

import os

MONGO_ENDPOINT = os.environ['MONGO_ENDPOINT']
MONGO_PORT = os.environ['MONGO_PORT']
MONGO_ENDPOINT_USERNAME = os.environ['MONGO_INITDB_ROOT_USERNAME']
MONGO_ENDPOINT_PASSWORD = os.environ['MONGO_INITDB_ROOT_PASSWORD']
MONGO_DBNAME = os.environ['MONGO_DBNAME']
mongo_client = MongoClient(
                            MONGO_ENDPOINT, 
                            int(MONGO_PORT), 
                            username=MONGO_ENDPOINT_USERNAME, 
                            password=MONGO_ENDPOINT_PASSWORD, 
                            authSource='admin'
                        )

def get_collection(collection_name):
    return mongo_client[MONGO_DBNAME][collection_name]

def get_entity(collection_name, entity):
    return mongo_client[MONGO_DBNAME][collection_name].find_one({ 'entity': entity })

def get_cell_retrieval(collection_name, cell):
    return mongo_client[MONGO_DBNAME][collection_name].find_one({"cell": cell})

In [None]:
# Literal Recognizer
import re
class LiteralRecognizer():
    
    # PATTERN TO MATCH DATES
    # dates like: '145 bc', '145.bc', '145,bc'
    # dates like: '1997-08-26', '1997.08.26', '1997/08/26'
    # dates like: '26/08/1997', '26.08.1997', '26-08-1997'
    # dates like: '26/08/97', '26.08.97', '26-08-97'
    # dates like: 'august 26 1997', 'august.26.1997', 'august,26,1997'
    # dates like: '26 august 1997', '26.august.1997', '26,august,1997'
    # dates like: '1997 august 26', '1997,august,26', '1997.august.26'
    # dates like: '1997 26 august', '1997,26,august', '1997.26.august'
    # dates like: 'august 1997', 'august.1997', 'august,1997'
    # dates like: '1997 august', '1997.august', '1997,august'
    # numbers like: '2,797,800,564', '2.797.800.564'
    # numbers like: '200,797,800', '200.797.800'
    # numbers like: '2,8', '2.8'
    # date's year: 1997
    # any pure number: 1345, 26, 1, 1990

    DATE_PATTERN = r'^\d{1,4}[\,\.\s\t\n]+bc$|' \
                r'^\d{4}[-.\/]\d{1,2}[-.\/]\d{1,2}$|' \
                r'^\d{1,2}[-.\/]\d{1,2}[-.\/]\d{4}$|' \
                r'^\d{1,2}[-.\/]\d{1,2}[-.\/]\d{2}$|' \
                r'^(january|february|march|april|may|june|july|august|september|october|november|dicember)[\.\,\s\t\n\/]+\d{1,2}[\.\,\s\t\n\/]+\d{4}$|' \
                r'^\d{1,2}[\.\,\s\t\n\/]+(january|february|march|april|may|june|july|august|september|october|november|dicember)[\.\,\s\t\n\/]+\d{4}$|' \
                r'^\d{4}[\.\,\s\t\n\/]+(january|february|march|april|may|june|july|august|september|october|november|dicember)[\.\,\s\t\n\/]+\d{1,2}$|' \
                r'^\d{4}[\.\,\s\t\n\/]+\d{1,2}[\.\,\s\t\n\/]+(january|february|march|april|may|june|july|august|september|october|november|dicember)$|' \
                r'^(january|february|march|april|may|june|july|august|september|october|november|dicember)[\.\,\s\n\t\/]+\d{4}$|' \
                r'^\d{4}[\.\,\s\n\t\/]+(january|february|march|april|may|june|july|august|september|october|november|dicember)$|' \
                r'^\d+[\.\,]\d+[\.\,]\d+[\.\,]\d+$|' \
                r'^\d+[\.\,]\d+[\.\,]\d+$|' \
                r'^\d+[\.\,]\d+$|' \
                r'^\d{4}$|' \
                r'^\d+$'
    
    # PATTERN TO MATCH URLS
    URL_PATTERN = r'^((http|https)\:\/\/)?[a-zA-Z0-9\.\/\?\:@\-_=#]+\.([a-zA-Z]){2,6}([a-zA-Z0-9\.\&\/\?\:@\-_=#])$'
    
    EMAIL_PATTERN = r'^\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b$'
    
    # PATTERN TO MATCH DATETIME
    DATETIME_PATTERN = r'^\d{4}[-.\/]\d{1,2}[-.\/]\d{1,2}T(24:00|2[0-3]:[0-5][0-9]|[0-1][0-9]:[0-5][0-9])$'
    
    # PATTERN TO MATCH TIME
    TIME_PATTERN = r'^(24:00|2[0-3]:[0-5][0-9]|[0-1][0-9]:[0-5][0-9])$'
    
    # PATTERN TO MATCH FLOAT NUMBERS
    FLOAT_NUMBER = r'^[-+]?\d+\.\d+$|' \
                r'^[-+]?\d+\.\d+[eE][-+]\d+$'
    
    # PATTERN TO MATCH INTEGER NUMBERS
    INTEGER_NUMBER = r'^[-+]?\d+$'
    
    datetime_pattern_to_match = re.compile(DATETIME_PATTERN, re.IGNORECASE)
    time_pattern_to_match = re.compile(TIME_PATTERN, re.IGNORECASE)
    date_pattern_to_match = re.compile(DATE_PATTERN, re.IGNORECASE)
    url_pattern_to_match = re.compile(URL_PATTERN, re.IGNORECASE)
    email_pattern_to_match = re.compile(EMAIL_PATTERN, re.IGNORECASE)
    integer_pattern_to_match = re.compile(INTEGER_NUMBER, re.IGNORECASE)
    float_pattern_to_match = re.compile(FLOAT_NUMBER, re.IGNORECASE)
    
    literal_types = {'datetime': datetime_pattern_to_match, 'time': time_pattern_to_match, 'url': url_pattern_to_match, 'email': email_pattern_to_match, 'float': float_pattern_to_match, 'integer': integer_pattern_to_match, 'date': date_pattern_to_match}
    
    @classmethod
    # check literals
    def check_literal(self, token):
        for key in self.literal_types:
            matches = self.literal_types[key].finditer(token)
            list_of_matches = list(matches)
            if len(list_of_matches) > 0:
                return key.upper()
        return 'ENTITY'

In [None]:
class GroundTruthItemCEA:
    def __init__(self, table, row, column, value):
        self.table: str = table
        self.row: str = row
        self.column: str = column
        self.value: str = value

    def get_item(self) -> dict[str, str]:
        return {
            'table': self.table,
            'row': self.row,
            'column': self.column,
        }
    
    @property
    def get_identifier(self) -> str:
        return f'{self.row}_{self.column}'
    
    @property
    def get_output(self) -> str:
        return f'({self.row},{self.column})={self.value}'

    def __str__(self) -> str:
        return f'{self.table} {self.row} {self.column} {self.value}'
    
class GroundTruthCEA:
    def __init__(self, filename):
        self.filename: str = filename
        self.ground_truth: dict[str, dict[str, GroundTruthItemCEA]] = {}

    @property
    def total(self) -> int:
        return len(self.ground_truth)
    
    def number_of_items_in_csv(self) -> int:
        with open(self.filename, 'r') as f:
            reader = csv.reader(f)
            total: int = 0
            for _ in reader:
                total += 1
            return total

    def load(self):
        with open(self.filename, 'r') as f:
            print('Loading ground truth...')
            total_lines: int = self.number_of_items_in_csv()
            print(f'Total lines: {total_lines}')
            reader = csv.reader(f)
            for row in tqdm(reader, total=total_lines):
                current_gt: GroundTruthItemCEA = GroundTruthItemCEA(row[0], row[1], row[2], row[3])
                if current_gt.table in self.ground_truth:
                    self.ground_truth[current_gt.table][current_gt.get_identifier] = current_gt
                else:
                    self.ground_truth[current_gt.table] = {
                        current_gt.get_identifier: current_gt
                    }

In [131]:
pool_cache = set()
class Dataset():

    def __init__(self, gt_dataset: GroundTruthCEA, dataset_name: str, filename: str):
        self.filename: str = filename
        self.dataset_name: str = dataset_name
        self.gt_dataset = gt_dataset
        self.llm_dataset: list[dict[str, str]] = []
        self.instruction = "perform the cell entity annotation (cea) task on this table:"
        self.literal_recognizer = LiteralRecognizer()
    
    def get_output(self, table_name: str, pool: set) -> str | None:
        outputCEA: str = ""
        if table_name not in self.gt_dataset.ground_truth:
            return None
        for _, gtTable in self.gt_dataset.ground_truth[table_name].items():
            if gtTable.value != "UNKNOWN":
                current_entity_label = get_entity("entities", gtTable.value.replace('http://www.wikidata.org/entity/', ''))
                if current_entity_label is not None and 'label' in current_entity_label:
                    outputCEA += gtTable.get_output.replace('http://www.wikidata.org/entity/', '') + f" [{current_entity_label['label'].lower()}]" + "|"
                    # add correct entity to pool
                    pool.add(f"{gtTable.value.replace('http://www.wikidata.org/entity/', '')} [{current_entity_label['label'].lower()}]")
                else:
                    outputCEA += gtTable.get_output.replace('http://www.wikidata.org/entity/', '') + "|"
            else:
                return None
        
        return outputCEA, pool
    
    def get_pool(self, filename: str, csv_table: str) -> list[dict[str, str]]:
        pool = set()
        with open(os.path.join(filename, csv_table), 'r') as f:
                reader = csv.reader(f)
                next(reader, None)
                for row in reader:
                    for cell in row:
                        if self.literal_recognizer.check_literal(cell) == "ENTITY":
                            cell_retrieval = get_cell_retrieval("candidate", cell.lower())
                            if cell_retrieval is not None:
                                for entity in cell_retrieval['candidates']['entities'][0:3]:
                                    pool.add(f"{entity['id']} [{entity['label'].lower()}]")
                                    pool_cache.add(f"{entity['id']} [{entity['label'].lower()}]")
                            else:
                                # pescare qualche candidato a caso
                                random_entities = random.sample(list(pool_cache), 3)
                                pool.update(random_entities)
        return pool
    
    def format_pool(self, pool: set) -> str:
        string_pool = ""
        for entity in pool:
            string_pool += entity + ";"
        return string_pool
    
    def load_tables(self):
        all_csv_tables = os.listdir(self.filename)
        print("LOAD TABLES...")
        for csv_table in tqdm(all_csv_tables, total=len(all_csv_tables)):
            table_name: str = csv_table.split('.')[0]
            table_representation: str = ""
            with open(os.path.join(self.filename, csv_table), 'r') as f:
                reader = csv.reader(f)
                next(reader, None)
                for row in reader:
                    table_representation += ";".join(row) + "|"
                table_representation = table_representation[:-1]
                pool = self.get_pool(self.filename, csv_table)
                current_output, pool_updated = self.get_output(table_name, pool)
                if current_output is None:
                    continue
                self.llm_dataset.append({
                    'dataset': self.dataset_name,
                    'table': table_name,
                    'instruction': self.instruction,
                    'input': table_representation,
                    "output": current_output[:-1],
                    "pool_instruction": "use the following pool of entities to annotate the table:",
                    "pool": self.format_pool(pool_updated)
                })

In [134]:
cea_datasets: list[dict[str, str]] = [
    {"gt": "./wikidata/SemTab2020_Table_GT_Target/GT/CEA/CEA_Round1_gt.csv", "tables": "./wikidata/SemTab2020_Table_GT_Target/Round1/tables", "dataset": "semtab_2020_r1"}, # SemTab2020_Table_GT_Target R1
    {"gt": "./wikidata/SemTab2020_Table_GT_Target/GT/CEA/CEA_Round2_gt.csv", "tables": "./wikidata/SemTab2020_Table_GT_Target/Round2/tables", "dataset": "semtab_2020_r2"}, # SemTab2020_Table_GT_Target R2
    #{"gt": "./wikidata/SemTab2020_Table_GT_Target/GT/CEA/CEA_Round3_gt.csv", "tables": "./wikidata/SemTab2020_Table_GT_Target/Round3/tables", "dataset": "semtab_2020_r3"}, # SemTab2020_Table_GT_Target R3
    {"gt": "./wikidata/SemTab2020_Table_GT_Target/GT/CEA/CEA_Round4_gt.csv", "tables": "./wikidata/SemTab2020_Table_GT_Target/Round4/tables", "dataset": "semtab_2020_r4"}, # SemTab2020_Table_GT_Target R4
    {"gt": "./wikidata/HardTablesR1/DataSets/HardTablesR1/Valid/gt/cea_gt.csv", "tables": "./wikidata/HardTablesR1/DataSets/HardTablesR1/Valid/tables", "dataset": "hardtables_2022_r1"}, # HardTablesR1 2022
    #{"gt": "./wikidata/HardTablesR2/DataSets/HardTablesR2/Valid/gt/cea_gt.csv", "tables": "./wikidata/HardTablesR2/DataSets/HardTablesR2/Valid/tables", "dataset": "hardtables_2022_r2"}, # HardTablesR2 2022
    {"gt": "./wikidata/WikidataTables2023R1/DataSets/Valid/gt/cea_gt.csv", "tables": "./wikidata/WikidataTables2023R1/DataSets/Valid/tables", "dataset": "wikidata_tables_2023"}, # Wikidata Tables 2023
]

In [135]:
all_datasets: list[dict[str, str]] = []
for dataset in tqdm(cea_datasets):
    print(f"{dataset['dataset']}...")
    gt_dataset = GroundTruthCEA(dataset['gt'])
    gt_dataset.load()

    dataset = Dataset(gt_dataset, dataset['dataset'], dataset['tables'])
    dataset.load_tables()
    all_datasets.extend(dataset.llm_dataset)

  0%|          | 0/2 [00:00<?, ?it/s]

hardtables_2022_r1...
Loading ground truth...
Total lines: 1406


100%|██████████| 1406/1406 [00:00<00:00, 749706.51it/s]


LOAD TABLES...


100%|██████████| 200/200 [00:01<00:00, 189.59it/s]
 50%|█████     | 1/2 [00:01<00:01,  1.06s/it]

wikidata_tables_2023...
Loading ground truth...
Total lines: 4247


100%|██████████| 4247/4247 [00:00<00:00, 716713.97it/s]


LOAD TABLES...


100%|██████████| 500/500 [00:03<00:00, 129.02it/s]
100%|██████████| 2/2 [00:04<00:00,  2.47s/it]


In [138]:
dataset.llm_dataset

[{'dataset': 'wikidata_tables_2023',
  'table': 'NA5XOTEC',
  'instruction': 'perform the cell entity annotation (cea) task on this table:',
  'input': 'The Stag Hunt;|Saint Lawrence;|Cain Killing Abel;German Renaissance|Marriage of the Virgin;German Renaissance|Christ on the Mount of Olives;German Renaissance',
  'output': '(1,0)=Q65511076 [the stag hunt]|(2,0)=Q104154781 [saint lawrence]|(3,0)=Q18339669 [cain killing abel]|(3,1)=Q2455000 [german renaissance]|(4,0)=Q18338565 [marriage of the virgin]|(4,1)=Q2455000 [german renaissance]|(5,0)=Q18339648 [christ on the mount of olives]|(5,1)=Q2455000 [german renaissance]',
  'pool': 'Q20648159 [no no no];Q26535864 [k6 telephone kiosk outside crown court];Q18339669 [cain killing abel];Q3816474 [korfball at the summer olympics];Q7935239 [virtuoso universal server];Q18338565 [marriage of the virgin];Q613436 [ceti];Q3224308 [marriage of the virgin];Q23785360 [münchen ost rangierbahnhof];Q1243556 [renaissance av renaissance];Q49478020 [chief o

In [139]:
def to_jsonl(llm_dataset_list: list, datasetName: str):
    with open(datasetName, 'w') as outfile:
        for entry in llm_dataset_list:
            json.dump(entry, outfile)
            outfile.write('\n')

In [140]:
to_jsonl(all_datasets, './output/pool/cea_pool.jsonl')

In [22]:
entity_set = set()
path = "./wikidata/WikidataTables2023R1/DataSets/Valid/gt/cta_gt.csv"
with open(os.path.join(path), 'r') as f:
    reader = csv.reader(f)
    for row in reader:
        entity_set.add(row[2])
len(entity_set)

194