Skip to content

Commit

Permalink
Predict tasks, datasets and metrics independently
Browse files Browse the repository at this point in the history
* compute probabilities for each axis (tasks, datasets and metrics)
independently
* fix metric score extraction and conversion
  • Loading branch information
mkardas committed Feb 17, 2020
1 parent d005f9c commit 081bd5c
Show file tree
Hide file tree
Showing 5 changed files with 118 additions and 39 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Expand Up @@ -98,3 +98,6 @@ venv.bak/
.mypy_cache/
.idea/*
.vscode/settings.json

# pytest
.pytest_cache
6 changes: 6 additions & 0 deletions README.md
Expand Up @@ -34,3 +34,9 @@ To test the whole extraction on a single file run
```
make test
```

### Unit Tests

```
PYTHONPATH=. py.test
```
74 changes: 50 additions & 24 deletions sota_extractor2/models/linking/bm25_naive.py
Expand Up @@ -9,7 +9,6 @@
import spacy
from scispacy.abbreviation import AbbreviationDetector
from sota_extractor2.models.linking.format import extract_value
from functools import total_ordering


@dataclass()
Expand Down Expand Up @@ -60,20 +59,19 @@ def __str__(self):
return f"{self.model_name}: {self.raw_value} on {self.dataset}"


@total_ordering
class MetricValue(Decimal):
class MetricValue:
value: Decimal
unit: str = None

def __new__(cls, value, unit):
return super().__new__(cls, value / Decimal(100) if unit is '%' else value)

def __init__(self, value, unit):
self.value = value
self.unit = unit

def to_unitless(self):
return self.value

def to_absolute(self):
return Decimal(self)
return self.value / Decimal(100) if self.unit is '%' else self.value

# unit = None means that no unit was specified, so we have to guess the unit.
# if there's a value "21" in a table's cell, then we guess if it's 21 or 0.21 (i.e., 21%)
Expand All @@ -84,10 +82,13 @@ def to_percentage(self):
return self.value

def complement(self):
if self.unit is None and 1 < self.value < 100:
value = 100 - self.value
if self.unit is None:
if 1 < self.value < 100:
value = 100 - self.value
else:
value = 1 - self.value
else:
value = 1 - self.value
value = 100 - self.value
return MetricValue(value, self.unit)

def __repr__(self):
Expand Down Expand Up @@ -211,6 +212,30 @@ def handle_pm(value):
# %%


def convert_metric(raw_value, rng, complementary):
format = "{x}"

percentage = '%' in raw_value
if percentage:
format += '%'

with localcontext() as ctx:
ctx.traps[InvalidOperation] = 0
parsed = extract_value(raw_value, format)
parsed = MetricValue(parsed, '%' if percentage else None)

if complementary:
parsed = parsed.complement()
if rng == '0-1':
parsed = parsed.to_percentage() / 100
elif rng == '1-100':
parsed = parsed.to_percentage()
elif rng == 'abs':
parsed = parsed.to_absolute()
else:
parsed = parsed.to_unitless()
return parsed

proposal_columns = ['dataset', 'metric', 'task', 'format', 'raw_value', 'model', 'model_type', 'cell_ext_id',
'confidence', 'parsed', 'struct_model_type', 'struct_dataset']

Expand Down Expand Up @@ -267,26 +292,27 @@ def linked_proposals(proposals):
df = taxonomy_linking(prop.dataset, datasets, desc, topk=topk, debug_info=prop)
for _, row in df.iterrows():
raw_value = prop.raw_value
task = row['task']
dataset = row['dataset']
metric = row['metric']

with localcontext() as ctx:
ctx.traps[InvalidOperation] = 0
parsed = extract_value(raw_value, format)
parsed = MetricValue(parsed, '%' if percentage else None)
complementary = False
if metric != row['true_metric']:
metric = row['true_metric']
complementary = True

if metric != row['true_metric']:
metric = row['true_metric']
parsed = parsed.complement()
# todo: pass taxonomy directly to proposals generation
ranges = taxonomy_linking.taxonomy.metrics_range
key = (task, dataset, metric)
rng = ranges.get(key, '')
if not rng: rng = ranges.get(metric, '')

if set(metric.lower().split()) & {"error", "accuracy", "bleu", "f1", "precision", "recall"}:
parsed = float(parsed.to_percentage() / 100)
else:
parsed = float(parsed.to_absolute())
parsed = float(convert_metric(raw_value, rng, complementary))

linked = {
'dataset': row['dataset'],
'dataset': dataset,
'metric': metric,
'task': row['task'],
'task': task,
'format': format,
'raw_value': raw_value,
'model': prop.model_name,
Expand All @@ -305,7 +331,7 @@ def linked_proposals(proposals):
return proposals


def linked_proposals(paper_ext_id, paper, annotated_tables, taxonomy_linking=MatchSearch(),
def linked_proposals(paper_ext_id, paper, annotated_tables, taxonomy_linking=None,
dataset_extractor=None, topk=1):
# dataset_extractor=DatasetExtractor()):
proposals = []
Expand Down
72 changes: 58 additions & 14 deletions sota_extractor2/models/linking/context_search.py
Expand Up @@ -116,8 +116,9 @@ def axis_logprobs(evidences_for, reverse_probs, found_evidences, noise, pb):

# compute log-probabilities in a given context and add them to logprobs
@njit
def compute_logprobs(taxonomy, reverse_merged_p, reverse_metrics_p, reverse_task_p,
dss, mss, tss, noise, ms_noise, ts_noise, ds_pb, ms_pb, ts_pb, logprobs):
def compute_logprobs(taxonomy, tasks, datasets, metrics,
reverse_merged_p, reverse_metrics_p, reverse_task_p,
dss, mss, tss, noise, ms_noise, ts_noise, ds_pb, ms_pb, ts_pb, logprobs, axes_logprobs):
task_cache = typed.Dict.empty(types.unicode_type, types.float64)
dataset_cache = typed.Dict.empty(types.unicode_type, types.float64)
metric_cache = typed.Dict.empty(types.unicode_type, types.float64)
Expand All @@ -130,6 +131,21 @@ def compute_logprobs(taxonomy, reverse_merged_p, reverse_metrics_p, reverse_task
task_cache[task] = axis_logprobs(task, reverse_task_p, tss, ts_noise, ts_pb)

logprobs[i] += dataset_cache[dataset] + metric_cache[metric] + task_cache[task]
for i, task in enumerate(tasks):
axes_logprobs[0][i] += task_cache[task]

for i, dataset in enumerate(datasets):
axes_logprobs[1][i] += dataset_cache[dataset]

for i, metric in enumerate(metrics):
axes_logprobs[2][i] += metric_cache[metric]


def _to_typed_list(iterable):
l = typed.List()
for i in iterable:
l.append(i)
return l


class ContextSearch:
Expand All @@ -145,9 +161,12 @@ def __init__(self, taxonomy, evidence_finder, context_noise=(0.5, 0.2, 0.1), met
self.queries = {}
self.taxonomy = taxonomy
self.evidence_finder = evidence_finder
self._taxonomy = typed.List()
for t in self.taxonomy.taxonomy:
self._taxonomy.append(t)

self._taxonomy = _to_typed_list(self.taxonomy.taxonomy)
self._taxonomy_tasks = _to_typed_list(self.taxonomy.tasks)
self._taxonomy_datasets = _to_typed_list(self.taxonomy.datasets)
self._taxonomy_metrics = _to_typed_list(self.taxonomy.metrics)

self.extract_acronyms = AcronymExtractor()
self.context_noise = context_noise
self.metrics_noise = metrics_noise if metrics_noise else context_noise
Expand All @@ -174,10 +193,10 @@ def _numba_extend_list(self, lst):
l.append(x)
return l

def compute_context_logprobs(self, context, noise, ms_noise, ts_noise, logprobs):
def compute_context_logprobs(self, context, noise, ms_noise, ts_noise, logprobs, axes_logprobs):
context = context or ""
abbrvs = self.extract_acronyms(context)
context = normalize_cell_ws(normalize_dataset(context))
context = normalize_cell_ws(normalize_dataset_ws(context))
dss = set(self.evidence_finder.find_datasets(context)) | set(abbrvs.keys())
mss = set(self.evidence_finder.find_metrics(context))
tss = set(self.evidence_finder.find_tasks(context))
Expand All @@ -191,21 +210,34 @@ def compute_context_logprobs(self, context, noise, ms_noise, ts_noise, logprobs)
dss = self._numba_extend_list(dss)
mss = self._numba_extend_list(mss)
tss = self._numba_extend_list(tss)
compute_logprobs(self._taxonomy, self.reverse_merged_p, self.reverse_metrics_p, self.reverse_tasks_p,
dss, mss, tss, noise, ms_noise, ts_noise, self.ds_pb, self.ms_pb, self.ts_pb, logprobs)
compute_logprobs(self._taxonomy, self._taxonomy_tasks, self._taxonomy_datasets, self._taxonomy_metrics,
self.reverse_merged_p, self.reverse_metrics_p, self.reverse_tasks_p,
dss, mss, tss, noise, ms_noise, ts_noise, self.ds_pb, self.ms_pb, self.ts_pb, logprobs,
axes_logprobs)

def match(self, contexts):
assert len(contexts) == len(self.context_noise)
n = len(self._taxonomy)
context_logprobs = np.zeros(n)
axes_context_logprobs = _to_typed_list([
np.zeros(len(self._taxonomy_tasks)),
np.zeros(len(self._taxonomy_datasets)),
np.zeros(len(self._taxonomy_metrics)),
])

for context, noise, ms_noise, ts_noise in zip(contexts, self.context_noise, self.metrics_noise, self.task_noise):
self.compute_context_logprobs(context, noise, ms_noise, ts_noise, context_logprobs)
self.compute_context_logprobs(context, noise, ms_noise, ts_noise, context_logprobs, axes_context_logprobs)
keys = self.taxonomy.taxonomy
logprobs = context_logprobs
#keys, logprobs = zip(*context_logprobs.items())
probs = softmax(np.array(logprobs))
return zip(keys, probs)
axes_probs = [softmax(np.array(a)) for a in axes_context_logprobs]
return (
zip(keys, probs),
zip(self.taxonomy.tasks, axes_probs[0]),
zip(self.taxonomy.datasets, axes_probs[1]),
zip(self.taxonomy.metrics, axes_probs[2])
)

def __call__(self, query, datasets, caption, topk=1, debug_info=None):
cellstr = debug_info.cell.cell_ext_id
Expand All @@ -229,8 +261,10 @@ def __call__(self, query, datasets, caption, topk=1, debug_info=None):
###print("Taking result from cache")
p = self.queries[key]
else:
dist = self.match((datasets, caption, query))
top_results = sorted(dist, key=lambda x: x[1], reverse=True)[:max(topk, 5)]
dists = self.match((datasets, caption, query))

all_top_results = [sorted(dist, key=lambda x: x[1], reverse=True)[:max(topk, 5)] for dist in dists]
top_results, top_results_t, top_results_d, top_results_m = all_top_results

entries = []
for it, prob in top_results:
Expand All @@ -239,6 +273,16 @@ def __call__(self, query, datasets, caption, topk=1, debug_info=None):
entry.update({"evidence": "", "confidence": prob})
entries.append(entry)

# entries = []
# for i in range(5):
# best_independent = dict(
# task=top_results_t[i][0],
# dataset=top_results_d[i][0],
# metric=top_results_m[i][0])
# best_independent.update({"evidence": "", "confidence": top_results_t[i][1]})
# entries.append(best_independent)
#entries = [best_independent] + entries

# best, best_p = sorted(dist, key=lambda x: x[1], reverse=True)[0]
# entry = et[best]
# p = pd.DataFrame({k:[v] for k, v in entry.items()})
Expand Down Expand Up @@ -283,5 +327,5 @@ def from_paper(self, paper):
return self(text)

def __call__(self, text):
text = normalize_cell_ws(normalize_dataset(text))
text = normalize_cell_ws(normalize_dataset_ws(text))
return self.evidence_finder.find_datasets(text) | self.evidence_finder.find_tasks(text)
2 changes: 1 addition & 1 deletion sota_extractor2/models/linking/taxonomy.py
Expand Up @@ -66,5 +66,5 @@ def _read_metrics_info(self, path):
s[rng] = s.get(rng, 0) + 1
mr[metric] = s
for metric in mr:
metrics_range[metric] = sorted(mr[metric].items(), key=lambda x: x[1])[-1]
metrics_range[metric] = sorted(mr[metric].items(), key=lambda x: x[1])[-1][0]
return metrics_info, metrics_range

0 comments on commit 081bd5c

Please sign in to comment.