Skip to content

Commit

Permalink
Add text evidences for taxonomy entries
Browse files Browse the repository at this point in the history
* plus some fixes
  • Loading branch information
mkardas committed Feb 14, 2020
1 parent 298a938 commit d005f9c
Show file tree
Hide file tree
Showing 6 changed files with 97 additions and 22 deletions.
17 changes: 17 additions & 0 deletions sota_extractor2/data/elastic.py
Expand Up @@ -352,3 +352,20 @@ def display_fragment(f, cell_type="", display=True):
if display:
display_html(html)
return html


def get_evidences_for_taxonomy(paper_id, task, dataset, metric, value):
evidence_query = Fragment.search().highlight(
'text', pre_tags="<b>", post_tags="</b>", fragment_size=50)

values = [task, dataset, metric, value]
query = {
"query": ' '.join(values)
}

fragments = list(evidence_query
.filter('term', paper_id=paper_id)
.query('match', text=query)[:5]
)

return '\n'.join([' '.join(f.meta['highlight']['text']) for f in fragments])
75 changes: 60 additions & 15 deletions sota_extractor2/models/linking/bm25_naive.py
@@ -1,5 +1,5 @@
import re
from decimal import Decimal
from decimal import Decimal, localcontext, InvalidOperation
from dataclasses import dataclass
import numpy as np
import pandas as pd
Expand All @@ -9,6 +9,7 @@
import spacy
from scispacy.abbreviation import AbbreviationDetector
from sota_extractor2.models.linking.format import extract_value
from functools import total_ordering


@dataclass()
Expand Down Expand Up @@ -58,6 +59,44 @@ def model_type(self):
def __str__(self):
return f"{self.model_name}: {self.raw_value} on {self.dataset}"


@total_ordering
class MetricValue(Decimal):
value: Decimal
unit: str = None

def __new__(cls, value, unit):
return super().__new__(cls, value / Decimal(100) if unit is '%' else value)

def __init__(self, value, unit):
self.value = value
self.unit = unit

def to_absolute(self):
return Decimal(self)

# unit = None means that no unit was specified, so we have to guess the unit.
# if there's a value "21" in a table's cell, then we guess if it's 21 or 0.21 (i.e., 21%)
# based on the target metric properties.
def to_percentage(self):
if self.unit is None and 0 < self.value < 1:
return self.value * 100
return self.value

def complement(self):
if self.unit is None and 1 < self.value < 100:
value = 100 - self.value
else:
value = 1 - self.value
return MetricValue(value, self.unit)

def __repr__(self):
return f"MetricValue({self.value}, {repr(self.unit)})"

def __str__(self):
return str(self.value)


def mkquery_ngrams(query):
return {
"query": {
Expand Down Expand Up @@ -164,7 +203,9 @@ def handle_pm(value):
for match in float_pm_re.findall(value):
if not match[0]:
try:
yield Decimal(whitespace_re.sub("", match[1])) / (100 if match[-1] else 1)
percent = bool(match[-1])
value = Decimal(whitespace_re.sub("", match[1])) / (100 if percent else 1)
yield MetricValue(value, "%" if percent else None)
except:
pass
# %%
Expand Down Expand Up @@ -217,26 +258,30 @@ def annotations(r, c, type='model'):
def linked_proposals(proposals):
for prop in proposals:
# heuristyic to handle accuracy vs error
first_num = (list(handle_pm(prop.raw_value)) + [0])[0]
format = "{x}"
# if first_num > 1:
# first_num /= 100
# format = "{x/100}"
if 0 < first_num < 1 and '%' not in prop.raw_value:
first_num *= 100
format = "{100*x}"
if '%' in prop.raw_value:

percentage = '%' in prop.raw_value
if percentage:
format += '%'

df = taxonomy_linking(prop.dataset, datasets, desc, topk=topk, debug_info=prop)
for _, row in df.iterrows():
raw_value = prop.raw_value
parsed = extract_value(raw_value, format)
metric = row['metric']
if metric != row['true_metric']:
metric = row['true_metric']
parsed = 1 - parsed if 0 < parsed < 1 else 100 - parsed
parsed = float(parsed)

with localcontext() as ctx:
ctx.traps[InvalidOperation] = 0
parsed = extract_value(raw_value, format)
parsed = MetricValue(parsed, '%' if percentage else None)

if metric != row['true_metric']:
metric = row['true_metric']
parsed = parsed.complement()

if set(metric.lower().split()) & {"error", "accuracy", "bleu", "f1", "precision", "recall"}:
parsed = float(parsed.to_percentage() / 100)
else:
parsed = float(parsed.to_absolute())

linked = {
'dataset': row['dataset'],
Expand Down
4 changes: 2 additions & 2 deletions sota_extractor2/models/linking/format.py
@@ -1,8 +1,8 @@
import re
from decimal import Decimal, ROUND_DOWN, ROUND_HALF_UP, InvalidOperation

float_value_re = re.compile(r"([+-]?(?:(?:\d{1,2}(?:,\d{3})+|\d+)(?:\.\d*)?|\.\d+)(?:[eE][+-]?\d+)?)")
float_value_nc = re.compile(r"(?:[+-]?(?:(?:\d{1,2}(?:,\d{3})+|\d+)(?:\.\d*)?|\.\d+)(?:[eE][+-]?\d+)?)")
float_value_re = re.compile(r"([+-]?(?:(?:\d{1,3}(?:,\d{3})+|\d+)(?:\.\d*)?|\.\d+)(?:[eE][+-]?\d+)?)")
float_value_nc = re.compile(r"(?:[+-]?(?:(?:\d{1,3}(?:,\d{3})+|\d+)(?:\.\d*)?|\.\d+)(?:[eE][+-]?\d+)?)")
par_re = re.compile(r"\{([^\}]*)\}")
escaped_whitespace_re = re.compile(r"(\\\s)+")

Expand Down
16 changes: 13 additions & 3 deletions sota_extractor2/models/linking/taxonomy.py
Expand Up @@ -7,7 +7,7 @@
class Taxonomy:
def __init__(self, taxonomy, metrics_info):
self.taxonomy = self._get_taxonomy(taxonomy)
self.metrics_info = self._read_metrics_info(metrics_info)
self.metrics_info, self.metrics_range = self._read_metrics_info(metrics_info)
self.tasks = self._get_axis('task')
self.datasets = self._get_axis('dataset')
self.metrics = self._get_axis('metric')
Expand Down Expand Up @@ -52,9 +52,19 @@ def _get_axis(self, axis):
def _read_metrics_info(self, path):
records = self._read_json(path)
metrics_info = {}
metrics_range = {}
mr = {}
for r in records:
task, dataset, metric = r['task'], r['dataset'], r['metric']
key = (task, dataset, metric)
d = 1 if r['higher_is_better'] else -1
metrics_info[(task, dataset, metric)] = d
rng = r['range']
metrics_info[key] = d
metrics_info[metric] = metrics_info.get(metric, 0) + d
return metrics_info
metrics_range[key] = rng
s = mr.get(metric, {})
s[rng] = s.get(rng, 0) + 1
mr[metric] = s
for metric in mr:
metrics_range[metric] = sorted(mr[metric].items(), key=lambda x: x[1])[-1]
return metrics_info, metrics_range
4 changes: 2 additions & 2 deletions sota_extractor2/models/linking/utils.py
Expand Up @@ -39,8 +39,8 @@ def normalize_dataset_ws(name):

def normalize_dataset(name):
name = remove_references(name)
name = hyphens_re.sub("", name)
name = year_2k_re.sub(r"\1", name)
name = hyphens_re.sub("", name)
name = ws_re.sub(" ", name)
return unidecode(name.strip().lower())

Expand All @@ -51,4 +51,4 @@ def normalize_cell(s):
def normalize_cell_ws(s):
return unidecode("".join([x for x in s if x.isalnum() or x.isspace()]))

# end of cleaning & normalization
# end of cleaning & normalization
3 changes: 3 additions & 0 deletions sota_extractor2/pipeline_logger.py
Expand Up @@ -5,6 +5,9 @@ class PipelineLogger:
def __init__(self):
self.observers = []

def reset(self):
self.observers = []

def register(self, pattern, observer):
if isinstance(pattern, str):
pattern = re.compile(pattern)
Expand Down

0 comments on commit d005f9c

Please sign in to comment.