# Wikidata

## Wikidata Simple Question

Under beam size 2, 5, 10, 20, how the coverage and the graph size vary?

In [1]:
import os

intermediate_dir = 'data/wikidata-simplequestions/intermediate'
retrieve_subgraph_dir = 'artifacts/subgraphs/wikidata-simple-question-t5-e0'
trained_model_dir = 'artifacts/models/wd_simple_t5_e10'
test_scored_path = os.path.join(intermediate_dir, 'scores_valid.jsonl')

In [2]:
for beam_width in [1]:
    retrieved_subgraph_path = os.path.join(retrieve_subgraph_dir, f'e5-b-{beam_width}.jsonl')
    !srtk retrieve -i $test_scored_path \
        -o $retrieved_subgraph_path \
        -e http://localhost:1234/api/endpoint/sparql \
        -kg wikidata \
        --scorer-model-path drt/wikidata-simplequestions \
        --beam-width $beam_width \
        --max-depth 1 \
        --evaluate

Retrieving subgraphs: 100%|█████████████████| 2210/2210 [03:30<00:00, 10.48it/s]
Retrieved subgraphs saved to to artifacts/subgraphs/wikidata-simple-question-t5-e0/e5-b-1.jsonl
Answer recall: 0.9638009049773756 (2130 / 2210)


In [3]:
beam_width = 40
retrieved_subgraph_path = os.path.join(retrieve_subgraph_dir, f'e5-b-{beam_width}.jsonl')
!srtk retrieve -i $test_scored_path \
    -o $retrieved_subgraph_path \
    -e http://localhost:1234/api/endpoint/sparql \
    -kg wikidata \
    --scorer-model-path $trained_model_dir \
    --beam-width $beam_width \
    --max-depth 1 \
    --evaluate

Retrieving subgraphs: 100%|█████████████████| 4296/4296 [03:08<00:00, 22.73it/s]
Retrieved subgraphs saved to to artifacts/subgraphs/wikidata-simple-question/e5-b-30.jsonl
Answer recall: 0.9753258845437617 (4190 / 4296)


In [14]:
import srsly
import statistics

In [16]:
wikidata_sq_subgraph_path_pattern = 'artifacts/subgraphs/wikidata-simple-question-e10/e5-b-{beam}.jsonl'
metric_path_pattern = 'artifacts/subgraphs/wikidata-simple-question-e10/e5-b-{beam}.metric'
subgraph_means = []
recalls = []
for beam in [1, 2, 5, 10, 20]:
    subgraph_path = wikidata_sq_subgraph_path_pattern.format(beam=beam)
    metric_path = metric_path_pattern.format(beam=beam)
    subgraphs = srsly.read_jsonl(subgraph_path)
    metric = srsly.read_json(metric_path)
    avg_n_triplets = statistics.mean([len(sg['triplets']) for sg in subgraphs])
    recalls.append(metric['recall'])
    subgraph_means.append(avg_n_triplets)
print(recalls)
print(subgraph_means)  

[0.963800905, 0.9728506787000001, 0.9760180995000001, 0.9796380090000001, 0.9805429864]
[1.616289592760181, 2.4624434389140273, 6.160633484162896, 12.227149321266968, 18.228054298642533]


# Freebase

## WebQSP

In [8]:
formatted_data_dir = 'data/webqsp/formatted'
retrieve_subgraph_dir = 'artifacts/subgraphs/webqsp'
webqsp_model_dir = 'artifacts/models/webqsp'
formatted_test_path = os.path.join(formatted_data_dir, 'test.jsonl')

In [6]:
for beam_width in [1, 2, 5, 10, 20]:
    retrieve_subgraph_path = os.path.join(retrieve_subgraph_dir, f'e5-b{beam_width}.jsonl')
    !srtk retrieve --input $formatted_test_path \
        --output $retrieve_subgraph_path \
        --sparql-endpoint http://localhost:3001/sparql \
        --knowledge-graph freebase \
        --scorer-model-path $webqsp_model_dir \
        --beam-width $beam_width \
        --max-depth 2 \
        --evaluate

Retrieving subgraphs:   8%|█▌                | 133/1582 [01:36<08:38,  2.79it/s]

In [17]:
wikidata_sq_subgraph_path_pattern = 'artifacts/subgraphs/webqsp/e5-b{beam}.jsonl'
metric_path_pattern = 'artifacts/subgraphs/webqsp/e5-b{beam}.metric'
subgraph_means = []
recalls = []
for beam in [1, 2, 5, 10, 20]:
    subgraph_path = wikidata_sq_subgraph_path_pattern.format(beam=beam)
    metric_path = metric_path_pattern.format(beam=beam)
    subgraphs = srsly.read_jsonl(subgraph_path)
    metric = srsly.read_json(metric_path)
    avg_n_triplets = statistics.mean([len(sg['triplets']) for sg in subgraphs])
    recalls.append(metric['recall'])
    subgraph_means.append(avg_n_triplets)
print(recalls)
print(subgraph_means)  

[0.5619469027, 0.8343868521000001, 0.8609355247, 0.9121365360000001, 0.9348925411000001]
[11.38621997471555, 62.09924146649811, 66.23451327433628, 150.12010113780025, 298.5575221238938]


## CWQ

# Calculate involved entities and relations

In [2]:
import srsly
webqsp_train = srsly.read_jsonl("data/webqsp/formatted/train.jsonl")
webqsp_test = srsly.read_jsonl("data/webqsp/formatted/test.jsonl")

In [3]:
entity_set = set()
for dataset in [webqsp_train, webqsp_test]:
    for question in dataset:
        entity_set.update(question['question_entities'])
        entity_set.update(question['answer_entities'])
print(len(entity_set))

39973
