In [103]:
%load_ext autoreload
%autoreload

from src.common import read_data, QTClaim, QTDataset, save_data
from typing import Dict
from copy import deepcopy

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [104]:
def load_and_combine(split):
    all_claims = []
    for category in ['statistical', 'interval', 'comparison', 'temporal']:
        try:
            claims = read_data(f'custom_decomposition/{split}_extracted_{category}.json')
            print(f'{split} {category}: {len(claims)}')

            for claim in claims:
                assert claim['taxonomy_label'].strip() == category

            all_claims += claims
        except FileNotFoundError as e:
            print(f'{split} {category}: -')

    print(f'{len(all_claims)}\n')
    return all_claims


train_claims = load_and_combine('train')
val_claims = load_and_combine('val')
test_claims = load_and_combine('test')

train statistical: 4660
train interval: 1541
train comparison: 1051
train temporal: 2683
9935

val statistical: 1432
val interval: 469
val comparison: 339
val temporal: 844
3084

test statistical: 1210
test interval: 347
test comparison: 255
test temporal: 683
2495



In [105]:
subquestions_templates = {
    'statistical': "What do the quantities {} mean in this claim? Are these quantities rigorously proven by the evidence?",

    'interval': "This claim contains a range or interval of numbers. What do the numbers {} mean in this claim? Are these numbers rigorously proven by the evidence?",

    'comparison': "This claim contains a comparison of numbers. What do the numbers {} mean in this claim? Are these numbers rigorously proven by the evidence?",

    'temporal': "What do the dates {} mean in this claim? Are these dates rigorously proven by the evidence?",
}


def num_dict_to_text(num_dict: Dict[str, str]) -> str:
    if 'Numbers' in num_dict:
        values = num_dict['Numbers']
    elif 'Time' in num_dict:
        values = [val for val in num_dict.values() if val]
    else:
        return None

    return ', '.join(set(values))


def add_numerical_subquestion(claim: QTClaim) -> None:
    nums_text = num_dict_to_text(claim['extracted_nums'])

    if not nums_text:
        return claim

    template = subquestions_templates[claim['taxonomy_label'].strip()]
    subquestion = template.format(nums_text)
    claim['subquestions'].append(subquestion)


def add_subquestions(claims: QTDataset) -> None:
    for claim in claims:
        add_numerical_subquestion(claim)


# add_numerical_subquestion(deepcopy(test_claims[-8]))

add_subquestions(train_claims)
add_subquestions(val_claims)
add_subquestions(test_claims)

save_data('custom_decomposition/train_evidences_decomposed_custom.json', train_claims)
save_data('custom_decomposition/val_evidences_decomposed_custom.json', val_claims)
save_data('custom_decomposition/test_evidences_decomposed_custom.json', test_claims)