Evaluate the precision, recall and F1 score 

In [90]:
import os
import sys
import argparse
import json
import pandas as pd

current_path = os.getcwd()
parser = argparse.ArgumentParser()

parser.add_argument(
    '--threshold',
    type=float,
    default=0.16,
    help='Threshold to determine a column class')
parser.add_argument(
    '--predictions',
    type=str,

    # default=os.path.join(current_path, 'output/predictions/p_lookup.json'),
    # default=os.path.join(current_path, 'output/predictions/p_cnn_1_2_1.00.json'),

    default=os.path.join(
        current_path, 'output/predictions/p_cnn_1_2_1.00_lookup.json'),
    help='Prediction file')
parser.add_argument(
    '--ground_truths',
    type=str,
    default=os.path.join(current_path, '../SemTab_DataSets/Round1DataSets/Valid/gt/cea_gt.csv'),
    help='Ground truths')
FLAGS, unparsed = parser.parse_known_args()


In [91]:
def load_json(file):
    with open(file) as json_file:
        return json.load(json_file)


In [92]:
wd_prefix = 'http://www.wikidata.org/entity/'

# read columns and the ground truth
col_gt_classes = dict()
gt = pd.read_csv(os.path.join(current_path, FLAGS.ground_truths), delimiter=',', names=['tab_id', 'row_id', 'col_id', 'entity'],
                 dtype={'tab_id': str, 'row_id': str, 'col_id': str, 'entity': str}, keep_default_na=False, nrows=100)
for index, row in gt.iterrows():
    cells = col_gt_classes.keys()
    if row['tab_id'] in cells:
        gt_value = col_gt_classes[cell]
        gt_value.append(row['entity'].split(wd_prefix)[1])
    else:
        cell = '%s' % (row['tab_id'])
        gt_value = [row['entity'].split(wd_prefix)[1]]
    col_gt_classes[cell] = gt_value

# read the column, predicted column and scores
col_pclasses = dict()
p_classes = load_json(FLAGS.predictions)
for key, value in p_classes.items():
    col_cls = key.split(',')
    col, cls = col_cls
    score = float(value)
    if score >= FLAGS.threshold:
        if col in col_pclasses:
            col_pclasses[col].append(cls)
        else:
            col_pclasses[col] = [cls]

# calculating metrics
correct_cells, annotated_cells = set(), set()
for cell in col_gt_classes:
    try:
        if cell not in annotated_cells:
            annotated_cells.add(cell)
        ann_cells = set(col_pclasses[cell] )
        gt_cells = set(col_gt_classes[cell])
        if len(ann_cells.intersection(gt_cells)) > 0 :
            correct_cells.add(cell)
    except:
        pass


precision = len(correct_cells) / \
    len(annotated_cells) if len(annotated_cells) > 0 else 0.0
recall = len(correct_cells) / len(col_gt_classes.keys())
f1 = (2 * precision * recall) / (precision +
                                 recall) if (precision + recall) > 0 else 0.0
main_score = f1
secondary_score = precision
print('F1: %.3f, Precision: %.3f, Recall: %.3f' %
      (f1, precision, recall))


F1: 0.458, Precision: 0.458, Recall: 0.458
