In [2]:
import re

In [14]:
# match cea prompt
pattern = r"\([0-9]+,[0-9]+\)=Q\d+"
pattern_2 = r"\([0-9]+\.[0-9]+\)=Q\d+"
prompt_example = "[INST] conduct the cell entity annotation (cea) task for wikidata on this table: Dyxig;63.68|Rskilde Fjord;124.496|ingkbing Fjord;280.157|Hgsfjor;111.1[/INST]i of the cell entity annotation task for Wikidata (Wikimedia Incubator). In this task, annotate the given table as appropriate for the QuarryWiki project. [/INST] (1,0)=Q105011 [Roskilde Fjord];(2,0)=Q104999 [Køge Bugt];(3,0)=Q104998 [Køge Bugt];(4,0)=Q104997 [Køge Bugt] [/INST] (1,0)=Q105011 [Roskilde Fjord];(2,0)=Q104999 [Køge Bugt];(3,0)=Q104998 [Køge Bugt];(4,0)=Q104997 [Kø];(13,0)=Q104998 Test]]"
prompt_example_2 = "[INST] conduct the cell entity annotation (cea) task for wikidata on this table: Winchester;124295;Badger Farm;Winchester City Council|West Lancashire;113949;Wrightington;West Lancashire Borough Council|Scarborough;108736;Westerdale;Scarborough Borough Council|Chorley;116821;Astley Village;Chorley Borough Council[/INST]INSTANT [/INST/conduce the column entity annotation (CEA) task for Wikidata, Link Islington;196000;Finsbury Park;Islington London Borough Council[/INST] (0,1)=P1082 [population];(0,2)=Q18843059 [electoral ward in the United Kingdom];(0,3)=Q2085735 [local authority in England]|(0,4)=Q1137624 [local council in England]\n[/INST] (0,1)=P1082 [population];(0,2)=Q18843059 [electoral ward in the United Kingdom];(0,3)=Q1137624 [local council in England];(0,4)=Q11376"
print(re.findall(pattern, prompt_example))
print(re.findall(pattern, prompt_example_2))

['(1,0)=Q105011', '(2,0)=Q104999', '(3,0)=Q104998', '(4,0)=Q104997', '(1,0)=Q105011', '(2,0)=Q104999', '(3,0)=Q104998', '(4,0)=Q104997', '(13,0)=Q104998']
['(0,2)=Q18843059', '(0,3)=Q2085735', '(0,4)=Q1137624', '(0,2)=Q18843059', '(0,3)=Q1137624', '(0,4)=Q11376']


## Prompt Validation

In [19]:
import csv

tables = set()
with open("./output2.csv", "r") as f:
    reader = csv.reader(f)
    for row in reader:
        if len(row) == 4:
            tables.add(row[0])

In [20]:
import csv
from tqdm import tqdm

cea_datasets: list[dict[str, str]] = [
    "./../datasets/wikidata/SemTab2020_Table_GT_Target/GT/CEA/CEA_Round1_gt.csv",
    "./../datasets/wikidata/SemTab2020_Table_GT_Target/GT/CEA/CEA_Round2_gt.csv",
    "./../datasets/wikidata/SemTab2020_Table_GT_Target/GT/CEA/CEA_Round3_gt.csv",
    "./../datasets/wikidata/SemTab2020_Table_GT_Target/GT/CEA/CEA_Round4_gt.csv",
    "./../datasets/wikidata/HardTablesR1/DataSets/HardTablesR1/Valid/gt/cea_gt.csv",
    "./../datasets/wikidata/HardTablesR2/DataSets/HardTablesR2/Valid/gt/cea_gt.csv",
    "./../datasets/wikidata/WikidataTables2023R1/DataSets/Valid/gt/cea_gt.csv"
]
total = 0
gt_dict = {}
for gt in tqdm(cea_datasets):
    with open(gt, "r") as f:
        reader = csv.reader(f)
        for line in reader:
            table = line[0]
            if table not in tables:
                continue
            total += 1
            row = line[1]
            col = line[2]
            entity = line[3].replace("http://www.wikidata.org/entity/", "")
            if table not in gt_dict:
                gt_dict[table] = {}
            else:
                gt_dict[table][f"({row}_{col})"] = entity

100%|██████████| 7/7 [00:01<00:00,  4.16it/s]


In [21]:
import csv

correct = 0
wrong = 0
found = 0
with open("./output2.csv", "r") as f:
    reader = csv.reader(f)
    for current_row in reader:
        if len(current_row) == 4:
            table = current_row[0]
            row = current_row[1]
            col = current_row[2]
            entity = current_row[3].replace("http://www.wikidata.org/entity/", "")
            if table in gt_dict and f"({row}_{col})" in gt_dict[table]:
                found += 1
                if gt_dict[table][f"({row}_{col})"] == entity:
                    correct += 1
                else:
                    wrong += 1

print(f"Correct: {correct}, Wrong: {wrong}, Found: {found}, Total: {total}")

print(f"PRECISION: {correct / found}")
print(f"RECALL: {correct / total}")
print(f"F1 SCORE: {2 * (correct / found) * (correct / total) / ((correct / found) + (correct / total))}")

Correct: 468, Wrong: 7812, Found: 8280, Total: 28033
PRECISION: 0.05652173913043478
RECALL: 0.01669460992401812
F1 SCORE: 0.025775892930906288


In [1]:
from pymongo import MongoClient

import os

MONGO_ENDPOINT = os.environ['MONGO_ENDPOINT']
MONGO_PORT = os.environ['MONGO_PORT']
MONGO_ENDPOINT_USERNAME = os.environ['MONGO_INITDB_ROOT_USERNAME']
MONGO_ENDPOINT_PASSWORD = os.environ['MONGO_INITDB_ROOT_PASSWORD']
MONGO_DBNAME = os.environ['MONGO_DBNAME']
mongo_client = MongoClient(
                            MONGO_ENDPOINT, 
                            int(MONGO_PORT), 
                            username=MONGO_ENDPOINT_USERNAME, 
                            password=MONGO_ENDPOINT_PASSWORD, 
                            authSource='admin'
                        )

annotation_c = mongo_client[MONGO_DBNAME].wikidata_qids_labels

In [2]:
import re
from tqdm import tqdm

def get_annotations(prompt_response):
    all_annotations = re.findall(r"\([0-9]+,[0-9]+\)=Q\d+", prompt_response)
    return all_annotations

def parse_annotation(annotations, table):
    annotation_list = []
    for current_annotation in annotations:
        annotation = current_annotation.replace('(', '').replace(')', '')
        annotation = annotation.split('=')
        row, column = annotation[0].split(',')
        annotation_list.append([table, row, column, annotation[1]])
    return annotation_list

all_annotations = annotation_c.find()
all_tables = set()
final = []
for ann in tqdm(all_annotations, total=5489):
    all_tables.add(ann['table'])
    annotation_list = []
    for current_annotation in get_annotations(ann['response']):
        annotation = current_annotation.replace('(', '').replace(')', '')
        annotation = annotation.split('=')
        row, column = annotation[0].split(',')
        annotation_list.append([ann["table"], row, column, annotation[1]])
    for annot in annotation_list:
        final.append(annot)

100%|██████████| 5489/5489 [00:00<00:00, 66537.39it/s]


In [3]:
import csv

with open('annotations.csv', 'w', encoding='UTF8', newline='') as f:
    writer = csv.writer(f)

    # write multiple rows
    writer.writerows(final)

In [4]:
import csv

gt_test = {}
all_tables_in_gt = set()
all_cells = set()
with open("./../datasets/wikidata/SemTab2020_Table_GT_Target/GT/CEA/CEA_Round3_gt.csv", 'r') as f:
    reader = csv.reader(f)
    for row in reader:
        all_tables_in_gt.add(row[0])
        if row[0] in all_tables:
            all_cells.add(f"{row[0]}_{row[1]}_{row[2]}")
            if row[0] not in gt_test:
                gt_test[row[0]] = {f"{row[1]}_{row[2]}": row[3].replace("http://www.wikidata.org/entity/", "")}
            else:
                gt_test[row[0]][f"{row[1]}_{row[2]}"] = row[3].replace("http://www.wikidata.org/entity/", "")

In [5]:
len(all_tables_in_gt), len(all_tables), len(all_cells)

(62614, 5489, 18616)

In [6]:
def compute_precision(gt_test, pred_path):
    with open(pred_path, 'r') as f:
        reader = csv.reader(f)
        correct = set()
        wrong = set()
        all_annotations = set()
        for row in reader:
            if f"{row[0]}_{row[1]}_{row[2]}" in all_cells:
                if row[0] in gt_test:
                    row_col = f"{row[1]}_{row[2]}"
                    if row_col in gt_test[row[0]]:
                        all_annotations.add(f"{row[0]}_{row[1]}_{row[2]}")
                        if gt_test[row[0]][row_col] == row[3]:
                            correct.add(f"{row[0]}_{row[1]}_{row[2]}")
                        else:
                            wrong.add(f"{row[0]}_{row[1]}_{row[2]}")

    return all_annotations, correct, wrong

all_annotations, correct, wrong = compute_precision(gt_test, './annotations.csv')

print(f"ANNOTATIONS: {len(all_annotations)}", f"CORRECT: {len(correct)}", f"WRONG: {len(wrong)}", f"ALL CELLS: {len(all_cells)}")
print(f"PRECISION: {len(correct)/len(all_annotations)}")
print(f"RECALL: {len(correct)/len(all_cells)}")
print(f"F1: {2*(len(correct)/len(all_annotations))*(len(correct)/len(all_cells))/((len(correct)/len(all_annotations))+(len(correct)/len(all_cells)))}")


ANNOTATIONS: 18164 CORRECT: 439 WRONG: 17727 ALL CELLS: 18616
PRECISION: 0.024168685311605372
RECALL: 0.02358186506231199
F1: 0.02387166938553562
